Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 51 additions & 13 deletions src/megatron/energon/sample_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,22 +157,43 @@ def format_sample_detailed(sample: Any, indent: str = "") -> str:

Example:
>>> print(format_sample_detailed({"image": torch.zeros(3, 224, 224), "label": 5}))
- image: Tensor(shape=(3, 224, 224), dtype=torch.float32, ...)
- label: 5
image: Tensor(shape=(3, 224, 224), dtype=torch.float32, ...)
label: 5
"""

def _child_indent(cur: str, value: Any) -> str:
if cur:
return cur + " "
if isinstance(value, (dict, list, tuple)):
return " "
if dataclasses.is_dataclass(value):
return " "
return " "

if isinstance(sample, dict):
result = []
for _, (key, value) in zip(range(25), sample.items()):
result.append(f"{indent} - {key}: {format_sample_detailed(value, indent + ' ')}")
nested = format_sample_detailed(value, _child_indent(indent, value))
head = f"{indent}{key}:"
if "\n" not in nested:
result.append(f"{head} {nested}")
elif isinstance(value, str) or dataclasses.is_dataclass(value):
result.append(f"{head} {nested}")
else:
result.append(f"{head}\n{nested}")
if len(sample) > 25:
result.append(f"{indent} - ... (and {len(sample) - 25} more items)")
result.append(f"{indent}... (and {len(sample) - 25} more items)")
return "\n".join(result)
elif isinstance(sample, str):
if len(sample) > 1000:
sample = f"{sample[:1000]}... (and {len(sample) - 1000} more characters)"
if "\n" in sample:
# represent as """ string if it contains newlines:
return '"""' + sample.replace("\n", "\n " + indent) + '"""'
lines = sample.split("\n")
out = '"""' + indent + lines[0]
for line in lines[1:]:
out += "\n" + indent + line
out += '"""'
return out
return repr(sample)
elif isinstance(sample, (int, float, bool, type(None))):
return repr(sample)
Expand All @@ -181,9 +202,22 @@ def format_sample_detailed(sample: Any, indent: str = "") -> str:
return f"[{', '.join(repr(value) for value in sample)}]"
result = []
for _, value in zip(range(10), sample):
result.append(f"{indent} - {format_sample_detailed(value, indent + ' ')}")
if isinstance(value, dict) and len(value) == 1:
(k, v), = value.items()
nested_v = format_sample_detailed(v, indent + " ")
item_head = f"{indent}- {k}:"
if "\n" not in nested_v:
result.append(f"{item_head} {nested_v}")
else:
result.append(f"{item_head}\n{nested_v}")
else:
nested = format_sample_detailed(value, indent + " ")
if "\n" not in nested:
result.append(f"{indent}- {nested}")
else:
result.append(f"{indent}-\n{nested}")
if len(sample) > 10:
result.append(f"{indent} - ... (and {len(sample) - 10} more items)")
result.append(f"{indent}- ... (and {len(sample) - 10} more items)")
return "\n".join(result)
elif isinstance(sample, torch.Tensor):
try:
Expand Down Expand Up @@ -235,12 +269,16 @@ def format_sample_detailed(sample: Any, indent: str = "") -> str:
# Handle empty arrays or non-numeric dtypes
return f"np.ndarray(shape={sample.shape}, dtype={sample.dtype})"
elif dataclasses.is_dataclass(sample):
result = [f"{indent}{type(sample).__name__}("]
result = [f"{type(sample).__name__}("]
for field in dataclasses.fields(sample):
result.append(
f"{indent} {field.name}={format_sample_detailed(getattr(sample, field.name), indent + ' ')}"
)
result.append(f"{indent})")
field_val = getattr(sample, field.name)
nested = format_sample_detailed(field_val, indent + " ")
head = f"{indent}{field.name}:"
if "\n" not in nested:
result.append(f"{head} {nested}")
else:
result.append(f"{head}\n{nested}")
result.append(")")
return "\n".join(result)
else:
repr_str = repr(sample)
Expand Down
69 changes: 69 additions & 0 deletions tests/test_sample_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

"""Tests for sample formatting helpers."""

import unittest
from dataclasses import dataclass

import numpy as np
import torch

from megatron.energon.sample_utils import format_sample_detailed


@dataclass
class _FormatProbeDc:
n: int
nested: dict[str, int]


class TestFormatSampleDetailed(unittest.TestCase):
def test_format_sample_detailed_complex_types(self) -> None:
"""Exercise dict, scalars, strings, lists/tuples, tensors, ndarray, dataclass, and fallback repr."""

class _Unknown:
def __repr__(self) -> str:
return "<unknown-probe>"

sample = {
"scalars": {"i": -3, "f": 2.5, "b": False, "n": None},
"plain_str": "hi",
"multiline_str": "line1\nline2",
"primitive_seq": (1, 2, "x"),
"hetero_list": [{"k": 1}, {"k": 2}],
"tensor": torch.tensor([0.0, 2.0], dtype=torch.float32),
"array": np.array([1, 2, 3], dtype=np.int64),
"dataclass": _FormatProbeDc(n=9, nested={"a": 1, "b": 2}),
"unknown": _Unknown(),
}
out = format_sample_detailed(sample)

print(out)

assert out == '''\
scalars:
i: -3
f: 2.5
b: False
n: None
plain_str: 'hi'
multiline_str: """\
line1
line2"""
primitive_seq: [1, 2, 'x']
hetero_list:
- k: 1
- k: 2
tensor: Tensor(shape=torch.Size([2]), dtype=torch.float32, device=cpu, min=0.0, max=2.0, values=[0.0, 2.0])
array: np.ndarray(shape=(3,), dtype=int64, min=1, max=3, values=[np.int64(1), np.int64(2), np.int64(3)])
dataclass: _FormatProbeDc(
n: 9
nested:
a: 1
b: 2
)
unknown: <unknown-probe>'''

if __name__ == "__main__":
unittest.main()
Loading