From b1c7bbc8c3e08c2259b0c627e8c4cbfad0ae9e87 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Wed, 20 May 2026 11:49:27 -0700 Subject: [PATCH] Fix sample printer --- src/megatron/energon/sample_utils.py | 64 ++++++++++++++++++++------ tests/test_sample_utils.py | 69 ++++++++++++++++++++++++++++ 2 files changed, 120 insertions(+), 13 deletions(-) create mode 100644 tests/test_sample_utils.py diff --git a/src/megatron/energon/sample_utils.py b/src/megatron/energon/sample_utils.py index a695817f..87cce975 100644 --- a/src/megatron/energon/sample_utils.py +++ b/src/megatron/energon/sample_utils.py @@ -157,22 +157,43 @@ def format_sample_detailed(sample: Any, indent: str = "") -> str: Example: >>> print(format_sample_detailed({"image": torch.zeros(3, 224, 224), "label": 5})) - - image: Tensor(shape=(3, 224, 224), dtype=torch.float32, ...) - - label: 5 + image: Tensor(shape=(3, 224, 224), dtype=torch.float32, ...) + label: 5 """ + + def _child_indent(cur: str, value: Any) -> str: + if cur: + return cur + " " + if isinstance(value, (dict, list, tuple)): + return " " + if dataclasses.is_dataclass(value): + return " " + return " " + if isinstance(sample, dict): result = [] for _, (key, value) in zip(range(25), sample.items()): - result.append(f"{indent} - {key}: {format_sample_detailed(value, indent + ' ')}") + nested = format_sample_detailed(value, _child_indent(indent, value)) + head = f"{indent}{key}:" + if "\n" not in nested: + result.append(f"{head} {nested}") + elif isinstance(value, str) or dataclasses.is_dataclass(value): + result.append(f"{head} {nested}") + else: + result.append(f"{head}\n{nested}") if len(sample) > 25: - result.append(f"{indent} - ... (and {len(sample) - 25} more items)") + result.append(f"{indent}... (and {len(sample) - 25} more items)") return "\n".join(result) elif isinstance(sample, str): if len(sample) > 1000: sample = f"{sample[:1000]}... (and {len(sample) - 1000} more characters)" if "\n" in sample: - # represent as """ string if it contains newlines: - return '"""' + sample.replace("\n", "\n " + indent) + '"""' + lines = sample.split("\n") + out = '"""' + indent + lines[0] + for line in lines[1:]: + out += "\n" + indent + line + out += '"""' + return out return repr(sample) elif isinstance(sample, (int, float, bool, type(None))): return repr(sample) @@ -181,9 +202,22 @@ def format_sample_detailed(sample: Any, indent: str = "") -> str: return f"[{', '.join(repr(value) for value in sample)}]" result = [] for _, value in zip(range(10), sample): - result.append(f"{indent} - {format_sample_detailed(value, indent + ' ')}") + if isinstance(value, dict) and len(value) == 1: + (k, v), = value.items() + nested_v = format_sample_detailed(v, indent + " ") + item_head = f"{indent}- {k}:" + if "\n" not in nested_v: + result.append(f"{item_head} {nested_v}") + else: + result.append(f"{item_head}\n{nested_v}") + else: + nested = format_sample_detailed(value, indent + " ") + if "\n" not in nested: + result.append(f"{indent}- {nested}") + else: + result.append(f"{indent}-\n{nested}") if len(sample) > 10: - result.append(f"{indent} - ... (and {len(sample) - 10} more items)") + result.append(f"{indent}- ... (and {len(sample) - 10} more items)") return "\n".join(result) elif isinstance(sample, torch.Tensor): try: @@ -235,12 +269,16 @@ def format_sample_detailed(sample: Any, indent: str = "") -> str: # Handle empty arrays or non-numeric dtypes return f"np.ndarray(shape={sample.shape}, dtype={sample.dtype})" elif dataclasses.is_dataclass(sample): - result = [f"{indent}{type(sample).__name__}("] + result = [f"{type(sample).__name__}("] for field in dataclasses.fields(sample): - result.append( - f"{indent} {field.name}={format_sample_detailed(getattr(sample, field.name), indent + ' ')}" - ) - result.append(f"{indent})") + field_val = getattr(sample, field.name) + nested = format_sample_detailed(field_val, indent + " ") + head = f"{indent}{field.name}:" + if "\n" not in nested: + result.append(f"{head} {nested}") + else: + result.append(f"{head}\n{nested}") + result.append(")") return "\n".join(result) else: repr_str = repr(sample) diff --git a/tests/test_sample_utils.py b/tests/test_sample_utils.py new file mode 100644 index 00000000..6a6877e2 --- /dev/null +++ b/tests/test_sample_utils.py @@ -0,0 +1,69 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause + +"""Tests for sample formatting helpers.""" + +import unittest +from dataclasses import dataclass + +import numpy as np +import torch + +from megatron.energon.sample_utils import format_sample_detailed + + +@dataclass +class _FormatProbeDc: + n: int + nested: dict[str, int] + + +class TestFormatSampleDetailed(unittest.TestCase): + def test_format_sample_detailed_complex_types(self) -> None: + """Exercise dict, scalars, strings, lists/tuples, tensors, ndarray, dataclass, and fallback repr.""" + + class _Unknown: + def __repr__(self) -> str: + return "" + + sample = { + "scalars": {"i": -3, "f": 2.5, "b": False, "n": None}, + "plain_str": "hi", + "multiline_str": "line1\nline2", + "primitive_seq": (1, 2, "x"), + "hetero_list": [{"k": 1}, {"k": 2}], + "tensor": torch.tensor([0.0, 2.0], dtype=torch.float32), + "array": np.array([1, 2, 3], dtype=np.int64), + "dataclass": _FormatProbeDc(n=9, nested={"a": 1, "b": 2}), + "unknown": _Unknown(), + } + out = format_sample_detailed(sample) + + print(out) + + assert out == '''\ +scalars: + i: -3 + f: 2.5 + b: False + n: None +plain_str: 'hi' +multiline_str: """\ + line1 + line2""" +primitive_seq: [1, 2, 'x'] +hetero_list: + - k: 1 + - k: 2 +tensor: Tensor(shape=torch.Size([2]), dtype=torch.float32, device=cpu, min=0.0, max=2.0, values=[0.0, 2.0]) +array: np.ndarray(shape=(3,), dtype=int64, min=1, max=3, values=[np.int64(1), np.int64(2), np.int64(3)]) +dataclass: _FormatProbeDc( + n: 9 + nested: + a: 1 + b: 2 +) +unknown: ''' + +if __name__ == "__main__": + unittest.main()