From 7244d1551489ec77801a7e2b47c80410aaa58afd Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Tue, 13 May 2025 17:41:15 +0200 Subject: [PATCH 1/7] Implement new debug output as loadable chrome tracing --- .../flavors/webdataset/sample_loader.py | 381 ++--- src/megatron/energon/fork_hook.py | 133 +- src/megatron/energon/savable_loader.py | 165 +-- src/megatron/energon/tracing.py | 1284 +++++++++++++++++ src/megatron/energon/worker.py | 59 +- .../energon/wrappers/batch_dataset.py | 167 ++- .../energon/wrappers/blend_dataset.py | 112 +- .../energon/wrappers/concat_dataset.py | 20 +- .../energon/wrappers/epochize_dataset.py | 70 +- .../energon/wrappers/filter_dataset.py | 33 +- src/megatron/energon/wrappers/gc_dataset.py | 36 +- .../energon/wrappers/group_batch_dataset.py | 196 ++- .../energon/wrappers/iter_map_dataset.py | 101 +- .../energon/wrappers/limit_dataset.py | 62 +- .../energon/wrappers/log_sample_dataset.py | 43 +- src/megatron/energon/wrappers/map_dataset.py | 130 +- .../energon/wrappers/packing_dataset.py | 201 ++- .../energon/wrappers/repeat_dataset.py | 87 +- .../wrappers/shuffle_buffer_dataset.py | 66 +- .../energon/wrappers/watchdog_dataset.py | 31 +- tests/test_dataset.py | 25 +- 21 files changed, 2502 insertions(+), 900 deletions(-) create mode 100644 src/megatron/energon/tracing.py diff --git a/src/megatron/energon/flavors/webdataset/sample_loader.py b/src/megatron/energon/flavors/webdataset/sample_loader.py index 29c23464..3ab8bb30 100644 --- a/src/megatron/energon/flavors/webdataset/sample_loader.py +++ b/src/megatron/energon/flavors/webdataset/sample_loader.py @@ -203,42 +203,42 @@ def _slices_iter(self) -> Generator[RawSampleData, None, None]: """Iterates the samples in a list of slices, possibly using multiple parallel iterators over the slices.""" - assert self.slice_offsets is not None - - active_slice_probs = torch.zeros(self.parallel_slice_iters, dtype=torch.float32) - active_slices = self._active_slice_state - pending_slice_indexes = self._pending_slice_indexes + trace = self.worker_config.worker_trace_span() - def slice_at(idx: int) -> SliceState: + with trace.span("WebdatasetSampleLoaderDataset._slices_iter", level=1) as fn_span: assert self.slice_offsets is not None - return SliceState( - index=idx, - current=self.slice_offsets[idx], - ) - - # Weight the slices by their size to get a more even distribution of samples - if any(s is not None for s in active_slices) or self._pending_slices_offset is not None: - # Having an active state, or pending slices. This means we are resuming an epoch. - if pending_slice_indexes is None: - # Need to restore the pending slices - pending_slice_indexes = self._slices_once() - assert pending_slice_indexes is not None - # Restore the state - assert len(active_slices) == self.parallel_slice_iters - for idx, slice_state in enumerate(active_slices): - if slice_state is not None: - active_slice_probs[idx] = ( - self.slice_offsets[slice_state.index + 1] - - self.slice_offsets[slice_state.index] - ) + active_slice_probs = torch.zeros(self.parallel_slice_iters, dtype=torch.float32) + active_slices = self._active_slice_state + pending_slice_indexes = self._pending_slice_indexes - if self.worker_config.should_log(level=1): - self.worker_config.worker_log( + def slice_at(idx: int) -> SliceState: + assert self.slice_offsets is not None + return SliceState( + index=idx, + current=self.slice_offsets[idx], + ) + + # Weight the slices by their size to get a more even distribution of samples + if any(s is not None for s in active_slices) or self._pending_slices_offset is not None: + # Having an active state, or pending slices. This means we are resuming an epoch. + if pending_slice_indexes is None: + # Need to restore the pending slices + pending_slice_indexes = self._slices_once() + assert pending_slice_indexes is not None + + # Restore the state + assert len(active_slices) == self.parallel_slice_iters + for idx, slice_state in enumerate(active_slices): + if slice_state is not None: + active_slice_probs[idx] = ( + self.slice_offsets[slice_state.index + 1] + - self.slice_offsets[slice_state.index] + ) + + fn_span.update_args( { - "t": "WebdatasetSampleLoaderDataset._slices_iter.resume_epoch", - "r": self.worker_config.rank, - "w": self.worker_config.rank_worker_id(), + "mode": "resume_epoch", "pending_slice_indexes": pending_slice_indexes, "active_slices": [ ( @@ -257,18 +257,37 @@ def slice_at(idx: int) -> SliceState: "probs": active_slice_probs.tolist(), } ) + trace.instant( + "WebdatasetSampleLoaderDataset._slices_iter.resume_epoch", + args={ + "pending_slice_indexes": pending_slice_indexes, + "active_slices": [ + ( + None + if state is None + else { + "index": state.index, + "current": state.current, + } + ) + for state in active_slices + ], + "count": self._sample_count, + "epoch": self._epoch_count, + "epoch_count": self._epoch_sample_count, + "probs": active_slice_probs.tolist(), + }, + level=1, + ) - else: - # Start a new epoch - assert pending_slice_indexes is None - pending_slice_indexes = self._slices_once() + else: + # Start a new epoch + assert pending_slice_indexes is None + pending_slice_indexes = self._slices_once() - if self.worker_config.should_log(level=1): - self.worker_config.worker_log( + fn_span.update_args( { - "t": "WebdatasetSampleLoaderDataset._slices_iter.next_epoch", - "r": self.worker_config.rank, - "w": self.worker_config.rank_worker_id(), + "mode": "next_epoch", "pending_slice_indexes": pending_slice_indexes, "count": self._sample_count, "epoch": self._epoch_count, @@ -277,136 +296,146 @@ def slice_at(idx: int) -> SliceState: "shuffle_over_epochs": self.shuffle_over_epochs, } ) - - assert self._pending_slices_offset is not None - - # List of slice iterators, always of length `parallel_slice_iters`. May contain `None`. - active_slices.clear() - # Fill up the slice iterators - while len(pending_slice_indexes) > 0 and len(active_slices) < self.parallel_slice_iters: - slice_index = pending_slice_indexes.pop() - self._pending_slices_offset += 1 - slice_state = slice_at(slice_index) - active_slice_probs[len(active_slices)] = ( - self.slice_offsets[slice_state.index + 1] - - self.slice_offsets[slice_state.index] + trace.instant( + "WebdatasetSampleLoaderDataset._slices_iter.next_epoch", + args={ + "pending_slice_indexes": pending_slice_indexes, + "count": self._sample_count, + "epoch": self._epoch_count, + "epoch_count": self._epoch_sample_count, + "probs": active_slice_probs.tolist(), + "shuffle_over_epochs": self.shuffle_over_epochs, + }, + level=1, ) - active_slices.append(slice_state) - # Fill up the slice iterators with None - for _ in range(len(active_slices), self.parallel_slice_iters): - active_slices.append(None) - - # print( - # f"Next slice iters generated for {self.worker_config.rank}:{self.worker_config.rank_worker_id()}: probs={active_slice_probs}" - # ) - # for slice_state in active_slices: - # if slice_state is None: - # print(" - None") - # else: - # print( - # f" - [{slice_offsets[slice_state.index]}, {slice_offsets[slice_state.index + 1]}] at {slice_state.current}" - # ) - - # Iterate over the slice iterators while there is an iterator left - while torch.count_nonzero(active_slice_probs).item() > 0: - if self.shuffle_over_epochs is None: - # No shuffling, deterministic order, always the same - assert self.parallel_slice_iters == 1 - slice_idx = 0 - else: - # Take a random slice iterator - slice_idx = self._worker_rng.choice_idx(active_slice_probs) - slice_state = active_slices[slice_idx] - assert slice_state is not None - sample = self._get_sample(slice_state.current) - # print(f"Read sample at {slice_state.current} -> {'None' if sample is None or sample.data[0] is None else sample.data[0]['__key__']}") - slice_state.current += 1 - self._sample_count += 1 - self._epoch_sample_count += 1 - if slice_state.current >= self.slice_offsets[slice_state.index + 1]: - # Iterator exhausted -> take next / remove from list - if len(pending_slice_indexes) > 0 or self.shuffle_over_epochs == -1: - if len(pending_slice_indexes) > 0: - # Take the next slice (without replacement) - next_idx = pending_slice_indexes.pop() - assert self._pending_slices_offset is not None - self._pending_slices_offset += 1 - else: - # Randomly select a new slice directly (with replacement) - num_slices = len(self.slice_offsets) - 1 - next_idx = self._worker_rng.randbelow(num_slices) - next_slice_state = slice_at(next_idx) - active_slice_probs[slice_idx] = ( - self.slice_offsets[next_slice_state.index + 1] - - self.slice_offsets[next_slice_state.index] - ) - active_slices[slice_idx] = next_slice_state - # print( - # f"Slice iter for {self.worker_config.rank}:{self.worker_config.rank_worker_id()} " - # f"[{slice_offsets[slice_state.index]}, {slice_offsets[slice_state.index + 1]}] exhausted at {slice_state.current}, " - # f"taking next slice {next_slice_state} [{slice_offsets[next_slice_state.index]}, {slice_offsets[next_slice_state.index + 1]}], " - # f"{len(pending_slice_indexes)} slices left, probs={active_slice_probs.tolist()}" - # ) - else: - active_slice_probs[slice_idx] = 0 - active_slices[slice_idx] = None - # print( - # f"Slice iter for {self.worker_config.rank}:{self.worker_config.rank_worker_id()} " - # f"[{slice_offsets[slice_state.index]}, {slice_offsets[slice_state.index + 1]}] exhausted at {slice_state.current}, " - # f"no next slice, probs={active_slice_probs.tolist()}" - # ) - if self.worker_config.should_log(level=2): - self.worker_config.worker_log( - { - "t": "WebdatasetSampleLoaderDataset._slices_iter.exhausted", - "r": self.worker_config.rank, - "w": self.worker_config.rank_worker_id(), - "remaining": len(pending_slice_indexes), - "count": self._sample_count, - "epoch": self._epoch_count, - "epoch_count": self._epoch_sample_count, - "probs": active_slice_probs.tolist(), - } - ) - if sample.data[0] is not None: - # Otherwise the sample was skipped. - if self.worker_config.should_log(level=1): - self.worker_config.worker_log( - { - "t": "WebdatasetSampleLoaderDataset._slices_iter.yield", - "r": self.worker_config.rank, - "w": self.worker_config.rank_worker_id(), - "index": sample.__restore_key__[1], - "key": sample.data[0]["__key__"], - "shard": sample.data[0]["__shard__"], - "count": self._sample_count, - "epoch": self._epoch_count, - "epoch_count": self._epoch_sample_count, - } + assert self._pending_slices_offset is not None + + # List of slice iterators, always of length `parallel_slice_iters`. May contain `None`. + active_slices.clear() + # Fill up the slice iterators + while ( + len(pending_slice_indexes) > 0 + and len(active_slices) < self.parallel_slice_iters + ): + slice_index = pending_slice_indexes.pop() + self._pending_slices_offset += 1 + slice_state = slice_at(slice_index) + active_slice_probs[len(active_slices)] = ( + self.slice_offsets[slice_state.index + 1] + - self.slice_offsets[slice_state.index] ) - # Now, yield the sample - yield sample - del sample - if self.worker_config.should_log(level=2): - self.worker_config.worker_log( - { - "t": "WebdatasetSampleLoaderDataset._slices_iter.all_exhausted", - "r": self.worker_config.rank, - "w": self.worker_config.rank_worker_id(), - "count": self._sample_count, - "epoch": self._epoch_count, - "epoch_count": self._epoch_sample_count, - } - ) - - # Epoch has finished, reset states. - self._epoch_count += 1 - self._epoch_sample_count = 0 - self._pending_slice_indexes = None - self._pending_slices_offset = None - # print( - # f"slice iters exhausted for {self.worker_config.rank}:{self.worker_config.rank_worker_id()} after {cnt} samples" - # ) + active_slices.append(slice_state) + # Fill up the slice iterators with None + for _ in range(len(active_slices), self.parallel_slice_iters): + active_slices.append(None) + + # print( + # f"Next slice iters generated for {self.worker_config.rank}:{self.worker_config.rank_worker_id()}: probs={active_slice_probs}" + # ) + # for slice_state in active_slices: + # if slice_state is None: + # print(" - None") + # else: + # print( + # f" - [{slice_offsets[slice_state.index]}, {slice_offsets[slice_state.index + 1]}] at {slice_state.current}" + # ) + + # Iterate over the slice iterators while there is an iterator left + while torch.count_nonzero(active_slice_probs).item() > 0: + with trace.span("WebdatasetSampleLoaderDataset._slices_iter.iter", level=1): + if self.shuffle_over_epochs is None: + # No shuffling, deterministic order, always the same + assert self.parallel_slice_iters == 1 + slice_idx = 0 + else: + # Take a random slice iterator + slice_idx = self._worker_rng.choice_idx(active_slice_probs) + slice_state = active_slices[slice_idx] + assert slice_state is not None + sample = self._get_sample(slice_state.current) + # print(f"Read sample at {slice_state.current} -> {'None' if sample is None or sample.data[0] is None else sample.data[0]['__key__']}") + slice_state.current += 1 + self._sample_count += 1 + self._epoch_sample_count += 1 + if slice_state.current >= self.slice_offsets[slice_state.index + 1]: + # Iterator exhausted -> take next / remove from list + if len(pending_slice_indexes) > 0 or self.shuffle_over_epochs == -1: + if len(pending_slice_indexes) > 0: + # Take the next slice (without replacement) + next_idx = pending_slice_indexes.pop() + assert self._pending_slices_offset is not None + self._pending_slices_offset += 1 + else: + # Randomly select a new slice directly (with replacement) + num_slices = len(self.slice_offsets) - 1 + next_idx = self._worker_rng.randbelow(num_slices) + next_slice_state = slice_at(next_idx) + active_slice_probs[slice_idx] = ( + self.slice_offsets[next_slice_state.index + 1] + - self.slice_offsets[next_slice_state.index] + ) + active_slices[slice_idx] = next_slice_state + # print( + # f"Slice iter for {self.worker_config.rank}:{self.worker_config.rank_worker_id()} " + # f"[{slice_offsets[slice_state.index]}, {slice_offsets[slice_state.index + 1]}] exhausted at {slice_state.current}, " + # f"taking next slice {next_slice_state} [{slice_offsets[next_slice_state.index]}, {slice_offsets[next_slice_state.index + 1]}], " + # f"{len(pending_slice_indexes)} slices left, probs={active_slice_probs.tolist()}" + # ) + else: + active_slice_probs[slice_idx] = 0 + active_slices[slice_idx] = None + # print( + # f"Slice iter for {self.worker_config.rank}:{self.worker_config.rank_worker_id()} " + # f"[{slice_offsets[slice_state.index]}, {slice_offsets[slice_state.index + 1]}] exhausted at {slice_state.current}, " + # f"no next slice, probs={active_slice_probs.tolist()}" + # ) + trace.instant( + "WebdatasetSampleLoaderDataset._slices_iter.exhausted", + args={ + "remaining": len(pending_slice_indexes), + "count": self._sample_count, + "epoch": self._epoch_count, + "epoch_count": self._epoch_sample_count, + "probs": active_slice_probs.tolist(), + }, + level=2, + ) + if sample.data[0] is not None: + # Otherwise the sample was skipped. + with trace.span( + "WebdatasetSampleLoaderDataset._slices_iter.yield", + args={ + "index": sample.__restore_key__[1], + "key": sample.data[0]["__key__"], + "shard": sample.data[0]["__shard__"], + "count": self._sample_count, + "epoch": self._epoch_count, + "epoch_count": self._epoch_sample_count, + }, + level=2, + ): + # Now, yield the sample + yield sample + del sample + if self.worker_config.should_log(level=2): + trace.instant( + "WebdatasetSampleLoaderDataset._slices_iter.all_exhausted", + args={ + "count": self._sample_count, + "epoch": self._epoch_count, + "epoch_count": self._epoch_sample_count, + }, + level=2, + ) + + # Epoch has finished, reset states. + self._epoch_count += 1 + self._epoch_sample_count = 0 + self._pending_slice_indexes = None + self._pending_slices_offset = None + # print( + # f"slice iters exhausted for {self.worker_config.rank}:{self.worker_config.rank_worker_id()} after {cnt} samples" + # ) def __len__(self) -> int: return sum( @@ -425,17 +454,11 @@ def __iter__(self) -> Iterator[RawSampleData]: self.ensure_slice_offsets() assert self.slice_offsets is not None - if self.worker_config.should_log(level=1): - self.worker_config.worker_log( - { - "t": "WebdatasetSampleLoaderDataset.__iter__", - "r": self.worker_config.rank, - "w": self.worker_config.rank_worker_id(), - "slice_offsets": self.slice_offsets, - "parallel_slice_iters": self.parallel_slice_iters, - "shuffle_over_epochs": self.shuffle_over_epochs, - } - ) + self.worker_config.worker_trace_writer().instant( + "WebdatasetSampleLoaderDataset.__iter__", + args=self.config(), + level=1, + ) if len(self.slice_offsets) <= 1: return diff --git a/src/megatron/energon/fork_hook.py b/src/megatron/energon/fork_hook.py index b5184207..087363be 100644 --- a/src/megatron/energon/fork_hook.py +++ b/src/megatron/energon/fork_hook.py @@ -1,15 +1,15 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause +from dataclasses import dataclass import functools import os import weakref from typing import Callable, Protocol, Type, TypeVar -_after_in_child_fork_hooks = weakref.WeakKeyDictionary() -_after_in_parent_fork_hooks = weakref.WeakKeyDictionary() -_before_fork_hooks = weakref.WeakKeyDictionary() - +_after_in_child_fork_hooks = dict() +_after_in_parent_fork_hooks = dict() +_before_fork_hooks = dict() T = TypeVar("T", bound=Callable[[], None]) @@ -18,34 +18,41 @@ def before_fork_hook(callable: Callable[[], None]): """ Run function before the fork of a worker process. The function must be persistent. """ - # Make sure, that callable is a method of object - assert getattr(callable, "__self__", None) is None, ( - f"Callable must not be a method: {callable.__name__}" - ) - # print(f"Adding before_fork_hook for {callable.__name__}\n", end="") - _before_fork_hooks[callable] = callable + if getattr(callable, "__self__", None): + self = callable.__self__ + _before_fork_hooks[id(self)] = callable + weakref.finalize(self, lambda: _before_fork_hooks.pop(id(self))) + else: + _before_fork_hooks[id(callable)] = callable + weakref.finalize(callable, lambda: _before_fork_hooks.pop(id(callable))) -def after_in_parent_fork_hook(callable: T): +def after_in_parent_fork_hook(callable: Callable[[], None]): """ Run function after the fork of a worker process. The function must be persistent. """ # print(f"Adding after_in_child_fork_hook for {callable.__name__}\n", end="") - assert getattr(callable, "__self__", None) is None, ( - f"Callable must not be a method: {callable.__name__}" - ) - _after_in_parent_fork_hooks[callable] = callable + if getattr(callable, "__self__", None): + self = callable.__self__ + _after_in_parent_fork_hooks[id(self)] = callable + weakref.finalize(self, lambda: _after_in_parent_fork_hooks.pop(id(self))) + else: + _after_in_parent_fork_hooks[id(callable)] = callable + weakref.finalize(callable, lambda: _after_in_parent_fork_hooks.pop(id(callable))) -def after_in_child_fork_hook(callable: T): +def after_in_child_fork_hook(callable: Callable[[], None]): """ Run function after the fork of a worker process. The function must be persistent. """ # print(f"Adding after_in_child_fork_hook for {callable.__name__}\n", end="") - assert getattr(callable, "__self__", None) is None, ( - f"Callable must not be a method: {callable.__name__}" - ) - _after_in_child_fork_hooks[callable] = callable + if getattr(callable, "__self__", None): + self = callable.__self__ + _after_in_child_fork_hooks[id(self)] = callable + weakref.finalize(self, lambda: _after_in_child_fork_hooks.pop(id(self))) + else: + _after_in_child_fork_hooks[id(callable)] = callable + weakref.finalize(callable, lambda: _after_in_child_fork_hooks.pop(id(callable))) class ForkMixin: @@ -55,18 +62,21 @@ class ForkMixin: def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + self.__post_init__() + + def __post_init__(self): if getattr(self.__before_fork__, "__func__", None) is not ForkMixin.__before_fork__: - _before_fork_hooks[self] = "__before_fork__" + before_fork_hook(self.__before_fork__) if ( getattr(self.__after_in_child_fork__, "__func__", None) is not ForkMixin.__after_in_child_fork__ ): - _after_in_child_fork_hooks[self] = "__after_in_child_fork__" + after_in_child_fork_hook(self.__after_in_child_fork__) if ( getattr(self.__after_in_parent_fork__, "__func__", None) is not ForkMixin.__after_in_parent_fork__ ): - _after_in_parent_fork_hooks[self] = "__after_in_parent_fork__" + after_in_parent_fork_hook(self.__after_in_parent_fork__) def __after_in_child_fork__(self): """ @@ -86,59 +96,43 @@ def __before_fork__(self): """ pass - -class ForkHookProtocol(Protocol): +@dataclass +class DataclassForkMixin: """ - A protocol that defines a method that runs before and after the fork of a worker process. + A mixin that runs a method after the fork of a worker process. """ + def __post_init__(self): + if getattr(self.__before_fork__, "__func__", None) is not ForkMixin.__before_fork__: + before_fork_hook(self.__before_fork__) + if ( + getattr(self.__after_in_child_fork__, "__func__", None) + is not ForkMixin.__after_in_child_fork__ + ): + after_in_child_fork_hook(self.__after_in_child_fork__) + if ( + getattr(self.__after_in_parent_fork__, "__func__", None) + is not ForkMixin.__after_in_parent_fork__ + ): + after_in_parent_fork_hook(self.__after_in_parent_fork__) + def __after_in_child_fork__(self): """ A method that runs after the fork in the child process. """ - ... + pass def __after_in_parent_fork__(self): """ A method that runs after the fork in the parent process. """ - ... + pass def __before_fork__(self): """ A method that runs before the fork of a worker process. """ - ... - - -T_CLS = TypeVar("T_CLS", bound=Type[ForkHookProtocol]) - - -def fork_hook_class(cls: T_CLS) -> T_CLS: - """ - A decorator that runs a function after the fork of a worker process. - """ - if hasattr(cls, "__init__"): - orig_init = cls.__init__ - - @functools.wraps(orig_init) - def __init__(self, *args, **kwargs): - _after_in_child_fork_hooks[self] = "__after_in_child_fork__" - _after_in_parent_fork_hooks[self] = "__after_in_parent_fork__" - _before_fork_hooks[self] = "__before_fork__" - orig_init(self, *args, **kwargs) - - cls.__init__ = __init__ - else: - - def __init__(self, *args, **kwargs): - _after_in_child_fork_hooks[cls] = "__after_in_child_fork__" - _after_in_parent_fork_hooks[cls] = "__after_in_parent_fork__" - _before_fork_hooks[cls] = "__before_fork__" - cls(*args, **kwargs) - - cls.__init__ = __init__ - return cls + pass def _run_before_fork_hooks(): @@ -146,12 +140,9 @@ def _run_before_fork_hooks(): Run all the functions that were registered with the before_fork_hook decorator. """ # print(f"Running before_fork_hooks for pid={os.getpid()}") - for obj, hook in _before_fork_hooks.items(): + for hook in _before_fork_hooks.values(): # print(f"Running before_fork_hook for {hook}\n", end="") - if callable(hook): - hook() - else: - getattr(obj, hook)() + hook() def _run_after_in_child_fork_hooks(): @@ -159,12 +150,9 @@ def _run_after_in_child_fork_hooks(): Run all the functions that were registered with the after_in_child_fork_hook decorator. """ # print(f"Running after_in_child_fork_hooks for pid={os.getpid()}") - for obj, hook in _after_in_child_fork_hooks.items(): + for hook in _after_in_child_fork_hooks.values(): # print(f"Running after_in_child_fork_hook for {hook}\n", end="") - if callable(hook): - hook() - else: - getattr(obj, hook)() + hook() def _run_after_in_parent_fork_hooks(): @@ -172,12 +160,9 @@ def _run_after_in_parent_fork_hooks(): Run all the functions that were registered with the after_in_parent_fork_hook decorator. """ # print(f"Running after_in_parent_fork_hooks for pid={os.getpid()}") - for obj, hook in _after_in_parent_fork_hooks.items(): + for hook in _after_in_parent_fork_hooks.values(): # print(f"Running after_in_parent_fork_hook for {hook}\n", end="") - if callable(hook): - hook() - else: - getattr(obj, hook)() + hook() os.register_at_fork( diff --git a/src/megatron/energon/savable_loader.py b/src/megatron/energon/savable_loader.py index 7cd45bc1..227e8e0c 100644 --- a/src/megatron/energon/savable_loader.py +++ b/src/megatron/energon/savable_loader.py @@ -788,16 +788,15 @@ def __init__( **kwargs, ) - if self.worker_config.should_log(level=1): - self.worker_config.worker_log( - { - "t": "SavableLoader.__init__", - "r": self.worker_config.rank, - "w": None, - "id": self.id, - "config": dataset.config(), - } - ) + self.worker_config.worker_trace_writer().trace_object( + self, + "SavableDataLoader", + args={ + "id": self.id, + "config": dataset.config(), + }, + level=1, + ) @staticmethod def next_id() -> int: @@ -821,16 +820,15 @@ class InnerIterator: def __init__(self, iterator): self._iterator = iterator self.id = outerself.next_id() - if outerself.worker_config.should_log(level=1): - outerself.worker_config.worker_log( - { - "t": "SavableDataLoader.iter", - "r": outerself.worker_config.rank, - "w": None, - "id": outerself.id, - "iter_id": self.id, - } - ) + outerself.worker_config.worker_trace_writer().trace_object( + self, + "SavableDataLoader.iter", + args={ + "id": outerself.id, + "iter_id": self.id, + }, + level=1, + ) # self._debugf = open( # f"worker_samples_rank{outerself.worker_config.rank:02}_t{int(time.time())}.log", "w" @@ -851,22 +849,18 @@ def __next__(self): # self._debugf.flush() if outerself.worker_config.should_log(level=1): keys = default_get_keys(sample) - outerself.worker_config.worker_log( - { - **{ - "t": "SavableDataLoader.yield", - "r": outerself.worker_config.rank, - "w": None, - "id": outerself.id, - "iter_id": self.id, - "worker_id": worker_id, - "worker_idx": sample_idx, - "idx": outerself._sample_idx, - "iter_idx": self.iter_idx, - "global_idx": outerself._global_sample_idx, - }, + outerself.worker_config.worker_trace_writer().instant( + "SavableDataLoader.yield", + args={ + "id": outerself.id, + "iter_id": self.id, + "worker_id": worker_id, + "worker_idx": sample_idx, + "idx": outerself._sample_idx, + "iter_idx": self.iter_idx, + "global_idx": outerself._global_sample_idx, **({} if keys is None else {"keys": keys}), - } + }, ) outerself._sample_idx += 1 outerself._global_sample_idx += 1 @@ -875,16 +869,14 @@ def __next__(self): except StopIteration: self.finished = True outerself._next_worker_id = 0 - if outerself.worker_config.should_log(level=1): - outerself.worker_config.worker_log( - { - "t": "SavableDataLoader.StopIteration", - "r": outerself.worker_config.rank, - "w": None, - "id": outerself.id, - "iter_id": self.id, - } - ) + outerself.worker_config.worker_trace_writer().instant( + "SavableDataLoader.StopIteration", + args={ + "id": outerself.id, + "iter_id": self.id, + }, + level=1, + ) raise if self.num_workers > 0: @@ -1293,16 +1285,15 @@ def __init__( worker_init_fn=partial(_init_worker, seed_per_worker), **kwargs, ) - if self.worker_config.should_log(level=1): - self.worker_config.worker_log( - { - "t": "BasicDataLoader.__init__", - "r": self.worker_config.rank, - "w": None, - "id": self.id, - "config": self.config(), - } - ) + self.worker_config.worker_trace_writer().trace_object( + self, + "BasicDataLoader", + args={ + "id": self.id, + "config": self.config(), + }, + level=1, + ) def __iter__(self): outerself = self @@ -1320,16 +1311,15 @@ def __init__(self, iterator): self._iterator = iterator self.id = SavableDataLoader.next_id() - if outerself.worker_config.should_log(level=1): - outerself.worker_config.worker_log( - { - "t": "BasicDataLoader.iter", - "r": outerself.worker_config.rank, - "w": None, - "id": outerself.id, - "iter_id": self.id, - } - ) + outerself.worker_config.worker_trace_writer().trace_object( + self, + "BasicDataLoader.iter", + args={ + "id": outerself.id, + "iter_id": self.id, + }, + level=1, + ) def __iter__(self): return self @@ -1341,38 +1331,33 @@ def __next__(self): self.next_worker_id = (worker_id + 1) % max(outerself.num_workers, 1) if outerself.worker_config.should_log(level=1): keys = default_get_keys(sample) - outerself.worker_config.worker_log( - { - **{ - "t": "BasicDataLoader.yield", - "r": outerself.worker_config.rank, - "w": None, - "id": outerself.id, - "iter_id": self.id, - "worker_id": worker_id, - "worker_idx": sample_idx, - "idx": self.iter_idx, - "iter_idx": self.iter_idx, - "global_idx": outerself._sample_idx, - }, + outerself.worker_config.worker_trace_writer().instant( + "BasicDataLoader.yield", + args={ + "id": outerself.id, + "iter_id": self.id, + "worker_id": worker_id, + "worker_idx": sample_idx, + "idx": self.iter_idx, + "iter_idx": self.iter_idx, + "global_idx": outerself._sample_idx, **({} if keys is None else {"keys": keys}), - } + }, + level=1, ) outerself._sample_idx += 1 self.iter_idx += 1 return sample except StopIteration: self.next_worker_id = 0 - if outerself.worker_config.should_log(level=1): - outerself.worker_config.worker_log( - { - "t": "BasicDataLoader.StopIteration", - "r": outerself.worker_config.rank, - "w": None, - "id": outerself.id, - "iter_id": self.id, - } - ) + outerself.worker_config.worker_trace_writer().instant( + "BasicDataLoader.StopIteration", + args={ + "id": outerself.id, + "iter_id": self.id, + }, + level=1, + ) raise return InnerIterator(super().__iter__()) diff --git a/src/megatron/energon/tracing.py b/src/megatron/energon/tracing.py new file mode 100644 index 00000000..6280659a --- /dev/null +++ b/src/megatron/energon/tracing.py @@ -0,0 +1,1284 @@ +from __future__ import annotations + +import gc +import json +import os +import threading +import time +import weakref +from contextlib import AbstractContextManager +from dataclasses import asdict, is_dataclass +from pathlib import Path +from typing import ( + IO, + Any, + Callable, + ClassVar, + ContextManager, + Dict, + Generic, + Iterable, + Iterator, + Mapping, + Optional, + TypeVar, + Union, + cast, +) + +import numpy as np +import torch + +__all__ = [ + "TraceWriter", + "Span", + "AsyncSpan", + "AsyncFlow", + "ObjectTrace", + "NoopTraceWriter", +] + +T = TypeVar("T") + +_JSON_OPEN = b"[\n" +_JSON_NEXT = b",\n" +_JSON_CLOSE = b"]\n" + + +def _timestamp_us() -> int: + """Return current time in micro-seconds as int.""" + # Use time_ns, such that it's synchronized between processes. + return time.time_ns() // 1_000 # convert ns -> µs + + +class JsonEncoder(json.JSONEncoder): + """Custom JSON encoder that handles numpy arrays, torch tensors, and dataclasses.""" + + def default(self, o: Any) -> Any: + # Handle numpy arrays directly + if isinstance(o, (np.ndarray, torch.Tensor)): + try: + return o.tolist() + except Exception: + return str(o)[:250] + + # Handle dataclasses + if is_dataclass(o): + return {"__type__": type(o).__name__, **asdict(o)} + + return super().default(o) + + +class TraceWriter(AbstractContextManager): + """Chrome-trace writer with live-streaming capabilities. + + This helper produces trace logs that follow the *Trace Event Format* as + consumed by Chrome's ``chrome://tracing`` and the Perfetto UI. We output + the simplest JSON variant – a flat **array of event objects** – because it + can be concatenated on the fly. + + The public surface consists of one generic :py:meth:`emit` method that + serialises an *event dictionary* directly plus a set of convenience + helpers – :py:meth:`span`, :py:meth:`instant`, :py:meth:`async_begin`, + :py:meth:`flow_start`, :py:meth:`counter`, :py:meth:`object_new`, … – that + wrap the *phase* field (``ph``) semantics defined in the Chromium spec: + + ============ ============================================================= + Phase (``ph``) Helper(s) + ------------ ------------------------------------------------------------- + ``B``/``E`` :py:meth:`span` (or :pyclass:`Span` ctx-mgr) + ``i`` :py:meth:`instant` + ``b``/``n``/``e`` :py:meth:`async_begin`, :py:meth:`async_instant`, + :py:meth:`async_end` and the :pyclass:`AsyncSpan` + context-manager + ``s``/``t``/``f`` :py:meth:`flow_start`, :py:meth:`flow_step`, + :py:meth:`flow_end` + ``C`` :py:meth:`counter` + ``N``/``O``/``D`` :py:meth:`object_new`, :py:meth:`object_snapshot`, + :py:meth:`object_delete` and :pyclass:`ObjectTrace` + ============ ============================================================= + + For further background on each event family refer to the *Event + Descriptions* section in the Trace-Event specification. + """ + + _write_lock: threading.Lock + _pid: int + _events: int + _closed: bool + _stream: IO[bytes] + _own_stream: Optional[IO[bytes]] + _flush_interval: int + _pending: int + _log_level: int + + _global_next_id_lock: ClassVar[threading.Lock] = threading.Lock() + _global_next_id: ClassVar[int] = 0 + + def __init__( + self, + stream: Union[str, Path, IO[bytes]], + *, + pid: int | None = None, + log_level: int = 0, + ) -> None: + self._pid = pid if pid is not None else os.getpid() + self._events = 0 + self._closed = False + self._write_lock = threading.Lock() + + if isinstance(stream, (str, Path)): + # Ensure parent directory exists when stream is a path. + path = Path(stream) + path.parent.mkdir(parents=True, exist_ok=True) + self._stream = path.open("wb+") + buffering = os.stat(stream).st_blksize + self._own_stream = self._stream + self._flush_interval = int(buffering * 0.8) + else: + self._stream = stream + self._own_stream = None + try: + buffering = os.stat(stream).st_blksize + except Exception: + buffering = 4096 + self._flush_interval = int(buffering * 0.8) + + self._pending = 0 + + # logging level (lower is more verbose) — default 0 + self._log_level = log_level + + # Initialise the JSON array with a closing bracket so the file is + # syntactically complete right away. + self._stream.write(_JSON_OPEN + _JSON_CLOSE) + self._stream.flush() + + # --------------------------------------------------------------------- + # Low-level helpers + # --------------------------------------------------------------------- + + @classmethod + def _next_id(cls) -> int: + """Return a new unique identifier.""" + with cls._global_next_id_lock: + cls._global_next_id += 1 + return cls._global_next_id + + def _write_raw(self, json_event: bytes, *, flush: bool = False) -> None: + """Write raw *json_event* bytes keeping the trace JSON valid. Flushes the stream if needed. + + Args: + json_event: A fully-serialised event as UTF-8 encoded JSON bytes. + flush: If *True* the underlying stream is flushed after the write. + """ + with self._write_lock: + self._stream.seek(-len(_JSON_CLOSE), os.SEEK_END) + if self._events > 0: + json_event = _JSON_NEXT + json_event + _JSON_CLOSE + else: + json_event = json_event + _JSON_CLOSE + self._stream.write(json_event) + self._pending += len(json_event) + if flush or self._pending >= self._flush_interval: + self._stream.flush() + self._pending = 0 + + def close(self) -> None: + if not self._closed: + self._closed = True + with self._write_lock: + self._stream.flush() + if self._own_stream is not None: + self._own_stream.close() + self._own_stream = None + + def flush(self) -> None: + with self._write_lock: + self._stream.flush() + + def _emit(self, event: Dict[str, Any]) -> None: + """Serialize *event* mapping and append it to the trace. + + Args: + event: A dictionary that already fulfills the Trace-Event schema + expectations. + """ + json_event = json.dumps( + event, separators=(",", ":"), ensure_ascii=False, cls=JsonEncoder + ).encode("utf-8") + self._write_raw(json_event) + self._events += 1 + + # Convenience helpers -------------------------------------------------- + + def duration_begin( + self, + name: str, + *, + cat: str | None = None, + args: Optional[Dict[str, Any]] = None, + level: int = 0, + ) -> None: + """Emit a *duration* event pair. + + Args: + name: Displayed slice name. + cat: Optional comma-separated category list. + args: Extra arguments object to attach to both *B* and *E* events. + level: Logging level. + """ + if level < self._log_level: + return + event = { + "name": name, + "ph": "B", + "ts": _timestamp_us(), + "pid": self._pid, + "tid": threading.get_ident(), + } + if cat is not None: + event["cat"] = cat + if args: + event["args"] = dict(args) + self._emit(event) + + def duration_end( + self, + name: str, + *, + cat: str | None = None, + args: Optional[Dict[str, Any]] = None, + level: int = 0, + ) -> None: + """Emit the *end* of a *duration* event pair (``ph='E'``). + + Args: + name: Displayed slice name. + cat: Optional comma-separated category list. + args: Extra arguments object to attach to both *B* and *E* events. + level: Logging level. + """ + if level < self._log_level: + return + event = { + "name": name, + "ph": "E", + "ts": _timestamp_us(), + "pid": self._pid, + "tid": threading.get_ident(), + } + if cat is not None: + event["cat"] = cat + if args: + event["args"] = dict(args) + self._emit(event) + + def span( + self, + name: str, + *, + cat: str | None = None, + args: Optional[Dict[str, Any]] = None, + level: int = 0, + ) -> "Span": + """Return a context manager capturing a *duration* event pair. + + Args: + name: Displayed slice name. + cat: Optional comma-separated category list. + args: Extra arguments object to attach to both *B* and *E* events. + level: Logging level. + + Returns: + Span – a context manager emitting matching ``B``/``E`` events. + """ + if level < self._log_level: + return _NOOP_SPAN + return Span(self, name=name, cat=cat, args=args) + + def instant( + self, + name: str, + *, + cat: str | None = None, + scope: str = "t", + args: Optional[Dict[str, Any]] = None, + level: int = 0, + ) -> None: + """Emit a zero-duration *instant* event (``ph='i'``). + + Args: + name: Display name. + cat: Optional categories. + scope: Trace-viewer scope selector – ``t`` (thread), ``p`` (process) + or ``g`` (global). + args: Optional arguments payload. + level: Logging level. + """ + if level < self._log_level: + return + event = { + "name": name, + "ph": "i", + "ts": _timestamp_us(), + "pid": self._pid, + "tid": threading.get_ident(), + "s": scope, + } + if cat is not None: + event["cat"] = cat + if args: + event["args"] = dict(args) + self._emit(event) + + # Async events -------------------------------------------------------- + + def async_begin( + self, + name: str, + *, + id: Union[int, str, None] = None, + cat: str | None = None, + scope: str | None = None, + args: Optional[Dict[str, Any]] = None, + level: int = 0, + ) -> Union[int, str]: + """Start a *nestable async* chain (``ph='b'``). + + Args: + name: Event display name. + id: Correlation identifier (int or str). + cat: Optional categories. + scope: Extra scope string to avoid id collisions. + args: Optional argument object. + level: Logging level. + """ + if level < self._log_level: + return + if id is None: + id = self._next_id() + + event = { + "name": name, + "ph": "b", + "id": id, + "ts": _timestamp_us(), + "pid": self._pid, + "tid": threading.get_ident(), + } + if cat is not None: + event["cat"] = cat + if scope is not None: + event["scope"] = scope # avoid clash with "s" used by instant events + if args: + event["args"] = dict(args) + self._emit(event) + return id + + def async_instant( + self, + name: str, + *, + id: Union[int, str, None] = None, + cat: str | None = None, + scope: str | None = None, + args: Optional[Dict[str, Any]] = None, + level: int = 0, + ) -> None: + """Emit an *instant* step for a nestable async chain (``ph='n'``). + + Args: + name: Event name. + id: Correlation identifier. + cat: Categories. + scope: Optional scope string. + args: Additional arguments. + level: Logging level. + """ + if level < self._log_level: + return + if id is None: + id = self._next_id() + + event = { + "name": name, + "ph": "n", + "id": id, + "ts": _timestamp_us(), + "pid": self._pid, + "tid": threading.get_ident(), + } + if cat is not None: + event["cat"] = cat + if scope is not None: + event["scope"] = scope + if args: + event["args"] = dict(args) + self._emit(event) + + def async_end( + self, + name: str, + *, + id: Union[int, str], + cat: str | None = None, + scope: str | None = None, + args: Optional[Dict[str, Any]] = None, + level: int = 0, + ) -> None: + """Finish a *nestable async* chain (``ph='e'``). + + Args: + id: Correlation identifier. + cat: Categories. + scope: Optional scope string. + args: Additional arguments. + level: Logging level. + """ + if level < self._log_level: + return + event = { + "ph": "e", + "id": id, + "ts": _timestamp_us(), + "pid": self._pid, + "tid": threading.get_ident(), + } + if cat is not None: + event["cat"] = cat + if scope is not None: + event["scope"] = scope + if args: + event["args"] = dict(args) + self._emit(event) + + def async_span( + self, + name: str, + *, + id: Union[int, str, None] = None, + cat: str | None = None, + scope: str | None = None, + args: Optional[Dict[str, Any]] = None, + level: int = 0, + ) -> "AsyncSpan": + """Return an *AsyncSpan* context-manager for a nestable async chain. + + Args: + name: Display name. + id: Correlation identifier to keep events together. + cat: Categories. + scope: Optional scope string. + args: Arguments attached to the begin event. + level: Logging level. + + Returns: + AsyncSpan context manager. + """ + if level < self._log_level: + return _NOOP_ASYNC_SPAN + if id is None: + id = self._next_id() + + return AsyncSpan( + self, + name=name, + id=id, + cat=cat, + scope=scope, + args=args, + ) + + def async_flow( + self, + *, + id: Union[int, str, None] = None, + cat: str | None = None, + scope: str | None = None, + level: int = 0, + ) -> "AsyncFlow": + """Return an *AsyncFlow* context-manager for a nestable async chain. + + Args: + id: Correlation identifier. + cat: Categories. + scope: Optional scope string. + level: Logging level. + """ + if level < self._log_level: + return _NOOP_ASYNC_FLOW + if id is None: + id = self._next_id() + + return AsyncFlow( + self, + id=id, + cat=cat, + scope=scope, + ) + + # Counter events ------------------------------------------------------ + + def counter( + self, + name: str, + value: Union[int, float, Dict[str, Union[int, float]]], + *, + id: Union[int, str, None] = None, + cat: str | None = None, + level: int = 0, + ) -> None: + """Emit a numerical *counter* (``ph='C'``). + + Args: + name: Counter track name. + value: Either a single numeric value or a mapping of series-name to + numeric value. + id: Optional counter identifier (name+id pair becomes counter key). + cat: Categories. + level: Logging level. + """ + if level < self._log_level: + return + if isinstance(value, Mapping): + args_field = value + else: + args_field = {"value": value} + if id is None: + id = self._next_id() + + event = { + "name": name, + "ph": "C", + "ts": _timestamp_us(), + "pid": self._pid, + "tid": threading.get_ident(), + "args": args_field, + } + if id is not None: + event["id"] = id + if cat is not None: + event["cat"] = cat + self._emit(event) + + # Object events ------------------------------------------------------- + + def object_new( + self, + name: str, + *, + id: Union[int, str, None] = None, + cat: str | None = None, + scope: str | None = None, + level: int = 0, + ) -> None: + """Emit an object creation event (``ph='N'``). + + Args: + name: Object type/name displayed in UI. + id: Unique identifier (e.g. pointer address or GUID). + cat: Categories. + scope: Optional scope string to avoid id clashes. + level: Logging level. + """ + if level < self._log_level: + return + if id is None: + id = self._next_id() + + event = { + "name": name, + "ph": "N", + "id": id, + "ts": _timestamp_us(), + "pid": self._pid, + "tid": threading.get_ident(), + } + if cat is not None: + event["cat"] = cat + if scope is not None: + event["scope"] = scope + self._emit(event) + return id + + def object_snapshot( + self, + name: str, + *, + id: Union[int, str, None] = None, + snapshot: Dict[str, Any], + cat: str | None = None, + scope: str | None = None, + level: int = 0, + ) -> None: + """Emit an object *snapshot* (``ph='O'``). + + Args: + name: Object name. + id: Identifier matching a previously created object. + snapshot: Arbitrary JSON-serialisable state payload. + cat: Categories. + scope: Optional scope string. + level: Logging level. + """ + if level < self._log_level: + return + if id is None: + id = self._next_id() + + event = { + "name": name, + "ph": "O", + "id": id, + "ts": _timestamp_us(), + "pid": self._pid, + "tid": threading.get_ident(), + "args": {"snapshot": dict(snapshot)}, + } + if cat is not None: + event["cat"] = cat + if scope is not None: + event["scope"] = scope + self._emit(event) + + def object_delete( + self, + name: str, + *, + id: Union[int, str, None] = None, + cat: str | None = None, + scope: str | None = None, + level: int = 0, + ) -> None: + """Emit an object deletion event (``ph='D'``). + + Args: + name: Object name. + id: Identifier. + cat: Categories. + scope: Optional scope string. + level: Logging level. + """ + if level < self._log_level: + return + if id is None: + id = self._next_id() + + event = { + "name": name, + "ph": "D", + "id": id, + "ts": _timestamp_us(), + "pid": self._pid, + "tid": threading.get_ident(), + } + if cat is not None: + event["cat"] = cat + if scope is not None: + event["scope"] = scope + self._emit(event) + + # Helper -------------------------------------------------------------- + + def object_trace( + self, + name: str, + *, + id: Union[int, str, None] = None, + cat: str | None = None, + scope: str | None = None, + snapshot: Optional[Dict[str, Any]] = None, + level: int = 0, + ) -> "ObjectTrace": + """Create an :class:`ObjectTrace` helper. + + Args: + name: Object type/name. + id: Identifier to correlate with future snapshots/deletion. + cat: Categories. + scope: Optional scope string. + snapshot: Optional initial snapshot emitted right after ``N``. + level: Logging level. + + Returns: + ObjectTrace instance. + """ + if level < self._log_level: + return _NOOP_OBJECT_TRACE + if id is None: + id = self._next_id() + + return ObjectTrace( + self, + name=name, + id=id, + cat=cat, + scope=scope, + initial_snapshot=snapshot, + ) + + def trace_object( + self, + obj: Any, + name: str, + *, + cat: str | None = None, + scope: str | None = None, + args: Optional[Dict[str, Any]] = None, + level: int = 0, + ) -> "ObjectTrace": + """Attach tracing to an existing Python *obj* until GC. + + Args: + obj: Target instance to monitor. + name: Trace-viewer object name. + cat: Categories. + scope: Optional scope string. + level: Logging level. + + Returns: + ObjectTrace handle. + """ + if not gc.is_tracked(obj): + raise ValueError("Object is not tracked by the garbage collector") + if level < self._log_level: + return _NOOP_OBJECT_TRACE + trace = self.object_trace(name, id=id(obj), cat=cat, scope=scope) + weakref.finalize(obj, trace.delete) + if args: + trace.snapshot(args) + return trace + + # Metadata ------------------------------------------------------------ + + def metadata( + self: "TraceWriter", + name: str, + *, + args: Dict[str, Any], + pid: int | None = None, + tid: int | None = None, + ) -> None: + """Emit a generic *metadata* event (``ph='M'``). + + Args: + name: Metadata event name (e.g. ``process_name``). + args: Arguments dict as required by the spec. + pid: Override process id; defaults to writer.pid. + tid: Thread id; required for thread metadata. + """ + event = { + "name": name, + "ph": "M", + "pid": pid if pid is not None else self._pid, + } + if tid is not None: + event["tid"] = tid + if args: + event["args"] = dict(args) + self._emit(event) + + def metadata_process_name(self, name: str, *, pid: int | None = None) -> None: + self.metadata("process_name", args=dict(name=name), pid=pid) + + def metadata_process_labels(self, labels: str, *, pid: int | None = None) -> None: + self.metadata("process_labels", args=dict(labels=labels), pid=pid) + + def metadata_process_sort_index(self, sort_index: int, *, pid: int | None = None) -> None: + self.metadata("process_sort_index", args=dict(sort_index=sort_index), pid=pid) + + def metadata_thread_name( + self, name: str, *, tid: int | None = None, pid: int | None = None + ) -> None: + self.metadata( + "thread_name", args=dict(name=name), tid=tid or threading.get_ident(), pid=pid + ) + + def metadata_thread_sort_index( + self, sort_index: int, *, tid: int | None = None, pid: int | None = None + ) -> None: + self.metadata( + "thread_sort_index", + args=dict(sort_index=sort_index), + tid=tid or threading.get_ident(), + pid=pid, + ) + + # Context management --------------------------------------------------- + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): # noqa: D401, N802 + self.close() + return False + + # ------------------------------------------------------------------ + # Representation helpers + # ------------------------------------------------------------------ + + def __repr__(self) -> str: # pragma: no cover + status = "closed" if self._closed else f"{self._events} events" + return f"" + + +class Span(AbstractContextManager): + """Context manager for *duration* events. + + See :py:meth:`TraceWriter.span`. + """ + + __slots__ = ("_writer", "_name", "_cat", "_args", "_begin_ts") + _writer: TraceWriter + _name: str + _cat: Optional[str] + _args: Dict[str, Any] | None + + def __init__( + self, + writer: TraceWriter, + *, + name: str, + cat: str | None = None, + args: Optional[Dict[str, Any]] = None, + ) -> None: + self._writer = writer + self._name = name + self._cat = cat + self._args = args + + def begin(self) -> None: + self._writer.duration_begin(self._name, cat=self._cat, args=self._args) + self._args = None + + def update_args(self, args: Dict[str, Any]) -> None: + if self._args is None: + self._args = args + else: + self._args.update(args) + + def end(self) -> None: + self._writer.duration_end(self._name, cat=self._cat, args=self._args or None) + self._args = None + + # ------------------------------------------------------------------ + # Context management + # ------------------------------------------------------------------ + + def __enter__(self): # noqa: D401 + self.begin() + return self + + def __exit__(self, exc_type, exc, tb): # noqa: D401, N802 + self.end() + + +class AsyncSpan(AbstractContextManager): + """Context manager for *nestable async* events. + + Use :py:meth:`instant` for ``n`` events inside the span. + """ + + __slots__ = ( + "_writer", + "_name", + "_id", + "_cat", + "_scope", + "_args", + ) + + def __init__( + self, + writer: TraceWriter, + *, + name: str, + id: Union[int, str], + cat: str | None = None, + scope: str | None = None, + args: Optional[Dict[str, Any]] = None, + ) -> None: + self._writer = writer + self._name = name + self._id = id + self._cat = cat + self._scope = scope + self._args = args + + # ------------------------------------------------------------------ + # Context management + # ------------------------------------------------------------------ + + def begin(self) -> None: + self._writer.async_begin( + self._name, id=self._id, cat=self._cat, scope=self._scope, args=self._args or None + ) + self._args = None + + def __enter__(self): # noqa: D401 + self.begin() + return self + + def update_args(self, args: Dict[str, Any]) -> None: + if self._args is None: + self._args = args + else: + self._args.update(args) + + def end(self, args: Optional[Dict[str, Any]] = None) -> None: + if self._args and args: + self._args.update(args) + self._writer.async_end( + self._name, id=self._id, cat=self._cat, scope=self._scope, args=self._args or None + ) + self._args = None + + def __exit__(self, exc_type, exc, tb): # noqa: D401, N802 + self.end() + + +class AsyncFlow(AbstractContextManager): + """Context manager for *nestable async* events.""" + + __slots__ = ( + "_writer", + "_id", + "_cat", + "_scope", + ) + + def __init__( + self, + writer: TraceWriter, + *, + id: Union[int, str], + cat: str | None = None, + scope: str | None = None, + ) -> None: + self._writer = writer + self._id = id + self._cat = cat + self._scope = scope + + # ------------------------------------------------------------------ + # Context management + # ------------------------------------------------------------------ + + def __enter__(self): # noqa: D401 + return self + + def instant(self, name: str, *, args: Optional[Dict[str, Any]] = None, level: int = 0) -> None: + """Emit an async *instant* (``ph='n'``) event within this async flow.""" + self._writer.async_instant( + name, id=self._id, cat=self._cat, scope=self._scope, args=args, level=level + ) + + def start(self, name: str, *, args: Optional[Dict[str, Any]] = None, level: int = 0) -> None: + """Emit an async *start* (``ph='b'``) event within this async flow.""" + self._writer.async_begin( + name, id=self._id, cat=self._cat, scope=self._scope, args=args, level=level + ) + + def end(self, name: str, *, args: Optional[Dict[str, Any]] = None, level: int = 0) -> None: + """Emit an async *end* (``ph='e'``) event within this async flow.""" + self._writer.async_end( + name, id=self._id, cat=self._cat, scope=self._scope, args=args, level=level + ) + + def span( + self, name: str, *, args: Optional[Dict[str, Any]] = None, level: int = 0 + ) -> AsyncSpan: + """Emit an async *span* (``ph='s'``) event within this async flow.""" + return self._writer.async_span( + name, id=self._id, cat=self._cat, scope=self._scope, args=args, level=level + ) + + def iterable(self, iterable: Iterable[T], *, name: str, level: int = 0) -> Iterable[T]: + """Wrap an iterable to emit trace events for each `next` call.""" + if level < self._writer._log_level: + return iterable + return IterableNextWrapper(iterable, span=lambda: self.span(name)) + + def exception(self, exc: Exception, *, name: str, level: int = 0) -> None: + """Emit an exception event.""" + self._writer.exception( + exc, name=name, id=self._id, cat=self._cat, scope=self._scope, level=level + ) + + def __exit__(self, exc_type, exc, tb): # noqa: D401, N802 + pass + + +class IterableNextWrapper(Iterator[T], Generic[T]): + """A wrapper for an iterable that emits trace events for each `next` call.""" + + __slots__ = ( + "_iterable", + "_name", + ) + + _iterator: Iterator[T] + _span: Callable[[], ContextManager] + + def __init__(self, iterable: Iterable[T], *, span: Callable[[], ContextManager]): + self._iterator = iter(iterable) + self._span = span + + def __iter__(self): + return self + + def __next__(self): + with self._span(): + return next(self._iterator) + + +class ObjectTrace(AbstractContextManager): + """Lifecycle helper for Trace-Event objects. + + Emits ``N`` on construction, :py:meth:`snapshot` for ``O`` and ``D`` upon + deletion, context exit, or garbage collection. + """ + + __slots__ = ( + "_writer", + "_name", + "_id", + "_cat", + "_scope", + "_deleted", + ) + + def __init__( + self, + writer: TraceWriter, + *, + name: str, + id: Union[int, str], + cat: str | None = None, + scope: str | None = None, + initial_snapshot: Optional[Dict[str, Any]] = None, + ) -> None: + self._writer = writer + self._name = name + self._id = id + self._cat = cat + self._scope = scope + self._deleted = False + + # Emit object creation event + self._writer.object_new(name, id=id, cat=cat, scope=scope) + + if initial_snapshot is not None: + self.snapshot(initial_snapshot) + + # ------------------------------------------------------------------ + # API + # ------------------------------------------------------------------ + + def snapshot(self, data: Dict[str, Any], *, level: int = 0) -> None: + """Emit snapshot for current state of the object.""" + if self._deleted: + raise RuntimeError("Cannot snapshot deleted traced object") + self._writer.object_snapshot( + self._name, + id=self._id, + snapshot=data, + cat=self._cat, + scope=self._scope, + level=level, + ) + + def delete(self) -> None: + """Emit delete event if not already emitted.""" + if not self._deleted: + self._writer.object_delete( + self._name, + id=self._id, + cat=self._cat, + scope=self._scope, + ) + self._deleted = True + + # ------------------------------------------------------------------ + # Context management + # ------------------------------------------------------------------ + + def __enter__(self): # noqa: D401 + return self + + def __exit__(self, exc_type, exc, tb): # noqa: D401, N802 + self.delete() + # Do not suppress exceptions + return False + + def __del__(self): # noqa: D401 + # Ensure deletion event when object garbage-collected + try: + self.delete() + except Exception: + pass + + +# ------------------------------------------------------------------ +# Noop implementations +# ------------------------------------------------------------------ + + +class _NoopSpan(AbstractContextManager): + def begin(self, *args, **kwargs) -> None: + pass + + def update_args(self, *args, **kwargs) -> None: + pass + + def end(self, *args, **kwargs) -> None: + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + pass + + +_NOOP_SPAN = cast(Span, _NoopSpan()) + + +class NoopAsyncSpan(AbstractContextManager): + def begin(self, *args, **kwargs) -> None: + pass + + def update_args(self, *args, **kwargs) -> None: + pass + + def end(self, *args, **kwargs) -> None: + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + pass + + +_NOOP_ASYNC_SPAN = cast(AsyncSpan, NoopAsyncSpan()) + + +class NoopAsyncFlow(AbstractContextManager): + def instant(self, *args, **kwargs) -> None: + pass + + def async_start(self, *args, **kwargs) -> None: + pass + + def async_end(self, *args, **kwargs) -> None: + pass + + def span(self, *args, **kwargs) -> AsyncSpan: + return _NOOP_ASYNC_SPAN + + def iterable(self, iterable, *args, **kwargs) -> Iterable: + return iterable + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + pass + + +_NOOP_ASYNC_FLOW = cast(AsyncFlow, NoopAsyncFlow()) + + +class NoopObjectTrace(AbstractContextManager): + def snapshot(self, *args, **kwargs) -> None: + pass + + def delete(self, *args, **kwargs) -> None: + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + pass + + +_NOOP_OBJECT_TRACE = cast(ObjectTrace, NoopObjectTrace()) + + +class NoopTraceWriter: + """A trace writer that does nothing. Used when tracing is disabled.""" + + def duration_begin(self, *args, **kwargs) -> None: + pass + + def duration_end(self, *args, **kwargs) -> None: + pass + + def span(self, *args, **kwargs) -> "Span": + return _NOOP_SPAN + + def instant(self, *args, **kwargs) -> None: + pass + + def async_begin(self, *args, **kwargs) -> None: + pass + + def async_instant(self, *args, **kwargs) -> None: + pass + + def async_end(self, *args, **kwargs) -> None: + pass + + def async_span(self, *args, **kwargs) -> "AsyncSpan": + return _NOOP_ASYNC_SPAN + + def async_flow(self, *args, **kwargs) -> "AsyncFlow": + return _NOOP_ASYNC_FLOW + + def flow_start(self, *args, **kwargs) -> None: + pass + + def flow_step(self, *args, **kwargs) -> None: + pass + + def flow_end(self, *args, **kwargs) -> None: + pass + + def counter(self, *args, **kwargs) -> None: + pass + + def object_new(self, *args, **kwargs) -> None: + pass + + def object_snapshot(self, *args, **kwargs) -> None: + pass + + def object_delete(self, *args, **kwargs) -> None: + pass + + def object_trace(self, *args, **kwargs) -> "ObjectTrace": + return _NOOP_OBJECT_TRACE + + def trace_object(self, *args, **kwargs) -> "ObjectTrace": + return _NOOP_OBJECT_TRACE + + def metadata(self, *args, **kwargs) -> None: + pass + + def close(self) -> None: + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + pass + + def __repr__(self) -> str: + return "" + + +NOOP_TRACE_WRITER: TraceWriter = cast(TraceWriter, NoopTraceWriter()) diff --git a/src/megatron/energon/worker.py b/src/megatron/energon/worker.py index 8ea9bed7..cb72e24d 100644 --- a/src/megatron/energon/worker.py +++ b/src/megatron/energon/worker.py @@ -2,7 +2,6 @@ # SPDX-License-Identifier: BSD-3-Clause import hashlib -import json import multiprocessing from pathlib import Path from typing import Any, ClassVar, Dict, List, Optional, TextIO, TypeVar @@ -12,6 +11,8 @@ from megatron.energon.cache import CachePool from megatron.energon.dataclass_slots import dataclass_slots +from megatron.energon.fork_hook import DataclassForkMixin +from megatron.energon.tracing import NOOP_TRACE_WRITER, AsyncFlow, TraceWriter __all__ = ("WorkerConfig",) @@ -19,7 +20,7 @@ @dataclass_slots(eq=False) -class WorkerConfig: +class WorkerConfig(DataclassForkMixin): """ Provides information about the current worker and the global configuration. This gives each data parallel rank its proper config. Every `rank` (up to `world_size-1`) must be used. @@ -57,10 +58,10 @@ class WorkerConfig: worker_debug_path: Optional[str] = None #: Log level for worker logging. worker_log_level: int = 0 + #: The current trace writer for the worker. + _worker_trace_writer: Optional[TraceWriter] = None #: The opened file for the current worker. Should not be set from outside. _worker_debug_file: Optional[TextIO] = None - #: worker_id of the opened worker debug file - _worker_debug_file_worker_id: Optional[int] = None #: The current sample index within the current iterating worker _sample_index_stack: ClassVar[Optional[List[int]]] = None @@ -68,7 +69,7 @@ class WorkerConfig: active_worker_config: ClassVar[Optional["WorkerConfig"]] = None #: The global rank override for the worker. Required for restoring samples. - _worker_override_global_rank: ClassVar[Optional[List[int]]] = None + _worker_override_global_rank: ClassVar[Optional[int]] = None #: The current cache pool for the worker. _cache_pool: ClassVar[Optional[CachePool]] = None @@ -252,27 +253,43 @@ def config(self) -> Dict[str, Any]: def should_log(self, level: int) -> bool: return level <= self.worker_log_level - def worker_log(self, data: dict) -> None: - """Logs the given data to the worker debug file.""" + def __after_in_child_fork__(self): + if self._worker_trace_writer is not None: + self._worker_trace_writer.close() + self._worker_trace_writer = None + + def __before_fork__(self): + if self._worker_trace_writer is not None: + self._worker_trace_writer.flush() + + def worker_trace_writer(self) -> TraceWriter: if self.worker_debug_path is None: - print(json.dumps(data) + "\n", end="", flush=True) - else: + return NOOP_TRACE_WRITER + if self._worker_trace_writer is None: in_worker = torch.utils.data.get_worker_info() is not None # Additional "worker" with rank_worker_id=0 is the main process. All workers have +1 # as their worker_id. worker_id = ( self.rank * (self.num_workers + 1) + self.rank_worker_id() + (1 if in_worker else 0) ) - if self._worker_debug_file is None or self._worker_debug_file_worker_id != worker_id: - if self._worker_debug_file is not None: - self._worker_debug_file.close() - path = Path( - self.worker_debug_path.format( - worker_id=worker_id, pid=multiprocessing.current_process().ident - ) + if self._worker_trace_writer is not None: + self._worker_trace_writer.close() + path = Path( + self.worker_debug_path.format( + worker_id=worker_id, pid=multiprocessing.current_process().ident ) - path.parent.mkdir(exist_ok=True, parents=True) - self._worker_debug_file = path.open("w") - self._worker_debug_file_worker_id = worker_id - self._worker_debug_file.write(json.dumps(data) + "\n") - self._worker_debug_file.flush() + ) + path.parent.mkdir(exist_ok=True, parents=True) + proc_name = f"dprank{self.global_rank()}" + if in_worker: + proc_name += f"_worker{self.rank_worker_id()}" + self._worker_trace_writer = TraceWriter(path, log_level=self.worker_log_level) + self._worker_trace_writer.metadata_process_name(multiprocessing.current_process().name) + self._worker_trace_writer.metadata_process_labels(proc_name) + self._worker_trace_writer.metadata_process_sort_index(worker_id) + self._worker_trace_writer.metadata_thread_name("worker_main") + self._worker_trace_writer.metadata_thread_sort_index(0) + return self._worker_trace_writer + + def worker_trace_span(self) -> AsyncFlow: + return self.worker_trace_writer().async_flow() diff --git a/src/megatron/energon/wrappers/batch_dataset.py b/src/megatron/energon/wrappers/batch_dataset.py index 5439ff6b..0ce8c9f6 100644 --- a/src/megatron/energon/wrappers/batch_dataset.py +++ b/src/megatron/energon/wrappers/batch_dataset.py @@ -102,52 +102,30 @@ def __len__(self): ) def __iter__(self) -> Iterator[T_batch]: - batch: List[T_batch_sample] = [] - sample_restore_keys = [] - - if self._generator_sample_keys is not None: - sample_restore_keys = self._generator_sample_keys - assert self._generator_offset is not None - batch = [self.dataset.restore_sample(inner_idx) for inner_idx in sample_restore_keys] - with self._sample_index.ctx(self._sample_index.current_idx) as sample_idx: - batch_sample = self.batcher(batch) - assert isinstance(batch_sample, Generator) - assert inspect.isgeneratorfunction(self.batcher), ( - f"Generator in {self.batcher} but not marked as such." - ) - target_offset = self._generator_offset - self._generator_offset = 0 - for batch_sub_idx, (sample_idx, inner_batch_sample) in enumerate( - self._sample_index.iter_ctx(batch_sample, sample_idx) - ): - # Skip other samples - if batch_sub_idx >= target_offset: - self._generator_offset = batch_sub_idx + 1 - yield set_sample_restore_key( - inner_batch_sample, - sample_idx, - batch_sub_idx, - *sample_restore_keys, - src=self, - ) - self._generator_sample_keys = None - self._generator_offset = None - batch.clear() + trace_span = self.worker_config.worker_trace_span() + with trace_span.span("BatchDataset.__iter__", args={"config": self._own_config()}, level=1): + batch: List[T_batch_sample] = [] sample_restore_keys = [] - def flush(): - try: - with self._sample_index.ctx() as sample_idx: + if self._generator_sample_keys is not None: + sample_restore_keys = self._generator_sample_keys + assert self._generator_offset is not None + batch = [ + self.dataset.restore_sample(inner_idx) for inner_idx in sample_restore_keys + ] + with self._sample_index.ctx(self._sample_index.current_idx) as sample_idx: batch_sample = self.batcher(batch) - if isinstance(batch_sample, Generator): - assert inspect.isgeneratorfunction(self.batcher), ( - f"Generator in {self.batcher} but not marked as such." - ) - self._generator_sample_keys = sample_restore_keys - self._generator_offset = 0 - for batch_sub_idx, (sample_idx, inner_batch_sample) in enumerate( - self._sample_index.iter_ctx(batch_sample, sample_idx) - ): + assert isinstance(batch_sample, Generator) + assert inspect.isgeneratorfunction(self.batcher), ( + f"Generator in {self.batcher} but not marked as such." + ) + target_offset = self._generator_offset + self._generator_offset = 0 + for batch_sub_idx, (sample_idx, inner_batch_sample) in enumerate( + self._sample_index.iter_ctx(batch_sample, sample_idx) + ): + # Skip other samples + if batch_sub_idx >= target_offset: self._generator_offset = batch_sub_idx + 1 yield set_sample_restore_key( inner_batch_sample, @@ -156,27 +134,70 @@ def flush(): *sample_restore_keys, src=self, ) - self._generator_sample_keys = None - self._generator_offset = None - else: - set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self) - yield batch_sample - sample_restore_keys.clear() - except SkipSample: - pass - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(batch) - except Exception as e: - self.error_handler(e, batch) - - for sample in self.dataset: - batch.append(sample) - sample_restore_keys.append(get_sample_restore_key(sample)) - if len(batch) == self.batch_size: + self._generator_sample_keys = None + self._generator_offset = None + batch.clear() + sample_restore_keys = [] + + def flush() -> Generator[T_batch, None, None]: + try: + with self._sample_index.ctx() as sample_idx: + batch_sample = self.batcher(batch) + if isinstance(batch_sample, Generator): + assert inspect.isgeneratorfunction(self.batcher), ( + f"Generator in {self.batcher} but not marked as such." + ) + self._generator_sample_keys = sample_restore_keys + self._generator_offset = 0 + for batch_sub_idx, (sample_idx, inner_batch_sample) in enumerate( + self._sample_index.iter_ctx(batch_sample, sample_idx) + ): + self._generator_offset = batch_sub_idx + 1 + yield set_sample_restore_key( + inner_batch_sample, + sample_idx, + batch_sub_idx, + *sample_restore_keys, + src=self, + ) + self._generator_sample_keys = None + self._generator_offset = None + else: + set_sample_restore_key( + batch_sample, sample_idx, *sample_restore_keys, src=self + ) + yield batch_sample + sample_restore_keys.clear() + except SkipSample: + trace_span.instant("BatchDataset.__iter__.skip", level=2) + except SYSTEM_EXCEPTIONS: + raise FatalSampleError.from_sample(batch) + except Exception as e: + self.error_handler(e, batch) + trace_span.instant( + "BatchDataset.__iter__.error/skip", + args={"exception": f"{type(e).__name__}: {str(e)}"}, + level=2, + ) + + batch_span = trace_span.span("BatchDataset.__iter__.collect", level=2) + + try: + for sample in self.dataset: + batch.append(sample) + sample_restore_keys.append(get_sample_restore_key(sample)) + if len(batch) == self.batch_size: + batch_span.end() + yield from flush() + batch = [] + batch_span = trace_span.span("BatchDataset.__iter__.collect", level=2) + finally: + batch_span.end() + if len(batch) > 0 and not self.drop_last: + batch_span = trace_span.span( + "BatchDataset.__iter__.last", args={"batch_size": len(batch)}, level=1 + ) yield from flush() - batch = [] - if len(batch) > 0 and not self.drop_last: - yield from flush() def can_restore_sample(self) -> bool: # Cannot really verify if the returned elements contain a __restore_key__. @@ -225,6 +246,26 @@ def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_batch: src=self, ) + def _own_config(self) -> Dict[str, Any]: + return { + "batch_size": self.batch_size, + "batcher": self._function_config(self.batcher), + **( + { + "batcher_config": ( + self.batcher_config() + if callable(self.batcher_config) + else self.batcher_config + ) + } + if self.batcher_config + else {} + ), + "batcher_stateless": self.batcher_stateless, + "drop_last": self.drop_last, + "error_handler": self._function_config(self.error_handler), + } + def config(self) -> Dict[str, Any]: return { "type": type(self).__qualname__, diff --git a/src/megatron/energon/wrappers/blend_dataset.py b/src/megatron/energon/wrappers/blend_dataset.py index 5db8b682..69cebfcc 100644 --- a/src/megatron/energon/wrappers/blend_dataset.py +++ b/src/megatron/energon/wrappers/blend_dataset.py @@ -57,54 +57,70 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[T_sample]: assert self.worker_has_samples(), "Cannot blend all empty datasets" - # Create a list of datasets and their weights, but - # set the weight to 0 if the dataset has no samples on this worker. - - dataset_iters = [] - weights = [] - for idx, (dataset, weight) in enumerate(self.dataset_weights): - assert weight > 0, "All blending weights must be > 0" - - if dataset.worker_has_samples(): - dataset_iters.append(iter(dataset)) - weights.append(weight) - else: - dataset_iters.append(None) - weights.append(0) - - weights = torch.tensor(weights, dtype=torch.float32) - if weights.sum() == 0: - raise RuntimeError( - "There is a worker with no samples in any of the blended datasets. " - "This can happen if you have a lot of workers and your dataset is too small. " - "Currently this case is not supported." - ) - - # Some may already be exhausted on this worker when restoring a state. - for idx, exhausted in enumerate(self.exhausted): - if exhausted: - weights[idx] = 0 - dataset_iters[idx] = None - - while True: - ds_idx = self._worker_rng.choice_idx(probs=weights) - - if dataset_iters[ds_idx] is None: - if all(dataset_iter is None for dataset_iter in dataset_iters): - break - continue - try: - sample = next(dataset_iters[ds_idx]) - except StopIteration: - dataset_iters[ds_idx] = None - weights[ds_idx] = 0 - self.exhausted[ds_idx] = True - if all(dataset_iter is None for dataset_iter in dataset_iters): - break - else: - yield add_sample_restore_key(sample, ds_idx, src=self) - - self.exhausted = [False] * len(self.dataset_weights) + trace_span = self.worker_config.worker_trace_span() + with trace_span.span("BlendDataset.__iter__", args={"config": self._own_config()}, level=1): + # Create a list of datasets and their weights, but + # set the weight to 0 if the dataset has no samples on this worker. + + dataset_iters = [] + weights = [] + for idx, (dataset, weight) in enumerate(self.dataset_weights): + assert weight > 0, "All blending weights must be > 0" + + if dataset.worker_has_samples(): + dataset_iters.append(iter(dataset)) + weights.append(weight) + else: + dataset_iters.append(None) + weights.append(0) + + weights = torch.tensor(weights, dtype=torch.float32) + if weights.sum() == 0: + raise RuntimeError( + "There is a worker with no samples in any of the blended datasets. " + "This can happen if you have a lot of workers and your dataset is too small. " + "Currently this case is not supported." + ) + + # Some may already be exhausted on this worker when restoring a state. + for idx, exhausted in enumerate(self.exhausted): + if exhausted: + weights[idx] = 0 + dataset_iters[idx] = None + + while True: + ds_idx = self._worker_rng.choice_idx(probs=weights) + trace_span.instant( + "BlendDataset.__iter__.sample", + args={"weights": weights, "ds_idx": ds_idx}, + level=2, + ) + + if dataset_iters[ds_idx] is None: + if all(dataset_iter is None for dataset_iter in dataset_iters): + break + continue + try: + sample = next(dataset_iters[ds_idx]) + except StopIteration: + trace_span.instant( + "BlendDataset.__iter__.exhausted", args={"ds_idx": ds_idx}, level=1 + ) + dataset_iters[ds_idx] = None + weights[ds_idx] = 0 + self.exhausted[ds_idx] = True + if all(dataset_iter is None for dataset_iter in dataset_iters): + break + else: + yield add_sample_restore_key(sample, ds_idx, src=self) + + trace_span.instant("BlendDataset.__iter__.reset", level=1) + self.exhausted = [False] * len(self.dataset_weights) + + def _own_config(self) -> Dict[str, Any]: + return { + "dataset_weights": [weight for _, weight in self.dataset_weights], + } def config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/wrappers/concat_dataset.py b/src/megatron/energon/wrappers/concat_dataset.py index 7388e3b2..a76d365b 100644 --- a/src/megatron/energon/wrappers/concat_dataset.py +++ b/src/megatron/energon/wrappers/concat_dataset.py @@ -33,13 +33,19 @@ def __len__(self): return sum(len(dataset) for dataset in self.datasets) def __iter__(self) -> Iterator[T_sample]: - for ds_idx, dataset in enumerate(self.datasets): - for sample in dataset: - yield add_sample_restore_key( - sample, - ds_idx, - src=self, - ) + trace_span = self.worker_config.worker_trace_span() + with trace_span.span("ConcatDataset.__iter__", level=1): + for ds_idx, dataset in enumerate(self.datasets): + with trace_span.span( + "ConcatDataset.next_dataset.yield_from", args={"ds_idx": ds_idx}, level=1 + ): + for sample in dataset: + yield add_sample_restore_key( + sample, + ds_idx, + src=self, + ) + trace_span.instant("ConcatDataset.__iter__.done", level=1) def config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/wrappers/epochize_dataset.py b/src/megatron/energon/wrappers/epochize_dataset.py index 6058cd27..8d85ea73 100644 --- a/src/megatron/energon/wrappers/epochize_dataset.py +++ b/src/megatron/energon/wrappers/epochize_dataset.py @@ -49,8 +49,8 @@ def reset_state_own(self) -> None: self._offset = 0 def __iter__(self) -> Iterator[T_sample]: + trace_span = self.worker_config.worker_trace_span() # Compute the local length for this worker, i.e. all worker's lengths sum up to the total - if self.worker_config.num_workers <= 1: local_length = self.length else: @@ -58,54 +58,44 @@ def __iter__(self) -> Iterator[T_sample]: if self.worker_config.rank_worker_id() < self.length % self.worker_config.num_workers: local_length += 1 - if self.worker_config.should_log(level=2): - self.worker_config.worker_log( - { - "t": "EpochizeDataset.epoch_start", - "r": self.worker_config.rank, - "w": self.worker_config.rank_worker_id(), - "offset": self._offset, - "local_length": local_length, - "length": self.length, - } - ) - - offset_range = list(range(self._offset, local_length)) - - # Only iterate if there are samples to iterate - if len(offset_range) > 0: - if self._active_iter is None: - self._active_iter = iter(self.dataset) - - for idx in offset_range: - self._offset = (idx + 1) % local_length - try: - sample = next(self._active_iter) - except StopIteration: - break - yield sample - - if self.worker_config.should_log(level=2): - self.worker_config.worker_log( - { - "t": "EpochizeDataset.epoch_end", - "r": self.worker_config.rank, - "w": self.worker_config.rank_worker_id(), - "offset": self._offset, - "local_length": local_length, - "length": self.length, - } - ) + with trace_span.span( + "EpochizeDataset.__iter__", + args={ + "offset": self._offset, + "local_length": local_length, + "config": self._own_config(), + }, + level=1, + ): + offset_range = list(range(self._offset, local_length)) + + # Only iterate if there are samples to iterate + if len(offset_range) > 0: + if self._active_iter is None: + self._active_iter = iter(self.dataset) + + for idx in offset_range: + self._offset = (idx + 1) % local_length + try: + sample = next(self._active_iter) + except StopIteration: + break + yield sample + trace_span.instant("EpochizeDataset.__iter__.done", level=1) def __len__(self) -> int: return self.length + def _own_config(self) -> Dict[str, Any]: + return { + "length": self.length, + } + def config(self) -> Dict[str, Any]: return { "type": type(self).__qualname__, "dataset": self.dataset.config(), "length": self.length, - "worker_config": self.worker_config.config(), } def __str__(self): diff --git a/src/megatron/energon/wrappers/filter_dataset.py b/src/megatron/energon/wrappers/filter_dataset.py index ff5e09a7..93ecdde8 100644 --- a/src/megatron/energon/wrappers/filter_dataset.py +++ b/src/megatron/energon/wrappers/filter_dataset.py @@ -51,15 +51,38 @@ def __len__(self): return len(self.dataset) def __iter__(self) -> Iterator[T_sample]: - for sample in self.dataset: - with self._sample_index.ctx(): - filter_res = self.filter_fn(sample) - if filter_res: - yield sample + with self.worker_config.worker_trace_span().span( + "FilterDataset.__iter__", args={"config": self._own_config()}, level=1 + ): + for sample in self.dataset: + with self._sample_index.ctx(): + filter_res = self.filter_fn(sample) + if filter_res: + yield sample + else: + self.worker_config.worker_trace_span().instant( + "FilterDataset.__iter__.reject", level=3 + ) def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_sample: return self.dataset.restore_sample(index) + def _own_config(self) -> Dict[str, Any]: + return { + "filter_fn": self._function_config(self.filter_fn), + **( + { + "filter_fn_config": ( + self.filter_fn_config() + if callable(self.filter_fn_config) + else self.filter_fn_config + ) + } + if self.filter_fn_config + else {} + ), + } + def config(self) -> Dict[str, Any]: return { "type": type(self).__qualname__, diff --git a/src/megatron/energon/wrappers/gc_dataset.py b/src/megatron/energon/wrappers/gc_dataset.py index b8a0d4c4..ac5a379b 100644 --- a/src/megatron/energon/wrappers/gc_dataset.py +++ b/src/megatron/energon/wrappers/gc_dataset.py @@ -101,6 +101,7 @@ def __len__(self): return len(self.dataset) def __iter__(self) -> Iterator[T_sample]: + trace_span = self.worker_config.worker_trace_span() in_worker = torch.utils.data.get_worker_info() is not None if in_worker and not _frozen_cuda_tensors_initialized: raise GcFreezeError( @@ -108,19 +109,28 @@ def __iter__(self) -> Iterator[T_sample]: ) if self.freeze: - gc.collect() - gc.freeze() - try: - iter = 0 - for sample in self.dataset: - yield sample - iter += 1 - if iter >= self.every_n_iter: - gc.collect() - iter = 0 - finally: - if self.freeze: - gc.unfreeze() + with trace_span.span("GcDataset.__iter__.gc.freeze", level=1): + gc.collect() + gc.freeze() + with trace_span.span("GcDataset.__iter__", args={"config": self._own_config()}, level=1): + try: + iter = 0 + for sample in self.dataset: + yield sample + iter += 1 + if iter >= self.every_n_iter: + with trace_span.span("GcDataset.__iter__.gc.collect", level=1): + gc.collect() + iter = 0 + finally: + if self.freeze: + gc.unfreeze() + + def _own_config(self) -> Dict[str, Any]: + return { + "every_n_iter": self.every_n_iter, + "freeze": self.freeze, + } def config(self) -> Dict[str, Any]: # This is transparent, no config to be saved (it does not affect the dataset) diff --git a/src/megatron/energon/wrappers/group_batch_dataset.py b/src/megatron/energon/wrappers/group_batch_dataset.py index b6907a50..c03fe4b8 100644 --- a/src/megatron/energon/wrappers/group_batch_dataset.py +++ b/src/megatron/energon/wrappers/group_batch_dataset.py @@ -119,81 +119,108 @@ def __len__(self): def __iter__(self) -> Iterator[T_batch]: buckets = self._buckets - if buckets is None: - buckets = self._buckets = dict() - - # Load saved state if available - for bucket in buckets.values(): - bucket.samples.worker_start() - - # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] initial GroupBatchDataset state:\n", end="") - # for bucket_key, bucket in buckets.items(): - # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] - Bucket [{bucket_key}] (bs={bucket.batch_size}, len(samples)={len(bucket.samples)}):\n", end="") - # bucket.samples.debug_print(" ") - # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] initial done\n", end="") - - def flush(bucket: Bucket[T_batch_sample]) -> Generator[T_batch, None, None]: - # Debug print the state - # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] flush GroupBatchDataset state:\n", end="") - # for dbg_bucket_key, dbg_bucket in buckets.items(): - # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] - Bucket [{dbg_bucket_key}{'*' if dbg_bucket_key == bucket_key else ''}] (bs={dbg_bucket.batch_size}, len(samples)={len(dbg_bucket.samples)}):\n", end="") - # dbg_bucket.samples.debug_print(" ") - batch_items, sample_restore_keys = bucket.samples.flush() - # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] flushed: len(batch)={len(batch_items)} len(samples)={len(bucket.samples)}\n", end="") - try: - with self._batch_sample_index.ctx() as sample_idx: - batch_sample = self.batcher(batch_items) - assert not isinstance(batch_sample, Generator), ( - f"Batcher {self.batcher} returned a generator, which is not supported for grouped batching yet." + trace_span = self.worker_config.worker_trace_span() + with trace_span.span( + "GroupBatchDataset.__iter__", args={"config": self._own_config()}, level=1 + ): + if buckets is None: + buckets = self._buckets = dict() + + # Load saved state if available + for bucket in buckets.values(): + bucket.samples.worker_start() + + # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] initial GroupBatchDataset state:\n", end="") + # for bucket_key, bucket in buckets.items(): + # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] - Bucket [{bucket_key}] (bs={bucket.batch_size}, len(samples)={len(bucket.samples)}):\n", end="") + # bucket.samples.debug_print(" ") + # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] initial done\n", end="") + + def flush(key: Any, bucket: Bucket[T_batch_sample]) -> Generator[T_batch, None, None]: + # Debug print the state + # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] flush GroupBatchDataset state:\n", end="") + # for dbg_bucket_key, dbg_bucket in buckets.items(): + # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] - Bucket [{dbg_bucket_key}{'*' if dbg_bucket_key == bucket_key else ''}] (bs={dbg_bucket.batch_size}, len(samples)={len(dbg_bucket.samples)}):\n", end="") + # dbg_bucket.samples.debug_print(" ") + batch_items, sample_restore_keys = bucket.samples.flush() + # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] flushed: len(batch)={len(batch_items)} len(samples)={len(bucket.samples)}\n", end="") + try: + with trace_span.span( + "GroupBatchDataset.flush", + args={ + "bucket": str(key), + "bucket_size": bucket.batch_size, + "bucket_len": len(batch_items), + }, + level=2, + ): + with self._batch_sample_index.ctx() as sample_idx: + batch_sample = self.batcher(batch_items) + assert not isinstance(batch_sample, Generator), ( + f"Batcher {self.batcher} returned a generator, which is not supported for grouped batching yet." + ) + set_sample_restore_key( + batch_sample, sample_idx, *sample_restore_keys, src=self + ) + yield batch_sample + except SkipSample: + trace_span.instant("GroupBatchDataset.flush.skip", level=2) + except SYSTEM_EXCEPTIONS: + raise FatalSampleError.from_sample(batch_items) + except Exception as e: + self.error_handler(e, batch_items) + trace_span.instant( + "GroupBatchDataset.flush.error/skip", + args={"exception": f"{type(e).__name__}: {str(e)}"}, + level=2, ) - set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self) - yield batch_sample - except SkipSample: - pass - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(batch_items) - except Exception as e: - self.error_handler(e, batch_items) - - # Add samples to the buckets - for sample in self.dataset: - try: - with self._group_key_sample_index.ctx(): - bucket_key, batch_size = self.sample_group_key(sample) - assert (batch_size is None) != (self.fixed_batch_size is None), ( - f"A sample in group for key {bucket_key} returned batch size {batch_size}, but fixed " - f"batch size is set to {self.fixed_batch_size}. One of the two should be None." + + # Add samples to the buckets + for sample in self.dataset: + try: + with self._group_key_sample_index.ctx(): + bucket_key, batch_size = self.sample_group_key(sample) + assert (batch_size is None) != (self.fixed_batch_size is None), ( + f"A sample in group for key {bucket_key} returned batch size {batch_size}, but fixed " + f"batch size is set to {self.fixed_batch_size}. One of the two should be None." + ) + if self.fixed_batch_size is not None: + batch_size = self.fixed_batch_size + except SkipSample: + trace_span.instant("GroupBatchDataset.__iter__.skip", level=2) + continue + except SYSTEM_EXCEPTIONS: + raise FatalSampleError.from_sample(sample) + except Exception as e: + self.error_handler(e, [sample]) + trace_span.instant( + "GroupBatchDataset.__iter__.error/skip", + args={"exception": f"{type(e).__name__}: {str(e)}"}, + level=2, ) - if self.fixed_batch_size is not None: - batch_size = self.fixed_batch_size - except SkipSample: - continue - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(sample) - except Exception as e: - self.error_handler(e, [sample]) - continue - bucket = buckets.get(bucket_key) - if bucket is None: - assert batch_size is not None - buckets[bucket_key] = bucket = Bucket( - batch_size=batch_size, - samples=SavableSampleBuffer(self.dataset, worker_config=self.worker_config), - ) - else: - assert bucket.batch_size == batch_size, ( - f"Got different batch size for group {bucket_key}: {bucket.batch_size} != {batch_size}." - ) - bucket.samples.append(sample) - if len(bucket.samples) >= bucket.batch_size: - yield from flush(bucket) - # Flush out last samples - if not self.drop_last: - for bucket in buckets.values(): - if len(bucket.samples) > 0: - yield from flush(bucket) - # Clear the buckets - self._buckets.clear() + continue + bucket = buckets.get(bucket_key) + if bucket is None: + assert batch_size is not None + buckets[bucket_key] = bucket = Bucket( + batch_size=batch_size, + samples=SavableSampleBuffer(self.dataset, worker_config=self.worker_config), + ) + else: + assert bucket.batch_size == batch_size, ( + f"Got different batch size for group {bucket_key}: {bucket.batch_size} != {batch_size}." + ) + bucket.samples.append(sample) + if len(bucket.samples) >= bucket.batch_size: + yield from flush(bucket_key, bucket) + # Flush out last samples + if not self.drop_last: + for bucket_key, bucket in buckets.items(): + if len(bucket.samples) > 0: + yield from flush(bucket_key, bucket) + # Clear the buckets + self._buckets.clear() + trace_span.instant("GroupBatchDataset.__iter__.done", level=1) def save_state(self) -> FlexState: return FlexState( @@ -234,9 +261,28 @@ def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_batch: set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self) return batch_sample + def _own_config(self) -> Dict[str, Any]: + return { + "bucket": self._function_config(self.sample_group_key), + "batcher": self._function_config(self.batcher), + **( + { + "batcher_config": ( + self.batcher_config() + if callable(self.batcher_config) + else self.batcher_config + ) + } + if self.batcher_config + else {} + ), + "batcher_stateless": self.batcher_stateless, + "drop_last": self.drop_last, + "error_handler": self._function_config(self.error_handler), + } + def config(self) -> Dict[str, Any]: return { - "type": type(self).__qualname__, "bucket": self._function_config(self.sample_group_key), "batcher": self._function_config(self.batcher), **( @@ -253,8 +299,8 @@ def config(self) -> Dict[str, Any]: "batcher_stateless": self.batcher_stateless, "drop_last": self.drop_last, "error_handler": self._function_config(self.error_handler), - "worker_config": self.worker_config.config(), "dataset": self.dataset.config(), + "worker_config": self.worker_config.config(), } def __str__(self): diff --git a/src/megatron/energon/wrappers/iter_map_dataset.py b/src/megatron/energon/wrappers/iter_map_dataset.py index e4578ba3..f5632697 100644 --- a/src/megatron/energon/wrappers/iter_map_dataset.py +++ b/src/megatron/energon/wrappers/iter_map_dataset.py @@ -89,45 +89,56 @@ def __len__(self): return self.len_map_fn(len(self.dataset)) def __iter__(self) -> Iterator[T_sample_out]: - last_sample_wrapper = _LastSampleWrapper(self.dataset) - # The iter_map_fn is stateless. Thus we need to know which inner sample created the - # outer sample, and the relative outer sample index, so we can restore it. - - # This is the sample index within the currently yielded sample - iter_idx = 0 - sample_idx = 0 - sample_restore_keys = [] - - def reset_idx_iter() -> Generator[T_sample, None, None]: - # Resets the inner sample index - nonlocal iter_idx, sample_restore_keys - for entry in last_sample_wrapper: - iter_idx = 0 - sample_restore_keys.append(get_sample_restore_key(entry)) - yield entry + trace_span = self.worker_config.worker_trace_span() + with trace_span.span( + "IterMapDataset.__iter__", args={"config": self._own_config()}, level=1 + ): + last_sample_wrapper = _LastSampleWrapper(self.dataset) + # The iter_map_fn is stateless. Thus we need to know which inner sample created the + # outer sample, and the relative outer sample index, so we can restore it. + + # This is the sample index within the currently yielded sample + iter_idx = 0 + sample_idx = 0 + sample_restore_keys = [] - ds_iter = iter(reset_idx_iter()) + def reset_idx_iter() -> Generator[T_sample, None, None]: + # Resets the inner sample index + nonlocal iter_idx, sample_restore_keys + for entry in last_sample_wrapper: + iter_idx = 0 + sample_restore_keys.append(get_sample_restore_key(entry)) + yield entry - # While True will break when the inner dataset is exhausted, but may continue on exception - while True: - iter_idx = 0 - try: - for sample_idx, sample in self._sample_index.iter_ctx(self.iter_map_fn(ds_iter)): - yield set_sample_restore_key( - sample, - sample_idx, - iter_idx, - *sample_restore_keys, - src=self, + ds_iter = iter(reset_idx_iter()) + + # While True will break when the inner dataset is exhausted, but may continue on exception + while True: + iter_idx = 0 + try: + for sample_idx, sample in self._sample_index.iter_ctx( + self.iter_map_fn(ds_iter) + ): + yield set_sample_restore_key( + sample, + sample_idx, + iter_idx, + *sample_restore_keys, + src=self, + ) + sample_restore_keys.clear() + iter_idx += 1 + except SYSTEM_EXCEPTIONS: + raise FatalSampleError.from_sample(last_sample_wrapper.last_sample) + except Exception as e: + self.error_handler(e, last_sample_wrapper.last_sample) + trace_span.instant( + "IterMapDataset.__iter__.error/retry", + args={"exception": f"{type(e).__name__}: {str(e)}"}, + level=1, ) - sample_restore_keys.clear() - iter_idx += 1 - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(last_sample_wrapper.last_sample) - except Exception as e: - self.error_handler(e, last_sample_wrapper.last_sample) - else: - break + else: + break def can_restore_sample(self) -> bool: return super().can_restore_sample() and self.stateless_iter_fn @@ -172,6 +183,24 @@ def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_sample: if hasattr(inner_iter, "close"): inner_iter.close() + def _own_config(self) -> Dict[str, Any]: + return { + "iter_map_fn": self._function_config(self.iter_map_fn), + **( + { + "iter_map_fn_config": ( + self.iter_map_fn_config() + if callable(self.iter_map_fn_config) + else self.iter_map_fn_config + ) + } + if self.iter_map_fn_config + else {} + ), + "len_map_fn": self._function_config(self.len_map_fn), + "error_handler": self._function_config(self.error_handler), + } + def config(self) -> Dict[str, Any]: return { "type": type(self).__qualname__, diff --git a/src/megatron/energon/wrappers/limit_dataset.py b/src/megatron/energon/wrappers/limit_dataset.py index d7a491ea..ab72ec8d 100644 --- a/src/megatron/energon/wrappers/limit_dataset.py +++ b/src/megatron/energon/wrappers/limit_dataset.py @@ -57,49 +57,41 @@ def __iter__(self) -> Iterator[T_sample]: if worker_id < self.length % self.worker_config.num_workers: local_limit += 1 - if self.worker_config.should_log(level=2): - self.worker_config.worker_log( - { - "t": "LimitDataset.start", - "r": self.worker_config.rank, - "w": worker_id, - "offset": self.current_offset, - "local_limit": local_limit, - "limit": self.length, - } - ) - - offset_range = list(range(self.current_offset, local_limit)) - # Only iterate self.dataset if there are samples to iterate - if len(offset_range) > 0: - for sample, offset in zip( - self.dataset, - offset_range, - ): - self.current_offset = offset + 1 - yield sample - - if self.worker_config.should_log(level=2): - self.worker_config.worker_log( - { - "t": "LimitDataset.done", - "r": self.worker_config.rank, - "w": worker_id, - "offset": self.current_offset, - "local_limit": local_limit, - "limit": self.length, - } - ) + trace_span = self.worker_config.worker_trace_span() + with trace_span.span( + "LimitDataset.__iter__", + args={ + "offset": self.current_offset, + "local_limit": local_limit, + "config": self._own_config(), + }, + level=2, + ): + offset_range = list(range(self.current_offset, local_limit)) + # Only iterate self.dataset if there are samples to iterate + if len(offset_range) > 0: + for sample, offset in zip( + self.dataset, + offset_range, + ): + self.current_offset = offset + 1 + yield sample # Reset the inner dataset - self.dataset.reset_state_deep() self.current_offset = 0 if self.reset_after_epoch: - self.dataset.reset_state_deep() + with trace_span.span("LimitDataset.__iter__.reset_state_deep"): + self.dataset.reset_state_deep() def worker_has_samples(self) -> bool: return super().worker_has_samples() and self.length > 0 + def _own_config(self) -> Dict[str, Any]: + return { + "length": self.length, + "reset_after_epoch": self.reset_after_epoch, + } + def config(self) -> Dict[str, Any]: return { "type": type(self).__qualname__, diff --git a/src/megatron/energon/wrappers/log_sample_dataset.py b/src/megatron/energon/wrappers/log_sample_dataset.py index 5947b84a..2014c996 100644 --- a/src/megatron/energon/wrappers/log_sample_dataset.py +++ b/src/megatron/energon/wrappers/log_sample_dataset.py @@ -80,26 +80,33 @@ def reset_state_own(self) -> None: def __len__(self): return len(self.dataset) - def _log(self, sample: T_sample) -> None: - if self.worker_config.should_log(level=1): - log_entry = { - "t": "yield_batch", - "r": self.worker_config.rank, - "w": self.worker_config.global_worker_id(), - "m": self.mode, - "idx": self._step, - } - keys = self.get_keys_fn(sample) - if keys is not None: - log_entry["keys"] = keys - - self.worker_config.worker_log(log_entry) + def _log(self, sample: T_sample) -> dict: + log_entry = { + "idx": self._step, + } + keys = self.get_keys_fn(sample) + if keys is not None: + log_entry["keys"] = keys + + return log_entry def __iter__(self) -> Iterator[T_sample]: - for sample in self.dataset: - self._log(sample) - self._step += 1 - yield sample + trace_span = self.worker_config.worker_trace_span() + with trace_span.span( + "LogSampleDataset.__iter__", + args={ + "mode": self.mode, + }, + level=1, + ): + for sample in trace_span.iterable( + self.dataset, name="LogSampleDataset.__iter__.next", level=1 + ): + with trace_span.span( + "LogSampleDataset.__iter__.yield", args=self._log(sample), level=1 + ): + self._step += 1 + yield sample def config(self) -> Dict[str, Any]: # Transparent logger, it won't change the samples diff --git a/src/megatron/energon/wrappers/map_dataset.py b/src/megatron/energon/wrappers/map_dataset.py index dc626de0..1c623df3 100644 --- a/src/megatron/energon/wrappers/map_dataset.py +++ b/src/megatron/energon/wrappers/map_dataset.py @@ -87,48 +87,25 @@ def __len__(self): return len(self.dataset) def __iter__(self) -> Iterator[T_sample_out]: - if self._generator_sample_key is not None: - assert self._generator_offset is not None - sample = self.dataset.restore_sample(self._generator_sample_key) - # Do not increment the sample index, use previous index - with self._sample_index.ctx(self._sample_index.current_idx) as sample_idx: - mapped_sample = self.map_fn(sample) - assert isinstance(mapped_sample, Generator) - assert inspect.isgeneratorfunction(self.map_fn), ( - f"Generator in {self.map_fn} but not marked as such." - ) - target_offset = self._generator_offset - self._generator_offset = 0 - for idx, (sample_idx, inner_sample) in enumerate( - self._sample_index.iter_ctx(mapped_sample, sample_idx) - ): - # Skip other samples - if idx >= target_offset: - self._generator_offset = idx + 1 - yield add_sample_restore_key( - inner_sample, - sample_idx, - idx, - src=self, - ) - self._generator_sample_key = None - self._generator_offset = None - - for sample in self.dataset: - try: - with self._sample_index.ctx() as sample_idx: + trace_span = self.worker_config.worker_trace_span() + with trace_span.span("MapDataset.__iter__", args={"config": self._own_config()}, level=1): + if self._generator_sample_key is not None: + assert self._generator_offset is not None + sample = self.dataset.restore_sample(self._generator_sample_key) + # Do not increment the sample index, use previous index + with self._sample_index.ctx(self._sample_index.current_idx) as sample_idx: mapped_sample = self.map_fn(sample) - if isinstance(mapped_sample, Generator): - assert inspect.isgeneratorfunction(self.map_fn), ( - f"Generator in {self.map_fn} but not marked as such." - ) - self._generator_sample_key = get_sample_restore_key(sample) - self._generator_offset = 0 - # In case of a generator, additionally store the index of the yielded samples - # per input sample - for idx, (sample_idx, inner_sample) in enumerate( - self._sample_index.iter_ctx(mapped_sample, sample_idx) - ): + assert isinstance(mapped_sample, Generator) + assert inspect.isgeneratorfunction(self.map_fn), ( + f"Generator in {self.map_fn} but not marked as such." + ) + target_offset = self._generator_offset + self._generator_offset = 0 + for idx, (sample_idx, inner_sample) in enumerate( + self._sample_index.iter_ctx(mapped_sample, sample_idx) + ): + # Skip other samples + if idx >= target_offset: self._generator_offset = idx + 1 yield add_sample_restore_key( inner_sample, @@ -136,20 +113,50 @@ def __iter__(self) -> Iterator[T_sample_out]: idx, src=self, ) - self._generator_sample_key = None - self._generator_offset = None - else: - yield add_sample_restore_key( - mapped_sample, - sample_idx, - src=self, + self._generator_sample_key = None + self._generator_offset = None + + for sample in self.dataset: + try: + with self._sample_index.ctx() as sample_idx: + mapped_sample = self.map_fn(sample) + if isinstance(mapped_sample, Generator): + assert inspect.isgeneratorfunction(self.map_fn), ( + f"Generator in {self.map_fn} but not marked as such." + ) + self._generator_sample_key = get_sample_restore_key(sample) + self._generator_offset = 0 + # In case of a generator, additionally store the index of the yielded samples + # per input sample + for idx, (sample_idx, inner_sample) in enumerate( + self._sample_index.iter_ctx(mapped_sample, sample_idx) + ): + self._generator_offset = idx + 1 + yield add_sample_restore_key( + inner_sample, + sample_idx, + idx, + src=self, + ) + self._generator_sample_key = None + self._generator_offset = None + else: + yield add_sample_restore_key( + mapped_sample, + sample_idx, + src=self, + ) + except SkipSample: + trace_span.instant("MapDataset.__iter__.skip", level=1) + except SYSTEM_EXCEPTIONS: + raise FatalSampleError.from_sample(sample) + except Exception as e: + self.error_handler(e, sample) + trace_span.instant( + "MapDataset.__iter__.error/skip", + args={"exception": f"{type(e).__name__}: {str(e)}"}, + level=1, ) - except SkipSample: - pass - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(sample) - except Exception as e: - self.error_handler(e, sample) def can_restore_sample(self) -> bool: return super().can_restore_sample() and self.stateless_map_fn @@ -189,6 +196,21 @@ def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_sample_ else: return add_sample_restore_key(mapped_sample, sample_idx, src=self) + def _own_config(self) -> Dict[str, Any]: + return { + "map_fn": self._function_config(self.map_fn), + **( + { + "map_fn_config": ( + self.map_fn_config() if callable(self.map_fn_config) else self.map_fn_config + ) + } + if self.map_fn_config + else {} + ), + "map_fn_stateless": self.stateless_map_fn, + } + def config(self) -> Dict[str, Any]: return { "type": type(self).__qualname__, diff --git a/src/megatron/energon/wrappers/packing_dataset.py b/src/megatron/energon/wrappers/packing_dataset.py index 6b154e0f..0be345c6 100644 --- a/src/megatron/energon/wrappers/packing_dataset.py +++ b/src/megatron/energon/wrappers/packing_dataset.py @@ -179,40 +179,46 @@ def _fill_reading_buffer(self, source_iter: Iterator, log_progress: bool = False return False return True - def _encode_pack_samples(self, pack: List[T_sample]) -> List[T_encoded_sample]: - # Apply the sample encoder to the pack - if self.sample_encoder is None: - return pack - encoded_pack = [] - for sample in pack: - try: - with self._sample_encoder_sample_index.ctx() as encode_idx: - encoded_sample = self.sample_encoder(sample) - assert not isinstance(encoded_sample, Generator), "Generator not supported" - encoded_pack.append( - add_sample_restore_key( - encoded_sample, - encode_idx, - src=self, - ) - ) - except SkipSample: - pass - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(pack) - except Exception as e: - self.error_handler(e, pack) - return encoded_pack - def __iter__(self) -> Iterator[T_batch_sample]: - pre_packing_lengths = self._pre_packing_lengths - # The source dataset - src_iter = iter(self.dataset) - - self._pre_packing_buffer.worker_start() - self._reading_buffer.worker_start() - - is_initial_pack = True + trace_span = self.worker_config.worker_trace_span() + + def encode_pack_samples(pack: List[T_sample]) -> List[T_encoded_sample]: + # Apply the sample encoder to the pack + if self.sample_encoder is None: + return pack + encoded_pack = [] + with trace_span.span( + "PackingDataset._encode_pack_samples", args={"len": len(pack)}, level=2 + ): + for sample in pack: + try: + with trace_span.span( + "PackingDataset._encode_pack_samples.encode_sample", level=2 + ): + with self._sample_encoder_sample_index.ctx() as encode_idx: + encoded_sample = self.sample_encoder(sample) + assert not isinstance(encoded_sample, Generator), ( + "Generator not supported" + ) + encoded_pack.append( + add_sample_restore_key( + encoded_sample, + encode_idx, + src=self, + ) + ) + except SkipSample: + trace_span.instant("PackingDataset._encode_pack_samples.skip", level=2) + except SYSTEM_EXCEPTIONS: + raise FatalSampleError.from_sample(pack) + except Exception as e: + self.error_handler(e, pack) + trace_span.instant( + "PackingDataset._encode_pack_samples.error/skip", + args={"exception": f"{type(e).__name__}: {str(e)}"}, + level=2, + ) + return encoded_pack def next_pre_pack(): """Take the samples from the reading buffer and select groups of samples to be packed @@ -231,10 +237,16 @@ def next_pre_pack(): pre_packs = self.pre_packer(samples) except SkipSample: pre_packs = [] + trace_span.instant("PackingDataset.next_pre_pack.skip", level=2) except SYSTEM_EXCEPTIONS: raise FatalSampleError.from_sample(samples) except Exception as e: self.error_handler(e, samples) + trace_span.instant( + "PackingDataset.next_pre_pack.error/skip", + args={"exception": f"{type(e).__name__}: {str(e)}"}, + level=2, + ) pre_packs = [] # Put the pre-packed samples into the pre_packing_buffer @@ -249,7 +261,10 @@ def next_final_pack() -> Generator[T_batch_sample, None, None]: """Yield the next packs from the buffer. The final packer is called on the fly.""" pack = list(self._pre_packing_buffer[: pre_packing_lengths[0]]) - pack = self._encode_pack_samples(pack) + pack = encode_pack_samples(pack) + if len(pack) == 0: + # All samples in the pack were skipped + return del self._pre_packing_buffer[: pre_packing_lengths[0]] del pre_packing_lengths[0] @@ -279,50 +294,84 @@ def next_final_pack() -> Generator[T_batch_sample, None, None]: src=self, ) except SkipSample: - pass + trace_span.instant("PackingDataset.next_final_pack.skip", level=2) except SYSTEM_EXCEPTIONS: raise FatalSampleError.from_sample(pack) except Exception as e: self.error_handler(e, pack) + trace_span.instant( + "PackingDataset.next_final_pack.error/skip", + args={"exception": f"{type(e).__name__}: {str(e)}"}, + level=2, + ) - # Main loop: - pre_pack_round = 0 - while True: - if pre_pack_round > 10: - raise RuntimeError("Pre packer did not yield any packs after 10 rounds.") - # Fill a portion of the buffer - if not self._fill_reading_buffer(src_iter, log_progress=is_initial_pack): - # Break out of the main loop when the source is exhausted. - break - is_initial_pack = False - - # Create new pre packs if necessary - if len(pre_packing_lengths) == 0: - assert len(self._pre_packing_buffer) == 0 - assert len(self._reading_buffer) == self.buffer_size - next_pre_pack() + with trace_span.span( + "PackingDataset.__iter__", args={"config": self._own_config()}, level=1 + ): + pre_packing_lengths = self._pre_packing_lengths + # The source dataset + src_iter = iter(self.dataset) + + self._pre_packing_buffer.worker_start() + self._reading_buffer.worker_start() + + is_initial_pack = True + + pre_pack_round = 0 + # Main loop: + while True: + if pre_pack_round > 10: + raise RuntimeError("Pre packer did not yield any packs after 10 rounds.") + with trace_span.span( + "PackingDataset.__iter__.fill_reading_buffer", + args={ + "to_fill": self.buffer_size + - len(self._reading_buffer) + - len(self._pre_packing_buffer), + "reading_buffer": len(self._reading_buffer), + "pre_packing_buffer": len(self._pre_packing_buffer), + "buffer_size": self.buffer_size, + }, + level=2, + ): + # Fill a portion of the buffer + if not self._fill_reading_buffer(src_iter, log_progress=is_initial_pack): + # Break out of the main loop when the source is exhausted. + break + is_initial_pack = False + + # Create new pre packs if necessary if len(pre_packing_lengths) == 0: - # Retry packing, nothing was returned. - pre_pack_round += 1 - continue - - if len(pre_packing_lengths) > 0: + with trace_span.span("PackingDataset.__iter__.next_pre_pack", level=1): + assert len(self._pre_packing_buffer) == 0 + assert len(self._reading_buffer) == self.buffer_size + next_pre_pack() + if len(pre_packing_lengths) == 0: + # Retry packing, nothing was returned. + pre_pack_round += 1 + continue + # Reset the pre pack round counter for failing pre_pack_round = 0 - yield from next_final_pack() + with trace_span.span("PackingDataset.__iter__.final_pack", level=2): + yield from next_final_pack() - # Yield the remaining packs, flushing the collecting buffer - while len(pre_packing_lengths) > 0: - yield from next_final_pack() + with trace_span.span("PackingDataset.__iter__.last", level=1): + # Yield the remaining packs, flushing the collecting buffer + while len(pre_packing_lengths) > 0: + with trace_span.span("PackingDataset.__iter__.last.final_pack", level=2): + yield from next_final_pack() - # If there are still samples in the partial reading buffer, pre-pack them and yield the - # resulting (partial) packs - if len(self._reading_buffer) > 0: - next_pre_pack() + # If there are still samples in the partial reading buffer, pre-pack them and yield the + # resulting (partial) packs + if len(self._reading_buffer) > 0: + with trace_span.span("PackingDataset.__iter__.last.next_pre_pack", level=1): + next_pre_pack() - # Yield the remaining packs, flushing the collecting buffer - while len(pre_packing_lengths) > 0: - yield from next_final_pack() + # Yield the remaining packs, flushing the collecting buffer + while len(pre_packing_lengths) > 0: + with trace_span.span("PackingDataset.__iter__.last.final_pack", level=2): + yield from next_final_pack() def can_restore_sample(self) -> bool: # Cannot really verify if the returned elements contain a __restore_key__. @@ -383,6 +432,24 @@ def restore_sample(self, restore_key: Any) -> T_sample: else: return set_sample_restore_key(final_pack, pack_idx, *pack_restore_keys, src=self) + def _own_config(self) -> Dict[str, Any]: + return { + "buffer_size": self.buffer_size, + "pre_packer": self._function_config(self.pre_packer), + "final_packer": self._function_config(self.final_packer), + "final_packer_stateless": self.final_packer_stateless, + **( + { + "packer_config": ( + self.packer_config() if callable(self.packer_config) else self.packer_config + ) + } + if self.packer_config + else {} + ), + "error_handler": self._function_config(self.error_handler), + } + def config(self) -> Dict[str, Any]: return { "type": type(self).__qualname__, diff --git a/src/megatron/energon/wrappers/repeat_dataset.py b/src/megatron/energon/wrappers/repeat_dataset.py index 1261adf8..cf11be6c 100644 --- a/src/megatron/energon/wrappers/repeat_dataset.py +++ b/src/megatron/energon/wrappers/repeat_dataset.py @@ -59,41 +59,60 @@ def __iter__(self) -> Iterator[T_sample]: ds_len = len(self.dataset) - while self.repeats is None or self._repetition < self.repeats: - if self.repeats is not None and self._repetition == math.floor(self.repeats): - # Last iteration, adjust the number of samples - fraction = self.repeats - math.floor(self.repeats) - stop_after = math.floor(ds_len * fraction) - if self._index >= stop_after: - # We restored an index and it is already past the stop_after - break - else: - stop_after = None - - for sample in self.dataset: - self._index += 1 - yield sample - if stop_after is not None and self._index >= stop_after: - break - - if self.worker_config.should_log(level=2): - self.worker_config.worker_log( - { - "t": "RepeatDataset.repeat", - "r": self.worker_config.rank, - "w": self.worker_config.rank_worker_id(), - "offset": self._repetition, + trace_span = self.worker_config.worker_trace_span() + with trace_span.span( + "RepeatDataset.__iter__", + args={ + "repetition": self._repetition, + "repeats": self.repeats, + "inner_len": ds_len, + }, + level=2, + ): + while self.repeats is None or self._repetition < self.repeats: + with trace_span.span( + "RepeatDataset.__iter__.repeat", + args={ + "repetition": self._repetition, "repeats": self.repeats, - } - ) - self._repetition += 1 - self._index = 0 - - if self.restart: - self._repetition = 0 - else: - # No more repeats - self._repetition = math.ceil(self.repeats) + }, + level=2, + ): + if self.repeats is not None and self._repetition == math.floor(self.repeats): + # Last iteration, adjust the number of samples + fraction = self.repeats - math.floor(self.repeats) + stop_after = math.floor(ds_len * fraction) + if self._index >= stop_after: + # We restored an index and it is already past the stop_after + trace_span.instant("RepeatDataset.__iter__.break(stop_after)", level=2) + break + else: + stop_after = None + + for sample in self.dataset: + with trace_span.span( + "RepeatDataset.__iter__.__iter__.yield", + args={ + "idx": self._index, + }, + level=2, + ): + self._index += 1 + yield sample + + if stop_after is not None and self._index >= stop_after: + trace_span.instant( + "RepeatDataset.__iter__.__iter__.break(stop_after)", level=2 + ) + break + self._repetition += 1 + self._index = 0 + + if self.restart: + self._repetition = 0 + else: + # No more repeats + self._repetition = math.ceil(self.repeats) def config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/wrappers/shuffle_buffer_dataset.py b/src/megatron/energon/wrappers/shuffle_buffer_dataset.py index e40d262c..7befb4cc 100644 --- a/src/megatron/energon/wrappers/shuffle_buffer_dataset.py +++ b/src/megatron/energon/wrappers/shuffle_buffer_dataset.py @@ -1,7 +1,7 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Dict, Generic, Iterator, Tuple, TypeVar, Union +from typing import Any, Dict, Generic, Iterator, List, Tuple, TypeVar, Union from megatron.energon.flavors.base_dataset import SavableDataset from megatron.energon.rng import WorkerRng @@ -18,8 +18,10 @@ class ShuffleBufferDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sam size: int _worker_rng: WorkerRng _active_buffer: SavableSampleBuffer[T_sample] + _iterations: int + _sample_creation: List[int] - _savable_fields = ("_active_buffer", "_worker_rng") + _savable_fields = ("_active_buffer", "_worker_rng", "_iterations", "_sample_creation") def __init__( self, @@ -36,29 +38,61 @@ def __init__( def reset_state_own(self) -> None: self._worker_rng = WorkerRng(self.worker_config) self._active_buffer = SavableSampleBuffer(self.dataset, worker_config=self.worker_config) + self._iterations = 0 + self._sample_creation = [] def __len__(self) -> int: return len(self.dataset) def __iter__(self) -> Iterator[T_sample]: - self._active_buffer.worker_start() - it = iter(self._active_buffer.append_iter()) - while True: - if len(self._active_buffer) >= self.size: - pop_idx = self._worker_rng.randbelow(len(self._active_buffer)) - yield self._active_buffer.pop(pop_idx) - else: - try: - next(it) - except StopIteration: - break - while len(self._active_buffer) > 0: - pop_idx = self._worker_rng.randbelow(len(self._active_buffer)) - yield self._active_buffer.pop(pop_idx) + trace_span = self.worker_config.worker_trace_span() + with trace_span.span( + "ShuffleBufferDataset.__iter__", args={"config": self._own_config()}, level=1 + ): + self._active_buffer.worker_start() + it = iter(self._active_buffer.append_iter()) + while True: + if len(self._active_buffer) >= self.size: + pop_idx = self._worker_rng.randbelow(len(self._active_buffer)) + sample_creation = self._sample_creation.pop(pop_idx) + trace_span.instant( + "ShuffleBufferDataset.__iter__.yield", + args={ + "idx": pop_idx, + "sample_creation": sample_creation, + "sample_age": self._iterations - sample_creation, + }, + level=2, + ) + yield self._active_buffer.pop(pop_idx) + else: + try: + next(it) + self._sample_creation.append(self._iterations) + trace_span.instant( + "ShuffleBufferDataset.__iter__.append", + args={ + "idx": len(self._sample_creation) - 1, + "sample_creation": self._iterations, + }, + level=2, + ) + self._iterations += 1 + except StopIteration: + break + with trace_span.span("ShuffleBufferDataset.__iter__.final_buffer", level=2): + while len(self._active_buffer) > 0: + pop_idx = self._worker_rng.randbelow(len(self._active_buffer)) + yield self._active_buffer.pop(pop_idx) def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_sample: return self._active_buffer.restore_sample(index) + def _own_config(self) -> Dict[str, Any]: + return { + "size": self.size, + } + def config(self) -> Dict[str, Any]: return { "type": type(self).__qualname__, diff --git a/src/megatron/energon/wrappers/watchdog_dataset.py b/src/megatron/energon/wrappers/watchdog_dataset.py index e11e7b5d..e2f7f59b 100644 --- a/src/megatron/energon/wrappers/watchdog_dataset.py +++ b/src/megatron/energon/wrappers/watchdog_dataset.py @@ -45,6 +45,7 @@ def __len__(self): return len(self.dataset) def _watchdog_trigger(self) -> None: + self.worker_config.worker_trace_span().instant("WatchdogDataset._watchdog_trigger", level=2) if self.fail_on_timeout: # Raising an exception here will kill the whole process raise TimeoutError( @@ -57,16 +58,26 @@ def _watchdog_trigger(self) -> None: ) def __iter__(self) -> Iterator[T_sample]: - if self.timeout_seconds is None: - yield from self.dataset - else: - watchdog = Watchdog( - timeout=self.timeout_seconds, - initial_timeout=self.initial_timeout_seconds, - callback=self._watchdog_trigger, - enabled=False, - ) - yield from watchdog.watch_iter(self.dataset) + with self.worker_config.worker_trace_span().span( + "WatchdogDataset.__iter__", args={"config": self._own_config()}, level=1 + ): + if self.timeout_seconds is None: + yield from self.dataset + else: + watchdog = Watchdog( + timeout=self.timeout_seconds, + initial_timeout=self.initial_timeout_seconds, + callback=self._watchdog_trigger, + enabled=False, + ) + yield from watchdog.watch_iter(self.dataset) + + def _own_config(self) -> Dict[str, Any]: + return { + "timeout_seconds": self.timeout_seconds, + "initial_timeout_seconds": self.initial_timeout_seconds, + "fail_on_timeout": self.fail_on_timeout, + } def config(self) -> Dict[str, Any]: # Watchdog is transparent, it won't change the samples diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 50736d18..c809c061 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1683,7 +1683,7 @@ def test_debug_dataset(self): world_size=1, num_workers=2, worker_log_level=3, - worker_debug_path=str(self.dataset_path) + "/worker_debug/{worker_id}.jsonl", + worker_debug_path=str(self.dataset_path) + "/worker_debug/{worker_id}.json", ) # Reset this to 0 to make sure the test is deterministic @@ -1702,24 +1702,19 @@ def test_debug_dataset(self): samples = [[batch.__key__ for batch in loader] for _ in range(2)] print(samples) + del loader + gc.collect() debug_log_path = self.dataset_path / "worker_debug" - assert (debug_log_path / "0.jsonl").is_file() - assert (debug_log_path / "1.jsonl").is_file() - assert (debug_log_path / "2.jsonl").is_file() + assert (debug_log_path / "0.json").is_file(), f"{list(debug_log_path.iterdir())}" + assert (debug_log_path / "1.json").is_file(), f"{list(debug_log_path.iterdir())}" + assert (debug_log_path / "2.json").is_file(), f"{list(debug_log_path.iterdir())}" collected_keys_order = [[None] * 10 for _ in range(2)] - with (debug_log_path / "0.jsonl").open() as rf: - for line in rf: - line_data = json.loads(line) - if line_data["t"] == "SavableDataLoader.yield": - print(line_data) - for i in range(len(collected_keys_order)): - if collected_keys_order[i][line_data["idx"]] is None: - collected_keys_order[i][line_data["idx"]] = line_data["keys"] - break - else: - assert False, "Too many entries for key" + with (debug_log_path / "0.json").open() as rf: + raw = json.load(rf) + assert len(raw) > 0 + print(raw) print(collected_keys_order) assert collected_keys_order == samples From aa1455b7254c9f5f9c8e62d84f07285cc70bebeb Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Mon, 26 May 2025 12:02:33 +0200 Subject: [PATCH 2/7] Implement new tracing functionality for detailed debugging --- src/megatron/energon/cli/main.py | 2 + src/megatron/energon/flavors/base_dataset.py | 4 + .../flavors/webdataset/sample_loader.py | 72 +- src/megatron/energon/fork_hook.py | 6 +- src/megatron/energon/savable_loader.py | 848 ++++++++----- src/megatron/energon/tools/analyze_debug.py | 669 ++++------ .../energon/tools/analyze_debug_merge.py | 146 +++ src/megatron/energon/tracing.py | 1087 +++++++++++------ src/megatron/energon/watchdog.py | 34 +- src/megatron/energon/worker.py | 23 +- src/megatron/energon/wrappers/base.py | 21 +- .../energon/wrappers/batch_dataset.py | 228 ++-- .../energon/wrappers/blend_dataset.py | 113 +- .../energon/wrappers/epochize_dataset.py | 58 +- .../energon/wrappers/filter_dataset.py | 20 +- src/megatron/energon/wrappers/gc_dataset.py | 34 +- .../energon/wrappers/group_batch_dataset.py | 80 +- .../energon/wrappers/iter_map_dataset.py | 142 ++- .../energon/wrappers/log_sample_dataset.py | 68 +- src/megatron/energon/wrappers/map_dataset.py | 142 ++- .../energon/wrappers/packing_dataset.py | 320 +++-- .../energon/wrappers/repeat_dataset.py | 33 +- .../wrappers/shuffle_buffer_dataset.py | 67 +- tests/test_dataset.py | 23 +- 24 files changed, 2512 insertions(+), 1728 deletions(-) create mode 100644 src/megatron/energon/tools/analyze_debug_merge.py diff --git a/src/megatron/energon/cli/main.py b/src/megatron/energon/cli/main.py index d12cf01f..1bb8686e 100644 --- a/src/megatron/energon/cli/main.py +++ b/src/megatron/energon/cli/main.py @@ -6,6 +6,7 @@ import click from megatron.energon.tools.analyze_debug import command as analyze_debug_command +from megatron.energon.tools.analyze_debug_merge import command as analyze_debug_merge_command from megatron.energon.tools.checkpoint import command as checkpoint_command from megatron.energon.tools.info import command as info_command from megatron.energon.tools.lint import command as lint_command @@ -28,6 +29,7 @@ def main(ctx): main.add_command(analyze_debug_command) +main.add_command(analyze_debug_merge_command) main.add_command(checkpoint_command) main.add_command(lint_command) main.add_command(info_command) diff --git a/src/megatron/energon/flavors/base_dataset.py b/src/megatron/energon/flavors/base_dataset.py index 239fc47f..370763f7 100644 --- a/src/megatron/energon/flavors/base_dataset.py +++ b/src/megatron/energon/flavors/base_dataset.py @@ -330,6 +330,10 @@ def _function_config(fn: Callable) -> str: mod_name = getattr(fn, "__module__", "") return f"{mod_name}.{getattr(fn, '__qualname__', getattr(fn, '__name__', ''))}" + @staticmethod + def _function_config_short(fn: Callable) -> str: + return getattr(fn, "__qualname__", getattr(fn, "__name__", "")) + @abstractmethod def config(self) -> Dict[str, Any]: """Return a config dict that can be used to check if datasets have the same settings. diff --git a/src/megatron/energon/flavors/webdataset/sample_loader.py b/src/megatron/energon/flavors/webdataset/sample_loader.py index 3ab8bb30..89cd773b 100644 --- a/src/megatron/energon/flavors/webdataset/sample_loader.py +++ b/src/megatron/energon/flavors/webdataset/sample_loader.py @@ -205,7 +205,18 @@ def _slices_iter(self) -> Generator[RawSampleData, None, None]: trace = self.worker_config.worker_trace_span() - with trace.span("WebdatasetSampleLoaderDataset._slices_iter", level=1) as fn_span: + with trace.span( + "WebdatasetSampleLoaderDataset._slices_iter", + args={ + "base_paths": [str(reader.base_path) for reader in self.join_readers], + "shuffle_over_epochs": self.shuffle_over_epochs, + "parallel_slice_iters": self.parallel_slice_iters, + "sample_count": self._sample_count, + "epoch_idx": self._epoch_count, + "epoch_sample_count": self._epoch_sample_count, + }, + level=1, + ) as fn_span: assert self.slice_offsets is not None active_slice_probs = torch.zeros(self.parallel_slice_iters, dtype=torch.float32) @@ -251,9 +262,6 @@ def slice_at(idx: int) -> SliceState: ) for state in active_slices ], - "count": self._sample_count, - "epoch": self._epoch_count, - "epoch_count": self._epoch_sample_count, "probs": active_slice_probs.tolist(), } ) @@ -272,10 +280,12 @@ def slice_at(idx: int) -> SliceState: ) for state in active_slices ], - "count": self._sample_count, - "epoch": self._epoch_count, - "epoch_count": self._epoch_sample_count, + "sample_count": self._sample_count, + "epoch_idx": self._epoch_count, + "epoch_sample_count": self._epoch_sample_count, "probs": active_slice_probs.tolist(), + "shuffle_over_epochs": self.shuffle_over_epochs, + "parallel_slice_iters": self.parallel_slice_iters, }, level=1, ) @@ -289,22 +299,19 @@ def slice_at(idx: int) -> SliceState: { "mode": "next_epoch", "pending_slice_indexes": pending_slice_indexes, - "count": self._sample_count, - "epoch": self._epoch_count, - "epoch_count": self._epoch_sample_count, "probs": active_slice_probs.tolist(), - "shuffle_over_epochs": self.shuffle_over_epochs, } ) trace.instant( "WebdatasetSampleLoaderDataset._slices_iter.next_epoch", args={ "pending_slice_indexes": pending_slice_indexes, - "count": self._sample_count, - "epoch": self._epoch_count, - "epoch_count": self._epoch_sample_count, + "sample_count": self._sample_count, + "epoch_idx": self._epoch_count, + "epoch_sample_count": self._epoch_sample_count, "probs": active_slice_probs.tolist(), "shuffle_over_epochs": self.shuffle_over_epochs, + "parallel_slice_iters": self.parallel_slice_iters, }, level=1, ) @@ -393,37 +400,38 @@ def slice_at(idx: int) -> SliceState: "WebdatasetSampleLoaderDataset._slices_iter.exhausted", args={ "remaining": len(pending_slice_indexes), - "count": self._sample_count, - "epoch": self._epoch_count, - "epoch_count": self._epoch_sample_count, + "sample_count": self._sample_count, + "epoch_idx": self._epoch_count, + "epoch_sample_count": self._epoch_sample_count, "probs": active_slice_probs.tolist(), }, level=2, ) if sample.data[0] is not None: # Otherwise the sample was skipped. - with trace.span( + trace.instant( "WebdatasetSampleLoaderDataset._slices_iter.yield", args={ - "index": sample.__restore_key__[1], + "base_path": str(self.join_readers[0].base_path), + "global_sample_index": sample.__restore_key__[1], "key": sample.data[0]["__key__"], "shard": sample.data[0]["__shard__"], - "count": self._sample_count, - "epoch": self._epoch_count, - "epoch_count": self._epoch_sample_count, + "sample_count": self._sample_count, + "epoch_idx": self._epoch_count, + "epoch_sample_count": self._epoch_sample_count, }, level=2, - ): - # Now, yield the sample - yield sample - del sample + ) + # Now, yield the sample + yield sample + del sample if self.worker_config.should_log(level=2): trace.instant( "WebdatasetSampleLoaderDataset._slices_iter.all_exhausted", args={ - "count": self._sample_count, - "epoch": self._epoch_count, - "epoch_count": self._epoch_sample_count, + "sample_count": self._sample_count, + "epoch_idx": self._epoch_count, + "epoch_sample_count": self._epoch_sample_count, }, level=2, ) @@ -454,12 +462,6 @@ def __iter__(self) -> Iterator[RawSampleData]: self.ensure_slice_offsets() assert self.slice_offsets is not None - self.worker_config.worker_trace_writer().instant( - "WebdatasetSampleLoaderDataset.__iter__", - args=self.config(), - level=1, - ) - if len(self.slice_offsets) <= 1: return diff --git a/src/megatron/energon/fork_hook.py b/src/megatron/energon/fork_hook.py index 087363be..72a3d4aa 100644 --- a/src/megatron/energon/fork_hook.py +++ b/src/megatron/energon/fork_hook.py @@ -1,11 +1,10 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -from dataclasses import dataclass -import functools import os import weakref -from typing import Callable, Protocol, Type, TypeVar +from dataclasses import dataclass +from typing import Callable, TypeVar _after_in_child_fork_hooks = dict() _after_in_parent_fork_hooks = dict() @@ -96,6 +95,7 @@ def __before_fork__(self): """ pass + @dataclass class DataclassForkMixin: """ diff --git a/src/megatron/energon/savable_loader.py b/src/megatron/energon/savable_loader.py index 227e8e0c..ef802346 100644 --- a/src/megatron/energon/savable_loader.py +++ b/src/megatron/energon/savable_loader.py @@ -39,6 +39,7 @@ add_sample_restore_key, ) from megatron.energon.rng import SystemRng, SystemRngState +from megatron.energon.tracing import Flow from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset from megatron.energon.wrappers.batch_dataset import BatchDataset @@ -74,18 +75,27 @@ class SimpleSavableDatasetWrapper(BaseWrapperDataset[T, Tuple[int, int, T]], Gen _savable_fields = ("_sample_index",) def __init__( - self, dataset: SavableDataset[T], worker_config: WorkerConfig, cache_pool: CachePool + self, + dataset: SavableDataset[T], + worker_config: WorkerConfig, + cache_pool: CachePool, + dataloader_id: int, ): """ Args: dataset: The dataset to wrap. worker_config: The worker config to use for the dataset. cache_pool: The cache pool to use for the dataset. + dataloader_id: The id of the data loader for logging purposes. """ super().__init__(dataset, worker_config=worker_config) self.cache_pool = cache_pool + self.dataloader_id = dataloader_id self.reset_state_own() + # This must be removed, such that the outer dataloader does not use this + # from the second epoch on for training. + del self.__len__ def reset_state_own(self) -> None: self._sample_index = 0 @@ -95,50 +105,88 @@ def __len__(self): return len(self.dataset) def __iter__(self): - self._state_restored = True - worker_id = self.worker_config.rank_worker_id() - global_worker_id = self.worker_config.global_worker_id() - while self._state_restored: - self._state_restored = False - self.worker_config.worker_activate(self._sample_index, cache_pool=self.cache_pool) - worker_active = True - try: - for src_data in self.dataset: - self.worker_config.worker_deactivate() - worker_active = False - sample_index = self._sample_index - src_data = add_sample_restore_key( - src_data, global_worker_id, sample_index, src=self - ) - self._sample_index += 1 - yield worker_id, sample_index, src_data - if self._state_restored: - # Restart iterator after restore - break - self.worker_config.worker_activate( - self._sample_index, cache_pool=self.cache_pool - ) - worker_active = True - finally: - if worker_active: - self.worker_config.worker_deactivate() + trace_writer = self.worker_config.worker_trace_writer() + with trace_writer.span( + "SimpleSavableDatasetWrapper.__iter__", + args={ + "config": self.config(), + "loader_id": self.dataloader_id, + "rank": self.worker_config.rank, + "worker_id": self.worker_config.rank_worker_id(), + }, + level=1, + ): + self._state_restored = True + worker_id = self.worker_config.rank_worker_id() + global_worker_id = self.worker_config.global_worker_id() + while self._state_restored: + self._state_restored = False + self.worker_config.worker_activate(self._sample_index, cache_pool=self.cache_pool) + worker_active = True + try: + # For tracing, this contains the current sample flow + + trace_sample_flow: Flow = None + + def _next_span(): + # Trace the next sample flow + nonlocal trace_sample_flow + span = trace_writer.span( + name="SimpleSavableDatasetWrapper.__iter__.loop.dataset.next", + args={"sample_idx": self._sample_index}, + level=1, + ) + trace_sample_flow = trace_writer.flow( + f"w{global_worker_id}_s{self._sample_index}", + level=1, + ) + return span + + for src_data in trace_writer.iterable(self.dataset, next=_next_span): + self.worker_config.worker_deactivate() + worker_active = False + sample_index = self._sample_index + src_data = add_sample_restore_key( + src_data, global_worker_id, sample_index, src=self + ) + self._sample_index += 1 + trace_sample_flow.end(level=1) + trace_writer.instant( + "SimpleSavableDatasetWrapper.__iter__.loop.yield", + args={"sample_index": sample_index}, + level=1, + ) + yield worker_id, sample_index, src_data + if self._state_restored: + # Restart iterator after restore + break + self.worker_config.worker_activate( + self._sample_index, cache_pool=self.cache_pool + ) + worker_active = True + finally: + if worker_active: + self.worker_config.worker_deactivate() def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T: - id, global_worker_id, sample_idx = index[:3] - assert id == type(self).__name__ - index = index[3:] - self.worker_config.worker_activate( - sample_idx, override_global_rank=global_worker_id, cache_pool=self.cache_pool - ) - try: - return add_sample_restore_key( - self.dataset.restore_sample(index), - global_worker_id, - sample_idx, - src=self, + with self.worker_config.worker_trace_writer().span( + "SimpleSavableDatasetWrapper.restore_sample", args={"index": index}, level=1 + ): + id, global_worker_id, sample_idx = index[:3] + assert id == type(self).__name__ + index = index[3:] + self.worker_config.worker_activate( + sample_idx, override_global_rank=global_worker_id, cache_pool=self.cache_pool ) - finally: - self.worker_config.worker_deactivate() + try: + return add_sample_restore_key( + self.dataset.restore_sample(index), + global_worker_id, + sample_idx, + src=self, + ) + finally: + self.worker_config.worker_deactivate() def config(self) -> Dict[str, Any]: return self.dataset.config() @@ -147,6 +195,35 @@ def __str__(self): return f"SimpleSavableDatasetWrapper(dataset={self.dataset})" +class SimpleSavableDatasetWrapperWithoutLen(IterableDataset[Tuple[int, int, T]], Generic[T]): + """Wrapper for non-multiprocessing savable datasets. Restarts the inner dataset. This class is + not intended to be used directly.""" + + def __init__(self, dataset: SavableDataset[T]): + self.dataset = dataset + + def inner_len(self): + return len(self.dataset) + + def __iter__(self): + return self.dataset.__iter__() + + def save_state(self): + return self.dataset.save_state() + + def restore_state(self, state: FlexState): + return self.dataset.restore_state(state) + + def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T: + return self.dataset.restore_sample(index) + + def config(self): + return self.dataset.config() + + def __str__(self): + return f"SimpleSavableDatasetWrapperWithoutLen(dataset={self.dataset})" + + @dataclass_slots class SavableDatasetState(State): """State of the dataset wrapper. It stores the global random states and the index of the next @@ -235,6 +312,7 @@ def __init__( cmd_queues: List[torch.multiprocessing.Queue], result_queues: List[torch.multiprocessing.Queue], cache_pool: CachePool, + dataloader_id: int, ): """ Create the savable dataset wrapper for multiprocessing data loading. @@ -250,6 +328,7 @@ def __init__( cmd_queues: The command queues for communicating with the worker processes. result_queues: The result queues for communicating with the worker processes. cache_pool: The cache pool to use for the dataset. + dataloader_id: The id of the data loader for logging purposes. """ num_workers = max(worker_config.num_workers, 1) @@ -266,6 +345,7 @@ def __init__( self._cmd_queues = cmd_queues self._result_queues = result_queues self.cache_pool = cache_pool + self.dataloader_id = dataloader_id @staticmethod def _command_thread(self: "SavableDatasetWrapper"): @@ -276,35 +356,52 @@ def _command_thread(self: "SavableDatasetWrapper"): # print(f"{id(self)}:{multiprocessing.current_process().ident} Worker command thread starting") assert self._command_lock is not None - try: - while self._running: - try: - cmd_args = self._cmd_queues[self._worker_id].get(timeout=1) - except queue.Empty: - continue - # print(f"recv cmd {cmd_args}") - with self._command_lock: - cmd = cmd_args[0] - if cmd is None: - break + trace_writer = self.worker_config.worker_trace_writer() + trace_writer.metadata_thread_name("command_thread") + + with trace_writer.span( + "SavableDatasetWrapper._command_thread", args={"config": self.config()}, level=1 + ): + try: + while self._running: try: - fn = getattr(self, cmd) - self._result_queues[self._worker_id].put( - {self._worker_id: fn(*cmd_args[1:])} - ) - # print(f"result sent") - except Exception as e: - traceback.print_exc() - self._result_queues[self._worker_id].put({self._worker_id: e}) - # print(f"exc sent") - except BaseException: - traceback.print_exc() - raise - finally: - pass - # print(f"{id(self)}:{multiprocessing.current_process().ident} Worker command thread closing") + cmd_args = self._cmd_queues[self._worker_id].get(timeout=1) + except queue.Empty: + continue + # print(f"recv cmd {cmd_args}") + with self._command_lock: + cmd = cmd_args[0] + if cmd is None: + break + with trace_writer.span( + f"SavableDatasetWrapper._command_thread.{cmd}", level=1 + ): + try: + fn = getattr(self, cmd) + self._result_queues[self._worker_id].put( + {self._worker_id: fn(*cmd_args[1:])} + ) + # print(f"result sent") + except Exception as e: + traceback.print_exc() + self._result_queues[self._worker_id].put({self._worker_id: e}) + trace_writer.instant( + "SavableDatasetWrapper._command_thread.cmd_lock.exception", + args={ + "exc": f"{type(e).__name__}: {e}", + "tb": traceback.format_exc(), + }, + level=1, + ) + # print(f"exc sent") + except BaseException: + traceback.print_exc() + raise + finally: + pass + # print(f"{id(self)}:{multiprocessing.current_process().ident} Worker command thread closing") - def __len__(self): + def inner_len(self): return len(self.dataset) def __del__(self): @@ -316,97 +413,155 @@ def __del__(self): self._cmd_thread = None # print(f"{id(self)}:{multiprocessing.current_process().ident} Cmd thread closed") + def _trace_hierarchy(self, dataset: SavableDataset[T]): + trace_writer = self.worker_config.worker_trace_writer() + dataset = self.dataset + with trace_writer.span(type(dataset).__name__, level=1): + if isinstance(dataset, BaseWrapperDataset): + for dataset in dataset.datasets: + self._trace_hierarchy(dataset) + def __iter__(self): + trace_writer = self.worker_config.worker_trace_writer() # First: Set the worker offset globally for the current worker WorkerConfig.worker_id_offset = self._worker_offset self._worker_id = self.worker_config.rank_worker_id() global_worker_id = self.worker_config.global_worker_id() - if self._cmd_thread is None: - self._running = True - self._command_lock = threading.RLock() - weakref_self = weakref.proxy(self) - self._cmd_thread = threading.Thread( - target=SavableDatasetWrapper._command_thread, - name="command_thread", - args=(weakref_self,), - daemon=True, - ) - self._cmd_thread.start() - # atexit.register(lambda: weakref_self.__del__()) - try: - assert self._command_lock is not None - with self._command_lock: - if self._workers_restore_from[self._worker_id] is not None: - my_state = self._workers_restore_from[self._worker_id] - my_ds_state = my_state.dataset_state - assert my_state is not None - if my_ds_state is None: - self.dataset.reset_state_deep() + with trace_writer.span( + "SavableDatasetWrapper.__iter__", + args={ + "config": self.config(), + "loader_id": self.dataloader_id, + "rank": self.worker_config.rank, + "worker_id": self._worker_id, + "global_worker_id": global_worker_id, + }, + level=1, + ): + if self._cmd_thread is None: + self._running = True + self._command_lock = threading.RLock() + weakref_self = weakref.proxy(self) + self._cmd_thread = threading.Thread( + target=SavableDatasetWrapper._command_thread, + name="command_thread", + args=(weakref_self,), + daemon=True, + ) + self._cmd_thread.start() + # atexit.register(lambda: weakref_self.__del__()) + try: + assert self._command_lock is not None + with self._command_lock: + if self._workers_restore_from[self._worker_id] is not None: + with trace_writer.span("SavableDatasetWrapper.__iter__.restore", level=1): + my_state = self._workers_restore_from[self._worker_id] + my_ds_state = my_state.dataset_state + assert my_state is not None + if my_ds_state is None: + self.dataset.reset_state_deep() + else: + self.dataset.restore_state(my_ds_state) + self._restore_state(my_state) + self._workers_restore_from[self._worker_id] = None else: - self.dataset.restore_state(my_ds_state) - self._restore_state(my_state) - self._workers_restore_from[self._worker_id] = None - else: - # Store the initial state of the worker if we stop before the first sample - self._store_checkpoint() - # If skipping, also restart the iterator to reach the start of the restored - # checkpoint - last_was_skip = True - while last_was_skip: - dataset_has_samples = False - self.worker_config.worker_activate( - self._sample_index, cache_pool=self.cache_pool - ) - worker_active = True - try: - for src_data in self.dataset: - self.worker_config.worker_deactivate() - worker_active = False - dataset_has_samples = True - if self._workers_skip_samples[self._worker_id] > 0: - # Skip ahead to reach the start of the restored checkpoint - # print(f"Skip [{self._sample_index}:{self._worker_id}] {src_data}") - self._workers_skip_samples[self._worker_id] -= 1 - self._sample_index += 1 - last_was_skip = True - self.worker_config.worker_activate( - self._sample_index, cache_pool=self.cache_pool - ) - worker_active = True - continue - last_was_skip = False - sample_index = self._sample_index - add_sample_restore_key( - src_data, global_worker_id, sample_index, src=self - ) - self._sample_index += 1 - self._store_checkpoint() - try: - self._command_lock.release() - # print(f"{id(self)}:{multiprocessing.current_process().ident} Lock released") - # Commands may be executed only when data was yielded, not during - # iteration fetching. - # print(f"Yield next data [{sample_index}:{self._worker_id}] {src_data}") - yield self._worker_id, sample_index, src_data - finally: - # print(f"{id(self)}:{multiprocessing.current_process().ident} Lock acquiring") - self._command_lock.acquire() - # print(f"{id(self)}:{multiprocessing.current_process().ident} Lock acquired") - self.worker_config.worker_activate( - self._sample_index, cache_pool=self.cache_pool - ) - worker_active = True - finally: - if worker_active: - self.worker_config.worker_deactivate() - - # If the dataset is empty, don't try again and again - if not dataset_has_samples: - break - finally: - # print(f"{id(self)}:{multiprocessing.current_process().ident} Worker iter closing") - # Always store a final checkpoint (it's likely to be saved) - self._store_checkpoint(force=True) + # Store the initial state of the worker if we stop before the first sample + self._store_checkpoint() + + # If skipping, also restart the iterator to reach the start of the restored + # checkpoint + last_was_skip = True + while last_was_skip: + dataset_has_samples = False + self.worker_config.worker_activate( + self._sample_index, cache_pool=self.cache_pool + ) + worker_active = True + trace_sample_flow: dict = {} + try: + with trace_writer.span("SavableDatasetWrapper.__iter__.loop", level=1): + + def _trace_next(): + nonlocal trace_sample_flow + span = trace_writer.span( + "SavableDatasetWrapper.__iter__.loop.dataset.next", + args={ + "sample_index": self._sample_index, + }, + level=1, + ) + + trace_sample_flow = trace_writer.flow( + f"w{global_worker_id}_s{self._sample_index}", + level=1, + ).save() + return span + + for src_data in trace_writer.iterable( + self.dataset, + next=_trace_next, + ): + self.worker_config.worker_deactivate() + worker_active = False + dataset_has_samples = True + if self._workers_skip_samples[self._worker_id] > 0: + with trace_writer.span( + "SavableDatasetWrapper.__iter__.loop.skip", level=1 + ): + # Skip ahead to reach the start of the restored checkpoint + # print(f"Skip [{self._sample_index}:{self._worker_id}] {src_data}") + self._workers_skip_samples[self._worker_id] -= 1 + self._sample_index += 1 + last_was_skip = True + self.worker_config.worker_activate( + self._sample_index, cache_pool=self.cache_pool + ) + worker_active = True + continue + last_was_skip = False + sample_index = self._sample_index + add_sample_restore_key( + src_data, global_worker_id, sample_index, src=self + ) + self._sample_index += 1 + self._store_checkpoint() + try: + self._command_lock.release() + # print(f"{id(self)}:{multiprocessing.current_process().ident} Lock released") + # Commands may be executed only when data was yielded, not during + # iteration fetching. + # print(f"Yield next data [{sample_index}:{self._worker_id}] {src_data}") + trace_writer.instant( + "SavableDatasetWrapper.__iter__.loop.yield", + args={"sample_index": sample_index}, + level=1, + ) + yield ( + self._worker_id, + sample_index, + src_data, + trace_sample_flow, + ) + finally: + # print(f"{id(self)}:{multiprocessing.current_process().ident} Lock acquiring") + self._command_lock.acquire() + # print(f"{id(self)}:{multiprocessing.current_process().ident} Lock acquired") + self.worker_config.worker_activate( + self._sample_index, cache_pool=self.cache_pool + ) + worker_active = True + finally: + if worker_active: + self.worker_config.worker_deactivate() + + # If the dataset is empty, don't try again and again + if not dataset_has_samples: + trace_writer.instant("SavableDatasetWrapper.__iter__.break", level=1) + break + finally: + # print(f"{id(self)}:{multiprocessing.current_process().ident} Worker iter closing") + # Always store a final checkpoint (it's likely to be saved) + self._store_checkpoint(force=True) def _store_checkpoint(self, force: bool = False) -> None: """ @@ -417,6 +572,7 @@ def _store_checkpoint(self, force: bool = False) -> None: Args: force: If true, ignore time or frequency condition. """ + trace_writer = self.worker_config.worker_trace_writer() if ( force or ( @@ -427,26 +583,24 @@ def _store_checkpoint(self, force: bool = False) -> None: ) or self._sample_index <= 1 ): - # print(f"Storing checkpoint at {self._worker_id}:{self._sample_index}") - self._last_checkpoints.append( - SavableCheckpoint( - state=self._save_state(), - checkpoint_time=time.perf_counter(), - sample_index=self._sample_index, + with trace_writer.span( + "SavableDatasetWrapper._store_checkpoint", + args={"force": force, "sample_index": self._sample_index}, + level=1, + ): + # print(f"Storing checkpoint at {self._worker_id}:{self._sample_index}") + self._last_checkpoints.append( + SavableCheckpoint( + state=self._save_state(), + checkpoint_time=time.perf_counter(), + sample_index=self._sample_index, + ) ) - ) - if len(self._last_checkpoints) > self.n_checkpoints: - self._last_checkpoints.pop(0) + if len(self._last_checkpoints) > self.n_checkpoints: + self._last_checkpoints.pop(0) def _save_state(self) -> SavableDatasetState: """Saves the internal state""" - ( - np_tp, - np_state, - pos, - has_gauss, - cached_gaussian, - ) = np.random.get_state() return SavableDatasetState( rng=SystemRng.save_state(), dataset_state=self.dataset.save_state(), @@ -557,19 +711,22 @@ def can_restore_sample(self) -> bool: return self.dataset.can_restore_sample() def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T: - id, global_worker_id, sample_idx = index[:3] - assert id == type(self).__name__ - index = index[3:] - self.worker_config.worker_activate(sample_idx, override_global_rank=global_worker_id) - try: - return add_sample_restore_key( - self.dataset.restore_sample(index), - global_worker_id, - sample_idx, - src=self, - ) - finally: - self.worker_config.worker_deactivate() + with self.worker_config.worker_trace_writer().span( + "SavableDatasetWrapper.restore_sample", args={"index": index}, level=1 + ): + id, global_worker_id, sample_idx = index[:3] + assert id == type(self).__name__ + index = index[3:] + self.worker_config.worker_activate(sample_idx, override_global_rank=global_worker_id) + try: + return add_sample_restore_key( + self.dataset.restore_sample(index), + global_worker_id, + sample_idx, + src=self, + ) + finally: + self.worker_config.worker_deactivate() def config(self) -> Dict[str, Any]: return self.dataset.config() @@ -643,7 +800,7 @@ class SavableDataLoader(DataLoader[T], Generic[T]): #: The worker config worker_config: WorkerConfig #: The wrapped dataset. For multiprocessing, this is a :class:`megatron.energon.SavableDatasetWrapper` - dataset: Union[SavableDatasetWrapper[T], SimpleSavableDatasetWrapper[T]] + dataset: Union[SavableDatasetWrapper[T], SimpleSavableDatasetWrapperWithoutLen[T]] #: The global ID counter _next_id: ClassVar[int] = 0 @@ -752,10 +909,13 @@ def __init__( cmd_queues=self.cmd_queues, result_queues=self.result_queues, cache_pool=cache_pool, + dataloader_id=self.id, ) else: - dataset = SimpleSavableDatasetWrapper( - dataset, self.worker_config, cache_pool=cache_pool + dataset = SimpleSavableDatasetWrapperWithoutLen( + SimpleSavableDatasetWrapper( + dataset, self.worker_config, cache_pool=cache_pool, dataloader_id=self.id + ) ) self._worker_sample_counters = [-1] * num_procs @@ -778,6 +938,17 @@ def __init__( self.worker_config.worker_seed(i) for i in range(self.worker_config.num_workers) ] + self.worker_config.worker_trace_writer().trace_object_async( + self, + "SavableDataLoader", + args={ + "loader_id": self.id, + "worker_config": self.worker_config.config(), + "config": dataset.config(), + }, + level=1, + ) + super().__init__( dataset, batch_size=None, @@ -788,106 +959,108 @@ def __init__( **kwargs, ) - self.worker_config.worker_trace_writer().trace_object( - self, - "SavableDataLoader", - args={ - "id": self.id, - "config": dataset.config(), - }, - level=1, - ) - @staticmethod def next_id() -> int: next_id = SavableDataLoader._next_id SavableDataLoader._next_id += 1 return next_id + def __len__(self): + return self.dataset.inner_len() + def __iter__(self): - outerself = self - - class InnerIterator: - """Internal class which keeps the iterator alive across multiple `iter()` calls. - If the inner iterator is exhausted, will also exhaust and a new instance is needed. - Also saves the last sample index and the next worker id. - """ - - finished: bool = False - iter_idx: int = 0 - id: int - - def __init__(self, iterator): - self._iterator = iterator - self.id = outerself.next_id() - outerself.worker_config.worker_trace_writer().trace_object( - self, - "SavableDataLoader.iter", + def _inner_generator(iterator): + iter_idx = 0 + id = self.next_id() + trace_writer = self.worker_config.worker_trace_writer() + trace_span = self.worker_config.worker_trace_span() + trace_writer.instant( + "SavableDataLoader.__iter__", + args={ + "world_size": self.worker_config.world_size, + "num_workers": self.worker_config.num_workers, + "loader_id": self.id, + "iter_id": id, + }, + level=1, + ) + with ( + trace_span.span( + "SavableDataLoader.__iter__", args={ - "id": outerself.id, - "iter_id": self.id, + "loader_id": self.id, + "iter_id": id, }, level=1, - ) - - # self._debugf = open( - # f"worker_samples_rank{outerself.worker_config.rank:02}_t{int(time.time())}.log", "w" - # ) - - def __iter__(self): - return self - - def __next__(self): + ), + trace_writer.generator( + "SavableDataLoader.__iter__.next", + level=1, + ) as trace_generator, + ): try: - worker_id, sample_idx, sample = next(self._iterator) - outerself._worker_sample_counters[worker_id] = sample_idx - # If the next sample will be from the first worker, we can safely resume - outerself._next_worker_id = (worker_id + 1) % max(outerself.num_workers, 1) - # self._debugf.write( - # f"[w={worker_id}, s={sample_idx}] {self._sample_str(sample)}\n" - # ) - # self._debugf.flush() - if outerself.worker_config.should_log(level=1): - keys = default_get_keys(sample) - outerself.worker_config.worker_trace_writer().instant( - "SavableDataLoader.yield", - args={ - "id": outerself.id, - "iter_id": self.id, + for worker_id, sample_idx, sample, trace_sample_flow in iterator: + self._worker_sample_counters[worker_id] = sample_idx + # If the next sample will be from the first worker, we can safely resume + self._next_worker_id = (worker_id + 1) % max(self.num_workers, 1) + if self.worker_config.should_log(level=1): + trace_writer.resume_flow(trace_sample_flow).end( + bind_enclosing_slice=True, level=1 + ) + keys = default_get_keys(sample) + trace_span.instant( + "SavableDataLoader.yield", + args={ + "loader_id": self.id, + "iter_id": id, + "worker_id": worker_id, + "worker_sample_idx": sample_idx, + "sample_idx": self._sample_idx, + "iter_idx": iter_idx, + "global_sample_idx": self._global_sample_idx, + **({} if keys is None else {"keys": keys}), + }, + level=1, + ) + with trace_generator.yield_( + last_args={ + "loader_id": self.id, + "iter_id": id, "worker_id": worker_id, - "worker_idx": sample_idx, - "idx": outerself._sample_idx, - "iter_idx": self.iter_idx, - "global_idx": outerself._global_sample_idx, + "worker_sample_idx": sample_idx, + "sample_idx": self._sample_idx, + "iter_idx": iter_idx, + "global_sample_idx": self._global_sample_idx, **({} if keys is None else {"keys": keys}), - }, - ) - outerself._sample_idx += 1 - outerself._global_sample_idx += 1 - self.iter_idx += 1 - return sample - except StopIteration: - self.finished = True - outerself._next_worker_id = 0 - outerself.worker_config.worker_trace_writer().instant( + } + ): + self._sample_idx += 1 + self._global_sample_idx += 1 + iter_idx += 1 + yield sample + finally: + self._persistent_iterator = None + self._next_worker_id = 0 + trace_span.instant( "SavableDataLoader.StopIteration", - args={ - "id": outerself.id, - "iter_id": self.id, - }, level=1, + args={"loader_id": self.id, "iter_id": id}, + ) + trace_writer.instant( + "SavableDataLoader.StopIteration", + level=1, + args={"loader_id": self.id, "iter_id": id}, ) - raise if self.num_workers > 0: # Always keep same iterator alive, as long as it yields data - if self._persistent_iterator is None or self._persistent_iterator.finished: - self._persistent_iterator = InnerIterator(super().__iter__()) + if self._persistent_iterator is None: + self._persistent_iterator = _inner_generator(super().__iter__()) self._sample_idx = 0 # print("New Iterator", self._persistent_iterator) return self._persistent_iterator else: - return InnerIterator(super().__iter__()) + return _inner_generator(super().__iter__()) def _worker_command(self, *cmd_args) -> List[Any]: """Executes a command in all workers and returns the results.""" @@ -908,6 +1081,8 @@ def _get_batch_size(self) -> Optional[int]: """Try to infer micro batch size from the dataset""" if isinstance(self.dataset, SavableDatasetWrapper): dataset = self.dataset.dataset + elif isinstance(self.dataset, SimpleSavableDatasetWrapperWithoutLen): + dataset = self.dataset.dataset else: dataset = self.dataset @@ -931,7 +1106,7 @@ def save_state_rank(self) -> Optional[SavableDataLoaderState]: # Fetch current rank's worker's state if self.num_workers == 0: # No workers configured - assert isinstance(self.dataset, SimpleSavableDatasetWrapper) + assert isinstance(self.dataset, SimpleSavableDatasetWrapperWithoutLen) worker_states = [self.dataset.save_state()] assert self._next_worker_id == 0 elif self._persistent_iterator is None: @@ -972,7 +1147,9 @@ def restore_state_rank(self, state: Optional[SavableDataLoaderState]) -> None: old_micro_batch_size = state.micro_batch_size micro_batch_size = self._get_batch_size() - if isinstance(self.dataset, SavableDataset): + if self.num_workers == 0: + # No workers configured + assert isinstance(self.dataset, SimpleSavableDatasetWrapperWithoutLen) assert micro_batch_size == old_micro_batch_size, ( "Changing micro batch size is not allowed without workers" ) @@ -981,6 +1158,7 @@ def restore_state_rank(self, state: Optional[SavableDataLoaderState]) -> None: assert isinstance(state.worker_states[0], FlexState) self.dataset.restore_state(state.worker_states[0]) else: + # Workers configured assert isinstance(self.dataset, SavableDatasetWrapper) assert all(isinstance(s, SavableDatasetCheckpoint) for s in state.worker_states) @@ -1276,6 +1454,17 @@ def __init__( gc.collect() # This ensures that we don't include any old worker refs in the newly forked worker processes + self.worker_config.worker_trace_writer().trace_object_async( + self, + "BasicDataLoader", + args={ + "loader_id": self.id, + "worker_config": self.worker_config.config(), + "config": self.config(), + }, + level=1, + ) + super().__init__( dataset, batch_size=None, @@ -1285,82 +1474,89 @@ def __init__( worker_init_fn=partial(_init_worker, seed_per_worker), **kwargs, ) - self.worker_config.worker_trace_writer().trace_object( - self, - "BasicDataLoader", - args={ - "id": self.id, - "config": self.config(), - }, - level=1, - ) def __iter__(self): - outerself = self - - class InnerIterator: - """Internal class which keeps the iterator alive across multiple `iter()` calls. - If the inner iterator is exhausted, will also exhaust and a new instance is needed. - Also saves the last sample index and the next worker id. - """ - - iter_idx: int = 0 - id: int - - def __init__(self, iterator): - self._iterator = iterator - self.id = SavableDataLoader.next_id() + def _inner_generator(iterator): + iter_idx = 0 + id = SavableDataLoader.next_id() + + trace_writer = self.worker_config.worker_trace_writer() + trace_span = self.worker_config.worker_trace_span() + + trace_writer.instant( + "BasicDataLoader.__iter__", + args={ + "rank": self.worker_config.rank, + "world_size": self.worker_config.world_size, + "num_workers": self.worker_config.num_workers, + "loader_id": self.id, + "iter_id": id, + }, + level=1, + ) - outerself.worker_config.worker_trace_writer().trace_object( - self, + with ( + trace_span.span( "BasicDataLoader.iter", args={ - "id": outerself.id, - "iter_id": self.id, + "loader_id": self.id, + "iter_id": id, }, level=1, - ) - - def __iter__(self): - return self - - def __next__(self): + ), + trace_writer.generator( + "BasicDataLoader.iter", + level=1, + ) as trace_generator, + ): try: - worker_id, sample_idx, sample = next(self._iterator) - # If the next sample will be from the first worker, we can safely resume - self.next_worker_id = (worker_id + 1) % max(outerself.num_workers, 1) - if outerself.worker_config.should_log(level=1): - keys = default_get_keys(sample) - outerself.worker_config.worker_trace_writer().instant( - "BasicDataLoader.yield", - args={ - "id": outerself.id, - "iter_id": self.id, + for worker_id, sample_idx, sample, trace_sample_flow in iterator: + if self.worker_config.should_log(level=1): + trace_writer.resume_flow(trace_sample_flow).end( + bind_enclosing_slice=True, level=1 + ) + keys = default_get_keys(sample) + trace_span.instant( + "BasicDataLoader.yield", + args={ + "loader_id": self.id, + "iter_id": id, + "worker_id": worker_id, + "worker_sample_idx": sample_idx, + "sample_idx": iter_idx, + "iter_idx": iter_idx, + "global_sample_idx": self._sample_idx, + **({} if keys is None else {"keys": keys}), + }, + level=1, + ) + with trace_generator.yield_( + last_args={ + "loader_id": self.id, + "iter_id": id, "worker_id": worker_id, - "worker_idx": sample_idx, - "idx": self.iter_idx, - "iter_idx": self.iter_idx, - "global_idx": outerself._sample_idx, - **({} if keys is None else {"keys": keys}), - }, - level=1, - ) - outerself._sample_idx += 1 - self.iter_idx += 1 - return sample - except StopIteration: - self.next_worker_id = 0 - outerself.worker_config.worker_trace_writer().instant( + "sample_idx": sample_idx, + } + ): + self._sample_idx += 1 + iter_idx += 1 + yield sample + finally: + trace_span.instant( "BasicDataLoader.StopIteration", args={ - "id": outerself.id, - "iter_id": self.id, + "loader_id": self.id, + "iter_id": id, }, level=1, ) - raise + trace_writer.instant( + "BasicDataLoader.StopIteration", + level=1, + args={"loader_id": self.id, "iter_id": id}, + ) - return InnerIterator(super().__iter__()) + return _inner_generator(super().__iter__()) def config(self): """Get the configuration, which defines the dataset. Useful in conjunction with `save_state` diff --git a/src/megatron/energon/tools/analyze_debug.py b/src/megatron/energon/tools/analyze_debug.py index 3114be19..6133fb5b 100644 --- a/src/megatron/energon/tools/analyze_debug.py +++ b/src/megatron/energon/tools/analyze_debug.py @@ -2,19 +2,13 @@ # SPDX-License-Identifier: BSD-3-Clause import json -import time -import traceback -from concurrent.futures.process import ProcessPoolExecutor +import re +from collections import defaultdict +from dataclasses import dataclass from pathlib import Path from typing import ( - Container, - Dict, Generator, - Iterable, - List, Literal, - Optional, - Tuple, TypedDict, Union, ) @@ -22,9 +16,8 @@ import click import numpy as np from PIL import Image -from tqdm import tqdm -from megatron.energon.dataclass_slots import dataclass_slots +from megatron.energon.tools.analyze_debug_merge import merge_log_reader cpal = np.array( [ @@ -98,54 +91,6 @@ ) -class YieldBatchLogLine(TypedDict): - # Json example: - # { - # "t": "yield_batch", - # "r": 1, - # "w": 1, - # "m": "train", - # "idx": 1, - # "keys": ["parts/data-train-000051.tar/528866", ...], - # } - t: Literal["yield_batch"] - r: int - w: int - m: Literal["train", "val"] - idx: int - keys: List[str] - - -class SampleLoaderYieldLogLine(TypedDict): - # Json example: - # { - # "t": "WebdatasetSampleLoaderDataset._slices_iter.yield", - # "r": 1, - # "w": 1, - # "index": 528800, - # "key": "parts/data-train-000051.tar/528866", - # "shard": "parts/data-train-000051.tar", - # "count": 633, - # "epoch": 0, - # "epoch_count": 633 - # } - t: Literal["WebdatasetSampleLoaderDataset._slices_iter.yield"] - r: int - w: int - #: The global index in the underlying dataset (concats of all shards) - index: int - #: The sample key from the shard, concatenated as f"{shard}/{key}" - key: str - #: Name of the shard - shard: str - #: Number of samples yielded from the sample loader over all epochs - count: int - #: Number of repetitions of the dataset (=epochs). First epoch is 0. - epoch: int - #: Number of samples yielded from the sample loader in the current epoch - epoch_count: int - - class AutosizingHeatmapWriter: """Writes a heatmap, automatically resizing it if necessary.""" @@ -166,6 +111,7 @@ def add(self, sample_id: int, step: int, src: int) -> None: Args: sample_id: The sample id (y-axis) step: The step (x-axis) + src: The source rank (colorizing) """ # Resize heatmap? while self.heatmap.shape[0] * self.heatmap_sample_factor <= sample_id: @@ -212,8 +158,7 @@ def save(self, path: Union[Path, str], gain: float): @click.command(name="analyze-debug") @click.argument( - "log_paths", - nargs=-1, + "log_path", type=click.Path(exists=True, file_okay=True, dir_okay=True, path_type=Path), ) @click.option( @@ -239,24 +184,6 @@ def save(self, path: Union[Path, str], gain: float): default=10, help="Gain (=multiplication factor) for the heatmap", ) -@click.option( - "--force-loading-order", - is_flag=True, - default=False, - help="If true, force using the dataloader loading order instead of batch data", -) -@click.option( - "--include-modality", - type=str, - default="train", - help="Choose which modality/modalities (train,val) to include. Comma separate for multiple.", -) -@click.option( - "--skip", - type=int, - default=0, - help="If >0, skip this many steps at the beginning of log file parsing.", -) @click.option( "--no-colors", is_flag=True, @@ -264,360 +191,286 @@ def save(self, path: Union[Path, str], gain: float): help="If set, disable colorizing ranks.", ) def command( - log_paths: List[Path], + log_path: Path, heatmap_path: Path, heatmap_steps: int, heatmap_samples: int, heatmap_gain: float, - force_loading_order: bool, - include_modality: str, - skip: int, no_colors: bool, ): """Internal tool to analyze randomness. The LOG_PATH should point to the folder with the debug log, or to a single log file.""" - if len(log_paths) == 0: - raise click.ClickException("No log paths specified") - log_files = [] - for log_path in log_paths: - if log_path.is_dir(): - log_files.extend(sorted(log_path.glob("*.jsonl"))) - elif log_path.is_file(): - log_files.append(log_path) - else: - raise click.ClickException(f"Invalid log path: {log_path}") + heatmap = AutosizingHeatmapWriter(heatmap_samples, heatmap_steps, colorize=not no_colors) - if len(log_files) == 0: - raise click.ClickException("No log files found") + print(f"Analyzing log {log_path}...") - heatmap = AutosizingHeatmapWriter(heatmap_samples, heatmap_steps, colorize=not no_colors) + if log_path.is_dir(): + log_paths = list(log_path.glob("*.json")) + else: + log_paths = [log_path] + + print(f"Analyzing {len(log_paths)} logs...") + + loader_log_loader = LogLoader(log_paths) - print(f"Analyzing {len(log_files)} logs...") - - modalities = [m.strip() for m in include_modality.split(",")] - - key_index = {} - count = 0 - if not force_loading_order: - loaders = [LoaderLogIter(log_file, start_idx=skip) for log_file in log_files] - loaders_by_id: Dict[int, Tuple[LoaderInfo, List[LoaderLogIter]]] = {} - with ProcessPoolExecutor(max_workers=16) as executor: - for loader, loader_info in tqdm( - executor.map(_proc_map_loader, loaders), total=len(loaders) - ): - for loader_id, loader_info in loader_info.items(): - if loader_id in loaders_by_id: - existing_loader_info, existing_loaders = loaders_by_id[loader_id] - assert ( - existing_loader_info.modality == loader_info.modality - and existing_loader_info.path == loader_info.path - ), ( - f"Found multiple loaders for {loader_id}: {existing_loader_info.modality, existing_loader_info.path} and {loader_info.modality, loader_info.path}" - ) - existing_loader_info.global_count = max( - existing_loader_info.global_count, loader_info.global_count - ) - existing_loaders.append(loader) - else: - loaders_by_id[loader_id] = (loader_info, [loader]) - print("Available loaders:") - selected_loader_id = None - must_select = False - for loader_id, (loader_info, _iters) in loaders_by_id.items(): + key_index: dict[str, int] = defaultdict(lambda: len(key_index)) + + for entry in loader_log_loader.read_entries(): + if isinstance(entry, LogLoader.LoaderIterator): print( - f" {loader_id}: {loader_info.modality} {loader_info.path} {loader_info.global_count} steps" + f"Loader rank={entry.rank} loader_id={entry.loader_id} iter_id={entry.iter_id} nw={entry.num_workers} ws={entry.world_size}" ) - if loader_info.modality in modalities: - if selected_loader_id is None: - selected_loader_id = loader_id - else: - # Have multiple loaders - must_select = True - if must_select: - while True: - loader_id_str = input("Choose loader id: ") - try: - selected_loader_id = int(loader_id_str) - except ValueError: - print(f"Invalid loader id {loader_id_str} 1") - continue - if selected_loader_id in loaders_by_id: - break - print(f"Invalid loader id {selected_loader_id}") - assert selected_loader_id is not None - selected_loader_info, selected_loader_readers = loaders_by_id[selected_loader_id] - print( - f"Reading for loader {selected_loader_id}: {selected_loader_info.modality} {selected_loader_info.path}" - ) - log_iters = [ - (idx, loader.log_entries(loader_ids={selected_loader_id})) - for idx, loader in enumerate(selected_loader_readers) - ] - with tqdm(total=selected_loader_info.global_count) as pbar: - while len(log_iters) > 0: - cur_count = 0 - # Iterate over all iterators for this count and put into heatmap - for src_idx, log_iter in tuple(log_iters): - # Iterate until None (=next count) is encountered - while True: - try: - log_keys = next(log_iter) - except StopIteration: - log_iters.remove((src_idx, log_iter)) - break - except OSError: - traceback.print_exc() - log_iters.remove((src_idx, log_iter)) - break - else: - if log_keys is None: - break - for log_key in log_keys: - key_id = key_index.setdefault(log_key, len(key_index)) - heatmap.add(key_id, count, src_idx) - cur_count += 1 - if cur_count == 0: - print(f"No data for step {count}") - count += 1 - pbar.update(1) + elif isinstance(entry, LogLoader.Worker): + print( + f"Worker rank={entry.loader.rank} loader_id={entry.loader.loader_id} iter_id={entry.loader.iter_id} worker_id={entry.worker_id}" + ) + # elif isinstance(entry, LogLoader.LoadSample): + # print(f"LoadSample {entry.worker.worker_id} {entry.worker.loader.loader_id} {entry.worker.loader.rank} {entry.worker.loader.num_workers} {entry.base_path} {entry.key} {entry.index} {entry.epoch} {entry.epoch_count}") + elif isinstance(entry, LogLoader.YieldSample): + # print(f"YieldSample rank={entry.worker.loader.rank} loader_id={entry.worker.loader.loader_id} iter_id={entry.worker.loader.iter_id} wrk_id={entry.worker.worker_id} sample_idx={entry.sample_idx} iter_idx={entry.iter_idx} global_sample_idx={entry.global_sample_idx} keys={entry.keys}") + if entry.keys is not None: + for key in entry.keys: + heatmap.add( + key_index[key], entry.global_sample_idx, src=entry.worker.loader.rank + ) + elif isinstance(entry, LogLoader.LoadNextEpoch): + # print(f"LoadNextEpoch rank={entry.worker.loader.rank} loader_id={entry.worker.loader.loader_id} iter_id={entry.worker.loader.iter_id} wrk_id={entry.worker.worker_id} epoch_idx={entry.epoch_idx} epoch_sample_count={entry.epoch_sample_count}") + pass + elif isinstance(entry, LogLoader.StopIteration): + # print(f"StopIteration rank={entry.loader.rank} loader_id={entry.loader.loader_id} iter_id={entry.loader.iter_id}") + pass if len(key_index) == 0: - if force_loading_order: - print("Forcing to use sample loader logs") - else: - print("No batch information in logs, trying sample loader logs...") - if modalities != {"train", "val"}: - print(" Data includes all modalities (train and val)") - print( - " Shuffle buffer and batching will not be considered, only the loading order from disk" - ) - log_iters = [ - _iter_sl_log_line_keys(_iter_sl_log_samples(log_file), start_idx=skip) - for log_file in log_files - ] - key_index = {} - count = 0 - start = time.time() - while len(log_iters) > 0: - cur_count = 0 - # Iterate over all iterators for this count and put into heatmap - for log_iter in tuple(log_iters): - # Iterate until None (=next count) is encountered - while True: - try: - log_key = next(log_iter) - except StopIteration: - log_iters.remove(log_iter) - break - except OSError: - traceback.print_exc() - log_iters.remove(log_iter) - break - else: - if log_key is None: - break - key_id = key_index.setdefault(log_key, len(key_index)) - heatmap.add(key_id, count) - cur_count += 1 - if cur_count == 0: - print(f"No data for step {count}") - if time.time() - start > 10: - print(f" Step {count}") - start = time.time() - count += 1 - - if count == 0: raise click.ClickException("No data found in logs") - print(f"Found {len(key_index)} unique sample keys, {count} steps") + print(f"Found {len(key_index)} unique sample keys, {heatmap.heatmap_step_max + 1} steps") # print(f"Heatmap factors: {heatmap_sample_factor} samples, {heatmap_step_factor} steps") # print(f"Heatmap max: {heatmap_sample_max} samples, {heatmap_step_max} steps") - n_samples, n_steps = heatmap.save(heatmap_path, heatmap_gain) + max_sample, max_step = heatmap.save(heatmap_path, heatmap_gain) print(f"Wrote heatmap to {heatmap_path}") print("Heatmap axes:") - print(f" x-axis: {n_steps} worker steps") - print(f" y-axis: {n_samples} samples") - - -class LoaderInitLogLine(TypedDict): - t: Literal["SavableLoader.__init__", "BasicDataLoader.__init__"] - r: int - w: None - id: int - config: dict - - -class LoaderIterLogLine(TypedDict): - t: Literal["SavableDataLoader.iter", "BasicDataLoader.iter"] - r: int - w: None - id: int - iter_id: int - - -class LoaderYieldLogLine(TypedDict): - t: Literal["SavableDataLoader.yield", "BasicDataLoader.yield"] - r: int - w: None + print(f" x-axis: {max_step + 1} worker steps") + print(f" y-axis: {max_sample + 1} samples") + + +class LogEntry(TypedDict): + """ + Chrome tracing log entry. + *ph*ase values: + - B: Begin + - E: End + - i: Instant + - b: Begin (async) + - e: End (async) + - n: Instant (async) + - C: Counter + - M: Metadata + - s: Flow start + - t: Flow step + - f: Flow end + """ + + ph: Literal["B", "E", "i", "b", "e", "n", "C", "M", "s", "t", "f"] + name: str id: int - iter_id: int - worker_id: int - worker_idx: int - idx: int - iter_idx: int - global_idx: int - keys: Optional[List[str]] - - -class LoaderStopLogLine(TypedDict): - t: Literal["SavableDataLoader.StopIteration", "BasicDataLoader.StopIteration"] - r: int - w: None - id: int - iter_id: int - - -LoaderLines = Union[ - LoaderInitLogLine, - LoaderIterLogLine, - LoaderYieldLogLine, - LoaderStopLogLine, -] - -LOADER_LOG_LINE_TYPES_T = ( - "SavableLoader.__init__", - "BasicDataLoader.__init__", - "SavableDataLoader.iter", - "BasicDataLoader.iter", - "SavableDataLoader.yield", - "BasicDataLoader.yield", - "SavableDataLoader.StopIteration", - "BasicDataLoader.StopIteration", -) - - -@dataclass_slots -class LoaderInfo: - id: int - modality: str - path: str - global_count: int - - -class LoaderLogIter: - def __init__(self, path: Path, start_idx: int = 0): - self._path = path - self._start_idx = start_idx - - def _iter_log_lines(self, which: Iterable[str]) -> Generator[LoaderLines, None, None]: - try: - with self._path.open("r") as rf: - for line in rf: - if any(f'"t": "{t}"' in line for t in which): - try: - yield json.loads(line.strip()) - except json.JSONDecodeError: - print("Cannot decode line", repr(line)) - except IOError as e: - print(f"Ignoring IOError: {e} for {self._path}") - - @staticmethod - def _find_config_modality(config: dict) -> Literal["train", "val"]: - assert isinstance(config, dict) - if "map_fn_config" in config and "training" in config["map_fn_config"]: - return "train" if config["map_fn_config"]["training"] else "val" - elif "dataset" in config: - return LoaderLogIter._find_config_modality(config["dataset"]) - elif "dataset_weights" in config: - return LoaderLogIter._find_config_modality(config["dataset_weights"][0][0]) - elif "datasets" in config: - return LoaderLogIter._find_config_modality(config["datasets"][0]) - assert False, f"Unrecognized config {config}" - - @staticmethod - def _find_config_path(config: dict) -> str: - assert isinstance(config, dict) - if "map_fn_config" in config and "_path" in config["map_fn_config"]: - return config["map_fn_config"]["_path"] - elif "dataset" in config: - return LoaderLogIter._find_config_path(config["dataset"]) - elif "dataset_weights" in config: - return LoaderLogIter._find_config_path(config["dataset_weights"][0][0]) - elif "datasets" in config: - return LoaderLogIter._find_config_path(config["datasets"][0]) - assert False, f"Unrecognized config {config}" - - def loaders(self) -> Dict[int, LoaderInfo]: - loaders = {} - for log_line in self._iter_log_lines( - ( - "SavableLoader.__init__", - "BasicDataLoader.__init__", - "SavableDataLoader.yield", - "BasicDataLoader.yield", - ) - ): - if log_line["t"] in ("SavableLoader.__init__", "BasicDataLoader.__init__"): - loaders[log_line["id"]] = LoaderInfo( - id=log_line["id"], - modality=self._find_config_modality(log_line["config"]), - path=self._find_config_path(log_line["config"]), - global_count=0, - ) - elif log_line["t"] in ("SavableDataLoader.yield", "BasicDataLoader.yield"): - loaders[log_line["id"]].global_count = log_line["global_idx"] - return loaders - - def log_entries(self, loader_ids: Container[int]) -> Generator[Optional[List[str]], None, None]: - idx = self._start_idx - for log_line in self._iter_log_lines(("SavableDataLoader.yield", "BasicDataLoader.yield")): - if ( - log_line["t"] in ("SavableDataLoader.yield", "BasicDataLoader.yield") - and log_line["id"] in loader_ids - ): - assert log_line["global_idx"] >= idx, ( - f"Found entry {log_line} with wrong idx <{idx}" - ) - while log_line["global_idx"] != idx: - yield None - idx += 1 - if "keys" in log_line: - yield log_line["keys"] - - def __repr__(self) -> str: - return f"log({str(self._path)})" - - -def _proc_map_loader(loader: LoaderLogIter) -> Tuple[LoaderLogIter, Dict[int, LoaderInfo]]: - return (loader, loader.loaders()) - - -def _iter_sl_log_line_keys( - log_lines: Iterable[SampleLoaderYieldLogLine], - start_idx: int = 0, -) -> Generator[Optional[str], None, None]: - count = start_idx - for log_line in log_lines: - if log_line["count"] < start_idx: - continue - assert log_line["count"] >= count - while log_line["count"] != count: - yield None - count += 1 - yield log_line["key"] - - -def _iter_sl_log_samples(path: Path) -> Generator[SampleLoaderYieldLogLine, None, None]: - with path.open("r") as rf: - for line in rf: - if '"t": "WebdatasetSampleLoaderDataset._slices_iter.yield"' in line: - try: - yield json.loads(line.strip()) - except json.JSONDecodeError: - print("Cannot decode line", repr(line)) + ts: int + pid: int + tid: int + args: dict + s: Literal["t", "p", "g"] + + +class LogLoader: + """Loads a chrome tracing log file. Extract specific information from it.""" + + _re_pname = re.compile(r"^dprank(\d+)(?:_worker(\d+))?$") + + def __init__(self, paths: list[Path]): + self._paths = paths + + def _log_reader(self, path: Path) -> Generator[LogEntry, None, None]: + """Reads a log file and yields a tuple of the line and the ts.""" + had_end = False + with open(path, "rb") as f: + assert f.read(2) == b"[\n", "Log file must start with a JSON array" + for line in f: + if not line: + assert had_end, "Log file must end with a JSON array" + if line.endswith(b"]\n"): + had_end = True + else: + assert line.endswith(b",\n"), f"Log file must be newline-terminated: {line}" + yield json.loads(line[:-2]) + assert had_end, "Log file must end with a JSON array" + + def _log_reader_all(self) -> Generator[LogEntry, None, None]: + """Reads all log files and yields a tuple of the line and the ts.""" + if len(self._paths) == 1: + yield from self._log_reader(self._paths[0]) + else: + for entry in merge_log_reader(self._paths): + yield json.loads(entry) + + @dataclass + class LoaderIterator: + world_size: int + rank: int + num_workers: int + loader_id: int + iter_id: int + + @dataclass + class Worker: + worker_id: int + loader: "LogLoader.LoaderIterator" + + @dataclass + class LoadSample: + worker: "LogLoader.Worker" + base_path: str + key: str + global_sample_index: int + sample_count: int + epoch_idx: int + epoch_sample_count: int + + @dataclass + class LoadNextEpoch: + worker: "LogLoader.Worker" + epoch_idx: int + epoch_sample_count: int + + @dataclass + class YieldSample: + worker: "LogLoader.Worker" + worker_sample_idx: int + sample_idx: int + iter_idx: int + global_sample_idx: int + keys: list[str] | None + + @dataclass + class StopIteration: + loader: "LogLoader.LoaderIterator" + + def read_entries(self): + # Maps pid to (rank, worker_id|None) + procs: dict[int, tuple[int, int | None]] = dict() + # Maps (pid, tid) to worker_id|None, only for main threads + proc_workers: dict[tuple[int, int], int | None] = dict() + # Maps (pid, tid) to worker + workers_by_pid_tid: dict[tuple[int, int], LogLoader.Worker] = dict() + # Maps (rank, loader_id, worker_id) to worker + workers_by_rank_loader_id_iter_id_worker_id: dict[ + tuple[int, int, int], LogLoader.Worker + ] = dict() + # Maps (rank, loader_id) to loader + loaders_by_rank_loader_id: dict[tuple[int, int], LogLoader.LoaderIterator] = dict() + # Maps (rank, loader_id, iter_id) to loader + loaders_by_rank_loader_id_iter_id: dict[tuple[int, int, int], LogLoader.LoaderIterator] = ( + dict() + ) + for log_entry in self._log_reader_all(): + ph = log_entry["ph"] + name = log_entry.get("name") + if ph == "M": + if name == "process_name": + pid = log_entry["pid"] + pname = log_entry["args"]["name"] + m = self._re_pname.match(pname) + if m: + rank = int(m.group(1)) + if m.group(2) is not None: + worker_id = int(m.group(2)) + else: + worker_id = None + procs[log_entry["pid"]] = (rank, worker_id) + if name == "thread_name": + thread_name = log_entry["args"]["name"] + pid = log_entry["pid"] + tid = log_entry["tid"] + if thread_name in ("main", "worker_main"): + proc_workers[(pid, tid)] = procs[pid][1] + if ph == "n": + if name == "WebdatasetSampleLoaderDataset._slices_iter.yield": + yield LogLoader.LoadSample( + worker=workers_by_pid_tid[(log_entry["pid"], log_entry["tid"])], + base_path=log_entry["args"]["base_path"], + key=log_entry["args"]["key"], + global_sample_index=log_entry["args"]["global_sample_index"], + sample_count=log_entry["args"]["sample_count"], + epoch_idx=log_entry["args"]["epoch_idx"], + epoch_sample_count=log_entry["args"]["epoch_sample_count"], + ) + elif name == "WebdatasetSampleLoaderDataset._slices_iter.next_epoch": + yield LogLoader.LoadNextEpoch( + worker=workers_by_pid_tid[(log_entry["pid"], log_entry["tid"])], + epoch_idx=log_entry["args"]["epoch_idx"], + epoch_sample_count=log_entry["args"]["epoch_sample_count"], + ) + elif name in ("SavableDataLoader.yield", "BasicDataLoader.yield"): + rank = procs[log_entry["pid"]][0] + yield LogLoader.YieldSample( + worker=workers_by_rank_loader_id_iter_id_worker_id[ + (rank, log_entry["args"]["loader_id"], log_entry["args"]["worker_id"]) + ], + worker_sample_idx=log_entry["args"]["worker_sample_idx"], + sample_idx=log_entry["args"]["sample_idx"], + iter_idx=log_entry["args"]["iter_idx"], + global_sample_idx=log_entry["args"]["global_sample_idx"], + keys=log_entry["args"].get("keys", None), + ) + elif name in ("SavableDataLoader.StopIteration", "BasicDataLoader.StopIteration"): + rank = procs[log_entry["pid"]][0] + yield LogLoader.StopIteration( + loader=loaders_by_rank_loader_id_iter_id[ + (rank, log_entry["args"]["loader_id"], log_entry["args"]["iter_id"]) + ], + ) + elif ph == "B": + if name in ( + "SavableDatasetWrapper.__iter__", + "SimpleSavableDatasetWrapper.__iter__", + ): + rank = procs[log_entry["pid"]][0] + # This is not 100% correct, but it's the best mapping we can get right now. + loader = loaders_by_rank_loader_id[(rank, log_entry["args"]["loader_id"])] + worker = LogLoader.Worker( + worker_id=log_entry["args"]["worker_id"], + loader=loader, + ) + workers_by_pid_tid[(log_entry["pid"], log_entry["tid"])] = worker + workers_by_rank_loader_id_iter_id_worker_id[ + (rank, loader.loader_id, worker.worker_id) + ] = worker + yield worker + elif ph == "b": + if name in ("SavableDataLoader.__iter__", "BasicDataLoader.__iter__"): + rank = procs[log_entry["pid"]][0] + loader = loaders_by_rank_loader_id[(rank, log_entry["args"]["loader_id"])] + loader.iter_id = log_entry["args"]["iter_id"] + loaders_by_rank_loader_id_iter_id[(rank, loader.loader_id, loader.iter_id)] = ( + loader + ) + yield loader + elif name in ("SavableDataLoader", "BasicDataLoader"): + cfg_rank = log_entry["args"]["worker_config"]["rank"] + rank = procs[log_entry["pid"]][0] + assert rank == cfg_rank, f"Rank mismatch: {rank} != {cfg_rank}" + + loader = LogLoader.LoaderIterator( + world_size=log_entry["args"]["worker_config"]["world_size"], + rank=rank, + num_workers=log_entry["args"]["worker_config"]["num_workers"], + loader_id=log_entry["args"]["loader_id"], + iter_id=-1, + ) + # This is not 100% correct, but it's the best mapping we can get right now. + loaders_by_rank_loader_id[(rank, log_entry["args"]["loader_id"])] = loader + yield loader if __name__ == "__main__": diff --git a/src/megatron/energon/tools/analyze_debug_merge.py b/src/megatron/energon/tools/analyze_debug_merge.py new file mode 100644 index 00000000..3fcddb91 --- /dev/null +++ b/src/megatron/energon/tools/analyze_debug_merge.py @@ -0,0 +1,146 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause + +import functools +import os +import re +from pathlib import Path +from typing import ( + Callable, + Generator, + List, +) + +import click + +# Regular expressions for parsing the log file efficiently +_re_ts = re.compile(rb'"ts":(\d+)') +_re_pid = re.compile(rb'"pid":(\d+)') + + +@click.command(name="analyze-debug-merge") +@click.argument( + "log_paths", + nargs=-1, + type=click.Path(exists=True, file_okay=True, dir_okay=True, path_type=Path), +) +@click.argument( + "output_path", + type=click.Path(exists=False, writable=True, dir_okay=False, path_type=Path), +) +def command( + log_paths: List[Path], + output_path: Path, +): + """Internal tool to merge multiple debug logs into a single file. + + The LOG_PATH should point to the folder with the debug log, or to a single log file.""" + + if len(log_paths) == 0: + raise click.ClickException("No log paths specified") + log_files = [] + for log_path in log_paths: + if log_path.is_dir(): + log_files.extend(sorted(log_path.glob("*.json"))) + elif log_path.is_file(): + log_files.append(log_path) + else: + raise click.ClickException(f"Invalid log path: {log_path}") + + if len(log_files) == 0: + raise click.ClickException("No log files found") + + print(f"Merging {len(log_files)} log files into {output_path}") + + entry_count = 0 + with open(output_path, "wb") as f: + f.write(b"[\n") + for entry in merge_log_reader(log_files): + f.write(entry + b",\n") + entry_count += 1 + f.seek(-2, os.SEEK_END) + f.write(b"]\n") + print(f"Merged {len(log_files)} log files with {entry_count} entries into {output_path}") + + +def merge_log_reader(log_files: List[Path]) -> Generator[bytes, None, None]: + """Merges multiple log files into a single stream of entries.""" + + # Map of (file_idx, pid) to new pid + repid_map = {} + + def get_repid(file_idx: int, pid: int) -> int: + if (file_idx, pid) in repid_map: + return repid_map[(file_idx, pid)] + repid_map[(file_idx, pid)] = len(repid_map) + return repid_map[(file_idx, pid)] + + log_readers = [ + _log_reader(log_file, functools.partial(get_repid, idx)) + for idx, log_file in enumerate(log_files) + ] + log_entries = [] + for idx in reversed(range(len(log_readers))): + reader = log_readers[idx] + try: + while True: + entry, ts = next(reader) + if ts is not None: + log_entries.append((entry, ts)) + break + yield entry + except StopIteration: + log_readers.pop(idx) + # Read the entries ordered by ts + while len(log_entries) > 0: + # Find the smallest entry, get that entry and fetch the next entry from the reader + min_ts = log_entries[0][1] + min_entry_idx = 0 + for entry_idx, (_, ts) in enumerate(log_entries[1:], 1): + if ts < min_ts: + min_ts = ts + min_entry_idx = entry_idx + min_entry, _ = log_entries[min_entry_idx] + yield min_entry + while True: + try: + next_entry, ts = next(log_readers[min_entry_idx]) + if ts is not None: + log_entries[min_entry_idx] = (next_entry, ts) + break + yield next_entry + except StopIteration: + del log_readers[min_entry_idx] + del log_entries[min_entry_idx] + break + + +def _log_reader( + log_file: Path, pidmap: Callable[[int], int] +) -> Generator[tuple[bytes, int | None], None, None]: + """Reads a log file and yields a tuple of the line and the ts.""" + + def pidmap_sub(match: re.Match[bytes]) -> bytes: + return b'"pid":' + str(pidmap(int(match.group(1)))).encode() + + had_end = False + with open(log_file, "rb") as f: + assert f.read(2) == b"[\n", "Log file must start with a JSON array" + for line in f: + if not line: + assert had_end, "Log file must end with a JSON array" + if line.endswith(b"]\n"): + had_end = True + else: + assert line.endswith(b",\n"), f"Log file must be newline-terminated: {line}" + line = _re_pid.sub(pidmap_sub, line) + ts = _re_ts.search(line) + if ts is None: + yield line[:-2], None + else: + yield line[:-2], int(ts.group(1)) + assert had_end, "Log file must end with a JSON array" + + +if __name__ == "__main__": + command() diff --git a/src/megatron/energon/tracing.py b/src/megatron/energon/tracing.py index 6280659a..0b0563a8 100644 --- a/src/megatron/energon/tracing.py +++ b/src/megatron/energon/tracing.py @@ -5,8 +5,9 @@ import os import threading import time +import traceback import weakref -from contextlib import AbstractContextManager +from contextlib import AbstractContextManager, contextmanager from dataclasses import asdict, is_dataclass from pathlib import Path from typing import ( @@ -16,9 +17,8 @@ ClassVar, ContextManager, Dict, - Generic, Iterable, - Iterator, + Literal, Mapping, Optional, TypeVar, @@ -33,7 +33,8 @@ "TraceWriter", "Span", "AsyncSpan", - "AsyncFlow", + "AsyncContext", + "Flow", "ObjectTrace", "NoopTraceWriter", ] @@ -51,6 +52,14 @@ def _timestamp_us() -> int: return time.time_ns() // 1_000 # convert ns -> µs +def _cur_thread_id() -> int: + """Return current thread id as int.""" + tid = threading.get_ident() + while tid > 0xFFFFFFFF: + tid = (tid & 0xFFFFFFFF) ^ (tid >> 32) + return tid + + class JsonEncoder(json.JSONEncoder): """Custom JSON encoder that handles numpy arrays, torch tensors, and dataclasses.""" @@ -62,8 +71,8 @@ def default(self, o: Any) -> Any: except Exception: return str(o)[:250] - # Handle dataclasses - if is_dataclass(o): + # Handle dataclass *instances* (exclude dataclass *types*). + if is_dataclass(o) and not isinstance(o, type): return {"__type__": type(o).__name__, **asdict(o)} return super().default(o) @@ -228,14 +237,14 @@ def duration_begin( args: Extra arguments object to attach to both *B* and *E* events. level: Logging level. """ - if level < self._log_level: + if level > self._log_level: return event = { "name": name, "ph": "B", "ts": _timestamp_us(), "pid": self._pid, - "tid": threading.get_ident(), + "tid": _cur_thread_id(), } if cat is not None: event["cat"] = cat @@ -259,14 +268,14 @@ def duration_end( args: Extra arguments object to attach to both *B* and *E* events. level: Logging level. """ - if level < self._log_level: + if level > self._log_level: return event = { "name": name, "ph": "E", "ts": _timestamp_us(), "pid": self._pid, - "tid": threading.get_ident(), + "tid": _cur_thread_id(), } if cat is not None: event["cat"] = cat @@ -293,16 +302,34 @@ def span( Returns: Span – a context manager emitting matching ``B``/``E`` events. """ - if level < self._log_level: + if level > self._log_level: return _NOOP_SPAN return Span(self, name=name, cat=cat, args=args) + def iterable( + self, + iterable: Iterable[T], + *, + name: Optional[str] = None, + next: Optional[Callable[[], ContextManager]] = None, + level: int = 0, + ) -> Iterable[T]: + """Wrap an iterable to emit trace events for each `next` call.""" + if level > self._log_level: + return iterable + assert (name is not None) != (next is not None), "Either name xor next must be provided" + if name is not None: + return iterable_wrapper(iterable, span=lambda: self.span(name)) + else: + assert next is not None + return iterable_wrapper(iterable, span=next) + def instant( self, name: str, *, cat: str | None = None, - scope: str = "t", + scope: Optional[Literal["t", "p", "g"]] = None, args: Optional[Dict[str, Any]] = None, level: int = 0, ) -> None: @@ -312,26 +339,39 @@ def instant( name: Display name. cat: Optional categories. scope: Trace-viewer scope selector – ``t`` (thread), ``p`` (process) - or ``g`` (global). + or ``g`` (global). Defaults to ``t``. args: Optional arguments payload. level: Logging level. """ - if level < self._log_level: + if level > self._log_level: return event = { "name": name, "ph": "i", "ts": _timestamp_us(), "pid": self._pid, - "tid": threading.get_ident(), - "s": scope, + "tid": _cur_thread_id(), } + if scope is not None: + event["s"] = scope if cat is not None: event["cat"] = cat if args: event["args"] = dict(args) self._emit(event) + def generator( + self, + name: str, + *, + cat: str | None = None, + next_args: Optional[Dict[str, Any]] = None, + level: int = 0, + ) -> "GeneratorContext": + if level > self._log_level: + return _NOOP_GENERATOR_CONTEXT + return GeneratorContext(self, name=name, cat=cat, next_args=next_args) + # Async events -------------------------------------------------------- def async_begin( @@ -340,7 +380,6 @@ def async_begin( *, id: Union[int, str, None] = None, cat: str | None = None, - scope: str | None = None, args: Optional[Dict[str, Any]] = None, level: int = 0, ) -> Union[int, str]: @@ -350,14 +389,13 @@ def async_begin( name: Event display name. id: Correlation identifier (int or str). cat: Optional categories. - scope: Extra scope string to avoid id collisions. args: Optional argument object. level: Logging level. """ - if level < self._log_level: - return if id is None: id = self._next_id() + if level > self._log_level: + return id event = { "name": name, @@ -365,12 +403,10 @@ def async_begin( "id": id, "ts": _timestamp_us(), "pid": self._pid, - "tid": threading.get_ident(), + "tid": _cur_thread_id(), } if cat is not None: event["cat"] = cat - if scope is not None: - event["scope"] = scope # avoid clash with "s" used by instant events if args: event["args"] = dict(args) self._emit(event) @@ -382,7 +418,6 @@ def async_instant( *, id: Union[int, str, None] = None, cat: str | None = None, - scope: str | None = None, args: Optional[Dict[str, Any]] = None, level: int = 0, ) -> None: @@ -392,11 +427,10 @@ def async_instant( name: Event name. id: Correlation identifier. cat: Categories. - scope: Optional scope string. args: Additional arguments. level: Logging level. """ - if level < self._log_level: + if level > self._log_level: return if id is None: id = self._next_id() @@ -407,12 +441,10 @@ def async_instant( "id": id, "ts": _timestamp_us(), "pid": self._pid, - "tid": threading.get_ident(), + "tid": _cur_thread_id(), } if cat is not None: event["cat"] = cat - if scope is not None: - event["scope"] = scope if args: event["args"] = dict(args) self._emit(event) @@ -423,7 +455,6 @@ def async_end( *, id: Union[int, str], cat: str | None = None, - scope: str | None = None, args: Optional[Dict[str, Any]] = None, level: int = 0, ) -> None: @@ -432,23 +463,20 @@ def async_end( Args: id: Correlation identifier. cat: Categories. - scope: Optional scope string. args: Additional arguments. level: Logging level. """ - if level < self._log_level: + if level > self._log_level: return event = { "ph": "e", "id": id, "ts": _timestamp_us(), "pid": self._pid, - "tid": threading.get_ident(), + "tid": _cur_thread_id(), } if cat is not None: event["cat"] = cat - if scope is not None: - event["scope"] = scope if args: event["args"] = dict(args) self._emit(event) @@ -459,7 +487,6 @@ def async_span( *, id: Union[int, str, None] = None, cat: str | None = None, - scope: str | None = None, args: Optional[Dict[str, Any]] = None, level: int = 0, ) -> "AsyncSpan": @@ -469,14 +496,13 @@ def async_span( name: Display name. id: Correlation identifier to keep events together. cat: Categories. - scope: Optional scope string. args: Arguments attached to the begin event. level: Logging level. Returns: AsyncSpan context manager. """ - if level < self._log_level: + if level > self._log_level: return _NOOP_ASYNC_SPAN if id is None: id = self._next_id() @@ -486,7 +512,6 @@ def async_span( name=name, id=id, cat=cat, - scope=scope, args=args, ) @@ -495,27 +520,46 @@ def async_flow( *, id: Union[int, str, None] = None, cat: str | None = None, - scope: str | None = None, level: int = 0, - ) -> "AsyncFlow": + ) -> "AsyncContext": """Return an *AsyncFlow* context-manager for a nestable async chain. Args: id: Correlation identifier. cat: Categories. - scope: Optional scope string. level: Logging level. """ - if level < self._log_level: - return _NOOP_ASYNC_FLOW + if level > self._log_level: + return _NOOP_ASYNC_CONTEXT if id is None: id = self._next_id() - return AsyncFlow( + return AsyncContext( + self, + id=id, + cat=cat, + ) + + def async_generator( + self, + name: str, + *, + id: Union[int, str, None] = None, + cat: str | None = None, + next_args: Optional[Dict[str, Any]] = None, + level: int = 0, + ) -> "AsyncGeneratorContext": + """Emit an async *generator* (``ph='g'``) event within this async flow.""" + if level > self._log_level: + return _NOOP_ASYNC_GENERATOR_CONTEXT + if id is None: + id = self._next_id() + return AsyncGeneratorContext( self, + name=name, id=id, cat=cat, - scope=scope, + next_args=next_args, ) # Counter events ------------------------------------------------------ @@ -539,7 +583,7 @@ def counter( cat: Categories. level: Logging level. """ - if level < self._log_level: + if level > self._log_level: return if isinstance(value, Mapping): args_field = value @@ -553,7 +597,7 @@ def counter( "ph": "C", "ts": _timestamp_us(), "pid": self._pid, - "tid": threading.get_ident(), + "tid": _cur_thread_id(), "args": args_field, } if id is not None: @@ -562,132 +606,12 @@ def counter( event["cat"] = cat self._emit(event) - # Object events ------------------------------------------------------- - - def object_new( - self, - name: str, - *, - id: Union[int, str, None] = None, - cat: str | None = None, - scope: str | None = None, - level: int = 0, - ) -> None: - """Emit an object creation event (``ph='N'``). - - Args: - name: Object type/name displayed in UI. - id: Unique identifier (e.g. pointer address or GUID). - cat: Categories. - scope: Optional scope string to avoid id clashes. - level: Logging level. - """ - if level < self._log_level: - return - if id is None: - id = self._next_id() - - event = { - "name": name, - "ph": "N", - "id": id, - "ts": _timestamp_us(), - "pid": self._pid, - "tid": threading.get_ident(), - } - if cat is not None: - event["cat"] = cat - if scope is not None: - event["scope"] = scope - self._emit(event) - return id - - def object_snapshot( - self, - name: str, - *, - id: Union[int, str, None] = None, - snapshot: Dict[str, Any], - cat: str | None = None, - scope: str | None = None, - level: int = 0, - ) -> None: - """Emit an object *snapshot* (``ph='O'``). - - Args: - name: Object name. - id: Identifier matching a previously created object. - snapshot: Arbitrary JSON-serialisable state payload. - cat: Categories. - scope: Optional scope string. - level: Logging level. - """ - if level < self._log_level: - return - if id is None: - id = self._next_id() - - event = { - "name": name, - "ph": "O", - "id": id, - "ts": _timestamp_us(), - "pid": self._pid, - "tid": threading.get_ident(), - "args": {"snapshot": dict(snapshot)}, - } - if cat is not None: - event["cat"] = cat - if scope is not None: - event["scope"] = scope - self._emit(event) - - def object_delete( - self, - name: str, - *, - id: Union[int, str, None] = None, - cat: str | None = None, - scope: str | None = None, - level: int = 0, - ) -> None: - """Emit an object deletion event (``ph='D'``). - - Args: - name: Object name. - id: Identifier. - cat: Categories. - scope: Optional scope string. - level: Logging level. - """ - if level < self._log_level: - return - if id is None: - id = self._next_id() - - event = { - "name": name, - "ph": "D", - "id": id, - "ts": _timestamp_us(), - "pid": self._pid, - "tid": threading.get_ident(), - } - if cat is not None: - event["cat"] = cat - if scope is not None: - event["scope"] = scope - self._emit(event) - - # Helper -------------------------------------------------------------- - - def object_trace( + def async_object_trace( self, name: str, *, id: Union[int, str, None] = None, cat: str | None = None, - scope: str | None = None, snapshot: Optional[Dict[str, Any]] = None, level: int = 0, ) -> "ObjectTrace": @@ -697,14 +621,13 @@ def object_trace( name: Object type/name. id: Identifier to correlate with future snapshots/deletion. cat: Categories. - scope: Optional scope string. snapshot: Optional initial snapshot emitted right after ``N``. level: Logging level. Returns: - ObjectTrace instance. + AsyncObjectTrace instance. """ - if level < self._log_level: + if level > self._log_level: return _NOOP_OBJECT_TRACE if id is None: id = self._next_id() @@ -714,17 +637,15 @@ def object_trace( name=name, id=id, cat=cat, - scope=scope, initial_snapshot=snapshot, ) - def trace_object( + def trace_object_async( self, obj: Any, name: str, *, cat: str | None = None, - scope: str | None = None, args: Optional[Dict[str, Any]] = None, level: int = 0, ) -> "ObjectTrace": @@ -734,20 +655,17 @@ def trace_object( obj: Target instance to monitor. name: Trace-viewer object name. cat: Categories. - scope: Optional scope string. level: Logging level. Returns: - ObjectTrace handle. + AsyncObjectTrace handle. """ if not gc.is_tracked(obj): raise ValueError("Object is not tracked by the garbage collector") - if level < self._log_level: + if level > self._log_level: return _NOOP_OBJECT_TRACE - trace = self.object_trace(name, id=id(obj), cat=cat, scope=scope) + trace = self.async_object_trace(name, id=id(obj), cat=cat, snapshot=args) weakref.finalize(obj, trace.delete) - if args: - trace.snapshot(args) return trace # Metadata ------------------------------------------------------------ @@ -757,7 +675,6 @@ def metadata( name: str, *, args: Dict[str, Any], - pid: int | None = None, tid: int | None = None, ) -> None: """Emit a generic *metadata* event (``ph='M'``). @@ -765,13 +682,12 @@ def metadata( Args: name: Metadata event name (e.g. ``process_name``). args: Arguments dict as required by the spec. - pid: Override process id; defaults to writer.pid. tid: Thread id; required for thread metadata. """ event = { "name": name, "ph": "M", - "pid": pid if pid is not None else self._pid, + "pid": self._pid, } if tid is not None: event["tid"] = tid @@ -779,31 +695,165 @@ def metadata( event["args"] = dict(args) self._emit(event) - def metadata_process_name(self, name: str, *, pid: int | None = None) -> None: - self.metadata("process_name", args=dict(name=name), pid=pid) + def metadata_process_name(self, name: str) -> None: + """Set the current process name.""" + self.metadata("process_name", args=dict(name=name)) + + def metadata_thread_name(self, name: str) -> None: + """Set the current thread name.""" + self.metadata("thread_name", args=dict(name=name), tid=_cur_thread_id()) + + # Flow events -------------------------------------------------------- + + def flow_start( + self, + name: str, + *, + id: Optional[int] = None, + cat: str | None = None, + level: int = 0, + ) -> Union[int, str]: + """Emit a *flow start* (``ph='s'``) event. The flow is bound to the enclosing slice. + + Args: + name: Display name. + id: Correlation identifier. + cat: Categories. + args: Additional arguments. + level: Logging level. + """ + if id is None: + id = self._next_id() + if level > self._log_level: + return id + + event: Dict[str, Any] = { + "name": name, + "ph": "s", + "id": id, + "ts": _timestamp_us(), + "pid": self._pid, + "tid": _cur_thread_id(), + } + if cat is not None: + event["cat"] = cat + self._emit(event) + return id + + def flow_step( + self, + name: str, + *, + id: int, + cat: str | None = None, + level: int = 0, + ) -> None: + """ + Emit a *flow step* (``ph='t'``) event. The flow is bound to the enclosing slice. - def metadata_process_labels(self, labels: str, *, pid: int | None = None) -> None: - self.metadata("process_labels", args=dict(labels=labels), pid=pid) + Args: + name: The name of the flow. + id: The id of the flow. + cat: The category of the flow. + level: The level of the flow. + """ + if level > self._log_level: + return - def metadata_process_sort_index(self, sort_index: int, *, pid: int | None = None) -> None: - self.metadata("process_sort_index", args=dict(sort_index=sort_index), pid=pid) + event: Dict[str, Any] = { + "name": name, + "ph": "t", + "id": id, + "ts": _timestamp_us(), + "pid": self._pid, + "tid": _cur_thread_id(), + } + if cat is not None: + event["cat"] = cat + self._emit(event) - def metadata_thread_name( - self, name: str, *, tid: int | None = None, pid: int | None = None + def flow_end( + self, + name: str, + *, + id: int, + cat: str | None = None, + bind_enclosing_slice: bool = False, + level: int = 0, ) -> None: - self.metadata( - "thread_name", args=dict(name=name), tid=tid or threading.get_ident(), pid=pid - ) + """Emit a *flow end* (``ph='f'``) event. The flow is finished either in the enclosing slice or at the next slice. + + Args: + name: The name of the flow. + id: The id of the flow. + cat: The category of the flow. + bind_enclosing_slice: If *True*, adds ``bp='e'`` to bind to the + enclosing slice (see Trace Event Format), otherwise binds to the next slice. + level: The level of the flow. + """ + if level > self._log_level: + return + + event: Dict[str, Any] = { + "name": name, + "ph": "f", + "id": id, + "ts": _timestamp_us(), + "pid": self._pid, + "tid": _cur_thread_id(), + } + if cat is not None: + event["cat"] = cat + if bind_enclosing_slice: + event["bp"] = "e" + self._emit(event) + + def flow( + self, + name: str, + *, + id: Optional[int] = None, + cat: str | None = None, + level: int = 0, + ) -> "Flow": + """Emit a *flow* event.""" + if level > self._log_level: + return _NOOP_FLOW + if id is None: + id = self._next_id() + return Flow(self, name=name, id=id, cat=cat) - def metadata_thread_sort_index( - self, sort_index: int, *, tid: int | None = None, pid: int | None = None + def resume_flow(self, saved_flow: dict) -> "Flow": + """Resume a flow from a dictionary.""" + if len(saved_flow) == 0: + return _NOOP_FLOW + return Flow(self, **saved_flow, resuming=True) + + # Exception --------------------------------------------------------- + + def async_exc( + self, + *, + name: str, + id: Union[int, str, None] = None, + cat: str | None = None, + level: int = 0, ) -> None: - self.metadata( - "thread_sort_index", - args=dict(sort_index=sort_index), - tid=tid or threading.get_ident(), - pid=pid, - ) + """Emit an *exception* event as an async instant (``ph='n'``). + + This is primarily used by :class:`AsyncFlow` to surface + exceptions that happened inside a flow. + """ + + if id is None: + id = self._next_id() + if level > self._log_level: + return + + # Represent exception as string to keep JSON serialisable. + exc_repr = traceback.format_exc().splitlines() + + self.async_instant(name, id=id, cat=cat, args={"exception": exc_repr}, level=level) # Context management --------------------------------------------------- @@ -830,7 +880,7 @@ class Span(AbstractContextManager): """ __slots__ = ("_writer", "_name", "_cat", "_args", "_begin_ts") - _writer: TraceWriter + _writer: Optional[TraceWriter] _name: str _cat: Optional[str] _args: Dict[str, Any] | None @@ -846,10 +896,7 @@ def __init__( self._writer = writer self._name = name self._cat = cat - self._args = args - - def begin(self) -> None: - self._writer.duration_begin(self._name, cat=self._cat, args=self._args) + self._writer.duration_begin(self._name, cat=self._cat, args=args or None) self._args = None def update_args(self, args: Dict[str, Any]) -> None: @@ -858,22 +905,46 @@ def update_args(self, args: Dict[str, Any]) -> None: else: self._args.update(args) - def end(self) -> None: - self._writer.duration_end(self._name, cat=self._cat, args=self._args or None) + def end(self, args: Optional[Dict[str, Any]] = None) -> None: + if self._writer is None: + return + if self._args and args: + self._args.update(args) + self._writer.duration_end(self._name, cat=self._cat, args=self._args or args or None) self._args = None + self._writer = None # ------------------------------------------------------------------ # Context management # ------------------------------------------------------------------ def __enter__(self): # noqa: D401 - self.begin() return self def __exit__(self, exc_type, exc, tb): # noqa: D401, N802 self.end() +class _NoopSpan(AbstractContextManager): + def begin(self, *args, **kwargs) -> None: + pass + + def update_args(self, *args, **kwargs) -> None: + pass + + def end(self, *args, **kwargs) -> None: + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + pass + + +_NOOP_SPAN = cast(Span, _NoopSpan()) + + class AsyncSpan(AbstractContextManager): """Context manager for *nestable async* events. @@ -885,10 +956,15 @@ class AsyncSpan(AbstractContextManager): "_name", "_id", "_cat", - "_scope", "_args", ) + _writer: Optional[TraceWriter] + _name: str + _id: Union[int, str] + _cat: Optional[str] + _args: Optional[Dict[str, Any]] + def __init__( self, writer: TraceWriter, @@ -896,29 +972,15 @@ def __init__( name: str, id: Union[int, str], cat: str | None = None, - scope: str | None = None, args: Optional[Dict[str, Any]] = None, ) -> None: self._writer = writer self._name = name self._id = id self._cat = cat - self._scope = scope - self._args = args - - # ------------------------------------------------------------------ - # Context management - # ------------------------------------------------------------------ - - def begin(self) -> None: - self._writer.async_begin( - self._name, id=self._id, cat=self._cat, scope=self._scope, args=self._args or None - ) self._args = None - def __enter__(self): # noqa: D401 - self.begin() - return self + self._writer.async_begin(self._name, id=self._id, cat=self._cat, args=args or None) def update_args(self, args: Dict[str, Any]) -> None: if self._args is None: @@ -927,117 +989,143 @@ def update_args(self, args: Dict[str, Any]) -> None: self._args.update(args) def end(self, args: Optional[Dict[str, Any]] = None) -> None: + if self._writer is None: + return if self._args and args: self._args.update(args) - self._writer.async_end( - self._name, id=self._id, cat=self._cat, scope=self._scope, args=self._args or None - ) + self._writer.async_end(self._name, id=self._id, cat=self._cat, args=self._args or None) self._args = None + self._writer = None - def __exit__(self, exc_type, exc, tb): # noqa: D401, N802 - self.end() - + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self.end() + + +class NoopAsyncSpan(AbstractContextManager): + def update_args(self, *args, **kwargs) -> None: + pass -class AsyncFlow(AbstractContextManager): - """Context manager for *nestable async* events.""" + def end(self, *args, **kwargs) -> None: + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + pass + + +_NOOP_ASYNC_SPAN = cast(AsyncSpan, NoopAsyncSpan()) + + +class AsyncContext: + """Context manager for *nestable async* events with the same id.""" __slots__ = ( "_writer", "_id", "_cat", - "_scope", ) + _writer: TraceWriter + _id: Union[int, str] + _cat: Optional[str] + def __init__( self, writer: TraceWriter, *, id: Union[int, str], cat: str | None = None, - scope: str | None = None, ) -> None: self._writer = writer self._id = id self._cat = cat - self._scope = scope - - # ------------------------------------------------------------------ - # Context management - # ------------------------------------------------------------------ - - def __enter__(self): # noqa: D401 - return self def instant(self, name: str, *, args: Optional[Dict[str, Any]] = None, level: int = 0) -> None: """Emit an async *instant* (``ph='n'``) event within this async flow.""" - self._writer.async_instant( - name, id=self._id, cat=self._cat, scope=self._scope, args=args, level=level - ) + self._writer.async_instant(name, id=self._id, cat=self._cat, args=args, level=level) def start(self, name: str, *, args: Optional[Dict[str, Any]] = None, level: int = 0) -> None: """Emit an async *start* (``ph='b'``) event within this async flow.""" - self._writer.async_begin( - name, id=self._id, cat=self._cat, scope=self._scope, args=args, level=level - ) + self._writer.async_begin(name, id=self._id, cat=self._cat, args=args, level=level) def end(self, name: str, *, args: Optional[Dict[str, Any]] = None, level: int = 0) -> None: """Emit an async *end* (``ph='e'``) event within this async flow.""" - self._writer.async_end( - name, id=self._id, cat=self._cat, scope=self._scope, args=args, level=level - ) + self._writer.async_end(name, id=self._id, cat=self._cat, args=args, level=level) def span( self, name: str, *, args: Optional[Dict[str, Any]] = None, level: int = 0 ) -> AsyncSpan: """Emit an async *span* (``ph='s'``) event within this async flow.""" - return self._writer.async_span( - name, id=self._id, cat=self._cat, scope=self._scope, args=args, level=level + return self._writer.async_span(name, id=self._id, cat=self._cat, args=args, level=level) + + def generator( + self, name: str, *, next_args: Optional[Dict[str, Any]] = None, level: int = 0 + ) -> "AsyncGeneratorContext": + """Get a generator context for the given name. + + This is used to trace all code being executed between yields of a generator. + + Usage:: + + with async_ctx.generator(name="my_generator", next_args={"item_idx": 0}) as ctx: + for item_idx, item in enumerate(iterable): + ctx.instant("item", args={"item": item}) + with ctx.yield_(next_args={"item_idx": item_idx + 1}): + yield item + """ + return self._writer.async_generator( + name, id=self._id, cat=self._cat, next_args=next_args, level=level ) def iterable(self, iterable: Iterable[T], *, name: str, level: int = 0) -> Iterable[T]: """Wrap an iterable to emit trace events for each `next` call.""" - if level < self._writer._log_level: + if level > self._writer._log_level: return iterable - return IterableNextWrapper(iterable, span=lambda: self.span(name)) + return iterable_wrapper(iterable, span=lambda: self.span(name)) - def exception(self, exc: Exception, *, name: str, level: int = 0) -> None: + def exc(self, *, name: str, level: int = 0) -> None: """Emit an exception event.""" - self._writer.exception( - exc, name=name, id=self._id, cat=self._cat, scope=self._scope, level=level - ) + self._writer.async_exc(name=name, id=self._id, cat=self._cat, level=level) - def __exit__(self, exc_type, exc, tb): # noqa: D401, N802 + +class NoopAsyncContext: + def instant(self, *args, **kwargs) -> None: pass + def async_start(self, *args, **kwargs) -> None: + pass -class IterableNextWrapper(Iterator[T], Generic[T]): - """A wrapper for an iterable that emits trace events for each `next` call.""" + def async_end(self, *args, **kwargs) -> None: + pass - __slots__ = ( - "_iterable", - "_name", - ) + def span(self, *args, **kwargs) -> AsyncSpan: + return _NOOP_ASYNC_SPAN - _iterator: Iterator[T] - _span: Callable[[], ContextManager] + def iterable(self, iterable, *args, **kwargs) -> Iterable: + return iterable - def __init__(self, iterable: Iterable[T], *, span: Callable[[], ContextManager]): - self._iterator = iter(iterable) - self._span = span + def exc(self, *args, **kwargs) -> None: + pass - def __iter__(self): - return self - def __next__(self): - with self._span(): - return next(self._iterator) +_NOOP_ASYNC_CONTEXT = cast(AsyncContext, NoopAsyncContext()) -class ObjectTrace(AbstractContextManager): - """Lifecycle helper for Trace-Event objects. +class AsyncGeneratorContext(AbstractContextManager): + """Context manager for a generator context, that interrupts when yielding. - Emits ``N`` on construction, :py:meth:`snapshot` for ``O`` and ``D`` upon - deletion, context exit, or garbage collection. + Use like this:: + + with writer.async_generator_context(name="my_generator", next_args={"item_idx": 0}) as ctx: + for item_idx, item in enumerate(iterable): + ctx.instant("item", args={"item": item}) + with ctx.yield_(next_args={"item_idx": item_idx + 1}): + yield item """ __slots__ = ( @@ -1045,165 +1133,275 @@ class ObjectTrace(AbstractContextManager): "_name", "_id", "_cat", - "_scope", - "_deleted", + "_active_scope", ) + _writer: Optional[TraceWriter] + _name: str + _id: Union[int, str] + _cat: Optional[str] + + _active_scope: Optional[AsyncSpan] + def __init__( self, writer: TraceWriter, *, name: str, id: Union[int, str], - cat: str | None = None, - scope: str | None = None, - initial_snapshot: Optional[Dict[str, Any]] = None, - ) -> None: + cat: Optional[str] = None, + next_args: Optional[Dict[str, Any]] = None, + ): self._writer = writer self._name = name self._id = id self._cat = cat - self._scope = scope - self._deleted = False - # Emit object creation event - self._writer.object_new(name, id=id, cat=cat, scope=scope) + self._active_scope = self._writer.async_span(name, id=id, cat=cat, args=next_args) - if initial_snapshot is not None: - self.snapshot(initial_snapshot) + @contextmanager + def yield_( + self, + *, + last_args: Optional[Dict[str, Any]] = None, + next_args: Optional[Dict[str, Any]] = None, + ): + if self._writer is None: + return + assert self._active_scope is not None + self._active_scope.end(args=last_args) + self._active_scope = None + try: + yield self + finally: + assert self._active_scope is None + self._active_scope = self._writer.async_span( + self._name, id=self._id, cat=self._cat, args=next_args + ) - # ------------------------------------------------------------------ - # API - # ------------------------------------------------------------------ + def yield_from( + self, + iterable: Iterable[T], + *, + last_args: Optional[Dict[str, Any]] = None, + args: Optional[Dict[str, Any]] = None, + ) -> Iterable[T]: + """Wrap an iterable to emit trace events for each `next` call.""" + for item in iterable: + with self.yield_(last_args=last_args, next_args=args): + last_args = None + yield item - def snapshot(self, data: Dict[str, Any], *, level: int = 0) -> None: - """Emit snapshot for current state of the object.""" - if self._deleted: - raise RuntimeError("Cannot snapshot deleted traced object") - self._writer.object_snapshot( - self._name, - id=self._id, - snapshot=data, - cat=self._cat, - scope=self._scope, - level=level, - ) + def __enter__(self): + return self - def delete(self) -> None: - """Emit delete event if not already emitted.""" - if not self._deleted: - self._writer.object_delete( - self._name, - id=self._id, - cat=self._cat, - scope=self._scope, - ) - self._deleted = True + def __exit__(self, exc_type, exc, tb): + assert self._active_scope is not None + self._active_scope.end() + self._active_scope = None + self._writer = None - # ------------------------------------------------------------------ - # Context management - # ------------------------------------------------------------------ - def __enter__(self): # noqa: D401 +class DummyAsyncGeneratorContext(AbstractContextManager): + @contextmanager + def yield_(self, *args, **kwargs): + yield self + + def yield_from(self, iterable: Iterable[T], **kwargs) -> Iterable[T]: + return iterable + + def __enter__(self): return self - def __exit__(self, exc_type, exc, tb): # noqa: D401, N802 - self.delete() - # Do not suppress exceptions - return False + def __exit__(self, *args, **kwargs): + pass - def __del__(self): # noqa: D401 - # Ensure deletion event when object garbage-collected - try: - self.delete() - except Exception: - pass +_NOOP_ASYNC_GENERATOR_CONTEXT = cast(AsyncGeneratorContext, DummyAsyncGeneratorContext()) -# ------------------------------------------------------------------ -# Noop implementations -# ------------------------------------------------------------------ +class GeneratorContext(AbstractContextManager): + """Context manager for a generator context, that interrupts when yielding. -class _NoopSpan(AbstractContextManager): - def begin(self, *args, **kwargs) -> None: - pass + Use like this:: - def update_args(self, *args, **kwargs) -> None: - pass + with writer.generator_context(name="my_generator", next_args={"item_idx": 0}) as ctx: + for item_idx, item in enumerate(iterable): + with ctx.yield_(next_args={"item_idx": item_idx + 1}): + yield item + """ - def end(self, *args, **kwargs) -> None: - pass + __slots__ = ( + "_writer", + "_name", + "_cat", + "_active_scope", + ) + + _writer: Optional[TraceWriter] + _name: str + _cat: Optional[str] + + _active_scope: Optional[Span] + + def __init__( + self, + writer: TraceWriter, + *, + name: str, + cat: Optional[str] = None, + next_args: Optional[Dict[str, Any]] = None, + ): + self._writer = writer + self._name = name + self._cat = cat + + self._active_scope = self._writer.span(name, cat=cat, args=next_args) + + @contextmanager + def yield_( + self, + *, + last_args: Optional[Dict[str, Any]] = None, + next_args: Optional[Dict[str, Any]] = None, + ): + if self._writer is None: + return + assert self._active_scope is not None + self._active_scope.end(args=last_args) + self._active_scope = None + try: + yield self + finally: + assert self._active_scope is None + self._active_scope = self._writer.span(self._name, cat=self._cat, args=next_args) + + def yield_from( + self, + iterable: Iterable[T], + *, + last_args: Optional[Dict[str, Any]] = None, + args: Optional[Dict[str, Any]] = None, + ) -> Iterable[T]: + """Wrap an iterable to emit trace events for each `next` call.""" + for item in iterable: + with self.yield_(last_args=last_args, next_args=args): + last_args = None + yield item def __enter__(self): return self def __exit__(self, exc_type, exc, tb): - pass + assert self._active_scope is not None + self._active_scope.end() + self._active_scope = None + self._writer = None -_NOOP_SPAN = cast(Span, _NoopSpan()) +class DummyGeneratorContext(AbstractContextManager): + @contextmanager + def yield_(self, *args, **kwargs): + yield self + def yield_from(self, iterable: Iterable[T], **kwargs) -> Iterable[T]: + return iterable -class NoopAsyncSpan(AbstractContextManager): - def begin(self, *args, **kwargs) -> None: - pass + def __enter__(self): + return self - def update_args(self, *args, **kwargs) -> None: + def __exit__(self, *args, **kwargs): pass - def end(self, *args, **kwargs) -> None: - pass - def __enter__(self): - return self +_NOOP_GENERATOR_CONTEXT = cast(GeneratorContext, DummyGeneratorContext()) - def __exit__(self, exc_type, exc, tb): - pass +def iterable_wrapper(iterable: Iterable[T], *, span: Callable[[], ContextManager]) -> Iterable[T]: + """A wrapper for an iterable that emits trace events for each `next` call.""" + ctx = span() + ctx.__enter__() + try: + for item in iterable: + ctx.__exit__(None, None, None) + yield item + ctx = span() + ctx.__enter__() + finally: + ctx.__exit__(None, None, None) -_NOOP_ASYNC_SPAN = cast(AsyncSpan, NoopAsyncSpan()) +class ObjectTrace: + """Lifecycle helper for Trace-Event objects, using async events to trace the object. -class NoopAsyncFlow(AbstractContextManager): - def instant(self, *args, **kwargs) -> None: - pass + Emits ``N`` on construction, :py:meth:`snapshot` for ``O`` and ``D`` upon + deletion, context exit, or garbage collection. + """ - def async_start(self, *args, **kwargs) -> None: - pass + __slots__ = ( + "_writer", + "_name", + "_id", + "_cat", + ) - def async_end(self, *args, **kwargs) -> None: - pass + _writer: Optional[TraceWriter] + _name: str + _id: Union[int, str] + _cat: Optional[str] - def span(self, *args, **kwargs) -> AsyncSpan: - return _NOOP_ASYNC_SPAN - - def iterable(self, iterable, *args, **kwargs) -> Iterable: - return iterable + def __init__( + self, + writer: TraceWriter, + *, + name: str, + id: Union[int, str], + cat: str | None = None, + initial_snapshot: Optional[Dict[str, Any]] = None, + ) -> None: + self._writer = writer + self._name = name + self._id = id + self._cat = cat - def __enter__(self): - return self + # Emit object creation event + self._writer.async_begin(name, id=id, cat=cat, args=initial_snapshot) - def __exit__(self, exc_type, exc, tb): - pass + # ------------------------------------------------------------------ + # API + # ------------------------------------------------------------------ + def snapshot(self, data: Dict[str, Any], *, level: int = 0) -> None: + """Emit snapshot for current state of the object.""" + if self._writer is None: + raise RuntimeError("Cannot snapshot deleted traced object") + self._writer.async_instant( + self._name, + id=self._id, + args=data, + cat=self._cat, + level=level, + ) -_NOOP_ASYNC_FLOW = cast(AsyncFlow, NoopAsyncFlow()) + def delete(self) -> None: + """Emit delete event if not already emitted.""" + if self._writer is None: + return + self._writer.async_end( + self._name, + id=self._id, + cat=self._cat, + ) + self._writer = None -class NoopObjectTrace(AbstractContextManager): +class NoopObjectTrace: def snapshot(self, *args, **kwargs) -> None: pass def delete(self, *args, **kwargs) -> None: pass - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - pass - _NOOP_OBJECT_TRACE = cast(ObjectTrace, NoopObjectTrace()) @@ -1235,8 +1433,8 @@ def async_end(self, *args, **kwargs) -> None: def async_span(self, *args, **kwargs) -> "AsyncSpan": return _NOOP_ASYNC_SPAN - def async_flow(self, *args, **kwargs) -> "AsyncFlow": - return _NOOP_ASYNC_FLOW + def async_context(self, *args, **kwargs) -> "AsyncContext": + return _NOOP_ASYNC_CONTEXT def flow_start(self, *args, **kwargs) -> None: pass @@ -1247,16 +1445,10 @@ def flow_step(self, *args, **kwargs) -> None: def flow_end(self, *args, **kwargs) -> None: pass - def counter(self, *args, **kwargs) -> None: - pass - - def object_new(self, *args, **kwargs) -> None: - pass - - def object_snapshot(self, *args, **kwargs) -> None: - pass + def flow(self, *args, **kwargs) -> "Flow": + return _NOOP_FLOW - def object_delete(self, *args, **kwargs) -> None: + def counter(self, *args, **kwargs) -> None: pass def object_trace(self, *args, **kwargs) -> "ObjectTrace": @@ -1280,5 +1472,116 @@ def __exit__(self, exc_type, exc, tb): def __repr__(self) -> str: return "" + def exception(self, *args, **kwargs) -> None: + pass + NOOP_TRACE_WRITER: TraceWriter = cast(TraceWriter, NoopTraceWriter()) + + +# ------------------------------------------------------------------ +# Flow context manager +# ------------------------------------------------------------------ + + +class Flow: + """Context manager for *flow* events (``ph='s'``/``'t'``/``'f'``). + + Use :py:meth:`step` for intermediate *t* events inside the flow. + """ + + __slots__ = ( + "_writer", + "_name", + "_id", + "_cat", + ) + + _writer: Optional[TraceWriter] + _name: str + _id: Union[int, str] + _cat: Optional[str] + + def __init__( + self, + writer: TraceWriter, + *, + name: str, + id: int, + cat: str | None = None, + resuming: bool = False, + ) -> None: + self._writer = writer + self._name = name + self._id = id + self._cat = cat + + # Emit flow *start* event. + if not resuming: + self._writer.flow_start( + self._name, + id=self._id, + cat=self._cat, + ) + + def step(self, *, level: int = 0) -> None: + """Emit a *flow step* (``ph='t'``) event. The flow is bound to the enclosing slice.""" + writer = self._writer + if writer is None: + return + writer.flow_step( + self._name, + id=self._id, + cat=self._cat, + level=level, + ) + + def end( + self, + *, + level: int = 0, + bind_enclosing_slice: bool = False, + ) -> None: + """ + Emit the *flow end* (``ph='f'``) event. The flow is finished either in the enclosing slice or in the next slice. + + Args: + name: The name of the flow. + level: The level of the flow. + bind_enclosing_slice: Whether to bind the flow to the enclosing slice (otherwise bind to the next slice). + """ + writer = self._writer + if writer is None: + return + + writer.flow_end( + self._name, + id=self._id, + cat=self._cat, + bind_enclosing_slice=bind_enclosing_slice, + level=level, + ) + # Mark as closed to avoid further emissions. + self._writer = None + + def save(self) -> dict: + """Return a dictionary representation of the flow, allowing resuming the flow in another process.""" + return { + "name": self._name, + "id": self._id, + "cat": self._cat, + } + + +class NoopFlow: + def step(self, *args, **kwargs) -> None: + pass + + def end(self, *args, **kwargs) -> None: + pass + + def save(self) -> dict: + return {} + + +_NOOP_FLOW = cast(Flow, NoopFlow()) diff --git a/src/megatron/energon/watchdog.py b/src/megatron/energon/watchdog.py index cdb6c902..6239c090 100644 --- a/src/megatron/energon/watchdog.py +++ b/src/megatron/energon/watchdog.py @@ -291,21 +291,25 @@ def watch_iter(self, iterable: Iterable[T]) -> Iterator[T]: """ it = iter(iterable) - while True: - # Automatically resets timer - self.enable() - try: - item = next(it) - except StopIteration: - self.disable() - break - except: - # On any error, disable watchdog before re-raising - self.disable() - raise - else: - self.disable() - yield item + try: + while True: + # Automatically resets timer + self.enable() + try: + item = next(it) + except StopIteration: + self.disable() + break + except: + # On any error, disable watchdog before re-raising + self.disable() + raise + else: + self.disable() + yield item + finally: + if hasattr(it, "close"): + it.close() def repr_short(obj: Any) -> str: diff --git a/src/megatron/energon/worker.py b/src/megatron/energon/worker.py index cb72e24d..49dc3fc0 100644 --- a/src/megatron/energon/worker.py +++ b/src/megatron/energon/worker.py @@ -12,7 +12,7 @@ from megatron.energon.cache import CachePool from megatron.energon.dataclass_slots import dataclass_slots from megatron.energon.fork_hook import DataclassForkMixin -from megatron.energon.tracing import NOOP_TRACE_WRITER, AsyncFlow, TraceWriter +from megatron.energon.tracing import NOOP_TRACE_WRITER, AsyncContext, Flow, TraceWriter __all__ = ("WorkerConfig",) @@ -60,6 +60,8 @@ class WorkerConfig(DataclassForkMixin): worker_log_level: int = 0 #: The current trace writer for the worker. _worker_trace_writer: Optional[TraceWriter] = None + #: The current trace writer for the worker. + _worker_trace_sample_flow: Optional[Flow] = None #: The opened file for the current worker. Should not be set from outside. _worker_debug_file: Optional[TextIO] = None @@ -257,6 +259,7 @@ def __after_in_child_fork__(self): if self._worker_trace_writer is not None: self._worker_trace_writer.close() self._worker_trace_writer = None + self._worker_trace_sample_flow = None def __before_fork__(self): if self._worker_trace_writer is not None: @@ -284,12 +287,18 @@ def worker_trace_writer(self) -> TraceWriter: if in_worker: proc_name += f"_worker{self.rank_worker_id()}" self._worker_trace_writer = TraceWriter(path, log_level=self.worker_log_level) - self._worker_trace_writer.metadata_process_name(multiprocessing.current_process().name) - self._worker_trace_writer.metadata_process_labels(proc_name) - self._worker_trace_writer.metadata_process_sort_index(worker_id) - self._worker_trace_writer.metadata_thread_name("worker_main") - self._worker_trace_writer.metadata_thread_sort_index(0) + self._worker_trace_writer.metadata_process_name(proc_name) + if in_worker: + self._worker_trace_writer.metadata_thread_name("worker_main") + else: + self._worker_trace_writer.metadata_thread_name("main") + self._worker_trace_writer.flush() return self._worker_trace_writer - def worker_trace_span(self) -> AsyncFlow: + def worker_trace_span(self) -> AsyncContext: return self.worker_trace_writer().async_flow() + + def worker_trace_sample_flow(self, level: int): + if self._worker_trace_sample_flow is None: + return + self._worker_trace_sample_flow.step(level=level) diff --git a/src/megatron/energon/wrappers/base.py b/src/megatron/energon/wrappers/base.py index 30d38036..daef58dc 100644 --- a/src/megatron/energon/wrappers/base.py +++ b/src/megatron/energon/wrappers/base.py @@ -78,17 +78,26 @@ def _find_wrapped_dataset(self, cls: Type[SavableDataset]) -> Optional[SavableDa def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_sample_out: if len(self.datasets) == 1: - return self.datasets[0].restore_sample(index) + with self.worker_config.worker_trace_writer().span( + f"{type(self).__name__}.restore_sample", + level=1, + ): + return self.datasets[0].restore_sample(index) else: id, ds_idx = index[:2] assert id == type(self).__name__ index = index[2:] assert isinstance(ds_idx, int) - return add_sample_restore_key( - self.datasets[ds_idx].restore_sample(index), - ds_idx, - src=self, - ) + with self.worker_config.worker_trace_writer().span( + f"{type(self).__name__}.restore_sample", + args={"ds_idx": ds_idx}, + level=1, + ): + return add_sample_restore_key( + self.datasets[ds_idx].restore_sample(index), + ds_idx, + src=self, + ) def save_state(self) -> FlexState: own_state = super().save_state() diff --git a/src/megatron/energon/wrappers/batch_dataset.py b/src/megatron/energon/wrappers/batch_dataset.py index 0ce8c9f6..40fd6612 100644 --- a/src/megatron/energon/wrappers/batch_dataset.py +++ b/src/megatron/energon/wrappers/batch_dataset.py @@ -102,8 +102,65 @@ def __len__(self): ) def __iter__(self) -> Iterator[T_batch]: + batcher_name = self._function_config(self.batcher) trace_span = self.worker_config.worker_trace_span() - with trace_span.span("BatchDataset.__iter__", args={"config": self._own_config()}, level=1): + + def flush() -> Generator[T_batch, None, None]: + try: + with ( + self._sample_index.ctx() as sample_idx, + trace_span.span( + batcher_name, args={"sample_idx": sample_idx, "len": len(batch)}, level=2 + ), + ): + batch_sample = self.batcher(batch) + if isinstance(batch_sample, Generator): + assert inspect.isgeneratorfunction(self.batcher), ( + f"Generator in {self.batcher} but not marked as such." + ) + self._generator_sample_keys = sample_restore_keys + self._generator_offset = 0 + for batch_sub_idx, (sample_idx, inner_batch_sample) in trace_span.iterable( + self._sample_index.iter_ctx(batch_sample, sample_idx), + name=f"{batcher_name}.next", + level=2, + ): + self._generator_offset = batch_sub_idx + 1 + with trace_gen.yield_(next_args={"sample_idx": sample_idx}): + yield set_sample_restore_key( + inner_batch_sample, + sample_idx, + batch_sub_idx, + *sample_restore_keys, + src=self, + ) + self._generator_sample_keys = None + self._generator_offset = None + else: + set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self) + with trace_gen.yield_(next_args={"sample_idx": sample_idx}): + yield batch_sample + sample_restore_keys.clear() + except SkipSample: + trace_span.instant("BatchDataset.__iter__.skip", level=2) + except SYSTEM_EXCEPTIONS: + raise FatalSampleError.from_sample(batch) + except Exception as e: + self.error_handler(e, batch) + trace_span.instant( + "BatchDataset.__iter__.error/skip", + args={"exception": f"{type(e).__name__}: {str(e)}"}, + level=2, + ) + + with ( + trace_span.span("BatchDataset.__iter__", args={"config": self._own_config()}, level=1), + self.worker_config.worker_trace_writer().generator( + "BatchDataset.__iter__.next", + next_args={"sample_idx": self._sample_index.current_idx}, + level=2, + ) as trace_gen, + ): batch: List[T_batch_sample] = [] sample_restore_keys = [] @@ -113,7 +170,12 @@ def __iter__(self) -> Iterator[T_batch]: batch = [ self.dataset.restore_sample(inner_idx) for inner_idx in sample_restore_keys ] - with self._sample_index.ctx(self._sample_index.current_idx) as sample_idx: + with ( + self._sample_index.ctx(self._sample_index.current_idx) as sample_idx, + trace_span.span( + batcher_name, args={"sample_idx": sample_idx, "len": len(batch)}, level=2 + ), + ): batch_sample = self.batcher(batch) assert isinstance(batch_sample, Generator) assert inspect.isgeneratorfunction(self.batcher), ( @@ -121,38 +183,15 @@ def __iter__(self) -> Iterator[T_batch]: ) target_offset = self._generator_offset self._generator_offset = 0 - for batch_sub_idx, (sample_idx, inner_batch_sample) in enumerate( - self._sample_index.iter_ctx(batch_sample, sample_idx) + for batch_sub_idx, (sample_idx, inner_batch_sample) in trace_span.iterable( + self._sample_index.iter_ctx(batch_sample, sample_idx), + name=f"{batcher_name}.next", + level=2, ): # Skip other samples if batch_sub_idx >= target_offset: self._generator_offset = batch_sub_idx + 1 - yield set_sample_restore_key( - inner_batch_sample, - sample_idx, - batch_sub_idx, - *sample_restore_keys, - src=self, - ) - self._generator_sample_keys = None - self._generator_offset = None - batch.clear() - sample_restore_keys = [] - - def flush() -> Generator[T_batch, None, None]: - try: - with self._sample_index.ctx() as sample_idx: - batch_sample = self.batcher(batch) - if isinstance(batch_sample, Generator): - assert inspect.isgeneratorfunction(self.batcher), ( - f"Generator in {self.batcher} but not marked as such." - ) - self._generator_sample_keys = sample_restore_keys - self._generator_offset = 0 - for batch_sub_idx, (sample_idx, inner_batch_sample) in enumerate( - self._sample_index.iter_ctx(batch_sample, sample_idx) - ): - self._generator_offset = batch_sub_idx + 1 + with trace_gen.yield_(next_args={"sample_idx": sample_idx}): yield set_sample_restore_key( inner_batch_sample, sample_idx, @@ -160,43 +199,18 @@ def flush() -> Generator[T_batch, None, None]: *sample_restore_keys, src=self, ) - self._generator_sample_keys = None - self._generator_offset = None - else: - set_sample_restore_key( - batch_sample, sample_idx, *sample_restore_keys, src=self - ) - yield batch_sample - sample_restore_keys.clear() - except SkipSample: - trace_span.instant("BatchDataset.__iter__.skip", level=2) - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(batch) - except Exception as e: - self.error_handler(e, batch) - trace_span.instant( - "BatchDataset.__iter__.error/skip", - args={"exception": f"{type(e).__name__}: {str(e)}"}, - level=2, - ) - - batch_span = trace_span.span("BatchDataset.__iter__.collect", level=2) + self._generator_sample_keys = None + self._generator_offset = None + batch.clear() + sample_restore_keys = [] - try: - for sample in self.dataset: - batch.append(sample) - sample_restore_keys.append(get_sample_restore_key(sample)) - if len(batch) == self.batch_size: - batch_span.end() - yield from flush() - batch = [] - batch_span = trace_span.span("BatchDataset.__iter__.collect", level=2) - finally: - batch_span.end() + for sample in self.dataset: + batch.append(sample) + sample_restore_keys.append(get_sample_restore_key(sample)) + if len(batch) == self.batch_size: + yield from flush() + batch = [] if len(batch) > 0 and not self.drop_last: - batch_span = trace_span.span( - "BatchDataset.__iter__.last", args={"batch_size": len(batch)}, level=1 - ) yield from flush() def can_restore_sample(self) -> bool: @@ -211,40 +225,58 @@ def assert_can_restore(self) -> None: super().assert_can_restore() def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_batch: - # We need to store multiple indices to restore a batch. - self.assert_can_restore() - if inspect.isgeneratorfunction(self.batcher): - id, sample_idx, batch_sub_idx, *samples_restore_keys = index - assert id == type(self).__name__ - else: - id, sample_idx, *samples_restore_keys = index - assert id == type(self).__name__ - batch = [self.dataset.restore_sample(inner_idx) for inner_idx in samples_restore_keys] - with self._sample_index.ctx(sample_idx): - batch_sample = self.batcher(batch) - if isinstance(batch_sample, Generator): - assert inspect.isgeneratorfunction(self.batcher), ( - f"Generator in {self.batcher} but not marked as such." - ) - for cur_batch_sub_idx, (sample_idx, inner_batch_sample) in enumerate( - self._sample_index.iter_ctx(batch_sample, sample_idx) + trace_span = self.worker_config.worker_trace_span() + with trace_span.span("BatchDataset.restore_sample", args={"index": index}, level=1): + # We need to store multiple indices to restore a batch. + self.assert_can_restore() + if inspect.isgeneratorfunction(self.batcher): + id, sample_idx, batch_sub_idx, *samples_restore_keys = index + assert id == type(self).__name__ + else: + id, sample_idx, *samples_restore_keys = index + assert id == type(self).__name__ + with trace_span.span( + "BatchDataset.restore_sample.restore", + args={"len": len(samples_restore_keys)}, + level=2, ): - if cur_batch_sub_idx == batch_sub_idx: - return set_sample_restore_key( - inner_batch_sample, - sample_idx, - batch_sub_idx, - *samples_restore_keys, - src=self, - ) - assert False, f"Batch sub-index {batch_sub_idx} not found in batch" - else: - return set_sample_restore_key( - batch_sample, - sample_idx, - *samples_restore_keys, - src=self, - ) + batch = [ + self.dataset.restore_sample(inner_idx) for inner_idx in samples_restore_keys + ] + with ( + self._sample_index.ctx(sample_idx), + trace_span.span( + f"BatchDataset.restore_sample.batcher:{self._function_config(self.batcher)}", + args={"sample_idx": sample_idx, "len": len(batch)}, + level=2, + ), + ): + batch_sample = self.batcher(batch) + if isinstance(batch_sample, Generator): + assert inspect.isgeneratorfunction(self.batcher), ( + f"Generator in {self.batcher} but not marked as such." + ) + for cur_batch_sub_idx, (sample_idx, inner_batch_sample) in trace_span.iterable( + self._sample_index.iter_ctx(batch_sample, sample_idx), + name=f"BatchDataset.restore_sample.batcher:{self._function_config(self.batcher)}.next", + level=2, + ): + if cur_batch_sub_idx == batch_sub_idx: + return set_sample_restore_key( + inner_batch_sample, + sample_idx, + batch_sub_idx, + *samples_restore_keys, + src=self, + ) + assert False, f"Batch sub-index {batch_sub_idx} not found in batch" + else: + return set_sample_restore_key( + batch_sample, + sample_idx, + *samples_restore_keys, + src=self, + ) def _own_config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/wrappers/blend_dataset.py b/src/megatron/energon/wrappers/blend_dataset.py index 69cebfcc..6d041994 100644 --- a/src/megatron/energon/wrappers/blend_dataset.py +++ b/src/megatron/energon/wrappers/blend_dataset.py @@ -58,64 +58,75 @@ def __iter__(self) -> Iterator[T_sample]: assert self.worker_has_samples(), "Cannot blend all empty datasets" trace_span = self.worker_config.worker_trace_span() - with trace_span.span("BlendDataset.__iter__", args={"config": self._own_config()}, level=1): + with ( + trace_span.span("BlendDataset.__iter__", args={"config": self._own_config()}, level=1), + self.worker_config.worker_trace_writer().generator( + "BlendDataset.__iter__.next", level=2 + ) as trace_gen, + ): # Create a list of datasets and their weights, but # set the weight to 0 if the dataset has no samples on this worker. dataset_iters = [] weights = [] - for idx, (dataset, weight) in enumerate(self.dataset_weights): - assert weight > 0, "All blending weights must be > 0" - - if dataset.worker_has_samples(): - dataset_iters.append(iter(dataset)) - weights.append(weight) - else: - dataset_iters.append(None) - weights.append(0) - - weights = torch.tensor(weights, dtype=torch.float32) - if weights.sum() == 0: - raise RuntimeError( - "There is a worker with no samples in any of the blended datasets. " - "This can happen if you have a lot of workers and your dataset is too small. " - "Currently this case is not supported." - ) - - # Some may already be exhausted on this worker when restoring a state. - for idx, exhausted in enumerate(self.exhausted): - if exhausted: - weights[idx] = 0 - dataset_iters[idx] = None - - while True: - ds_idx = self._worker_rng.choice_idx(probs=weights) - trace_span.instant( - "BlendDataset.__iter__.sample", - args={"weights": weights, "ds_idx": ds_idx}, - level=2, - ) - - if dataset_iters[ds_idx] is None: - if all(dataset_iter is None for dataset_iter in dataset_iters): - break - continue - try: - sample = next(dataset_iters[ds_idx]) - except StopIteration: + try: + for idx, (dataset, weight) in enumerate(self.dataset_weights): + assert weight > 0, "All blending weights must be > 0" + + if dataset.worker_has_samples(): + dataset_iters.append(iter(dataset)) + weights.append(weight) + else: + dataset_iters.append(None) + weights.append(0) + + weights = torch.tensor(weights, dtype=torch.float32) + if weights.sum() == 0: + raise RuntimeError( + "There is a worker with no samples in any of the blended datasets. " + "This can happen if you have a lot of workers and your dataset is too small. " + "Currently this case is not supported." + ) + + # Some may already be exhausted on this worker when restoring a state. + for idx, exhausted in enumerate(self.exhausted): + if exhausted: + weights[idx] = 0 + dataset_iters[idx] = None + + while True: + ds_idx = self._worker_rng.choice_idx(probs=weights) trace_span.instant( - "BlendDataset.__iter__.exhausted", args={"ds_idx": ds_idx}, level=1 + "BlendDataset.__iter__.sample", + args={"weights": weights, "ds_idx": ds_idx}, + level=2, ) - dataset_iters[ds_idx] = None - weights[ds_idx] = 0 - self.exhausted[ds_idx] = True - if all(dataset_iter is None for dataset_iter in dataset_iters): - break - else: - yield add_sample_restore_key(sample, ds_idx, src=self) - - trace_span.instant("BlendDataset.__iter__.reset", level=1) - self.exhausted = [False] * len(self.dataset_weights) + + if dataset_iters[ds_idx] is None: + if all(dataset_iter is None for dataset_iter in dataset_iters): + break + continue + try: + sample = next(dataset_iters[ds_idx]) + except StopIteration: + trace_span.instant( + "BlendDataset.__iter__.exhausted", args={"ds_idx": ds_idx}, level=1 + ) + dataset_iters[ds_idx] = None + weights[ds_idx] = 0 + self.exhausted[ds_idx] = True + if all(dataset_iter is None for dataset_iter in dataset_iters): + break + else: + with trace_gen.yield_(): + yield add_sample_restore_key(sample, ds_idx, src=self) + + trace_span.instant("BlendDataset.__iter__.reset", level=1) + self.exhausted = [False] * len(self.dataset_weights) + finally: + for it in dataset_iters: + if hasattr(it, "close"): + it.close() def _own_config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/wrappers/epochize_dataset.py b/src/megatron/energon/wrappers/epochize_dataset.py index 8d85ea73..3112ff4a 100644 --- a/src/megatron/energon/wrappers/epochize_dataset.py +++ b/src/megatron/energon/wrappers/epochize_dataset.py @@ -58,30 +58,42 @@ def __iter__(self) -> Iterator[T_sample]: if self.worker_config.rank_worker_id() < self.length % self.worker_config.num_workers: local_length += 1 - with trace_span.span( - "EpochizeDataset.__iter__", - args={ - "offset": self._offset, - "local_length": local_length, - "config": self._own_config(), - }, - level=1, + with ( + trace_span.span( + "EpochizeDataset.__iter__", + args={ + "offset": self._offset, + "local_length": local_length, + "config": self._own_config(), + }, + level=1, + ), + self.worker_config.worker_trace_writer().generator( + "EpochizeDataset.__iter__.next", level=2 + ) as trace_gen, ): - offset_range = list(range(self._offset, local_length)) - - # Only iterate if there are samples to iterate - if len(offset_range) > 0: - if self._active_iter is None: - self._active_iter = iter(self.dataset) - - for idx in offset_range: - self._offset = (idx + 1) % local_length - try: - sample = next(self._active_iter) - except StopIteration: - break - yield sample - trace_span.instant("EpochizeDataset.__iter__.done", level=1) + try: + offset_range = list(range(self._offset, local_length)) + + # Only iterate if there are samples to iterate + if len(offset_range) > 0: + if self._active_iter is None: + self._active_iter = iter(self.dataset) + + for idx in offset_range: + self._offset = (idx + 1) % local_length + try: + sample = next(self._active_iter) + except StopIteration: + break + with trace_gen.yield_(): + yield sample + trace_span.instant("EpochizeDataset.__iter__.done", level=1) + except GeneratorExit: + if self._active_iter is not None and hasattr(self._active_iter, "close"): + self._active_iter.close() + self._active_iter = None + raise def __len__(self) -> int: return self.length diff --git a/src/megatron/energon/wrappers/filter_dataset.py b/src/megatron/energon/wrappers/filter_dataset.py index 93ecdde8..e8cf7e6a 100644 --- a/src/megatron/energon/wrappers/filter_dataset.py +++ b/src/megatron/energon/wrappers/filter_dataset.py @@ -1,7 +1,7 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Callable, Dict, Generic, Iterator, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, Generic, Iterator, Optional, TypeVar, Union from megatron.energon.flavors.base_dataset import SavableDataset from megatron.energon.worker import WorkerConfig @@ -51,22 +51,26 @@ def __len__(self): return len(self.dataset) def __iter__(self) -> Iterator[T_sample]: - with self.worker_config.worker_trace_span().span( - "FilterDataset.__iter__", args={"config": self._own_config()}, level=1 + trace_span = self.worker_config.worker_trace_span() + filter_name = self._function_config(self.filter_fn) + with ( + trace_span.span("FilterDataset.__iter__", args={"config": self._own_config()}, level=1), + trace_span.generator("FilterDataset.__iter__.next", level=2) as trace_gen, ): for sample in self.dataset: - with self._sample_index.ctx(): + with ( + self._sample_index.ctx(), + trace_span.span(filter_name, args={"sample": sample}, level=2), + ): filter_res = self.filter_fn(sample) if filter_res: - yield sample + with trace_gen.yield_(): + yield sample else: self.worker_config.worker_trace_span().instant( "FilterDataset.__iter__.reject", level=3 ) - def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_sample: - return self.dataset.restore_sample(index) - def _own_config(self) -> Dict[str, Any]: return { "filter_fn": self._function_config(self.filter_fn), diff --git a/src/megatron/energon/wrappers/gc_dataset.py b/src/megatron/energon/wrappers/gc_dataset.py index ac5a379b..5ab3c807 100644 --- a/src/megatron/energon/wrappers/gc_dataset.py +++ b/src/megatron/energon/wrappers/gc_dataset.py @@ -102,21 +102,28 @@ def __len__(self): def __iter__(self) -> Iterator[T_sample]: trace_span = self.worker_config.worker_trace_span() - in_worker = torch.utils.data.get_worker_info() is not None - if in_worker and not _frozen_cuda_tensors_initialized: - raise GcFreezeError( - "You are using GcDataset with multiple workers, but forgot to call gc_init_worker() in at least one forked worker process." - ) - - if self.freeze: - with trace_span.span("GcDataset.__iter__.gc.freeze", level=1): - gc.collect() - gc.freeze() - with trace_span.span("GcDataset.__iter__", args={"config": self._own_config()}, level=1): + with ( + trace_span.span("GcDataset.__iter__", args={"config": self._own_config()}, level=1), + self.worker_config.worker_trace_writer().generator( + "GcDataset.__iter__.next", level=2 + ) as trace_gen, + ): + in_worker = torch.utils.data.get_worker_info() is not None + if in_worker and not _frozen_cuda_tensors_initialized: + raise GcFreezeError( + "You are using GcDataset with multiple workers, but forgot to call gc_init_worker() in at least one forked worker process." + ) + + if self.freeze: + with trace_span.span("GcDataset.__iter__.gc.collect", level=1): + gc.collect() + with trace_span.span("GcDataset.__iter__.gc.freeze", level=1): + gc.freeze() try: iter = 0 for sample in self.dataset: - yield sample + with trace_gen.yield_(): + yield sample iter += 1 if iter >= self.every_n_iter: with trace_span.span("GcDataset.__iter__.gc.collect", level=1): @@ -124,7 +131,8 @@ def __iter__(self) -> Iterator[T_sample]: iter = 0 finally: if self.freeze: - gc.unfreeze() + with trace_span.span("GcDataset.__iter__.gc.unfreeze", level=1): + gc.unfreeze() def _own_config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/wrappers/group_batch_dataset.py b/src/megatron/energon/wrappers/group_batch_dataset.py index c03fe4b8..2cb1bc31 100644 --- a/src/megatron/energon/wrappers/group_batch_dataset.py +++ b/src/megatron/energon/wrappers/group_batch_dataset.py @@ -119,9 +119,15 @@ def __len__(self): def __iter__(self) -> Iterator[T_batch]: buckets = self._buckets + batcher_name = self._function_config(self.batcher) trace_span = self.worker_config.worker_trace_span() - with trace_span.span( - "GroupBatchDataset.__iter__", args={"config": self._own_config()}, level=1 + with ( + trace_span.span( + "GroupBatchDataset.__iter__", args={"config": self._own_config()}, level=1 + ), + self.worker_config.worker_trace_writer().generator( + "GroupBatchDataset.__iter__.next", level=2 + ) as trace_gen, ): if buckets is None: buckets = self._buckets = dict() @@ -145,24 +151,26 @@ def flush(key: Any, bucket: Bucket[T_batch_sample]) -> Generator[T_batch, None, batch_items, sample_restore_keys = bucket.samples.flush() # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] flushed: len(batch)={len(batch_items)} len(samples)={len(bucket.samples)}\n", end="") try: - with trace_span.span( - "GroupBatchDataset.flush", - args={ - "bucket": str(key), - "bucket_size": bucket.batch_size, - "bucket_len": len(batch_items), - }, - level=2, + with ( + self._batch_sample_index.ctx() as sample_idx, + trace_span.span( + batcher_name, + args={ + "bucket": str(key), + "bucket_size": bucket.batch_size, + "sample_idx": sample_idx, + "len": len(batch_items), + }, + level=2, + ), ): - with self._batch_sample_index.ctx() as sample_idx: - batch_sample = self.batcher(batch_items) - assert not isinstance(batch_sample, Generator), ( - f"Batcher {self.batcher} returned a generator, which is not supported for grouped batching yet." - ) - set_sample_restore_key( - batch_sample, sample_idx, *sample_restore_keys, src=self + batch_sample = self.batcher(batch_items) + assert not isinstance(batch_sample, Generator), ( + f"Batcher {self.batcher} returned a generator, which is not supported for grouped batching yet." ) - yield batch_sample + set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self) + with trace_gen.yield_(): + yield batch_sample except SkipSample: trace_span.instant("GroupBatchDataset.flush.skip", level=2) except SYSTEM_EXCEPTIONS: @@ -252,14 +260,34 @@ def assert_can_restore(self) -> None: super().assert_can_restore() def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_batch: - self.assert_can_restore() - id, sample_idx, *sample_restore_keys = index - assert id == type(self).__name__ - batch = [self.dataset.restore_sample(inner_idx) for inner_idx in sample_restore_keys] - with self._batch_sample_index.ctx(sample_idx): - batch_sample = self.batcher(batch) - set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self) - return batch_sample + trace_span = self.worker_config.worker_trace_span() + with trace_span.span( + "GroupBatchDataset.restore_sample", + args={"index": index}, + level=1, + ): + self.assert_can_restore() + id, sample_idx, *sample_restore_keys = index + assert id == type(self).__name__ + with trace_span.span( + "GroupBatchDataset.restore_sample.dataset", + args={"sample_idx": sample_idx, "len": len(sample_restore_keys)}, + level=2, + ): + batch = [ + self.dataset.restore_sample(inner_idx) for inner_idx in sample_restore_keys + ] + with ( + self._batch_sample_index.ctx(sample_idx), + trace_span.span( + f"GroupBatchDataset.restore_sample.batcher:{self._function_config(self.batcher)}", + args={"sample_idx": sample_idx, "len": len(batch)}, + level=2, + ), + ): + batch_sample = self.batcher(batch) + set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self) + return batch_sample def _own_config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/wrappers/iter_map_dataset.py b/src/megatron/energon/wrappers/iter_map_dataset.py index f5632697..0637a993 100644 --- a/src/megatron/energon/wrappers/iter_map_dataset.py +++ b/src/megatron/energon/wrappers/iter_map_dataset.py @@ -90,8 +90,14 @@ def __len__(self): def __iter__(self) -> Iterator[T_sample_out]: trace_span = self.worker_config.worker_trace_span() - with trace_span.span( - "IterMapDataset.__iter__", args={"config": self._own_config()}, level=1 + iter_name = f"IterMapDataset.__iter__.iter_map_fn:{self._function_config(self.iter_map_fn)}" + with ( + trace_span.span( + "IterMapDataset.__iter__", args={"config": self._own_config()}, level=1 + ), + self.worker_config.worker_trace_writer().generator( + "IterMapDataset.__iter__.next", level=2 + ) as trace_gen, ): last_sample_wrapper = _LastSampleWrapper(self.dataset) # The iter_map_fn is stateless. Thus we need to know which inner sample created the @@ -112,33 +118,39 @@ def reset_idx_iter() -> Generator[T_sample, None, None]: ds_iter = iter(reset_idx_iter()) - # While True will break when the inner dataset is exhausted, but may continue on exception - while True: - iter_idx = 0 - try: - for sample_idx, sample in self._sample_index.iter_ctx( - self.iter_map_fn(ds_iter) - ): - yield set_sample_restore_key( - sample, - sample_idx, - iter_idx, - *sample_restore_keys, - src=self, + try: + # While True will break when the inner dataset is exhausted, but may continue on exception + while True: + iter_idx = 0 + try: + for sample_idx, sample in trace_span.iterable( + self._sample_index.iter_ctx(self.iter_map_fn(ds_iter)), + name=iter_name, + level=1, + ): + with trace_gen.yield_(): + yield set_sample_restore_key( + sample, + sample_idx, + iter_idx, + *sample_restore_keys, + src=self, + ) + sample_restore_keys.clear() + iter_idx += 1 + except SYSTEM_EXCEPTIONS: + raise FatalSampleError.from_sample(last_sample_wrapper.last_sample) + except Exception as e: + self.error_handler(e, last_sample_wrapper.last_sample) + trace_span.instant( + "IterMapDataset.__iter__.error/retry", + args={"exception": f"{type(e).__name__}: {str(e)}"}, + level=1, ) - sample_restore_keys.clear() - iter_idx += 1 - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(last_sample_wrapper.last_sample) - except Exception as e: - self.error_handler(e, last_sample_wrapper.last_sample) - trace_span.instant( - "IterMapDataset.__iter__.error/retry", - args={"exception": f"{type(e).__name__}: {str(e)}"}, - level=1, - ) - else: - break + else: + break + finally: + ds_iter.close() def can_restore_sample(self) -> bool: return super().can_restore_sample() and self.stateless_iter_fn @@ -150,38 +162,52 @@ def assert_can_restore(self) -> None: super().assert_can_restore() def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_sample: + trace_span = self.worker_config.worker_trace_span() + iter_name = self._function_config(self.iter_map_fn) self.assert_can_restore() - id, sample_idx, iter_idx, *sample_restore_keys = index - assert id == type(self).__name__ - assert isinstance(iter_idx, int) - inner_iter = iter( - self.iter_map_fn( - (self.dataset.restore_sample(inner_index) for inner_index in sample_restore_keys) - ) - ) - try: - # Skip inner yielded samples to get the correct sample - for skip_idx in range(iter_idx): - with self._sample_index.ctx(sample_idx - iter_idx + skip_idx): - next(inner_iter) - # This is the sample to restore - with self._sample_index.ctx(sample_idx): - sample = next(inner_iter) - return set_sample_restore_key( - sample, - sample_idx, - iter_idx, - *sample_restore_keys, - src=self, - ) - except StopIteration: - raise RuntimeError( - "Generator did not yield enough samples, but is marked stateless/deterministic." + with trace_span.span( + "IterMapDataset.restore_sample", + args={"index": index}, + level=1, + ): + id, sample_idx, iter_idx, *sample_restore_keys = index + assert id == type(self).__name__ + assert isinstance(iter_idx, int) + inner_iter = iter( + trace_span.iterable( + self.iter_map_fn( + ( + self.dataset.restore_sample(inner_index) + for inner_index in sample_restore_keys + ) + ), + name=f"{iter_name}.next", + level=2, + ) ) - finally: - # Properly close if it's a generator - if hasattr(inner_iter, "close"): - inner_iter.close() + try: + # Skip inner yielded samples to get the correct sample + for skip_idx in range(iter_idx): + with self._sample_index.ctx(sample_idx - iter_idx + skip_idx): + next(inner_iter) + # This is the sample to restore + with self._sample_index.ctx(sample_idx): + sample = next(inner_iter) + return set_sample_restore_key( + sample, + sample_idx, + iter_idx, + *sample_restore_keys, + src=self, + ) + except StopIteration: + raise RuntimeError( + "Generator did not yield enough samples, but is marked stateless/deterministic." + ) + finally: + # Properly close if it's a generator + if hasattr(inner_iter, "close"): + inner_iter.close() def _own_config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/wrappers/log_sample_dataset.py b/src/megatron/energon/wrappers/log_sample_dataset.py index 2014c996..70145650 100644 --- a/src/megatron/energon/wrappers/log_sample_dataset.py +++ b/src/megatron/energon/wrappers/log_sample_dataset.py @@ -10,40 +10,46 @@ T_sample = TypeVar("T_sample") +def _flatten_str_list(keys: Any) -> Iterator[Optional[str]]: + """Flatten a list of keys into a list of strings.""" + if isinstance(keys, str): + yield keys + elif isinstance(keys, (list, tuple)): + for key in keys: + yield from _flatten_str_list(key) + else: + yield None + + +def _flatten_str_list_or_none(keys: Any) -> Optional[List[str]]: + """Flatten a list of keys into a list of strings. If this cannot be fetched, return None.""" + keys = list(_flatten_str_list(keys)) + if any(k is None for k in keys): + return None + return keys + + def default_get_keys(batch: Any) -> Optional[List[str]]: """Default get_keys, which has some heuristics to find the sample keys.""" if isinstance(batch, list): - batch = batch[0] - if ( - hasattr(batch, "__key__") - and isinstance(batch.__key__, list) - and all(isinstance(k, str) for k in batch.__key__) - ): - return batch.__key__ - elif ( - hasattr(batch, "__keys__") - and isinstance(batch.__keys__, list) - and all(isinstance(k, str) for k in batch.__keys__) - ): - return batch.__keys__ - elif ( - isinstance(batch, dict) - and "__key__" in batch - and all(isinstance(k, str) for k in batch["__key__"]) - ): - return batch["__key__"] - elif ( - isinstance(batch, dict) - and "__keys__" in batch - and all(isinstance(k, str) for k in batch["__keys__"]) - ): - return batch["__keys__"] - elif ( - isinstance(batch, dict) - and "keys" in batch - and all(isinstance(k, str) for k in batch["keys"]) - ): - return batch["keys"] + all_keys = [] + for b in batch: + k = default_get_keys(b) + if k is None: + return None + all_keys.extend(k) + return all_keys + if hasattr(batch, "__key__"): + return _flatten_str_list_or_none(batch.__key__) + elif hasattr(batch, "__keys__"): + return _flatten_str_list_or_none(batch.__keys__) + elif isinstance(batch, dict): + if "__key__" in batch: + return _flatten_str_list_or_none(batch["__key__"]) + elif "__keys__" in batch: + return _flatten_str_list_or_none(batch["__keys__"]) + elif "keys" in batch: + return _flatten_str_list_or_none(batch["keys"]) return None diff --git a/src/megatron/energon/wrappers/map_dataset.py b/src/megatron/energon/wrappers/map_dataset.py index 1c623df3..869e2126 100644 --- a/src/megatron/energon/wrappers/map_dataset.py +++ b/src/megatron/energon/wrappers/map_dataset.py @@ -88,12 +88,24 @@ def __len__(self): def __iter__(self) -> Iterator[T_sample_out]: trace_span = self.worker_config.worker_trace_span() - with trace_span.span("MapDataset.__iter__", args={"config": self._own_config()}, level=1): + map_dataset_prefix = f"MapDataset({self._function_config_short(self.map_fn)})" + fn_span = self._function_config(self.map_fn) + with ( + trace_span.span( + f"{map_dataset_prefix}.__iter__", args={"config": self._own_config()}, level=1 + ), + self.worker_config.worker_trace_writer().generator( + "MapDataset.__iter__.next", level=2 + ) as trace_gen, + ): if self._generator_sample_key is not None: assert self._generator_offset is not None sample = self.dataset.restore_sample(self._generator_sample_key) # Do not increment the sample index, use previous index - with self._sample_index.ctx(self._sample_index.current_idx) as sample_idx: + with ( + self._sample_index.ctx(self._sample_index.current_idx) as sample_idx, + trace_span.span(fn_span, args={"sample_idx": sample_idx}, level=2), + ): mapped_sample = self.map_fn(sample) assert isinstance(mapped_sample, Generator) assert inspect.isgeneratorfunction(self.map_fn), ( @@ -101,24 +113,30 @@ def __iter__(self) -> Iterator[T_sample_out]: ) target_offset = self._generator_offset self._generator_offset = 0 - for idx, (sample_idx, inner_sample) in enumerate( - self._sample_index.iter_ctx(mapped_sample, sample_idx) + for idx, (sample_idx, inner_sample) in trace_span.iterable( + self._sample_index.iter_ctx(mapped_sample, sample_idx), + name=f"{fn_span}.next", + level=2, ): # Skip other samples if idx >= target_offset: self._generator_offset = idx + 1 - yield add_sample_restore_key( - inner_sample, - sample_idx, - idx, - src=self, - ) + with trace_gen.yield_(last_args={"sample_idx": sample_idx, "idx": idx}): + yield add_sample_restore_key( + inner_sample, + sample_idx, + idx, + src=self, + ) self._generator_sample_key = None self._generator_offset = None for sample in self.dataset: try: - with self._sample_index.ctx() as sample_idx: + with ( + self._sample_index.ctx() as sample_idx, + trace_span.span(fn_span, args={"sample_idx": sample_idx}, level=2), + ): mapped_sample = self.map_fn(sample) if isinstance(mapped_sample, Generator): assert inspect.isgeneratorfunction(self.map_fn), ( @@ -128,32 +146,36 @@ def __iter__(self) -> Iterator[T_sample_out]: self._generator_offset = 0 # In case of a generator, additionally store the index of the yielded samples # per input sample - for idx, (sample_idx, inner_sample) in enumerate( - self._sample_index.iter_ctx(mapped_sample, sample_idx) + for idx, (sample_idx, inner_sample) in trace_span.iterable( + self._sample_index.iter_ctx(mapped_sample, sample_idx), + name=f"{fn_span}.next", + level=2, ): self._generator_offset = idx + 1 + with trace_gen.yield_(last_args={"sample_idx": sample_idx, "idx": idx}): + yield add_sample_restore_key( + inner_sample, + sample_idx, + idx, + src=self, + ) + self._generator_sample_key = None + self._generator_offset = None + else: + with trace_gen.yield_(last_args={"sample_idx": sample_idx}): yield add_sample_restore_key( - inner_sample, + mapped_sample, sample_idx, - idx, src=self, ) - self._generator_sample_key = None - self._generator_offset = None - else: - yield add_sample_restore_key( - mapped_sample, - sample_idx, - src=self, - ) except SkipSample: - trace_span.instant("MapDataset.__iter__.skip", level=1) + trace_span.instant(f"{map_dataset_prefix}.__iter__.skip", level=1) except SYSTEM_EXCEPTIONS: raise FatalSampleError.from_sample(sample) except Exception as e: self.error_handler(e, sample) trace_span.instant( - "MapDataset.__iter__.error/skip", + f"{map_dataset_prefix}.__iter__.error/skip", args={"exception": f"{type(e).__name__}: {str(e)}"}, level=1, ) @@ -168,33 +190,53 @@ def assert_can_restore(self) -> None: super().assert_can_restore() def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_sample_out: + trace_span = self.worker_config.worker_trace_span() self.assert_can_restore() - if inspect.isgeneratorfunction(self.map_fn): - id, sample_idx, local_idx = index[:3] - assert id == type(self).__name__ - index = index[3:] - assert isinstance(local_idx, int) - else: - id, sample_idx = index[:2] - assert id == type(self).__name__ - index = index[2:] - inner_sample = self.dataset.restore_sample(index) - with self._sample_index.ctx(sample_idx): - mapped_sample = self.map_fn(inner_sample) - if isinstance(mapped_sample, Generator): - assert inspect.isgeneratorfunction(self.map_fn), ( - f"Generator in {self.map_fn} but not marked as such." - ) - for idx, (sample_idx, res_sample) in enumerate( - self._sample_index.iter_ctx(mapped_sample, sample_idx) + with trace_span.span( + "MapDataset.restore_sample", + args={"index": index}, + level=1, + ): + if inspect.isgeneratorfunction(self.map_fn): + id, sample_idx, local_idx = index[:3] + assert id == type(self).__name__ + index = index[3:] + assert isinstance(local_idx, int) + else: + id, sample_idx = index[:2] + assert id == type(self).__name__ + index = index[2:] + with trace_span.span( + "MapDataset.restore_sample.dataset", + args={"index": index}, + level=2, ): - if idx == local_idx: - return add_sample_restore_key(res_sample, sample_idx, local_idx, src=self) - assert False, ( - "Generator did not yield enough samples, but is marked stateless/deterministic." - ) - else: - return add_sample_restore_key(mapped_sample, sample_idx, src=self) + inner_sample = self.dataset.restore_sample(index) + with ( + self._sample_index.ctx(sample_idx), + trace_span.span( + f"MapDataset.restore_sample.map_fn:{self._function_config(self.map_fn)}", + args={"sample_idx": sample_idx}, + level=2, + ), + ): + mapped_sample = self.map_fn(inner_sample) + if isinstance(mapped_sample, Generator): + assert inspect.isgeneratorfunction(self.map_fn), ( + f"Generator in {self.map_fn} but not marked as such." + ) + for idx, (sample_idx, res_sample) in trace_span.iterable( + self._sample_index.iter_ctx(mapped_sample, sample_idx), + name=f"MapDataset.restore_sample.map_fn:{self._function_config(self.map_fn)}.next", + level=2, + ): + if idx == local_idx: + return add_sample_restore_key(res_sample, sample_idx, local_idx, src=self) + assert False, ( + "Generator did not yield enough samples, but is marked stateless/deterministic." + ) + else: + return add_sample_restore_key(mapped_sample, sample_idx, src=self) def _own_config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/wrappers/packing_dataset.py b/src/megatron/energon/wrappers/packing_dataset.py index 0be345c6..5faae330 100644 --- a/src/megatron/energon/wrappers/packing_dataset.py +++ b/src/megatron/energon/wrappers/packing_dataset.py @@ -60,7 +60,7 @@ class PackingDataset( #: The samples are stored sequentially in the pre_packing_buffer because #: SavableSampleBuffer doesn't support nesting. But to keep the groups #: separate, we need to store the lengths of the groups here. - _pre_packing_lengths: List[List[int]] + _pre_packing_lengths: List[int] #: Sample index for the pre_packer _pre_packing_sample_index: SampleIndex @@ -88,7 +88,7 @@ def __init__( final_packer: Callable[[List[T_encoded_sample]], T_batch_sample], *, final_packer_stateless: bool = False, - sample_encoder: Optional[Callable[[List[T_sample]], T_encoded_sample]] = None, + sample_encoder: Optional[Callable[[T_sample], T_encoded_sample]] = None, sample_encoder_stateless: bool = False, packer_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] = None, error_handler: Callable[[Exception, List[T_sample]], None] = log_exception, @@ -181,6 +181,10 @@ def _fill_reading_buffer(self, source_iter: Iterator, log_progress: bool = False def __iter__(self) -> Iterator[T_batch_sample]: trace_span = self.worker_config.worker_trace_span() + if self.sample_encoder is not None: + encode_name = self._function_config(self.sample_encoder) + pre_packer_name = self._function_config(self.pre_packer) + final_packer_name = self._function_config(self.final_packer) def encode_pack_samples(pack: List[T_sample]) -> List[T_encoded_sample]: # Apply the sample encoder to the pack @@ -192,21 +196,19 @@ def encode_pack_samples(pack: List[T_sample]) -> List[T_encoded_sample]: ): for sample in pack: try: - with trace_span.span( - "PackingDataset._encode_pack_samples.encode_sample", level=2 + with ( + self._sample_encoder_sample_index.ctx() as encode_idx, + trace_span.span(encode_name, args={"sample_idx": encode_idx}, level=2), ): - with self._sample_encoder_sample_index.ctx() as encode_idx: - encoded_sample = self.sample_encoder(sample) - assert not isinstance(encoded_sample, Generator), ( - "Generator not supported" - ) - encoded_pack.append( - add_sample_restore_key( - encoded_sample, - encode_idx, - src=self, - ) + encoded_sample = self.sample_encoder(sample) + assert not isinstance(encoded_sample, Generator), "Generator not supported" + encoded_pack.append( + add_sample_restore_key( + encoded_sample, + encode_idx, + src=self, ) + ) except SkipSample: trace_span.instant("PackingDataset._encode_pack_samples.skip", level=2) except SYSTEM_EXCEPTIONS: @@ -230,10 +232,17 @@ def next_pre_pack(): samples = list(self._reading_buffer) # Clear buffer and pre_packing_lengths self._reading_buffer.clear() - pre_packing_lengths.clear() + self._pre_packing_lengths.clear() # Now pre pack the samples try: - with self._pre_packing_sample_index.ctx(): + with ( + self._pre_packing_sample_index.ctx() as pre_pack_idx, + trace_span.span( + pre_packer_name, + args={"pre_pack_idx": pre_pack_idx, "len": len(samples)}, + level=2, + ), + ): pre_packs = self.pre_packer(samples) except SkipSample: pre_packs = [] @@ -255,44 +264,57 @@ def next_pre_pack(): # so that the groups can be separated later for pre_pack in pre_packs: self._pre_packing_buffer.extend(pre_pack) - pre_packing_lengths.append(len(pre_pack)) + self._pre_packing_lengths.append(len(pre_pack)) def next_final_pack() -> Generator[T_batch_sample, None, None]: """Yield the next packs from the buffer. The final packer is called on the fly.""" - pack = list(self._pre_packing_buffer[: pre_packing_lengths[0]]) + pack = list(self._pre_packing_buffer[: self._pre_packing_lengths[0]]) pack = encode_pack_samples(pack) if len(pack) == 0: # All samples in the pack were skipped return - del self._pre_packing_buffer[: pre_packing_lengths[0]] - del pre_packing_lengths[0] + del self._pre_packing_buffer[: self._pre_packing_lengths[0]] + del self._pre_packing_lengths[0] try: pack_restore_keys = tuple(get_sample_restore_key(sample) for sample in pack) - with self._final_packing_sample_index.ctx() as pack_idx: + with ( + self._final_packing_sample_index.ctx() as pack_idx, + trace_span.span( + final_packer_name, args={"pack_idx": pack_idx, "len": len(pack)}, level=2 + ), + ): final_packed_sample = self.final_packer(pack) if isinstance(final_packed_sample, Generator): assert inspect.isgeneratorfunction(self.final_packer), ( f"Generator in {self.final_packer} but not marked as such." ) - for pack_sub_idx, (pack_idx, inner_batch_sample) in enumerate( - self._final_packing_sample_index.iter_ctx(final_packed_sample, pack_idx) + for pack_sub_idx, (pack_idx, inner_batch_sample) in trace_span.iterable( + enumerate( + self._final_packing_sample_index.iter_ctx(final_packed_sample, pack_idx) + ), + name=f"{final_packer_name}.next", + level=2, ): + with trace_gen.yield_( + last_args={"pack_idx": pack_idx, "pack_sub_idx": pack_sub_idx} + ): + yield set_sample_restore_key( + inner_batch_sample, + pack_idx, + pack_sub_idx, + *pack_restore_keys, + src=self, + ) + else: + with trace_gen.yield_(last_args={"pack_idx": pack_idx}): yield set_sample_restore_key( - inner_batch_sample, + final_packed_sample, pack_idx, - pack_sub_idx, *pack_restore_keys, src=self, ) - else: - yield set_sample_restore_key( - final_packed_sample, - pack_idx, - *pack_restore_keys, - src=self, - ) except SkipSample: trace_span.instant("PackingDataset.next_final_pack.skip", level=2) except SYSTEM_EXCEPTIONS: @@ -305,73 +327,81 @@ def next_final_pack() -> Generator[T_batch_sample, None, None]: level=2, ) - with trace_span.span( - "PackingDataset.__iter__", args={"config": self._own_config()}, level=1 + with ( + trace_span.span( + "PackingDataset.__iter__", args={"config": self._own_config()}, level=1 + ), + self.worker_config.worker_trace_writer().generator( + "PackingDataset.__iter__.next", level=2 + ) as trace_gen, ): - pre_packing_lengths = self._pre_packing_lengths # The source dataset src_iter = iter(self.dataset) - self._pre_packing_buffer.worker_start() - self._reading_buffer.worker_start() - - is_initial_pack = True - - pre_pack_round = 0 - # Main loop: - while True: - if pre_pack_round > 10: - raise RuntimeError("Pre packer did not yield any packs after 10 rounds.") - with trace_span.span( - "PackingDataset.__iter__.fill_reading_buffer", - args={ - "to_fill": self.buffer_size - - len(self._reading_buffer) - - len(self._pre_packing_buffer), - "reading_buffer": len(self._reading_buffer), - "pre_packing_buffer": len(self._pre_packing_buffer), - "buffer_size": self.buffer_size, - }, - level=2, - ): - # Fill a portion of the buffer - if not self._fill_reading_buffer(src_iter, log_progress=is_initial_pack): - # Break out of the main loop when the source is exhausted. - break - is_initial_pack = False - - # Create new pre packs if necessary - if len(pre_packing_lengths) == 0: - with trace_span.span("PackingDataset.__iter__.next_pre_pack", level=1): - assert len(self._pre_packing_buffer) == 0 - assert len(self._reading_buffer) == self.buffer_size - next_pre_pack() - if len(pre_packing_lengths) == 0: - # Retry packing, nothing was returned. - pre_pack_round += 1 - continue - # Reset the pre pack round counter for failing - pre_pack_round = 0 + try: + self._pre_packing_buffer.worker_start() + self._reading_buffer.worker_start() - with trace_span.span("PackingDataset.__iter__.final_pack", level=2): - yield from next_final_pack() + is_initial_pack = True - with trace_span.span("PackingDataset.__iter__.last", level=1): - # Yield the remaining packs, flushing the collecting buffer - while len(pre_packing_lengths) > 0: - with trace_span.span("PackingDataset.__iter__.last.final_pack", level=2): + pre_pack_round = 0 + # Main loop: + while True: + if pre_pack_round > 10: + raise RuntimeError("Pre packer did not yield any packs after 10 rounds.") + with trace_span.span( + "PackingDataset.__iter__.fill_reading_buffer", + args={ + "to_fill": self.buffer_size + - len(self._reading_buffer) + - len(self._pre_packing_buffer), + "reading_buffer": len(self._reading_buffer), + "pre_packing_buffer": len(self._pre_packing_buffer), + "buffer_size": self.buffer_size, + }, + level=2, + ): + # Fill a portion of the buffer + if not self._fill_reading_buffer(src_iter, log_progress=is_initial_pack): + # Break out of the main loop when the source is exhausted. + break + is_initial_pack = False + + # Create new pre packs if necessary + if len(self._pre_packing_lengths) == 0: + with trace_span.span("PackingDataset.__iter__.next_pre_pack", level=1): + assert len(self._pre_packing_buffer) == 0 + assert len(self._reading_buffer) == self.buffer_size + next_pre_pack() + if len(self._pre_packing_lengths) == 0: + # Retry packing, nothing was returned. + pre_pack_round += 1 + continue + # Reset the pre pack round counter for failing + pre_pack_round = 0 + + with trace_span.span("PackingDataset.__iter__.final_pack", level=2): yield from next_final_pack() - # If there are still samples in the partial reading buffer, pre-pack them and yield the - # resulting (partial) packs - if len(self._reading_buffer) > 0: - with trace_span.span("PackingDataset.__iter__.last.next_pre_pack", level=1): - next_pre_pack() - - # Yield the remaining packs, flushing the collecting buffer - while len(pre_packing_lengths) > 0: - with trace_span.span("PackingDataset.__iter__.last.final_pack", level=2): - yield from next_final_pack() + with trace_span.span("PackingDataset.__iter__.last", level=1): + # Yield the remaining packs, flushing the collecting buffer + while len(self._pre_packing_lengths) > 0: + with trace_span.span("PackingDataset.__iter__.last.final_pack", level=2): + yield from next_final_pack() + + # If there are still samples in the partial reading buffer, pre-pack them and yield the + # resulting (partial) packs + if len(self._reading_buffer) > 0: + with trace_span.span("PackingDataset.__iter__.last.next_pre_pack", level=1): + next_pre_pack() + + # Yield the remaining packs, flushing the collecting buffer + while len(self._pre_packing_lengths) > 0: + with trace_span.span("PackingDataset.__iter__.last.final_pack", level=2): + yield from next_final_pack() + finally: + if hasattr(src_iter, "close"): + src_iter.close() def can_restore_sample(self) -> bool: # Cannot really verify if the returned elements contain a __restore_key__. @@ -389,48 +419,78 @@ def assert_can_restore(self): super().assert_can_restore() def restore_sample(self, restore_key: Any) -> T_sample: + trace_span = self.worker_config.worker_trace_span() # We need to store multiple indices to restore a batch. self.assert_can_restore() - if inspect.isgeneratorfunction(self.final_packer): - id, pack_idx, pack_sub_idx, *pack_restore_keys = restore_key - assert id == type(self).__name__ - else: - id, pack_idx, *pack_restore_keys = restore_key - assert id == type(self).__name__ - - pack = [] - for inner_idx in pack_restore_keys: - if self.sample_encoder is not None: - id, sample_idx, *inner_idx = inner_idx + with trace_span.span( + "PackingDataset.restore_sample", args={"restore_key": restore_key}, level=1 + ): + if inspect.isgeneratorfunction(self.final_packer): + id, pack_idx, pack_sub_idx, *pack_restore_keys = restore_key + assert id == type(self).__name__ + else: + id, pack_idx, *pack_restore_keys = restore_key assert id == type(self).__name__ - assert isinstance(sample_idx, int) - sample = self.dataset.restore_sample(inner_idx) - if self.sample_encoder is not None: - with self._sample_encoder_sample_index.ctx(sample_idx): - sample = self.sample_encoder(sample) - assert not isinstance(sample, Generator), "Generator not supported" - sample = add_sample_restore_key(sample, sample_idx, src=self) - pack.append(sample) - with self._final_packing_sample_index.ctx(pack_idx): - final_pack = self.final_packer(pack) - if isinstance(final_pack, Generator): - assert inspect.isgeneratorfunction(self.final_packer), ( - f"Generator in {self.final_packer} but not marked as such." - ) - for cur_batch_sub_idx, (pack_idx, inner_batch_sample) in enumerate( - self._final_packing_sample_index.iter_ctx(final_pack, pack_idx) + + with trace_span.span( + "PackingDataset.restore_sample.restore_samples", + args={"len": len(pack_restore_keys)}, + level=2, ): - if cur_batch_sub_idx == pack_sub_idx: - return set_sample_restore_key( - inner_batch_sample, - pack_idx, - pack_sub_idx, - *pack_restore_keys, - src=self, - ) - assert False, f"Pack sub-index {pack_sub_idx} not found in pack" - else: - return set_sample_restore_key(final_pack, pack_idx, *pack_restore_keys, src=self) + pack = [] + for inner_sample_idx, inner_idx in enumerate(pack_restore_keys): + if self.sample_encoder is not None: + id, sample_idx, *inner_idx = inner_idx + assert id == type(self).__name__ + assert isinstance(sample_idx, int) + with trace_span.span( + "PackingDataset.restore_sample.dataset", + args={"sample_idx": inner_sample_idx}, + level=2, + ): + sample = self.dataset.restore_sample(inner_idx) + if self.sample_encoder is not None: + with ( + self._sample_encoder_sample_index.ctx(sample_idx), + trace_span.span( + f"PackingDataset.restore_sample.sample_encoder:{self._function_config(self.sample_encoder)}", + args={"sample_idx": sample_idx}, + level=2, + ), + ): + sample = self.sample_encoder(sample) + assert not isinstance(sample, Generator), "Generator not supported" + sample = add_sample_restore_key(sample, sample_idx, src=self) + pack.append(sample) + with ( + self._final_packing_sample_index.ctx(pack_idx), + trace_span.span( + f"PackingDataset.restore_sample.final_packer:{self._function_config(self.final_packer)}", + args={"pack_idx": pack_idx}, + level=2, + ), + ): + final_pack = self.final_packer(pack) + if isinstance(final_pack, Generator): + assert inspect.isgeneratorfunction(self.final_packer), ( + f"Generator in {self.final_packer} but not marked as such." + ) + for cur_batch_sub_idx, (pack_idx, inner_batch_sample) in trace_span.iterable( + self._final_packing_sample_index.iter_ctx(final_pack, pack_idx), + name=f"PackingDataset.restore_sample.final_packer:{self._function_config(self.final_packer)}.next", + level=2, + ): + if cur_batch_sub_idx == pack_sub_idx: + return set_sample_restore_key( + inner_batch_sample, + pack_idx, + pack_sub_idx, + *pack_restore_keys, + src=self, + ) + assert False, f"Pack sub-index {pack_sub_idx} not found in pack" + else: + return set_sample_restore_key(final_pack, pack_idx, *pack_restore_keys, src=self) def _own_config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/wrappers/repeat_dataset.py b/src/megatron/energon/wrappers/repeat_dataset.py index cf11be6c..557702ed 100644 --- a/src/megatron/energon/wrappers/repeat_dataset.py +++ b/src/megatron/energon/wrappers/repeat_dataset.py @@ -60,14 +60,19 @@ def __iter__(self) -> Iterator[T_sample]: ds_len = len(self.dataset) trace_span = self.worker_config.worker_trace_span() - with trace_span.span( - "RepeatDataset.__iter__", - args={ - "repetition": self._repetition, - "repeats": self.repeats, - "inner_len": ds_len, - }, - level=2, + with ( + trace_span.span( + "RepeatDataset.__iter__", + args={ + "repetition": self._repetition, + "inner_len": ds_len, + "config": self._own_config(), + }, + level=2, + ), + self.worker_config.worker_trace_writer().generator( + "RepeatDataset.__iter__.next", next_args={"idx": self._index}, level=2 + ) as trace_gen, ): while self.repeats is None or self._repetition < self.repeats: with trace_span.span( @@ -90,14 +95,15 @@ def __iter__(self) -> Iterator[T_sample]: stop_after = None for sample in self.dataset: - with trace_span.span( + trace_span.instant( "RepeatDataset.__iter__.__iter__.yield", args={ "idx": self._index, }, level=2, - ): - self._index += 1 + ) + self._index += 1 + with trace_gen.yield_(next_args={"idx": self._index}): yield sample if stop_after is not None and self._index >= stop_after: @@ -114,6 +120,11 @@ def __iter__(self) -> Iterator[T_sample]: # No more repeats self._repetition = math.ceil(self.repeats) + def _own_config(self) -> Dict[str, Any]: + return { + "repeats": self.repeats, + } + def config(self) -> Dict[str, Any]: return { "type": type(self).__qualname__, diff --git a/src/megatron/energon/wrappers/shuffle_buffer_dataset.py b/src/megatron/energon/wrappers/shuffle_buffer_dataset.py index 7befb4cc..6c930366 100644 --- a/src/megatron/energon/wrappers/shuffle_buffer_dataset.py +++ b/src/megatron/energon/wrappers/shuffle_buffer_dataset.py @@ -1,7 +1,7 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Dict, Generic, Iterator, List, Tuple, TypeVar, Union +from typing import Any, Dict, Generic, Iterator, List, TypeVar from megatron.energon.flavors.base_dataset import SavableDataset from megatron.energon.rng import WorkerRng @@ -46,48 +46,55 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[T_sample]: trace_span = self.worker_config.worker_trace_span() - with trace_span.span( - "ShuffleBufferDataset.__iter__", args={"config": self._own_config()}, level=1 + with ( + trace_span.span( + "ShuffleBufferDataset.__iter__", args={"config": self._own_config()}, level=1 + ), + self.worker_config.worker_trace_writer().generator( + "ShuffleBufferDataset.__iter__.next", level=2 + ) as trace_gen, ): self._active_buffer.worker_start() it = iter(self._active_buffer.append_iter()) - while True: - if len(self._active_buffer) >= self.size: - pop_idx = self._worker_rng.randbelow(len(self._active_buffer)) - sample_creation = self._sample_creation.pop(pop_idx) - trace_span.instant( - "ShuffleBufferDataset.__iter__.yield", - args={ - "idx": pop_idx, - "sample_creation": sample_creation, - "sample_age": self._iterations - sample_creation, - }, - level=2, - ) - yield self._active_buffer.pop(pop_idx) - else: - try: - next(it) - self._sample_creation.append(self._iterations) + try: + while True: + if len(self._active_buffer) >= self.size: + pop_idx = self._worker_rng.randbelow(len(self._active_buffer)) + sample_creation = self._sample_creation.pop(pop_idx) trace_span.instant( - "ShuffleBufferDataset.__iter__.append", + "ShuffleBufferDataset.__iter__.yield", args={ - "idx": len(self._sample_creation) - 1, - "sample_creation": self._iterations, + "idx": pop_idx, + "sample_creation": sample_creation, + "sample_age": self._iterations - sample_creation, }, level=2, ) - self._iterations += 1 - except StopIteration: - break + with trace_gen.yield_(last_args={"idx": pop_idx}): + yield self._active_buffer.pop(pop_idx) + else: + try: + next(it) + self._sample_creation.append(self._iterations) + trace_span.instant( + "ShuffleBufferDataset.__iter__.append", + args={ + "idx": len(self._sample_creation) - 1, + "sample_creation": self._iterations, + }, + level=2, + ) + self._iterations += 1 + except StopIteration: + break + finally: + if hasattr(it, "close"): + it.close() with trace_span.span("ShuffleBufferDataset.__iter__.final_buffer", level=2): while len(self._active_buffer) > 0: pop_idx = self._worker_rng.randbelow(len(self._active_buffer)) yield self._active_buffer.pop(pop_idx) - def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_sample: - return self._active_buffer.restore_sample(index) - def _own_config(self) -> Dict[str, Any]: return { "size": self.size, diff --git a/tests/test_dataset.py b/tests/test_dataset.py index c809c061..721fcaf9 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1684,17 +1684,21 @@ def test_debug_dataset(self): num_workers=2, worker_log_level=3, worker_debug_path=str(self.dataset_path) + "/worker_debug/{worker_id}.json", + # worker_debug_path="./tmp_worker_debug/{worker_id}.json", ) # Reset this to 0 to make sure the test is deterministic SavableDataLoader._next_id = 0 loader = get_savable_loader( - get_val_dataset( + get_train_dataset( self.dataset_path, split_part="train", batch_size=5, worker_config=worker_config, + shuffle_buffer_size=10, + max_samples_per_sequence=None, + virtual_epoch_length=10, ), ) @@ -1710,22 +1714,27 @@ def test_debug_dataset(self): assert (debug_log_path / "1.json").is_file(), f"{list(debug_log_path.iterdir())}" assert (debug_log_path / "2.json").is_file(), f"{list(debug_log_path.iterdir())}" - collected_keys_order = [[None] * 10 for _ in range(2)] + collected_keys = defaultdict(list) with (debug_log_path / "0.json").open() as rf: raw = json.load(rf) + for entry in raw: + if entry["ph"] == "n" and entry["name"] == "SavableDataLoader.yield": + # print(entry) + collected_keys[entry["args"]["global_sample_idx"]].extend(entry["args"]["keys"]) assert len(raw) > 0 - print(raw) + dst_keys = [ + [collected_keys[i] for i in range(10)], + [collected_keys[i] for i in range(10, 20)], + ] - print(collected_keys_order) - assert collected_keys_order == samples + print(dst_keys) + assert dst_keys == samples runner = CliRunner() result = runner.invoke( analyze_debug_command, [ str(debug_log_path), - "--include-modality", - "train,val", "--heatmap-path", str(self.dataset_path / "heatmap.png"), ], From a67072f879279f3e6253b1056efd103bc06f6951 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Tue, 27 May 2025 13:37:39 +0200 Subject: [PATCH 3/7] Small fixes for new tracing --- src/megatron/energon/savable_loader.py | 208 +++++++++---------- src/megatron/energon/tracing.py | 35 +++- src/megatron/energon/wrappers/map_dataset.py | 6 +- 3 files changed, 128 insertions(+), 121 deletions(-) diff --git a/src/megatron/energon/savable_loader.py b/src/megatron/energon/savable_loader.py index 40f48376..c98e5452 100644 --- a/src/megatron/energon/savable_loader.py +++ b/src/megatron/energon/savable_loader.py @@ -39,7 +39,6 @@ add_sample_restore_key, ) from megatron.energon.rng import SystemRng, SystemRngState -from megatron.energon.tracing import Flow from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset from megatron.energon.wrappers.batch_dataset import BatchDataset @@ -62,7 +61,24 @@ def _init_worker(seed_per_worker: List[int], worker_id: int): SystemRng.seed(worker_seed) -class SimpleSavableDatasetWrapper(BaseWrapperDataset[T, Tuple[int, int, T]], Generic[T]): +@edataclass +class SavableDatasetState(State): + """State of the dataset wrapper. It stores the global random states and the index of the next + sample to be returned from the dataset. This class is not intended to be used directly, but by + :class:`megatron.energon.SavableDatasetWrapper`.""" + + #: The state of all the system random number generators + rng: SystemRngState + #: The state of the savable dataset + dataset_state: FlexState + #: Index of the next sample to be returned from the dataset + sample_index: int + + def __repr__(self): + return f"SavableDatasetState(rng={self.rng!r}, sample_index={self.sample_index})" + + +class SimpleSavableDatasetWrapper(IterableDataset[Tuple[int, int, T]], Generic[T]): """Wrapper for non-multiprocessing savable datasets. Restarts the inner dataset. This class is not intended to be used directly.""" @@ -72,8 +88,6 @@ class SimpleSavableDatasetWrapper(BaseWrapperDataset[T, Tuple[int, int, T]], Gen _state_restored: bool _sample_index: int - _savable_fields = ("_sample_index",) - def __init__( self, dataset: SavableDataset[T], @@ -88,20 +102,18 @@ def __init__( cache_pool: The cache pool to use for the dataset. dataloader_id: The id of the data loader for logging purposes. """ - super().__init__(dataset, worker_config=worker_config) + self.dataset = dataset + self.worker_config = worker_config self.cache_pool = cache_pool self.dataloader_id = dataloader_id self.reset_state_own() - # This must be removed, such that the outer dataloader does not use this - # from the second epoch on for training. - del self.__len__ def reset_state_own(self) -> None: self._sample_index = 0 self._state_restored = False - def __len__(self): + def inner_len(self): return len(self.dataset) def __iter__(self): @@ -123,26 +135,25 @@ def __iter__(self): self._state_restored = False self.worker_config.worker_activate(self._sample_index, cache_pool=self.cache_pool) worker_active = True + trace_sample_flow: dict = {} try: # For tracing, this contains the current sample flow - trace_sample_flow: Flow = None - - def _next_span(): + def _trace_next(): # Trace the next sample flow nonlocal trace_sample_flow span = trace_writer.span( name="SimpleSavableDatasetWrapper.__iter__.loop.dataset.next", - args={"sample_idx": self._sample_index}, + args={"sample_index": self._sample_index}, level=1, ) trace_sample_flow = trace_writer.flow( f"w{global_worker_id}_s{self._sample_index}", level=1, - ) + ).save() return span - for src_data in trace_writer.iterable(self.dataset, next=_next_span): + for src_data in trace_writer.iterable(self.dataset, next=_trace_next): self.worker_config.worker_deactivate() worker_active = False sample_index = self._sample_index @@ -150,13 +161,12 @@ def _next_span(): src_data, global_worker_id, sample_index, src=self ) self._sample_index += 1 - trace_sample_flow.end(level=1) trace_writer.instant( "SimpleSavableDatasetWrapper.__iter__.loop.yield", args={"sample_index": sample_index}, level=1, ) - yield worker_id, sample_index, src_data + yield worker_id, sample_index, src_data, trace_sample_flow if self._state_restored: # Restart iterator after restore break @@ -167,6 +177,21 @@ def _next_span(): finally: if worker_active: self.worker_config.worker_deactivate() + trace_writer.instant("SimpleSavableDatasetWrapper.__iter__.break", level=1) + + def save_state(self): + return SavableDatasetState( + rng=None, + dataset_state=self.dataset.save_state(), + sample_index=self._sample_index, + ) + + def restore_state(self, state: SavableDatasetState): + self._sample_index = state.sample_index + self.dataset.restore_state(state.dataset_state) + + def can_restore_sample(self) -> bool: + return self.dataset.can_restore_sample() def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T: with self.worker_config.worker_trace_writer().span( @@ -195,52 +220,6 @@ def __str__(self): return f"SimpleSavableDatasetWrapper(dataset={self.dataset})" -class SimpleSavableDatasetWrapperWithoutLen(IterableDataset[Tuple[int, int, T]], Generic[T]): - """Wrapper for non-multiprocessing savable datasets. Restarts the inner dataset. This class is - not intended to be used directly.""" - - def __init__(self, dataset: SavableDataset[T]): - self.dataset = dataset - - def inner_len(self): - return len(self.dataset) - - def __iter__(self): - return self.dataset.__iter__() - - def save_state(self): - return self.dataset.save_state() - - def restore_state(self, state: FlexState): - return self.dataset.restore_state(state) - - def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T: - return self.dataset.restore_sample(index) - - def config(self): - return self.dataset.config() - - def __str__(self): - return f"SimpleSavableDatasetWrapperWithoutLen(dataset={self.dataset})" - - -@edataclass -class SavableDatasetState(State): - """State of the dataset wrapper. It stores the global random states and the index of the next - sample to be returned from the dataset. This class is not intended to be used directly, but by - :class:`megatron.energon.SavableDatasetWrapper`.""" - - #: The state of all the system random number generators - rng: SystemRngState - #: The state of the savable dataset - dataset_state: FlexState - #: Index of the next sample to be returned from the dataset - sample_index: int - - def __repr__(self): - return f"SavableDatasetState(rng={self.rng!r}, sample_index={self.sample_index})" - - @edataclass class SavableCheckpoint: """Checkpoint data for :class:`megatron.energon.SavableDatasetWrapper`. An instance is created @@ -413,14 +392,6 @@ def __del__(self): self._cmd_thread = None # print(f"{id(self)}:{multiprocessing.current_process().ident} Cmd thread closed") - def _trace_hierarchy(self, dataset: SavableDataset[T]): - trace_writer = self.worker_config.worker_trace_writer() - dataset = self.dataset - with trace_writer.span(type(dataset).__name__, level=1): - if isinstance(dataset, BaseWrapperDataset): - for dataset in dataset.datasets: - self._trace_hierarchy(dataset) - def __iter__(self): trace_writer = self.worker_config.worker_trace_writer() # First: Set the worker offset globally for the current worker @@ -717,7 +688,9 @@ def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T: id, global_worker_id, sample_idx = restore_key[:3] assert id == type(self).__name__ restore_key = restore_key[3:] - self.worker_config.worker_activate(sample_idx, override_global_rank=global_worker_id) + self.worker_config.worker_activate( + sample_idx, override_global_rank=global_worker_id, cache_pool=self.cache_pool + ) try: return add_sample_restore_key( self.dataset.restore_sample(restore_key), @@ -741,7 +714,7 @@ class SavableDataLoaderState(State): processed of a single rank.""" #: The internal state of the dataset (for each worker process) - worker_states: List[Union[SavableDatasetCheckpoint, FlexState]] + worker_states: Union[List[SavableDatasetCheckpoint], List[SavableDatasetState]] #: Which worker will be the next to emit a sample. Used to restore the proper order next_worker_id: int @@ -800,7 +773,7 @@ class SavableDataLoader(DataLoader[T], Generic[T]): #: The worker config worker_config: WorkerConfig #: The wrapped dataset. For multiprocessing, this is a :class:`megatron.energon.SavableDatasetWrapper` - dataset: Union[SavableDatasetWrapper[T], SimpleSavableDatasetWrapperWithoutLen[T]] + dataset: Union[SavableDatasetWrapper[T], SimpleSavableDatasetWrapper[T]] #: The global ID counter _next_id: ClassVar[int] = 0 @@ -815,6 +788,8 @@ class SavableDataLoader(DataLoader[T], Generic[T]): #: Instance of the current data iterator. There shall be only one active iterator, such that the # dataset is not iterated multiple times in parallel. The state will proceed. _persistent_iterator: Optional[Iterator[T]] = None + #: Whether the dataloader has running workers. + _has_workers: bool = False #: The index of the current worker. -1 if not started yet. _worker_sample_counters: List[int] #: Id of the next worker to retrieve data from @@ -912,10 +887,8 @@ def __init__( dataloader_id=self.id, ) else: - dataset = SimpleSavableDatasetWrapperWithoutLen( - SimpleSavableDatasetWrapper( - dataset, self.worker_config, cache_pool=cache_pool, dataloader_id=self.id - ) + dataset = SimpleSavableDatasetWrapper( + dataset, self.worker_config, cache_pool=cache_pool, dataloader_id=self.id ) self._worker_sample_counters = [-1] * num_procs @@ -938,6 +911,16 @@ def __init__( self.worker_config.worker_seed(i) for i in range(self.worker_config.num_workers) ] + super().__init__( + dataset, + batch_size=None, + shuffle=False, + num_workers=self.worker_config.num_workers, + pin_memory=True, + worker_init_fn=partial(_init_worker, seed_per_worker), + **kwargs, + ) + self.worker_config.worker_trace_writer().trace_object_async( self, "SavableDataLoader", @@ -949,16 +932,6 @@ def __init__( level=1, ) - super().__init__( - dataset, - batch_size=None, - shuffle=False, - num_workers=self.worker_config.num_workers, - pin_memory=True, - worker_init_fn=partial(_init_worker, seed_per_worker), - **kwargs, - ) - @staticmethod def next_id() -> int: next_id = SavableDataLoader._next_id @@ -966,6 +939,7 @@ def next_id() -> int: return next_id def __len__(self): + # We override this, because otherwise we'll see warnings return self.dataset.inner_len() def __iter__(self): @@ -1022,6 +996,8 @@ def _inner_generator(iterator): }, level=1, ) + else: + keys = None with trace_generator.yield_( last_args={ "loader_id": self.id, @@ -1038,9 +1014,10 @@ def _inner_generator(iterator): self._global_sample_idx += 1 iter_idx += 1 yield sample - finally: + # After the source is exhausted, not for GeneratorExit. self._persistent_iterator = None self._next_worker_id = 0 + finally: trace_span.instant( "SavableDataLoader.StopIteration", level=1, @@ -1057,6 +1034,7 @@ def _inner_generator(iterator): if self._persistent_iterator is None: self._persistent_iterator = _inner_generator(super().__iter__()) self._sample_idx = 0 + self._has_workers = True # print("New Iterator", self._persistent_iterator) return self._persistent_iterator else: @@ -1079,9 +1057,7 @@ def _worker_command(self, *cmd_args) -> List[Any]: def _get_batch_size(self) -> Optional[int]: """Try to infer micro batch size from the dataset""" - if isinstance(self.dataset, SavableDatasetWrapper): - dataset = self.dataset.dataset - elif isinstance(self.dataset, SimpleSavableDatasetWrapperWithoutLen): + if isinstance(self.dataset, (SavableDatasetWrapper, SimpleSavableDatasetWrapper)): dataset = self.dataset.dataset else: dataset = self.dataset @@ -1106,17 +1082,17 @@ def save_state_rank(self) -> Optional[SavableDataLoaderState]: # Fetch current rank's worker's state if self.num_workers == 0: # No workers configured - assert isinstance(self.dataset, SimpleSavableDatasetWrapperWithoutLen) + assert isinstance(self.dataset, SimpleSavableDatasetWrapper) worker_states = [self.dataset.save_state()] assert self._next_worker_id == 0 - elif self._persistent_iterator is None: + elif self._has_workers: + # Fetch from worker processes + worker_states = self._worker_command("get_checkpoint", self._worker_sample_counters) + else: # Workers configured, but not started yet. # If a state has already been restored, it will be returned. assert isinstance(self.dataset, SavableDatasetWrapper) worker_states = self.dataset.get_initial_checkpoint() - else: - # Fetch from worker processes - worker_states = self._worker_command("get_checkpoint", self._worker_sample_counters) if worker_states is None: return None @@ -1149,13 +1125,13 @@ def restore_state_rank(self, state: Optional[SavableDataLoaderState]) -> None: if self.num_workers == 0: # No workers configured - assert isinstance(self.dataset, SimpleSavableDatasetWrapperWithoutLen) + assert isinstance(self.dataset, SimpleSavableDatasetWrapper) assert micro_batch_size == old_micro_batch_size, ( "Changing micro batch size is not allowed without workers" ) assert len(state.worker_states) == 1 - assert isinstance(state.worker_states[0], FlexState) + assert isinstance(state.worker_states[0], SavableDatasetState) self.dataset.restore_state(state.worker_states[0]) else: # Workers configured @@ -1437,7 +1413,7 @@ def __init__( ) dataset = SimpleSavableDatasetWrapper( - dataset, worker_config=self.worker_config, cache_pool=cache_pool + dataset, worker_config=self.worker_config, dataloader_id=self.id, cache_pool=cache_pool ) self._worker_sample_counters = [0] * max(self.worker_config.num_workers, 1) @@ -1454,6 +1430,16 @@ def __init__( gc.collect() # This ensures that we don't include any old worker refs in the newly forked worker processes + super().__init__( + dataset, + batch_size=None, + shuffle=False, + num_workers=self.worker_config.num_workers, + pin_memory=True, + worker_init_fn=partial(_init_worker, seed_per_worker), + **kwargs, + ) + self.worker_config.worker_trace_writer().trace_object_async( self, "BasicDataLoader", @@ -1465,15 +1451,9 @@ def __init__( level=1, ) - super().__init__( - dataset, - batch_size=None, - shuffle=False, - num_workers=self.worker_config.num_workers, - pin_memory=True, - worker_init_fn=partial(_init_worker, seed_per_worker), - **kwargs, - ) + def __len__(self): + # We override this, because otherwise we'll see warnings + return self.dataset.inner_len() def __iter__(self): def _inner_generator(iterator): @@ -1530,12 +1510,18 @@ def _inner_generator(iterator): }, level=1, ) + else: + keys = None with trace_generator.yield_( last_args={ "loader_id": self.id, "iter_id": id, "worker_id": worker_id, - "sample_idx": sample_idx, + "worker_sample_idx": sample_idx, + "sample_idx": iter_idx, + "iter_idx": iter_idx, + "global_sample_idx": self._sample_idx, + **({} if keys is None else {"keys": keys}), } ): self._sample_idx += 1 @@ -1563,7 +1549,7 @@ def config(self): and `restore_state` to match the configuration as well.""" return { "type": type(self).__qualname__, - "num_workers": self.num_workers, + "num_workers": self.worker_config.num_workers, "persistent_workers": self.persistent_workers, "pin_memory": self.pin_memory, "prefetch_factor": None if self.num_workers == 0 else self.prefetch_factor, diff --git a/src/megatron/energon/tracing.py b/src/megatron/energon/tracing.py index 0b0563a8..67472c7e 100644 --- a/src/megatron/energon/tracing.py +++ b/src/megatron/energon/tracing.py @@ -1409,6 +1409,12 @@ def delete(self, *args, **kwargs) -> None: class NoopTraceWriter: """A trace writer that does nothing. Used when tracing is disabled.""" + def close(self) -> None: + pass + + def flush(self) -> None: + pass + def duration_begin(self, *args, **kwargs) -> None: pass @@ -1421,6 +1427,12 @@ def span(self, *args, **kwargs) -> "Span": def instant(self, *args, **kwargs) -> None: pass + def iterable(self, iterable: Iterable[T], *args, **kwargs) -> Iterable[T]: + return iterable + + def generator(self, *args, **kwargs) -> "GeneratorContext": + return _NOOP_GENERATOR_CONTEXT + def async_begin(self, *args, **kwargs) -> None: pass @@ -1433,9 +1445,12 @@ def async_end(self, *args, **kwargs) -> None: def async_span(self, *args, **kwargs) -> "AsyncSpan": return _NOOP_ASYNC_SPAN - def async_context(self, *args, **kwargs) -> "AsyncContext": + def async_flow(self, *args, **kwargs) -> "AsyncContext": return _NOOP_ASYNC_CONTEXT + def async_generator(self, *args, **kwargs) -> "AsyncGeneratorContext": + return _NOOP_ASYNC_GENERATOR_CONTEXT + def flow_start(self, *args, **kwargs) -> None: pass @@ -1448,19 +1463,28 @@ def flow_end(self, *args, **kwargs) -> None: def flow(self, *args, **kwargs) -> "Flow": return _NOOP_FLOW + def resume_flow(self, saved_flow: dict) -> "Flow": + return _NOOP_FLOW + def counter(self, *args, **kwargs) -> None: pass - def object_trace(self, *args, **kwargs) -> "ObjectTrace": + def async_object_trace(self, *args, **kwargs) -> "ObjectTrace": return _NOOP_OBJECT_TRACE - def trace_object(self, *args, **kwargs) -> "ObjectTrace": + def trace_object_async(self, *args, **kwargs) -> "ObjectTrace": return _NOOP_OBJECT_TRACE def metadata(self, *args, **kwargs) -> None: pass - def close(self) -> None: + def metadata_process_name(self, name: str) -> None: + pass + + def metadata_thread_name(self, name: str) -> None: + pass + + def async_exc(self, *args, **kwargs) -> None: pass def __enter__(self): @@ -1472,9 +1496,6 @@ def __exit__(self, exc_type, exc, tb): def __repr__(self) -> str: return "" - def exception(self, *args, **kwargs) -> None: - pass - NOOP_TRACE_WRITER: TraceWriter = cast(TraceWriter, NoopTraceWriter()) diff --git a/src/megatron/energon/wrappers/map_dataset.py b/src/megatron/energon/wrappers/map_dataset.py index 3f3b26de..97f2d86d 100644 --- a/src/megatron/energon/wrappers/map_dataset.py +++ b/src/megatron/energon/wrappers/map_dataset.py @@ -115,7 +115,7 @@ def __iter__(self) -> Iterator[T_sample_out]: target_offset = self._generator_offset self._generator_offset = 0 for idx, (sample_idx, inner_sample) in trace_span.iterable( - self._sample_index.iter_ctx(mapped_sample, sample_idx), + enumerate(self._sample_index.iter_ctx(mapped_sample, sample_idx)), name=f"{fn_span}.next", level=2, ): @@ -149,7 +149,7 @@ def __iter__(self) -> Iterator[T_sample_out]: # In case of a generator, additionally store the index of the yielded samples # per input sample for idx, (sample_idx, inner_sample) in trace_span.iterable( - self._sample_index.iter_ctx(mapped_sample, sample_idx), + enumerate(self._sample_index.iter_ctx(mapped_sample, sample_idx)), name=f"{fn_span}.next", level=2, ): @@ -228,7 +228,7 @@ def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_s f"Generator in {self.map_fn} but not marked as such." ) for idx, (sample_idx, res_sample) in trace_span.iterable( - self._sample_index.iter_ctx(mapped_sample, sample_idx), + enumerate(self._sample_index.iter_ctx(mapped_sample, sample_idx)), name=f"MapDataset.restore_sample.map_fn:{self._function_config(self.map_fn)}.next", level=2, ): From e5993ca834fa86b18666a8496e20058065998154 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Wed, 28 May 2025 12:53:14 +0200 Subject: [PATCH 4/7] Fix license header --- src/megatron/energon/tracing.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/megatron/energon/tracing.py b/src/megatron/energon/tracing.py index 67472c7e..dc367784 100644 --- a/src/megatron/energon/tracing.py +++ b/src/megatron/energon/tracing.py @@ -1,4 +1,5 @@ -from __future__ import annotations +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause import gc import json From efa054177239dbc1e2c34de4a8c7c0f3e8923406 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Wed, 28 May 2025 12:59:31 +0200 Subject: [PATCH 5/7] Update formatting --- tests/test_epathlib.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_epathlib.py b/tests/test_epathlib.py index eb5e4898..e98c2156 100644 --- a/tests/test_epathlib.py +++ b/tests/test_epathlib.py @@ -136,7 +136,9 @@ def test_s3_path_resolution(self): os.environ["HOME"] = "/tmp/XDG_CONFIG_HOME" # Hack to clear the cache of the rclone config for msc to get the "s3" profile from multistorageclient.rclone import read_rclone_config + read_rclone_config.cache_clear() + try: # Test globbing p = EPath("msc://s3/tmp/path/subpath.txt") From ef55a2f1470f319b9cba6b46a6cef045f950a924 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Tue, 3 Jun 2025 09:42:28 +0200 Subject: [PATCH 6/7] WIP: Cleanup tracing --- src/megatron/energon/flavors/base_dataset.py | 4 +- src/megatron/energon/flavors/trace.py | 185 +++++++++ src/megatron/energon/savable_loader.py | 2 +- src/megatron/energon/source_info.py | 13 +- .../energon/wrappers/_log_exception.py | 4 +- .../energon/wrappers/batch_dataset.py | 146 +++---- .../energon/wrappers/blend_dataset.py | 140 +++---- .../energon/wrappers/concat_dataset.py | 37 +- .../energon/wrappers/epochize_dataset.py | 65 ++- .../energon/wrappers/filter_dataset.py | 47 ++- src/megatron/energon/wrappers/gc_dataset.py | 70 ++-- .../energon/wrappers/group_batch_dataset.py | 232 +++++------ .../energon/wrappers/iter_map_dataset.py | 113 +++--- .../energon/wrappers/limit_dataset.py | 39 +- .../energon/wrappers/log_sample_dataset.py | 81 +--- src/megatron/energon/wrappers/map_dataset.py | 202 +++++---- .../energon/wrappers/packing_dataset.py | 383 ++++++++---------- .../energon/wrappers/repeat_dataset.py | 106 +++-- .../wrappers/shuffle_buffer_dataset.py | 85 ++-- .../energon/wrappers/watchdog_dataset.py | 32 +- 20 files changed, 1019 insertions(+), 967 deletions(-) create mode 100644 src/megatron/energon/flavors/trace.py diff --git a/src/megatron/energon/flavors/base_dataset.py b/src/megatron/energon/flavors/base_dataset.py index 3331f01d..56f7fb54 100644 --- a/src/megatron/energon/flavors/base_dataset.py +++ b/src/megatron/energon/flavors/base_dataset.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import dataclasses +import functools import inspect import typing from abc import ABC, abstractmethod @@ -449,9 +450,6 @@ def legacy_handler( ], ) -> Callable[[Exception, Optional[str], Optional[list[SourceInfo]]], None]: """Safely returns the new style three argument handler. If the handler takes 2 arguments, it wraps it.""" - import functools - import inspect - handler_sig = inspect.signature(handler) if len(handler_sig.parameters) != 3: original_handler = handler diff --git a/src/megatron/energon/flavors/trace.py b/src/megatron/energon/flavors/trace.py new file mode 100644 index 00000000..aae0be52 --- /dev/null +++ b/src/megatron/energon/flavors/trace.py @@ -0,0 +1,185 @@ +import functools +from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, TypeVar, Union + +from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.tracing import AsyncContext + +T_dataset = TypeVar("T_dataset", bound=SavableDataset) +T_call = TypeVar("T_call", bound=Callable) +T_sample = TypeVar("T_sample") + + +def _flatten_str_list(keys: Any) -> Iterator[Optional[str]]: + """Flatten a list of keys into a list of strings.""" + if isinstance(keys, str): + yield keys + elif isinstance(keys, (list, tuple)): + for key in keys: + yield from _flatten_str_list(key) + else: + yield None + + +def _flatten_str_list_or_none(keys: Any) -> Optional[List[str]]: + """Flatten a list of keys into a list of strings. If this cannot be fetched, return None.""" + keys = list(_flatten_str_list(keys)) + if any(k is None for k in keys): + return None + return keys + + +def default_get_keys(batch: Any) -> Optional[List[str]]: + """Default get_keys, which has some heuristics to find the sample keys.""" + if isinstance(batch, list): + all_keys = [] + for b in batch: + k = default_get_keys(b) + if k is None: + return None + all_keys.extend(k) + return all_keys + if hasattr(batch, "__key__"): + return _flatten_str_list_or_none(batch.__key__) + elif hasattr(batch, "__keys__"): + return _flatten_str_list_or_none(batch.__keys__) + elif isinstance(batch, dict): + if "__key__" in batch: + return _flatten_str_list_or_none(batch["__key__"]) + elif "__keys__" in batch: + return _flatten_str_list_or_none(batch["__keys__"]) + elif "keys" in batch: + return _flatten_str_list_or_none(batch["keys"]) + return None + + +class TraceIter: + last_args: Dict[str, Any] = {} + + def __init__( + self, + outer_self: T_dataset, + name: str, + trace_span: AsyncContext, + call_args: Dict[str, Union[str, Callable[[T_dataset], Any]]], + ): + self.outer_self = outer_self + self.name = name + self.trace_span = trace_span + self.call_args = call_args + + def sample_exception( + self, exception: Exception, samples: Union[T_sample, Sequence[T_sample]] + ) -> None: + self.trace_span.instant( + f"{self.name}.error/skip", + args={ + "exception": f"{type(exception).__name__}: {str(exception)}", + "sample_keys": default_get_keys(samples), + **{ + arg_name: arg_value(self.outer_self) if callable(arg_value) else arg_value + for arg_name, arg_value in self.call_args.items() + }, + }, + level=1, + ) + + def skip_sample(self, samples: Sequence[T_sample]) -> None: + self.trace_span.instant( + f"{self.name}.skip", + args={ + "sample_keys": default_get_keys(samples), + }, + level=1, + ) + + def sample( + self, sample: Union[T_sample, Sequence[T_sample]], args: Dict[str, Any] = {} + ) -> None: + self.last_args["sample_keys"] = default_get_keys(sample) + self.last_args.update(args) + + def wrap_fn(self, fn: T_call) -> T_call: + fn_name = getattr(fn, "__qualname__", getattr(fn, "__name__", "")) + + @functools.wraps(fn) + def wrapped_fn(*args, **kwargs): + with self.trace_span.span( + f"{self.name}.{fn_name}.call", + args={ + arg_name: arg_value(self.outer_self) if callable(arg_value) else arg_value + for arg_name, arg_value in self.call_args.items() + }, + level=2, + ): + return fn(*args, **kwargs) + + return wrapped_fn + + def wrap_inner(self, call_args: Callable[..., Dict[str, Any]] = lambda *args, **kwargs: {}): + def decorator(fn): + fn_name = getattr(fn, "__qualname__", getattr(fn, "__name__", "")) + + @functools.wraps(fn) + def wrapped_inner_gen(*args, **kwargs): + with self.trace_span.span( + f"{self.name}.{fn_name}.__iter__", args=call_args(*args, **kwargs), level=2 + ): + return fn(*args, **kwargs) + + return wrapped_inner_gen + + return decorator + + +def trace_iter( + name: Callable[[T_dataset], str] = lambda ds: type(ds).__name__, + call_args: Dict[str, Union[str, Callable[[T_dataset], Any]]] = {}, + next_args: Dict[str, Union[str, Callable[[T_dataset], Any]]] = {}, +) -> Callable[ + [Callable[[T_dataset, TraceIter], Iterator[T_sample]]], + Callable[[T_dataset], Iterator[T_sample]], +]: + """Decorator for SavableDataset.__iter__ to trace the iteration using the worker config.""" + + def decorator( + iter_fn: Callable[[T_dataset, TraceIter], Iterator[T_sample]], + ) -> Callable[[T_dataset], Iterator[T_sample]]: + @functools.wraps(iter_fn) + def wrapper(self: T_dataset) -> Iterator[T_sample]: + trace_span = self.worker_config.worker_trace_span() + span_name = name(self) + trace_iter = TraceIter(self, span_name, trace_span, call_args) + with ( + trace_span.span( + f"{span_name}.__iter__", + args={ + arg_name: arg_value(self) if callable(arg_value) else arg_value + for arg_name, arg_value in call_args.items() + }, + level=1, + ), + self.worker_config.worker_trace_writer().generator( + f"{span_name}.__iter__.next", + level=2, + ) as trace_gen, + ): + for sample in trace_span.iterable( + iter_fn(self, trace_iter), + name=f"{span_name}.__iter__.loop", + level=2, + ): + with trace_gen.yield_( + last_args={ + **{ + arg_name: arg_value(self) if callable(arg_value) else arg_value + for arg_name, arg_value in next_args.items() + }, + **trace_iter.last_args, + }, + ): + trace_iter.last_args.clear() + yield sample + + return wrapper + + return decorator diff --git a/src/megatron/energon/savable_loader.py b/src/megatron/energon/savable_loader.py index c98e5452..a4c130d3 100644 --- a/src/megatron/energon/savable_loader.py +++ b/src/megatron/energon/savable_loader.py @@ -38,12 +38,12 @@ State, add_sample_restore_key, ) +from megatron.energon.flavors.trace import default_get_keys from megatron.energon.rng import SystemRng, SystemRngState from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset from megatron.energon.wrappers.batch_dataset import BatchDataset from megatron.energon.wrappers.gc_dataset import GC_DEFAULT_EVERY_N_ITER, GcDataset, gc_init_worker -from megatron.energon.wrappers.log_sample_dataset import default_get_keys from megatron.energon.wrappers.watchdog_dataset import WatchdogDataset T = TypeVar("T") diff --git a/src/megatron/energon/source_info.py b/src/megatron/energon/source_info.py index 1a90c228..8bf1153b 100644 --- a/src/megatron/energon/source_info.py +++ b/src/megatron/energon/source_info.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause from dataclasses import dataclass -from typing import Any, Union +from typing import Any, Sequence, Union from megatron.energon.epathlib import EPath @@ -36,3 +36,14 @@ def add_source_info(sample: Any, *source_infos: SourceInfo) -> Any: else: sample["__sources__"] = (*sample["__sources__"], *source_infos) return sample + + +def get_source_info(sample: Any) -> Sequence[SourceInfo]: + """Get the source info of a sample.""" + from megatron.energon.flavors.base_dataset import Sample + + if isinstance(sample, Sample) or hasattr(sample, "__sources__"): + return sample.__sources__ or () + elif isinstance(sample, dict): + return sample.get("__sources__") or () + return () diff --git a/src/megatron/energon/wrappers/_log_exception.py b/src/megatron/energon/wrappers/_log_exception.py index a20959aa..791be674 100644 --- a/src/megatron/energon/wrappers/_log_exception.py +++ b/src/megatron/energon/wrappers/_log_exception.py @@ -2,12 +2,12 @@ # SPDX-License-Identifier: BSD-3-Clause import traceback -from typing import Optional +from typing import Optional, Sequence from megatron.energon.source_info import SourceInfo -def log_exception(_e: Exception, sample, sources: Optional[list[SourceInfo]] = None): +def log_exception(_e: Exception, sample, sources: Optional[Sequence[SourceInfo]] = None): traceback.print_exc() print("-" * 10) diff --git a/src/megatron/energon/wrappers/batch_dataset.py b/src/megatron/energon/wrappers/batch_dataset.py index 6c40c302..695f90dd 100644 --- a/src/megatron/energon/wrappers/batch_dataset.py +++ b/src/megatron/energon/wrappers/batch_dataset.py @@ -11,14 +11,19 @@ Iterator, List, Optional, + Sequence, Tuple, TypeVar, Union, ) from megatron.energon.errors import SYSTEM_EXCEPTIONS, FatalSampleError -from megatron.energon.flavors.base_dataset import SavableDataset, set_sample_restore_key -from megatron.energon.source_info import SourceInfo +from megatron.energon.flavors.base_dataset import ( + SavableDataset, + set_sample_restore_key, +) +from megatron.energon.flavors.trace import TraceIter, trace_iter +from megatron.energon.source_info import SourceInfo, get_source_info from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers._log_exception import log_exception from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key @@ -34,7 +39,7 @@ class BatchDataset(BaseWrapperDataset[T_batch_sample, T_batch], Generic[T_batch_ batch_size: int batcher: Callable[[List[T_batch_sample]], T_batch] drop_last: bool - error_handler: Callable[[Exception, list[T_batch_sample], list[SourceInfo]], None] + error_handler: Callable[[Exception, List[T_batch_sample], Sequence[SourceInfo]], None] _sample_index: SampleIndex _generator_sample_keys: Optional[Any] _generator_offset: Optional[int] @@ -51,7 +56,7 @@ def __init__( batcher_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] = None, drop_last: bool = False, error_handler: Callable[ - [Exception, List[T_batch_sample], List[SourceInfo]], None + [Exception, List[T_batch_sample], Sequence[SourceInfo]], None ] = log_exception, failure_tolerance: Optional[int] = 100, worker_config: WorkerConfig, @@ -107,36 +112,37 @@ def __len__(self): + n_batches_per_worker_ceil * remaining_n_sample_workers ) - def __iter__(self) -> Iterator[T_batch]: + @trace_iter( + name=lambda self: f"BatchDataset({self._function_config(self.batcher)})", + call_args={ + "config": lambda self: self._own_config(), + }, + next_args={ + "idx": lambda self: self._sample_index.current_idx, + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_batch]: batch: List[T_batch_sample] = [] sample_restore_keys = [] last_batch_failures = 0 - batcher_name = self._function_config(self.batcher) - trace_span = self.worker_config.worker_trace_span() + batcher = trace_iter.wrap_fn(self.batcher) def flush() -> Generator[T_batch, None, None]: nonlocal last_batch_failures try: - with ( - self._sample_index.ctx() as sample_idx, - trace_span.span( - batcher_name, args={"sample_idx": sample_idx, "len": len(batch)}, level=2 - ), - ): - batch_sample = self.batcher(batch) + with self._sample_index.ctx() as sample_idx: + batch_sample = batcher(batch) if isinstance(batch_sample, Generator): assert inspect.isgeneratorfunction(self.batcher), ( f"Generator in {self.batcher} but not marked as such." ) self._generator_sample_keys = sample_restore_keys self._generator_offset = 0 - for batch_sub_idx, (sample_idx, inner_batch_sample) in trace_span.iterable( - enumerate(self._sample_index.iter_ctx(batch_sample, sample_idx)), - name=f"{batcher_name}.next", - level=2, + for batch_sub_idx, (sample_idx, inner_batch_sample) in enumerate( + self._sample_index.iter_ctx(batch_sample, sample_idx) ): last_batch_failures = 0 self._generator_offset = batch_sub_idx + 1 @@ -152,22 +158,18 @@ def flush() -> Generator[T_batch, None, None]: else: last_batch_failures = 0 set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self) - with trace_gen.yield_(next_args={"sample_idx": sample_idx}): - yield batch_sample + trace_iter.sample(batch_sample, {"sample_idx": sample_idx}) + yield batch_sample sample_restore_keys.clear() except GeneratorExit: raise except SkipSample: - trace_span.instant("BatchDataset.__iter__.skip", level=2) + trace_iter.skip_sample(batch) except SYSTEM_EXCEPTIONS: raise FatalSampleError.from_sample(batch) except Exception as e: - self.error_handler(e, batch) - trace_span.instant( - "BatchDataset.__iter__.error/skip", - args={"exception": f"{type(e).__name__}: {str(e)}"}, - level=2, - ) + self.error_handler(e, batch, get_source_info(batch)) + trace_iter.sample_exception(e, batch) last_batch_failures += 1 if ( self.failure_tolerance is not None @@ -178,62 +180,44 @@ def flush() -> Generator[T_batch, None, None]: f"BatchDataset {self.batcher} failed {last_batch_failures} times in a row. Likely your code or dataset are broken.", ) - with ( - trace_span.span("BatchDataset.__iter__", args={"config": self._own_config()}, level=1), - self.worker_config.worker_trace_writer().generator( - "BatchDataset.__iter__.next", - next_args={"sample_idx": self._sample_index.current_idx}, - level=2, - ) as trace_gen, - ): - if self._generator_sample_keys is not None: - sample_restore_keys = self._generator_sample_keys - assert self._generator_offset is not None - batch = [ - self.dataset.restore_sample(inner_idx) for inner_idx in sample_restore_keys - ] - with ( - self._sample_index.ctx(self._sample_index.current_idx) as sample_idx, - trace_span.span( - batcher_name, args={"sample_idx": sample_idx, "len": len(batch)}, level=2 - ), - ): - batch_sample = self.batcher(batch) - assert isinstance(batch_sample, Generator) - assert inspect.isgeneratorfunction(self.batcher), ( - f"Generator in {self.batcher} but not marked as such." - ) - target_offset = self._generator_offset - self._generator_offset = 0 - for batch_sub_idx, (sample_idx, inner_batch_sample) in trace_span.iterable( - self._sample_index.iter_ctx(batch_sample, sample_idx), - name=f"{batcher_name}.next", - level=2, - ): - # Skip other samples - if batch_sub_idx >= target_offset: - self._generator_offset = batch_sub_idx + 1 - with trace_gen.yield_(next_args={"sample_idx": sample_idx}): - yield set_sample_restore_key( - inner_batch_sample, - sample_idx, - batch_sub_idx, - *sample_restore_keys, - src=self, - ) - self._generator_sample_keys = None - self._generator_offset = None - batch.clear() - sample_restore_keys = [] + if self._generator_sample_keys is not None: + sample_restore_keys = self._generator_sample_keys + assert self._generator_offset is not None + batch = [self.dataset.restore_sample(inner_idx) for inner_idx in sample_restore_keys] + with self._sample_index.ctx(self._sample_index.current_idx) as sample_idx: + batch_sample = batcher(batch) + assert isinstance(batch_sample, Generator) + assert inspect.isgeneratorfunction(self.batcher), ( + f"Generator in {self.batcher} but not marked as such." + ) + target_offset = self._generator_offset + self._generator_offset = 0 + for batch_sub_idx, (sample_idx, inner_batch_sample) in enumerate( + self._sample_index.iter_ctx(batch_sample, sample_idx) + ): + # Skip other samples + if batch_sub_idx >= target_offset: + self._generator_offset = batch_sub_idx + 1 + yield set_sample_restore_key( + inner_batch_sample, + sample_idx, + batch_sub_idx, + *sample_restore_keys, + src=self, + ) + self._generator_sample_keys = None + self._generator_offset = None + batch.clear() + sample_restore_keys = [] - for sample in self.dataset: - batch.append(sample) - sample_restore_keys.append(get_sample_restore_key(sample)) - if len(batch) == self.batch_size: - yield from flush() - batch = [] - if len(batch) > 0 and not self.drop_last: + for sample in self.dataset: + batch.append(sample) + sample_restore_keys.append(get_sample_restore_key(sample)) + if len(batch) == self.batch_size: yield from flush() + batch = [] + if len(batch) > 0 and not self.drop_last: + yield from flush() def can_restore_sample(self) -> bool: # Cannot really verify if the returned elements contain a __restore_key__. diff --git a/src/megatron/energon/wrappers/blend_dataset.py b/src/megatron/energon/wrappers/blend_dataset.py index 6d041994..f4e5640b 100644 --- a/src/megatron/energon/wrappers/blend_dataset.py +++ b/src/megatron/energon/wrappers/blend_dataset.py @@ -1,11 +1,15 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Dict, Iterator, List, Tuple, TypeVar +from typing import Any, Dict, Iterator, List, Sequence, Tuple, TypeVar import torch -from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key +from megatron.energon.flavors.base_dataset import ( + SavableDataset, + add_sample_restore_key, +) +from megatron.energon.flavors.trace import TraceIter, trace_iter from megatron.energon.rng import WorkerRng from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset @@ -19,7 +23,9 @@ class BlendDataset(BaseWrapperDataset[T_sample, T_sample]): The datasets may be infinite. This dataset is always infinite. """ + datasets: List[SavableDataset[T_sample]] weights: Tuple[float, ...] + dataset_weights: Sequence[Tuple[SavableDataset[T_sample], float]] exhausted: List[bool] _worker_rng: WorkerRng @@ -52,85 +58,69 @@ def reset_state_own(self) -> None: def __len__(self) -> int: # Give the number of samples in inner datasets, disregarding the weight - return sum(len(dataset) for dataset, weight in self.dataset_weights) - - def __iter__(self) -> Iterator[T_sample]: + return sum(len(dataset) for dataset in self.datasets) + + @trace_iter( + name=lambda self: "BlendDataset", + call_args={ + "config": lambda self: self._own_config(), + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_sample]: assert self.worker_has_samples(), "Cannot blend all empty datasets" - trace_span = self.worker_config.worker_trace_span() - with ( - trace_span.span("BlendDataset.__iter__", args={"config": self._own_config()}, level=1), - self.worker_config.worker_trace_writer().generator( - "BlendDataset.__iter__.next", level=2 - ) as trace_gen, - ): - # Create a list of datasets and their weights, but - # set the weight to 0 if the dataset has no samples on this worker. - - dataset_iters = [] - weights = [] + # Create a list of datasets and their weights, but + # set the weight to 0 if the dataset has no samples on this worker. + + dataset_iters = [] + weights = [] + for idx, (dataset, weight) in enumerate(self.dataset_weights): + assert weight > 0, "All blending weights must be > 0" + + if dataset.worker_has_samples(): + dataset_iters.append(iter(dataset)) + weights.append(weight) + else: + dataset_iters.append(None) + weights.append(0) + + weights = torch.tensor(weights, dtype=torch.float32) + if weights.sum() == 0: + raise RuntimeError( + "There is a worker with no samples in any of the blended datasets. " + "This can happen if you have a lot of workers and your dataset is too small. " + "Currently this case is not supported." + ) + + # Some may already be exhausted on this worker when restoring a state. + for idx, exhausted in enumerate(self.exhausted): + if exhausted: + weights[idx] = 0 + dataset_iters[idx] = None + + while True: + ds_idx = self._worker_rng.choice_idx(probs=weights) + + if dataset_iters[ds_idx] is None: + if all(dataset_iter is None for dataset_iter in dataset_iters): + break + continue try: - for idx, (dataset, weight) in enumerate(self.dataset_weights): - assert weight > 0, "All blending weights must be > 0" - - if dataset.worker_has_samples(): - dataset_iters.append(iter(dataset)) - weights.append(weight) - else: - dataset_iters.append(None) - weights.append(0) - - weights = torch.tensor(weights, dtype=torch.float32) - if weights.sum() == 0: - raise RuntimeError( - "There is a worker with no samples in any of the blended datasets. " - "This can happen if you have a lot of workers and your dataset is too small. " - "Currently this case is not supported." - ) - - # Some may already be exhausted on this worker when restoring a state. - for idx, exhausted in enumerate(self.exhausted): - if exhausted: - weights[idx] = 0 - dataset_iters[idx] = None - - while True: - ds_idx = self._worker_rng.choice_idx(probs=weights) - trace_span.instant( - "BlendDataset.__iter__.sample", - args={"weights": weights, "ds_idx": ds_idx}, - level=2, - ) - - if dataset_iters[ds_idx] is None: - if all(dataset_iter is None for dataset_iter in dataset_iters): - break - continue - try: - sample = next(dataset_iters[ds_idx]) - except StopIteration: - trace_span.instant( - "BlendDataset.__iter__.exhausted", args={"ds_idx": ds_idx}, level=1 - ) - dataset_iters[ds_idx] = None - weights[ds_idx] = 0 - self.exhausted[ds_idx] = True - if all(dataset_iter is None for dataset_iter in dataset_iters): - break - else: - with trace_gen.yield_(): - yield add_sample_restore_key(sample, ds_idx, src=self) - - trace_span.instant("BlendDataset.__iter__.reset", level=1) - self.exhausted = [False] * len(self.dataset_weights) - finally: - for it in dataset_iters: - if hasattr(it, "close"): - it.close() + sample = next(dataset_iters[ds_idx]) + except StopIteration: + dataset_iters[ds_idx] = None + weights[ds_idx] = 0 + self.exhausted[ds_idx] = True + if all(dataset_iter is None for dataset_iter in dataset_iters): + break + else: + yield add_sample_restore_key(sample, ds_idx, src=self) + + self.exhausted = [False] * len(self.dataset_weights) def _own_config(self) -> Dict[str, Any]: return { - "dataset_weights": [weight for _, weight in self.dataset_weights], + "weights": self.weights, } def config(self) -> Dict[str, Any]: diff --git a/src/megatron/energon/wrappers/concat_dataset.py b/src/megatron/energon/wrappers/concat_dataset.py index a76d365b..b997987c 100644 --- a/src/megatron/energon/wrappers/concat_dataset.py +++ b/src/megatron/energon/wrappers/concat_dataset.py @@ -3,7 +3,11 @@ from typing import Any, Dict, Generic, Iterator, TypeVar -from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key +from megatron.energon.flavors.base_dataset import ( + SavableDataset, + add_sample_restore_key, +) +from megatron.energon.flavors.trace import TraceIter, trace_iter from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset @@ -32,20 +36,23 @@ def reset_state_own(self) -> None: def __len__(self): return sum(len(dataset) for dataset in self.datasets) - def __iter__(self) -> Iterator[T_sample]: - trace_span = self.worker_config.worker_trace_span() - with trace_span.span("ConcatDataset.__iter__", level=1): - for ds_idx, dataset in enumerate(self.datasets): - with trace_span.span( - "ConcatDataset.next_dataset.yield_from", args={"ds_idx": ds_idx}, level=1 - ): - for sample in dataset: - yield add_sample_restore_key( - sample, - ds_idx, - src=self, - ) - trace_span.instant("ConcatDataset.__iter__.done", level=1) + @trace_iter( + name=lambda self: "ConcatDataset", + call_args={ + "config": lambda self: self._own_config(), + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_sample]: + for ds_idx, dataset in enumerate(self.datasets): + for sample in dataset: + yield add_sample_restore_key( + sample, + ds_idx, + src=self, + ) + + def _own_config(self) -> Dict[str, Any]: + return {} def config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/wrappers/epochize_dataset.py b/src/megatron/energon/wrappers/epochize_dataset.py index 3112ff4a..5b6d4688 100644 --- a/src/megatron/energon/wrappers/epochize_dataset.py +++ b/src/megatron/energon/wrappers/epochize_dataset.py @@ -3,7 +3,11 @@ from typing import Any, Dict, Generic, Iterator, Optional, TypeVar -from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.flavors.base_dataset import ( + SavableDataset, + add_sample_restore_key, +) +from megatron.energon.flavors.trace import TraceIter, trace_iter from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset @@ -48,8 +52,13 @@ def __init__( def reset_state_own(self) -> None: self._offset = 0 - def __iter__(self) -> Iterator[T_sample]: - trace_span = self.worker_config.worker_trace_span() + @trace_iter( + name=lambda self: "EpochizeDataset", + call_args={ + "config": lambda self: self._own_config(), + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_sample]: # Compute the local length for this worker, i.e. all worker's lengths sum up to the total if self.worker_config.num_workers <= 1: local_length = self.length @@ -58,42 +67,31 @@ def __iter__(self) -> Iterator[T_sample]: if self.worker_config.rank_worker_id() < self.length % self.worker_config.num_workers: local_length += 1 - with ( - trace_span.span( - "EpochizeDataset.__iter__", - args={ - "offset": self._offset, - "local_length": local_length, - "config": self._own_config(), - }, - level=1, - ), - self.worker_config.worker_trace_writer().generator( - "EpochizeDataset.__iter__.next", level=2 - ) as trace_gen, - ): + while self._offset < local_length: try: - offset_range = list(range(self._offset, local_length)) - - # Only iterate if there are samples to iterate - if len(offset_range) > 0: - if self._active_iter is None: - self._active_iter = iter(self.dataset) - - for idx in offset_range: - self._offset = (idx + 1) % local_length - try: - sample = next(self._active_iter) - except StopIteration: - break - with trace_gen.yield_(): - yield sample - trace_span.instant("EpochizeDataset.__iter__.done", level=1) + if self._active_iter is None: + self._active_iter = iter(self.dataset) + + sample_offset = self._offset + self._offset += 1 + try: + sample = next(self._active_iter) + except StopIteration: + self._active_iter = None + break + + yield add_sample_restore_key( + sample, + sample_offset, + src=self, + ) except GeneratorExit: if self._active_iter is not None and hasattr(self._active_iter, "close"): self._active_iter.close() self._active_iter = None raise + if self._offset >= local_length: + self._offset = 0 def __len__(self) -> int: return self.length @@ -108,6 +106,7 @@ def config(self) -> Dict[str, Any]: "type": type(self).__qualname__, "dataset": self.dataset.config(), "length": self.length, + "worker_config": self.worker_config.config(), } def __str__(self): diff --git a/src/megatron/energon/wrappers/filter_dataset.py b/src/megatron/energon/wrappers/filter_dataset.py index e8cf7e6a..5c0deda8 100644 --- a/src/megatron/energon/wrappers/filter_dataset.py +++ b/src/megatron/energon/wrappers/filter_dataset.py @@ -3,7 +3,11 @@ from typing import Any, Callable, Dict, Generic, Iterator, Optional, TypeVar, Union -from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.flavors.base_dataset import ( + SavableDataset, + add_sample_restore_key, +) +from megatron.energon.flavors.trace import TraceIter, trace_iter from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex @@ -50,26 +54,27 @@ def reset_state_own(self) -> None: def __len__(self): return len(self.dataset) - def __iter__(self) -> Iterator[T_sample]: - trace_span = self.worker_config.worker_trace_span() - filter_name = self._function_config(self.filter_fn) - with ( - trace_span.span("FilterDataset.__iter__", args={"config": self._own_config()}, level=1), - trace_span.generator("FilterDataset.__iter__.next", level=2) as trace_gen, - ): - for sample in self.dataset: - with ( - self._sample_index.ctx(), - trace_span.span(filter_name, args={"sample": sample}, level=2), - ): - filter_res = self.filter_fn(sample) - if filter_res: - with trace_gen.yield_(): - yield sample - else: - self.worker_config.worker_trace_span().instant( - "FilterDataset.__iter__.reject", level=3 - ) + @trace_iter( + name=lambda self: f"FilterDataset({self._function_config_short(self.filter_fn)})", + call_args={ + "config": lambda self: self._own_config(), + }, + next_args={ + "sample_idx": lambda self: self._sample_index.current_idx, + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_sample]: + filter_fn = trace_iter.wrap_fn(self.filter_fn) + + for sample in self.dataset: + with self._sample_index.ctx() as sample_idx: + filter_res = filter_fn(sample) + if filter_res: + yield add_sample_restore_key( + sample, + sample_idx, + src=self, + ) def _own_config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/wrappers/gc_dataset.py b/src/megatron/energon/wrappers/gc_dataset.py index 5ab3c807..e4fc670a 100644 --- a/src/megatron/energon/wrappers/gc_dataset.py +++ b/src/megatron/energon/wrappers/gc_dataset.py @@ -11,6 +11,7 @@ from torch.distributed.distributed_c10d import reduce_op from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.flavors.trace import TraceIter, trace_iter from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset @@ -100,39 +101,44 @@ def reset_state_own(self) -> None: def __len__(self): return len(self.dataset) - def __iter__(self) -> Iterator[T_sample]: - trace_span = self.worker_config.worker_trace_span() - with ( - trace_span.span("GcDataset.__iter__", args={"config": self._own_config()}, level=1), - self.worker_config.worker_trace_writer().generator( - "GcDataset.__iter__.next", level=2 - ) as trace_gen, - ): - in_worker = torch.utils.data.get_worker_info() is not None - if in_worker and not _frozen_cuda_tensors_initialized: - raise GcFreezeError( - "You are using GcDataset with multiple workers, but forgot to call gc_init_worker() in at least one forked worker process." - ) - + @trace_iter( + call_args={ + "config": lambda self: self._own_config(), + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_sample]: + in_worker = torch.utils.data.get_worker_info() is not None + if in_worker and not _frozen_cuda_tensors_initialized: + raise GcFreezeError( + "You are using GcDataset with multiple workers, but forgot to call gc_init_worker() in at least one forked worker process." + ) + + @trace_iter.wrap_inner() + def gc_freeze(): + gc.freeze() + + @trace_iter.wrap_inner() + def gc_collect(): + gc.collect() + + @trace_iter.wrap_inner() + def gc_unfreeze(): + gc.unfreeze() + + if self.freeze: + gc_collect() + gc_freeze() + try: + iter = 0 + for sample in self.dataset: + yield sample + iter += 1 + if iter >= self.every_n_iter: + gc_collect() + iter = 0 + finally: if self.freeze: - with trace_span.span("GcDataset.__iter__.gc.collect", level=1): - gc.collect() - with trace_span.span("GcDataset.__iter__.gc.freeze", level=1): - gc.freeze() - try: - iter = 0 - for sample in self.dataset: - with trace_gen.yield_(): - yield sample - iter += 1 - if iter >= self.every_n_iter: - with trace_span.span("GcDataset.__iter__.gc.collect", level=1): - gc.collect() - iter = 0 - finally: - if self.freeze: - with trace_span.span("GcDataset.__iter__.gc.unfreeze", level=1): - gc.unfreeze() + gc_unfreeze() def _own_config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/wrappers/group_batch_dataset.py b/src/megatron/energon/wrappers/group_batch_dataset.py index 56350973..e16c56d6 100644 --- a/src/megatron/energon/wrappers/group_batch_dataset.py +++ b/src/megatron/energon/wrappers/group_batch_dataset.py @@ -12,6 +12,7 @@ Iterator, List, Optional, + Sequence, Tuple, TypeVar, Union, @@ -24,8 +25,9 @@ SavableDataset, set_sample_restore_key, ) +from megatron.energon.flavors.trace import TraceIter, trace_iter from megatron.energon.savable import Savable -from megatron.energon.source_info import SourceInfo +from megatron.energon.source_info import SourceInfo, get_source_info from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers._log_exception import log_exception from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex @@ -67,7 +69,7 @@ class GroupBatchDataset( sample_group_key: Callable[[T_batch_sample], Tuple[Hashable, Optional[int]]] batcher: Callable[[List[T_batch_sample]], T_batch] drop_last: bool - error_handler: Callable[[Exception, List[T_batch_sample], list[SourceInfo]], None] + error_handler: Callable[[Exception, List[T_batch_sample], Sequence[SourceInfo]], None] _group_key_sample_index: SampleIndex _batch_sample_index: SampleIndex _buckets: Dict[Hashable, Bucket[T_batch_sample]] @@ -83,7 +85,7 @@ def __init__( batcher_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] = None, drop_last: bool = False, error_handler: Callable[ - [Exception, List[T_batch_sample], list[SourceInfo]], None + [Exception, List[T_batch_sample], Sequence[SourceInfo]], None ] = log_exception, failure_tolerance: Optional[int] = 100, worker_config: WorkerConfig, @@ -126,132 +128,122 @@ def __len__(self): # Return an upper bound. This is for sure not correct. return len(self.dataset) - def __iter__(self) -> Iterator[T_batch]: + @trace_iter( + next_args={ + "idx": lambda self: self._batch_sample_index.current_idx, + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_batch]: buckets = self._buckets last_batch_failures = 0 - batcher_name = self._function_config(self.batcher) - trace_span = self.worker_config.worker_trace_span() - with ( - trace_span.span( - "GroupBatchDataset.__iter__", args={"config": self._own_config()}, level=1 - ), - self.worker_config.worker_trace_writer().generator( - "GroupBatchDataset.__iter__.next", level=2 - ) as trace_gen, - ): - if buckets is None: - buckets = self._buckets = dict() - - # Load saved state if available - for bucket in buckets.values(): - bucket.samples.worker_start() - - # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] initial GroupBatchDataset state:\n", end="") - # for bucket_key, bucket in buckets.items(): - # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] - Bucket [{bucket_key}] (bs={bucket.batch_size}, len(samples)={len(bucket.samples)}):\n", end="") - # bucket.samples.debug_print(" ") - # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] initial done\n", end="") - - def flush(key: Any, bucket: Bucket[T_batch_sample]) -> Generator[T_batch, None, None]: - nonlocal last_batch_failures - # Debug print the state - # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] flush GroupBatchDataset state:\n", end="") - # for dbg_bucket_key, dbg_bucket in buckets.items(): - # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] - Bucket [{dbg_bucket_key}{'*' if dbg_bucket_key == bucket_key else ''}] (bs={dbg_bucket.batch_size}, len(samples)={len(dbg_bucket.samples)}):\n", end="") - # dbg_bucket.samples.debug_print(" ") - batch_items, sample_restore_keys = bucket.samples.flush() - # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] flushed: len(batch)={len(batch_items)} len(samples)={len(bucket.samples)}\n", end="") - try: - with ( - self._batch_sample_index.ctx() as sample_idx, - trace_span.span( - batcher_name, - args={ - "bucket": str(key), - "bucket_size": bucket.batch_size, - "sample_idx": sample_idx, - "len": len(batch_items), - }, - level=2, - ), - ): - batch_sample = self.batcher(batch_items) - assert not isinstance(batch_sample, Generator), ( - f"Batcher {self.batcher} returned a generator, which is not supported for grouped batching yet." - ) - last_batch_failures = 0 - set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self) - with trace_gen.yield_(): - yield batch_sample - except SkipSample: - trace_span.instant("GroupBatchDataset.flush.skip", level=2) - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(batch_items) - except Exception as e: - self.error_handler(e, batch_items) - trace_span.instant( - "GroupBatchDataset.flush.error/skip", - args={"exception": f"{type(e).__name__}: {str(e)}"}, - level=2, + if buckets is None: + buckets = self._buckets = dict() + + batcher = trace_iter.wrap_fn(self.batcher) + + # Load saved state if available + for bucket in buckets.values(): + bucket.samples.worker_start() + + # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] initial GroupBatchDataset state:\n", end="") + # for bucket_key, bucket in buckets.items(): + # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] - Bucket [{bucket_key}] (bs={bucket.batch_size}, len(samples)={len(bucket.samples)}):\n", end="") + # bucket.samples.debug_print(" ") + # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] initial done\n", end="") + + @trace_iter.wrap_inner( + call_args=lambda key, bucket: { + "key": key, + "len": len(bucket.samples), + }, + ) + def flush(key: Any, bucket: Bucket[T_batch_sample]) -> Generator[T_batch, None, None]: + nonlocal last_batch_failures + # Debug print the state + # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] flush GroupBatchDataset state:\n", end="") + # for dbg_bucket_key, dbg_bucket in buckets.items(): + # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] - Bucket [{dbg_bucket_key}{'*' if dbg_bucket_key == bucket_key else ''}] (bs={dbg_bucket.batch_size}, len(samples)={len(dbg_bucket.samples)}):\n", end="") + # dbg_bucket.samples.debug_print(" ") + batch_items, sample_restore_keys = bucket.samples.flush() + # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] flushed: len(batch)={len(batch_items)} len(samples)={len(bucket.samples)}\n", end="") + try: + with self._batch_sample_index.ctx() as sample_idx: + trace_iter.sample( + batch_items, + { + "bucket": str(key), + "bucket_size": bucket.batch_size, + "sample_idx": sample_idx, + "len": len(batch_items), + }, ) - last_batch_failures += 1 - if ( - self.failure_tolerance is not None - and last_batch_failures >= self.failure_tolerance - ): - raise FatalSampleError.from_sample( - batch_items, - f"GroupBatchDataset {self.batcher} failed {last_batch_failures} times in a row. Likely your code or dataset are broken.", - ) - - # Add samples to the buckets - for sample in self.dataset: - try: - with self._group_key_sample_index.ctx(): - bucket_key, batch_size = self.sample_group_key(sample) - assert (batch_size is None) != (self.fixed_batch_size is None), ( - f"A sample in group for key {bucket_key} returned batch size {batch_size}, but fixed " - f"batch size is set to {self.fixed_batch_size}. One of the two should be None." - ) - if self.fixed_batch_size is not None: - batch_size = self.fixed_batch_size - except SkipSample: - trace_span.instant("GroupBatchDataset.__iter__.skip", level=2) - continue - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(sample) - except Exception as e: - self.error_handler(e, [sample]) - trace_span.instant( - "GroupBatchDataset.__iter__.error/skip", - args={"exception": f"{type(e).__name__}: {str(e)}"}, - level=2, + batch_sample = batcher(batch_items) + assert not isinstance(batch_sample, Generator), ( + f"Batcher {batcher} returned a generator, which is not supported for grouped batching yet." ) - continue - bucket = buckets.get(bucket_key) - if bucket is None: - assert batch_size is not None - buckets[bucket_key] = bucket = Bucket( - batch_size=batch_size, - samples=SavableSampleBuffer(self.dataset, worker_config=self.worker_config), + last_batch_failures = 0 + set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self) + yield batch_sample + except SkipSample: + trace_iter.skip_sample(batch_items) + except SYSTEM_EXCEPTIONS: + raise FatalSampleError.from_sample(batch_items) + except Exception as e: + self.error_handler(e, batch_items, get_source_info(batch_items)) + trace_iter.sample_exception(e, batch_items) + last_batch_failures += 1 + if ( + self.failure_tolerance is not None + and last_batch_failures >= self.failure_tolerance + ): + raise FatalSampleError.from_sample( + batch_items, + f"GroupBatchDataset {self.batcher} failed {last_batch_failures} times in a row. Likely your code or dataset are broken.", ) - else: - assert bucket.batch_size == batch_size, ( - f"Got different batch size for group {bucket_key}: {bucket.batch_size} != {batch_size}." + + # Add samples to the buckets + for sample in self.dataset: + try: + with self._group_key_sample_index.ctx(): + bucket_key, batch_size = self.sample_group_key(sample) + assert (batch_size is None) != (self.fixed_batch_size is None), ( + f"A sample in group for key {bucket_key} returned batch size {batch_size}, but fixed " + f"batch size is set to {self.fixed_batch_size}. One of the two should be None." ) - bucket.samples.append(sample) - if len(bucket.samples) >= bucket.batch_size: + if self.fixed_batch_size is not None: + batch_size = self.fixed_batch_size + except SkipSample: + trace_iter.skip_sample(sample) + continue + except SYSTEM_EXCEPTIONS: + raise FatalSampleError.from_sample(sample) + except Exception as e: + self.error_handler(e, [sample], get_source_info(sample)) + trace_iter.sample_exception(e, sample) + continue + bucket = buckets.get(bucket_key) + if bucket is None: + assert batch_size is not None + buckets[bucket_key] = bucket = Bucket( + batch_size=batch_size, + samples=SavableSampleBuffer(self.dataset, worker_config=self.worker_config), + ) + else: + assert bucket.batch_size == batch_size, ( + f"Got different batch size for group {bucket_key}: {bucket.batch_size} != {batch_size}." + ) + bucket.samples.append(sample) + if len(bucket.samples) >= bucket.batch_size: + yield from flush(bucket_key, bucket) + # Flush out last samples + if not self.drop_last: + for bucket_key, bucket in buckets.items(): + if len(bucket.samples) > 0: yield from flush(bucket_key, bucket) - # Flush out last samples - if not self.drop_last: - for bucket_key, bucket in buckets.items(): - if len(bucket.samples) > 0: - yield from flush(bucket_key, bucket) - # Clear the buckets - self._buckets.clear() - trace_span.instant("GroupBatchDataset.__iter__.done", level=1) + # Clear the buckets + self._buckets.clear() def save_state(self) -> FlexState: return FlexState( diff --git a/src/megatron/energon/wrappers/iter_map_dataset.py b/src/megatron/energon/wrappers/iter_map_dataset.py index 58f0f056..f0434e0b 100644 --- a/src/megatron/energon/wrappers/iter_map_dataset.py +++ b/src/megatron/energon/wrappers/iter_map_dataset.py @@ -17,7 +17,11 @@ from torch.utils.data import IterableDataset from megatron.energon.errors import SYSTEM_EXCEPTIONS, FatalSampleError -from megatron.energon.flavors.base_dataset import SavableDataset, set_sample_restore_key +from megatron.energon.flavors.base_dataset import ( + SavableDataset, + set_sample_restore_key, +) +from megatron.energon.flavors.trace import TraceIter, trace_iter from megatron.energon.source_info import SourceInfo from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers._log_exception import log_exception @@ -91,69 +95,60 @@ def reset_state_own(self) -> None: def __len__(self): return self.len_map_fn(len(self.dataset)) - def __iter__(self) -> Iterator[T_sample_out]: - trace_span = self.worker_config.worker_trace_span() - iter_name = f"IterMapDataset.__iter__.iter_map_fn:{self._function_config(self.iter_map_fn)}" - with ( - trace_span.span( - "IterMapDataset.__iter__", args={"config": self._own_config()}, level=1 - ), - self.worker_config.worker_trace_writer().generator( - "IterMapDataset.__iter__.next", level=2 - ) as trace_gen, - ): - last_sample_wrapper = _LastSampleWrapper(self.dataset) - # The iter_map_fn is stateless. Thus we need to know which inner sample created the - # outer sample, and the relative outer sample index, so we can restore it. + @trace_iter( + name=lambda self: f"IterMapDataset.__iter__.iter_map_fn:{self._function_config(self.iter_map_fn)}", + call_args={ + "config": lambda self: self._own_config(), + }, + next_args={ + "idx": lambda self: self._sample_index.current_idx, + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_sample_out]: + iter_map_fn = trace_iter.wrap_fn(self.iter_map_fn) + last_sample_wrapper = _LastSampleWrapper(self.dataset) + # The iter_map_fn is stateless. Thus we need to know which inner sample created the + # outer sample, and the relative outer sample index, so we can restore it. - # This is the sample index within the currently yielded sample - iter_idx = 0 - sample_idx = 0 - sample_restore_keys = [] + # This is the sample index within the currently yielded sample + iter_idx = 0 + sample_idx = 0 + sample_restore_keys = [] - def reset_idx_iter() -> Generator[T_sample, None, None]: - # Resets the inner sample index - nonlocal iter_idx, sample_restore_keys - for entry in last_sample_wrapper: - iter_idx = 0 - sample_restore_keys.append(get_sample_restore_key(entry)) - yield entry + def reset_idx_iter() -> Generator[T_sample, None, None]: + # Resets the inner sample index + nonlocal iter_idx, sample_restore_keys + for entry in last_sample_wrapper: + iter_idx = 0 + sample_restore_keys.append(get_sample_restore_key(entry)) + yield entry - ds_iter = iter(reset_idx_iter()) + ds_iter = iter(reset_idx_iter()) - try: - # While True will break when the inner dataset is exhausted, but may continue on exception - while True: - iter_idx = 0 - try: - for sample_idx, sample in trace_span.iterable( - self._sample_index.iter_ctx(self.iter_map_fn(ds_iter)), - name=iter_name, - level=1, - ): - with trace_gen.yield_(): - yield set_sample_restore_key( - sample, - sample_idx, - iter_idx, - *sample_restore_keys, - src=self, - ) - sample_restore_keys.clear() - iter_idx += 1 - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(last_sample_wrapper.last_sample) - except Exception as e: - self.error_handler(e, last_sample_wrapper.last_sample) - trace_span.instant( - "IterMapDataset.__iter__.error/retry", - args={"exception": f"{type(e).__name__}: {str(e)}"}, - level=1, + try: + # While True will break when the inner dataset is exhausted, but may continue on exception + while True: + iter_idx = 0 + try: + for sample_idx, sample in self._sample_index.iter_ctx(iter_map_fn(ds_iter)): + yield set_sample_restore_key( + sample, + sample_idx, + iter_idx, + *sample_restore_keys, + src=self, ) - else: - break - finally: - ds_iter.close() + sample_restore_keys.clear() + iter_idx += 1 + except SYSTEM_EXCEPTIONS: + raise FatalSampleError.from_sample(last_sample_wrapper.last_sample) + except Exception as e: + self.error_handler(e, last_sample_wrapper.last_sample) + trace_iter.sample_exception(e, last_sample_wrapper.last_sample) + else: + break + finally: + ds_iter.close() def can_restore_sample(self) -> bool: return super().can_restore_sample() and self.stateless_iter_fn diff --git a/src/megatron/energon/wrappers/limit_dataset.py b/src/megatron/energon/wrappers/limit_dataset.py index ab72ec8d..a63270cb 100644 --- a/src/megatron/energon/wrappers/limit_dataset.py +++ b/src/megatron/energon/wrappers/limit_dataset.py @@ -4,6 +4,7 @@ from typing import Any, Dict, Generic, Iterator, TypeVar from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.flavors.trace import TraceIter, trace_iter from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset @@ -46,7 +47,12 @@ def reset_state_own(self) -> None: def __len__(self) -> int: return min(self.length, len(self.dataset)) - def __iter__(self) -> Iterator[T_sample]: + @trace_iter( + call_args={ + "config": lambda self: self._own_config(), + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_sample]: worker_id = self.worker_config.rank_worker_id() # Compute the local limit for this worker, i.e. all worker's limits sum up to the total @@ -57,31 +63,20 @@ def __iter__(self) -> Iterator[T_sample]: if worker_id < self.length % self.worker_config.num_workers: local_limit += 1 - trace_span = self.worker_config.worker_trace_span() - with trace_span.span( - "LimitDataset.__iter__", - args={ - "offset": self.current_offset, - "local_limit": local_limit, - "config": self._own_config(), - }, - level=2, - ): - offset_range = list(range(self.current_offset, local_limit)) - # Only iterate self.dataset if there are samples to iterate - if len(offset_range) > 0: - for sample, offset in zip( - self.dataset, - offset_range, - ): - self.current_offset = offset + 1 - yield sample + offset_range = list(range(self.current_offset, local_limit)) + # Only iterate self.dataset if there are samples to iterate + if len(offset_range) > 0: + for sample, offset in zip( + self.dataset, + offset_range, + ): + self.current_offset = offset + 1 + yield sample # Reset the inner dataset self.current_offset = 0 if self.reset_after_epoch: - with trace_span.span("LimitDataset.__iter__.reset_state_deep"): - self.dataset.reset_state_deep() + self.dataset.reset_state_deep() def worker_has_samples(self) -> bool: return super().worker_has_samples() and self.length > 0 diff --git a/src/megatron/energon/wrappers/log_sample_dataset.py b/src/megatron/energon/wrappers/log_sample_dataset.py index 70145650..1caae23b 100644 --- a/src/megatron/energon/wrappers/log_sample_dataset.py +++ b/src/megatron/energon/wrappers/log_sample_dataset.py @@ -4,55 +4,13 @@ from typing import Any, Callable, Dict, Generic, Iterator, List, Literal, Optional, TypeVar from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.flavors.trace import TraceIter, default_get_keys, trace_iter from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset T_sample = TypeVar("T_sample") -def _flatten_str_list(keys: Any) -> Iterator[Optional[str]]: - """Flatten a list of keys into a list of strings.""" - if isinstance(keys, str): - yield keys - elif isinstance(keys, (list, tuple)): - for key in keys: - yield from _flatten_str_list(key) - else: - yield None - - -def _flatten_str_list_or_none(keys: Any) -> Optional[List[str]]: - """Flatten a list of keys into a list of strings. If this cannot be fetched, return None.""" - keys = list(_flatten_str_list(keys)) - if any(k is None for k in keys): - return None - return keys - - -def default_get_keys(batch: Any) -> Optional[List[str]]: - """Default get_keys, which has some heuristics to find the sample keys.""" - if isinstance(batch, list): - all_keys = [] - for b in batch: - k = default_get_keys(b) - if k is None: - return None - all_keys.extend(k) - return all_keys - if hasattr(batch, "__key__"): - return _flatten_str_list_or_none(batch.__key__) - elif hasattr(batch, "__keys__"): - return _flatten_str_list_or_none(batch.__keys__) - elif isinstance(batch, dict): - if "__key__" in batch: - return _flatten_str_list_or_none(batch["__key__"]) - elif "__keys__" in batch: - return _flatten_str_list_or_none(batch["__keys__"]) - elif "keys" in batch: - return _flatten_str_list_or_none(batch["keys"]) - return None - - class LogSampleDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]): """This dataset logs every yielded sample to the debug logs.""" @@ -86,33 +44,16 @@ def reset_state_own(self) -> None: def __len__(self): return len(self.dataset) - def _log(self, sample: T_sample) -> dict: - log_entry = { - "idx": self._step, - } - keys = self.get_keys_fn(sample) - if keys is not None: - log_entry["keys"] = keys - - return log_entry - - def __iter__(self) -> Iterator[T_sample]: - trace_span = self.worker_config.worker_trace_span() - with trace_span.span( - "LogSampleDataset.__iter__", - args={ - "mode": self.mode, - }, - level=1, - ): - for sample in trace_span.iterable( - self.dataset, name="LogSampleDataset.__iter__.next", level=1 - ): - with trace_span.span( - "LogSampleDataset.__iter__.yield", args=self._log(sample), level=1 - ): - self._step += 1 - yield sample + @trace_iter( + next_args={ + "idx": lambda self: self._step, + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_sample]: + for sample in self.dataset: + self._step += 1 + trace_iter.sample(sample) + yield sample def config(self) -> Dict[str, Any]: # Transparent logger, it won't change the samples diff --git a/src/megatron/energon/wrappers/map_dataset.py b/src/megatron/energon/wrappers/map_dataset.py index c91752e5..8564a6a8 100644 --- a/src/megatron/energon/wrappers/map_dataset.py +++ b/src/megatron/energon/wrappers/map_dataset.py @@ -10,14 +10,19 @@ Generic, Iterator, Optional, + Sequence, Tuple, TypeVar, Union, ) from megatron.energon.errors import SYSTEM_EXCEPTIONS, FatalSampleError -from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key -from megatron.energon.source_info import SourceInfo +from megatron.energon.flavors.base_dataset import ( + SavableDataset, + add_sample_restore_key, +) +from megatron.energon.flavors.trace import TraceIter, trace_iter +from megatron.energon.source_info import SourceInfo, get_source_info from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers._log_exception import log_exception from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key @@ -31,7 +36,7 @@ class MapDataset(BaseWrapperDataset[T_sample, T_sample_out], Generic[T_sample, T """This dataset wrapper applies a custom function to transform each sample.""" map_fn: Callable[[T_sample], Union[T_sample_out, Generator[T_sample_out, None, None]]] - error_handler: Callable[[Exception, T_sample, list[SourceInfo]], None] + error_handler: Callable[[Exception, T_sample, Sequence[SourceInfo]], None] stateless_map_fn: bool map_fn_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] _sample_index: SampleIndex @@ -49,7 +54,7 @@ def __init__( dataset: SavableDataset[T_sample], map_fn: Callable[[T_sample], Union[T_sample_out, Generator[T_sample_out, None, None]]], *, - error_handler: Callable[[Exception, T_sample, list[SourceInfo]], None] = log_exception, + error_handler: Callable[[Exception, T_sample, Sequence[SourceInfo]], None] = log_exception, stateless_map_fn: bool = False, map_fn_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] = None, failure_tolerance: Optional[int] = 100, @@ -90,115 +95,98 @@ def reset_state_own(self) -> None: def __len__(self): return len(self.dataset) - def __iter__(self) -> Iterator[T_sample_out]: - trace_span = self.worker_config.worker_trace_span() - map_dataset_prefix = f"MapDataset({self._function_config_short(self.map_fn)})" - fn_span = self._function_config(self.map_fn) + @trace_iter( + name=lambda self: f"MapDataset({self._function_config_short(self.map_fn)})", + call_args={ + "config": lambda self: self._own_config(), + }, + next_args={ + "sample_idx": lambda self: self._sample_index.current_idx, + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_sample_out]: last_map_failures = 0 - with ( - trace_span.span( - f"{map_dataset_prefix}.__iter__", args={"config": self._own_config()}, level=1 - ), - self.worker_config.worker_trace_writer().generator( - "MapDataset.__iter__.next", level=2 - ) as trace_gen, - ): - if self._generator_sample_key is not None: - assert self._generator_offset is not None - sample = self.dataset.restore_sample(self._generator_sample_key) - # Do not increment the sample index, use previous index - with ( - self._sample_index.ctx(self._sample_index.current_idx) as sample_idx, - trace_span.span(fn_span, args={"sample_idx": sample_idx}, level=2), - ): - mapped_sample = self.map_fn(sample) - assert isinstance(mapped_sample, Generator) - assert inspect.isgeneratorfunction(self.map_fn), ( - f"Generator in {self.map_fn} but not marked as such." - ) - target_offset = self._generator_offset - self._generator_offset = 0 - for idx, (sample_idx, inner_sample) in trace_span.iterable( - enumerate(self._sample_index.iter_ctx(mapped_sample, sample_idx)), - name=f"{fn_span}.next", - level=2, - ): - # Skip other samples - if idx >= target_offset: - self._generator_offset = idx + 1 - with trace_gen.yield_(last_args={"sample_idx": sample_idx, "idx": idx}): - yield add_sample_restore_key( - inner_sample, - sample_idx, - idx, - src=self, - ) - self._generator_sample_key = None - self._generator_offset = None + map_fn = trace_iter.wrap_fn(self.map_fn) - for sample in self.dataset: - restore_key = get_sample_restore_key(sample) - try: - with ( - self._sample_index.ctx() as sample_idx, - trace_span.span(fn_span, args={"sample_idx": sample_idx}, level=2), - ): - mapped_sample = self.map_fn(sample) - if isinstance(mapped_sample, Generator): - assert inspect.isgeneratorfunction(self.map_fn), ( - f"Generator in {self.map_fn} but not marked as such." - ) - self._generator_sample_key = restore_key - self._generator_offset = 0 - # In case of a generator, additionally store the index of the yielded samples - # per input sample - for idx, (sample_idx, inner_sample) in trace_span.iterable( - enumerate(self._sample_index.iter_ctx(mapped_sample, sample_idx)), - name=f"{fn_span}.next", - level=2, - ): - self._generator_offset = idx + 1 - last_map_failures = 0 - with trace_gen.yield_(last_args={"sample_idx": sample_idx, "idx": idx}): - yield add_sample_restore_key( - inner_sample, - sample_idx, - idx, - src=self, - ) - self._generator_sample_key = None - self._generator_offset = None - else: - last_map_failures = 0 - with trace_gen.yield_(last_args={"sample_idx": sample_idx}): - yield add_sample_restore_key( - mapped_sample, - sample_idx, - src=self, - ) - except GeneratorExit: - raise - except SkipSample: - trace_span.instant(f"{map_dataset_prefix}.__iter__.skip", level=1) - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(sample) - except Exception as e: - self.error_handler(e, sample) - trace_span.instant( - f"{map_dataset_prefix}.__iter__.error/skip", - args={"exception": f"{type(e).__name__}: {str(e)}"}, - level=1, + if self._generator_sample_key is not None: + assert self._generator_offset is not None + sample = self.dataset.restore_sample(self._generator_sample_key) + # Do not increment the sample index, use previous index + with self._sample_index.ctx(self._sample_index.current_idx) as sample_idx: + mapped_sample = map_fn(sample) + assert isinstance(mapped_sample, Generator) + assert inspect.isgeneratorfunction(self.map_fn), ( + f"Generator in {self.map_fn} but not marked as such." + ) + target_offset = self._generator_offset + self._generator_offset = 0 + for idx, (sample_idx, inner_sample) in enumerate( + self._sample_index.iter_ctx(mapped_sample, sample_idx) + ): + # Skip other samples + if idx >= target_offset: + self._generator_offset = idx + 1 + yield add_sample_restore_key( + inner_sample, + sample_idx, + idx, + src=self, + ) + self._generator_sample_key = None + self._generator_offset = None + + for sample in self.dataset: + restore_key = get_sample_restore_key(sample) + try: + with self._sample_index.ctx() as sample_idx: + mapped_sample = map_fn(sample) + if isinstance(mapped_sample, Generator): + assert inspect.isgeneratorfunction(self.map_fn), ( + f"Generator in {self.map_fn} but not marked as such." ) - last_map_failures += 1 - if ( - self.failure_tolerance is not None - and last_map_failures >= self.failure_tolerance + self._generator_sample_key = restore_key + self._generator_offset = 0 + # In case of a generator, additionally store the index of the yielded samples + # per input sample + for idx, (sample_idx, inner_sample) in enumerate( + self._sample_index.iter_ctx(mapped_sample, sample_idx) ): - raise FatalSampleError.from_sample( - sample, - f"MapDataset {self.map_fn} failed {last_map_failures} times in a row. Likely your code or dataset are broken.", + self._generator_offset = idx + 1 + last_map_failures = 0 + yield add_sample_restore_key( + inner_sample, + sample_idx, + idx, + src=self, ) + self._generator_sample_key = None + self._generator_offset = None + else: + last_map_failures = 0 + yield add_sample_restore_key( + mapped_sample, + sample_idx, + src=self, + ) + except GeneratorExit: + raise + except SkipSample: + trace_iter.skip_sample(sample) + except SYSTEM_EXCEPTIONS: + raise FatalSampleError.from_sample(sample) + except Exception as e: + self.error_handler(e, sample, get_source_info(sample)) + trace_iter.sample_exception(e, sample) + last_map_failures += 1 + if ( + self.failure_tolerance is not None + and last_map_failures >= self.failure_tolerance + ): + raise FatalSampleError.from_sample( + sample, + f"MapDataset {self.map_fn} failed {last_map_failures} times in a row. Likely your code or dataset are broken.", + ) def can_restore_sample(self) -> bool: return super().can_restore_sample() and self.stateless_map_fn diff --git a/src/megatron/energon/wrappers/packing_dataset.py b/src/megatron/energon/wrappers/packing_dataset.py index 2a048146..c31917aa 100644 --- a/src/megatron/energon/wrappers/packing_dataset.py +++ b/src/megatron/energon/wrappers/packing_dataset.py @@ -12,6 +12,7 @@ Iterator, List, Optional, + Sequence, TypeVar, Union, ) @@ -22,7 +23,8 @@ add_sample_restore_key, set_sample_restore_key, ) -from megatron.energon.source_info import SourceInfo +from megatron.energon.flavors.trace import TraceIter, trace_iter +from megatron.energon.source_info import SourceInfo, get_source_info from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers._log_exception import log_exception from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key @@ -48,7 +50,7 @@ class PackingDataset( final_packer: Callable[[List[T_encoded_sample]], T_batch_sample] final_packer_stateless: bool packer_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] - error_handler: Callable[[Exception, List[T_sample], list[SourceInfo]], None] + error_handler: Callable[[Exception, List[T_sample], Sequence[SourceInfo]], None] #: The buffer for collecting the samples that shall be packed. _reading_buffer: SavableSampleBuffer @@ -93,7 +95,7 @@ def __init__( sample_encoder_stateless: bool = False, packer_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] = None, error_handler: Callable[ - [Exception, List[T_sample], list[SourceInfo]], None + [Exception, List[T_sample], Sequence[SourceInfo]], None ] = log_exception, pre_packer_failure_tolerance: Optional[int] = 100, final_packer_failure_tolerance: Optional[int] = 100, @@ -161,95 +163,74 @@ def __len__(self): return len(self.dataset) - def _fill_reading_buffer(self, source_iter: Iterator, log_progress: bool = False) -> bool: - """ - Fill the reading buffer with samples from the dataset source iterator. - - Args: - source_iter: Iterator of samples from the dataset. - log_progress: If True, log the progress of the filling. - - Returns: - True if samples are successfully read into the buffer, False if no more data. - """ - - if log_progress: - import tqdm - - pbar_ctx = pbar = tqdm.tqdm(total=self.buffer_size, desc="Filling reading buffer") - else: - pbar_ctx = contextlib.nullcontext() - pbar = None - - with pbar_ctx: - while len(self._reading_buffer) + len(self._pre_packing_buffer) < self.buffer_size: - try: - sample = next(source_iter) - self._reading_buffer.append(sample) - if pbar is not None: - pbar.update(1) - except StopIteration: - return False - return True - - def __iter__(self) -> Iterator[T_batch_sample]: - trace_span = self.worker_config.worker_trace_span() + @trace_iter( + name=lambda self: f"PackingDataset({self._function_config(self.pre_packer)}, {self._function_config(self.final_packer)})", + call_args={ + "config": lambda self: self._own_config(), + }, + next_args={ + "idx": lambda self: self._pre_packing_sample_index.current_idx, + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_batch_sample]: + pre_packer = trace_iter.wrap_fn(self.pre_packer) + final_packer = trace_iter.wrap_fn(self.final_packer) if self.sample_encoder is not None: - encode_name = self._function_config(self.sample_encoder) - pre_packer_name = self._function_config(self.pre_packer) - final_packer_name = self._function_config(self.final_packer) + sample_encoder = trace_iter.wrap_fn(self.sample_encoder) + else: + sample_encoder = None last_pre_pack_failures = 0 last_final_pack_failures = 0 last_sample_encoder_failures = 0 + @trace_iter.wrap_inner( + call_args=lambda pack: { + "len": len(pack), + "sample_encoder_idx": self._sample_encoder_sample_index.current_idx, + } + ) def encode_pack_samples(pack: List[T_sample]) -> List[T_encoded_sample]: """Encode the samples in the pack using the sample encoder.""" nonlocal last_sample_encoder_failures + assert sample_encoder is not None # Apply the sample encoder to the pack - if self.sample_encoder is None: - return pack encoded_pack = [] - with trace_span.span( - "PackingDataset._encode_pack_samples", args={"len": len(pack)}, level=2 - ): - for sample in pack: - try: - with ( - self._sample_encoder_sample_index.ctx() as encode_idx, - trace_span.span(encode_name, args={"sample_idx": encode_idx}, level=2), - ): - encoded_sample = self.sample_encoder(sample) - assert not isinstance(encoded_sample, Generator), "Generator not supported" - encoded_pack.append( - add_sample_restore_key( - encoded_sample, - encode_idx, - src=self, - ) + for sample in pack: + try: + with self._sample_encoder_sample_index.ctx() as encode_idx: + encoded_sample = sample_encoder(sample) + assert not isinstance(encoded_sample, Generator), "Generator not supported" + encoded_pack.append( + add_sample_restore_key( + encoded_sample, + encode_idx, + src=self, ) - except SkipSample: - trace_span.instant("PackingDataset._encode_pack_samples.skip", level=2) - except SYSTEM_EXCEPTIONS: - raise FatalSampleError.from_sample(pack) - except Exception as e: - self.error_handler(e, [sample]) - trace_span.instant( - "PackingDataset._encode_pack_samples.error/skip", - args={"exception": f"{type(e).__name__}: {str(e)}"}, - level=2, + ) + except SkipSample: + trace_iter.skip_sample(pack) + except SYSTEM_EXCEPTIONS: + raise FatalSampleError.from_sample(pack) + except Exception as e: + self.error_handler(e, pack, get_source_info(pack)) + trace_iter.sample_exception(e, pack) + if ( + self.sample_encoder_failure_tolerance is not None + and last_sample_encoder_failures >= self.sample_encoder_failure_tolerance + ): + raise FatalSampleError.from_sample( + pack, + f"Sample encoder {sample_encoder} failed {last_sample_encoder_failures} times. Likely your code or dataset are broken.", ) - if ( - self.sample_encoder_failure_tolerance is not None - and last_sample_encoder_failures - >= self.sample_encoder_failure_tolerance - ): - raise FatalSampleError.from_sample( - pack, - f"Sample encoder {self.sample_encoder} failed {last_sample_encoder_failures} times. Likely your code or dataset are broken.", - ) return encoded_pack + @trace_iter.wrap_inner( + call_args=lambda: { + "len": len(self._reading_buffer), + "pre_packing_idx": self._pre_packing_sample_index.current_idx, + } + ) def next_pre_pack(): """Take the samples from the reading buffer and select groups of samples to be packed together.""" @@ -264,27 +245,16 @@ def next_pre_pack(): self._pre_packing_lengths.clear() # Now pre pack the samples try: - with ( - self._pre_packing_sample_index.ctx() as pre_pack_idx, - trace_span.span( - pre_packer_name, - args={"pre_pack_idx": pre_pack_idx, "len": len(samples)}, - level=2, - ), - ): - pre_packs = self.pre_packer(samples) + with self._pre_packing_sample_index.ctx(): + pre_packs = pre_packer(samples) except SkipSample: pre_packs = [] - trace_span.instant("PackingDataset.next_pre_pack.skip", level=2) + trace_iter.skip_sample(samples) except SYSTEM_EXCEPTIONS: raise FatalSampleError.from_sample(samples) except Exception as e: - self.error_handler(e, samples) - trace_span.instant( - "PackingDataset.next_pre_pack.error/skip", - args={"exception": f"{type(e).__name__}: {str(e)}"}, - level=2, - ) + self.error_handler(e, samples, get_source_info(samples)) + trace_iter.sample_exception(e, samples) pre_packs = [] last_pre_pack_failures += 1 if ( @@ -293,7 +263,7 @@ def next_pre_pack(): ): raise FatalSampleError.from_sample( samples, - f"Pre packer {self.pre_packer} failed {last_pre_pack_failures} times. Likely your code or dataset are broken.", + f"Pre packer {pre_packer} failed {last_pre_pack_failures} times. Likely your code or dataset are broken.", ) # Put the pre-packed samples into the pre_packing_buffer @@ -305,6 +275,12 @@ def next_pre_pack(): self._pre_packing_buffer.extend(pre_pack) self._pre_packing_lengths.append(len(pre_pack)) + @trace_iter.wrap_inner( + call_args=lambda pack: { + "len": len(pack), + "final_packing_idx": self._final_packing_sample_index.current_idx, + } + ) def next_final_pack() -> Generator[T_batch_sample, None, None]: """Yield the next packs from the buffer. The final packer is called on the fly.""" nonlocal last_final_pack_failures @@ -312,62 +288,46 @@ def next_final_pack() -> Generator[T_batch_sample, None, None]: pack = list(self._pre_packing_buffer[: self._pre_packing_lengths[0]]) if len(pack) == 0: return - pack = encode_pack_samples(pack) - if len(pack) == 0: - # All samples in the pack were skipped - return + if sample_encoder is not None: + pack = encode_pack_samples(pack) + if len(pack) == 0: + # All samples in the pack were skipped + return del self._pre_packing_buffer[: self._pre_packing_lengths[0]] del self._pre_packing_lengths[0] try: pack_restore_keys = tuple(get_sample_restore_key(sample) for sample in pack) - with ( - self._final_packing_sample_index.ctx() as pack_idx, - trace_span.span( - final_packer_name, args={"pack_idx": pack_idx, "len": len(pack)}, level=2 - ), - ): - final_packed_sample = self.final_packer(pack) + with self._final_packing_sample_index.ctx() as pack_idx: + final_packed_sample = final_packer(pack) if isinstance(final_packed_sample, Generator): assert inspect.isgeneratorfunction(self.final_packer), ( f"Generator in {self.final_packer} but not marked as such." ) - for pack_sub_idx, (pack_idx, inner_batch_sample) in trace_span.iterable( - enumerate( - self._final_packing_sample_index.iter_ctx(final_packed_sample, pack_idx) - ), - name=f"{final_packer_name}.next", - level=2, + for pack_sub_idx, (pack_idx, inner_batch_sample) in enumerate( + self._final_packing_sample_index.iter_ctx(final_packed_sample, pack_idx) ): - with trace_gen.yield_( - last_args={"pack_idx": pack_idx, "pack_sub_idx": pack_sub_idx} - ): - yield set_sample_restore_key( - inner_batch_sample, - pack_idx, - pack_sub_idx, - *pack_restore_keys, - src=self, - ) - else: - with trace_gen.yield_(last_args={"pack_idx": pack_idx}): yield set_sample_restore_key( - final_packed_sample, + inner_batch_sample, pack_idx, + pack_sub_idx, *pack_restore_keys, src=self, ) + else: + yield set_sample_restore_key( + final_packed_sample, + pack_idx, + *pack_restore_keys, + src=self, + ) except SkipSample: - trace_span.instant("PackingDataset.next_final_pack.skip", level=2) + trace_iter.skip_sample(pack) except SYSTEM_EXCEPTIONS: raise FatalSampleError.from_sample(pack) except Exception as e: - self.error_handler(e, pack) - trace_span.instant( - "PackingDataset.next_final_pack.error/skip", - args={"exception": f"{type(e).__name__}: {str(e)}"}, - level=2, - ) + self.error_handler(e, pack, get_source_info(pack)) + trace_iter.sample_exception(e, pack) last_final_pack_failures += 1 if ( self.final_packer_failure_tolerance is not None @@ -378,84 +338,101 @@ def next_final_pack() -> Generator[T_batch_sample, None, None]: f"Final packer {self.final_packer} failed {last_final_pack_failures} times. Likely your code or dataset are broken.", ) - with ( - trace_span.span( - "PackingDataset.__iter__", args={"config": self._own_config()}, level=1 - ), - self.worker_config.worker_trace_writer().generator( - "PackingDataset.__iter__.next", level=2 - ) as trace_gen, - ): - # The source dataset - src_iter = iter(self.dataset) + @trace_iter.wrap_inner( + call_args=lambda source_iter, log_progress: { + "to_fill": self.buffer_size + - len(self._reading_buffer) + - len(self._pre_packing_buffer), + "reading_buffer": len(self._reading_buffer), + "pre_packing_buffer": len(self._pre_packing_buffer), + "buffer_size": self.buffer_size, + } + ) + def fill_reading_buffer( + source_iter: Iterator[T_sample], log_progress: bool = False + ) -> bool: + """ + Fill the reading buffer with samples from the dataset source iterator. - try: - self._pre_packing_buffer.worker_start() - self._reading_buffer.worker_start() + Args: + source_iter: Iterator of samples from the dataset. + log_progress: If True, log the progress of the filling. - is_initial_pack = True + Returns: + True if samples are successfully read into the buffer, False if no more data. + """ - pre_pack_round = 0 + if log_progress: + import tqdm - # Main loop: - while True: - if pre_pack_round > self.pre_packer_failure_tolerance: - raise RuntimeError( - f"Pre packer {self.pre_packer} did not yield any packs after {pre_pack_round} rounds. Likely your code or dataset are broken." - ) - with trace_span.span( - "PackingDataset.__iter__.fill_reading_buffer", - args={ - "to_fill": self.buffer_size - - len(self._reading_buffer) - - len(self._pre_packing_buffer), - "reading_buffer": len(self._reading_buffer), - "pre_packing_buffer": len(self._pre_packing_buffer), - "buffer_size": self.buffer_size, - }, - level=2, - ): - # Fill a portion of the buffer - if not self._fill_reading_buffer(src_iter, log_progress=is_initial_pack): - # Break out of the main loop when the source is exhausted. - break - is_initial_pack = False + pbar_ctx = pbar = tqdm.tqdm(total=self.buffer_size, desc="Filling reading buffer") + else: + pbar_ctx = contextlib.nullcontext() + pbar = None - # Create new pre packs if necessary + with pbar_ctx: + while len(self._reading_buffer) + len(self._pre_packing_buffer) < self.buffer_size: + try: + sample = next(source_iter) + self._reading_buffer.append(sample) + if pbar is not None: + pbar.update(1) + except StopIteration: + return False + return True + + # The source dataset + src_iter = iter(self.dataset) + + try: + self._pre_packing_buffer.worker_start() + self._reading_buffer.worker_start() + + is_initial_pack = True + + pre_pack_round = 0 + + # Main loop: + while True: + if pre_pack_round > self.pre_packer_failure_tolerance: + raise RuntimeError( + f"Pre packer {self.pre_packer} did not yield any packs after {pre_pack_round} rounds. Likely your code or dataset are broken." + ) + # Fill a portion of the buffer + if not fill_reading_buffer(src_iter, log_progress=is_initial_pack): + # Break out of the main loop when the source is exhausted. + break + is_initial_pack = False + + # Create new pre packs if necessary + if len(self._pre_packing_lengths) == 0: + assert len(self._pre_packing_buffer) == 0 + assert len(self._reading_buffer) == self.buffer_size + next_pre_pack() if len(self._pre_packing_lengths) == 0: - with trace_span.span("PackingDataset.__iter__.next_pre_pack", level=1): - assert len(self._pre_packing_buffer) == 0 - assert len(self._reading_buffer) == self.buffer_size - next_pre_pack() - if len(self._pre_packing_lengths) == 0: - # Retry packing, nothing was returned. - pre_pack_round += 1 - continue - # Reset the pre pack round counter for failing - pre_pack_round = 0 - - with trace_span.span("PackingDataset.__iter__.final_pack", level=2): - yield from next_final_pack() - - with trace_span.span("PackingDataset.__iter__.last", level=1): - # Yield the remaining packs, flushing the collecting buffer - while len(self._pre_packing_lengths) > 0: - with trace_span.span("PackingDataset.__iter__.last.final_pack", level=2): - yield from next_final_pack() - - # If there are still samples in the partial reading buffer, pre-pack them and yield the - # resulting (partial) packs - if len(self._reading_buffer) > 0: - with trace_span.span("PackingDataset.__iter__.last.next_pre_pack", level=1): - next_pre_pack() - - # Yield the remaining packs, flushing the collecting buffer - while len(self._pre_packing_lengths) > 0: - with trace_span.span("PackingDataset.__iter__.last.final_pack", level=2): - yield from next_final_pack() - finally: - if hasattr(src_iter, "close"): - src_iter.close() + # Retry packing, nothing was returned. + pre_pack_round += 1 + continue + # Reset the pre pack round counter for failing + pre_pack_round = 0 + + yield from next_final_pack() + + # Yield the remaining packs, flushing the collecting buffer + while len(self._pre_packing_lengths) > 0: + yield from next_final_pack() + + # If there are still samples in the partial reading buffer, pre-pack them and yield the + # resulting (partial) packs + if len(self._reading_buffer) > 0: + next_pre_pack() + + # Yield the remaining packs, flushing the collecting buffer + while len(self._pre_packing_lengths) > 0: + yield from next_final_pack() + finally: + if hasattr(src_iter, "close"): + src_iter.close() def can_restore_sample(self) -> bool: # Cannot really verify if the returned elements contain a __restore_key__. diff --git a/src/megatron/energon/wrappers/repeat_dataset.py b/src/megatron/energon/wrappers/repeat_dataset.py index 557702ed..613b7450 100644 --- a/src/megatron/energon/wrappers/repeat_dataset.py +++ b/src/megatron/energon/wrappers/repeat_dataset.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Generic, Iterator, Optional, TypeVar, Union from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.flavors.trace import TraceIter, trace_iter from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset @@ -52,73 +53,58 @@ def __len__(self): return len(self.dataset) return int(len(self.dataset) * self.repeats) - def __iter__(self) -> Iterator[T_sample]: + @trace_iter( + next_args={ + "idx": lambda self: self._index, + }, + call_args={ + "repetition": lambda self: self._repetition, + "inner_len": lambda self: len(self.dataset), + "config": lambda self: self._own_config(), + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_sample]: assert self.repeats is not None or self.dataset.worker_has_samples(), ( "Cannot repeat empty dataset indefinitely" ) + @trace_iter.wrap_inner( + call_args=lambda stop_after: { + "repetition": self._repetition, + "inner_len": len(self.dataset), + "stop_after": stop_after, + } + ) + def repeat(stop_after: Optional[int]): + for sample in self.dataset: + self._index += 1 + yield sample + + if stop_after is not None and self._index >= stop_after: + break + ds_len = len(self.dataset) - trace_span = self.worker_config.worker_trace_span() - with ( - trace_span.span( - "RepeatDataset.__iter__", - args={ - "repetition": self._repetition, - "inner_len": ds_len, - "config": self._own_config(), - }, - level=2, - ), - self.worker_config.worker_trace_writer().generator( - "RepeatDataset.__iter__.next", next_args={"idx": self._index}, level=2 - ) as trace_gen, - ): - while self.repeats is None or self._repetition < self.repeats: - with trace_span.span( - "RepeatDataset.__iter__.repeat", - args={ - "repetition": self._repetition, - "repeats": self.repeats, - }, - level=2, - ): - if self.repeats is not None and self._repetition == math.floor(self.repeats): - # Last iteration, adjust the number of samples - fraction = self.repeats - math.floor(self.repeats) - stop_after = math.floor(ds_len * fraction) - if self._index >= stop_after: - # We restored an index and it is already past the stop_after - trace_span.instant("RepeatDataset.__iter__.break(stop_after)", level=2) - break - else: - stop_after = None - - for sample in self.dataset: - trace_span.instant( - "RepeatDataset.__iter__.__iter__.yield", - args={ - "idx": self._index, - }, - level=2, - ) - self._index += 1 - with trace_gen.yield_(next_args={"idx": self._index}): - yield sample - - if stop_after is not None and self._index >= stop_after: - trace_span.instant( - "RepeatDataset.__iter__.__iter__.break(stop_after)", level=2 - ) - break - self._repetition += 1 - self._index = 0 - - if self.restart: - self._repetition = 0 + while self.repeats is None or self._repetition < self.repeats: + if self.repeats is not None and self._repetition == math.floor(self.repeats): + # Last iteration, adjust the number of samples + fraction = self.repeats - math.floor(self.repeats) + stop_after = math.floor(ds_len * fraction) + if self._index >= stop_after: + # We restored an index and it is already past the stop_after + break else: - # No more repeats - self._repetition = math.ceil(self.repeats) + stop_after = None + + yield from repeat(stop_after) + self._repetition += 1 + self._index = 0 + + if self.restart: + self._repetition = 0 + else: + # No more repeats + self._repetition = math.ceil(self.repeats) def _own_config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/wrappers/shuffle_buffer_dataset.py b/src/megatron/energon/wrappers/shuffle_buffer_dataset.py index 6c930366..11c23343 100644 --- a/src/megatron/energon/wrappers/shuffle_buffer_dataset.py +++ b/src/megatron/energon/wrappers/shuffle_buffer_dataset.py @@ -4,6 +4,7 @@ from typing import Any, Dict, Generic, Iterator, List, TypeVar from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.flavors.trace import TraceIter, trace_iter from megatron.energon.rng import WorkerRng from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset @@ -44,56 +45,44 @@ def reset_state_own(self) -> None: def __len__(self) -> int: return len(self.dataset) - def __iter__(self) -> Iterator[T_sample]: - trace_span = self.worker_config.worker_trace_span() - with ( - trace_span.span( - "ShuffleBufferDataset.__iter__", args={"config": self._own_config()}, level=1 - ), - self.worker_config.worker_trace_writer().generator( - "ShuffleBufferDataset.__iter__.next", level=2 - ) as trace_gen, - ): - self._active_buffer.worker_start() - it = iter(self._active_buffer.append_iter()) - try: - while True: - if len(self._active_buffer) >= self.size: - pop_idx = self._worker_rng.randbelow(len(self._active_buffer)) - sample_creation = self._sample_creation.pop(pop_idx) - trace_span.instant( - "ShuffleBufferDataset.__iter__.yield", - args={ - "idx": pop_idx, - "sample_creation": sample_creation, - "sample_age": self._iterations - sample_creation, - }, - level=2, - ) - with trace_gen.yield_(last_args={"idx": pop_idx}): - yield self._active_buffer.pop(pop_idx) - else: - try: - next(it) - self._sample_creation.append(self._iterations) - trace_span.instant( - "ShuffleBufferDataset.__iter__.append", - args={ - "idx": len(self._sample_creation) - 1, - "sample_creation": self._iterations, - }, - level=2, - ) - self._iterations += 1 - except StopIteration: - break - finally: - if hasattr(it, "close"): - it.close() - with trace_span.span("ShuffleBufferDataset.__iter__.final_buffer", level=2): - while len(self._active_buffer) > 0: + @trace_iter( + call_args={ + "config": lambda self: self._own_config(), + }, + next_args={ + "idx": lambda self: self._sample_creation[-1], + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_sample]: + self._active_buffer.worker_start() + it = iter(self._active_buffer.append_iter()) + try: + while True: + if len(self._active_buffer) >= self.size: pop_idx = self._worker_rng.randbelow(len(self._active_buffer)) + sample_creation = self._sample_creation.pop(pop_idx) + trace_iter.sample( + self._active_buffer.pop(pop_idx), + { + "idx": pop_idx, + "sample_creation": sample_creation, + "sample_age": self._iterations - sample_creation, + }, + ) yield self._active_buffer.pop(pop_idx) + else: + try: + next(it) + self._sample_creation.append(self._iterations) + self._iterations += 1 + except StopIteration: + break + finally: + if hasattr(it, "close"): + it.close() + while len(self._active_buffer) > 0: + pop_idx = self._worker_rng.randbelow(len(self._active_buffer)) + yield self._active_buffer.pop(pop_idx) def _own_config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/wrappers/watchdog_dataset.py b/src/megatron/energon/wrappers/watchdog_dataset.py index e2f7f59b..41c97806 100644 --- a/src/megatron/energon/wrappers/watchdog_dataset.py +++ b/src/megatron/energon/wrappers/watchdog_dataset.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Generic, Iterator, Optional, TypeVar from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.flavors.trace import TraceIter, trace_iter from megatron.energon.watchdog import Watchdog from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset @@ -57,20 +58,23 @@ def _watchdog_trigger(self) -> None: RuntimeWarning, ) - def __iter__(self) -> Iterator[T_sample]: - with self.worker_config.worker_trace_span().span( - "WatchdogDataset.__iter__", args={"config": self._own_config()}, level=1 - ): - if self.timeout_seconds is None: - yield from self.dataset - else: - watchdog = Watchdog( - timeout=self.timeout_seconds, - initial_timeout=self.initial_timeout_seconds, - callback=self._watchdog_trigger, - enabled=False, - ) - yield from watchdog.watch_iter(self.dataset) + @trace_iter( + name=lambda self: f"WatchdogDataset({self._function_config(self.dataset)})", + call_args={ + "config": lambda self: self._own_config(), + }, + ) + def __iter__(self, trace_iter: TraceIter) -> Iterator[T_sample]: + if self.timeout_seconds is None: + yield from self.dataset + else: + watchdog = Watchdog( + timeout=self.timeout_seconds, + initial_timeout=self.initial_timeout_seconds, + callback=self._watchdog_trigger, + enabled=False, + ) + yield from watchdog.watch_iter(self.dataset) def _own_config(self) -> Dict[str, Any]: return { From d86cf0ee9212de1a1660c889ab7fa98a1840cbce Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Thu, 3 Jul 2025 10:39:46 +0200 Subject: [PATCH 7/7] License header update --- src/megatron/energon/flavors/trace.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/megatron/energon/flavors/trace.py b/src/megatron/energon/flavors/trace.py index aae0be52..44a671db 100644 --- a/src/megatron/energon/flavors/trace.py +++ b/src/megatron/energon/flavors/trace.py @@ -1,3 +1,6 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause + import functools from typing import Any, Callable, Dict, Iterator, List, Optional, Sequence, TypeVar, Union