From 044cc133493ed525e7cfacb0f45cf8d432785a45 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Fri, 4 Jul 2025 18:03:55 +0200 Subject: [PATCH 01/36] WIP: Custom dataloader --- src/megatron/energon/dataloader.py | 342 +++++++++++++++++++++++++++++ 1 file changed, 342 insertions(+) create mode 100644 src/megatron/energon/dataloader.py diff --git a/src/megatron/energon/dataloader.py b/src/megatron/energon/dataloader.py new file mode 100644 index 00000000..43e76237 --- /dev/null +++ b/src/megatron/energon/dataloader.py @@ -0,0 +1,342 @@ +from abc import abstractmethod +from concurrent.futures import Future +import functools +import multiprocessing +import os +import queue +import sys +import threading +from typing import Any, Concatenate, Generic, Protocol, Self, Sequence, TypeVar, cast +from typing import Callable, TypeVar, ParamSpec +import warnings +from torch.utils.data import IterableDataset + + +from megatron.energon.edataclass import edataclass +from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.metadataset.dataset_loader import DatasetLoader +from megatron.energon.rng import SystemRng +from megatron.energon.state import FlexState +from megatron.energon.worker import WorkerConfig +from megatron.energon.wrappers.base import SampleIndex +from megatron.energon.wrappers.gc_dataset import gc_init_worker + + +class QueueProtocol(Protocol): + def get(self, /) -> Any: + ... + + def put(self, item: Any, /) -> None: + ... + + +P = ParamSpec('P') +R = TypeVar('R', covariant=True) +TSelf = TypeVar('TSelf', bound="DataLoaderWorker") + + +class Future(Protocol[R]): + def get(self) -> R: + ... + +TSample = TypeVar('TSample') + + +class DataLoaderWorker(Generic[TSample]): + """ + A worker for a :class:`DataLoader`. + + The worker is responsible for executing the commands sent by the main process and returning the results. + It also handles the communication with the main process via the command and result queues. + + There are different implementations of a worker: + - :class:`ForkDataLoaderWorker` - A worker that forks a new process for each worker. + - :class:`ThreadDataLoaderWorker` - A worker that uses threads to execute the commands. + """ + def __init__(self, dataset: SavableDataset, worker_config: WorkerConfig, worker_id: int, cmd_queue: QueueProtocol, result_queue: QueueProtocol, data_queue: QueueProtocol): + self.dataset = dataset + self.worker_config = worker_config + self._worker_id = worker_id + self._cmd_queue = cmd_queue + self._result_queue = result_queue + self._data_queue = data_queue + self._next_future_id = 0 + self._futures = {} + + # ------------------------------------------------------------------------------------------------ + # Section: Remote call implementation + + class FutureImpl(Future): + _outerself: "DataLoaderWorker" + _future_id: int + _result: Any + + def __init__(self, outerself: "DataLoaderWorker", future_id: int): + self._outerself = outerself + self._future_id = future_id + + def get(self) -> Any: + if not hasattr(self, "_result"): + self._outerself._wait_for_worker_result(self._future_id) + if isinstance(self._result, Exception): + raise self._result + return self._result + + def _set_result(self, result: Any) -> None: + self._result = result + + @edataclass + class WorkerCommand: + cmd: str + args: tuple[Any, ...] + kwargs: dict[str, Any] + future_id: int + + @staticmethod + def worker_call(fn: Callable[P, R]) -> Callable[P, R]: + """Make the function be called in the worker process via the command and result queues. + The function must be a method of the `DataLoaderWorker` class.""" + @functools.wraps(fn) + def wrapper(self, *args, **kwargs) -> R: + future = self._worker_call(fn.__name__, *args, **kwargs) + return future.get() + setattr(wrapper, '_orig', fn) + return cast(Callable[P, R], wrapper) + + @staticmethod + def worker_call_async(fn: Callable[P, R]) -> Callable[P, Future[R]]: + """Make the function be called in the worker process via the command and result queues. + The function must be a method of the `DataLoaderWorker` class.""" + @functools.wraps(fn) + def wrapper(self, *args, **kwargs) -> Future[R]: + return self._worker_call(fn.__name__, *args, **kwargs) + setattr(wrapper, '_orig', fn) + return cast(Callable[P, Future[R]], wrapper) + + def _wait_for_worker_result(self, future_id: int) -> None: + while True: + future_id, res = self._result_queue.get() + fut = self._futures.pop(future_id) + fut._set_result(res) + if future_id == future_id: + return + + def _worker_call(self, fn: str, *args: Any, **kwargs: Any) -> Future[Any]: + self._assert_running() + future_id = self._next_future_id + self._next_future_id += 1 + + self._futures[future_id] = self.FutureImpl(self, future_id) + self._cmd_queue.put(self.WorkerCommand(cmd=fn, args=args, kwargs=kwargs, future_id=future_id)) + return self._futures[future_id] + + def _worker_run(self, worker_id: int, cmd_queue: QueueProtocol, result_queue: QueueProtocol, data_queue: QueueProtocol, seed: int) -> None: + SystemRng.seed(seed) + self._worker_id = worker_id + self._data_queue = data_queue + while True: + cmd: DataLoaderWorker.WorkerCommand | None = cmd_queue.get() + if cmd is None: + break + try: + fn = getattr(self, cmd.cmd) + result = getattr(fn, '_orig')(self, *cmd.args, **cmd.kwargs) + result_queue.put((cmd.future_id, result)) + del result + except Exception as e: + result_queue.put((cmd.future_id, e)) + + # ------------------------------------------------------------------------------------------------ + # Section: Main control methods + + @abstractmethod + def start(self) -> None: + pass + + @abstractmethod + def shutdown(self) -> None: + pass + + @abstractmethod + def running(self) -> bool: + pass + + def _assert_running(self) -> None: + assert self.running(), "Worker must be running" + + # ------------------------------------------------------------------------------------------------ + # Section: Worker methods + + @worker_call + def dataset_init(self, initial_state: FlexState | None) -> None: + self._sample_index = SampleIndex(worker_config=self.worker_config, src=self) + if initial_state is None: + self.dataset.reset_state_deep() + else: + assert initial_state["__class__"] == type(self).__name__, "Worker type mismatch" + self._sample_index.restore_state(initial_state["_sample_index"]) + self.dataset.restore_state(initial_state["datasets"][0]) + + @worker_call + def start_iter(self) -> None: + self._dataset_iter = iter(self.dataset) + + @worker_call_async + def prefetch_next(self) -> tuple[int, TSample]: + assert self._dataset_iter is not None, "start_iter must be called before prefetch_next" + with self._sample_index.ctx() as sample_idx: + next_sample = next(self._dataset_iter) + return sample_idx, next_sample + + @worker_call + def save_state(self) -> FlexState: + return FlexState( + __class__=type(self).__name__, + rng=SystemRng.save_state(), + datasets=[self.dataset.save_state()], + _sample_index=self._sample_index.save_state(), + ) + + +class ForkDataLoaderWorker(DataLoaderWorker[TSample], Generic[TSample]): + _cmd_queue: multiprocessing.Queue + _result_queue: multiprocessing.Queue + _data_queue: multiprocessing.Queue + _process: multiprocessing.Process | None + + def __init__(self, dataset: SavableDataset, num_workers: int, worker_id: int): + super().__init__(dataset, num_workers=num_workers, worker_id=worker_id, cmd_queue=multiprocessing.Queue(), result_queue=multiprocessing.Queue(), data_queue=multiprocessing.Queue()) + self._spawning_process = multiprocessing.current_process() + + def _check_parent_process(self, evt_exit: threading.Event) -> None: + """Check if the parent process is alive. If it is not, exit the worker process.""" + parent_proc = multiprocessing.parent_process() + if parent_proc is None: + print("No parent process, exiting", file=sys.stderr) + os._exit(-1) + while not evt_exit.wait(1): + if parent_proc.exitcode is not None: + print("Parent process died, exiting", file=sys.stderr) + os._exit(-1) + + def _worker_run(self, worker_id: int, cmd_queue: multiprocessing.Queue, result_queue: multiprocessing.Queue, data_queue: multiprocessing.Queue, seed: int) -> None: + gc_init_worker(worker_id) + worker_exit_evt = threading.Event() + parent_check_thread = threading.Thread(target=self._check_parent_process, args=(worker_exit_evt,), daemon=True) + parent_check_thread.start() + try: + super()._worker_run(worker_id, cmd_queue, result_queue, data_queue, seed) + finally: + worker_exit_evt.set() + parent_check_thread.join() + cmd_queue.cancel_join_thread() + cmd_queue.close() + result_queue.cancel_join_thread() + result_queue.close() + data_queue.cancel_join_thread() + data_queue.close() + + def start(self) -> None: + self._process = multiprocessing.Process(target=self._worker_run, args=(self._worker_id, self._cmd_queue, self._result_queue, self._data_queue)) + self._process.start() + + def shutdown(self) -> None: + if self._spawning_process != multiprocessing.current_process(): + # Should avoid forked process containing a forked worker on exit. + warnings.warn("Shutting down worker from a different process than the one that spawned it, skipping") + return + if self._process is not None: + self._cmd_queue.put(None) + self._process.join() + self._cmd_queue.cancel_join_thread() + self._cmd_queue.close() + self._result_queue.cancel_join_thread() + self._result_queue.close() + self._data_queue.cancel_join_thread() + self._data_queue.close() + + def running(self) -> bool: + return self._process is not None + + def _assert_running(self) -> None: + assert self._process is not None, "Worker must be started first" + assert self._process.is_alive(), "Worker died" + + +class ThreadDataLoaderWorker(DataLoaderWorker[TSample], Generic[TSample]): + def __init__(self, dataset: SavableDataset, num_workers: int, worker_id: int): + super().__init__(dataset, num_workers=num_workers, worker_id=worker_id, cmd_queue=queue.Queue(), result_queue=queue.Queue(), data_queue=queue.Queue()) + + def _worker_run(self, worker_id: int, cmd_queue: queue.Queue, result_queue: queue.Queue, data_queue: queue.Queue, seed: int) -> None: + # TODO: Implement init_thread which should hook all randomness such that it's thread local. + SystemRng.init_thread() + super()._worker_run(worker_id, cmd_queue, result_queue, data_queue, seed) + + def start(self) -> None: + self._thread = threading.Thread(target=self._worker_run, args=(self._worker_id, self._cmd_queue, self._result_queue, self._data_queue)) + self._thread.start() + + def shutdown(self) -> None: + if self._thread is not None: + self._cmd_queue.put(None) + self._thread.join() + self._thread = None + + def running(self) -> bool: + return self._thread is not None + + +class WorkerType(Protocol[TSample]): + def __call__(self, dataset: SavableDataset, num_workers: int, worker_id: int) -> DataLoaderWorker[TSample]: + ... + + +class DataLoader(Generic[TSample]): + _workers: list[DataLoaderWorker[TSample]] + + _dataset: SavableDataset + _num_workers: int + _worker_type: WorkerType + + def __init__(self, dataset: SavableDataset, num_workers: int = 0, worker_type: WorkerType = ForkDataLoaderWorker): + self._dataset = dataset + self._num_workers = num_workers + self._worker_type = worker_type + self._workers = [] + + def start(self, initial_states: Sequence[FlexState | None]) -> None: + if self._num_workers == 0: + return + self._workers = [self._worker_type(self._dataset, self._num_workers, local_worker_id) for local_worker_id in range(self._num_workers)] + for worker in self._workers: + worker.start() + + for worker, initial_state in zip(self._workers, initial_states): + worker.dataset_init(initial_state, 0) + + def shutdown(self) -> None: + for worker in self._workers: + worker.shutdown() + self._workers = [] + + def _get_iterator(self): + if self._num_workers == 0: + # Easy case: no workers, just iterate over the dataset. + yield from self._dataset + return + + for worker in self._workers: + worker.start_iter() + + while True: + for worker in self._workers: + yield worker.prefetch_next() + + def __iter__(self): + return self._get_iterator() + + def __del__(self) -> None: + self.shutdown() + + def __len__(self): + return len(self._dataset) From 2ed657abff3a535dc468b4c68a29a740e5e3c726 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Tue, 8 Jul 2025 09:16:34 +0200 Subject: [PATCH 02/36] WIP: Custom dataloader --- src/megatron/energon/dataloader.py | 619 ++++++++++++++++++++++------- 1 file changed, 477 insertions(+), 142 deletions(-) diff --git a/src/megatron/energon/dataloader.py b/src/megatron/energon/dataloader.py index 43e76237..56c982d3 100644 --- a/src/megatron/energon/dataloader.py +++ b/src/megatron/energon/dataloader.py @@ -1,65 +1,154 @@ -from abc import abstractmethod -from concurrent.futures import Future +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause import functools import multiprocessing import os import queue import sys import threading -from typing import Any, Concatenate, Generic, Protocol, Self, Sequence, TypeVar, cast -from typing import Callable, TypeVar, ParamSpec import warnings -from torch.utils.data import IterableDataset - - +from typing import ( + Any, + Callable, + Generator, + Generic, + ParamSpec, + Protocol, + TypeVar, + cast, + override, +) + +from megatron.energon.cache.base import CachePool +from megatron.energon.cache.no_cache import NoCachePool from megatron.energon.edataclass import edataclass -from megatron.energon.flavors.base_dataset import SavableDataset -from megatron.energon.metadataset.dataset_loader import DatasetLoader +from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key from megatron.energon.rng import SystemRng from megatron.energon.state import FlexState from megatron.energon.worker import WorkerConfig -from megatron.energon.wrappers.base import SampleIndex +from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key +from megatron.energon.wrappers.batch_dataset import BatchDataset from megatron.energon.wrappers.gc_dataset import gc_init_worker class QueueProtocol(Protocol): - def get(self, /) -> Any: - ... + def get(self, /) -> Any: ... - def put(self, item: Any, /) -> None: - ... + def put(self, item: Any, /) -> None: ... -P = ParamSpec('P') -R = TypeVar('R', covariant=True) -TSelf = TypeVar('TSelf', bound="DataLoaderWorker") +P = ParamSpec("P") +R = TypeVar("R", covariant=True) +TSelf = TypeVar("TSelf", bound="DataLoaderWorker") +TSample = TypeVar("TSample") class Future(Protocol[R]): - def get(self) -> R: - ... + def get(self) -> R: ... + -TSample = TypeVar('TSample') +class DoneFuture(Future[TSample]): + def __init__(self, result: TSample): + self._result = result + + def get(self) -> TSample: + return self._result class DataLoaderWorker(Generic[TSample]): """ A worker for a :class:`DataLoader`. - The worker is responsible for executing the commands sent by the main process and returning the results. - It also handles the communication with the main process via the command and result queues. + The basic implementation iterates the dataset. + The async extension implements the main commands via a command and results queue. + """ + + def __init__( + self, + dataset: SavableDataset, + worker_config: WorkerConfig, + rank_worker_id: int, + cache_pool: CachePool, + ): + self.dataset = dataset + self.worker_config = worker_config + self._rank_worker_id = rank_worker_id + self._global_worker_id = worker_config.global_worker_id(rank_worker_id) + self._cache_pool = cache_pool + + # ------------------------------------------------------------------------------------------------ + # Section: Main control methods + + def start(self) -> None: + pass + + def shutdown(self) -> None: + pass + + def running(self) -> bool: + return True + + def _assert_running(self) -> None: + assert self.running(), "Worker must be running" + + # ------------------------------------------------------------------------------------------------ + # Section: Worker methods + + def dataset_init(self, initial_state: FlexState | None) -> None: + self._sample_index = SampleIndex(worker_config=self.worker_config, src=self) + self._global_worker_id = self.worker_config.global_worker_id() + if initial_state is None: + self.dataset.reset_state_deep() + else: + assert initial_state["__class__"] == "DataLoaderWorker", "Worker type mismatch" + self._sample_index.restore_state(initial_state["_sample_index"]) + self.dataset.restore_state(initial_state["datasets"][0]) + + def new_iter(self) -> None: + self._dataset_iter = iter(self.dataset) + + def prefetch_next(self) -> Future[TSample]: + assert self._dataset_iter is not None, "start_iter must be called before prefetch_next" + with self._sample_index.ctx() as sample_idx: + self.worker_config.worker_activate(sample_idx, cache_pool=self._cache_pool) + try: + next_sample = next(self._dataset_iter) + add_sample_restore_key(next_sample, self._global_worker_id, sample_idx, src=self) + finally: + self.worker_config.worker_deactivate() + return DoneFuture(next_sample) + + def save_state(self) -> FlexState: + return FlexState( + __class__="DataLoaderWorker", + rng=SystemRng.save_state(), + dataset=self.dataset.save_state(), + _sample_index=self._sample_index.save_state(), + ) + - There are different implementations of a worker: +class _DataLoaderAsyncWorker(DataLoaderWorker[TSample]): + """ + Extension of the `DataLoaderWorker`, which implements commands via a command and results queue. + + There are different implementations of the async worker: - :class:`ForkDataLoaderWorker` - A worker that forks a new process for each worker. - :class:`ThreadDataLoaderWorker` - A worker that uses threads to execute the commands. """ - def __init__(self, dataset: SavableDataset, worker_config: WorkerConfig, worker_id: int, cmd_queue: QueueProtocol, result_queue: QueueProtocol, data_queue: QueueProtocol): - self.dataset = dataset - self.worker_config = worker_config - self._worker_id = worker_id + + def __init__( + self, + dataset: SavableDataset, + worker_config: WorkerConfig, + rank_worker_id: int, + cmd_queue: QueueProtocol, + result_queue: QueueProtocol, + cache_pool: CachePool, + ): + super().__init__(dataset, worker_config, rank_worker_id, cache_pool) + assert worker_config.num_workers > 0, "Async workers require num_workers > 0" self._cmd_queue = cmd_queue self._result_queue = result_queue - self._data_queue = data_queue self._next_future_id = 0 self._futures = {} @@ -67,11 +156,11 @@ def __init__(self, dataset: SavableDataset, worker_config: WorkerConfig, worker_ # Section: Remote call implementation class FutureImpl(Future): - _outerself: "DataLoaderWorker" + _outerself: "_DataLoaderAsyncWorker" _future_id: int _result: Any - def __init__(self, outerself: "DataLoaderWorker", future_id: int): + def __init__(self, outerself: "_DataLoaderAsyncWorker", future_id: int): self._outerself = outerself self._future_id = future_id @@ -81,7 +170,7 @@ def get(self) -> Any: if isinstance(self._result, Exception): raise self._result return self._result - + def _set_result(self, result: Any) -> None: self._result = result @@ -91,26 +180,30 @@ class WorkerCommand: args: tuple[Any, ...] kwargs: dict[str, Any] future_id: int - + @staticmethod def worker_call(fn: Callable[P, R]) -> Callable[P, R]: """Make the function be called in the worker process via the command and result queues. The function must be a method of the `DataLoaderWorker` class.""" + @functools.wraps(fn) def wrapper(self, *args, **kwargs) -> R: future = self._worker_call(fn.__name__, *args, **kwargs) return future.get() - setattr(wrapper, '_orig', fn) + + setattr(wrapper, "_orig", fn) return cast(Callable[P, R], wrapper) @staticmethod def worker_call_async(fn: Callable[P, R]) -> Callable[P, Future[R]]: """Make the function be called in the worker process via the command and result queues. The function must be a method of the `DataLoaderWorker` class.""" + @functools.wraps(fn) def wrapper(self, *args, **kwargs) -> Future[R]: return self._worker_call(fn.__name__, *args, **kwargs) - setattr(wrapper, '_orig', fn) + + setattr(wrapper, "_orig", fn) return cast(Callable[P, Future[R]], wrapper) def _wait_for_worker_result(self, future_id: int) -> None: @@ -127,87 +220,86 @@ def _worker_call(self, fn: str, *args: Any, **kwargs: Any) -> Future[Any]: self._next_future_id += 1 self._futures[future_id] = self.FutureImpl(self, future_id) - self._cmd_queue.put(self.WorkerCommand(cmd=fn, args=args, kwargs=kwargs, future_id=future_id)) + self._cmd_queue.put( + self.WorkerCommand(cmd=fn, args=args, kwargs=kwargs, future_id=future_id) + ) return self._futures[future_id] - def _worker_run(self, worker_id: int, cmd_queue: QueueProtocol, result_queue: QueueProtocol, data_queue: QueueProtocol, seed: int) -> None: + def _worker_run( + self, worker_id: int, cmd_queue: QueueProtocol, result_queue: QueueProtocol, seed: int + ) -> None: SystemRng.seed(seed) self._worker_id = worker_id - self._data_queue = data_queue + import torch.utils.data._utils + + torch.utils.data._utils.worker._worker_info = torch.utils.data._utils.worker.WorkerInfo( + id=worker_id, + num_workers=self.worker_config.num_workers, + seed=seed, + dataset=self.dataset, + ) + self._global_worker_id = self.worker_config.global_worker_id() + self.worker_config.assert_worker() while True: - cmd: DataLoaderWorker.WorkerCommand | None = cmd_queue.get() + cmd: _DataLoaderAsyncWorker.WorkerCommand | None = cmd_queue.get() if cmd is None: break try: fn = getattr(self, cmd.cmd) - result = getattr(fn, '_orig')(self, *cmd.args, **cmd.kwargs) + result = getattr(fn, "_orig")(self, *cmd.args, **cmd.kwargs) result_queue.put((cmd.future_id, result)) del result except Exception as e: result_queue.put((cmd.future_id, e)) # ------------------------------------------------------------------------------------------------ - # Section: Main control methods - - @abstractmethod - def start(self) -> None: - pass - - @abstractmethod - def shutdown(self) -> None: - pass - - @abstractmethod - def running(self) -> bool: - pass - - def _assert_running(self) -> None: - assert self.running(), "Worker must be running" - - # ------------------------------------------------------------------------------------------------ - # Section: Worker methods + # Section: Worker methods - now calling to workers via queues. + @override @worker_call def dataset_init(self, initial_state: FlexState | None) -> None: - self._sample_index = SampleIndex(worker_config=self.worker_config, src=self) - if initial_state is None: - self.dataset.reset_state_deep() - else: - assert initial_state["__class__"] == type(self).__name__, "Worker type mismatch" - self._sample_index.restore_state(initial_state["_sample_index"]) - self.dataset.restore_state(initial_state["datasets"][0]) - + super().dataset_init(initial_state) + + @override @worker_call - def start_iter(self) -> None: - self._dataset_iter = iter(self.dataset) + def new_iter(self) -> None: + super().new_iter() + @override @worker_call_async - def prefetch_next(self) -> tuple[int, TSample]: - assert self._dataset_iter is not None, "start_iter must be called before prefetch_next" - with self._sample_index.ctx() as sample_idx: - next_sample = next(self._dataset_iter) - return sample_idx, next_sample + def prefetch_next(self) -> TSample: + return super().prefetch_next().get() + @override @worker_call def save_state(self) -> FlexState: - return FlexState( - __class__=type(self).__name__, - rng=SystemRng.save_state(), - datasets=[self.dataset.save_state()], - _sample_index=self._sample_index.save_state(), - ) + return super().save_state() -class ForkDataLoaderWorker(DataLoaderWorker[TSample], Generic[TSample]): +class ForkDataLoaderWorker(_DataLoaderAsyncWorker[TSample], Generic[TSample]): _cmd_queue: multiprocessing.Queue _result_queue: multiprocessing.Queue - _data_queue: multiprocessing.Queue _process: multiprocessing.Process | None - def __init__(self, dataset: SavableDataset, num_workers: int, worker_id: int): - super().__init__(dataset, num_workers=num_workers, worker_id=worker_id, cmd_queue=multiprocessing.Queue(), result_queue=multiprocessing.Queue(), data_queue=multiprocessing.Queue()) - self._spawning_process = multiprocessing.current_process() - + _spawning_process: int + + def __init__( + self, + dataset: SavableDataset, + worker_config: WorkerConfig, + rank_worker_id: int, + cache_pool: CachePool, + ): + super().__init__( + dataset, + worker_config=worker_config, + rank_worker_id=rank_worker_id, + cmd_queue=multiprocessing.Queue(), + result_queue=multiprocessing.Queue(), + cache_pool=cache_pool, + ) + self._spawning_process = os.getpid() + def _check_parent_process(self, evt_exit: threading.Event) -> None: """Check if the parent process is alive. If it is not, exit the worker process.""" parent_proc = multiprocessing.parent_process() @@ -219,13 +311,21 @@ def _check_parent_process(self, evt_exit: threading.Event) -> None: print("Parent process died, exiting", file=sys.stderr) os._exit(-1) - def _worker_run(self, worker_id: int, cmd_queue: multiprocessing.Queue, result_queue: multiprocessing.Queue, data_queue: multiprocessing.Queue, seed: int) -> None: + def _worker_run( + self, + worker_id: int, + cmd_queue: multiprocessing.Queue, + result_queue: multiprocessing.Queue, + seed: int, + ) -> None: gc_init_worker(worker_id) worker_exit_evt = threading.Event() - parent_check_thread = threading.Thread(target=self._check_parent_process, args=(worker_exit_evt,), daemon=True) + parent_check_thread = threading.Thread( + target=self._check_parent_process, args=(worker_exit_evt,), daemon=True + ) parent_check_thread.start() try: - super()._worker_run(worker_id, cmd_queue, result_queue, data_queue, seed) + super()._worker_run(worker_id, cmd_queue, result_queue, seed) finally: worker_exit_evt.set() parent_check_thread.join() @@ -233,17 +333,22 @@ def _worker_run(self, worker_id: int, cmd_queue: multiprocessing.Queue, result_q cmd_queue.close() result_queue.cancel_join_thread() result_queue.close() - data_queue.cancel_join_thread() - data_queue.close() - + + @override def start(self) -> None: - self._process = multiprocessing.Process(target=self._worker_run, args=(self._worker_id, self._cmd_queue, self._result_queue, self._data_queue)) + self._process = multiprocessing.Process( + target=self._worker_run, + args=(self._rank_worker_id, self._cmd_queue, self._result_queue), + ) self._process.start() + @override def shutdown(self) -> None: - if self._spawning_process != multiprocessing.current_process(): + if self._spawning_process != os.getpid(): # Should avoid forked process containing a forked worker on exit. - warnings.warn("Shutting down worker from a different process than the one that spawned it, skipping") + warnings.warn( + "Shutting down worker from a different process than the one that spawned it, skipping" + ) return if self._process is not None: self._cmd_queue.put(None) @@ -252,9 +357,8 @@ def shutdown(self) -> None: self._cmd_queue.close() self._result_queue.cancel_join_thread() self._result_queue.close() - self._data_queue.cancel_join_thread() - self._data_queue.close() + @override def running(self) -> bool: return self._process is not None @@ -263,80 +367,311 @@ def _assert_running(self) -> None: assert self._process.is_alive(), "Worker died" -class ThreadDataLoaderWorker(DataLoaderWorker[TSample], Generic[TSample]): - def __init__(self, dataset: SavableDataset, num_workers: int, worker_id: int): - super().__init__(dataset, num_workers=num_workers, worker_id=worker_id, cmd_queue=queue.Queue(), result_queue=queue.Queue(), data_queue=queue.Queue()) - - def _worker_run(self, worker_id: int, cmd_queue: queue.Queue, result_queue: queue.Queue, data_queue: queue.Queue, seed: int) -> None: +class ThreadDataLoaderWorker(_DataLoaderAsyncWorker[TSample], Generic[TSample]): + def __init__( + self, + dataset: SavableDataset, + worker_config: WorkerConfig, + rank_worker_id: int, + cache_pool: CachePool, + ): + super().__init__( + dataset, + worker_config=worker_config, + rank_worker_id=rank_worker_id, + cmd_queue=queue.Queue(), + result_queue=queue.Queue(), + cache_pool=cache_pool, + ) + + def _worker_run( + self, worker_id: int, cmd_queue: queue.Queue, result_queue: queue.Queue, seed: int + ) -> None: # TODO: Implement init_thread which should hook all randomness such that it's thread local. SystemRng.init_thread() - super()._worker_run(worker_id, cmd_queue, result_queue, data_queue, seed) - + super()._worker_run(worker_id, cmd_queue, result_queue, seed) + + @override def start(self) -> None: - self._thread = threading.Thread(target=self._worker_run, args=(self._worker_id, self._cmd_queue, self._result_queue, self._data_queue)) + self._thread = threading.Thread( + target=self._worker_run, + args=(self._rank_worker_id, self._cmd_queue, self._result_queue), + ) self._thread.start() - + + @override def shutdown(self) -> None: if self._thread is not None: self._cmd_queue.put(None) self._thread.join() self._thread = None + @override def running(self) -> bool: return self._thread is not None class WorkerType(Protocol[TSample]): - def __call__(self, dataset: SavableDataset, num_workers: int, worker_id: int) -> DataLoaderWorker[TSample]: - ... + def __call__( + self, + dataset: SavableDataset, + worker_config: WorkerConfig, + rank_worker_id: int, + cache_pool: CachePool, + ) -> DataLoaderWorker[TSample]: ... class DataLoader(Generic[TSample]): - _workers: list[DataLoaderWorker[TSample]] + _workers: list[DataLoaderWorker[TSample]] | None = None + _exhausted_workers: list[bool] + _next_worker_id: int = 0 + + _restore_state: FlexState | None = None _dataset: SavableDataset - _num_workers: int + _worker_config: WorkerConfig + _prefetch_factor: int _worker_type: WorkerType - - def __init__(self, dataset: SavableDataset, num_workers: int = 0, worker_type: WorkerType = ForkDataLoaderWorker): + _prefetching_samples: list[list[Future[TSample]]] + + _current_epoch_iter: Generator[TSample, None, None] | None = None + + _spawning_process: int + + def __init__( + self, + dataset: SavableDataset, + worker_config: WorkerConfig, + prefetch_factor: int = 2, + worker_type: WorkerType = ForkDataLoaderWorker, + cache_pool: CachePool = NoCachePool(), + ): + if self._worker_config.num_workers == 0 and worker_type == ForkDataLoaderWorker: + worker_type = DataLoaderWorker self._dataset = dataset - self._num_workers = num_workers + self._worker_config = worker_config + self._prefetch_factor = prefetch_factor self._worker_type = worker_type - self._workers = [] - - def start(self, initial_states: Sequence[FlexState | None]) -> None: - if self._num_workers == 0: - return - self._workers = [self._worker_type(self._dataset, self._num_workers, local_worker_id) for local_worker_id in range(self._num_workers)] - for worker in self._workers: - worker.start() + self._cache_pool = cache_pool + self._prefetching_samples = [[] for _ in range(self._worker_config.num_workers)] + self._exhausted_workers = [False] * self._worker_config.num_workers + + self._spawning_process = os.getpid() + + self._restore_state = None - for worker, initial_state in zip(self._workers, initial_states): - worker.dataset_init(initial_state, 0) - def shutdown(self) -> None: - for worker in self._workers: - worker.shutdown() - self._workers = [] + if self._workers is not None: + for worker in self._workers: + worker.shutdown() + self._workers = None - def _get_iterator(self): - if self._num_workers == 0: - # Easy case: no workers, just iterate over the dataset. - yield from self._dataset - return - - for worker in self._workers: - worker.start_iter() - - while True: + def start_iter(self) -> None: + if self._workers is not None: + for worker in self._workers: + worker.new_iter() + + def _epoch_iter(self) -> Generator[TSample, None, None]: + if self._workers is None: + self.start() + for worker, exhausted in zip(self._workers, self._exhausted_workers): + if not exhausted: + worker.new_iter() + + assert self._workers is not None, "DataLoader not started" + + if all(self._exhausted_workers): + # All workers are exhausted, restart for the next epoch. for worker in self._workers: - yield worker.prefetch_next() - - def __iter__(self): - return self._get_iterator() - + worker.new_iter() + self._exhausted_workers = [False] * self._worker_config.num_workers + + # For all workers, enqueue prefetching samples. + for worker_idx, (worker, exhausted) in enumerate( + zip(self._workers, self._exhausted_workers) + ): + while ( + len(self._prefetching_samples[worker_idx]) < self._prefetch_factor and not exhausted + ): + self._prefetching_samples[worker_idx].append(worker.prefetch_next()) + + # Main loop: + # - Get the next worker to prefetch samples from. + # - Prefetch samples from the worker. + # - Pop the first sample future from the prefetching samples. + # - Get the sample from the sample future (may wait for the sample to be prefetched). + # - Yield the sample. + while not all(self._exhausted_workers): + # Get the next worker to prefetch samples from. + worker_idx = self._next_worker_id + worker = self._workers[worker_idx] + self._next_worker_id = (worker_idx + 1) % self._worker_config.num_workers + if self._exhausted_workers[worker_idx]: + continue + # Prefetch samples from the worker. + while len(self._prefetching_samples[worker_idx]) < self._prefetch_factor: + # Add a new sample future to the prefetching samples if the worker has not prefetched enough samples. + self._prefetching_samples[worker_idx].append(worker.prefetch_next()) + # Pop the first sample future from the prefetching samples. + sample_future = self._prefetching_samples[worker_idx].pop(0) + try: + # Get the sample from the sample future (may wait for the sample to be ready). + sample = sample_future.get() + except Exception as e: + # If the sample future raises an exception, remove the worker from the list. + self._prefetching_samples[worker_idx] = [] + self._exhausted_workers[worker_idx] = True + raise e + else: + # Yield the sample. + yield sample + + def __iter__(self) -> Generator[TSample, None, None]: + # Restart the epoch iterator if was not created yet. Otherwise, the existing epoch iterator will be continued. + # That happens e.g. when iteration was interrupted. + if self._current_epoch_iter is None: + self._current_epoch_iter = self._epoch_iter() + yield from self._current_epoch_iter + # Reset the epoch iterator, it was exhausted. + self._current_epoch_iter = None + def __del__(self) -> None: - self.shutdown() + if self._spawning_process == os.getpid(): + # Otherwise we may be in a forked process which is not the one that spawned the DataLoader. + self.shutdown() def __len__(self): return len(self._dataset) + + def _get_batch_size(self) -> int | None: + """Try to infer micro batch size from the dataset""" + if ( + isinstance(self._dataset, BaseWrapperDataset) + and (bds := self._dataset._find_wrapped_dataset(BatchDataset)) is not None + ): + assert isinstance(bds, BatchDataset) + return bds.batch_size + else: + return None + + def save_state(self) -> FlexState: + prefetched_samples_keys = [ + [get_sample_restore_key(sample.get()) for sample in prefetching_sample] + for prefetching_sample in self._prefetching_samples + ] + if self._workers is None: + worker_states = [None] * self._worker_config.num_workers + else: + worker_states = [worker.save_state() for worker in self._workers] + + return FlexState( + __class__=type(self).__name__, + prefetched_samples_keys=prefetched_samples_keys, + worker_states=worker_states, + workers_exhausted=self._exhausted_workers.copy(), + next_worker_id=self._next_worker_id, + micro_batch_size=self._get_batch_size(), + ) + + def start(self, initial_state: FlexState | None = None) -> None: + assert self._workers is None and self._current_epoch_iter is None, ( + "DataLoader already started" + ) + self._workers = [ + self._worker_type(self._dataset, self._worker_config, local_worker_id, self._cache_pool) + for local_worker_id in range(max(self._worker_config.num_workers, 1)) + ] + for worker in self._workers: + worker.start() + + if initial_state is None: + if self._restore_state is not None: + initial_state = self._restore_state + self._restore_state = None + + if initial_state is None: + initial_states = [None] * self._worker_config.num_workers + else: + initial_states = initial_state["worker_states"] + + assert len(initial_states) == self._worker_config.num_workers, ( + "Number of initial states must match number of workers" + ) + + for worker, initial_state in zip(self._workers, initial_states): + worker.dataset_init(initial_state) + + if initial_state is not None: + self._prefetching_samples = [ + [ + DoneFuture(self.restore_sample(sample_key)) + for sample_key in prefetched_samples_keys + ] + for prefetched_samples_keys in initial_state["prefetched_samples_keys"] + ] + self._next_worker_id = initial_state["next_worker_id"] + self._exhausted_workers = initial_state["workers_exhausted"].copy() + + def restore_state_rank(self, state: FlexState | None) -> None: + assert self._workers is None and self._current_epoch_iter is None, ( + "DataLoader already started" + ) + + if state is None: + # Assume initial state. + return + + assert isinstance(state, FlexState) + assert state["__class__"] == type(self).__name__, "DataLoader type mismatch" + + old_micro_batch_size = state["micro_batch_size"] + micro_batch_size = self._get_batch_size() + + if self._worker_config.num_workers == 0: + assert micro_batch_size == old_micro_batch_size, "Micro batch size mismatch" + assert len(state["worker_states"]) == 1 + assert isinstance(state["worker_states"][0], FlexState) + self._dataset.restore_state(state["worker_states"][0]) + else: + # Check batch sizes (before and after) + if micro_batch_size != old_micro_batch_size: + assert micro_batch_size is not None and old_micro_batch_size is not None, ( + "Cannot resume with different batching mode " + "(batching to non-batching or vice versa)" + ) + + if micro_batch_size > old_micro_batch_size: + raise ValueError( + "Resuming with larger micro batch size is not allowed: " + f"{micro_batch_size} > {old_micro_batch_size}" + ) + elif ( + micro_batch_size < old_micro_batch_size + and old_micro_batch_size % micro_batch_size != 0 + ): + raise ValueError( + "Resuming with smaller micro batch size only allowed if the old " + f"micro batch size is a multiple of the new one: {micro_batch_size} < {old_micro_batch_size}" + ) + + self._restore_state = state + + def restore_sample(self, restore_key: tuple) -> TSample: + id, global_worker_id, sample_idx = restore_key[:3] + assert id == type(self).__name__ + restore_key = restore_key[3:] + self._worker_config.worker_activate( + sample_idx, override_global_rank=global_worker_id, cache_pool=self._cache_pool + ) + try: + return add_sample_restore_key( + self._dataset.restore_sample(restore_key), global_worker_id, sample_idx, src=self + ) + finally: + self._worker_config.worker_deactivate() + + def config(self) -> dict[str, Any]: + return self._dataset.config() + + def __str__(self) -> str: + return f"DataLoader(prefetch_factor={self._prefetch_factor}, worker_type={self._worker_type.__name__})" From 8cfd43a5d7efc0280edddfbf2c4d278db8b35526 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Wed, 9 Jul 2025 11:50:15 +0200 Subject: [PATCH 03/36] WIP: Custom dataloader --- src/megatron/energon/dataloader.py | 137 +++++++++++++++++++++-------- 1 file changed, 101 insertions(+), 36 deletions(-) diff --git a/src/megatron/energon/dataloader.py b/src/megatron/energon/dataloader.py index 56c982d3..32353df1 100644 --- a/src/megatron/energon/dataloader.py +++ b/src/megatron/energon/dataloader.py @@ -30,24 +30,25 @@ from megatron.energon.wrappers.batch_dataset import BatchDataset from megatron.energon.wrappers.gc_dataset import gc_init_worker - -class QueueProtocol(Protocol): - def get(self, /) -> Any: ... - - def put(self, item: Any, /) -> None: ... - - P = ParamSpec("P") R = TypeVar("R", covariant=True) TSelf = TypeVar("TSelf", bound="DataLoaderWorker") TSample = TypeVar("TSample") +class QueueProtocol(Protocol[TSample]): + def get(self, /) -> TSample: ... + + def put(self, item: TSample, /) -> None: ... + + class Future(Protocol[R]): def get(self) -> R: ... class DoneFuture(Future[TSample]): + """Future that is already done.""" + def __init__(self, result: TSample): self._result = result @@ -55,6 +56,16 @@ def get(self) -> TSample: return self._result +class ExceptionFuture(Future[Any]): + """Future that raises an exception.""" + + def __init__(self, exception: Exception): + self._exception = exception + + def get(self) -> Any: + raise self._exception + + class DataLoaderWorker(Generic[TSample]): """ A worker for a :class:`DataLoader`. @@ -63,9 +74,18 @@ class DataLoaderWorker(Generic[TSample]): The async extension implements the main commands via a command and results queue. """ + dataset: SavableDataset[TSample] + worker_config: WorkerConfig + + _rank_worker_id: int + _global_worker_id: int + _cache_pool: CachePool + + exhausted: bool = True + def __init__( self, - dataset: SavableDataset, + dataset: SavableDataset[TSample], worker_config: WorkerConfig, rank_worker_id: int, cache_pool: CachePool, @@ -106,6 +126,7 @@ def dataset_init(self, initial_state: FlexState | None) -> None: def new_iter(self) -> None: self._dataset_iter = iter(self.dataset) + self.exhausted = False def prefetch_next(self) -> Future[TSample]: assert self._dataset_iter is not None, "start_iter must be called before prefetch_next" @@ -114,6 +135,9 @@ def prefetch_next(self) -> Future[TSample]: try: next_sample = next(self._dataset_iter) add_sample_restore_key(next_sample, self._global_worker_id, sample_idx, src=self) + except StopIteration as e: + self.exhausted = True + return ExceptionFuture(e) finally: self.worker_config.worker_deactivate() return DoneFuture(next_sample) @@ -136,13 +160,18 @@ class _DataLoaderAsyncWorker(DataLoaderWorker[TSample]): - :class:`ThreadDataLoaderWorker` - A worker that uses threads to execute the commands. """ + _cmd_queue: QueueProtocol["WorkerCommand"] + _result_queue: QueueProtocol["WorkerResult"] + _next_future_id: int + _futures: dict[int, "FutureImpl"] + def __init__( self, dataset: SavableDataset, worker_config: WorkerConfig, rank_worker_id: int, - cmd_queue: QueueProtocol, - result_queue: QueueProtocol, + cmd_queue: QueueProtocol["WorkerCommand"], + result_queue: QueueProtocol["WorkerResult"], cache_pool: CachePool, ): super().__init__(dataset, worker_config, rank_worker_id, cache_pool) @@ -155,31 +184,47 @@ def __init__( # ------------------------------------------------------------------------------------------------ # Section: Remote call implementation - class FutureImpl(Future): + @edataclass + class WorkerResult: + """Internal class for communicating a result from the worker via the result queue.""" + + future_id: int + result: Any = None + exception: Exception | None = None + + @edataclass + class WorkerCommand: + """Internal class for communicating a command to the worker via the command queue.""" + + cmd: str + args: tuple[Any, ...] + kwargs: dict[str, Any] + future_id: int + + class FutureImpl(Future[Any]): + """Class for returning a future result from the worker..""" + _outerself: "_DataLoaderAsyncWorker" _future_id: int _result: Any + _exception: Exception def __init__(self, outerself: "_DataLoaderAsyncWorker", future_id: int): self._outerself = outerself self._future_id = future_id def get(self) -> Any: + if hasattr(self, "_exception"): + raise self._exception if not hasattr(self, "_result"): self._outerself._wait_for_worker_result(self._future_id) - if isinstance(self._result, Exception): - raise self._result return self._result def _set_result(self, result: Any) -> None: self._result = result - @edataclass - class WorkerCommand: - cmd: str - args: tuple[Any, ...] - kwargs: dict[str, Any] - future_id: int + def _set_exception(self, exception: Exception) -> None: + self._exception = exception @staticmethod def worker_call(fn: Callable[P, R]) -> Callable[P, R]: @@ -208,11 +253,16 @@ def wrapper(self, *args, **kwargs) -> Future[R]: def _wait_for_worker_result(self, future_id: int) -> None: while True: - future_id, res = self._result_queue.get() - fut = self._futures.pop(future_id) - fut._set_result(res) - if future_id == future_id: + res = self._result_queue.get() + fut = self._futures.pop(res.future_id) + if res.exception is not None: + fut._set_exception(res.exception) + else: + fut._set_result(res.result) + if res.future_id == future_id: return + else: + continue def _worker_call(self, fn: str, *args: Any, **kwargs: Any) -> Future[Any]: self._assert_running() @@ -226,7 +276,11 @@ def _worker_call(self, fn: str, *args: Any, **kwargs: Any) -> Future[Any]: return self._futures[future_id] def _worker_run( - self, worker_id: int, cmd_queue: QueueProtocol, result_queue: QueueProtocol, seed: int + self, + worker_id: int, + cmd_queue: QueueProtocol[WorkerCommand], + result_queue: QueueProtocol[WorkerResult], + seed: int, ) -> None: SystemRng.seed(seed) self._worker_id = worker_id @@ -241,20 +295,25 @@ def _worker_run( self._global_worker_id = self.worker_config.global_worker_id() self.worker_config.assert_worker() while True: - cmd: _DataLoaderAsyncWorker.WorkerCommand | None = cmd_queue.get() - if cmd is None: - break + cmd = cmd_queue.get() try: fn = getattr(self, cmd.cmd) result = getattr(fn, "_orig")(self, *cmd.args, **cmd.kwargs) - result_queue.put((cmd.future_id, result)) + result_queue.put(self.WorkerResult(future_id=cmd.future_id, result=result)) del result except Exception as e: - result_queue.put((cmd.future_id, e)) + result_queue.put(self.WorkerResult(future_id=cmd.future_id, exception=e)) + if cmd.cmd == "_shutdown_worker": + break # ------------------------------------------------------------------------------------------------ # Section: Worker methods - now calling to workers via queues. + @worker_call + def _shutdown_worker(self) -> None: + """Shutdown the worker. The actual shutdown is handled in the _worker_run method.""" + pass + @override @worker_call def dataset_init(self, initial_state: FlexState | None) -> None: @@ -277,9 +336,11 @@ def save_state(self) -> FlexState: class ForkDataLoaderWorker(_DataLoaderAsyncWorker[TSample], Generic[TSample]): - _cmd_queue: multiprocessing.Queue - _result_queue: multiprocessing.Queue - _process: multiprocessing.Process | None + """ + Implements the `DataLoaderWorker` interface using processes. + """ + + _process: multiprocessing.Process | None = None _spawning_process: int @@ -351,7 +412,7 @@ def shutdown(self) -> None: ) return if self._process is not None: - self._cmd_queue.put(None) + self._shutdown_worker() self._process.join() self._cmd_queue.cancel_join_thread() self._cmd_queue.close() @@ -368,6 +429,12 @@ def _assert_running(self) -> None: class ThreadDataLoaderWorker(_DataLoaderAsyncWorker[TSample], Generic[TSample]): + """ + Implements the `DataLoaderWorker` interface using threads. + """ + + _thread: threading.Thread | None = None + def __init__( self, dataset: SavableDataset, @@ -402,7 +469,7 @@ def start(self) -> None: @override def shutdown(self) -> None: if self._thread is not None: - self._cmd_queue.put(None) + self._shutdown_worker() self._thread.join() self._thread = None @@ -458,8 +525,6 @@ def __init__( self._spawning_process = os.getpid() - self._restore_state = None - def shutdown(self) -> None: if self._workers is not None: for worker in self._workers: From c9cc54bb96d2930bc650be84c51ffafb3bd2d528 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Wed, 9 Jul 2025 12:43:54 +0200 Subject: [PATCH 04/36] WIP: Custom dataloader --- src/megatron/energon/dataloader.py | 91 ++++++++++++++++-------------- 1 file changed, 50 insertions(+), 41 deletions(-) diff --git a/src/megatron/energon/dataloader.py b/src/megatron/energon/dataloader.py index 32353df1..9de5c296 100644 --- a/src/megatron/energon/dataloader.py +++ b/src/megatron/energon/dataloader.py @@ -42,6 +42,7 @@ def get(self, /) -> TSample: ... def put(self, item: TSample, /) -> None: ... + class Future(Protocol[R]): def get(self) -> R: ... @@ -123,6 +124,7 @@ def dataset_init(self, initial_state: FlexState | None) -> None: assert initial_state["__class__"] == "DataLoaderWorker", "Worker type mismatch" self._sample_index.restore_state(initial_state["_sample_index"]) self.dataset.restore_state(initial_state["datasets"][0]) + # TODO: exhausted def new_iter(self) -> None: self._dataset_iter = iter(self.dataset) @@ -147,11 +149,12 @@ def save_state(self) -> FlexState: __class__="DataLoaderWorker", rng=SystemRng.save_state(), dataset=self.dataset.save_state(), + exhausted=self.exhausted, _sample_index=self._sample_index.save_state(), ) -class _DataLoaderAsyncWorker(DataLoaderWorker[TSample]): +class _DataLoaderAsynchronousWorker(DataLoaderWorker[TSample]): """ Extension of the `DataLoaderWorker`, which implements commands via a command and results queue. @@ -187,15 +190,14 @@ def __init__( @edataclass class WorkerResult: """Internal class for communicating a result from the worker via the result queue.""" - future_id: int result: Any = None exception: Exception | None = None + @edataclass class WorkerCommand: """Internal class for communicating a command to the worker via the command queue.""" - cmd: str args: tuple[Any, ...] kwargs: dict[str, Any] @@ -203,21 +205,20 @@ class WorkerCommand: class FutureImpl(Future[Any]): """Class for returning a future result from the worker..""" - - _outerself: "_DataLoaderAsyncWorker" + _worker: "_DataLoaderAsynchronousWorker" _future_id: int _result: Any _exception: Exception - def __init__(self, outerself: "_DataLoaderAsyncWorker", future_id: int): - self._outerself = outerself + def __init__(self, worker: "_DataLoaderAsynchronousWorker", future_id: int): + self._worker = worker self._future_id = future_id def get(self) -> Any: + if not hasattr(self, "_result") and not hasattr(self, "_exception"): + self._worker._wait_for_worker_result(self._future_id) if hasattr(self, "_exception"): raise self._exception - if not hasattr(self, "_result"): - self._outerself._wait_for_worker_result(self._future_id) return self._result def _set_result(self, result: Any) -> None: @@ -240,7 +241,7 @@ def wrapper(self, *args, **kwargs) -> R: return cast(Callable[P, R], wrapper) @staticmethod - def worker_call_async(fn: Callable[P, R]) -> Callable[P, Future[R]]: + def worker_call_future(fn: Callable[P, R]) -> Callable[P, Future[R]]: """Make the function be called in the worker process via the command and result queues. The function must be a method of the `DataLoaderWorker` class.""" @@ -276,18 +277,13 @@ def _worker_call(self, fn: str, *args: Any, **kwargs: Any) -> Future[Any]: return self._futures[future_id] def _worker_run( - self, - worker_id: int, - cmd_queue: QueueProtocol[WorkerCommand], - result_queue: QueueProtocol[WorkerResult], - seed: int, + self, cmd_queue: QueueProtocol[WorkerCommand], result_queue: QueueProtocol[WorkerResult], seed: int ) -> None: SystemRng.seed(seed) - self._worker_id = worker_id import torch.utils.data._utils torch.utils.data._utils.worker._worker_info = torch.utils.data._utils.worker.WorkerInfo( - id=worker_id, + id=self._rank_worker_id, num_workers=self.worker_config.num_workers, seed=seed, dataset=self.dataset, @@ -299,10 +295,11 @@ def _worker_run( try: fn = getattr(self, cmd.cmd) result = getattr(fn, "_orig")(self, *cmd.args, **cmd.kwargs) - result_queue.put(self.WorkerResult(future_id=cmd.future_id, result=result)) - del result except Exception as e: result_queue.put(self.WorkerResult(future_id=cmd.future_id, exception=e)) + else: + result_queue.put(self.WorkerResult(future_id=cmd.future_id, result=result)) + del result if cmd.cmd == "_shutdown_worker": break @@ -325,8 +322,11 @@ def new_iter(self) -> None: super().new_iter() @override - @worker_call_async + @worker_call_future def prefetch_next(self) -> TSample: + # The super class implementation already returns a resolved future (to be interface compatible), + # so immediately resolve the future to the result (get returns immediately). + # The worker_call_future will wrap the result again in a future implicitly. return super().prefetch_next().get() @override @@ -335,7 +335,7 @@ def save_state(self) -> FlexState: return super().save_state() -class ForkDataLoaderWorker(_DataLoaderAsyncWorker[TSample], Generic[TSample]): +class ForkDataLoaderWorker(_DataLoaderAsynchronousWorker[TSample], Generic[TSample]): """ Implements the `DataLoaderWorker` interface using processes. """ @@ -364,29 +364,29 @@ def __init__( def _check_parent_process(self, evt_exit: threading.Event) -> None: """Check if the parent process is alive. If it is not, exit the worker process.""" parent_proc = multiprocessing.parent_process() + parent_pid = os.getppid() if parent_proc is None: print("No parent process, exiting", file=sys.stderr) os._exit(-1) while not evt_exit.wait(1): - if parent_proc.exitcode is not None: + if parent_proc.exitcode is not None or os.getppid() != parent_pid: print("Parent process died, exiting", file=sys.stderr) os._exit(-1) def _worker_run( self, - worker_id: int, cmd_queue: multiprocessing.Queue, result_queue: multiprocessing.Queue, seed: int, ) -> None: - gc_init_worker(worker_id) + gc_init_worker(self._rank_worker_id) worker_exit_evt = threading.Event() parent_check_thread = threading.Thread( target=self._check_parent_process, args=(worker_exit_evt,), daemon=True ) parent_check_thread.start() try: - super()._worker_run(worker_id, cmd_queue, result_queue, seed) + super()._worker_run(cmd_queue, result_queue, seed) finally: worker_exit_evt.set() parent_check_thread.join() @@ -397,9 +397,10 @@ def _worker_run( @override def start(self) -> None: + # TODO: seed per worker self._process = multiprocessing.Process( target=self._worker_run, - args=(self._rank_worker_id, self._cmd_queue, self._result_queue), + args=(self._cmd_queue, self._result_queue, seed), ) self._process.start() @@ -428,11 +429,10 @@ def _assert_running(self) -> None: assert self._process.is_alive(), "Worker died" -class ThreadDataLoaderWorker(_DataLoaderAsyncWorker[TSample], Generic[TSample]): +class ThreadDataLoaderWorker(_DataLoaderAsynchronousWorker[TSample], Generic[TSample]): """ Implements the `DataLoaderWorker` interface using threads. """ - _thread: threading.Thread | None = None def __init__( @@ -452,11 +452,11 @@ def __init__( ) def _worker_run( - self, worker_id: int, cmd_queue: queue.Queue, result_queue: queue.Queue, seed: int + self, cmd_queue: queue.Queue, result_queue: queue.Queue, seed: int ) -> None: # TODO: Implement init_thread which should hook all randomness such that it's thread local. SystemRng.init_thread() - super()._worker_run(worker_id, cmd_queue, result_queue, seed) + super()._worker_run(cmd_queue, result_queue, seed) @override def start(self) -> None: @@ -508,21 +508,27 @@ class DataLoader(Generic[TSample]): def __init__( self, dataset: SavableDataset, - worker_config: WorkerConfig, prefetch_factor: int = 2, worker_type: WorkerType = ForkDataLoaderWorker, cache_pool: CachePool = NoCachePool(), ): - if self._worker_config.num_workers == 0 and worker_type == ForkDataLoaderWorker: + if dataset.worker_config.num_workers == 0 and worker_type == ForkDataLoaderWorker: worker_type = DataLoaderWorker self._dataset = dataset - self._worker_config = worker_config + self._worker_config = dataset.worker_config self._prefetch_factor = prefetch_factor self._worker_type = worker_type self._cache_pool = cache_pool self._prefetching_samples = [[] for _ in range(self._worker_config.num_workers)] self._exhausted_workers = [False] * self._worker_config.num_workers + if self._worker_config.num_workers == 0: + assert prefetch_factor == 1, "prefetch_factor must be 1 for num_workers == 0" + else: + assert prefetch_factor > 0, "prefetch_factor must be > 0 for num_workers > 0" + + # TODO: Seed per worker from SavableDataLoader + self._spawning_process = os.getpid() def shutdown(self) -> None: @@ -537,8 +543,10 @@ def start_iter(self) -> None: worker.new_iter() def _epoch_iter(self) -> Generator[TSample, None, None]: + """Iterate over the dataset for one epoch (i.e. all workers StopIteration). + One epoch may also be infinite (if looping the dataset).""" if self._workers is None: - self.start() + self._start() for worker, exhausted in zip(self._workers, self._exhausted_workers): if not exhausted: worker.new_iter() @@ -573,20 +581,20 @@ def _epoch_iter(self) -> Generator[TSample, None, None]: self._next_worker_id = (worker_idx + 1) % self._worker_config.num_workers if self._exhausted_workers[worker_idx]: continue + # Pop the first sample future from the prefetching samples. + sample_future = self._prefetching_samples[worker_idx].pop(0) # Prefetch samples from the worker. while len(self._prefetching_samples[worker_idx]) < self._prefetch_factor: # Add a new sample future to the prefetching samples if the worker has not prefetched enough samples. self._prefetching_samples[worker_idx].append(worker.prefetch_next()) - # Pop the first sample future from the prefetching samples. - sample_future = self._prefetching_samples[worker_idx].pop(0) try: # Get the sample from the sample future (may wait for the sample to be ready). sample = sample_future.get() - except Exception as e: - # If the sample future raises an exception, remove the worker from the list. + except StopIteration: + # If the sample future raises StopIteration, remove the worker from the list. self._prefetching_samples[worker_idx] = [] self._exhausted_workers[worker_idx] = True - raise e + continue else: # Yield the sample. yield sample @@ -621,7 +629,7 @@ def _get_batch_size(self) -> int | None: def save_state(self) -> FlexState: prefetched_samples_keys = [ - [get_sample_restore_key(sample.get()) for sample in prefetching_sample] + [get_sample_restore_key(sample_fut.get()) for sample_fut in prefetching_sample] for prefetching_sample in self._prefetching_samples ] if self._workers is None: @@ -638,7 +646,7 @@ def save_state(self) -> FlexState: micro_batch_size=self._get_batch_size(), ) - def start(self, initial_state: FlexState | None = None) -> None: + def _start(self, initial_state: FlexState | None = None) -> None: assert self._workers is None and self._current_epoch_iter is None, ( "DataLoader already started" ) @@ -669,6 +677,7 @@ def start(self, initial_state: FlexState | None = None) -> None: if initial_state is not None: self._prefetching_samples = [ [ + # TODO: Use a callback future DoneFuture(self.restore_sample(sample_key)) for sample_key in prefetched_samples_keys ] From 4d97aeba7e16d57f96f8edb382b4f337e16e4f3f Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Wed, 9 Jul 2025 16:05:06 +0200 Subject: [PATCH 05/36] WIP: Custom dataloader --- src/megatron/energon/dataloader.py | 398 +++++++++++++++++++---------- 1 file changed, 267 insertions(+), 131 deletions(-) diff --git a/src/megatron/energon/dataloader.py b/src/megatron/energon/dataloader.py index 9de5c296..86fc5e3d 100644 --- a/src/megatron/energon/dataloader.py +++ b/src/megatron/energon/dataloader.py @@ -15,12 +15,10 @@ ParamSpec, Protocol, TypeVar, - cast, override, ) from megatron.energon.cache.base import CachePool -from megatron.energon.cache.no_cache import NoCachePool from megatron.energon.edataclass import edataclass from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key from megatron.energon.rng import SystemRng @@ -28,10 +26,12 @@ from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key from megatron.energon.wrappers.batch_dataset import BatchDataset -from megatron.energon.wrappers.gc_dataset import gc_init_worker +from megatron.energon.wrappers.gc_dataset import GC_DEFAULT_EVERY_N_ITER, GcDataset, gc_init_worker +from megatron.energon.wrappers.watchdog_dataset import WatchdogDataset P = ParamSpec("P") R = TypeVar("R", covariant=True) +T = TypeVar("T") TSelf = TypeVar("TSelf", bound="DataLoaderWorker") TSample = TypeVar("TSample") @@ -42,7 +42,6 @@ def get(self, /) -> TSample: ... def put(self, item: TSample, /) -> None: ... - class Future(Protocol[R]): def get(self) -> R: ... @@ -57,6 +56,41 @@ def get(self) -> TSample: return self._result +class CallableFuture(Future[R]): + """Future that calls a callable to get the result.""" + + _callable: Callable[[], R] + _value: R + _exception: Exception + + def __init__(self, callable: Callable[[], R]): + self._callable = callable + + def get(self) -> R: + if not hasattr(self, "_value") and not hasattr(self, "_exception"): + try: + self._value = self._callable() + except Exception as e: + self._exception = e + if hasattr(self, "_exception"): + raise self._exception + return self._value + + @staticmethod + def chain(future: Future[T], fn: Callable[[Future[T]], R]) -> Future[R]: + """ + Chain a function to a future. + + Args: + future: The future which provides the input for the function. + fn: The function to call on the result of the future, to transform the result. + + Returns: + A future that will be resolved to the result of the function given the result of the future. + """ + return CallableFuture(lambda: fn(future)) + + class ExceptionFuture(Future[Any]): """Future that raises an exception.""" @@ -80,57 +114,108 @@ class DataLoaderWorker(Generic[TSample]): _rank_worker_id: int _global_worker_id: int - _cache_pool: CachePool - - exhausted: bool = True + _seed: int + _cache_pool: CachePool | None + _sample_index: SampleIndex + _exhausted: bool = True def __init__( self, dataset: SavableDataset[TSample], worker_config: WorkerConfig, rank_worker_id: int, - cache_pool: CachePool, + cache_pool: CachePool | None, ): + """ + Initialize the worker. + + Args: + dataset: The dataset to iterate over. + worker_config: The worker configuration. + rank_worker_id: The rank of the worker. + cache_pool: The cache pool to use. + """ self.dataset = dataset self.worker_config = worker_config self._rank_worker_id = rank_worker_id self._global_worker_id = worker_config.global_worker_id(rank_worker_id) + self._seed = self.worker_config.worker_seed(rank_worker_id) self._cache_pool = cache_pool # ------------------------------------------------------------------------------------------------ # Section: Main control methods def start(self) -> None: + """ + Start the worker. + """ pass def shutdown(self) -> None: + """ + Shutdown the worker. + """ pass def running(self) -> bool: + """ + Check if the worker is running. + """ return True def _assert_running(self) -> None: + """ + Assert that the worker is running and alive. + """ assert self.running(), "Worker must be running" # ------------------------------------------------------------------------------------------------ # Section: Worker methods - def dataset_init(self, initial_state: FlexState | None) -> None: + def dataset_init(self, state: FlexState | None) -> None: + """ + Initialize the worker (may restore the state). + Calls `new_iter` if the worker is not exhausted and also initially (`state=None`). + + Args: + state: The state to restore the worker from or None for using the initial state. + """ + # This is called in the worker context (process/thread). self._sample_index = SampleIndex(worker_config=self.worker_config, src=self) - self._global_worker_id = self.worker_config.global_worker_id() - if initial_state is None: + assert self._global_worker_id == self.worker_config.global_worker_id(), ( + "Global worker ID mismatch" + ) + assert self._seed == self.worker_config.worker_seed(self._rank_worker_id), "Seed mismatch" + if state is None: self.dataset.reset_state_deep() + self.new_iter() else: - assert initial_state["__class__"] == "DataLoaderWorker", "Worker type mismatch" - self._sample_index.restore_state(initial_state["_sample_index"]) - self.dataset.restore_state(initial_state["datasets"][0]) - # TODO: exhausted + assert state["__class__"] == "DataLoaderWorker", "Worker type mismatch" + self._sample_index.restore_state(state["_sample_index"]) + self.dataset.restore_state(state["datasets"][0]) + if not state["exhausted"]: + self.new_iter() def new_iter(self) -> None: + """ + Start a new iterator of the dataset. + Called after the dataset is initialized and to start a new epoch (if the dataset is not infinite). + The iterator is stored in the worker and is used by the `prefetch_next` method, which calls `next` on it. + Updates the exhausted flag to False. + """ + # This is called in the worker context (process/thread). self._dataset_iter = iter(self.dataset) - self.exhausted = False + self._exhausted = False def prefetch_next(self) -> Future[TSample]: + """ + Fetch the next sample (i.e. call `next` on the iterator) and return a future for getting the result. + Updates the exhausted flag if the iterator is exhausted. + + Returns: + A future that will either be resolved to the next sample or raise StopIteration if the iterator is exhausted. + """ + # This is called in the worker context (process/thread). assert self._dataset_iter is not None, "start_iter must be called before prefetch_next" with self._sample_index.ctx() as sample_idx: self.worker_config.worker_activate(sample_idx, cache_pool=self._cache_pool) @@ -138,18 +223,22 @@ def prefetch_next(self) -> Future[TSample]: next_sample = next(self._dataset_iter) add_sample_restore_key(next_sample, self._global_worker_id, sample_idx, src=self) except StopIteration as e: - self.exhausted = True + self._exhausted = True return ExceptionFuture(e) finally: self.worker_config.worker_deactivate() return DoneFuture(next_sample) def save_state(self) -> FlexState: + """ + Save the state of the worker. + """ + # This is called in the worker context (process/thread). return FlexState( __class__="DataLoaderWorker", rng=SystemRng.save_state(), dataset=self.dataset.save_state(), - exhausted=self.exhausted, + exhausted=self._exhausted, _sample_index=self._sample_index.save_state(), ) @@ -175,7 +264,7 @@ def __init__( rank_worker_id: int, cmd_queue: QueueProtocol["WorkerCommand"], result_queue: QueueProtocol["WorkerResult"], - cache_pool: CachePool, + cache_pool: CachePool | None, ): super().__init__(dataset, worker_config, rank_worker_id, cache_pool) assert worker_config.num_workers > 0, "Async workers require num_workers > 0" @@ -190,14 +279,15 @@ def __init__( @edataclass class WorkerResult: """Internal class for communicating a result from the worker via the result queue.""" + future_id: int result: Any = None exception: Exception | None = None - @edataclass class WorkerCommand: """Internal class for communicating a command to the worker via the command queue.""" + cmd: str args: tuple[Any, ...] kwargs: dict[str, Any] @@ -205,6 +295,7 @@ class WorkerCommand: class FutureImpl(Future[Any]): """Class for returning a future result from the worker..""" + _worker: "_DataLoaderAsynchronousWorker" _future_id: int _result: Any @@ -227,32 +318,14 @@ def _set_result(self, result: Any) -> None: def _set_exception(self, exception: Exception) -> None: self._exception = exception - @staticmethod - def worker_call(fn: Callable[P, R]) -> Callable[P, R]: - """Make the function be called in the worker process via the command and result queues. - The function must be a method of the `DataLoaderWorker` class.""" - - @functools.wraps(fn) - def wrapper(self, *args, **kwargs) -> R: - future = self._worker_call(fn.__name__, *args, **kwargs) - return future.get() - - setattr(wrapper, "_orig", fn) - return cast(Callable[P, R], wrapper) - - @staticmethod - def worker_call_future(fn: Callable[P, R]) -> Callable[P, Future[R]]: - """Make the function be called in the worker process via the command and result queues. - The function must be a method of the `DataLoaderWorker` class.""" - - @functools.wraps(fn) - def wrapper(self, *args, **kwargs) -> Future[R]: - return self._worker_call(fn.__name__, *args, **kwargs) - - setattr(wrapper, "_orig", fn) - return cast(Callable[P, Future[R]], wrapper) - def _wait_for_worker_result(self, future_id: int) -> None: + """ + Wait for the result of a future. + If another result comes first, update the corresponding future. + + Args: + future_id: The ID of the future to wait for. + """ while True: res = self._result_queue.get() fut = self._futures.pop(res.future_id) @@ -265,33 +338,56 @@ def _wait_for_worker_result(self, future_id: int) -> None: else: continue - def _worker_call(self, fn: str, *args: Any, **kwargs: Any) -> Future[Any]: + def _worker_call(self, fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> Future[R]: + """ + Call a function in the worker and return a future for getting the result. + The function must be an instance method of `self`. Uses the name to identify the function in the worker + instance. + + Args: + fn: The function to call. + *args: The arguments to pass to the function. + **kwargs: The keyword arguments to pass to the function. + """ self._assert_running() future_id = self._next_future_id self._next_future_id += 1 self._futures[future_id] = self.FutureImpl(self, future_id) self._cmd_queue.put( - self.WorkerCommand(cmd=fn, args=args, kwargs=kwargs, future_id=future_id) + self.WorkerCommand(cmd=fn.__name__, args=args, kwargs=kwargs, future_id=future_id) ) return self._futures[future_id] def _worker_run( - self, cmd_queue: QueueProtocol[WorkerCommand], result_queue: QueueProtocol[WorkerResult], seed: int + self, cmd_queue: QueueProtocol[WorkerCommand], result_queue: QueueProtocol[WorkerResult] ) -> None: - SystemRng.seed(seed) + """ + The worker main loop. + It waits for commands via the command queue and executes them. + The functions to call are identified by their name. + The result of the call is put into the result queue. + The worker exits when the command `_shutdown_worker` is received. + + Args: + cmd_queue: The command queue to wait for commands. + result_queue: The result queue to put the results into. + """ + SystemRng.seed(self._seed) import torch.utils.data._utils torch.utils.data._utils.worker._worker_info = torch.utils.data._utils.worker.WorkerInfo( id=self._rank_worker_id, num_workers=self.worker_config.num_workers, - seed=seed, + seed=self._seed, dataset=self.dataset, ) self._global_worker_id = self.worker_config.global_worker_id() self.worker_config.assert_worker() while True: cmd = cmd_queue.get() + if cmd.cmd == "_shutdown_worker": + break try: fn = getattr(self, cmd.cmd) result = getattr(fn, "_orig")(self, *cmd.args, **cmd.kwargs) @@ -300,39 +396,49 @@ def _worker_run( else: result_queue.put(self.WorkerResult(future_id=cmd.future_id, result=result)) del result - if cmd.cmd == "_shutdown_worker": - break # ------------------------------------------------------------------------------------------------ # Section: Worker methods - now calling to workers via queues. - @worker_call - def _shutdown_worker(self) -> None: + def _wrk_shutdown_worker(self) -> None: """Shutdown the worker. The actual shutdown is handled in the _worker_run method.""" - pass + # This is not actually a recursive call, because the worker loop will exit before calling this method. + self._worker_call(self._wrk_shutdown_worker) - @override - @worker_call - def dataset_init(self, initial_state: FlexState | None) -> None: + def _wrk_dataset_init(self, initial_state: FlexState | None) -> None: + """Wraps the super class method to call it in the worker process.""" super().dataset_init(initial_state) - @override - @worker_call - def new_iter(self) -> None: + def _wrk_new_iter(self) -> None: + """Wraps the super class method to call it in the worker process.""" super().new_iter() - @override - @worker_call_future - def prefetch_next(self) -> TSample: + def _wrk_prefetch_next(self) -> TSample: + """Wraps the super class method to call it in the worker process.""" # The super class implementation already returns a resolved future (to be interface compatible), # so immediately resolve the future to the result (get returns immediately). - # The worker_call_future will wrap the result again in a future implicitly. return super().prefetch_next().get() + def _wrk_save_state(self) -> FlexState: + """Wraps the super class method to call it in the worker process.""" + return super().save_state() + + @override + def dataset_init(self, initial_state: FlexState | None) -> None: + self._worker_call(self._wrk_dataset_init, initial_state).get() + + @override + def new_iter(self) -> None: + self._worker_call(self._wrk_new_iter).get() + + @override + def prefetch_next(self) -> Future[TSample]: + # Do not resolve the future here, but return it. + return self._worker_call(self._wrk_prefetch_next) + @override - @worker_call def save_state(self) -> FlexState: - return super().save_state() + return self._worker_call(self._wrk_save_state).get() class ForkDataLoaderWorker(_DataLoaderAsynchronousWorker[TSample], Generic[TSample]): @@ -341,6 +447,8 @@ class ForkDataLoaderWorker(_DataLoaderAsynchronousWorker[TSample], Generic[TSamp """ _process: multiprocessing.Process | None = None + _cmd_queue: multiprocessing.Queue + _result_queue: multiprocessing.Queue _spawning_process: int @@ -349,7 +457,7 @@ def __init__( dataset: SavableDataset, worker_config: WorkerConfig, rank_worker_id: int, - cache_pool: CachePool, + cache_pool: CachePool | None, ): super().__init__( dataset, @@ -362,7 +470,7 @@ def __init__( self._spawning_process = os.getpid() def _check_parent_process(self, evt_exit: threading.Event) -> None: - """Check if the parent process is alive. If it is not, exit the worker process.""" + """Check if the parent process is alive. If it is dead, exit the worker process.""" parent_proc = multiprocessing.parent_process() parent_pid = os.getppid() if parent_proc is None: @@ -377,7 +485,6 @@ def _worker_run( self, cmd_queue: multiprocessing.Queue, result_queue: multiprocessing.Queue, - seed: int, ) -> None: gc_init_worker(self._rank_worker_id) worker_exit_evt = threading.Event() @@ -386,7 +493,7 @@ def _worker_run( ) parent_check_thread.start() try: - super()._worker_run(cmd_queue, result_queue, seed) + super()._worker_run(cmd_queue, result_queue) finally: worker_exit_evt.set() parent_check_thread.join() @@ -397,10 +504,9 @@ def _worker_run( @override def start(self) -> None: - # TODO: seed per worker self._process = multiprocessing.Process( target=self._worker_run, - args=(self._cmd_queue, self._result_queue, seed), + args=(self._cmd_queue, self._result_queue), ) self._process.start() @@ -413,7 +519,7 @@ def shutdown(self) -> None: ) return if self._process is not None: - self._shutdown_worker() + self._wrk_shutdown_worker() self._process.join() self._cmd_queue.cancel_join_thread() self._cmd_queue.close() @@ -433,14 +539,17 @@ class ThreadDataLoaderWorker(_DataLoaderAsynchronousWorker[TSample], Generic[TSa """ Implements the `DataLoaderWorker` interface using threads. """ + _thread: threading.Thread | None = None + _cmd_queue: queue.Queue + _result_queue: queue.Queue def __init__( self, dataset: SavableDataset, worker_config: WorkerConfig, rank_worker_id: int, - cache_pool: CachePool, + cache_pool: CachePool | None, ): super().__init__( dataset, @@ -451,25 +560,21 @@ def __init__( cache_pool=cache_pool, ) - def _worker_run( - self, cmd_queue: queue.Queue, result_queue: queue.Queue, seed: int - ) -> None: - # TODO: Implement init_thread which should hook all randomness such that it's thread local. - SystemRng.init_thread() - super()._worker_run(cmd_queue, result_queue, seed) + def _worker_run(self, cmd_queue: queue.Queue, result_queue: queue.Queue) -> None: + super()._worker_run(cmd_queue, result_queue) @override def start(self) -> None: self._thread = threading.Thread( target=self._worker_run, - args=(self._rank_worker_id, self._cmd_queue, self._result_queue), + args=(self._cmd_queue, self._result_queue), ) self._thread.start() @override def shutdown(self) -> None: if self._thread is not None: - self._shutdown_worker() + self._wrk_shutdown_worker() self._thread.join() self._thread = None @@ -484,7 +589,7 @@ def __call__( dataset: SavableDataset, worker_config: WorkerConfig, rank_worker_id: int, - cache_pool: CachePool, + cache_pool: CachePool | None, ) -> DataLoaderWorker[TSample]: ... @@ -508,12 +613,59 @@ class DataLoader(Generic[TSample]): def __init__( self, dataset: SavableDataset, + *, prefetch_factor: int = 2, worker_type: WorkerType = ForkDataLoaderWorker, - cache_pool: CachePool = NoCachePool(), + cache_pool: CachePool | None = None, + # Garbage collection configuration + gc_collect_every_n_steps: int = GC_DEFAULT_EVERY_N_ITER, + gc_freeze_at_start: bool = True, + # Watchdog configuration + watchdog_timeout_seconds: float | None = 60, + watchdog_initial_timeout_seconds: float | None = None, + fail_on_timeout: bool = False, ): + """ + Create the dataloader supporting saving and restoring the state. + + Args: + dataset: The dataset to load. + prefetch_factor: The number of samples to prefetch from each worker. + worker_type: The type of worker to use. + cache_pool: If set, the cache pool to use for the dataset. + gc_collect_every_n_steps: The number of steps after which the garbage collector is + called. As we're usually handling large (but few) tensors here, and the python + garbage collection is already full of objects just by importing, this can improve + the memory footprint quite a lot, and may even be necessary to avoid memory + overflow. + gc_freeze_at_start: If true, the garbage collector is frozen at the start of the worker + processes. This improves the garbage collection performance by a lot. + In rare cases, this may cause issues and can be disabled. Keep enabled if you + experience no issues. + watchdog_timeout_seconds: The timeout in seconds. If None, the watchdog is disabled. + watchdog_initial_timeout_seconds: The initial timeout in seconds. If None, the timeout is the same as watchdog_timeout_seconds. + fail_on_timeout: If True, stops the whole process upon timeout, after printing a stack trace. + """ if dataset.worker_config.num_workers == 0 and worker_type == ForkDataLoaderWorker: worker_type = DataLoaderWorker + + if watchdog_timeout_seconds is not None: + dataset = WatchdogDataset( + dataset, + worker_config=dataset.worker_config, + timeout_seconds=watchdog_timeout_seconds, + initial_timeout_seconds=watchdog_initial_timeout_seconds, + fail_on_timeout=fail_on_timeout, + ) + + if gc_collect_every_n_steps > 0: + dataset = GcDataset( + dataset, + worker_config=dataset.worker_config, + every_n_iter=gc_collect_every_n_steps, + freeze=gc_freeze_at_start, + ) + self._dataset = dataset self._worker_config = dataset.worker_config self._prefetch_factor = prefetch_factor @@ -527,8 +679,6 @@ def __init__( else: assert prefetch_factor > 0, "prefetch_factor must be > 0 for num_workers > 0" - # TODO: Seed per worker from SavableDataLoader - self._spawning_process = os.getpid() def shutdown(self) -> None: @@ -547,11 +697,7 @@ def _epoch_iter(self) -> Generator[TSample, None, None]: One epoch may also be infinite (if looping the dataset).""" if self._workers is None: self._start() - for worker, exhausted in zip(self._workers, self._exhausted_workers): - if not exhausted: - worker.new_iter() - - assert self._workers is not None, "DataLoader not started" + assert self._workers is not None, "DataLoader not started" if all(self._exhausted_workers): # All workers are exhausted, restart for the next epoch. @@ -628,6 +774,10 @@ def _get_batch_size(self) -> int | None: return None def save_state(self) -> FlexState: + # TODO: The redist tool must be able to change the batch size. + # That means that the redist tool shall split a saved restore key for the "BatchDataset". + # It should also change the saved micro batch size to match that. + # TODO @pfischer: Add changing the batch size to the docs. prefetched_samples_keys = [ [get_sample_restore_key(sample_fut.get()) for sample_fut in prefetching_sample] for prefetching_sample in self._prefetching_samples @@ -641,7 +791,6 @@ def save_state(self) -> FlexState: __class__=type(self).__name__, prefetched_samples_keys=prefetched_samples_keys, worker_states=worker_states, - workers_exhausted=self._exhausted_workers.copy(), next_worker_id=self._next_worker_id, micro_batch_size=self._get_batch_size(), ) @@ -663,33 +812,40 @@ def _start(self, initial_state: FlexState | None = None) -> None: self._restore_state = None if initial_state is None: - initial_states = [None] * self._worker_config.num_workers + worker_states = [None] * self._worker_config.num_workers else: - initial_states = initial_state["worker_states"] + worker_states = initial_state["worker_states"] - assert len(initial_states) == self._worker_config.num_workers, ( + assert len(worker_states) == self._worker_config.num_workers, ( "Number of initial states must match number of workers" ) - for worker, initial_state in zip(self._workers, initial_states): - worker.dataset_init(initial_state) + for worker, worker_state in zip(self._workers, worker_states): + worker.dataset_init(worker_state) if initial_state is not None: self._prefetching_samples = [ [ - # TODO: Use a callback future - DoneFuture(self.restore_sample(sample_key)) + CallableFuture(functools.partial(self.restore_sample, sample_key)) for sample_key in prefetched_samples_keys ] for prefetched_samples_keys in initial_state["prefetched_samples_keys"] ] self._next_worker_id = initial_state["next_worker_id"] - self._exhausted_workers = initial_state["workers_exhausted"].copy() + self._exhausted_workers = [ + False if worker_state is None else worker_state["exhausted"] + for worker_state in worker_states + ] def restore_state_rank(self, state: FlexState | None) -> None: + """ + Restore the state of the DataLoader on the current rank. + The state is actually restored when the processes are started, in the iterator. + """ assert self._workers is None and self._current_epoch_iter is None, ( "DataLoader already started" ) + assert self._restore_state is None, "Restore state already set" if state is None: # Assume initial state. @@ -697,40 +853,20 @@ def restore_state_rank(self, state: FlexState | None) -> None: assert isinstance(state, FlexState) assert state["__class__"] == type(self).__name__, "DataLoader type mismatch" - - old_micro_batch_size = state["micro_batch_size"] - micro_batch_size = self._get_batch_size() - - if self._worker_config.num_workers == 0: - assert micro_batch_size == old_micro_batch_size, "Micro batch size mismatch" - assert len(state["worker_states"]) == 1 - assert isinstance(state["worker_states"][0], FlexState) - self._dataset.restore_state(state["worker_states"][0]) - else: - # Check batch sizes (before and after) - if micro_batch_size != old_micro_batch_size: - assert micro_batch_size is not None and old_micro_batch_size is not None, ( - "Cannot resume with different batching mode " - "(batching to non-batching or vice versa)" - ) - - if micro_batch_size > old_micro_batch_size: - raise ValueError( - "Resuming with larger micro batch size is not allowed: " - f"{micro_batch_size} > {old_micro_batch_size}" - ) - elif ( - micro_batch_size < old_micro_batch_size - and old_micro_batch_size % micro_batch_size != 0 - ): - raise ValueError( - "Resuming with smaller micro batch size only allowed if the old " - f"micro batch size is a multiple of the new one: {micro_batch_size} < {old_micro_batch_size}" - ) + assert state["micro_batch_size"] == self._get_batch_size(), "Micro batch size mismatch" self._restore_state = state def restore_sample(self, restore_key: tuple) -> TSample: + """ + Restore a sample from a restore key. + + Args: + restore_key: The restore key to restore the sample from. + + Returns: + The restored sample. + """ id, global_worker_id, sample_idx = restore_key[:3] assert id == type(self).__name__ restore_key = restore_key[3:] From 1d7fda58b88d4b885a1a42dc29278792d9bfffd8 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Thu, 10 Jul 2025 15:05:28 +0200 Subject: [PATCH 06/36] Add dataloader test and fix dataloader cleanup --- src/megatron/energon/dataloader.py | 327 +++++++++++++++++++++-------- src/megatron/energon/worker.py | 5 + tests/test_dataloader.py | 237 +++++++++++++++++++++ 3 files changed, 479 insertions(+), 90 deletions(-) create mode 100644 tests/test_dataloader.py diff --git a/src/megatron/energon/dataloader.py b/src/megatron/energon/dataloader.py index 86fc5e3d..43318c0f 100644 --- a/src/megatron/energon/dataloader.py +++ b/src/megatron/energon/dataloader.py @@ -2,11 +2,14 @@ # SPDX-License-Identifier: BSD-3-Clause import functools import multiprocessing +import multiprocessing.managers import os import queue import sys import threading +import traceback import warnings +from abc import abstractmethod from typing import ( Any, Callable, @@ -24,7 +27,7 @@ from megatron.energon.rng import SystemRng from megatron.energon.state import FlexState from megatron.energon.worker import WorkerConfig -from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key +from megatron.energon.wrappers.base import BaseWrapperDataset, get_sample_restore_key from megatron.energon.wrappers.batch_dataset import BatchDataset from megatron.energon.wrappers.gc_dataset import GC_DEFAULT_EVERY_N_ITER, GcDataset, gc_init_worker from megatron.energon.wrappers.watchdog_dataset import WatchdogDataset @@ -116,7 +119,7 @@ class DataLoaderWorker(Generic[TSample]): _global_worker_id: int _seed: int _cache_pool: CachePool | None - _sample_index: SampleIndex + _sample_index: int = 0 _exhausted: bool = True def __init__( @@ -151,9 +154,12 @@ def start(self) -> None: """ pass - def shutdown(self) -> None: + def shutdown(self, in_del: bool = False) -> None: """ Shutdown the worker. + + Args: + in_del: If True, the worker is being deleted. """ pass @@ -169,6 +175,9 @@ def _assert_running(self) -> None: """ assert self.running(), "Worker must be running" + def __del__(self) -> None: + self.shutdown(in_del=True) + # ------------------------------------------------------------------------------------------------ # Section: Worker methods @@ -181,20 +190,25 @@ def dataset_init(self, state: FlexState | None) -> None: state: The state to restore the worker from or None for using the initial state. """ # This is called in the worker context (process/thread). - self._sample_index = SampleIndex(worker_config=self.worker_config, src=self) assert self._global_worker_id == self.worker_config.global_worker_id(), ( "Global worker ID mismatch" ) assert self._seed == self.worker_config.worker_seed(self._rank_worker_id), "Seed mismatch" + print(f"dataset_init {state=}\n", end="") if state is None: + self._sample_index = 0 self.dataset.reset_state_deep() + print("dataset_init reset_state_deep\n", end="") self.new_iter() + print("dataset_init new_iter\n", end="") else: - assert state["__class__"] == "DataLoaderWorker", "Worker type mismatch" - self._sample_index.restore_state(state["_sample_index"]) - self.dataset.restore_state(state["datasets"][0]) + assert state["__class__"] == "DataLoaderWorker", "state type mismatch" + self._sample_index = state["sample_index"] + SystemRng.restore_state(state["rng"]) + self.dataset.restore_state(state["dataset"]) if not state["exhausted"]: self.new_iter() + assert self._exhausted == state["exhausted"], "Exhausted state mismatch" def new_iter(self) -> None: """ @@ -204,8 +218,10 @@ def new_iter(self) -> None: Updates the exhausted flag to False. """ # This is called in the worker context (process/thread). + print("new_iter\n", end="") self._dataset_iter = iter(self.dataset) self._exhausted = False + print("new_iter done\n", end="") def prefetch_next(self) -> Future[TSample]: """ @@ -217,16 +233,24 @@ def prefetch_next(self) -> Future[TSample]: """ # This is called in the worker context (process/thread). assert self._dataset_iter is not None, "start_iter must be called before prefetch_next" - with self._sample_index.ctx() as sample_idx: - self.worker_config.worker_activate(sample_idx, cache_pool=self._cache_pool) + if self._exhausted: try: - next_sample = next(self._dataset_iter) - add_sample_restore_key(next_sample, self._global_worker_id, sample_idx, src=self) + raise StopIteration() except StopIteration as e: - self._exhausted = True return ExceptionFuture(e) - finally: - self.worker_config.worker_deactivate() + sample_idx = self._sample_index + self.worker_config.worker_activate(sample_idx, cache_pool=self._cache_pool) + try: + next_sample = next(self._dataset_iter) + self._sample_index += 1 + next_sample = add_sample_restore_key( + next_sample, self._global_worker_id, sample_idx, src=self + ) + except StopIteration as e: + self._exhausted = True + return ExceptionFuture(e) + finally: + self.worker_config.worker_deactivate() return DoneFuture(next_sample) def save_state(self) -> FlexState: @@ -239,7 +263,7 @@ def save_state(self) -> FlexState: rng=SystemRng.save_state(), dataset=self.dataset.save_state(), exhausted=self._exhausted, - _sample_index=self._sample_index.save_state(), + sample_index=self._sample_index, ) @@ -327,15 +351,19 @@ def _wait_for_worker_result(self, future_id: int) -> None: future_id: The ID of the future to wait for. """ while True: + print(f"[fut={future_id}] waiting for result\n", end="") res = self._result_queue.get() fut = self._futures.pop(res.future_id) if res.exception is not None: fut._set_exception(res.exception) else: fut._set_result(res.result) + # self._result_queue.task_done() if res.future_id == future_id: + print(f"[fut={future_id}] got result, return\n", end="") return else: + print(f"[fut={future_id}] got result for {res.future_id=}, continue\n", end="") continue def _worker_call(self, fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> Future[R]: @@ -350,14 +378,20 @@ def _worker_call(self, fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> **kwargs: The keyword arguments to pass to the function. """ self._assert_running() + assert not self._in_worker(), "worker_call must not be called in the worker" future_id = self._next_future_id self._next_future_id += 1 - self._futures[future_id] = self.FutureImpl(self, future_id) + self._futures[future_id] = future = self.FutureImpl(self, future_id) + print( + f"[wrk={self._rank_worker_id}] worker_call {fn.__name__=} {args=} {kwargs=} {future_id=}\n", + end="", + ) self._cmd_queue.put( self.WorkerCommand(cmd=fn.__name__, args=args, kwargs=kwargs, future_id=future_id) ) - return self._futures[future_id] + print(f"[wrk={self._rank_worker_id}] queue: {self._cmd_queue.qsize()}\n", end="") + return future def _worker_run( self, cmd_queue: QueueProtocol[WorkerCommand], result_queue: QueueProtocol[WorkerResult] @@ -373,45 +407,65 @@ def _worker_run( cmd_queue: The command queue to wait for commands. result_queue: The result queue to put the results into. """ - SystemRng.seed(self._seed) - import torch.utils.data._utils - - torch.utils.data._utils.worker._worker_info = torch.utils.data._utils.worker.WorkerInfo( - id=self._rank_worker_id, - num_workers=self.worker_config.num_workers, - seed=self._seed, - dataset=self.dataset, - ) - self._global_worker_id = self.worker_config.global_worker_id() - self.worker_config.assert_worker() - while True: - cmd = cmd_queue.get() - if cmd.cmd == "_shutdown_worker": - break - try: - fn = getattr(self, cmd.cmd) - result = getattr(fn, "_orig")(self, *cmd.args, **cmd.kwargs) - except Exception as e: - result_queue.put(self.WorkerResult(future_id=cmd.future_id, exception=e)) - else: - result_queue.put(self.WorkerResult(future_id=cmd.future_id, result=result)) - del result + assert self._in_worker(), "_worker_run must be called in the worker" + try: + SystemRng.seed(self._seed) + import torch.utils.data._utils + + torch.utils.data._utils.worker._worker_info = torch.utils.data._utils.worker.WorkerInfo( + id=self._rank_worker_id, + num_workers=self.worker_config.num_workers, + seed=self._seed, + dataset=self.dataset, + ) + self._global_worker_id = self.worker_config.global_worker_id() + self.worker_config.assert_worker() + while True: + print( + f"[wrk={self._rank_worker_id}] waiting for command, len: {cmd_queue.qsize()}\n", + end="", + ) + cmd = cmd_queue.get() + print( + f"[fut={cmd.future_id}] got command {cmd.cmd=} {cmd.args=} {cmd.kwargs=}\n", + end="", + ) + try: + fn = getattr(self, cmd.cmd) + result = fn(*cmd.args, **cmd.kwargs) + except Exception as e: + print(f"[fut={cmd.future_id}] send exception {e!r}\n", end="") + result_queue.put(self.WorkerResult(future_id=cmd.future_id, exception=e)) + else: + print(f"[fut={cmd.future_id}] send result {result!r}\n", end="") + result_queue.put(self.WorkerResult(future_id=cmd.future_id, result=result)) + del result + # cmd_queue.task_done() + if cmd.cmd == self._wrk_shutdown_worker.__name__: + print(f"[fut={cmd.future_id}] got shutdown command, exit\n", end="") + break + print(f"[fut={cmd.future_id}] processed, waiting for next command\n", end="") + except: + traceback.print_exc() + raise + + @abstractmethod + def _in_worker(self) -> bool: + """Check if the execution is within the worker.""" + ... # ------------------------------------------------------------------------------------------------ # Section: Worker methods - now calling to workers via queues. def _wrk_shutdown_worker(self) -> None: + """Does nothing. The actual shutdown is handled in the _worker_run method.""" + assert self._in_worker(), "_wrk_shutdown_worker must be called in the worker" + + def _shutdown_worker(self) -> None: """Shutdown the worker. The actual shutdown is handled in the _worker_run method.""" + assert not self._in_worker(), "shutdown_worker must not be called in the worker" # This is not actually a recursive call, because the worker loop will exit before calling this method. - self._worker_call(self._wrk_shutdown_worker) - - def _wrk_dataset_init(self, initial_state: FlexState | None) -> None: - """Wraps the super class method to call it in the worker process.""" - super().dataset_init(initial_state) - - def _wrk_new_iter(self) -> None: - """Wraps the super class method to call it in the worker process.""" - super().new_iter() + self._worker_call(self._wrk_shutdown_worker).get() def _wrk_prefetch_next(self) -> TSample: """Wraps the super class method to call it in the worker process.""" @@ -419,26 +473,33 @@ def _wrk_prefetch_next(self) -> TSample: # so immediately resolve the future to the result (get returns immediately). return super().prefetch_next().get() - def _wrk_save_state(self) -> FlexState: - """Wraps the super class method to call it in the worker process.""" - return super().save_state() - @override def dataset_init(self, initial_state: FlexState | None) -> None: - self._worker_call(self._wrk_dataset_init, initial_state).get() + if self._in_worker(): + return super().dataset_init(initial_state) + else: + return self._worker_call(self.dataset_init, initial_state).get() @override def new_iter(self) -> None: - self._worker_call(self._wrk_new_iter).get() + if self._in_worker(): + return super().new_iter() + else: + return self._worker_call(self.new_iter).get() @override def prefetch_next(self) -> Future[TSample]: # Do not resolve the future here, but return it. + if self._in_worker(): + return super().prefetch_next() return self._worker_call(self._wrk_prefetch_next) @override def save_state(self) -> FlexState: - return self._worker_call(self._wrk_save_state).get() + if self._in_worker(): + return super().save_state() + else: + return self._worker_call(self.save_state).get() class ForkDataLoaderWorker(_DataLoaderAsynchronousWorker[TSample], Generic[TSample]): @@ -450,6 +511,8 @@ class ForkDataLoaderWorker(_DataLoaderAsynchronousWorker[TSample], Generic[TSamp _cmd_queue: multiprocessing.Queue _result_queue: multiprocessing.Queue + _threaded_shutdown: threading.Thread | None = None + _spawning_process: int def __init__( @@ -459,6 +522,7 @@ def __init__( rank_worker_id: int, cache_pool: CachePool | None, ): + multiprocessing.set_start_method("fork", force=True) super().__init__( dataset, worker_config=worker_config, @@ -487,6 +551,8 @@ def _worker_run( result_queue: multiprocessing.Queue, ) -> None: gc_init_worker(self._rank_worker_id) + # cmd_queue is read only, so we can cancel the join thread. + cmd_queue.cancel_join_thread() worker_exit_evt = threading.Event() parent_check_thread = threading.Thread( target=self._check_parent_process, args=(worker_exit_evt,), daemon=True @@ -495,23 +561,35 @@ def _worker_run( try: super()._worker_run(cmd_queue, result_queue) finally: + print(f"[wrk={self._rank_worker_id}] shutting down\n", end="") worker_exit_evt.set() + print( + f"[wrk={self._rank_worker_id}] shutting down, wait for parent_check_thread\n", + end="", + ) parent_check_thread.join() - cmd_queue.cancel_join_thread() - cmd_queue.close() - result_queue.cancel_join_thread() + print(f"[wrk={self._rank_worker_id}] shutting down, close queues\n", end="") result_queue.close() + result_queue.join_thread() + cmd_queue.close() + cmd_queue.cancel_join_thread() + print(f"[wrk={self._rank_worker_id}] shutting down, done\n", end="") + + @override + def _in_worker(self) -> bool: + return multiprocessing.current_process() == self._process @override def start(self) -> None: self._process = multiprocessing.Process( target=self._worker_run, args=(self._cmd_queue, self._result_queue), + daemon=True, ) self._process.start() @override - def shutdown(self) -> None: + def shutdown(self, in_del: bool = False) -> None: if self._spawning_process != os.getpid(): # Should avoid forked process containing a forked worker on exit. warnings.warn( @@ -519,12 +597,41 @@ def shutdown(self) -> None: ) return if self._process is not None: - self._wrk_shutdown_worker() - self._process.join() - self._cmd_queue.cancel_join_thread() - self._cmd_queue.close() - self._result_queue.cancel_join_thread() - self._result_queue.close() + if in_del: + # It seems that the ResourceWarning does not work in the gc loop? Also print a warning here. + warnings.warn( + "Explicitly call DataLoader.shutdown() to avoid leaking processes. Terminating worker process.", + ResourceWarning, + ) + print( + "WARNING: Explicitly call DataLoader.shutdown() to avoid leaking processes. Terminating worker process.\n", + end="", + file=sys.stderr, + ) + self._cmd_queue.close() + self._cmd_queue.cancel_join_thread() + self._result_queue.close() + self._result_queue.cancel_join_thread() + # Kill the process, because we cannot communicate with it in the gc loop. + self._process.terminate() + self._process = None + else: + try: + self._shutdown_worker() + except Exception: + self._process.join(10) + if self._process.is_alive(): + self._process.terminate() + else: + self._process.join() + assert self._process.exitcode == 0, ( + f"Process exit code {self._process.exitcode}" + ) + self._process = None + self._cmd_queue.close() + self._cmd_queue.cancel_join_thread() + self._result_queue.close() + self._result_queue.cancel_join_thread() @override def running(self) -> bool: @@ -563,25 +670,54 @@ def __init__( def _worker_run(self, cmd_queue: queue.Queue, result_queue: queue.Queue) -> None: super()._worker_run(cmd_queue, result_queue) + @override + def _in_worker(self) -> bool: + return threading.current_thread() == self._thread + @override def start(self) -> None: self._thread = threading.Thread( target=self._worker_run, args=(self._cmd_queue, self._result_queue), + daemon=True, ) self._thread.start() @override - def shutdown(self) -> None: + def shutdown(self, in_del: bool = False) -> None: if self._thread is not None: - self._wrk_shutdown_worker() - self._thread.join() - self._thread = None + if in_del: + # It seems that the ResourceWarning does not work in the gc loop? Also print a warning here. + warnings.warn( + "Explicitly call DataLoader.shutdown() to avoid leaking threads.", + ResourceWarning, + ) + print( + "WARNING: Explicitly call DataLoader.shutdown() to avoid leaking threads.\n", + end="", + file=sys.stderr, + ) + # Just try to enqueue the shutdown command to the thread and hope for the best. Ignore the result. + self._cmd_queue.put( + self.WorkerCommand( + cmd=self._wrk_shutdown_worker.__name__, args=(), kwargs={}, future_id=-1 + ) + ) + self._thread = None + else: + self._shutdown_worker() + self._thread.join() + self._thread = None @override def running(self) -> bool: return self._thread is not None + @override + def _assert_running(self) -> None: + assert self._thread is not None, "Thread must be started first" + assert self._thread.is_alive(), "Thread died" + class WorkerType(Protocol[TSample]): def __call__( @@ -614,7 +750,7 @@ def __init__( self, dataset: SavableDataset, *, - prefetch_factor: int = 2, + prefetch_factor: int = 1, worker_type: WorkerType = ForkDataLoaderWorker, cache_pool: CachePool | None = None, # Garbage collection configuration @@ -671,8 +807,8 @@ def __init__( self._prefetch_factor = prefetch_factor self._worker_type = worker_type self._cache_pool = cache_pool - self._prefetching_samples = [[] for _ in range(self._worker_config.num_workers)] - self._exhausted_workers = [False] * self._worker_config.num_workers + self._prefetching_samples = [[] for _ in range(self._worker_config.safe_num_workers)] + self._exhausted_workers = [False] * self._worker_config.safe_num_workers if self._worker_config.num_workers == 0: assert prefetch_factor == 1, "prefetch_factor must be 1 for num_workers == 0" @@ -681,12 +817,25 @@ def __init__( self._spawning_process = os.getpid() - def shutdown(self) -> None: + def shutdown(self, in_del: bool = False) -> None: if self._workers is not None: + if in_del: + warnings.warn( + "Explicitly call DataLoader.shutdown() to avoid leaking workers.", + ResourceWarning, + ) + print( + "WARNING: Explicitly call DataLoader.shutdown() to avoid leaking workers.\n", + end="", + file=sys.stderr, + ) for worker in self._workers: - worker.shutdown() + worker.shutdown(in_del=in_del) self._workers = None + def __del__(self) -> None: + self.shutdown(in_del=True) + def start_iter(self) -> None: if self._workers is not None: for worker in self._workers: @@ -703,7 +852,7 @@ def _epoch_iter(self) -> Generator[TSample, None, None]: # All workers are exhausted, restart for the next epoch. for worker in self._workers: worker.new_iter() - self._exhausted_workers = [False] * self._worker_config.num_workers + self._exhausted_workers = [False] * self._worker_config.safe_num_workers # For all workers, enqueue prefetching samples. for worker_idx, (worker, exhausted) in enumerate( @@ -720,15 +869,19 @@ def _epoch_iter(self) -> Generator[TSample, None, None]: # - Pop the first sample future from the prefetching samples. # - Get the sample from the sample future (may wait for the sample to be prefetched). # - Yield the sample. + print(f"{self._exhausted_workers=}\n", end="") while not all(self._exhausted_workers): # Get the next worker to prefetch samples from. worker_idx = self._next_worker_id worker = self._workers[worker_idx] - self._next_worker_id = (worker_idx + 1) % self._worker_config.num_workers + print(f"{worker_idx=} {worker=}\n", end="") + self._next_worker_id = (worker_idx + 1) % self._worker_config.safe_num_workers if self._exhausted_workers[worker_idx]: + print(f"{worker_idx=} exhausted, continue with next worker\n", end="") continue # Pop the first sample future from the prefetching samples. sample_future = self._prefetching_samples[worker_idx].pop(0) + print(f"{sample_future=}\n", end="") # Prefetch samples from the worker. while len(self._prefetching_samples[worker_idx]) < self._prefetch_factor: # Add a new sample future to the prefetching samples if the worker has not prefetched enough samples. @@ -737,11 +890,13 @@ def _epoch_iter(self) -> Generator[TSample, None, None]: # Get the sample from the sample future (may wait for the sample to be ready). sample = sample_future.get() except StopIteration: + print(f"{worker_idx=} exhausted, remove from prefetching samples\n", end="") # If the sample future raises StopIteration, remove the worker from the list. self._prefetching_samples[worker_idx] = [] self._exhausted_workers[worker_idx] = True continue else: + print(f"{worker_idx=} got sample, yield\n", end="") # Yield the sample. yield sample @@ -754,11 +909,6 @@ def __iter__(self) -> Generator[TSample, None, None]: # Reset the epoch iterator, it was exhausted. self._current_epoch_iter = None - def __del__(self) -> None: - if self._spawning_process == os.getpid(): - # Otherwise we may be in a forked process which is not the one that spawned the DataLoader. - self.shutdown() - def __len__(self): return len(self._dataset) @@ -773,7 +923,7 @@ def _get_batch_size(self) -> int | None: else: return None - def save_state(self) -> FlexState: + def save_state_rank(self) -> FlexState: # TODO: The redist tool must be able to change the batch size. # That means that the redist tool shall split a saved restore key for the "BatchDataset". # It should also change the saved micro batch size to match that. @@ -783,7 +933,7 @@ def save_state(self) -> FlexState: for prefetching_sample in self._prefetching_samples ] if self._workers is None: - worker_states = [None] * self._worker_config.num_workers + worker_states = [None] * self._worker_config.safe_num_workers else: worker_states = [worker.save_state() for worker in self._workers] @@ -796,12 +946,9 @@ def save_state(self) -> FlexState: ) def _start(self, initial_state: FlexState | None = None) -> None: - assert self._workers is None and self._current_epoch_iter is None, ( - "DataLoader already started" - ) self._workers = [ self._worker_type(self._dataset, self._worker_config, local_worker_id, self._cache_pool) - for local_worker_id in range(max(self._worker_config.num_workers, 1)) + for local_worker_id in range(self._worker_config.safe_num_workers) ] for worker in self._workers: worker.start() @@ -812,11 +959,11 @@ def _start(self, initial_state: FlexState | None = None) -> None: self._restore_state = None if initial_state is None: - worker_states = [None] * self._worker_config.num_workers + worker_states = [None] * self._worker_config.safe_num_workers else: worker_states = initial_state["worker_states"] - assert len(worker_states) == self._worker_config.num_workers, ( + assert len(worker_states) == self._worker_config.safe_num_workers, ( "Number of initial states must match number of workers" ) diff --git a/src/megatron/energon/worker.py b/src/megatron/energon/worker.py index ec01e0b2..401b21ae 100644 --- a/src/megatron/energon/worker.py +++ b/src/megatron/energon/worker.py @@ -128,6 +128,11 @@ def active_worker_batch_index(self) -> int: WorkerConfig._sample_index_stack[0] * max(self.num_workers, 1) + self.rank_worker_id() ) + @property + def safe_num_workers(self) -> int: + """Returns the number of workers, but at least 1.""" + return max(self.num_workers, 1) + def global_rank(self) -> int: """Returns the global rank of this worker config but as a global rank, not as a rank within the data parallel group.""" diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py new file mode 100644 index 00000000..d54633c1 --- /dev/null +++ b/tests/test_dataloader.py @@ -0,0 +1,237 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause + +"""This module defines tests for meta datasets.""" + +import gc +import logging +import sys +import tempfile +import unittest +import warnings +from collections import Counter +from pathlib import Path +from typing import Iterable + +import torch +import webdataset as wds + +from megatron.energon import ( + WorkerConfig, + get_train_dataset, +) +from megatron.energon.dataloader import DataLoader, ForkDataLoaderWorker, ThreadDataLoaderWorker +from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME + + +class TestDataloader(unittest.TestCase): + # Set up the test fixture + def setUp(self): + logging.basicConfig(stream=sys.stderr, level=logging.INFO) + warnings.simplefilter("ignore", ResourceWarning) + + # Create a temporary directory + self.temp_dir = tempfile.TemporaryDirectory() + self.dataset_path = Path(self.temp_dir.name) + # self.dataset_path = Path("./test_dataset") + + self.dataset_path.mkdir(exist_ok=True, parents=True) + + self.ds1_path = self.dataset_path / "ds1" + self.ds1_path.mkdir(exist_ok=True, parents=True) + + # Create a small dummy captioning dataset + self.create_text_test_dataset(self.ds1_path, range(55), range(55)) + print(self.ds1_path) + + def tearDown(self): + # Remove all temporary files + gc.collect() + self.temp_dir.cleanup() + + @staticmethod + def create_text_test_dataset(path: Path, txt_range: Iterable[int], key_range: Iterable[int]): + """Creates a small dummy test dataset for testing purposes.""" + + # Create num_samples unique captions + (path / "parts").mkdir(exist_ok=True, parents=True) + + # Initialize the ShardWriter + with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=10) as shard_writer: + for key, txt in zip(key_range, txt_range): + # Write individual files to shards + shard_writer.write( + { + "__key__": f"{key:06d}", + "txt": f"{txt}".encode(), + }, + ) + total_shards = shard_writer.shard + + from megatron.energon.flavors import BaseWebdatasetFactory + + BaseWebdatasetFactory.prepare_dataset( + path, + [f"parts/data-{{0..{total_shards - 1}}}.tar"], + split_parts_ratio=[("train", 1.0)], + shuffle_seed=None, + ) + + with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: + f.write( + "\n".join( + [ + "sample_type:", + " __module__: megatron.energon", + " __class__: TextSample", + "field_map:", + " text: txt", + "subflavors:", + " source: dataset.yaml", + " dataset.yaml: true", + " number: 42", + ] + ) + ) + + def test_dataloader_no_workers(self): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + + # Train mode dataset + train_loader = DataLoader( + get_train_dataset( + self.ds1_path, + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + ) + ) + assert len(train_loader) == 6, len(train_loader) + + train_order1 = [ + text for idx, data in zip(range(55 * 10), train_loader) for text in data.text + ] + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(train_order1) == 55, len(train_order1) + assert len(Counter(train_order1)) == 55, Counter(train_order1) + assert all(v == 1 for v in Counter(train_order1).values()), Counter(train_order1) + + state1 = train_loader.save_state_rank() + + train_order2 = [ + text for idx, data in zip(range(55 * 10), train_loader) for text in data.text + ] + + train_loader.shutdown() + + train_loader = DataLoader( + get_train_dataset( + self.ds1_path, + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + ) + ) + + train_loader.restore_state_rank(state1) + + cmp_order2 = [text for idx, data in zip(range(55 * 10), train_loader) for text in data.text] + assert train_order2 == cmp_order2, (train_order1, cmp_order2) + + train_loader.shutdown() + + def test_dataloader_fork(self): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=1, + seed_offset=42, + ) + + # Train mode dataset + train_dataset = get_train_dataset( + self.ds1_path, + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + ) + assert len(train_dataset) == 6, len(train_dataset) + + train_loader1 = DataLoader( + train_dataset, + prefetch_factor=2, + worker_type=ForkDataLoaderWorker, + gc_collect_every_n_steps=10, + gc_freeze_at_start=True, + watchdog_timeout_seconds=60, + fail_on_timeout=True, + ) + + train_order1 = [ + text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text + ] + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(train_order1) == 55, len(train_order1) + assert len(Counter(train_order1)) == 55, Counter(train_order1) + assert all(v == 1 for v in Counter(train_order1).values()), Counter(train_order1) + + train_loader1.shutdown() + + def test_dataloader_thread(self): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=1, + seed_offset=42, + ) + + # Train mode dataset + train_dataset = get_train_dataset( + self.ds1_path, + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + ) + assert len(train_dataset) == 6, len(train_dataset) + + train_loader1 = DataLoader( + train_dataset, + prefetch_factor=2, + worker_type=ThreadDataLoaderWorker, + gc_collect_every_n_steps=0, + watchdog_timeout_seconds=60, + fail_on_timeout=True, + ) + + train_order1 = [ + text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text + ] + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(train_order1) == 55, len(train_order1) + assert len(Counter(train_order1)) == 55, Counter(train_order1) + assert all(v == 1 for v in Counter(train_order1).values()), Counter(train_order1) + + train_loader1.shutdown() + + +if __name__ == "__main__": + unittest.main() From 08094407593dade3d15d1fe6e17c5d5fcc17c6b1 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Thu, 10 Jul 2025 15:09:33 +0200 Subject: [PATCH 07/36] Add tests for save/restore for forking/threaded data loader --- tests/test_dataloader.py | 113 +++++++++++++++++++++++++++++---------- 1 file changed, 85 insertions(+), 28 deletions(-) diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index d54633c1..4c486f65 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -161,18 +161,15 @@ def test_dataloader_fork(self): ) # Train mode dataset - train_dataset = get_train_dataset( - self.ds1_path, - worker_config=worker_config, - batch_size=10, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - repeat=False, - ) - assert len(train_dataset) == 6, len(train_dataset) - - train_loader1 = DataLoader( - train_dataset, + train_loader = DataLoader( + get_train_dataset( + self.ds1_path, + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + ), prefetch_factor=2, worker_type=ForkDataLoaderWorker, gc_collect_every_n_steps=10, @@ -180,9 +177,10 @@ def test_dataloader_fork(self): watchdog_timeout_seconds=60, fail_on_timeout=True, ) + assert len(train_loader) == 6, len(train_loader) train_order1 = [ - text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text + text for idx, data in zip(range(55 * 10), train_loader) for text in data.text ] print(train_order1[:10]) print(Counter(train_order1)) @@ -190,7 +188,39 @@ def test_dataloader_fork(self): assert len(Counter(train_order1)) == 55, Counter(train_order1) assert all(v == 1 for v in Counter(train_order1).values()), Counter(train_order1) - train_loader1.shutdown() + state1 = train_loader.save_state_rank() + + train_order2 = [ + text for idx, data in zip(range(55 * 10), train_loader) for text in data.text + ] + + assert len(train_order1) == len(train_order2), (len(train_order1), len(train_order2)) + + train_loader.shutdown() + + train_loader = DataLoader( + get_train_dataset( + self.ds1_path, + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + ), + prefetch_factor=2, + worker_type=ForkDataLoaderWorker, + gc_collect_every_n_steps=10, + gc_freeze_at_start=True, + watchdog_timeout_seconds=60, + fail_on_timeout=True, + ) + + train_loader.restore_state_rank(state1) + + cmp_order2 = [text for idx, data in zip(range(55 * 10), train_loader) for text in data.text] + assert train_order2 == cmp_order2, (train_order1, cmp_order2) + + train_loader.shutdown() def test_dataloader_thread(self): torch.manual_seed(42) @@ -202,27 +232,25 @@ def test_dataloader_thread(self): ) # Train mode dataset - train_dataset = get_train_dataset( - self.ds1_path, - worker_config=worker_config, - batch_size=10, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - repeat=False, - ) - assert len(train_dataset) == 6, len(train_dataset) - - train_loader1 = DataLoader( - train_dataset, + train_loader = DataLoader( + get_train_dataset( + self.ds1_path, + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + ), prefetch_factor=2, worker_type=ThreadDataLoaderWorker, gc_collect_every_n_steps=0, watchdog_timeout_seconds=60, fail_on_timeout=True, ) + assert len(train_loader) == 6, len(train_loader) train_order1 = [ - text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text + text for idx, data in zip(range(55 * 10), train_loader) for text in data.text ] print(train_order1[:10]) print(Counter(train_order1)) @@ -230,7 +258,36 @@ def test_dataloader_thread(self): assert len(Counter(train_order1)) == 55, Counter(train_order1) assert all(v == 1 for v in Counter(train_order1).values()), Counter(train_order1) - train_loader1.shutdown() + state1 = train_loader.save_state_rank() + + train_order2 = [ + text for idx, data in zip(range(55 * 10), train_loader) for text in data.text + ] + + train_loader.shutdown() + + train_loader = DataLoader( + get_train_dataset( + self.ds1_path, + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + ), + prefetch_factor=2, + worker_type=ThreadDataLoaderWorker, + gc_collect_every_n_steps=0, + watchdog_timeout_seconds=60, + fail_on_timeout=True, + ) + + train_loader.restore_state_rank(state1) + + cmp_order2 = [text for idx, data in zip(range(55 * 10), train_loader) for text in data.text] + assert train_order2 == cmp_order2, (train_order1, cmp_order2) + + train_loader.shutdown() if __name__ == "__main__": From 54e98beef1cc067eb11247974837d50bcec639f6 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Thu, 24 Jul 2025 09:48:03 +0200 Subject: [PATCH 08/36] WIP Refactoring loader --- src/megatron/energon/dataloader.py | 1034 ----------------- src/megatron/energon/flavors/base_dataset.py | 21 +- src/megatron/energon/loader/__init__.py | 2 + src/megatron/energon/loader/dataloader.py | 500 ++++++++ src/megatron/energon/loader/future.py | 69 ++ src/megatron/energon/loader/pin_memory.py | 123 ++ .../energon/loader/workers/__init__.py | 2 + .../energon/loader/workers/async_worker.py | 263 +++++ .../energon/loader/workers/base_worker.py | 175 +++ .../energon/loader/workers/fork_worker.py | 157 +++ .../energon/loader/workers/thread_worker.py | 92 ++ 11 files changed, 1394 insertions(+), 1044 deletions(-) delete mode 100644 src/megatron/energon/dataloader.py create mode 100644 src/megatron/energon/loader/__init__.py create mode 100644 src/megatron/energon/loader/dataloader.py create mode 100644 src/megatron/energon/loader/future.py create mode 100644 src/megatron/energon/loader/pin_memory.py create mode 100644 src/megatron/energon/loader/workers/__init__.py create mode 100644 src/megatron/energon/loader/workers/async_worker.py create mode 100644 src/megatron/energon/loader/workers/base_worker.py create mode 100644 src/megatron/energon/loader/workers/fork_worker.py create mode 100644 src/megatron/energon/loader/workers/thread_worker.py diff --git a/src/megatron/energon/dataloader.py b/src/megatron/energon/dataloader.py deleted file mode 100644 index 43318c0f..00000000 --- a/src/megatron/energon/dataloader.py +++ /dev/null @@ -1,1034 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. -# SPDX-License-Identifier: BSD-3-Clause -import functools -import multiprocessing -import multiprocessing.managers -import os -import queue -import sys -import threading -import traceback -import warnings -from abc import abstractmethod -from typing import ( - Any, - Callable, - Generator, - Generic, - ParamSpec, - Protocol, - TypeVar, - override, -) - -from megatron.energon.cache.base import CachePool -from megatron.energon.edataclass import edataclass -from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key -from megatron.energon.rng import SystemRng -from megatron.energon.state import FlexState -from megatron.energon.worker import WorkerConfig -from megatron.energon.wrappers.base import BaseWrapperDataset, get_sample_restore_key -from megatron.energon.wrappers.batch_dataset import BatchDataset -from megatron.energon.wrappers.gc_dataset import GC_DEFAULT_EVERY_N_ITER, GcDataset, gc_init_worker -from megatron.energon.wrappers.watchdog_dataset import WatchdogDataset - -P = ParamSpec("P") -R = TypeVar("R", covariant=True) -T = TypeVar("T") -TSelf = TypeVar("TSelf", bound="DataLoaderWorker") -TSample = TypeVar("TSample") - - -class QueueProtocol(Protocol[TSample]): - def get(self, /) -> TSample: ... - - def put(self, item: TSample, /) -> None: ... - - -class Future(Protocol[R]): - def get(self) -> R: ... - - -class DoneFuture(Future[TSample]): - """Future that is already done.""" - - def __init__(self, result: TSample): - self._result = result - - def get(self) -> TSample: - return self._result - - -class CallableFuture(Future[R]): - """Future that calls a callable to get the result.""" - - _callable: Callable[[], R] - _value: R - _exception: Exception - - def __init__(self, callable: Callable[[], R]): - self._callable = callable - - def get(self) -> R: - if not hasattr(self, "_value") and not hasattr(self, "_exception"): - try: - self._value = self._callable() - except Exception as e: - self._exception = e - if hasattr(self, "_exception"): - raise self._exception - return self._value - - @staticmethod - def chain(future: Future[T], fn: Callable[[Future[T]], R]) -> Future[R]: - """ - Chain a function to a future. - - Args: - future: The future which provides the input for the function. - fn: The function to call on the result of the future, to transform the result. - - Returns: - A future that will be resolved to the result of the function given the result of the future. - """ - return CallableFuture(lambda: fn(future)) - - -class ExceptionFuture(Future[Any]): - """Future that raises an exception.""" - - def __init__(self, exception: Exception): - self._exception = exception - - def get(self) -> Any: - raise self._exception - - -class DataLoaderWorker(Generic[TSample]): - """ - A worker for a :class:`DataLoader`. - - The basic implementation iterates the dataset. - The async extension implements the main commands via a command and results queue. - """ - - dataset: SavableDataset[TSample] - worker_config: WorkerConfig - - _rank_worker_id: int - _global_worker_id: int - _seed: int - _cache_pool: CachePool | None - _sample_index: int = 0 - _exhausted: bool = True - - def __init__( - self, - dataset: SavableDataset[TSample], - worker_config: WorkerConfig, - rank_worker_id: int, - cache_pool: CachePool | None, - ): - """ - Initialize the worker. - - Args: - dataset: The dataset to iterate over. - worker_config: The worker configuration. - rank_worker_id: The rank of the worker. - cache_pool: The cache pool to use. - """ - self.dataset = dataset - self.worker_config = worker_config - self._rank_worker_id = rank_worker_id - self._global_worker_id = worker_config.global_worker_id(rank_worker_id) - self._seed = self.worker_config.worker_seed(rank_worker_id) - self._cache_pool = cache_pool - - # ------------------------------------------------------------------------------------------------ - # Section: Main control methods - - def start(self) -> None: - """ - Start the worker. - """ - pass - - def shutdown(self, in_del: bool = False) -> None: - """ - Shutdown the worker. - - Args: - in_del: If True, the worker is being deleted. - """ - pass - - def running(self) -> bool: - """ - Check if the worker is running. - """ - return True - - def _assert_running(self) -> None: - """ - Assert that the worker is running and alive. - """ - assert self.running(), "Worker must be running" - - def __del__(self) -> None: - self.shutdown(in_del=True) - - # ------------------------------------------------------------------------------------------------ - # Section: Worker methods - - def dataset_init(self, state: FlexState | None) -> None: - """ - Initialize the worker (may restore the state). - Calls `new_iter` if the worker is not exhausted and also initially (`state=None`). - - Args: - state: The state to restore the worker from or None for using the initial state. - """ - # This is called in the worker context (process/thread). - assert self._global_worker_id == self.worker_config.global_worker_id(), ( - "Global worker ID mismatch" - ) - assert self._seed == self.worker_config.worker_seed(self._rank_worker_id), "Seed mismatch" - print(f"dataset_init {state=}\n", end="") - if state is None: - self._sample_index = 0 - self.dataset.reset_state_deep() - print("dataset_init reset_state_deep\n", end="") - self.new_iter() - print("dataset_init new_iter\n", end="") - else: - assert state["__class__"] == "DataLoaderWorker", "state type mismatch" - self._sample_index = state["sample_index"] - SystemRng.restore_state(state["rng"]) - self.dataset.restore_state(state["dataset"]) - if not state["exhausted"]: - self.new_iter() - assert self._exhausted == state["exhausted"], "Exhausted state mismatch" - - def new_iter(self) -> None: - """ - Start a new iterator of the dataset. - Called after the dataset is initialized and to start a new epoch (if the dataset is not infinite). - The iterator is stored in the worker and is used by the `prefetch_next` method, which calls `next` on it. - Updates the exhausted flag to False. - """ - # This is called in the worker context (process/thread). - print("new_iter\n", end="") - self._dataset_iter = iter(self.dataset) - self._exhausted = False - print("new_iter done\n", end="") - - def prefetch_next(self) -> Future[TSample]: - """ - Fetch the next sample (i.e. call `next` on the iterator) and return a future for getting the result. - Updates the exhausted flag if the iterator is exhausted. - - Returns: - A future that will either be resolved to the next sample or raise StopIteration if the iterator is exhausted. - """ - # This is called in the worker context (process/thread). - assert self._dataset_iter is not None, "start_iter must be called before prefetch_next" - if self._exhausted: - try: - raise StopIteration() - except StopIteration as e: - return ExceptionFuture(e) - sample_idx = self._sample_index - self.worker_config.worker_activate(sample_idx, cache_pool=self._cache_pool) - try: - next_sample = next(self._dataset_iter) - self._sample_index += 1 - next_sample = add_sample_restore_key( - next_sample, self._global_worker_id, sample_idx, src=self - ) - except StopIteration as e: - self._exhausted = True - return ExceptionFuture(e) - finally: - self.worker_config.worker_deactivate() - return DoneFuture(next_sample) - - def save_state(self) -> FlexState: - """ - Save the state of the worker. - """ - # This is called in the worker context (process/thread). - return FlexState( - __class__="DataLoaderWorker", - rng=SystemRng.save_state(), - dataset=self.dataset.save_state(), - exhausted=self._exhausted, - sample_index=self._sample_index, - ) - - -class _DataLoaderAsynchronousWorker(DataLoaderWorker[TSample]): - """ - Extension of the `DataLoaderWorker`, which implements commands via a command and results queue. - - There are different implementations of the async worker: - - :class:`ForkDataLoaderWorker` - A worker that forks a new process for each worker. - - :class:`ThreadDataLoaderWorker` - A worker that uses threads to execute the commands. - """ - - _cmd_queue: QueueProtocol["WorkerCommand"] - _result_queue: QueueProtocol["WorkerResult"] - _next_future_id: int - _futures: dict[int, "FutureImpl"] - - def __init__( - self, - dataset: SavableDataset, - worker_config: WorkerConfig, - rank_worker_id: int, - cmd_queue: QueueProtocol["WorkerCommand"], - result_queue: QueueProtocol["WorkerResult"], - cache_pool: CachePool | None, - ): - super().__init__(dataset, worker_config, rank_worker_id, cache_pool) - assert worker_config.num_workers > 0, "Async workers require num_workers > 0" - self._cmd_queue = cmd_queue - self._result_queue = result_queue - self._next_future_id = 0 - self._futures = {} - - # ------------------------------------------------------------------------------------------------ - # Section: Remote call implementation - - @edataclass - class WorkerResult: - """Internal class for communicating a result from the worker via the result queue.""" - - future_id: int - result: Any = None - exception: Exception | None = None - - @edataclass - class WorkerCommand: - """Internal class for communicating a command to the worker via the command queue.""" - - cmd: str - args: tuple[Any, ...] - kwargs: dict[str, Any] - future_id: int - - class FutureImpl(Future[Any]): - """Class for returning a future result from the worker..""" - - _worker: "_DataLoaderAsynchronousWorker" - _future_id: int - _result: Any - _exception: Exception - - def __init__(self, worker: "_DataLoaderAsynchronousWorker", future_id: int): - self._worker = worker - self._future_id = future_id - - def get(self) -> Any: - if not hasattr(self, "_result") and not hasattr(self, "_exception"): - self._worker._wait_for_worker_result(self._future_id) - if hasattr(self, "_exception"): - raise self._exception - return self._result - - def _set_result(self, result: Any) -> None: - self._result = result - - def _set_exception(self, exception: Exception) -> None: - self._exception = exception - - def _wait_for_worker_result(self, future_id: int) -> None: - """ - Wait for the result of a future. - If another result comes first, update the corresponding future. - - Args: - future_id: The ID of the future to wait for. - """ - while True: - print(f"[fut={future_id}] waiting for result\n", end="") - res = self._result_queue.get() - fut = self._futures.pop(res.future_id) - if res.exception is not None: - fut._set_exception(res.exception) - else: - fut._set_result(res.result) - # self._result_queue.task_done() - if res.future_id == future_id: - print(f"[fut={future_id}] got result, return\n", end="") - return - else: - print(f"[fut={future_id}] got result for {res.future_id=}, continue\n", end="") - continue - - def _worker_call(self, fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> Future[R]: - """ - Call a function in the worker and return a future for getting the result. - The function must be an instance method of `self`. Uses the name to identify the function in the worker - instance. - - Args: - fn: The function to call. - *args: The arguments to pass to the function. - **kwargs: The keyword arguments to pass to the function. - """ - self._assert_running() - assert not self._in_worker(), "worker_call must not be called in the worker" - future_id = self._next_future_id - self._next_future_id += 1 - - self._futures[future_id] = future = self.FutureImpl(self, future_id) - print( - f"[wrk={self._rank_worker_id}] worker_call {fn.__name__=} {args=} {kwargs=} {future_id=}\n", - end="", - ) - self._cmd_queue.put( - self.WorkerCommand(cmd=fn.__name__, args=args, kwargs=kwargs, future_id=future_id) - ) - print(f"[wrk={self._rank_worker_id}] queue: {self._cmd_queue.qsize()}\n", end="") - return future - - def _worker_run( - self, cmd_queue: QueueProtocol[WorkerCommand], result_queue: QueueProtocol[WorkerResult] - ) -> None: - """ - The worker main loop. - It waits for commands via the command queue and executes them. - The functions to call are identified by their name. - The result of the call is put into the result queue. - The worker exits when the command `_shutdown_worker` is received. - - Args: - cmd_queue: The command queue to wait for commands. - result_queue: The result queue to put the results into. - """ - assert self._in_worker(), "_worker_run must be called in the worker" - try: - SystemRng.seed(self._seed) - import torch.utils.data._utils - - torch.utils.data._utils.worker._worker_info = torch.utils.data._utils.worker.WorkerInfo( - id=self._rank_worker_id, - num_workers=self.worker_config.num_workers, - seed=self._seed, - dataset=self.dataset, - ) - self._global_worker_id = self.worker_config.global_worker_id() - self.worker_config.assert_worker() - while True: - print( - f"[wrk={self._rank_worker_id}] waiting for command, len: {cmd_queue.qsize()}\n", - end="", - ) - cmd = cmd_queue.get() - print( - f"[fut={cmd.future_id}] got command {cmd.cmd=} {cmd.args=} {cmd.kwargs=}\n", - end="", - ) - try: - fn = getattr(self, cmd.cmd) - result = fn(*cmd.args, **cmd.kwargs) - except Exception as e: - print(f"[fut={cmd.future_id}] send exception {e!r}\n", end="") - result_queue.put(self.WorkerResult(future_id=cmd.future_id, exception=e)) - else: - print(f"[fut={cmd.future_id}] send result {result!r}\n", end="") - result_queue.put(self.WorkerResult(future_id=cmd.future_id, result=result)) - del result - # cmd_queue.task_done() - if cmd.cmd == self._wrk_shutdown_worker.__name__: - print(f"[fut={cmd.future_id}] got shutdown command, exit\n", end="") - break - print(f"[fut={cmd.future_id}] processed, waiting for next command\n", end="") - except: - traceback.print_exc() - raise - - @abstractmethod - def _in_worker(self) -> bool: - """Check if the execution is within the worker.""" - ... - - # ------------------------------------------------------------------------------------------------ - # Section: Worker methods - now calling to workers via queues. - - def _wrk_shutdown_worker(self) -> None: - """Does nothing. The actual shutdown is handled in the _worker_run method.""" - assert self._in_worker(), "_wrk_shutdown_worker must be called in the worker" - - def _shutdown_worker(self) -> None: - """Shutdown the worker. The actual shutdown is handled in the _worker_run method.""" - assert not self._in_worker(), "shutdown_worker must not be called in the worker" - # This is not actually a recursive call, because the worker loop will exit before calling this method. - self._worker_call(self._wrk_shutdown_worker).get() - - def _wrk_prefetch_next(self) -> TSample: - """Wraps the super class method to call it in the worker process.""" - # The super class implementation already returns a resolved future (to be interface compatible), - # so immediately resolve the future to the result (get returns immediately). - return super().prefetch_next().get() - - @override - def dataset_init(self, initial_state: FlexState | None) -> None: - if self._in_worker(): - return super().dataset_init(initial_state) - else: - return self._worker_call(self.dataset_init, initial_state).get() - - @override - def new_iter(self) -> None: - if self._in_worker(): - return super().new_iter() - else: - return self._worker_call(self.new_iter).get() - - @override - def prefetch_next(self) -> Future[TSample]: - # Do not resolve the future here, but return it. - if self._in_worker(): - return super().prefetch_next() - return self._worker_call(self._wrk_prefetch_next) - - @override - def save_state(self) -> FlexState: - if self._in_worker(): - return super().save_state() - else: - return self._worker_call(self.save_state).get() - - -class ForkDataLoaderWorker(_DataLoaderAsynchronousWorker[TSample], Generic[TSample]): - """ - Implements the `DataLoaderWorker` interface using processes. - """ - - _process: multiprocessing.Process | None = None - _cmd_queue: multiprocessing.Queue - _result_queue: multiprocessing.Queue - - _threaded_shutdown: threading.Thread | None = None - - _spawning_process: int - - def __init__( - self, - dataset: SavableDataset, - worker_config: WorkerConfig, - rank_worker_id: int, - cache_pool: CachePool | None, - ): - multiprocessing.set_start_method("fork", force=True) - super().__init__( - dataset, - worker_config=worker_config, - rank_worker_id=rank_worker_id, - cmd_queue=multiprocessing.Queue(), - result_queue=multiprocessing.Queue(), - cache_pool=cache_pool, - ) - self._spawning_process = os.getpid() - - def _check_parent_process(self, evt_exit: threading.Event) -> None: - """Check if the parent process is alive. If it is dead, exit the worker process.""" - parent_proc = multiprocessing.parent_process() - parent_pid = os.getppid() - if parent_proc is None: - print("No parent process, exiting", file=sys.stderr) - os._exit(-1) - while not evt_exit.wait(1): - if parent_proc.exitcode is not None or os.getppid() != parent_pid: - print("Parent process died, exiting", file=sys.stderr) - os._exit(-1) - - def _worker_run( - self, - cmd_queue: multiprocessing.Queue, - result_queue: multiprocessing.Queue, - ) -> None: - gc_init_worker(self._rank_worker_id) - # cmd_queue is read only, so we can cancel the join thread. - cmd_queue.cancel_join_thread() - worker_exit_evt = threading.Event() - parent_check_thread = threading.Thread( - target=self._check_parent_process, args=(worker_exit_evt,), daemon=True - ) - parent_check_thread.start() - try: - super()._worker_run(cmd_queue, result_queue) - finally: - print(f"[wrk={self._rank_worker_id}] shutting down\n", end="") - worker_exit_evt.set() - print( - f"[wrk={self._rank_worker_id}] shutting down, wait for parent_check_thread\n", - end="", - ) - parent_check_thread.join() - print(f"[wrk={self._rank_worker_id}] shutting down, close queues\n", end="") - result_queue.close() - result_queue.join_thread() - cmd_queue.close() - cmd_queue.cancel_join_thread() - print(f"[wrk={self._rank_worker_id}] shutting down, done\n", end="") - - @override - def _in_worker(self) -> bool: - return multiprocessing.current_process() == self._process - - @override - def start(self) -> None: - self._process = multiprocessing.Process( - target=self._worker_run, - args=(self._cmd_queue, self._result_queue), - daemon=True, - ) - self._process.start() - - @override - def shutdown(self, in_del: bool = False) -> None: - if self._spawning_process != os.getpid(): - # Should avoid forked process containing a forked worker on exit. - warnings.warn( - "Shutting down worker from a different process than the one that spawned it, skipping" - ) - return - if self._process is not None: - if in_del: - # It seems that the ResourceWarning does not work in the gc loop? Also print a warning here. - warnings.warn( - "Explicitly call DataLoader.shutdown() to avoid leaking processes. Terminating worker process.", - ResourceWarning, - ) - print( - "WARNING: Explicitly call DataLoader.shutdown() to avoid leaking processes. Terminating worker process.\n", - end="", - file=sys.stderr, - ) - self._cmd_queue.close() - self._cmd_queue.cancel_join_thread() - self._result_queue.close() - self._result_queue.cancel_join_thread() - # Kill the process, because we cannot communicate with it in the gc loop. - self._process.terminate() - self._process = None - else: - try: - self._shutdown_worker() - except Exception: - self._process.join(10) - if self._process.is_alive(): - self._process.terminate() - else: - self._process.join() - assert self._process.exitcode == 0, ( - f"Process exit code {self._process.exitcode}" - ) - self._process = None - self._cmd_queue.close() - self._cmd_queue.cancel_join_thread() - self._result_queue.close() - self._result_queue.cancel_join_thread() - - @override - def running(self) -> bool: - return self._process is not None - - def _assert_running(self) -> None: - assert self._process is not None, "Worker must be started first" - assert self._process.is_alive(), "Worker died" - - -class ThreadDataLoaderWorker(_DataLoaderAsynchronousWorker[TSample], Generic[TSample]): - """ - Implements the `DataLoaderWorker` interface using threads. - """ - - _thread: threading.Thread | None = None - _cmd_queue: queue.Queue - _result_queue: queue.Queue - - def __init__( - self, - dataset: SavableDataset, - worker_config: WorkerConfig, - rank_worker_id: int, - cache_pool: CachePool | None, - ): - super().__init__( - dataset, - worker_config=worker_config, - rank_worker_id=rank_worker_id, - cmd_queue=queue.Queue(), - result_queue=queue.Queue(), - cache_pool=cache_pool, - ) - - def _worker_run(self, cmd_queue: queue.Queue, result_queue: queue.Queue) -> None: - super()._worker_run(cmd_queue, result_queue) - - @override - def _in_worker(self) -> bool: - return threading.current_thread() == self._thread - - @override - def start(self) -> None: - self._thread = threading.Thread( - target=self._worker_run, - args=(self._cmd_queue, self._result_queue), - daemon=True, - ) - self._thread.start() - - @override - def shutdown(self, in_del: bool = False) -> None: - if self._thread is not None: - if in_del: - # It seems that the ResourceWarning does not work in the gc loop? Also print a warning here. - warnings.warn( - "Explicitly call DataLoader.shutdown() to avoid leaking threads.", - ResourceWarning, - ) - print( - "WARNING: Explicitly call DataLoader.shutdown() to avoid leaking threads.\n", - end="", - file=sys.stderr, - ) - # Just try to enqueue the shutdown command to the thread and hope for the best. Ignore the result. - self._cmd_queue.put( - self.WorkerCommand( - cmd=self._wrk_shutdown_worker.__name__, args=(), kwargs={}, future_id=-1 - ) - ) - self._thread = None - else: - self._shutdown_worker() - self._thread.join() - self._thread = None - - @override - def running(self) -> bool: - return self._thread is not None - - @override - def _assert_running(self) -> None: - assert self._thread is not None, "Thread must be started first" - assert self._thread.is_alive(), "Thread died" - - -class WorkerType(Protocol[TSample]): - def __call__( - self, - dataset: SavableDataset, - worker_config: WorkerConfig, - rank_worker_id: int, - cache_pool: CachePool | None, - ) -> DataLoaderWorker[TSample]: ... - - -class DataLoader(Generic[TSample]): - _workers: list[DataLoaderWorker[TSample]] | None = None - _exhausted_workers: list[bool] - _next_worker_id: int = 0 - - _restore_state: FlexState | None = None - - _dataset: SavableDataset - _worker_config: WorkerConfig - _prefetch_factor: int - _worker_type: WorkerType - _prefetching_samples: list[list[Future[TSample]]] - - _current_epoch_iter: Generator[TSample, None, None] | None = None - - _spawning_process: int - - def __init__( - self, - dataset: SavableDataset, - *, - prefetch_factor: int = 1, - worker_type: WorkerType = ForkDataLoaderWorker, - cache_pool: CachePool | None = None, - # Garbage collection configuration - gc_collect_every_n_steps: int = GC_DEFAULT_EVERY_N_ITER, - gc_freeze_at_start: bool = True, - # Watchdog configuration - watchdog_timeout_seconds: float | None = 60, - watchdog_initial_timeout_seconds: float | None = None, - fail_on_timeout: bool = False, - ): - """ - Create the dataloader supporting saving and restoring the state. - - Args: - dataset: The dataset to load. - prefetch_factor: The number of samples to prefetch from each worker. - worker_type: The type of worker to use. - cache_pool: If set, the cache pool to use for the dataset. - gc_collect_every_n_steps: The number of steps after which the garbage collector is - called. As we're usually handling large (but few) tensors here, and the python - garbage collection is already full of objects just by importing, this can improve - the memory footprint quite a lot, and may even be necessary to avoid memory - overflow. - gc_freeze_at_start: If true, the garbage collector is frozen at the start of the worker - processes. This improves the garbage collection performance by a lot. - In rare cases, this may cause issues and can be disabled. Keep enabled if you - experience no issues. - watchdog_timeout_seconds: The timeout in seconds. If None, the watchdog is disabled. - watchdog_initial_timeout_seconds: The initial timeout in seconds. If None, the timeout is the same as watchdog_timeout_seconds. - fail_on_timeout: If True, stops the whole process upon timeout, after printing a stack trace. - """ - if dataset.worker_config.num_workers == 0 and worker_type == ForkDataLoaderWorker: - worker_type = DataLoaderWorker - - if watchdog_timeout_seconds is not None: - dataset = WatchdogDataset( - dataset, - worker_config=dataset.worker_config, - timeout_seconds=watchdog_timeout_seconds, - initial_timeout_seconds=watchdog_initial_timeout_seconds, - fail_on_timeout=fail_on_timeout, - ) - - if gc_collect_every_n_steps > 0: - dataset = GcDataset( - dataset, - worker_config=dataset.worker_config, - every_n_iter=gc_collect_every_n_steps, - freeze=gc_freeze_at_start, - ) - - self._dataset = dataset - self._worker_config = dataset.worker_config - self._prefetch_factor = prefetch_factor - self._worker_type = worker_type - self._cache_pool = cache_pool - self._prefetching_samples = [[] for _ in range(self._worker_config.safe_num_workers)] - self._exhausted_workers = [False] * self._worker_config.safe_num_workers - - if self._worker_config.num_workers == 0: - assert prefetch_factor == 1, "prefetch_factor must be 1 for num_workers == 0" - else: - assert prefetch_factor > 0, "prefetch_factor must be > 0 for num_workers > 0" - - self._spawning_process = os.getpid() - - def shutdown(self, in_del: bool = False) -> None: - if self._workers is not None: - if in_del: - warnings.warn( - "Explicitly call DataLoader.shutdown() to avoid leaking workers.", - ResourceWarning, - ) - print( - "WARNING: Explicitly call DataLoader.shutdown() to avoid leaking workers.\n", - end="", - file=sys.stderr, - ) - for worker in self._workers: - worker.shutdown(in_del=in_del) - self._workers = None - - def __del__(self) -> None: - self.shutdown(in_del=True) - - def start_iter(self) -> None: - if self._workers is not None: - for worker in self._workers: - worker.new_iter() - - def _epoch_iter(self) -> Generator[TSample, None, None]: - """Iterate over the dataset for one epoch (i.e. all workers StopIteration). - One epoch may also be infinite (if looping the dataset).""" - if self._workers is None: - self._start() - assert self._workers is not None, "DataLoader not started" - - if all(self._exhausted_workers): - # All workers are exhausted, restart for the next epoch. - for worker in self._workers: - worker.new_iter() - self._exhausted_workers = [False] * self._worker_config.safe_num_workers - - # For all workers, enqueue prefetching samples. - for worker_idx, (worker, exhausted) in enumerate( - zip(self._workers, self._exhausted_workers) - ): - while ( - len(self._prefetching_samples[worker_idx]) < self._prefetch_factor and not exhausted - ): - self._prefetching_samples[worker_idx].append(worker.prefetch_next()) - - # Main loop: - # - Get the next worker to prefetch samples from. - # - Prefetch samples from the worker. - # - Pop the first sample future from the prefetching samples. - # - Get the sample from the sample future (may wait for the sample to be prefetched). - # - Yield the sample. - print(f"{self._exhausted_workers=}\n", end="") - while not all(self._exhausted_workers): - # Get the next worker to prefetch samples from. - worker_idx = self._next_worker_id - worker = self._workers[worker_idx] - print(f"{worker_idx=} {worker=}\n", end="") - self._next_worker_id = (worker_idx + 1) % self._worker_config.safe_num_workers - if self._exhausted_workers[worker_idx]: - print(f"{worker_idx=} exhausted, continue with next worker\n", end="") - continue - # Pop the first sample future from the prefetching samples. - sample_future = self._prefetching_samples[worker_idx].pop(0) - print(f"{sample_future=}\n", end="") - # Prefetch samples from the worker. - while len(self._prefetching_samples[worker_idx]) < self._prefetch_factor: - # Add a new sample future to the prefetching samples if the worker has not prefetched enough samples. - self._prefetching_samples[worker_idx].append(worker.prefetch_next()) - try: - # Get the sample from the sample future (may wait for the sample to be ready). - sample = sample_future.get() - except StopIteration: - print(f"{worker_idx=} exhausted, remove from prefetching samples\n", end="") - # If the sample future raises StopIteration, remove the worker from the list. - self._prefetching_samples[worker_idx] = [] - self._exhausted_workers[worker_idx] = True - continue - else: - print(f"{worker_idx=} got sample, yield\n", end="") - # Yield the sample. - yield sample - - def __iter__(self) -> Generator[TSample, None, None]: - # Restart the epoch iterator if was not created yet. Otherwise, the existing epoch iterator will be continued. - # That happens e.g. when iteration was interrupted. - if self._current_epoch_iter is None: - self._current_epoch_iter = self._epoch_iter() - yield from self._current_epoch_iter - # Reset the epoch iterator, it was exhausted. - self._current_epoch_iter = None - - def __len__(self): - return len(self._dataset) - - def _get_batch_size(self) -> int | None: - """Try to infer micro batch size from the dataset""" - if ( - isinstance(self._dataset, BaseWrapperDataset) - and (bds := self._dataset._find_wrapped_dataset(BatchDataset)) is not None - ): - assert isinstance(bds, BatchDataset) - return bds.batch_size - else: - return None - - def save_state_rank(self) -> FlexState: - # TODO: The redist tool must be able to change the batch size. - # That means that the redist tool shall split a saved restore key for the "BatchDataset". - # It should also change the saved micro batch size to match that. - # TODO @pfischer: Add changing the batch size to the docs. - prefetched_samples_keys = [ - [get_sample_restore_key(sample_fut.get()) for sample_fut in prefetching_sample] - for prefetching_sample in self._prefetching_samples - ] - if self._workers is None: - worker_states = [None] * self._worker_config.safe_num_workers - else: - worker_states = [worker.save_state() for worker in self._workers] - - return FlexState( - __class__=type(self).__name__, - prefetched_samples_keys=prefetched_samples_keys, - worker_states=worker_states, - next_worker_id=self._next_worker_id, - micro_batch_size=self._get_batch_size(), - ) - - def _start(self, initial_state: FlexState | None = None) -> None: - self._workers = [ - self._worker_type(self._dataset, self._worker_config, local_worker_id, self._cache_pool) - for local_worker_id in range(self._worker_config.safe_num_workers) - ] - for worker in self._workers: - worker.start() - - if initial_state is None: - if self._restore_state is not None: - initial_state = self._restore_state - self._restore_state = None - - if initial_state is None: - worker_states = [None] * self._worker_config.safe_num_workers - else: - worker_states = initial_state["worker_states"] - - assert len(worker_states) == self._worker_config.safe_num_workers, ( - "Number of initial states must match number of workers" - ) - - for worker, worker_state in zip(self._workers, worker_states): - worker.dataset_init(worker_state) - - if initial_state is not None: - self._prefetching_samples = [ - [ - CallableFuture(functools.partial(self.restore_sample, sample_key)) - for sample_key in prefetched_samples_keys - ] - for prefetched_samples_keys in initial_state["prefetched_samples_keys"] - ] - self._next_worker_id = initial_state["next_worker_id"] - self._exhausted_workers = [ - False if worker_state is None else worker_state["exhausted"] - for worker_state in worker_states - ] - - def restore_state_rank(self, state: FlexState | None) -> None: - """ - Restore the state of the DataLoader on the current rank. - The state is actually restored when the processes are started, in the iterator. - """ - assert self._workers is None and self._current_epoch_iter is None, ( - "DataLoader already started" - ) - assert self._restore_state is None, "Restore state already set" - - if state is None: - # Assume initial state. - return - - assert isinstance(state, FlexState) - assert state["__class__"] == type(self).__name__, "DataLoader type mismatch" - assert state["micro_batch_size"] == self._get_batch_size(), "Micro batch size mismatch" - - self._restore_state = state - - def restore_sample(self, restore_key: tuple) -> TSample: - """ - Restore a sample from a restore key. - - Args: - restore_key: The restore key to restore the sample from. - - Returns: - The restored sample. - """ - id, global_worker_id, sample_idx = restore_key[:3] - assert id == type(self).__name__ - restore_key = restore_key[3:] - self._worker_config.worker_activate( - sample_idx, override_global_rank=global_worker_id, cache_pool=self._cache_pool - ) - try: - return add_sample_restore_key( - self._dataset.restore_sample(restore_key), global_worker_id, sample_idx, src=self - ) - finally: - self._worker_config.worker_deactivate() - - def config(self) -> dict[str, Any]: - return self._dataset.config() - - def __str__(self) -> str: - return f"DataLoader(prefetch_factor={self._prefetch_factor}, worker_type={self._worker_type.__name__})" diff --git a/src/megatron/energon/flavors/base_dataset.py b/src/megatron/energon/flavors/base_dataset.py index f51aef01..928092be 100644 --- a/src/megatron/energon/flavors/base_dataset.py +++ b/src/megatron/energon/flavors/base_dataset.py @@ -41,32 +41,33 @@ class PinMemoryMixin: """A mixin class providing a generic `pin_memory` function.""" - def _pin_memory(self, batch: T, device: Union[torch.device, str, None] = None) -> T: + @classmethod + def sample_pin_memory(cls, batch: T, device: Union[torch.device, str, None] = None) -> T: """Pin memory of a batch. Uses recursion to handle nested structures. Supports nested structures of dicts, dataclasses, namedtuples, lists and tuples.""" - if isinstance(batch, torch.Tensor): + if hasattr(batch, "pin_memory"): return batch.pin_memory(device) - elif isinstance(batch, dict): - return {key: self._pin_memory(value, device) for key, value in batch.items()} + if isinstance(batch, dict): + return {key: cls.sample_pin_memory(value, device) for key, value in batch.items()} elif dataclasses.is_dataclass(batch): return type(batch)( **{ - field.name: self._pin_memory(getattr(batch, field.name), device) + field.name: cls.sample_pin_memory(getattr(batch, field.name), device) for field in dataclasses.fields(batch) } ) - elif isinstance(batch, (tuple, list)): + elif not isinstance(batch, (str, bytes)) and isinstance(batch, (tuple, list)): if hasattr(batch, "_fields"): # NamedTuple - return type(batch)(*[self._pin_memory(val, device) for val in batch]) + return type(batch)(*[cls.sample_pin_memory(val, device) for val in batch]) else: # list / tuple - return type(batch)(self._pin_memory(val, device) for val in batch) + return type(batch)(cls.sample_pin_memory(val, device) for val in batch) else: return batch - def pin_memory(self: Self) -> Self: - return self._pin_memory(self) + def pin_memory(self: Self, device: torch.device | str | None = None) -> Self: + return self.sample_pin_memory(self, device) class ExtendableDataclassMixin: diff --git a/src/megatron/energon/loader/__init__.py b/src/megatron/energon/loader/__init__.py new file mode 100644 index 00000000..43842921 --- /dev/null +++ b/src/megatron/energon/loader/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause diff --git a/src/megatron/energon/loader/dataloader.py b/src/megatron/energon/loader/dataloader.py new file mode 100644 index 00000000..4a2f3fb6 --- /dev/null +++ b/src/megatron/energon/loader/dataloader.py @@ -0,0 +1,500 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +import functools +import os +import sys +import warnings +from typing import ( + Any, + Generator, + Generic, + Literal, + Protocol, + Sequence, + TypeVar, +) + +import torch.distributed + +from megatron.energon.cache.base import CachePool +from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key +from megatron.energon.loader.future import CallableFuture, Future +from megatron.energon.loader.pin_memory import NoPinMemory, PinMemory, PinMemoryThread +from megatron.energon.loader.workers.base_worker import DataLoaderWorker +from megatron.energon.loader.workers.fork_worker import ForkDataLoaderWorker +from megatron.energon.state import FlexState +from megatron.energon.worker import WorkerConfig +from megatron.energon.wrappers.base import BaseWrapperDataset, get_sample_restore_key +from megatron.energon.wrappers.batch_dataset import BatchDataset +from megatron.energon.wrappers.gc_dataset import GC_DEFAULT_EVERY_N_ITER, GcDataset +from megatron.energon.wrappers.watchdog_dataset import WatchdogDataset + +TSample = TypeVar("TSample", covariant=True) + + +class WorkerType(Protocol[TSample]): + """Protocol for a worker type, i.e. for the constructor of a worker class.""" + + def __call__( + self, + dataset: SavableDataset, + worker_config: WorkerConfig, + rank_worker_id: int, + cache_pool: CachePool | None, + ) -> DataLoaderWorker[TSample]: ... + + +class DataLoader(Generic[TSample]): + _workers: list[DataLoaderWorker[TSample]] | None = None + _exhausted_workers: list[bool] + _next_worker_id: int = 0 + + _restore_state: FlexState | None = None + + _dataset: SavableDataset + _worker_config: WorkerConfig + _prefetch_factor: int + _worker_type: WorkerType + _prefetching_samples: list[list[Future[TSample]]] + _pin_memory: PinMemory[TSample] + + _current_epoch_iter: Generator[TSample, None, None] | None = None + + _spawning_process: int + + def __init__( + self, + dataset: SavableDataset, + *, + prefetch_factor: int = 1, + worker_type: WorkerType = ForkDataLoaderWorker, + cache_pool: CachePool | None = None, + # Garbage collection configuration + gc_collect_every_n_steps: int = GC_DEFAULT_EVERY_N_ITER, + gc_freeze_at_start: bool = True, + # Watchdog configuration + watchdog_timeout_seconds: float | None = 60, + watchdog_initial_timeout_seconds: float | None = None, + fail_on_timeout: bool = False, + # Pin memory configuration + pin_memory: PinMemory[TSample] | None | Literal["automatic"] = "automatic", + ): + """ + Create the dataloader supporting saving and restoring the state. + + Args: + dataset: The dataset to load. + prefetch_factor: The number of samples to prefetch from each worker. + worker_type: The type of worker to use. + cache_pool: If set, the cache pool to use for the dataset. + gc_collect_every_n_steps: The number of steps after which the garbage collector is + called. As we're usually handling large (but few) tensors here, and the python + garbage collection is already full of objects just by importing, this can improve + the memory footprint quite a lot, and may even be necessary to avoid memory + overflow. + gc_freeze_at_start: If true, the garbage collector is frozen at the start of the worker + processes. This improves the garbage collection performance by a lot. + In rare cases, this may cause issues and can be disabled. Keep enabled if you + experience no issues. + watchdog_timeout_seconds: The timeout in seconds. If `None`, the watchdog is disabled. + watchdog_initial_timeout_seconds: The initial timeout in seconds. If `None`, the timeout is the same as `watchdog_timeout_seconds`. + fail_on_timeout: If True, stops the whole process upon timeout, after printing a stack trace. + pin_memory: The memory pinner to use. If `None`, no memory is not pinned. If "automatic", the memory is pinned automatically if cuda is available. + """ + if dataset.worker_config.num_workers == 0 and worker_type == ForkDataLoaderWorker: + worker_type = DataLoaderWorker + + if watchdog_timeout_seconds is not None: + dataset = WatchdogDataset( + dataset, + worker_config=dataset.worker_config, + timeout_seconds=watchdog_timeout_seconds, + initial_timeout_seconds=watchdog_initial_timeout_seconds, + fail_on_timeout=fail_on_timeout, + ) + + if gc_collect_every_n_steps > 0: + dataset = GcDataset( + dataset, + worker_config=dataset.worker_config, + every_n_iter=gc_collect_every_n_steps, + freeze=gc_freeze_at_start, + ) + + self._dataset = dataset + self._worker_config = dataset.worker_config + self._prefetch_factor = prefetch_factor + self._worker_type = worker_type + self._cache_pool = cache_pool + self._prefetching_samples = [[] for _ in range(self._worker_config.safe_num_workers)] + self._exhausted_workers = [False] * self._worker_config.safe_num_workers + if pin_memory == "automatic": + # Automatic pinning + if torch.cuda.is_available(): + # Use cuda + self._pin_memory = PinMemoryThread(torch.device("cuda")) + else: + self._pin_memory = NoPinMemory() + else: + if pin_memory is None: + self._pin_memory = NoPinMemory() + else: + self._pin_memory = pin_memory + + if self._worker_config.num_workers == 0: + assert prefetch_factor == 1, "prefetch_factor must be 1 for num_workers == 0" + else: + assert prefetch_factor > 0, "prefetch_factor must be > 0 for num_workers > 0" + + self._spawning_process = os.getpid() + + def shutdown(self, in_del: bool = False) -> None: + if self._workers is not None: + if in_del: + warnings.warn( + "Explicitly call DataLoader.shutdown() to avoid leaking workers.", + ResourceWarning, + ) + print( + "WARNING: Explicitly call DataLoader.shutdown() to avoid leaking workers.\n", + end="", + file=sys.stderr, + ) + for worker in self._workers: + worker.shutdown(in_del=in_del) + self._workers = None + if self._pin_memory is not None: + self._pin_memory.shutdown() + + def __del__(self) -> None: + self.shutdown(in_del=True) + + def start_iter(self) -> None: + if self._workers is not None: + for worker in self._workers: + worker.new_iter() + + def _epoch_iter(self) -> Generator[TSample, None, None]: + """Iterate over the dataset for one epoch (i.e. all workers StopIteration). + One epoch may also be infinite (if looping the dataset).""" + if self._workers is None: + self._start() + assert self._workers is not None, "DataLoader not started" + + if all(self._exhausted_workers): + # All workers are exhausted, restart for the next epoch. + for worker in self._workers: + worker.new_iter() + self._exhausted_workers = [False] * self._worker_config.safe_num_workers + + # For all workers, enqueue prefetching samples. + for worker_idx, (worker, exhausted) in enumerate( + zip(self._workers, self._exhausted_workers) + ): + while ( + len(self._prefetching_samples[worker_idx]) < self._prefetch_factor and not exhausted + ): + self._prefetching_samples[worker_idx].append( + self._pin_memory(worker.prefetch_next()) + ) + + # Main loop: + # - Get the next worker to prefetch samples from. + # - Prefetch samples from the worker. + # - Pop the first sample future from the prefetching samples. + # - Get the sample from the sample future (may wait for the sample to be prefetched). + # - Yield the sample. + print(f"{self._exhausted_workers=}\n", end="") + while not all(self._exhausted_workers): + # Get the next worker to prefetch samples from. + worker_idx = self._next_worker_id + worker = self._workers[worker_idx] + print(f"{worker_idx=} {worker=}\n", end="") + self._next_worker_id = (worker_idx + 1) % self._worker_config.safe_num_workers + if self._exhausted_workers[worker_idx]: + print(f"{worker_idx=} exhausted, continue with next worker\n", end="") + continue + # Pop the first sample future from the prefetching samples. + sample_future = self._prefetching_samples[worker_idx].pop(0) + print(f"{sample_future=}\n", end="") + # Prefetch samples from the worker. + while len(self._prefetching_samples[worker_idx]) < self._prefetch_factor: + # Add a new sample future to the prefetching samples if the worker has not prefetched enough samples. + self._prefetching_samples[worker_idx].append( + self._pin_memory(worker.prefetch_next()) + ) + try: + # Get the sample from the sample future (may wait for the sample to be ready). + sample = sample_future.get() + except StopIteration: + print(f"{worker_idx=} exhausted, remove from prefetching samples\n", end="") + # If the sample future raises StopIteration, remove the worker from the list. + self._prefetching_samples[worker_idx] = [] + self._exhausted_workers[worker_idx] = True + continue + else: + print(f"{worker_idx=} got sample, yield\n", end="") + # Yield the sample. + yield sample + + def __iter__(self) -> Generator[TSample, None, None]: + # Restart the epoch iterator if was not created yet. Otherwise, the existing epoch iterator will be continued. + # That happens e.g. when iteration was interrupted. + if self._current_epoch_iter is None: + self._current_epoch_iter = self._epoch_iter() + assert self._current_epoch_iter is not None + yield from self._current_epoch_iter + # Reset the epoch iterator, it was exhausted. + self._current_epoch_iter.close() + self._current_epoch_iter = None + + def __len__(self): + return len(self._dataset) + + def _get_batch_size(self) -> int | None: + """Try to infer micro batch size from the dataset""" + if ( + isinstance(self._dataset, BaseWrapperDataset) + and (bds := self._dataset._find_wrapped_dataset(BatchDataset)) is not None + ): + assert isinstance(bds, BatchDataset) + return bds.batch_size + else: + return None + + def save_state_rank(self) -> FlexState: + # TODO: The redist tool must be able to change the batch size. + # That means that the redist tool shall split a saved restore key for the "BatchDataset". + # It should also change the saved micro batch size to match that. + # TODO @pfischer: Add changing the batch size to the docs. + prefetched_samples_keys = [ + [get_sample_restore_key(sample_fut.get()) for sample_fut in prefetching_sample] + for prefetching_sample in self._prefetching_samples + ] + if self._workers is None: + worker_states = [None] * self._worker_config.safe_num_workers + else: + worker_states = [worker.save_state() for worker in self._workers] + + return FlexState( + __class__=type(self).__name__, + prefetched_samples_keys=prefetched_samples_keys, + worker_states=worker_states, + next_worker_id=self._next_worker_id, + micro_batch_size=self._get_batch_size(), + ) + + def save_state_global(self, global_dst_rank: int) -> Sequence[FlexState | None] | None: + """ + Saves the state of the dataset globally, collecting the state from all ranks using torch + distributed. Allows for restoring the state later using `restore_state_global`, given the + result of this method. + Typical scenario: Save the state to disk only on the `dst_rank`, the other ranks do not + save the state. Later, restore the state either only loaded on the `dst_rank` or + loading on all ranks separately using `restore_state_global`. + + Note: If you want to save/restore the state per rank separately, use `save_state_rank` and + the corresponding `restore_state_rank`. Also, these do not rely on torch distributed. + + Args: + global_dst_rank: The state will be gathered to this rank. The rank refers to the + global rank, not the rank within the data parallel group. + + Returns: + The state of the dataset (or `None`, if not on `dst_rank`). + """ + # Fetch current rank's worker's state + merged_state = self.save_state_rank() + + # Gather the merged states + if self._worker_config.world_size > 1: + output: Sequence[FlexState | None] | None + if self._worker_config.global_rank() == global_dst_rank: + output = [None] * self._worker_config.world_size + else: + # Check if the global_dst_rank is in the same group at all + if self._worker_config.data_parallel_group is not None: + try: + _ = torch.distributed.get_group_rank( + self._worker_config.data_parallel_group, global_dst_rank + ) + except RuntimeError: + raise ValueError( + f"global_dst_rank {global_dst_rank} is not in the group of the current rank's worker config" + ) + + output = None + + torch.distributed.gather_object( + merged_state, + output, + global_dst_rank, + group=self._worker_config.data_parallel_group, + ) + + return output + else: + # Not distributed -> return the merged state + return [merged_state] + + def _start(self, initial_state: FlexState | None = None) -> None: + self._workers = [ + self._worker_type(self._dataset, self._worker_config, local_worker_id, self._cache_pool) + for local_worker_id in range(self._worker_config.safe_num_workers) + ] + for worker in self._workers: + worker.start() + + if initial_state is None: + if self._restore_state is not None: + initial_state = self._restore_state + self._restore_state = None + + if initial_state is None: + worker_states = [None] * self._worker_config.safe_num_workers + else: + worker_states = initial_state["worker_states"] + + assert len(worker_states) == self._worker_config.safe_num_workers, ( + "Number of initial states must match number of workers" + ) + + for worker, worker_state in zip(self._workers, worker_states): + worker.dataset_init(worker_state) + + if initial_state is not None: + self._prefetching_samples = [ + [ + CallableFuture(functools.partial(self.restore_sample, sample_key)) + for sample_key in prefetched_samples_keys + ] + for prefetched_samples_keys in initial_state["prefetched_samples_keys"] + ] + self._next_worker_id = initial_state["next_worker_id"] + self._exhausted_workers = [ + False if worker_state is None else worker_state["exhausted"] + for worker_state in worker_states + ] + + def restore_state_rank(self, state: FlexState | None) -> None: + """ + Restore the state of the DataLoader on the current rank. + The state is actually restored when the processes are started, in the iterator. + """ + assert self._workers is None and self._current_epoch_iter is None, ( + "Cannot restore state while workers are running" + ) + assert self._restore_state is None, "Restore state already set" + + if state is None: + # Assume initial state. + return + + assert isinstance(state, FlexState) + assert state["__class__"] == type(self).__name__, "DataLoader type mismatch" + assert state["micro_batch_size"] == self._get_batch_size(), "Micro batch size mismatch" + + self._restore_state = state + + def restore_state_global( + self, + state: Sequence[FlexState | None] | None, + *, + src_rank: int | None = None, + ) -> None: + """ + Restores the saved state from `save_state_global` (in torch distributed setup). + The global state needs be loaded on every rank that has a data loader instance. + + Optionally, one can specify a src_rank and only provide the state once. + In case of multiple data parallel groups, you must provide the state once + in each data parallel group. In this case the `src_rank` is the rank within the + data parallel group. + + Args: + state: The state to restore, as saved by `save_state_global`. + src_rank: The rank from which the state is broadcasted (within the data parallel group, if using DP groups). + """ + + assert self._workers is None and self._current_epoch_iter is None, ( + "Cannot restore state while workers are running" + ) + assert self._restore_state is None, "Restore state already set" + + # Only restore multi-rank if state is actually a list and we are in a distributed setup. + # Otherwise treat as single rank state. + if src_rank is None or self._worker_config.world_size == 1: + assert isinstance(state, list), "State must be a list in distributed setup" + assert len(state) == self._worker_config.world_size, ( + "State must be a list of size world_size" + ) + + # All ranks have the state + # Select the state of the current rank + rank_state = state[self._worker_config.rank] + else: + if self._worker_config.data_parallel_group is not None: + # Only the src_rank has the state within this dp group + try: + global_src_rank = torch.distributed.get_global_rank( + self._worker_config.data_parallel_group, src_rank + ) + except RuntimeError: + raise ValueError( + f"src_rank {src_rank} is not in the group of the current rank's worker config" + ) + else: + # If no DP group is given, we assume the global rank is + # the same as the data parallel rank + global_src_rank = src_rank + + if self._worker_config.rank != src_rank: + # Send the state to all other ranks + assert state is None + # Must still be a list of Nones + state = [None] * self._worker_config.world_size + else: + assert isinstance(state, list), "State must be a list in distributed setup" + assert len(state) == self._worker_config.world_size, ( + "State must be a list of size world_size" + ) + + local_object = [None] + torch.distributed.scatter_object_list( + local_object, + state, + src=global_src_rank, + group=self._worker_config.data_parallel_group, + ) + rank_state = local_object[0] + + self.restore_state_rank(rank_state) + + def restore_sample(self, restore_key: tuple) -> TSample: + """ + Restore a sample from a restore key. + + Args: + restore_key: The restore key to restore the sample from. + + Returns: + The restored sample. + """ + id, global_worker_id, sample_idx = restore_key[:3] + assert id == type(self).__name__ + restore_key = restore_key[3:] + self._worker_config.worker_activate( + sample_idx, override_global_rank=global_worker_id, cache_pool=self._cache_pool + ) + try: + return add_sample_restore_key( + self._dataset.restore_sample(restore_key), global_worker_id, sample_idx, src=self + ) + finally: + self._worker_config.worker_deactivate() + + def config(self) -> dict[str, Any]: + return self._dataset.config() + + def __str__(self) -> str: + return f"DataLoader(prefetch_factor={self._prefetch_factor}, worker_type={self._worker_type.__name__})" diff --git a/src/megatron/energon/loader/future.py b/src/megatron/energon/loader/future.py new file mode 100644 index 00000000..e433080e --- /dev/null +++ b/src/megatron/energon/loader/future.py @@ -0,0 +1,69 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +from abc import abstractmethod +from typing import Any, Callable, Generic, TypeVar + +R = TypeVar("R", covariant=True) +T = TypeVar("T", covariant=True) + + +class Future(Generic[R]): + """Base class for abstract futures.""" + + @abstractmethod + def get(self) -> R: ... + + +class DoneFuture(Future[R]): + """Future that is already done.""" + + def __init__(self, result: R): + self._result = result + + def get(self) -> R: + return self._result + + +class CallableFuture(Future[R]): + """Future that calls a callable to get the result.""" + + _callable: Callable[[], R] + _value: R + _exception: Exception + + def __init__(self, callable: Callable[[], R]): + self._callable = callable + + def get(self) -> R: + if not hasattr(self, "_value") and not hasattr(self, "_exception"): + try: + self._value = self._callable() + except Exception as e: + self._exception = e + if hasattr(self, "_exception"): + raise self._exception + return self._value + + @staticmethod + def chain(future: Future[T], fn: Callable[[Future[T]], R]) -> Future[R]: + """ + Chain a function to a future. + + Args: + future: The future which provides the input for the function. + fn: The function to call on the result of the future, to transform the result. + + Returns: + A future that will be resolved to the result of the function given the result of the future. + """ + return CallableFuture(lambda: fn(future)) + + +class ExceptionFuture(Future[Any]): + """Future that raises an exception.""" + + def __init__(self, exception: Exception): + self._exception = exception + + def get(self) -> Any: + raise self._exception diff --git a/src/megatron/energon/loader/pin_memory.py b/src/megatron/energon/loader/pin_memory.py new file mode 100644 index 00000000..9aba23c2 --- /dev/null +++ b/src/megatron/energon/loader/pin_memory.py @@ -0,0 +1,123 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +import queue +import threading +from typing import Generic, Protocol, TypeVar, cast, override + +import torch + +from megatron.energon.flavors.base_dataset import PinMemoryMixin +from megatron.energon.loader.future import CallableFuture, Future + +TSample = TypeVar("TSample") +T = TypeVar("T") + + +class PinMemory(Generic[TSample]): + """Base class for pinning memory of samples. + + This class is used to pin memory of samples in the primary process. + """ + + def __init__(self, device: str | torch.device): + self._device = device + + def _pin_memory(self, sample: TSample) -> TSample: + return PinMemoryMixin.sample_pin_memory(sample, self._device) + + def __call__(self, sample: Future[TSample]) -> Future[TSample]: + """Pin the memory of a sample. The default implementation runs in the main thread.""" + return CallableFuture.chain(sample, lambda fut: self._pin_memory(fut.get())) + + def shutdown(self) -> None: + """Shutdown any running threads.""" + pass + + +class NoPinMemory(PinMemory[TSample]): + """No-op implementation of :class:`PinMemory`. + + Does not pin the memory of samples. + """ + + def __init__(self): + super().__init__(device="cpu") + + def __call__(self, sample: Future[TSample]) -> Future[TSample]: + return sample + + +class QueueProtocol(Protocol[T]): + def get(self, /) -> T: ... + + def put(self, item: T, /) -> None: ... + + def qsize(self, /) -> int: ... + + def task_done(self, /) -> None: ... + + def join(self, /) -> None: ... + + +class PinMemoryThread(PinMemory[TSample], Generic[TSample]): + """Threaded implementation of :class:`PinMemory`. + + Pins the memory of samples in a separate thread in the background. + + Creates the thread on first use and shuts it down on shutdown. May be reused after shutdown. + """ + + _SHUTDOWN = cast(Future[TSample], object()) + + _thread: threading.Thread | None = None + + _item_queue: QueueProtocol[Future[TSample]] + _result_queue: QueueProtocol[tuple[TSample, None] | tuple[None, Exception]] + + def __init__( + self, + device: str | torch.device, + ): + super().__init__(device) + self._item_queue = queue.Queue() + self._result_queue = queue.Queue() + + def _run(self) -> None: + """The pin memory thread. It will fetch the sample from the item future queue and pin the memory.""" + while True: + try: + sample = self._item_queue.get() + if sample is self._SHUTDOWN: + break + sample = self._pin_memory(sample.get()) + except Exception as e: + self._result_queue.put((None, e)) + else: + self._result_queue.put((sample, None)) + self._item_queue.task_done() + + def _get_next_result(self) -> TSample: + result, exception = self._result_queue.get() + if exception is not None: + raise exception + return cast(TSample, result) + + def __call__(self, sample: Future[TSample]) -> Future[TSample]: + """ + Pin the memory of a sample. + Submits the sample future to the thread to fetch it and pins the memory in the thread, + then returns a future for fetching the pinned sample. + """ + if self._thread is None: + self._thread = threading.Thread(target=self._run, daemon=True, name="PinMemoryThread") + self._thread.start() + self._item_queue.put(sample) + return CallableFuture(self._get_next_result) + + @override + def shutdown(self) -> None: + if self._thread is not None: + self._item_queue.put(self._SHUTDOWN) + self._item_queue.join() + self._thread.join() + self._thread = None diff --git a/src/megatron/energon/loader/workers/__init__.py b/src/megatron/energon/loader/workers/__init__.py new file mode 100644 index 00000000..43842921 --- /dev/null +++ b/src/megatron/energon/loader/workers/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause diff --git a/src/megatron/energon/loader/workers/async_worker.py b/src/megatron/energon/loader/workers/async_worker.py new file mode 100644 index 00000000..1abab7b4 --- /dev/null +++ b/src/megatron/energon/loader/workers/async_worker.py @@ -0,0 +1,263 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +import traceback +from abc import abstractmethod +from typing import Any, Callable, ParamSpec, Protocol, TypeVar, override + +from megatron.energon.cache.base import CachePool +from megatron.energon.edataclass import edataclass +from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.loader.future import Future +from megatron.energon.loader.workers.base_worker import DataLoaderWorker +from megatron.energon.rng import SystemRng +from megatron.energon.state import FlexState +from megatron.energon.worker import WorkerConfig + +TSample = TypeVar("TSample", covariant=True) + +P = ParamSpec("P") +T = TypeVar("T") +R = TypeVar("R", covariant=True) + + +class QueueProtocol(Protocol[T]): + def get(self, /) -> T: ... + + def put(self, item: T, /) -> None: ... + + def qsize(self, /) -> int: ... + + +class DataLoaderAsynchronousWorker(DataLoaderWorker[TSample]): + """ + Extension of the `DataLoaderWorker`, which implements commands via a command and results queue. + + There are different implementations of the async worker: + - :class:`ForkDataLoaderWorker` - A worker that forks a new process for each worker. + - :class:`ThreadDataLoaderWorker` - A worker that uses threads to execute the commands. + """ + + _cmd_queue: QueueProtocol["WorkerCommand"] + _result_queue: QueueProtocol["WorkerResult"] + _next_future_id: int + _futures: dict[int, "FutureImpl"] + + def __init__( + self, + dataset: SavableDataset, + worker_config: WorkerConfig, + rank_worker_id: int, + cmd_queue: QueueProtocol["WorkerCommand"], + result_queue: QueueProtocol["WorkerResult"], + cache_pool: CachePool | None, + ): + super().__init__(dataset, worker_config, rank_worker_id, cache_pool) + assert worker_config.num_workers > 0, "Async workers require num_workers > 0" + self._cmd_queue = cmd_queue + self._result_queue = result_queue + self._next_future_id = 0 + self._futures = {} + + # ------------------------------------------------------------------------------------------------ + # Section: Remote call implementation + + @edataclass + class WorkerResult: + """Internal class for communicating a result from the worker via the result queue.""" + + future_id: int + result: Any = None + exception: Exception | None = None + + @edataclass + class WorkerCommand: + """Internal class for communicating a command to the worker via the command queue.""" + + cmd: str + args: tuple[Any, ...] + kwargs: dict[str, Any] + future_id: int + + class FutureImpl(Future[Any]): + """Class for returning a future result from the worker..""" + + _worker: "DataLoaderAsynchronousWorker" + _future_id: int + _result: Any + _exception: Exception + + def __init__(self, worker: "DataLoaderAsynchronousWorker", future_id: int): + self._worker = worker + self._future_id = future_id + + def get(self) -> Any: + if not hasattr(self, "_result") and not hasattr(self, "_exception"): + self._worker._wait_for_worker_result(self._future_id) + if hasattr(self, "_exception"): + raise self._exception + return self._result + + def _set_result(self, result: Any) -> None: + self._result = result + + def _set_exception(self, exception: Exception) -> None: + self._exception = exception + + def _wait_for_worker_result(self, future_id: int) -> None: + """ + Wait for the result of a future. + If another result comes first, update the corresponding future. + + Args: + future_id: The ID of the future to wait for. + """ + while True: + print(f"[fut={future_id}] waiting for result\n", end="") + res = self._result_queue.get() + fut = self._futures.pop(res.future_id) + if res.exception is not None: + fut._set_exception(res.exception) + else: + fut._set_result(res.result) + # self._result_queue.task_done() + if res.future_id == future_id: + print(f"[fut={future_id}] got result, return\n", end="") + return + else: + print(f"[fut={future_id}] got result for {res.future_id=}, continue\n", end="") + continue + + def _worker_call(self, fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> Future[R]: + """ + Call a function in the worker and return a future for getting the result. + The function must be an instance method of `self`. Uses the name to identify the function in the worker + instance. + + Args: + fn: The function to call. + *args: The arguments to pass to the function. + **kwargs: The keyword arguments to pass to the function. + """ + self._assert_running() + assert not self._in_worker(), "worker_call must not be called in the worker" + future_id = self._next_future_id + self._next_future_id += 1 + + self._futures[future_id] = future = self.FutureImpl(self, future_id) + print( + f"[wrk={self._rank_worker_id}] worker_call {fn.__name__=} {args=} {kwargs=} {future_id=}\n", + end="", + ) + self._cmd_queue.put( + self.WorkerCommand(cmd=fn.__name__, args=args, kwargs=kwargs, future_id=future_id) + ) + print(f"[wrk={self._rank_worker_id}] queue: {self._cmd_queue.qsize()}\n", end="") + return future + + def _worker_run( + self, cmd_queue: QueueProtocol[WorkerCommand], result_queue: QueueProtocol[WorkerResult] + ) -> None: + """ + The worker main loop. + It waits for commands via the command queue and executes them. + The functions to call are identified by their name. + The result of the call is put into the result queue. + The worker exits when the command `_shutdown_worker` is received. + + Args: + cmd_queue: The command queue to wait for commands. + result_queue: The result queue to put the results into. + """ + assert self._in_worker(), "_worker_run must be called in the worker" + try: + SystemRng.seed(self._seed) + import torch.utils.data._utils + + torch.utils.data._utils.worker._worker_info = torch.utils.data._utils.worker.WorkerInfo( + id=self._rank_worker_id, + num_workers=self.worker_config.num_workers, + seed=self._seed, + dataset=self.dataset, + ) + self._global_worker_id = self.worker_config.global_worker_id() + self.worker_config.assert_worker() + while True: + print( + f"[wrk={self._rank_worker_id}] waiting for command, len: {cmd_queue.qsize()}\n", + end="", + ) + cmd = cmd_queue.get() + print( + f"[fut={cmd.future_id}] got command {cmd.cmd=} {cmd.args=} {cmd.kwargs=}\n", + end="", + ) + try: + fn = getattr(self, cmd.cmd) + result = fn(*cmd.args, **cmd.kwargs) + except Exception as e: + print(f"[fut={cmd.future_id}] send exception {e!r}\n", end="") + result_queue.put(self.WorkerResult(future_id=cmd.future_id, exception=e)) + else: + print(f"[fut={cmd.future_id}] send result {result!r}\n", end="") + result_queue.put(self.WorkerResult(future_id=cmd.future_id, result=result)) + del result + # cmd_queue.task_done() + if cmd.cmd == self._wrk_shutdown_worker.__name__: + print(f"[fut={cmd.future_id}] got shutdown command, exit\n", end="") + break + print(f"[fut={cmd.future_id}] processed, waiting for next command\n", end="") + except: + traceback.print_exc() + raise + + @abstractmethod + def _in_worker(self) -> bool: + """Check if the execution is within the worker.""" + ... + + # ------------------------------------------------------------------------------------------------ + # Section: Worker methods - now calling to workers via queues. + + def _wrk_shutdown_worker(self) -> None: + """Does nothing. The actual shutdown is handled in the _worker_run method.""" + assert self._in_worker(), "_wrk_shutdown_worker must be called in the worker" + + def _shutdown_worker(self) -> None: + """Shutdown the worker. The actual shutdown is handled in the _worker_run method.""" + assert not self._in_worker(), "shutdown_worker must not be called in the worker" + # This is not actually a recursive call, because the worker loop will exit before calling this method. + self._worker_call(self._wrk_shutdown_worker).get() + + def _wrk_prefetch_next(self) -> TSample: + """Wraps the super class method to call it in the worker process.""" + # The super class implementation already returns a resolved future (to be interface compatible), + # so immediately resolve the future to the result (get returns immediately). + return super().prefetch_next().get() + + @override + def dataset_init(self, initial_state: FlexState | None) -> None: + if self._in_worker(): + return super().dataset_init(initial_state) + else: + return self._worker_call(self.dataset_init, initial_state).get() + + @override + def new_iter(self) -> None: + if self._in_worker(): + return super().new_iter() + else: + return self._worker_call(self.new_iter).get() + + @override + def prefetch_next(self) -> Future[TSample]: + # Do not resolve the future here, but return it. + if self._in_worker(): + return super().prefetch_next() + return self._worker_call(self._wrk_prefetch_next) + + @override + def save_state(self) -> FlexState: + if self._in_worker(): + return super().save_state() + else: + return self._worker_call(self.save_state).get() diff --git a/src/megatron/energon/loader/workers/base_worker.py b/src/megatron/energon/loader/workers/base_worker.py new file mode 100644 index 00000000..bc60c8e3 --- /dev/null +++ b/src/megatron/energon/loader/workers/base_worker.py @@ -0,0 +1,175 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +from typing import Generic, TypeVar + +from megatron.energon.cache.base import CachePool +from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key +from megatron.energon.loader.future import DoneFuture, ExceptionFuture, Future +from megatron.energon.rng import SystemRng +from megatron.energon.state import FlexState +from megatron.energon.worker import WorkerConfig + +TSample = TypeVar("TSample", covariant=True) + + +class DataLoaderWorker(Generic[TSample]): + """ + A worker for a :class:`DataLoader`. + + The basic implementation iterates the dataset. + The async extension implements the main commands via a command and results queue. + """ + + dataset: SavableDataset[TSample] + worker_config: WorkerConfig + + _rank_worker_id: int + _global_worker_id: int + _seed: int + _cache_pool: CachePool | None + _sample_index: int = 0 + _exhausted: bool = True + + def __init__( + self, + dataset: SavableDataset[TSample], + worker_config: WorkerConfig, + rank_worker_id: int, + cache_pool: CachePool | None, + ): + """ + Initialize the worker. + + Args: + dataset: The dataset to iterate over. + worker_config: The worker configuration. + rank_worker_id: The rank of the worker. + cache_pool: The cache pool to use. + """ + self.dataset = dataset + self.worker_config = worker_config + self._rank_worker_id = rank_worker_id + self._global_worker_id = worker_config.global_worker_id(rank_worker_id) + self._seed = self.worker_config.worker_seed(rank_worker_id) + self._cache_pool = cache_pool + + # ------------------------------------------------------------------------------------------------ + # Section: Main control methods + + def start(self) -> None: + """ + Start the worker. + """ + pass + + def shutdown(self, in_del: bool = False) -> None: + """ + Shutdown the worker. + + Args: + in_del: If True, the worker is being deleted. + """ + pass + + def running(self) -> bool: + """ + Check if the worker is running. + """ + return True + + def _assert_running(self) -> None: + """ + Assert that the worker is running and alive. + """ + assert self.running(), "Worker must be running" + + def __del__(self) -> None: + self.shutdown(in_del=True) + + # ------------------------------------------------------------------------------------------------ + # Section: Worker methods + + def dataset_init(self, state: FlexState | None) -> None: + """ + Initialize the worker (may restore the state). + Calls `new_iter` if the worker is not exhausted and also initially (`state=None`). + + Args: + state: The state to restore the worker from or None for using the initial state. + """ + # This is called in the worker context (process/thread). + assert self._global_worker_id == self.worker_config.global_worker_id(), ( + "Global worker ID mismatch" + ) + assert self._seed == self.worker_config.worker_seed(self._rank_worker_id), "Seed mismatch" + print(f"dataset_init {state=}\n", end="") + if state is None: + self._sample_index = 0 + self.dataset.reset_state_deep() + print("dataset_init reset_state_deep\n", end="") + self.new_iter() + print("dataset_init new_iter\n", end="") + else: + assert state["__class__"] == "DataLoaderWorker", "state type mismatch" + self._sample_index = state["sample_index"] + SystemRng.restore_state(state["rng"]) + self.dataset.restore_state(state["dataset"]) + if not state["exhausted"]: + self.new_iter() + assert self._exhausted == state["exhausted"], "Exhausted state mismatch" + + def new_iter(self) -> None: + """ + Start a new iterator of the dataset. + Called after the dataset is initialized and to start a new epoch (if the dataset is not infinite). + The iterator is stored in the worker and is used by the `prefetch_next` method, which calls `next` on it. + Updates the exhausted flag to False. + """ + # This is called in the worker context (process/thread). + print("new_iter\n", end="") + self._dataset_iter = iter(self.dataset) + self._exhausted = False + print("new_iter done\n", end="") + + def prefetch_next(self) -> Future[TSample]: + """ + Fetch the next sample (i.e. call `next` on the iterator) and return a future for getting the result. + Updates the exhausted flag if the iterator is exhausted. + + Returns: + A future that will either be resolved to the next sample or raise StopIteration if the iterator is exhausted. + """ + # This is called in the worker context (process/thread). + assert self._dataset_iter is not None, "start_iter must be called before prefetch_next" + if self._exhausted: + try: + raise StopIteration() + except StopIteration as e: + return ExceptionFuture(e) + sample_idx = self._sample_index + self.worker_config.worker_activate(sample_idx, cache_pool=self._cache_pool) + try: + next_sample = next(self._dataset_iter) + self._sample_index += 1 + next_sample = add_sample_restore_key( + next_sample, self._global_worker_id, sample_idx, src=self + ) + except StopIteration as e: + self._exhausted = True + return ExceptionFuture(e) + finally: + self.worker_config.worker_deactivate() + return DoneFuture(next_sample) + + def save_state(self) -> FlexState: + """ + Save the state of the worker. + """ + # This is called in the worker context (process/thread). + return FlexState( + __class__="DataLoaderWorker", + rng=SystemRng.save_state(), + dataset=self.dataset.save_state(), + exhausted=self._exhausted, + sample_index=self._sample_index, + ) diff --git a/src/megatron/energon/loader/workers/fork_worker.py b/src/megatron/energon/loader/workers/fork_worker.py new file mode 100644 index 00000000..3b070e9e --- /dev/null +++ b/src/megatron/energon/loader/workers/fork_worker.py @@ -0,0 +1,157 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +import multiprocessing +import os +import sys +import threading +import warnings +from typing import Generic, TypeVar, override + +from megatron.energon.cache.base import CachePool +from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.loader.workers.async_worker import DataLoaderAsynchronousWorker +from megatron.energon.worker import WorkerConfig +from megatron.energon.wrappers.gc_dataset import gc_init_worker + +TSample = TypeVar("TSample", covariant=True) + + +class ForkDataLoaderWorker(DataLoaderAsynchronousWorker[TSample], Generic[TSample]): + """ + Implements the `DataLoaderWorker` interface using processes. + """ + + _process: multiprocessing.Process | None = None + _cmd_queue: multiprocessing.Queue + _result_queue: multiprocessing.Queue + + _threaded_shutdown: threading.Thread | None = None + + _spawning_process: int + + def __init__( + self, + dataset: SavableDataset, + worker_config: WorkerConfig, + rank_worker_id: int, + cache_pool: CachePool | None, + ): + multiprocessing.set_start_method("fork", force=True) + super().__init__( + dataset, + worker_config=worker_config, + rank_worker_id=rank_worker_id, + cmd_queue=multiprocessing.Queue(), + result_queue=multiprocessing.Queue(), + cache_pool=cache_pool, + ) + self._spawning_process = os.getpid() + + def _check_parent_process(self, evt_exit: threading.Event) -> None: + """Check if the parent process is alive. If it is dead, exit the worker process.""" + parent_proc = multiprocessing.parent_process() + parent_pid = os.getppid() + if parent_proc is None: + print("No parent process, exiting", file=sys.stderr) + os._exit(-1) + while not evt_exit.wait(1): + if parent_proc.exitcode is not None or os.getppid() != parent_pid: + print("Parent process died, exiting", file=sys.stderr) + os._exit(-1) + + def _worker_run( + self, + cmd_queue: multiprocessing.Queue, + result_queue: multiprocessing.Queue, + ) -> None: + gc_init_worker(self._rank_worker_id) + # cmd_queue is read only, so we can cancel the join thread. + cmd_queue.cancel_join_thread() + worker_exit_evt = threading.Event() + parent_check_thread = threading.Thread( + target=self._check_parent_process, args=(worker_exit_evt,), daemon=True + ) + parent_check_thread.start() + try: + super()._worker_run(cmd_queue, result_queue) + finally: + print(f"[wrk={self._rank_worker_id}] shutting down\n", end="") + worker_exit_evt.set() + print( + f"[wrk={self._rank_worker_id}] shutting down, wait for parent_check_thread\n", + end="", + ) + parent_check_thread.join() + print(f"[wrk={self._rank_worker_id}] shutting down, close queues\n", end="") + result_queue.close() + result_queue.join_thread() + cmd_queue.close() + cmd_queue.cancel_join_thread() + print(f"[wrk={self._rank_worker_id}] shutting down, done\n", end="") + + @override + def _in_worker(self) -> bool: + return multiprocessing.current_process() == self._process + + @override + def start(self) -> None: + self._process = multiprocessing.Process( + target=self._worker_run, + args=(self._cmd_queue, self._result_queue), + daemon=True, + name=f"ForkDataLoaderWorker-{self._rank_worker_id}", + ) + self._process.start() + + @override + def shutdown(self, in_del: bool = False) -> None: + if self._spawning_process != os.getpid(): + # Should avoid forked process containing a forked worker on exit. + warnings.warn( + "Shutting down worker from a different process than the one that spawned it, skipping" + ) + return + if self._process is not None: + if in_del: + # It seems that the ResourceWarning does not work in the gc loop? Also print a warning here. + warnings.warn( + "Explicitly call DataLoader.shutdown() to avoid leaking processes. Terminating worker process.", + ResourceWarning, + ) + print( + "WARNING: Explicitly call DataLoader.shutdown() to avoid leaking processes. Terminating worker process.\n", + end="", + file=sys.stderr, + ) + self._cmd_queue.close() + self._cmd_queue.cancel_join_thread() + self._result_queue.close() + self._result_queue.cancel_join_thread() + # Kill the process, because we cannot communicate with it in the gc loop. + self._process.terminate() + self._process = None + else: + try: + self._shutdown_worker() + except Exception: + self._process.join(10) + if self._process.is_alive(): + self._process.terminate() + else: + self._process.join() + assert self._process.exitcode == 0, ( + f"Process exit code {self._process.exitcode}" + ) + self._process = None + self._cmd_queue.close() + self._cmd_queue.cancel_join_thread() + self._result_queue.close() + self._result_queue.cancel_join_thread() + + @override + def running(self) -> bool: + return self._process is not None + + def _assert_running(self) -> None: + assert self._process is not None, "Worker must be started first" + assert self._process.is_alive(), "Worker died" diff --git a/src/megatron/energon/loader/workers/thread_worker.py b/src/megatron/energon/loader/workers/thread_worker.py new file mode 100644 index 00000000..cd646ae1 --- /dev/null +++ b/src/megatron/energon/loader/workers/thread_worker.py @@ -0,0 +1,92 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +import queue +import sys +import threading +import warnings +from typing import Generic, TypeVar, override + +from megatron.energon.cache.base import CachePool +from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.loader.workers.async_worker import DataLoaderAsynchronousWorker +from megatron.energon.worker import WorkerConfig + +TSample = TypeVar("TSample", covariant=True) + + +class ThreadDataLoaderWorker(DataLoaderAsynchronousWorker[TSample], Generic[TSample]): + """ + Implements the `DataLoaderWorker` interface using threads. + """ + + _thread: threading.Thread | None = None + _cmd_queue: queue.Queue + _result_queue: queue.Queue + + def __init__( + self, + dataset: SavableDataset, + worker_config: WorkerConfig, + rank_worker_id: int, + cache_pool: CachePool | None, + ): + super().__init__( + dataset, + worker_config=worker_config, + rank_worker_id=rank_worker_id, + cmd_queue=queue.Queue(), + result_queue=queue.Queue(), + cache_pool=cache_pool, + ) + + def _worker_run(self, cmd_queue: queue.Queue, result_queue: queue.Queue) -> None: + super()._worker_run(cmd_queue, result_queue) + + @override + def _in_worker(self) -> bool: + return threading.current_thread() == self._thread + + @override + def start(self) -> None: + self._thread = threading.Thread( + target=self._worker_run, + args=(self._cmd_queue, self._result_queue), + daemon=True, + name=f"ThreadDataLoaderWorker-{self._rank_worker_id}", + ) + self._thread.start() + + @override + def shutdown(self, in_del: bool = False) -> None: + if self._thread is not None: + if in_del: + # It seems that the ResourceWarning does not work in the gc loop? Also print a warning here. + warnings.warn( + "Explicitly call DataLoader.shutdown() to avoid leaking threads.", + ResourceWarning, + ) + print( + "WARNING: Explicitly call DataLoader.shutdown() to avoid leaking threads.\n", + end="", + file=sys.stderr, + ) + # Just try to enqueue the shutdown command to the thread and hope for the best. Ignore the result. + self._cmd_queue.put( + self.WorkerCommand( + cmd=self._wrk_shutdown_worker.__name__, args=(), kwargs={}, future_id=-1 + ) + ) + self._thread = None + else: + self._shutdown_worker() + self._thread.join() + self._thread = None + + @override + def running(self) -> bool: + return self._thread is not None + + @override + def _assert_running(self) -> None: + assert self._thread is not None, "Thread must be started first" + assert self._thread.is_alive(), "Thread died" From 50f83750b3a747ece73f6eae47d106cfc483f3ea Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Thu, 24 Jul 2025 10:22:51 +0200 Subject: [PATCH 09/36] WIP Refactoring loader --- src/megatron/energon/dataloader/__init__.py | 15 ++ .../{loader => dataloader}/dataloader.py | 8 +- .../energon/{loader => dataloader}/future.py | 45 +++- .../{loader => dataloader}/pin_memory.py | 65 +---- .../energon/dataloader/workers/__init__.py | 11 + .../workers/async_worker.py | 247 +++++++++++------- .../workers/base_worker.py | 2 +- .../workers/fork_worker.py | 75 +++--- .../workers/thread_worker.py | 53 ++-- src/megatron/energon/flavors/base_dataset.py | 8 +- src/megatron/energon/loader/__init__.py | 2 - .../energon/loader/workers/__init__.py | 2 - tests/test_dataloader.py | 10 +- tests/test_dataset_det.py | 4 +- 14 files changed, 325 insertions(+), 222 deletions(-) create mode 100644 src/megatron/energon/dataloader/__init__.py rename src/megatron/energon/{loader => dataloader}/dataloader.py (98%) rename src/megatron/energon/{loader => dataloader}/future.py (59%) rename src/megatron/energon/{loader => dataloader}/pin_memory.py (51%) create mode 100644 src/megatron/energon/dataloader/workers/__init__.py rename src/megatron/energon/{loader => dataloader}/workers/async_worker.py (54%) rename src/megatron/energon/{loader => dataloader}/workers/base_worker.py (98%) rename src/megatron/energon/{loader => dataloader}/workers/fork_worker.py (76%) rename src/megatron/energon/{loader => dataloader}/workers/thread_worker.py (65%) delete mode 100644 src/megatron/energon/loader/__init__.py delete mode 100644 src/megatron/energon/loader/workers/__init__.py diff --git a/src/megatron/energon/dataloader/__init__.py b/src/megatron/energon/dataloader/__init__.py new file mode 100644 index 00000000..9839f8e2 --- /dev/null +++ b/src/megatron/energon/dataloader/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +from .dataloader import DataLoader +from .pin_memory import NoPinMemory, PinMemory, PinMemoryThread +from .workers import DataLoaderWorker, ForkDataLoaderWorker, ThreadDataLoaderWorker + +__all__ = [ + "DataLoader", + "PinMemory", + "NoPinMemory", + "PinMemoryThread", + "DataLoaderWorker", + "ThreadDataLoaderWorker", + "ForkDataLoaderWorker", +] diff --git a/src/megatron/energon/loader/dataloader.py b/src/megatron/energon/dataloader/dataloader.py similarity index 98% rename from src/megatron/energon/loader/dataloader.py rename to src/megatron/energon/dataloader/dataloader.py index 4a2f3fb6..3f3ec0ba 100644 --- a/src/megatron/energon/loader/dataloader.py +++ b/src/megatron/energon/dataloader/dataloader.py @@ -17,11 +17,11 @@ import torch.distributed from megatron.energon.cache.base import CachePool +from megatron.energon.dataloader.future import CallableFuture, Future +from megatron.energon.dataloader.pin_memory import NoPinMemory, PinMemory, PinMemoryThread +from megatron.energon.dataloader.workers.base_worker import DataLoaderWorker +from megatron.energon.dataloader.workers.fork_worker import ForkDataLoaderWorker from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key -from megatron.energon.loader.future import CallableFuture, Future -from megatron.energon.loader.pin_memory import NoPinMemory, PinMemory, PinMemoryThread -from megatron.energon.loader.workers.base_worker import DataLoaderWorker -from megatron.energon.loader.workers.fork_worker import ForkDataLoaderWorker from megatron.energon.state import FlexState from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset, get_sample_restore_key diff --git a/src/megatron/energon/loader/future.py b/src/megatron/energon/dataloader/future.py similarity index 59% rename from src/megatron/energon/loader/future.py rename to src/megatron/energon/dataloader/future.py index e433080e..b8f1e53c 100644 --- a/src/megatron/energon/loader/future.py +++ b/src/megatron/energon/dataloader/future.py @@ -7,11 +7,35 @@ T = TypeVar("T", covariant=True) +class CancelledError(Exception): + """Exception raised when a future was cancelled.""" + + @classmethod + def with_current_traceback(cls): + try: + raise cls() + except cls as e: + if e.__traceback__ is not None and e.__traceback__.tb_next is not None: + return e.with_traceback(e.__traceback__.tb_next) + return e + + class Future(Generic[R]): """Base class for abstract futures.""" @abstractmethod - def get(self) -> R: ... + def get(self) -> R: + """Get the result of the future. Waits until the future is done.""" + ... + + @abstractmethod + def cancel(self) -> bool: + """Cancel the future. + + Returns: + True if the future was cancelled, False if already done. + """ + ... class DoneFuture(Future[R]): @@ -23,18 +47,26 @@ def __init__(self, result: R): def get(self) -> R: return self._result + def cancel(self) -> bool: + return False + class CallableFuture(Future[R]): """Future that calls a callable to get the result.""" + __slots__ = ("_callable", "_value", "_exception", "_cancelled") + _callable: Callable[[], R] _value: R _exception: Exception + _cancelled: bool def __init__(self, callable: Callable[[], R]): self._callable = callable def get(self) -> R: + if getattr(self, "_cancelled", False): + raise CancelledError("Future was cancelled") if not hasattr(self, "_value") and not hasattr(self, "_exception"): try: self._value = self._callable() @@ -44,6 +76,14 @@ def get(self) -> R: raise self._exception return self._value + def cancel(self) -> bool: + if getattr(self, "_cancelled", False): + return True + if hasattr(self, "_value") or hasattr(self, "_exception"): + return False + self._cancelled = True + return True + @staticmethod def chain(future: Future[T], fn: Callable[[Future[T]], R]) -> Future[R]: """ @@ -67,3 +107,6 @@ def __init__(self, exception: Exception): def get(self) -> Any: raise self._exception + + def cancel(self) -> bool: + return False diff --git a/src/megatron/energon/loader/pin_memory.py b/src/megatron/energon/dataloader/pin_memory.py similarity index 51% rename from src/megatron/energon/loader/pin_memory.py rename to src/megatron/energon/dataloader/pin_memory.py index 9aba23c2..ad493914 100644 --- a/src/megatron/energon/loader/pin_memory.py +++ b/src/megatron/energon/dataloader/pin_memory.py @@ -2,12 +2,13 @@ # SPDX-License-Identifier: BSD-3-Clause import queue import threading -from typing import Generic, Protocol, TypeVar, cast, override +from typing import Generic, TypeVar, cast import torch +from megatron.energon.dataloader.future import CallableFuture, Future +from megatron.energon.dataloader.workers.async_worker import AsynchronousMixin from megatron.energon.flavors.base_dataset import PinMemoryMixin -from megatron.energon.loader.future import CallableFuture, Future TSample = TypeVar("TSample") T = TypeVar("T") @@ -47,19 +48,7 @@ def __call__(self, sample: Future[TSample]) -> Future[TSample]: return sample -class QueueProtocol(Protocol[T]): - def get(self, /) -> T: ... - - def put(self, item: T, /) -> None: ... - - def qsize(self, /) -> int: ... - - def task_done(self, /) -> None: ... - - def join(self, /) -> None: ... - - -class PinMemoryThread(PinMemory[TSample], Generic[TSample]): +class PinMemoryThread(PinMemory[TSample], AsynchronousMixin, Generic[TSample]): """Threaded implementation of :class:`PinMemory`. Pins the memory of samples in a separate thread in the background. @@ -71,36 +60,18 @@ class PinMemoryThread(PinMemory[TSample], Generic[TSample]): _thread: threading.Thread | None = None - _item_queue: QueueProtocol[Future[TSample]] - _result_queue: QueueProtocol[tuple[TSample, None] | tuple[None, Exception]] - def __init__( self, device: str | torch.device, ): super().__init__(device) - self._item_queue = queue.Queue() - self._result_queue = queue.Queue() - - def _run(self) -> None: - """The pin memory thread. It will fetch the sample from the item future queue and pin the memory.""" - while True: - try: - sample = self._item_queue.get() - if sample is self._SHUTDOWN: - break - sample = self._pin_memory(sample.get()) - except Exception as e: - self._result_queue.put((None, e)) - else: - self._result_queue.put((sample, None)) - self._item_queue.task_done() - - def _get_next_result(self) -> TSample: - result, exception = self._result_queue.get() - if exception is not None: - raise exception - return cast(TSample, result) + self._asynchronous_init(cmd_queue=queue.Queue(), result_queue=queue.Queue()) + + def _wrk_pin_memory(self, sample: Future[TSample]) -> TSample: + return self._pin_memory(sample.get()) + + def _worker_pin_memory(self, sample: Future[TSample]) -> Future[TSample]: + return self._worker_call(self._wrk_pin_memory, sample) def __call__(self, sample: Future[TSample]) -> Future[TSample]: """ @@ -108,16 +79,4 @@ def __call__(self, sample: Future[TSample]) -> Future[TSample]: Submits the sample future to the thread to fetch it and pins the memory in the thread, then returns a future for fetching the pinned sample. """ - if self._thread is None: - self._thread = threading.Thread(target=self._run, daemon=True, name="PinMemoryThread") - self._thread.start() - self._item_queue.put(sample) - return CallableFuture(self._get_next_result) - - @override - def shutdown(self) -> None: - if self._thread is not None: - self._item_queue.put(self._SHUTDOWN) - self._item_queue.join() - self._thread.join() - self._thread = None + return self._worker_pin_memory(sample) diff --git a/src/megatron/energon/dataloader/workers/__init__.py b/src/megatron/energon/dataloader/workers/__init__.py new file mode 100644 index 00000000..ee41db0f --- /dev/null +++ b/src/megatron/energon/dataloader/workers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +from .base_worker import DataLoaderWorker +from .fork_worker import ForkDataLoaderWorker +from .thread_worker import ThreadDataLoaderWorker + +__all__ = [ + "DataLoaderWorker", + "ThreadDataLoaderWorker", + "ForkDataLoaderWorker", +] diff --git a/src/megatron/energon/loader/workers/async_worker.py b/src/megatron/energon/dataloader/workers/async_worker.py similarity index 54% rename from src/megatron/energon/loader/workers/async_worker.py rename to src/megatron/energon/dataloader/workers/async_worker.py index 1abab7b4..d2ff0df2 100644 --- a/src/megatron/energon/loader/workers/async_worker.py +++ b/src/megatron/energon/dataloader/workers/async_worker.py @@ -2,13 +2,13 @@ # SPDX-License-Identifier: BSD-3-Clause import traceback from abc import abstractmethod -from typing import Any, Callable, ParamSpec, Protocol, TypeVar, override +from typing import Any, Callable, Generic, ParamSpec, Protocol, TypeVar, override from megatron.energon.cache.base import CachePool +from megatron.energon.dataloader.future import CancelledError, Future +from megatron.energon.dataloader.workers.base_worker import DataLoaderWorker from megatron.energon.edataclass import edataclass from megatron.energon.flavors.base_dataset import SavableDataset -from megatron.energon.loader.future import Future -from megatron.energon.loader.workers.base_worker import DataLoaderWorker from megatron.energon.rng import SystemRng from megatron.energon.state import FlexState from megatron.energon.worker import WorkerConfig @@ -28,80 +28,87 @@ def put(self, item: T, /) -> None: ... def qsize(self, /) -> int: ... -class DataLoaderAsynchronousWorker(DataLoaderWorker[TSample]): - """ - Extension of the `DataLoaderWorker`, which implements commands via a command and results queue. +@edataclass +class WorkerResult: + """Internal class for communicating a result from the worker via the result queue.""" - There are different implementations of the async worker: - - :class:`ForkDataLoaderWorker` - A worker that forks a new process for each worker. - - :class:`ThreadDataLoaderWorker` - A worker that uses threads to execute the commands. - """ + future_id: int + result: Any = None + exception: Exception | None = None - _cmd_queue: QueueProtocol["WorkerCommand"] - _result_queue: QueueProtocol["WorkerResult"] - _next_future_id: int - _futures: dict[int, "FutureImpl"] - def __init__( - self, - dataset: SavableDataset, - worker_config: WorkerConfig, - rank_worker_id: int, - cmd_queue: QueueProtocol["WorkerCommand"], - result_queue: QueueProtocol["WorkerResult"], - cache_pool: CachePool | None, - ): - super().__init__(dataset, worker_config, rank_worker_id, cache_pool) - assert worker_config.num_workers > 0, "Async workers require num_workers > 0" - self._cmd_queue = cmd_queue - self._result_queue = result_queue - self._next_future_id = 0 - self._futures = {} +@edataclass +class WorkerCommand: + """Internal class for communicating a command to the worker via the command queue.""" - # ------------------------------------------------------------------------------------------------ - # Section: Remote call implementation + cmd: str + args: tuple[Any, ...] + kwargs: dict[str, Any] + future_id: int - @edataclass - class WorkerResult: - """Internal class for communicating a result from the worker via the result queue.""" - future_id: int - result: Any = None - exception: Exception | None = None +class FutureImpl(Future[Any]): + """Class for returning a future result from the worker..""" - @edataclass - class WorkerCommand: - """Internal class for communicating a command to the worker via the command queue.""" + __slots__ = ("_worker", "_future_id", "_result", "_exception", "_cancelled") - cmd: str - args: tuple[Any, ...] - kwargs: dict[str, Any] - future_id: int + _worker: "DataLoaderAsynchronousWorker" + _future_id: int + _result: Any + _exception: Exception + _cancelled: bool - class FutureImpl(Future[Any]): - """Class for returning a future result from the worker..""" + def __init__(self, worker: "DataLoaderAsynchronousWorker", future_id: int): + self._worker = worker + self._future_id = future_id - _worker: "DataLoaderAsynchronousWorker" - _future_id: int - _result: Any - _exception: Exception + def get(self) -> Any: + if getattr(self, "_cancelled", False): + raise CancelledError() + if not hasattr(self, "_result") and not hasattr(self, "_exception"): + self._worker._wait_for_worker_result(self._future_id) + if hasattr(self, "_exception"): + raise self._exception + return self._result - def __init__(self, worker: "DataLoaderAsynchronousWorker", future_id: int): - self._worker = worker - self._future_id = future_id + def cancel(self) -> bool: + if getattr(self, "_cancelled", False): + return True + if hasattr(self, "_result") or hasattr(self, "_exception"): + return False + # In case the main process is waiting for thie future to complete, add the result + self._worker._result_queue.put( + WorkerResult( + future_id=self._future_id, exception=CancelledError.with_current_traceback() + ) + ) + self._cancelled = True + return True + + def _set_result(self, result: Any) -> None: + self._result = result + + def _set_exception(self, exception: Exception) -> None: + self._exception = exception + + +class AsynchronousMixin: + """Mixin for asynchronous workers.""" - def get(self) -> Any: - if not hasattr(self, "_result") and not hasattr(self, "_exception"): - self._worker._wait_for_worker_result(self._future_id) - if hasattr(self, "_exception"): - raise self._exception - return self._result + _cmd_queue: QueueProtocol[WorkerCommand] + _result_queue: QueueProtocol[WorkerResult] + _next_future_id: int + _futures: dict[int, FutureImpl] + _name: str - def _set_result(self, result: Any) -> None: - self._result = result + def _asynchronous_init(self, name: str) -> None: + self._cmd_queue, self._result_queue = self._queues() + self._next_future_id = 0 + self._futures = {} + self._name = name - def _set_exception(self, exception: Exception) -> None: - self._exception = exception + @abstractmethod + def _queues(self) -> tuple[QueueProtocol[WorkerCommand], QueueProtocol[WorkerResult]]: ... def _wait_for_worker_result(self, future_id: int) -> None: """ @@ -112,7 +119,7 @@ def _wait_for_worker_result(self, future_id: int) -> None: future_id: The ID of the future to wait for. """ while True: - print(f"[fut={future_id}] waiting for result\n", end="") + print(f"[{self._name}, fut={future_id}] waiting for result\n", end="") res = self._result_queue.get() fut = self._futures.pop(res.future_id) if res.exception is not None: @@ -121,12 +128,21 @@ def _wait_for_worker_result(self, future_id: int) -> None: fut._set_result(res.result) # self._result_queue.task_done() if res.future_id == future_id: - print(f"[fut={future_id}] got result, return\n", end="") + print(f"[{self._name}, fut={future_id}] got result, return\n", end="") return else: - print(f"[fut={future_id}] got result for {res.future_id=}, continue\n", end="") + print( + f"[{self._name}, fut={future_id}] got result for {res.future_id=}, continue\n", + end="", + ) continue + def _cancel_futures(self) -> None: + """Cancel all futures after worker shutdown.""" + for fut in self._futures.values(): + fut.cancel() + self._futures.clear() + def _worker_call(self, fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> Future[R]: """ Call a function in the worker and return a future for getting the result. @@ -143,15 +159,15 @@ def _worker_call(self, fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> future_id = self._next_future_id self._next_future_id += 1 - self._futures[future_id] = future = self.FutureImpl(self, future_id) + self._futures[future_id] = future = FutureImpl(self, future_id) print( - f"[wrk={self._rank_worker_id}] worker_call {fn.__name__=} {args=} {kwargs=} {future_id=}\n", + f"[{self._name}] worker_call {fn.__name__=} {args=} {kwargs=} {future_id=}\n", end="", ) self._cmd_queue.put( - self.WorkerCommand(cmd=fn.__name__, args=args, kwargs=kwargs, future_id=future_id) + WorkerCommand(cmd=fn.__name__, args=args, kwargs=kwargs, future_id=future_id) ) - print(f"[wrk={self._rank_worker_id}] queue: {self._cmd_queue.qsize()}\n", end="") + print(f"[{self._name}] queue: {self._cmd_queue.qsize()}\n", end="") return future def _worker_run( @@ -170,53 +186,49 @@ def _worker_run( """ assert self._in_worker(), "_worker_run must be called in the worker" try: - SystemRng.seed(self._seed) - import torch.utils.data._utils - - torch.utils.data._utils.worker._worker_info = torch.utils.data._utils.worker.WorkerInfo( - id=self._rank_worker_id, - num_workers=self.worker_config.num_workers, - seed=self._seed, - dataset=self.dataset, - ) - self._global_worker_id = self.worker_config.global_worker_id() - self.worker_config.assert_worker() while True: print( - f"[wrk={self._rank_worker_id}] waiting for command, len: {cmd_queue.qsize()}\n", + f"[{self._name}] waiting for command {cmd_queue.qsize()=}\n", end="", ) cmd = cmd_queue.get() print( - f"[fut={cmd.future_id}] got command {cmd.cmd=} {cmd.args=} {cmd.kwargs=}\n", + f"[{self._name}, fut={cmd.future_id}] got command {cmd.cmd=} {cmd.args=} {cmd.kwargs=}\n", end="", ) try: fn = getattr(self, cmd.cmd) result = fn(*cmd.args, **cmd.kwargs) except Exception as e: - print(f"[fut={cmd.future_id}] send exception {e!r}\n", end="") - result_queue.put(self.WorkerResult(future_id=cmd.future_id, exception=e)) + print(f"[{self._name}, fut={cmd.future_id}] send exception {e!r}\n", end="") + result_queue.put(WorkerResult(future_id=cmd.future_id, exception=e)) else: - print(f"[fut={cmd.future_id}] send result {result!r}\n", end="") - result_queue.put(self.WorkerResult(future_id=cmd.future_id, result=result)) + print(f"[{self._name}, fut={cmd.future_id}] send result {result!r}\n", end="") + result_queue.put(WorkerResult(future_id=cmd.future_id, result=result)) del result # cmd_queue.task_done() if cmd.cmd == self._wrk_shutdown_worker.__name__: - print(f"[fut={cmd.future_id}] got shutdown command, exit\n", end="") + print( + f"[{self._name}, fut={cmd.future_id}] got shutdown command, exit\n", end="" + ) break - print(f"[fut={cmd.future_id}] processed, waiting for next command\n", end="") + print( + f"[{self._name}, fut={cmd.future_id}] processed, waiting for next command\n", + end="", + ) except: traceback.print_exc() raise @abstractmethod - def _in_worker(self) -> bool: + def _assert_running(self) -> bool: """Check if the execution is within the worker.""" ... - # ------------------------------------------------------------------------------------------------ - # Section: Worker methods - now calling to workers via queues. + @abstractmethod + def _in_worker(self) -> bool: + """Check if the execution is within the worker.""" + ... def _wrk_shutdown_worker(self) -> None: """Does nothing. The actual shutdown is handled in the _worker_run method.""" @@ -227,6 +239,57 @@ def _shutdown_worker(self) -> None: assert not self._in_worker(), "shutdown_worker must not be called in the worker" # This is not actually a recursive call, because the worker loop will exit before calling this method. self._worker_call(self._wrk_shutdown_worker).get() + self._cancel_futures() + print(f"[{self._name}] shutdown\n", end="") + + @abstractmethod + def start(self) -> None: ... + + @abstractmethod + def shutdown(self) -> None: ... + + @abstractmethod + def running(self) -> bool: ... + + +class DataLoaderAsynchronousWorker(DataLoaderWorker[TSample], AsynchronousMixin, Generic[TSample]): + """ + Extension of the `DataLoaderWorker`, which implements commands via a command and results queue. + + There are different implementations of the async worker: + - :class:`ForkDataLoaderWorker` - A worker that forks a new process for each worker. + - :class:`ThreadDataLoaderWorker` - A worker that uses threads to execute the commands. + """ + + def __init__( + self, + dataset: SavableDataset, + worker_config: WorkerConfig, + rank_worker_id: int, + cache_pool: CachePool | None, + ): + super().__init__(dataset, worker_config, rank_worker_id, cache_pool) + assert worker_config.num_workers > 0, "Async workers require num_workers > 0" + self._asynchronous_init(name=f"wkr-{rank_worker_id}") + + # ------------------------------------------------------------------------------------------------ + # Section: Worker methods - now calling to workers via queues. + + def _worker_run( + self, cmd_queue: QueueProtocol[WorkerCommand], result_queue: QueueProtocol[WorkerResult] + ) -> None: + SystemRng.seed(self._seed) + import torch.utils.data._utils + + torch.utils.data._utils.worker._worker_info = torch.utils.data._utils.worker.WorkerInfo( + id=self._rank_worker_id, + num_workers=self.worker_config.num_workers, + seed=self._seed, + dataset=self.dataset, + ) + self._global_worker_id = self.worker_config.global_worker_id() + + super()._worker_run(cmd_queue, result_queue) def _wrk_prefetch_next(self) -> TSample: """Wraps the super class method to call it in the worker process.""" diff --git a/src/megatron/energon/loader/workers/base_worker.py b/src/megatron/energon/dataloader/workers/base_worker.py similarity index 98% rename from src/megatron/energon/loader/workers/base_worker.py rename to src/megatron/energon/dataloader/workers/base_worker.py index bc60c8e3..bf348d91 100644 --- a/src/megatron/energon/loader/workers/base_worker.py +++ b/src/megatron/energon/dataloader/workers/base_worker.py @@ -3,8 +3,8 @@ from typing import Generic, TypeVar from megatron.energon.cache.base import CachePool +from megatron.energon.dataloader.future import DoneFuture, ExceptionFuture, Future from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key -from megatron.energon.loader.future import DoneFuture, ExceptionFuture, Future from megatron.energon.rng import SystemRng from megatron.energon.state import FlexState from megatron.energon.worker import WorkerConfig diff --git a/src/megatron/energon/loader/workers/fork_worker.py b/src/megatron/energon/dataloader/workers/fork_worker.py similarity index 76% rename from src/megatron/energon/loader/workers/fork_worker.py rename to src/megatron/energon/dataloader/workers/fork_worker.py index 3b070e9e..72b280f8 100644 --- a/src/megatron/energon/loader/workers/fork_worker.py +++ b/src/megatron/energon/dataloader/workers/fork_worker.py @@ -7,19 +7,20 @@ import warnings from typing import Generic, TypeVar, override -from megatron.energon.cache.base import CachePool -from megatron.energon.flavors.base_dataset import SavableDataset -from megatron.energon.loader.workers.async_worker import DataLoaderAsynchronousWorker -from megatron.energon.worker import WorkerConfig +from megatron.energon.dataloader.workers.async_worker import ( + AsynchronousMixin, + DataLoaderAsynchronousWorker, + QueueProtocol, + WorkerCommand, + WorkerResult, +) from megatron.energon.wrappers.gc_dataset import gc_init_worker TSample = TypeVar("TSample", covariant=True) -class ForkDataLoaderWorker(DataLoaderAsynchronousWorker[TSample], Generic[TSample]): - """ - Implements the `DataLoaderWorker` interface using processes. - """ +class ForkAsynchronousMixin(AsynchronousMixin): + """Mixin for asynchronous workers that use processes.""" _process: multiprocessing.Process | None = None _cmd_queue: multiprocessing.Queue @@ -29,42 +30,33 @@ class ForkDataLoaderWorker(DataLoaderAsynchronousWorker[TSample], Generic[TSampl _spawning_process: int - def __init__( - self, - dataset: SavableDataset, - worker_config: WorkerConfig, - rank_worker_id: int, - cache_pool: CachePool | None, - ): - multiprocessing.set_start_method("fork", force=True) - super().__init__( - dataset, - worker_config=worker_config, - rank_worker_id=rank_worker_id, - cmd_queue=multiprocessing.Queue(), - result_queue=multiprocessing.Queue(), - cache_pool=cache_pool, - ) + @override + def _asynchronous_init(self, name: str) -> None: + super()._asynchronous_init(name) self._spawning_process = os.getpid() + @override + def _queues(self) -> tuple[QueueProtocol[WorkerCommand], QueueProtocol[WorkerResult]]: + return multiprocessing.Queue(), multiprocessing.Queue() + def _check_parent_process(self, evt_exit: threading.Event) -> None: """Check if the parent process is alive. If it is dead, exit the worker process.""" parent_proc = multiprocessing.parent_process() parent_pid = os.getppid() if parent_proc is None: - print("No parent process, exiting", file=sys.stderr) + print(f"[{self._name}] No parent process, exiting", file=sys.stderr) os._exit(-1) while not evt_exit.wait(1): if parent_proc.exitcode is not None or os.getppid() != parent_pid: - print("Parent process died, exiting", file=sys.stderr) + print(f"[{self._name}] Parent process died, exiting", file=sys.stderr) os._exit(-1) + @override def _worker_run( self, cmd_queue: multiprocessing.Queue, result_queue: multiprocessing.Queue, ) -> None: - gc_init_worker(self._rank_worker_id) # cmd_queue is read only, so we can cancel the join thread. cmd_queue.cancel_join_thread() worker_exit_evt = threading.Event() @@ -75,19 +67,19 @@ def _worker_run( try: super()._worker_run(cmd_queue, result_queue) finally: - print(f"[wrk={self._rank_worker_id}] shutting down\n", end="") + print(f"[{self._name}] shutting down\n", end="") worker_exit_evt.set() print( - f"[wrk={self._rank_worker_id}] shutting down, wait for parent_check_thread\n", + f"[{self._name}] shutting down, wait for parent_check_thread\n", end="", ) parent_check_thread.join() - print(f"[wrk={self._rank_worker_id}] shutting down, close queues\n", end="") + print(f"[{self._name}] shutting down, close queues\n", end="") result_queue.close() result_queue.join_thread() cmd_queue.close() cmd_queue.cancel_join_thread() - print(f"[wrk={self._rank_worker_id}] shutting down, done\n", end="") + print(f"[{self._name}] shutting down, done\n", end="") @override def _in_worker(self) -> bool: @@ -95,11 +87,12 @@ def _in_worker(self) -> bool: @override def start(self) -> None: + multiprocessing.set_start_method("fork", force=True) self._process = multiprocessing.Process( target=self._worker_run, args=(self._cmd_queue, self._result_queue), daemon=True, - name=f"ForkDataLoaderWorker-{self._rank_worker_id}", + name=f"ForkDataLoaderWorker-{self._name}", ) self._process.start() @@ -130,6 +123,7 @@ def shutdown(self, in_del: bool = False) -> None: # Kill the process, because we cannot communicate with it in the gc loop. self._process.terminate() self._process = None + self._cancel_futures() else: try: self._shutdown_worker() @@ -152,6 +146,23 @@ def shutdown(self, in_del: bool = False) -> None: def running(self) -> bool: return self._process is not None + @override def _assert_running(self) -> None: assert self._process is not None, "Worker must be started first" assert self._process.is_alive(), "Worker died" + + +class ForkDataLoaderWorker( + ForkAsynchronousMixin, DataLoaderAsynchronousWorker[TSample], Generic[TSample] +): + """ + Implements the `DataLoaderWorker` interface using processes. + """ + + def _worker_run( + self, + cmd_queue: multiprocessing.Queue, + result_queue: multiprocessing.Queue, + ) -> None: + gc_init_worker(self._rank_worker_id) + super()._worker_run(cmd_queue, result_queue) diff --git a/src/megatron/energon/loader/workers/thread_worker.py b/src/megatron/energon/dataloader/workers/thread_worker.py similarity index 65% rename from src/megatron/energon/loader/workers/thread_worker.py rename to src/megatron/energon/dataloader/workers/thread_worker.py index cd646ae1..e8b066bc 100644 --- a/src/megatron/energon/loader/workers/thread_worker.py +++ b/src/megatron/energon/dataloader/workers/thread_worker.py @@ -6,41 +6,25 @@ import warnings from typing import Generic, TypeVar, override -from megatron.energon.cache.base import CachePool -from megatron.energon.flavors.base_dataset import SavableDataset -from megatron.energon.loader.workers.async_worker import DataLoaderAsynchronousWorker -from megatron.energon.worker import WorkerConfig +from megatron.energon.dataloader.workers.async_worker import ( + AsynchronousMixin, + DataLoaderAsynchronousWorker, + QueueProtocol, + WorkerCommand, + WorkerResult, +) TSample = TypeVar("TSample", covariant=True) -class ThreadDataLoaderWorker(DataLoaderAsynchronousWorker[TSample], Generic[TSample]): - """ - Implements the `DataLoaderWorker` interface using threads. - """ +class ThreadAsynchronousMixin(AsynchronousMixin): + """Mixin for asynchronous workers that use threads.""" _thread: threading.Thread | None = None - _cmd_queue: queue.Queue - _result_queue: queue.Queue - def __init__( - self, - dataset: SavableDataset, - worker_config: WorkerConfig, - rank_worker_id: int, - cache_pool: CachePool | None, - ): - super().__init__( - dataset, - worker_config=worker_config, - rank_worker_id=rank_worker_id, - cmd_queue=queue.Queue(), - result_queue=queue.Queue(), - cache_pool=cache_pool, - ) - - def _worker_run(self, cmd_queue: queue.Queue, result_queue: queue.Queue) -> None: - super()._worker_run(cmd_queue, result_queue) + @override + def _queues(self) -> tuple[QueueProtocol[WorkerCommand], QueueProtocol[WorkerResult]]: + return queue.Queue(), queue.Queue() @override def _in_worker(self) -> bool: @@ -52,7 +36,7 @@ def start(self) -> None: target=self._worker_run, args=(self._cmd_queue, self._result_queue), daemon=True, - name=f"ThreadDataLoaderWorker-{self._rank_worker_id}", + name=f"{self._name}", ) self._thread.start() @@ -72,10 +56,11 @@ def shutdown(self, in_del: bool = False) -> None: ) # Just try to enqueue the shutdown command to the thread and hope for the best. Ignore the result. self._cmd_queue.put( - self.WorkerCommand( + WorkerCommand( cmd=self._wrk_shutdown_worker.__name__, args=(), kwargs={}, future_id=-1 ) ) + self._cancel_futures() self._thread = None else: self._shutdown_worker() @@ -90,3 +75,11 @@ def running(self) -> bool: def _assert_running(self) -> None: assert self._thread is not None, "Thread must be started first" assert self._thread.is_alive(), "Thread died" + + +class ThreadDataLoaderWorker( + ThreadAsynchronousMixin, DataLoaderAsynchronousWorker[TSample], Generic[TSample] +): + """ + Implements the `DataLoaderWorker` interface using threads. + """ diff --git a/src/megatron/energon/flavors/base_dataset.py b/src/megatron/energon/flavors/base_dataset.py index 928092be..c18ad96f 100644 --- a/src/megatron/energon/flavors/base_dataset.py +++ b/src/megatron/energon/flavors/base_dataset.py @@ -67,7 +67,13 @@ def sample_pin_memory(cls, batch: T, device: Union[torch.device, str, None] = No return batch def pin_memory(self: Self, device: torch.device | str | None = None) -> Self: - return self.sample_pin_memory(self, device) + assert dataclasses.is_dataclass(self), "Must be a dataclass" + return type(self)( + **{ + field.name: self.sample_pin_memory(getattr(self, field.name), device) + for field in dataclasses.fields(self) + } + ) class ExtendableDataclassMixin: diff --git a/src/megatron/energon/loader/__init__.py b/src/megatron/energon/loader/__init__.py deleted file mode 100644 index 43842921..00000000 --- a/src/megatron/energon/loader/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. -# SPDX-License-Identifier: BSD-3-Clause diff --git a/src/megatron/energon/loader/workers/__init__.py b/src/megatron/energon/loader/workers/__init__.py deleted file mode 100644 index 43842921..00000000 --- a/src/megatron/energon/loader/workers/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. -# SPDX-License-Identifier: BSD-3-Clause diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 4c486f65..06193a10 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -112,7 +112,8 @@ def test_dataloader_no_workers(self): shuffle_buffer_size=None, max_samples_per_sequence=None, repeat=False, - ) + ), + pin_memory=None, ) assert len(train_loader) == 6, len(train_loader) @@ -141,7 +142,8 @@ def test_dataloader_no_workers(self): shuffle_buffer_size=None, max_samples_per_sequence=None, repeat=False, - ) + ), + pin_memory=None, ) train_loader.restore_state_rank(state1) @@ -176,6 +178,7 @@ def test_dataloader_fork(self): gc_freeze_at_start=True, watchdog_timeout_seconds=60, fail_on_timeout=True, + pin_memory=None, ) assert len(train_loader) == 6, len(train_loader) @@ -213,6 +216,7 @@ def test_dataloader_fork(self): gc_freeze_at_start=True, watchdog_timeout_seconds=60, fail_on_timeout=True, + pin_memory=None, ) train_loader.restore_state_rank(state1) @@ -246,6 +250,7 @@ def test_dataloader_thread(self): gc_collect_every_n_steps=0, watchdog_timeout_seconds=60, fail_on_timeout=True, + pin_memory=None, ) assert len(train_loader) == 6, len(train_loader) @@ -280,6 +285,7 @@ def test_dataloader_thread(self): gc_collect_every_n_steps=0, watchdog_timeout_seconds=60, fail_on_timeout=True, + pin_memory=None, ) train_loader.restore_state_rank(state1) diff --git a/tests/test_dataset_det.py b/tests/test_dataset_det.py index 92919b0a..7cac5b21 100644 --- a/tests/test_dataset_det.py +++ b/tests/test_dataset_det.py @@ -23,12 +23,12 @@ TextSample, WorkerConfig, get_loader, + get_savable_loader, get_train_dataset, + stateless, ) from megatron.energon.dataset_config import get_dataset_from_config from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME -from megatron.energon.loader import get_savable_loader -from megatron.energon.task_encoder.base import stateless from megatron.energon.tools.checkpoint import command_redist # Speed up tests significantly by reducing the torch status check interval for broken worker shutdown From be6265688c249dc18a86e4228676966f9dd8ea07 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Thu, 24 Jul 2025 11:20:25 +0200 Subject: [PATCH 10/36] Add pinning --- .../dataloader/asynchronous/__init__.py | 14 + .../energon/dataloader/asynchronous/base.py | 259 ++++++++++++++++++ .../energon/dataloader/asynchronous/fork.py | 148 ++++++++++ .../energon/dataloader/asynchronous/thread.py | 74 +++++ src/megatron/energon/dataloader/pin_memory.py | 27 +- .../dataloader/workers/async_worker.py | 251 +---------------- .../energon/dataloader/workers/fork_worker.py | 146 +--------- .../dataloader/workers/thread_worker.py | 73 +---- tests/test_dataloader.py | 6 - 9 files changed, 529 insertions(+), 469 deletions(-) create mode 100644 src/megatron/energon/dataloader/asynchronous/__init__.py create mode 100644 src/megatron/energon/dataloader/asynchronous/base.py create mode 100644 src/megatron/energon/dataloader/asynchronous/fork.py create mode 100644 src/megatron/energon/dataloader/asynchronous/thread.py diff --git a/src/megatron/energon/dataloader/asynchronous/__init__.py b/src/megatron/energon/dataloader/asynchronous/__init__.py new file mode 100644 index 00000000..f77fef80 --- /dev/null +++ b/src/megatron/energon/dataloader/asynchronous/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +from .base import Asynchronous, QueueProtocol, WorkerCommand, WorkerResult +from .fork import ForkAsynchronous +from .thread import ThreadAsynchronous + +__all__ = [ + "Asynchronous", + "QueueProtocol", + "WorkerCommand", + "WorkerResult", + "ForkAsynchronous", + "ThreadAsynchronous", +] diff --git a/src/megatron/energon/dataloader/asynchronous/base.py b/src/megatron/energon/dataloader/asynchronous/base.py new file mode 100644 index 00000000..879a6bc5 --- /dev/null +++ b/src/megatron/energon/dataloader/asynchronous/base.py @@ -0,0 +1,259 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +import threading +import traceback +from abc import abstractmethod +from typing import Any, Callable, ParamSpec, Protocol, TypeVar + +from megatron.energon.dataloader.future import CancelledError, Future +from megatron.energon.edataclass import edataclass + +P = ParamSpec("P") +T = TypeVar("T") +R = TypeVar("R", covariant=True) + + +class QueueProtocol(Protocol[T]): + """Protocol for a queue.""" + + def get(self, /) -> T: ... + + def put(self, item: T, /) -> None: ... + + def qsize(self, /) -> int: ... + + +@edataclass +class WorkerCommand: + """Internal class for communicating a command to the worker via the command queue.""" + + cmd: str + args: tuple[Any, ...] + kwargs: dict[str, Any] + future_id: int + + +@edataclass +class WorkerResult: + """Internal class for communicating a result from the worker via the result queue.""" + + future_id: int + result: Any = None + exception: Exception | None = None + + +class FutureImpl(Future[Any]): + """Class for returning a future result from the worker..""" + + __slots__ = ("_worker", "_future_id", "_result", "_exception", "_cancelled") + + _worker: "Asynchronous" + _future_id: int + _result: Any + _exception: Exception + + def __init__(self, worker: "Asynchronous", future_id: int): + self._worker = worker + self._future_id = future_id + + def get(self) -> Any: + if not hasattr(self, "_result") and not hasattr(self, "_exception"): + self._worker._wait_for_worker_result(self) + if hasattr(self, "_exception"): + raise self._exception + return self._result + + def cancel(self) -> bool: + if hasattr(self, "_result") or hasattr(self, "_exception"): + print( + f"[{self._worker._name}, fut={self._future_id}] already has result or exception\n", + end="", + ) + return False + self._exception = CancelledError.with_current_traceback() + self._worker._cancel_future(self._future_id) + return True + + def done(self) -> bool: + return hasattr(self, "_result") or hasattr(self, "_exception") + + def _set_result(self, result: Any) -> None: + self._result = result + + def _set_exception(self, exception: Exception) -> None: + self._exception = exception + + +class Asynchronous: + """Asynchronous base class.""" + + _cmd_queue: QueueProtocol[WorkerCommand] + _result_queue: QueueProtocol[WorkerResult] + _next_future_id: int + _pending_futures: dict[int, FutureImpl] + _name: str + _result_lock: threading.Lock + + def _asynchronous_init(self, name: str) -> None: + self._cmd_queue, self._result_queue = self._queues() + self._next_future_id = 0 + self._pending_futures = {} + self._name = name + self._result_lock = threading.Lock() + + @abstractmethod + def _queues(self) -> tuple[QueueProtocol[WorkerCommand], QueueProtocol[WorkerResult]]: ... + + def _wait_for_worker_result(self, future: FutureImpl) -> None: + """ + Wait for the result of a future. + If another result comes first, update the corresponding future. + + Args: + future: The future to wait for. + """ + print(f"[{self._name}, fut={future._future_id}] waiting for result\n", end="") + with self._result_lock: + if future.done(): + # If calling get() from multiple threads, the future may be done now, because + # the other thread already set the result. + return + print(f"[{self._name}, fut={future._future_id}] got future\n", end="") + while True: + res = self._result_queue.get() + fut = self._pending_futures.pop(res.future_id) + if res.exception is not None: + fut._set_exception(res.exception) + else: + fut._set_result(res.result) + if res.future_id == future._future_id: + print(f"[{self._name}, fut={future._future_id}] got result, return\n", end="") + return + else: + print( + f"[{self._name}, fut={future._future_id}] got result for {res.future_id=}, continue\n", + end="", + ) + continue + + def _cancel_future(self, future_id: int) -> None: + """Cancel a future.""" + print(f"[{self._name}, fut={future_id}] cancelling future\n", end="") + # In case the main process is waiting for thie future to complete, add the result + self._result_queue.put( + WorkerResult(future_id=future_id, exception=CancelledError.with_current_traceback()) + ) + + def _cancel_futures(self) -> None: + """Cancel all futures after worker shutdown.""" + for fut in self._pending_futures.values(): + fut.cancel() + self._pending_futures.clear() + + def _worker_call(self, fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> Future[R]: + """ + Call a function in the worker and return a future for getting the result. + The function must be an instance method of `self`. Uses the name to identify the function in the worker + instance. + + Args: + fn: The function to call. + *args: The arguments to pass to the function. + **kwargs: The keyword arguments to pass to the function. + """ + self._assert_running() + assert not self._in_worker(), "worker_call must not be called in the worker" + future_id = self._next_future_id + self._next_future_id += 1 + + self._pending_futures[future_id] = future = FutureImpl(self, future_id) + print( + f"[{self._name}] worker_call {fn.__name__=} {args=} {kwargs=} {future_id=}\n", + end="", + ) + self._cmd_queue.put( + WorkerCommand(cmd=fn.__name__, args=args, kwargs=kwargs, future_id=future_id) + ) + print(f"[{self._name}] queue: {self._cmd_queue.qsize()}\n", end="") + return future + + def _worker_run( + self, cmd_queue: QueueProtocol[WorkerCommand], result_queue: QueueProtocol[WorkerResult] + ) -> None: + """ + The worker main loop. + It waits for commands via the command queue and executes them. + The functions to call are identified by their name. + The result of the call is put into the result queue. + The worker exits when the command `_shutdown_worker` is received. + + Args: + cmd_queue: The command queue to wait for commands. + result_queue: The result queue to put the results into. + """ + assert self._in_worker(), "_worker_run must be called in the worker" + try: + while True: + print( + f"[{self._name}] waiting for command {cmd_queue.qsize()=}\n", + end="", + ) + cmd = cmd_queue.get() + print( + f"[{self._name}, fut={cmd.future_id}] got command {cmd.cmd=} {cmd.args=} {cmd.kwargs=}\n", + end="", + ) + try: + fn = getattr(self, cmd.cmd) + result = fn(*cmd.args, **cmd.kwargs) + except Exception as e: + print(f"[{self._name}, fut={cmd.future_id}] send exception {e!r}\n", end="") + result_queue.put(WorkerResult(future_id=cmd.future_id, exception=e)) + else: + print(f"[{self._name}, fut={cmd.future_id}] send result {result!r}\n", end="") + result_queue.put(WorkerResult(future_id=cmd.future_id, result=result)) + del result + # cmd_queue.task_done() + if cmd.cmd == self._wrk_shutdown_worker.__name__: + print( + f"[{self._name}, fut={cmd.future_id}] got shutdown command, exit\n", end="" + ) + break + print( + f"[{self._name}, fut={cmd.future_id}] processed, waiting for next command\n", + end="", + ) + except: + traceback.print_exc() + raise + + @abstractmethod + def _assert_running(self) -> bool: + """Check if the execution is within the worker.""" + ... + + @abstractmethod + def _in_worker(self) -> bool: + """Check if the execution is within the worker.""" + ... + + def _wrk_shutdown_worker(self) -> None: + """Does nothing. The actual shutdown is handled in the _worker_run method.""" + assert self._in_worker(), "_wrk_shutdown_worker must be called in the worker" + + def _shutdown_worker(self) -> None: + """Shutdown the worker. The actual shutdown is handled in the _worker_run method.""" + assert not self._in_worker(), "shutdown_worker must not be called in the worker" + # This is not actually a recursive call, because the worker loop will exit before calling this method. + self._worker_call(self._wrk_shutdown_worker).get() + self._cancel_futures() + print(f"[{self._name}] shutdown\n", end="") + + @abstractmethod + def start(self) -> None: ... + + @abstractmethod + def shutdown(self) -> None: ... + + @abstractmethod + def running(self) -> bool: ... diff --git a/src/megatron/energon/dataloader/asynchronous/fork.py b/src/megatron/energon/dataloader/asynchronous/fork.py new file mode 100644 index 00000000..1b14298b --- /dev/null +++ b/src/megatron/energon/dataloader/asynchronous/fork.py @@ -0,0 +1,148 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +import multiprocessing +import os +import sys +import threading +import warnings +from typing import override + +from megatron.energon.dataloader.asynchronous.base import ( + Asynchronous, + QueueProtocol, + WorkerCommand, + WorkerResult, +) + + +class ForkAsynchronous(Asynchronous): + """Mixin for asynchronous workers that use processes.""" + + _process: multiprocessing.Process | None = None + _cmd_queue: multiprocessing.Queue + _result_queue: multiprocessing.Queue + + _threaded_shutdown: threading.Thread | None = None + + _spawning_process: int + + @override + def _asynchronous_init(self, name: str) -> None: + super()._asynchronous_init(name) + self._spawning_process = os.getpid() + + @override + def _queues(self) -> tuple[QueueProtocol[WorkerCommand], QueueProtocol[WorkerResult]]: + return multiprocessing.Queue(), multiprocessing.Queue() + + def _check_parent_process(self, evt_exit: threading.Event) -> None: + """Check if the parent process is alive. If it is dead, exit the worker process.""" + parent_proc = multiprocessing.parent_process() + parent_pid = os.getppid() + if parent_proc is None: + print(f"[{self._name}] No parent process, exiting", file=sys.stderr) + os._exit(-1) + while not evt_exit.wait(1): + if parent_proc.exitcode is not None or os.getppid() != parent_pid: + print(f"[{self._name}] Parent process died, exiting", file=sys.stderr) + os._exit(-1) + + @override + def _worker_run( + self, + cmd_queue: multiprocessing.Queue, + result_queue: multiprocessing.Queue, + ) -> None: + # cmd_queue is read only, so we can cancel the join thread. + cmd_queue.cancel_join_thread() + worker_exit_evt = threading.Event() + parent_check_thread = threading.Thread( + target=self._check_parent_process, args=(worker_exit_evt,), daemon=True + ) + parent_check_thread.start() + try: + super()._worker_run(cmd_queue, result_queue) + finally: + print(f"[{self._name}] shutting down\n", end="") + worker_exit_evt.set() + print( + f"[{self._name}] shutting down, wait for parent_check_thread\n", + end="", + ) + parent_check_thread.join() + print(f"[{self._name}] shutting down, close queues\n", end="") + result_queue.close() + result_queue.join_thread() + cmd_queue.close() + cmd_queue.cancel_join_thread() + print(f"[{self._name}] shutting down, done\n", end="") + + @override + def _in_worker(self) -> bool: + return multiprocessing.current_process() == self._process + + @override + def start(self) -> None: + multiprocessing.set_start_method("fork", force=True) + self._process = multiprocessing.Process( + target=self._worker_run, + args=(self._cmd_queue, self._result_queue), + daemon=True, + name=f"ForkDataLoaderWorker-{self._name}", + ) + self._process.start() + + @override + def shutdown(self, in_del: bool = False) -> None: + if self._spawning_process != os.getpid(): + # Should avoid forked process containing a forked worker on exit. + warnings.warn( + "Shutting down worker from a different process than the one that spawned it, skipping" + ) + return + if self._process is not None: + if in_del: + # It seems that the ResourceWarning does not work in the gc loop? Also print a warning here. + warnings.warn( + "Explicitly call DataLoader.shutdown() to avoid leaking processes. Terminating worker process.", + ResourceWarning, + ) + print( + "WARNING: Explicitly call DataLoader.shutdown() to avoid leaking processes. Terminating worker process.\n", + end="", + file=sys.stderr, + ) + self._cmd_queue.close() + self._cmd_queue.cancel_join_thread() + self._result_queue.close() + self._result_queue.cancel_join_thread() + # Kill the process, because we cannot communicate with it in the gc loop. + self._process.terminate() + self._process = None + self._cancel_futures() + else: + try: + self._shutdown_worker() + except Exception: + self._process.join(10) + if self._process.is_alive(): + self._process.terminate() + else: + self._process.join() + assert self._process.exitcode == 0, ( + f"Process exit code {self._process.exitcode}" + ) + self._process = None + self._cmd_queue.close() + self._cmd_queue.cancel_join_thread() + self._result_queue.close() + self._result_queue.cancel_join_thread() + + @override + def running(self) -> bool: + return self._process is not None + + @override + def _assert_running(self) -> None: + assert self._process is not None, "Worker must be started first" + assert self._process.is_alive(), "Worker died" diff --git a/src/megatron/energon/dataloader/asynchronous/thread.py b/src/megatron/energon/dataloader/asynchronous/thread.py new file mode 100644 index 00000000..8a9710a6 --- /dev/null +++ b/src/megatron/energon/dataloader/asynchronous/thread.py @@ -0,0 +1,74 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +import queue +import sys +import threading +import warnings +from typing import override + +from megatron.energon.dataloader.asynchronous.base import ( + Asynchronous, + QueueProtocol, + WorkerCommand, + WorkerResult, +) + + +class ThreadAsynchronous(Asynchronous): + """Mixin for asynchronous workers that use threads.""" + + _thread: threading.Thread | None = None + + @override + def _queues(self) -> tuple[QueueProtocol[WorkerCommand], QueueProtocol[WorkerResult]]: + return queue.Queue(), queue.Queue() + + @override + def _in_worker(self) -> bool: + return threading.current_thread() == self._thread + + @override + def start(self) -> None: + self._thread = threading.Thread( + target=self._worker_run, + args=(self._cmd_queue, self._result_queue), + daemon=True, + name=f"{self._name}", + ) + self._thread.start() + + @override + def shutdown(self, in_del: bool = False) -> None: + if self._thread is not None: + if in_del: + # It seems that the ResourceWarning does not work in the gc loop? Also print a warning here. + warnings.warn( + "Explicitly call DataLoader.shutdown() to avoid leaking threads.", + ResourceWarning, + ) + print( + "WARNING: Explicitly call DataLoader.shutdown() to avoid leaking threads.\n", + end="", + file=sys.stderr, + ) + # Just try to enqueue the shutdown command to the thread and hope for the best. Ignore the result. + self._cmd_queue.put( + WorkerCommand( + cmd=self._wrk_shutdown_worker.__name__, args=(), kwargs={}, future_id=-1 + ) + ) + self._cancel_futures() + self._thread = None + else: + self._shutdown_worker() + self._thread.join() + self._thread = None + + @override + def running(self) -> bool: + return self._thread is not None + + @override + def _assert_running(self) -> None: + assert self._thread is not None, "Thread must be started first" + assert self._thread.is_alive(), "Thread died" diff --git a/src/megatron/energon/dataloader/pin_memory.py b/src/megatron/energon/dataloader/pin_memory.py index ad493914..4c1539d5 100644 --- a/src/megatron/energon/dataloader/pin_memory.py +++ b/src/megatron/energon/dataloader/pin_memory.py @@ -1,13 +1,12 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -import queue import threading from typing import Generic, TypeVar, cast import torch +from megatron.energon.dataloader.asynchronous import ThreadAsynchronous from megatron.energon.dataloader.future import CallableFuture, Future -from megatron.energon.dataloader.workers.async_worker import AsynchronousMixin from megatron.energon.flavors.base_dataset import PinMemoryMixin TSample = TypeVar("TSample") @@ -30,8 +29,13 @@ def __call__(self, sample: Future[TSample]) -> Future[TSample]: """Pin the memory of a sample. The default implementation runs in the main thread.""" return CallableFuture.chain(sample, lambda fut: self._pin_memory(fut.get())) - def shutdown(self) -> None: - """Shutdown any running threads.""" + def shutdown(self, in_del: bool = False) -> None: + """ + Shutdown any running threads. + + Args: + in_del: Whether the shutdown is called from the garbage collector. + """ pass @@ -48,12 +52,12 @@ def __call__(self, sample: Future[TSample]) -> Future[TSample]: return sample -class PinMemoryThread(PinMemory[TSample], AsynchronousMixin, Generic[TSample]): +class PinMemoryThread(PinMemory[TSample], ThreadAsynchronous, Generic[TSample]): """Threaded implementation of :class:`PinMemory`. Pins the memory of samples in a separate thread in the background. - Creates the thread on first use and shuts it down on shutdown. May be reused after shutdown. + Creates the thread on first use. """ _SHUTDOWN = cast(Future[TSample], object()) @@ -65,10 +69,15 @@ def __init__( device: str | torch.device, ): super().__init__(device) - self._asynchronous_init(cmd_queue=queue.Queue(), result_queue=queue.Queue()) + self._asynchronous_init(name="pin-memory") def _wrk_pin_memory(self, sample: Future[TSample]) -> TSample: - return self._pin_memory(sample.get()) + print( + f"[{self._name}] Pinning memory of sample {sample}, waiting for sample data\n", end="" + ) + sample_data = sample.get() + print(f"[{self._name}] Got sample data\n", end="") + return self._pin_memory(sample_data) def _worker_pin_memory(self, sample: Future[TSample]) -> Future[TSample]: return self._worker_call(self._wrk_pin_memory, sample) @@ -79,4 +88,6 @@ def __call__(self, sample: Future[TSample]) -> Future[TSample]: Submits the sample future to the thread to fetch it and pins the memory in the thread, then returns a future for fetching the pinned sample. """ + if not self.running(): + self.start() return self._worker_pin_memory(sample) diff --git a/src/megatron/energon/dataloader/workers/async_worker.py b/src/megatron/energon/dataloader/workers/async_worker.py index d2ff0df2..00cff3a9 100644 --- a/src/megatron/energon/dataloader/workers/async_worker.py +++ b/src/megatron/energon/dataloader/workers/async_worker.py @@ -1,13 +1,16 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -import traceback -from abc import abstractmethod -from typing import Any, Callable, Generic, ParamSpec, Protocol, TypeVar, override +from typing import Generic, TypeVar, override from megatron.energon.cache.base import CachePool -from megatron.energon.dataloader.future import CancelledError, Future +from megatron.energon.dataloader.asynchronous import ( + Asynchronous, + QueueProtocol, + WorkerCommand, + WorkerResult, +) +from megatron.energon.dataloader.future import Future from megatron.energon.dataloader.workers.base_worker import DataLoaderWorker -from megatron.energon.edataclass import edataclass from megatron.energon.flavors.base_dataset import SavableDataset from megatron.energon.rng import SystemRng from megatron.energon.state import FlexState @@ -15,244 +18,8 @@ TSample = TypeVar("TSample", covariant=True) -P = ParamSpec("P") -T = TypeVar("T") -R = TypeVar("R", covariant=True) - -class QueueProtocol(Protocol[T]): - def get(self, /) -> T: ... - - def put(self, item: T, /) -> None: ... - - def qsize(self, /) -> int: ... - - -@edataclass -class WorkerResult: - """Internal class for communicating a result from the worker via the result queue.""" - - future_id: int - result: Any = None - exception: Exception | None = None - - -@edataclass -class WorkerCommand: - """Internal class for communicating a command to the worker via the command queue.""" - - cmd: str - args: tuple[Any, ...] - kwargs: dict[str, Any] - future_id: int - - -class FutureImpl(Future[Any]): - """Class for returning a future result from the worker..""" - - __slots__ = ("_worker", "_future_id", "_result", "_exception", "_cancelled") - - _worker: "DataLoaderAsynchronousWorker" - _future_id: int - _result: Any - _exception: Exception - _cancelled: bool - - def __init__(self, worker: "DataLoaderAsynchronousWorker", future_id: int): - self._worker = worker - self._future_id = future_id - - def get(self) -> Any: - if getattr(self, "_cancelled", False): - raise CancelledError() - if not hasattr(self, "_result") and not hasattr(self, "_exception"): - self._worker._wait_for_worker_result(self._future_id) - if hasattr(self, "_exception"): - raise self._exception - return self._result - - def cancel(self) -> bool: - if getattr(self, "_cancelled", False): - return True - if hasattr(self, "_result") or hasattr(self, "_exception"): - return False - # In case the main process is waiting for thie future to complete, add the result - self._worker._result_queue.put( - WorkerResult( - future_id=self._future_id, exception=CancelledError.with_current_traceback() - ) - ) - self._cancelled = True - return True - - def _set_result(self, result: Any) -> None: - self._result = result - - def _set_exception(self, exception: Exception) -> None: - self._exception = exception - - -class AsynchronousMixin: - """Mixin for asynchronous workers.""" - - _cmd_queue: QueueProtocol[WorkerCommand] - _result_queue: QueueProtocol[WorkerResult] - _next_future_id: int - _futures: dict[int, FutureImpl] - _name: str - - def _asynchronous_init(self, name: str) -> None: - self._cmd_queue, self._result_queue = self._queues() - self._next_future_id = 0 - self._futures = {} - self._name = name - - @abstractmethod - def _queues(self) -> tuple[QueueProtocol[WorkerCommand], QueueProtocol[WorkerResult]]: ... - - def _wait_for_worker_result(self, future_id: int) -> None: - """ - Wait for the result of a future. - If another result comes first, update the corresponding future. - - Args: - future_id: The ID of the future to wait for. - """ - while True: - print(f"[{self._name}, fut={future_id}] waiting for result\n", end="") - res = self._result_queue.get() - fut = self._futures.pop(res.future_id) - if res.exception is not None: - fut._set_exception(res.exception) - else: - fut._set_result(res.result) - # self._result_queue.task_done() - if res.future_id == future_id: - print(f"[{self._name}, fut={future_id}] got result, return\n", end="") - return - else: - print( - f"[{self._name}, fut={future_id}] got result for {res.future_id=}, continue\n", - end="", - ) - continue - - def _cancel_futures(self) -> None: - """Cancel all futures after worker shutdown.""" - for fut in self._futures.values(): - fut.cancel() - self._futures.clear() - - def _worker_call(self, fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> Future[R]: - """ - Call a function in the worker and return a future for getting the result. - The function must be an instance method of `self`. Uses the name to identify the function in the worker - instance. - - Args: - fn: The function to call. - *args: The arguments to pass to the function. - **kwargs: The keyword arguments to pass to the function. - """ - self._assert_running() - assert not self._in_worker(), "worker_call must not be called in the worker" - future_id = self._next_future_id - self._next_future_id += 1 - - self._futures[future_id] = future = FutureImpl(self, future_id) - print( - f"[{self._name}] worker_call {fn.__name__=} {args=} {kwargs=} {future_id=}\n", - end="", - ) - self._cmd_queue.put( - WorkerCommand(cmd=fn.__name__, args=args, kwargs=kwargs, future_id=future_id) - ) - print(f"[{self._name}] queue: {self._cmd_queue.qsize()}\n", end="") - return future - - def _worker_run( - self, cmd_queue: QueueProtocol[WorkerCommand], result_queue: QueueProtocol[WorkerResult] - ) -> None: - """ - The worker main loop. - It waits for commands via the command queue and executes them. - The functions to call are identified by their name. - The result of the call is put into the result queue. - The worker exits when the command `_shutdown_worker` is received. - - Args: - cmd_queue: The command queue to wait for commands. - result_queue: The result queue to put the results into. - """ - assert self._in_worker(), "_worker_run must be called in the worker" - try: - while True: - print( - f"[{self._name}] waiting for command {cmd_queue.qsize()=}\n", - end="", - ) - cmd = cmd_queue.get() - print( - f"[{self._name}, fut={cmd.future_id}] got command {cmd.cmd=} {cmd.args=} {cmd.kwargs=}\n", - end="", - ) - try: - fn = getattr(self, cmd.cmd) - result = fn(*cmd.args, **cmd.kwargs) - except Exception as e: - print(f"[{self._name}, fut={cmd.future_id}] send exception {e!r}\n", end="") - result_queue.put(WorkerResult(future_id=cmd.future_id, exception=e)) - else: - print(f"[{self._name}, fut={cmd.future_id}] send result {result!r}\n", end="") - result_queue.put(WorkerResult(future_id=cmd.future_id, result=result)) - del result - # cmd_queue.task_done() - if cmd.cmd == self._wrk_shutdown_worker.__name__: - print( - f"[{self._name}, fut={cmd.future_id}] got shutdown command, exit\n", end="" - ) - break - print( - f"[{self._name}, fut={cmd.future_id}] processed, waiting for next command\n", - end="", - ) - except: - traceback.print_exc() - raise - - @abstractmethod - def _assert_running(self) -> bool: - """Check if the execution is within the worker.""" - ... - - @abstractmethod - def _in_worker(self) -> bool: - """Check if the execution is within the worker.""" - ... - - def _wrk_shutdown_worker(self) -> None: - """Does nothing. The actual shutdown is handled in the _worker_run method.""" - assert self._in_worker(), "_wrk_shutdown_worker must be called in the worker" - - def _shutdown_worker(self) -> None: - """Shutdown the worker. The actual shutdown is handled in the _worker_run method.""" - assert not self._in_worker(), "shutdown_worker must not be called in the worker" - # This is not actually a recursive call, because the worker loop will exit before calling this method. - self._worker_call(self._wrk_shutdown_worker).get() - self._cancel_futures() - print(f"[{self._name}] shutdown\n", end="") - - @abstractmethod - def start(self) -> None: ... - - @abstractmethod - def shutdown(self) -> None: ... - - @abstractmethod - def running(self) -> bool: ... - - -class DataLoaderAsynchronousWorker(DataLoaderWorker[TSample], AsynchronousMixin, Generic[TSample]): +class DataLoaderAsynchronousWorker(DataLoaderWorker[TSample], Asynchronous, Generic[TSample]): """ Extension of the `DataLoaderWorker`, which implements commands via a command and results queue. diff --git a/src/megatron/energon/dataloader/workers/fork_worker.py b/src/megatron/energon/dataloader/workers/fork_worker.py index 72b280f8..ec165be6 100644 --- a/src/megatron/energon/dataloader/workers/fork_worker.py +++ b/src/megatron/energon/dataloader/workers/fork_worker.py @@ -1,159 +1,19 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause import multiprocessing -import os -import sys -import threading -import warnings -from typing import Generic, TypeVar, override +from typing import Generic, TypeVar +from megatron.energon.dataloader.asynchronous import ForkAsynchronous from megatron.energon.dataloader.workers.async_worker import ( - AsynchronousMixin, DataLoaderAsynchronousWorker, - QueueProtocol, - WorkerCommand, - WorkerResult, ) from megatron.energon.wrappers.gc_dataset import gc_init_worker TSample = TypeVar("TSample", covariant=True) -class ForkAsynchronousMixin(AsynchronousMixin): - """Mixin for asynchronous workers that use processes.""" - - _process: multiprocessing.Process | None = None - _cmd_queue: multiprocessing.Queue - _result_queue: multiprocessing.Queue - - _threaded_shutdown: threading.Thread | None = None - - _spawning_process: int - - @override - def _asynchronous_init(self, name: str) -> None: - super()._asynchronous_init(name) - self._spawning_process = os.getpid() - - @override - def _queues(self) -> tuple[QueueProtocol[WorkerCommand], QueueProtocol[WorkerResult]]: - return multiprocessing.Queue(), multiprocessing.Queue() - - def _check_parent_process(self, evt_exit: threading.Event) -> None: - """Check if the parent process is alive. If it is dead, exit the worker process.""" - parent_proc = multiprocessing.parent_process() - parent_pid = os.getppid() - if parent_proc is None: - print(f"[{self._name}] No parent process, exiting", file=sys.stderr) - os._exit(-1) - while not evt_exit.wait(1): - if parent_proc.exitcode is not None or os.getppid() != parent_pid: - print(f"[{self._name}] Parent process died, exiting", file=sys.stderr) - os._exit(-1) - - @override - def _worker_run( - self, - cmd_queue: multiprocessing.Queue, - result_queue: multiprocessing.Queue, - ) -> None: - # cmd_queue is read only, so we can cancel the join thread. - cmd_queue.cancel_join_thread() - worker_exit_evt = threading.Event() - parent_check_thread = threading.Thread( - target=self._check_parent_process, args=(worker_exit_evt,), daemon=True - ) - parent_check_thread.start() - try: - super()._worker_run(cmd_queue, result_queue) - finally: - print(f"[{self._name}] shutting down\n", end="") - worker_exit_evt.set() - print( - f"[{self._name}] shutting down, wait for parent_check_thread\n", - end="", - ) - parent_check_thread.join() - print(f"[{self._name}] shutting down, close queues\n", end="") - result_queue.close() - result_queue.join_thread() - cmd_queue.close() - cmd_queue.cancel_join_thread() - print(f"[{self._name}] shutting down, done\n", end="") - - @override - def _in_worker(self) -> bool: - return multiprocessing.current_process() == self._process - - @override - def start(self) -> None: - multiprocessing.set_start_method("fork", force=True) - self._process = multiprocessing.Process( - target=self._worker_run, - args=(self._cmd_queue, self._result_queue), - daemon=True, - name=f"ForkDataLoaderWorker-{self._name}", - ) - self._process.start() - - @override - def shutdown(self, in_del: bool = False) -> None: - if self._spawning_process != os.getpid(): - # Should avoid forked process containing a forked worker on exit. - warnings.warn( - "Shutting down worker from a different process than the one that spawned it, skipping" - ) - return - if self._process is not None: - if in_del: - # It seems that the ResourceWarning does not work in the gc loop? Also print a warning here. - warnings.warn( - "Explicitly call DataLoader.shutdown() to avoid leaking processes. Terminating worker process.", - ResourceWarning, - ) - print( - "WARNING: Explicitly call DataLoader.shutdown() to avoid leaking processes. Terminating worker process.\n", - end="", - file=sys.stderr, - ) - self._cmd_queue.close() - self._cmd_queue.cancel_join_thread() - self._result_queue.close() - self._result_queue.cancel_join_thread() - # Kill the process, because we cannot communicate with it in the gc loop. - self._process.terminate() - self._process = None - self._cancel_futures() - else: - try: - self._shutdown_worker() - except Exception: - self._process.join(10) - if self._process.is_alive(): - self._process.terminate() - else: - self._process.join() - assert self._process.exitcode == 0, ( - f"Process exit code {self._process.exitcode}" - ) - self._process = None - self._cmd_queue.close() - self._cmd_queue.cancel_join_thread() - self._result_queue.close() - self._result_queue.cancel_join_thread() - - @override - def running(self) -> bool: - return self._process is not None - - @override - def _assert_running(self) -> None: - assert self._process is not None, "Worker must be started first" - assert self._process.is_alive(), "Worker died" - - class ForkDataLoaderWorker( - ForkAsynchronousMixin, DataLoaderAsynchronousWorker[TSample], Generic[TSample] + ForkAsynchronous, DataLoaderAsynchronousWorker[TSample], Generic[TSample] ): """ Implements the `DataLoaderWorker` interface using processes. diff --git a/src/megatron/energon/dataloader/workers/thread_worker.py b/src/megatron/energon/dataloader/workers/thread_worker.py index e8b066bc..3d450eaa 100644 --- a/src/megatron/energon/dataloader/workers/thread_worker.py +++ b/src/megatron/energon/dataloader/workers/thread_worker.py @@ -1,84 +1,17 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -import queue -import sys -import threading -import warnings -from typing import Generic, TypeVar, override +from typing import Generic, TypeVar +from megatron.energon.dataloader.asynchronous import ThreadAsynchronous from megatron.energon.dataloader.workers.async_worker import ( - AsynchronousMixin, DataLoaderAsynchronousWorker, - QueueProtocol, - WorkerCommand, - WorkerResult, ) TSample = TypeVar("TSample", covariant=True) -class ThreadAsynchronousMixin(AsynchronousMixin): - """Mixin for asynchronous workers that use threads.""" - - _thread: threading.Thread | None = None - - @override - def _queues(self) -> tuple[QueueProtocol[WorkerCommand], QueueProtocol[WorkerResult]]: - return queue.Queue(), queue.Queue() - - @override - def _in_worker(self) -> bool: - return threading.current_thread() == self._thread - - @override - def start(self) -> None: - self._thread = threading.Thread( - target=self._worker_run, - args=(self._cmd_queue, self._result_queue), - daemon=True, - name=f"{self._name}", - ) - self._thread.start() - - @override - def shutdown(self, in_del: bool = False) -> None: - if self._thread is not None: - if in_del: - # It seems that the ResourceWarning does not work in the gc loop? Also print a warning here. - warnings.warn( - "Explicitly call DataLoader.shutdown() to avoid leaking threads.", - ResourceWarning, - ) - print( - "WARNING: Explicitly call DataLoader.shutdown() to avoid leaking threads.\n", - end="", - file=sys.stderr, - ) - # Just try to enqueue the shutdown command to the thread and hope for the best. Ignore the result. - self._cmd_queue.put( - WorkerCommand( - cmd=self._wrk_shutdown_worker.__name__, args=(), kwargs={}, future_id=-1 - ) - ) - self._cancel_futures() - self._thread = None - else: - self._shutdown_worker() - self._thread.join() - self._thread = None - - @override - def running(self) -> bool: - return self._thread is not None - - @override - def _assert_running(self) -> None: - assert self._thread is not None, "Thread must be started first" - assert self._thread.is_alive(), "Thread died" - - class ThreadDataLoaderWorker( - ThreadAsynchronousMixin, DataLoaderAsynchronousWorker[TSample], Generic[TSample] + ThreadAsynchronous, DataLoaderAsynchronousWorker[TSample], Generic[TSample] ): """ Implements the `DataLoaderWorker` interface using threads. diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 06193a10..2c057c8b 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -113,7 +113,6 @@ def test_dataloader_no_workers(self): max_samples_per_sequence=None, repeat=False, ), - pin_memory=None, ) assert len(train_loader) == 6, len(train_loader) @@ -143,7 +142,6 @@ def test_dataloader_no_workers(self): max_samples_per_sequence=None, repeat=False, ), - pin_memory=None, ) train_loader.restore_state_rank(state1) @@ -178,7 +176,6 @@ def test_dataloader_fork(self): gc_freeze_at_start=True, watchdog_timeout_seconds=60, fail_on_timeout=True, - pin_memory=None, ) assert len(train_loader) == 6, len(train_loader) @@ -216,7 +213,6 @@ def test_dataloader_fork(self): gc_freeze_at_start=True, watchdog_timeout_seconds=60, fail_on_timeout=True, - pin_memory=None, ) train_loader.restore_state_rank(state1) @@ -250,7 +246,6 @@ def test_dataloader_thread(self): gc_collect_every_n_steps=0, watchdog_timeout_seconds=60, fail_on_timeout=True, - pin_memory=None, ) assert len(train_loader) == 6, len(train_loader) @@ -285,7 +280,6 @@ def test_dataloader_thread(self): gc_collect_every_n_steps=0, watchdog_timeout_seconds=60, fail_on_timeout=True, - pin_memory=None, ) train_loader.restore_state_rank(state1) From e63463131c73003e65fa4617673e0d6ecdcf5c1a Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Thu, 24 Jul 2025 11:46:02 +0200 Subject: [PATCH 11/36] DataLoader as ctx mgr --- src/megatron/energon/dataloader/dataloader.py | 113 +++++++++----- src/megatron/energon/dataloader/pin_memory.py | 5 +- tests/test_dataloader.py | 143 ++++++++---------- 3 files changed, 139 insertions(+), 122 deletions(-) diff --git a/src/megatron/energon/dataloader/dataloader.py b/src/megatron/energon/dataloader/dataloader.py index 3f3ec0ba..e09564f1 100644 --- a/src/megatron/energon/dataloader/dataloader.py +++ b/src/megatron/energon/dataloader/dataloader.py @@ -45,6 +45,12 @@ def __call__( class DataLoader(Generic[TSample]): + """ + Implementation for a data loader. Orchestrates the workers for prefetching samples. + Opposing the `torch.utils.data.DataLoader`, this loader needs explicit shutdown when done, + to avoid leaking workers (fixes a bug). + """ + _workers: list[DataLoaderWorker[TSample]] | None = None _exhausted_workers: list[bool] _next_worker_id: int = 0 @@ -99,7 +105,9 @@ def __init__( watchdog_timeout_seconds: The timeout in seconds. If `None`, the watchdog is disabled. watchdog_initial_timeout_seconds: The initial timeout in seconds. If `None`, the timeout is the same as `watchdog_timeout_seconds`. fail_on_timeout: If True, stops the whole process upon timeout, after printing a stack trace. - pin_memory: The memory pinner to use. If `None`, no memory is not pinned. If "automatic", the memory is pinned automatically if cuda is available. + pin_memory: The memory pinner to use. If `None`, no memory is not pinned. + If "automatic", the memory is pinned automatically if cuda is available. + If a `PinMemory` instance, the instance may only be used for one `DataLoader`. """ if dataset.worker_config.num_workers == 0 and worker_type == ForkDataLoaderWorker: worker_type = DataLoaderWorker @@ -148,32 +156,6 @@ def __init__( self._spawning_process = os.getpid() - def shutdown(self, in_del: bool = False) -> None: - if self._workers is not None: - if in_del: - warnings.warn( - "Explicitly call DataLoader.shutdown() to avoid leaking workers.", - ResourceWarning, - ) - print( - "WARNING: Explicitly call DataLoader.shutdown() to avoid leaking workers.\n", - end="", - file=sys.stderr, - ) - for worker in self._workers: - worker.shutdown(in_del=in_del) - self._workers = None - if self._pin_memory is not None: - self._pin_memory.shutdown() - - def __del__(self) -> None: - self.shutdown(in_del=True) - - def start_iter(self) -> None: - if self._workers is not None: - for worker in self._workers: - worker.new_iter() - def _epoch_iter(self) -> Generator[TSample, None, None]: """Iterate over the dataset for one epoch (i.e. all workers StopIteration). One epoch may also be infinite (if looping the dataset).""" @@ -337,7 +319,8 @@ def save_state_global(self, global_dst_rank: int) -> Sequence[FlexState | None] # Not distributed -> return the merged state return [merged_state] - def _start(self, initial_state: FlexState | None = None) -> None: + def _start(self) -> None: + """Start the workers and restore the state if available.""" self._workers = [ self._worker_type(self._dataset, self._worker_config, local_worker_id, self._cache_pool) for local_worker_id in range(self._worker_config.safe_num_workers) @@ -345,15 +328,10 @@ def _start(self, initial_state: FlexState | None = None) -> None: for worker in self._workers: worker.start() - if initial_state is None: - if self._restore_state is not None: - initial_state = self._restore_state - self._restore_state = None - - if initial_state is None: + if self._restore_state is None: worker_states = [None] * self._worker_config.safe_num_workers else: - worker_states = initial_state["worker_states"] + worker_states = self._restore_state["worker_states"] assert len(worker_states) == self._worker_config.safe_num_workers, ( "Number of initial states must match number of workers" @@ -362,19 +340,59 @@ def _start(self, initial_state: FlexState | None = None) -> None: for worker, worker_state in zip(self._workers, worker_states): worker.dataset_init(worker_state) - if initial_state is not None: + if self._restore_state is not None: self._prefetching_samples = [ [ - CallableFuture(functools.partial(self.restore_sample, sample_key)) + self._pin_memory( + CallableFuture(functools.partial(self.restore_sample, sample_key)) + ) for sample_key in prefetched_samples_keys ] - for prefetched_samples_keys in initial_state["prefetched_samples_keys"] + for prefetched_samples_keys in self._restore_state["prefetched_samples_keys"] ] - self._next_worker_id = initial_state["next_worker_id"] + self._next_worker_id = self._restore_state["next_worker_id"] self._exhausted_workers = [ False if worker_state is None else worker_state["exhausted"] for worker_state in worker_states ] + # State was restored, clear + self._restore_state = None + + def shutdown(self, in_del: bool = False) -> None: + """ + Shutdown the workers and the pin memory thread. + + Args: + in_del: Whether the shutdown is called from the garbage collector (in __del__). + Users should not need to set this. + """ + if self._workers is not None: + if in_del: + warnings.warn( + "Explicitly call DataLoader.shutdown() to avoid leaking workers or run as context manager.", + ResourceWarning, + ) + print( + "WARNING: Explicitly call DataLoader.shutdown() to avoid leaking workers or run as context manager.\n", + end="", + file=sys.stderr, + ) + for worker in self._workers: + worker.shutdown(in_del=in_del) + self._workers = None + self._pin_memory.shutdown(in_del=in_del) + + def __del__(self) -> None: + self.shutdown(in_del=True) + + def __enter__(self) -> "DataLoader[TSample]": + # Already start if using the context manager. This ensures the lifecycle is fixed. + # Otherwise, will start when iterating. + self._start() + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + self.shutdown() def restore_state_rank(self, state: FlexState | None) -> None: """ @@ -493,7 +511,24 @@ def restore_sample(self, restore_key: tuple) -> TSample: finally: self._worker_config.worker_deactivate() + def with_restored_state_rank(self, state: FlexState | None) -> "DataLoader[TSample]": + """ + Use this data loader and restore the state. Useful for chaining commands. See `save_state_rank` for more details. + """ + self.restore_state_rank(state) + return self + + def with_restored_state_global( + self, state: Sequence[FlexState | None] | None, src_rank: int | None = None + ) -> "DataLoader[TSample]": + """ + Use this data loader and restore the state. Useful for chaining commands. See `save_state_global` for more details. + """ + self.restore_state_global(state, src_rank=src_rank) + return self + def config(self) -> dict[str, Any]: + """Get the configuration of the dataset.""" return self._dataset.config() def __str__(self) -> str: diff --git a/src/megatron/energon/dataloader/pin_memory.py b/src/megatron/energon/dataloader/pin_memory.py index 4c1539d5..4e2dc247 100644 --- a/src/megatron/energon/dataloader/pin_memory.py +++ b/src/megatron/energon/dataloader/pin_memory.py @@ -79,9 +79,6 @@ def _wrk_pin_memory(self, sample: Future[TSample]) -> TSample: print(f"[{self._name}] Got sample data\n", end="") return self._pin_memory(sample_data) - def _worker_pin_memory(self, sample: Future[TSample]) -> Future[TSample]: - return self._worker_call(self._wrk_pin_memory, sample) - def __call__(self, sample: Future[TSample]) -> Future[TSample]: """ Pin the memory of a sample. @@ -90,4 +87,4 @@ def __call__(self, sample: Future[TSample]) -> Future[TSample]: """ if not self.running(): self.start() - return self._worker_pin_memory(sample) + return self._worker_call(self._wrk_pin_memory, sample) diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 2c057c8b..94b8baef 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -104,7 +104,7 @@ def test_dataloader_no_workers(self): ) # Train mode dataset - train_loader = DataLoader( + with DataLoader( get_train_dataset( self.ds1_path, worker_config=worker_config, @@ -113,27 +113,25 @@ def test_dataloader_no_workers(self): max_samples_per_sequence=None, repeat=False, ), - ) - assert len(train_loader) == 6, len(train_loader) - - train_order1 = [ - text for idx, data in zip(range(55 * 10), train_loader) for text in data.text - ] - print(train_order1[:10]) - print(Counter(train_order1)) - assert len(train_order1) == 55, len(train_order1) - assert len(Counter(train_order1)) == 55, Counter(train_order1) - assert all(v == 1 for v in Counter(train_order1).values()), Counter(train_order1) + ) as train_loader: + assert len(train_loader) == 6, len(train_loader) - state1 = train_loader.save_state_rank() + train_order1 = [ + text for idx, data in zip(range(55 * 10), train_loader) for text in data.text + ] + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(train_order1) == 55, len(train_order1) + assert len(Counter(train_order1)) == 55, Counter(train_order1) + assert all(v == 1 for v in Counter(train_order1).values()), Counter(train_order1) - train_order2 = [ - text for idx, data in zip(range(55 * 10), train_loader) for text in data.text - ] + state1 = train_loader.save_state_rank() - train_loader.shutdown() + train_order2 = [ + text for idx, data in zip(range(55 * 10), train_loader) for text in data.text + ] - train_loader = DataLoader( + with DataLoader( get_train_dataset( self.ds1_path, worker_config=worker_config, @@ -142,14 +140,11 @@ def test_dataloader_no_workers(self): max_samples_per_sequence=None, repeat=False, ), - ) - - train_loader.restore_state_rank(state1) - - cmp_order2 = [text for idx, data in zip(range(55 * 10), train_loader) for text in data.text] - assert train_order2 == cmp_order2, (train_order1, cmp_order2) - - train_loader.shutdown() + ).with_restored_state_rank(state1) as train_loader: + cmp_order2 = [ + text for idx, data in zip(range(55 * 10), train_loader) for text in data.text + ] + assert train_order2 == cmp_order2, (train_order1, cmp_order2) def test_dataloader_fork(self): torch.manual_seed(42) @@ -161,7 +156,7 @@ def test_dataloader_fork(self): ) # Train mode dataset - train_loader = DataLoader( + with DataLoader( get_train_dataset( self.ds1_path, worker_config=worker_config, @@ -176,29 +171,27 @@ def test_dataloader_fork(self): gc_freeze_at_start=True, watchdog_timeout_seconds=60, fail_on_timeout=True, - ) - assert len(train_loader) == 6, len(train_loader) + ) as train_loader: + assert len(train_loader) == 6, len(train_loader) - train_order1 = [ - text for idx, data in zip(range(55 * 10), train_loader) for text in data.text - ] - print(train_order1[:10]) - print(Counter(train_order1)) - assert len(train_order1) == 55, len(train_order1) - assert len(Counter(train_order1)) == 55, Counter(train_order1) - assert all(v == 1 for v in Counter(train_order1).values()), Counter(train_order1) + train_order1 = [ + text for idx, data in zip(range(55 * 10), train_loader) for text in data.text + ] + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(train_order1) == 55, len(train_order1) + assert len(Counter(train_order1)) == 55, Counter(train_order1) + assert all(v == 1 for v in Counter(train_order1).values()), Counter(train_order1) - state1 = train_loader.save_state_rank() + state1 = train_loader.save_state_rank() - train_order2 = [ - text for idx, data in zip(range(55 * 10), train_loader) for text in data.text - ] + train_order2 = [ + text for idx, data in zip(range(55 * 10), train_loader) for text in data.text + ] - assert len(train_order1) == len(train_order2), (len(train_order1), len(train_order2)) + assert len(train_order1) == len(train_order2), (len(train_order1), len(train_order2)) - train_loader.shutdown() - - train_loader = DataLoader( + with DataLoader( get_train_dataset( self.ds1_path, worker_config=worker_config, @@ -213,14 +206,11 @@ def test_dataloader_fork(self): gc_freeze_at_start=True, watchdog_timeout_seconds=60, fail_on_timeout=True, - ) - - train_loader.restore_state_rank(state1) - - cmp_order2 = [text for idx, data in zip(range(55 * 10), train_loader) for text in data.text] - assert train_order2 == cmp_order2, (train_order1, cmp_order2) - - train_loader.shutdown() + ).with_restored_state_rank(state1) as train_loader: + cmp_order2 = [ + text for idx, data in zip(range(55 * 10), train_loader) for text in data.text + ] + assert train_order2 == cmp_order2, (train_order1, cmp_order2) def test_dataloader_thread(self): torch.manual_seed(42) @@ -232,7 +222,7 @@ def test_dataloader_thread(self): ) # Train mode dataset - train_loader = DataLoader( + with DataLoader( get_train_dataset( self.ds1_path, worker_config=worker_config, @@ -246,27 +236,25 @@ def test_dataloader_thread(self): gc_collect_every_n_steps=0, watchdog_timeout_seconds=60, fail_on_timeout=True, - ) - assert len(train_loader) == 6, len(train_loader) - - train_order1 = [ - text for idx, data in zip(range(55 * 10), train_loader) for text in data.text - ] - print(train_order1[:10]) - print(Counter(train_order1)) - assert len(train_order1) == 55, len(train_order1) - assert len(Counter(train_order1)) == 55, Counter(train_order1) - assert all(v == 1 for v in Counter(train_order1).values()), Counter(train_order1) + ) as train_loader: + assert len(train_loader) == 6, len(train_loader) - state1 = train_loader.save_state_rank() + train_order1 = [ + text for idx, data in zip(range(55 * 10), train_loader) for text in data.text + ] + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(train_order1) == 55, len(train_order1) + assert len(Counter(train_order1)) == 55, Counter(train_order1) + assert all(v == 1 for v in Counter(train_order1).values()), Counter(train_order1) - train_order2 = [ - text for idx, data in zip(range(55 * 10), train_loader) for text in data.text - ] + state1 = train_loader.save_state_rank() - train_loader.shutdown() + train_order2 = [ + text for idx, data in zip(range(55 * 10), train_loader) for text in data.text + ] - train_loader = DataLoader( + with DataLoader( get_train_dataset( self.ds1_path, worker_config=worker_config, @@ -280,14 +268,11 @@ def test_dataloader_thread(self): gc_collect_every_n_steps=0, watchdog_timeout_seconds=60, fail_on_timeout=True, - ) - - train_loader.restore_state_rank(state1) - - cmp_order2 = [text for idx, data in zip(range(55 * 10), train_loader) for text in data.text] - assert train_order2 == cmp_order2, (train_order1, cmp_order2) - - train_loader.shutdown() + ).with_restored_state_rank(state1) as train_loader: + cmp_order2 = [ + text for idx, data in zip(range(55 * 10), train_loader) for text in data.text + ] + assert train_order2 == cmp_order2, (train_order1, cmp_order2) if __name__ == "__main__": From ffa25a44b5ff841a9d26b05b480870b978ec8253 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Wed, 30 Jul 2025 11:50:34 +0200 Subject: [PATCH 12/36] WIP: Thread-local state --- .../energon/dataloader/asynchronous/fork.py | 9 - .../energon/dataloader/asynchronous/thread.py | 7 - src/megatron/energon/dataloader/dataloader.py | 150 +++++++------- .../dataloader/workers/async_worker.py | 14 +- .../energon/dataloader/workers/base_worker.py | 2 +- .../energon/dataloader/workers/fork_worker.py | 9 + .../dataloader/workers/thread_worker.py | 31 ++- src/megatron/energon/flavors/base_dataset.py | 44 +++- .../energon/flavors/webdataset/itar_reader.py | 15 +- .../flavors/webdataset/sample_loader.py | 5 +- src/megatron/energon/loader.py | 74 ++++--- src/megatron/energon/rng.py | 41 ++++ src/megatron/energon/savable_loader.py | 8 +- src/megatron/energon/task_encoder/base.py | 188 ++++++++++++------ src/megatron/energon/worker.py | 94 +++++---- src/megatron/energon/wrappers/base.py | 38 ++-- .../energon/wrappers/batch_dataset.py | 2 - .../energon/wrappers/blend_dataset.py | 1 - src/megatron/energon/wrappers/buffer.py | 1 - .../energon/wrappers/epochize_dataset.py | 2 - .../energon/wrappers/filter_dataset.py | 5 - src/megatron/energon/wrappers/gc_dataset.py | 3 - .../energon/wrappers/group_batch_dataset.py | 2 - .../energon/wrappers/iter_map_dataset.py | 2 - .../energon/wrappers/limit_dataset.py | 5 +- .../energon/wrappers/log_sample_dataset.py | 5 - src/megatron/energon/wrappers/map_dataset.py | 5 - .../energon/wrappers/mix_batch_dataset.py | 3 - .../energon/wrappers/packing_dataset.py | 2 - .../energon/wrappers/repeat_dataset.py | 2 - .../wrappers/shuffle_buffer_dataset.py | 4 - .../wrappers/task_encoder_state_dataset.py | 81 ++++++++ .../energon/wrappers/watchdog_dataset.py | 3 - tests/test_crudedataset.py | 18 -- tests/test_dataloader.py | 4 +- tests/test_dataset.py | 10 - tests/test_dataset_det.py | 3 +- tests/test_metadataset.py | 10 - tests/test_metadataset_fewsamp.py | 6 - tests/test_metadataset_v2.py | 28 --- 40 files changed, 559 insertions(+), 377 deletions(-) create mode 100644 src/megatron/energon/wrappers/task_encoder_state_dataset.py diff --git a/src/megatron/energon/dataloader/asynchronous/fork.py b/src/megatron/energon/dataloader/asynchronous/fork.py index 1b14298b..75a12a92 100644 --- a/src/megatron/energon/dataloader/asynchronous/fork.py +++ b/src/megatron/energon/dataloader/asynchronous/fork.py @@ -5,7 +5,6 @@ import sys import threading import warnings -from typing import override from megatron.energon.dataloader.asynchronous.base import ( Asynchronous, @@ -26,12 +25,10 @@ class ForkAsynchronous(Asynchronous): _spawning_process: int - @override def _asynchronous_init(self, name: str) -> None: super()._asynchronous_init(name) self._spawning_process = os.getpid() - @override def _queues(self) -> tuple[QueueProtocol[WorkerCommand], QueueProtocol[WorkerResult]]: return multiprocessing.Queue(), multiprocessing.Queue() @@ -47,7 +44,6 @@ def _check_parent_process(self, evt_exit: threading.Event) -> None: print(f"[{self._name}] Parent process died, exiting", file=sys.stderr) os._exit(-1) - @override def _worker_run( self, cmd_queue: multiprocessing.Queue, @@ -77,11 +73,9 @@ def _worker_run( cmd_queue.cancel_join_thread() print(f"[{self._name}] shutting down, done\n", end="") - @override def _in_worker(self) -> bool: return multiprocessing.current_process() == self._process - @override def start(self) -> None: multiprocessing.set_start_method("fork", force=True) self._process = multiprocessing.Process( @@ -92,7 +86,6 @@ def start(self) -> None: ) self._process.start() - @override def shutdown(self, in_del: bool = False) -> None: if self._spawning_process != os.getpid(): # Should avoid forked process containing a forked worker on exit. @@ -138,11 +131,9 @@ def shutdown(self, in_del: bool = False) -> None: self._result_queue.close() self._result_queue.cancel_join_thread() - @override def running(self) -> bool: return self._process is not None - @override def _assert_running(self) -> None: assert self._process is not None, "Worker must be started first" assert self._process.is_alive(), "Worker died" diff --git a/src/megatron/energon/dataloader/asynchronous/thread.py b/src/megatron/energon/dataloader/asynchronous/thread.py index 8a9710a6..8b108659 100644 --- a/src/megatron/energon/dataloader/asynchronous/thread.py +++ b/src/megatron/energon/dataloader/asynchronous/thread.py @@ -4,7 +4,6 @@ import sys import threading import warnings -from typing import override from megatron.energon.dataloader.asynchronous.base import ( Asynchronous, @@ -19,15 +18,12 @@ class ThreadAsynchronous(Asynchronous): _thread: threading.Thread | None = None - @override def _queues(self) -> tuple[QueueProtocol[WorkerCommand], QueueProtocol[WorkerResult]]: return queue.Queue(), queue.Queue() - @override def _in_worker(self) -> bool: return threading.current_thread() == self._thread - @override def start(self) -> None: self._thread = threading.Thread( target=self._worker_run, @@ -37,7 +33,6 @@ def start(self) -> None: ) self._thread.start() - @override def shutdown(self, in_del: bool = False) -> None: if self._thread is not None: if in_del: @@ -64,11 +59,9 @@ def shutdown(self, in_del: bool = False) -> None: self._thread.join() self._thread = None - @override def running(self) -> bool: return self._thread is not None - @override def _assert_running(self) -> None: assert self._thread is not None, "Thread must be started first" assert self._thread.is_alive(), "Thread died" diff --git a/src/megatron/energon/dataloader/dataloader.py b/src/megatron/energon/dataloader/dataloader.py index e09564f1..c1cd95c1 100644 --- a/src/megatron/energon/dataloader/dataloader.py +++ b/src/megatron/energon/dataloader/dataloader.py @@ -156,6 +156,81 @@ def __init__( self._spawning_process = os.getpid() + def _start(self) -> None: + """Start the workers and restore the state if available.""" + self._workers = [ + self._worker_type(self._dataset, self._worker_config, local_worker_id, self._cache_pool) + for local_worker_id in range(self._worker_config.safe_num_workers) + ] + for worker in self._workers: + worker.start() + + if self._restore_state is None: + worker_states = [None] * self._worker_config.safe_num_workers + else: + worker_states = self._restore_state["worker_states"] + + assert len(worker_states) == self._worker_config.safe_num_workers, ( + "Number of initial states must match number of workers" + ) + + for worker, worker_state in zip(self._workers, worker_states): + worker.dataset_init(worker_state) + + if self._restore_state is not None: + self._prefetching_samples = [ + [ + self._pin_memory( + CallableFuture(functools.partial(self.restore_sample, sample_key)) + ) + for sample_key in prefetched_samples_keys + ] + for prefetched_samples_keys in self._restore_state["prefetched_samples_keys"] + ] + self._next_worker_id = self._restore_state["next_worker_id"] + self._exhausted_workers = [ + False if worker_state is None else worker_state["exhausted"] + for worker_state in worker_states + ] + # State was restored, clear + self._restore_state = None + + def shutdown(self, in_del: bool = False) -> None: + """ + Shutdown the workers and the pin memory thread. + + Args: + in_del: Whether the shutdown is called from the garbage collector (in __del__). + Users should not need to set this. + """ + if self._workers is not None: + if in_del: + warnings.warn( + "Explicitly call DataLoader.shutdown() to avoid leaking workers or run as context manager.", + ResourceWarning, + ) + print( + "WARNING: Explicitly call DataLoader.shutdown() to avoid leaking workers or run as context manager.\n", + end="", + file=sys.stderr, + ) + for worker in self._workers: + worker.shutdown(in_del=in_del) + self._workers = None + self._pin_memory.shutdown(in_del=in_del) + + def __del__(self) -> None: + self.shutdown(in_del=True) + + def __enter__(self) -> "DataLoader[TSample]": + # Already start if using the context manager. This ensures the lifecycle is fixed. + # Otherwise, will start when iterating. + self._start() + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + self.shutdown() + def _epoch_iter(self) -> Generator[TSample, None, None]: """Iterate over the dataset for one epoch (i.e. all workers StopIteration). One epoch may also be infinite (if looping the dataset).""" @@ -319,81 +394,6 @@ def save_state_global(self, global_dst_rank: int) -> Sequence[FlexState | None] # Not distributed -> return the merged state return [merged_state] - def _start(self) -> None: - """Start the workers and restore the state if available.""" - self._workers = [ - self._worker_type(self._dataset, self._worker_config, local_worker_id, self._cache_pool) - for local_worker_id in range(self._worker_config.safe_num_workers) - ] - for worker in self._workers: - worker.start() - - if self._restore_state is None: - worker_states = [None] * self._worker_config.safe_num_workers - else: - worker_states = self._restore_state["worker_states"] - - assert len(worker_states) == self._worker_config.safe_num_workers, ( - "Number of initial states must match number of workers" - ) - - for worker, worker_state in zip(self._workers, worker_states): - worker.dataset_init(worker_state) - - if self._restore_state is not None: - self._prefetching_samples = [ - [ - self._pin_memory( - CallableFuture(functools.partial(self.restore_sample, sample_key)) - ) - for sample_key in prefetched_samples_keys - ] - for prefetched_samples_keys in self._restore_state["prefetched_samples_keys"] - ] - self._next_worker_id = self._restore_state["next_worker_id"] - self._exhausted_workers = [ - False if worker_state is None else worker_state["exhausted"] - for worker_state in worker_states - ] - # State was restored, clear - self._restore_state = None - - def shutdown(self, in_del: bool = False) -> None: - """ - Shutdown the workers and the pin memory thread. - - Args: - in_del: Whether the shutdown is called from the garbage collector (in __del__). - Users should not need to set this. - """ - if self._workers is not None: - if in_del: - warnings.warn( - "Explicitly call DataLoader.shutdown() to avoid leaking workers or run as context manager.", - ResourceWarning, - ) - print( - "WARNING: Explicitly call DataLoader.shutdown() to avoid leaking workers or run as context manager.\n", - end="", - file=sys.stderr, - ) - for worker in self._workers: - worker.shutdown(in_del=in_del) - self._workers = None - self._pin_memory.shutdown(in_del=in_del) - - def __del__(self) -> None: - self.shutdown(in_del=True) - - def __enter__(self) -> "DataLoader[TSample]": - # Already start if using the context manager. This ensures the lifecycle is fixed. - # Otherwise, will start when iterating. - self._start() - return self - - def __exit__(self, exc_type, exc_value, traceback) -> None: - self.shutdown() - def restore_state_rank(self, state: FlexState | None) -> None: """ Restore the state of the DataLoader on the current rank. diff --git a/src/megatron/energon/dataloader/workers/async_worker.py b/src/megatron/energon/dataloader/workers/async_worker.py index 00cff3a9..d4a4e90b 100644 --- a/src/megatron/energon/dataloader/workers/async_worker.py +++ b/src/megatron/energon/dataloader/workers/async_worker.py @@ -1,6 +1,6 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -from typing import Generic, TypeVar, override +from typing import Generic, TypeVar from megatron.energon.cache.base import CachePool from megatron.energon.dataloader.asynchronous import ( @@ -46,14 +46,6 @@ def _worker_run( self, cmd_queue: QueueProtocol[WorkerCommand], result_queue: QueueProtocol[WorkerResult] ) -> None: SystemRng.seed(self._seed) - import torch.utils.data._utils - - torch.utils.data._utils.worker._worker_info = torch.utils.data._utils.worker.WorkerInfo( - id=self._rank_worker_id, - num_workers=self.worker_config.num_workers, - seed=self._seed, - dataset=self.dataset, - ) self._global_worker_id = self.worker_config.global_worker_id() super()._worker_run(cmd_queue, result_queue) @@ -64,28 +56,24 @@ def _wrk_prefetch_next(self) -> TSample: # so immediately resolve the future to the result (get returns immediately). return super().prefetch_next().get() - @override def dataset_init(self, initial_state: FlexState | None) -> None: if self._in_worker(): return super().dataset_init(initial_state) else: return self._worker_call(self.dataset_init, initial_state).get() - @override def new_iter(self) -> None: if self._in_worker(): return super().new_iter() else: return self._worker_call(self.new_iter).get() - @override def prefetch_next(self) -> Future[TSample]: # Do not resolve the future here, but return it. if self._in_worker(): return super().prefetch_next() return self._worker_call(self._wrk_prefetch_next) - @override def save_state(self) -> FlexState: if self._in_worker(): return super().save_state() diff --git a/src/megatron/energon/dataloader/workers/base_worker.py b/src/megatron/energon/dataloader/workers/base_worker.py index bf348d91..a650bffc 100644 --- a/src/megatron/energon/dataloader/workers/base_worker.py +++ b/src/megatron/energon/dataloader/workers/base_worker.py @@ -103,9 +103,9 @@ def dataset_init(self, state: FlexState | None) -> None: ) assert self._seed == self.worker_config.worker_seed(self._rank_worker_id), "Seed mismatch" print(f"dataset_init {state=}\n", end="") + self.dataset.reset_state() if state is None: self._sample_index = 0 - self.dataset.reset_state_deep() print("dataset_init reset_state_deep\n", end="") self.new_iter() print("dataset_init new_iter\n", end="") diff --git a/src/megatron/energon/dataloader/workers/fork_worker.py b/src/megatron/energon/dataloader/workers/fork_worker.py index ec165be6..39a8d591 100644 --- a/src/megatron/energon/dataloader/workers/fork_worker.py +++ b/src/megatron/energon/dataloader/workers/fork_worker.py @@ -25,4 +25,13 @@ def _worker_run( result_queue: multiprocessing.Queue, ) -> None: gc_init_worker(self._rank_worker_id) + import torch.utils.data._utils + + torch.utils.data._utils.worker._worker_info = torch.utils.data._utils.worker.WorkerInfo( + id=self._rank_worker_id, + num_workers=self.worker_config.num_workers, + seed=self._seed, + dataset=self.dataset, + ) + super()._worker_run(cmd_queue, result_queue) diff --git a/src/megatron/energon/dataloader/workers/thread_worker.py b/src/megatron/energon/dataloader/workers/thread_worker.py index 3d450eaa..7f5ea4fe 100644 --- a/src/megatron/energon/dataloader/workers/thread_worker.py +++ b/src/megatron/energon/dataloader/workers/thread_worker.py @@ -1,8 +1,12 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause +import threading from typing import Generic, TypeVar -from megatron.energon.dataloader.asynchronous import ThreadAsynchronous +import torch.utils.data + +from megatron.energon.dataloader.asynchronous import ThreadAsynchronous, WorkerCommand, WorkerResult +from megatron.energon.dataloader.asynchronous.base import QueueProtocol from megatron.energon.dataloader.workers.async_worker import ( DataLoaderAsynchronousWorker, ) @@ -10,9 +14,34 @@ TSample = TypeVar("TSample", covariant=True) +_orig_get_worker_info = torch.utils.data.get_worker_info + +_thread_local_worker_info = threading.local() + + +def _patch_get_worker_info(): + if not hasattr(_thread_local_worker_info, "_worker_info"): + _thread_local_worker_info._worker_info = _orig_get_worker_info() + return _orig_get_worker_info() + + +torch.utils.data.get_worker_info = _patch_get_worker_info + + class ThreadDataLoaderWorker( ThreadAsynchronous, DataLoaderAsynchronousWorker[TSample], Generic[TSample] ): """ Implements the `DataLoaderWorker` interface using threads. """ + + def _worker_run( + self, cmd_queue: QueueProtocol[WorkerCommand], result_queue: QueueProtocol[WorkerResult] + ) -> None: + _thread_local_worker_info._worker_info = torch.utils.data._utils.worker.WorkerInfo( + id=self._rank_worker_id, + num_workers=self.worker_config.num_workers, + seed=self._seed, + dataset=self.dataset, + ) + return super()._worker_run(cmd_queue, result_queue) diff --git a/src/megatron/energon/flavors/base_dataset.py b/src/megatron/energon/flavors/base_dataset.py index c18ad96f..4c2ee398 100644 --- a/src/megatron/energon/flavors/base_dataset.py +++ b/src/megatron/energon/flavors/base_dataset.py @@ -3,6 +3,7 @@ import dataclasses import inspect +import threading import typing from abc import ABC, abstractmethod from copy import deepcopy @@ -249,6 +250,9 @@ def save_state(self) -> MyExtendedState: """ +THREAD_SAFE = True + + class SavableDataset(IterableDataset[T_sample], Savable, Generic[T_sample], ABC): """A dataset that can be saved and restored (i.e. the random state, internal buffers, etc.). I.e. it can be resumed from a checkpoint. @@ -272,6 +276,8 @@ class SavableDataset(IterableDataset[T_sample], Savable, Generic[T_sample], ABC) def __init__(self, worker_config: WorkerConfig): self.worker_config = worker_config + if THREAD_SAFE: + self._thread_state = threading.local() @abstractmethod def len_worker(self, worker_idx: int | None = None) -> int: @@ -348,14 +354,12 @@ def restore_state(self, state: FlexState) -> None: else: setattr(self, key, value) - @abstractmethod - def reset_state_own(self) -> None: - """Resets the state of the dataset to the initial state. Can only be called in a worker process.""" - ... - - def reset_state_deep(self) -> None: - """Resets the state of the dataset to the initial state. Can only be called in a worker process.""" - self.reset_state_own() + def reset_state(self) -> None: + """ + Resets the state of the dataset. Called at least once in the worker process before iterating. + Recursively resets the state of all wrapped datasets as well. + """ + pass @abstractmethod def worker_has_samples(self) -> bool: @@ -401,6 +405,30 @@ def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_s "This dataset does not support indexing, because it is not safely deterministic." ) + if THREAD_SAFE: + + def __getattribute__(self, name: str) -> Any: + if name in ("_savable_fields", "_thread_state", "worker_config"): + return object.__getattribute__(self, name) + elif name in self._savable_fields: + return getattr(self._thread_state, name) + else: + return object.__getattribute__(self, name) + + def __delattr__(self, name: str) -> None: + if name in self._savable_fields: + delattr(self._thread_state, name) + else: + object.__delattr__(self, name) + + def __setattr__(self, name: str, value: Any) -> None: + if name in ("_savable_fields", "_thread_state", "worker_config"): + object.__setattr__(self, name, value) + elif name in self._savable_fields: + setattr(self._thread_state, name, value) + else: + object.__setattr__(self, name, value) + class BaseCoreDatasetFactory(Generic[T_sample], ABC): """Base type for an inner dataset sample loader. This factory can be used to construct a sample loader, or for diff --git a/src/megatron/energon/flavors/webdataset/itar_reader.py b/src/megatron/energon/flavors/webdataset/itar_reader.py index 961de23b..d2bdd826 100644 --- a/src/megatron/energon/flavors/webdataset/itar_reader.py +++ b/src/megatron/energon/flavors/webdataset/itar_reader.py @@ -1,6 +1,7 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause +import threading from abc import ABC, abstractmethod from bisect import bisect_right from typing import ( @@ -328,7 +329,7 @@ class ShardInfosITarReader(ITarReader[int]): shard_infos: List[ShardInfo] shard_tar_file_idxs: List[int] shard_count_cumsum: List[int] - cached_offset_reader: CachedItarOffsetReader + _thread_local: threading.local def __init__( self, @@ -365,8 +366,8 @@ def __init__( tar_filenames = list(cur_tar_files.keys()) tar_filepaths = [p[1] for p in cur_tar_files.values()] - # Instantiate cached reader for the .tar.idx files - self.cached_offset_reader = CachedItarOffsetReader(cache_size=itar_cache_size) + self._itar_cache_size = itar_cache_size + self._thread_local = threading.local() super().__init__( base_path=base_path, @@ -377,6 +378,14 @@ def __init__( sample_filter=sample_filter, ) + @property + def cached_offset_reader(self) -> CachedItarOffsetReader: + if not hasattr(self._thread_local, "_cached_offset_reader"): + self._thread_local._cached_offset_reader = CachedItarOffsetReader( + cache_size=self._itar_cache_size + ) + return self._thread_local._cached_offset_reader + def _get_itar_sample_pointer(self, idx: int) -> ITarSamplePointer: """ Get the ITarSample object for the given index. diff --git a/src/megatron/energon/flavors/webdataset/sample_loader.py b/src/megatron/energon/flavors/webdataset/sample_loader.py index df4ad103..78c2399f 100644 --- a/src/megatron/energon/flavors/webdataset/sample_loader.py +++ b/src/megatron/energon/flavors/webdataset/sample_loader.py @@ -118,12 +118,11 @@ def __init__( self.workers_slice_offsets = workers_sample_slice_offsets self.slice_offsets = None - self.reset_state_own() - assert shuffle_over_epochs is None or shuffle_over_epochs == -1 or shuffle_over_epochs >= 1 assert self.parallel_slice_iters >= 1 - def reset_state_own(self) -> None: + def reset_state(self) -> None: + super().reset_state() self._worker_rng = WorkerRng(self.worker_config) self._pending_slice_indexes = None self._pending_slices_offset = None diff --git a/src/megatron/energon/loader.py b/src/megatron/energon/loader.py index 97b796b9..f9526459 100644 --- a/src/megatron/energon/loader.py +++ b/src/megatron/energon/loader.py @@ -1,12 +1,17 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -from typing import Optional, TypeVar +from typing import Literal, Optional, TypeVar from megatron.energon.cache import CachePool +from megatron.energon.dataloader import ( + DataLoader, + DataLoaderWorker, + ForkDataLoaderWorker, + ThreadDataLoaderWorker, +) from megatron.energon.errors import warn_deprecated from megatron.energon.flavors import SavableDataset -from megatron.energon.savable_loader import BasicDataLoader, SavableDataLoader from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.gc_dataset import GC_DEFAULT_EVERY_N_ITER @@ -17,16 +22,16 @@ def get_savable_loader( dataset: SavableDataset[T], *, worker_config: Optional[WorkerConfig] = None, - checkpoint_every_sec: float = 60, - checkpoint_every_min_n_samples: Optional[int] = None, - n_checkpoints: Optional[int] = None, + worker_type: Literal["main", "fork", "thread"] | type[DataLoaderWorker] = "fork", + gc_freeze_at_start: bool = True, gc_collect_every_n_steps: int = GC_DEFAULT_EVERY_N_ITER, prefetch_factor: int = 2, cache_pool: Optional[CachePool] = None, watchdog_timeout_seconds: Optional[float] = 60, watchdog_initial_timeout_seconds: Optional[float] = None, fail_on_timeout: bool = False, -) -> SavableDataLoader[T]: + pin_memory: bool = True, +) -> DataLoader[T]: """ Get a dataloader for the given dataset. @@ -34,21 +39,18 @@ def get_savable_loader( Args: dataset: The dataset to create a loader for. worker_config: Deprecated. Please pass this to the dataset instead. - checkpoint_every_sec: This is the time in seconds after which an internal checkpoint is - saved. It may take the same duration to restore a checkpoint, but introduces additional - overhead during reading data from the dataset, so this should be chosen accordingly. - Only applies if using workers. - checkpoint_every_min_n_samples: Overwrites the minimum number of samples between - checkpoints. Defaults to `number of workers * 2`. Only applies if using workers. - n_checkpoints: The number of internal checkpoints to keep. Only applies if using workers. - If None, computes a suitable value. + worker_type: The type of worker to use. + gc_freeze_at_start: If True, the garbage collector is frozen at the start of the loader. + gc_collect_every_n_steps: The number of steps after which the garbage collector is called. + prefetch_factor: The factor by which to prefetch the dataset. cache_pool: If set, the cache pool to use for the dataset. watchdog_timeout_seconds: The timeout in seconds. If None, the watchdog is disabled. watchdog_initial_timeout_seconds: The initial timeout in seconds. If None, the timeout is the same as watchdog_timeout_seconds. fail_on_timeout: If True, stops the whole process upon timeout, after printing a stack trace. + pin_memory: If True, the dataset is pinned to memory. + Returns: - The instantiated :class:`megatron.energon.SavableDataLoader`, yielding batches from the dataset, - allowing to save the state of the dataset. + The instantiated :class:`megatron.energon.DataLoader`, yielding batches from the dataset. """ if worker_config is not None: if worker_config != dataset.worker_config: @@ -61,17 +63,32 @@ def get_savable_loader( "Passing a worker_config to get_savable_loader() is deprecated and will have no effect." ) - return SavableDataLoader( + if worker_type == "fork": + worker_type = ForkDataLoaderWorker + elif worker_type == "thread": + worker_type = ThreadDataLoaderWorker + elif worker_type == "main": + worker_type = DataLoaderWorker + elif not issubclass(worker_type, DataLoaderWorker): + raise ValueError(f"Invalid worker type: {worker_type}") + if dataset.worker_config.num_workers == 0: + assert prefetch_factor == 2 + prefetch_factor = 1 + pin_memory_arg = None + else: + pin_memory_arg = "automatic" if pin_memory else None + + return DataLoader( dataset, - checkpoint_every_sec=checkpoint_every_sec, - checkpoint_every_min_n_samples=checkpoint_every_min_n_samples, - n_checkpoints=n_checkpoints, - gc_collect_every_n_steps=gc_collect_every_n_steps, prefetch_factor=prefetch_factor, + worker_type=worker_type, cache_pool=cache_pool, + gc_collect_every_n_steps=gc_collect_every_n_steps, + gc_freeze_at_start=gc_freeze_at_start, watchdog_timeout_seconds=watchdog_timeout_seconds, watchdog_initial_timeout_seconds=watchdog_initial_timeout_seconds, fail_on_timeout=fail_on_timeout, + pin_memory=pin_memory_arg, ) @@ -84,7 +101,7 @@ def get_loader( watchdog_timeout_seconds: Optional[float] = 60, watchdog_initial_timeout_seconds: Optional[float] = None, fail_on_timeout: bool = False, -) -> BasicDataLoader[T]: +) -> DataLoader[T]: """ Get a dataloader for the given dataset. @@ -95,8 +112,9 @@ def get_loader( watchdog_timeout_seconds: The timeout in seconds. If None, the watchdog is disabled. watchdog_initial_timeout_seconds: The initial timeout in seconds. If None, the timeout is the same as watchdog_timeout_seconds. fail_on_timeout: If True, stops the whole process upon timeout, after printing a stack trace. + Returns: - The instantiated :class:`torch.data.DataLoader`, yielding batches from the dataset. + The instantiated :class:`DataLoader`, yielding batches from the dataset. """ if worker_config is not None: if worker_config != dataset.worker_config: @@ -109,11 +127,19 @@ def get_loader( "Passing a worker_config to get_loader() is deprecated and will have no effect." ) - return BasicDataLoader( + if dataset.worker_config.num_workers == 0: + assert prefetch_factor == 2 + prefetch_factor = 1 + pin_memory = None + else: + pin_memory = "automatic" + + return DataLoader( dataset, prefetch_factor=prefetch_factor, cache_pool=cache_pool, watchdog_timeout_seconds=watchdog_timeout_seconds, watchdog_initial_timeout_seconds=watchdog_initial_timeout_seconds, fail_on_timeout=fail_on_timeout, + pin_memory=pin_memory, ) diff --git a/src/megatron/energon/rng.py b/src/megatron/energon/rng.py index b2991580..8eb2f008 100644 --- a/src/megatron/energon/rng.py +++ b/src/megatron/energon/rng.py @@ -89,6 +89,47 @@ def restore_state(self, state: FlexState): self._restore_state = state["rng"] +class UserRng: + """User random generators. To be used within the task encoder, providing local seeding.""" + + def __init__(self, seed: int): + self.torch = torch.Generator() + self.torch.manual_seed(seed) + if torch.cuda.is_available(): + self.torch_cuda = torch.Generator(device="cuda") + self.torch_cuda.manual_seed(seed) + + self.numpy = numpy.random.default_rng(seed=seed) + self.random = random.Random(seed) + + def seed(self, seed: int) -> None: + self.torch.manual_seed(seed) + if torch.cuda.is_available(): + self.torch_cuda.manual_seed(seed) + self.numpy.bit_generator.state = numpy.random.default_rng(seed).bit_generator.state + self.random.seed(seed) + + def seed_args(self, *args: Any) -> None: + self.seed(SystemRng.get_seed_from_args(*args)) + + def save_state(self) -> FlexState: + state = FlexState( + torch=self.torch.get_state().tolist(), + numpy=self.numpy.bit_generator.state, + random=self.random.getstate(), + ) + if torch.cuda.is_available(): + state["torch_cuda"] = self.torch_cuda.get_state().tolist() + return state + + def restore_state(self, state: FlexState): + self.torch.set_state(torch.as_tensor(state["torch"])) + if torch.cuda.is_available(): + self.torch_cuda.set_state(torch.as_tensor(state["torch_cuda"], device="cuda")) + self.numpy.bit_generator.state = state["numpy"] + self.random.setstate(state["random"]) + + @edataclass class SystemRngState: """The state of the global random generators. diff --git a/src/megatron/energon/savable_loader.py b/src/megatron/energon/savable_loader.py index 1f319e45..a8a5d777 100644 --- a/src/megatron/energon/savable_loader.py +++ b/src/megatron/energon/savable_loader.py @@ -85,8 +85,6 @@ def __init__( super().__init__(dataset, worker_config=worker_config) self.cache_pool = cache_pool - self.reset_state_own() - def reset_state_own(self) -> None: self._sample_index = 0 self._state_restored = False @@ -353,13 +351,13 @@ def __iter__(self): my_state = self._workers_restore_from[self._worker_id] my_ds_state = my_state.dataset_state assert my_state is not None - if my_ds_state is None: - self.dataset.reset_state_deep() - else: + self.dataset.reset_state() + if my_ds_state is not None: self.dataset.restore_state(my_ds_state) self._restore_state(my_state) self._workers_restore_from[self._worker_id] = None else: + self.dataset.reset_state() # Store the initial state of the worker if we stop before the first sample self._store_checkpoint() # If skipping, also restart the iterator to reach the start of the restored diff --git a/src/megatron/energon/task_encoder/base.py b/src/megatron/energon/task_encoder/base.py index 6140ed0c..f92664d1 100644 --- a/src/megatron/energon/task_encoder/base.py +++ b/src/megatron/energon/task_encoder/base.py @@ -10,6 +10,7 @@ from typing import ( Any, Callable, + ClassVar, Dict, Generator, Generic, @@ -39,7 +40,8 @@ ) from megatron.energon.flavors.base_dataset import ExtendableDataclassMixin from megatron.energon.metadataset.loader_interface import DatasetBlendMode, LoadedDataset -from megatron.energon.rng import SystemRng +from megatron.energon.rng import SystemRng, UserRng +from megatron.energon.savable import Savable from megatron.energon.source_info import SourceInfo from megatron.energon.task_encoder.cooking import Cooker from megatron.energon.worker import WorkerConfig @@ -132,6 +134,7 @@ def stateless( fn: Optional[Callable[..., T]] = None, *, restore_seeds: bool = False, + restore_task_encoder_seeds: bool = False, failure_tolerance: Optional[int] = None, ) -> Union[Callable[[Callable[..., T]], Callable[..., T]], Callable[..., T]]: """Decorator to mark a function of the task encoder as restorable. @@ -141,6 +144,9 @@ def stateless( restore_seeds: Whether to restore the seeds for the function. I.e. the seeds are set from the sample index and the worker seed, such that they can be restored when a sample is restored from that function. + restore_task_encoder_seeds: Whether to restore the seeds for the task encoder. I.e. the seeds are set + from the sample index and the worker seed, such that they can be restored when a sample + is restored from that function. failure_tolerance: The number of consecutive exceptions that are handled, after which a `FatalSampleError` is raised for this function. @@ -165,77 +171,132 @@ def encode_sample(self, sample: T_sample) -> T_encoded_sample: ) if restore_seeds: worker_seed = None + orig_fn = fn + + if inspect.isgeneratorfunction(orig_fn): - @functools.wraps(fn) - def seed_wrapper_generator(self, *args, **kwargs): - nonlocal worker_seed - if worker_seed is None: - worker_seed = WorkerConfig.active_worker_config.worker_seed() + @functools.wraps(orig_fn) + def seed_wrapper_generator(self, *args, **kwargs): + nonlocal worker_seed + if worker_seed is None: + worker_seed = WorkerConfig.active_worker_config.worker_seed() - # Save the RNG states and set the new seed - outer_rng_state = SystemRng.save_state() + # Save the RNG states and set the new seed + outer_rng_state = SystemRng.save_state() - # Before constructing the generator and before the first - # iteration, set inner RNG based on seed computed - # from worker_seed and current sample index - SystemRng.seed_args(worker_seed, self.current_sample_index) + # Before constructing the generator and before the first + # iteration, set inner RNG based on seed computed + # from worker_seed and current sample index + SystemRng.seed_args(worker_seed, self.current_sample_index) + + it = iter(orig_fn(self, *args, **kwargs)) + + inner_rand_state = None + + while True: + if inner_rand_state is not None: + # Restore inner random state before calling the generator + # This will not be done on the first iteration + SystemRng.restore_state(inner_rand_state) + + try: + # Now call the generator. This will yield the sample + # But note it may also throw an exception or a StopIteration + sample = next(it) + + # Save inner random state after calling the generator + inner_rand_state = SystemRng.save_state() + except StopIteration: + # We're stopping here, but the outer random state + # will be restored before returning (in finally below) + break + finally: + # Restore outer rand state before yielding or when an exception was raised + SystemRng.restore_state(outer_rng_state) + + # Now yield the sample. + # This will give control back to the caller who may + # change the random state. + yield sample + + # Save outer random state after yielding + outer_rng_state = SystemRng.save_state() + + fn = seed_wrapper_generator + else: - it = iter(fn(self, *args, **kwargs)) + @functools.wraps(orig_fn) + def seed_wrapper(self, *args, **kwargs): + nonlocal worker_seed + if worker_seed is None: + worker_seed = WorkerConfig.active_worker_config.worker_seed() - inner_rand_state = None + # Save the RNG states and set the new seed + rng_state = SystemRng.save_state() - while True: - if inner_rand_state is not None: - # Restore inner random state before calling the generator - # This will not be done on the first iteration - SystemRng.restore_state(inner_rand_state) + SystemRng.seed_args(worker_seed, self.current_sample_index) try: - # Now call the generator. This will yield the sample - # But note it may also throw an exception or a StopIteration - sample = next(it) - - # Save inner random state after calling the generator - inner_rand_state = SystemRng.save_state() - except StopIteration: - # We're stopping here, but the outer random state - # will be restored before returning (in finally below) - break + return orig_fn(self, *args, **kwargs) finally: - # Restore outer rand state before yielding or when an exception was raised - SystemRng.restore_state(outer_rng_state) + # Restore the RNGs + SystemRng.restore_state(rng_state) - # Now yield the sample. - # This will give control back to the caller who may - # change the random state. - yield sample + fn = seed_wrapper - # Save outer random state after yielding - outer_rng_state = SystemRng.save_state() + if restore_task_encoder_seeds: + te_orig_fn = fn + worker_seed = None + if inspect.isgeneratorfunction(te_orig_fn): + + @functools.wraps(te_orig_fn) + def seed_wrapper_generator(self, *args, **kwargs): + nonlocal worker_seed + if worker_seed is None: + worker_seed = WorkerConfig.active_worker_config.worker_seed() + + te_outer_rng_state = self.rng.save_state() - @functools.wraps(fn) - def seed_wrapper(self, *args, **kwargs): - nonlocal worker_seed - if worker_seed is None: - worker_seed = WorkerConfig.active_worker_config.worker_seed() + self.rng.seed_args(worker_seed, self.current_sample_index) - # Save the RNG states and set the new seed - rng_state = SystemRng.save_state() + it = iter(te_orig_fn(self, *args, **kwargs)) - SystemRng.seed_args(worker_seed, self.current_sample_index) + inner_rand_state = None - try: - return fn(self, *args, **kwargs) - finally: - # Restore the RNGs - SystemRng.restore_state(rng_state) + while True: + if inner_rand_state is not None: + self.rng.restore_state(inner_rand_state) + try: + sample = next(it) + inner_rand_state = self.rng.save_state() + except StopIteration: + break + finally: + self.rng.restore_state(te_outer_rng_state) - if inspect.isgeneratorfunction(fn): - setattr(seed_wrapper_generator, "__stateless__", True) - return seed_wrapper_generator + yield sample + + te_outer_rng_state = self.rng.save_state() else: - setattr(seed_wrapper, "__stateless__", True) - return seed_wrapper + + @functools.wraps(te_orig_fn) + def seed_wrapper(self, *args, **kwargs): + nonlocal worker_seed + if worker_seed is None: + worker_seed = WorkerConfig.active_worker_config.worker_seed() + + # Save the RNG states and set the new seed + te_rng_state = self.rng.save_state() + + self.rng.seed_args(worker_seed, self.current_sample_index) + + try: + return te_orig_fn(self, *args, **kwargs) + finally: + # Restore the RNGs + self.rng.restore_state(te_rng_state) + + fn = seed_wrapper setattr(fn, "__stateless__", True) setattr(fn, "__failure_tolerance__", failure_tolerance) @@ -345,7 +406,7 @@ def from_samples(cls: Type[T_batch], samples: Sequence[Sample], **kwargs) -> T_b return cls(**init_args) -class TaskEncoder(ABC, Generic[T_sample, T_encoded_sample, T_raw_batch, T_batch]): +class TaskEncoder(Savable, Generic[T_sample, T_encoded_sample, T_raw_batch, T_batch]): """ Base class for task encoders. @@ -371,6 +432,12 @@ class TaskEncoder(ABC, Generic[T_sample, T_encoded_sample, T_raw_batch, T_batch] #: The decoder to use for decoding samples. Set manually as needed to override options. decoder: Optional[SampleDecoder] = SampleDecoder() + # Defines which fields are saved and restored when saving and restoring the state of the task encoder. + _state_fields: ClassVar[Tuple[str, ...]] = ("rng",) + + # State fields, they are initialized when the dataloader is started. + rng: UserRng + @stateless def cook_crude_sample( self, @@ -972,6 +1039,15 @@ def build_val_datasets( return dataset + def reset_state(self) -> None: + """Internally reset the state of the task encoder. This is called when the dataloader is started.""" + assert WorkerConfig.active_worker_config is not None, "Must be called within worker" + self.rng = UserRng(WorkerConfig.active_worker_config.worker_seed()) + + # Burrow the save_state and restore_state methods from the SavableDataset class. + save_state = SavableDataset.save_state + restore_state = SavableDataset.restore_state + @property def current_batch_index(self) -> int: """Returns the current index for the next batch yielded from the current worker. Each batch diff --git a/src/megatron/energon/worker.py b/src/megatron/energon/worker.py index 401b21ae..22027bc5 100644 --- a/src/megatron/energon/worker.py +++ b/src/megatron/energon/worker.py @@ -4,6 +4,7 @@ import hashlib import json import multiprocessing +import threading from dataclasses import dataclass from pathlib import Path from typing import Any, ClassVar, Dict, List, Optional, TextIO, TypeVar @@ -12,11 +13,34 @@ import torch.utils.data from megatron.energon.cache import CachePool +from megatron.energon.edataclass import edataclass __all__ = ("WorkerConfig",) T = TypeVar("T") +THREAD_SAFE = True + + +@edataclass +class ActiveWorkerState(threading.local): + #: The current sample index within the current iterating worker + sample_index_stack: Optional[List[int]] = None + #: The global rank override for the worker. Required for restoring samples. + override_global_rank: Optional[int] = None + #: The current cache pool for the worker. + cache_pool: Optional[CachePool] = None + #: The current worker config within the current iterating worker + worker_config: "WorkerConfig | None" = None + + +class classproperty: + def __init__(self, getter): + self.getter = getter + + def __get__(self, instance, owner): + return self.getter(owner) + @dataclass(slots=True, kw_only=True, eq=False) class WorkerConfig: @@ -62,16 +86,12 @@ class WorkerConfig: #: worker_id of the opened worker debug file _worker_debug_file_worker_id: Optional[int] = None - #: The current sample index within the current iterating worker - _sample_index_stack: ClassVar[Optional[List[int]]] = None - #: The current worker config within the current iterating worker - active_worker_config: ClassVar[Optional["WorkerConfig"]] = None + _active_state: ClassVar[ActiveWorkerState] = ActiveWorkerState() - #: The global rank override for the worker. Required for restoring samples. - _worker_override_global_rank: ClassVar[Optional[List[int]]] = None - - #: The current cache pool for the worker. - _cache_pool: "ClassVar[Optional[CachePool]]" = None + @classproperty + def active_worker_config(cls) -> Optional["WorkerConfig"]: + """The current worker config within the current iterating worker""" + return cls._active_state.worker_config def worker_activate( self, @@ -81,42 +101,49 @@ def worker_activate( ): """Activates the worker config for the current worker and sets it as actively iterating. Must be called before next() call on the datasets.""" - assert WorkerConfig.active_worker_config is None - WorkerConfig._sample_index_stack = [sample_index] - WorkerConfig.active_worker_config = self - WorkerConfig._worker_override_global_rank = override_global_rank - WorkerConfig._cache_pool = cache_pool + WorkerConfig._active_state.sample_index_stack = [sample_index] + WorkerConfig._active_state.worker_config = self + WorkerConfig._active_state.override_global_rank = override_global_rank + WorkerConfig._active_state.cache_pool = cache_pool + print( + f"worker_activate {self.rank} {self.num_workers} {self.rank_worker_id()} on {threading.get_ident()}\n", + end="", + ) def worker_push_sample_index(self, sample_index: int): """Pushes a new sample index to the sample index stack. Should be set by wrapping datasets before calling inners.""" - assert WorkerConfig.active_worker_config is not None - WorkerConfig._sample_index_stack.append(sample_index) + assert WorkerConfig._active_state.sample_index_stack is not None + WorkerConfig._active_state.sample_index_stack.append(sample_index) def worker_pop_sample_index(self): """Pushes a new sample index to the sample index stack. Should be set by wrapping datasets before calling inners.""" - assert WorkerConfig.active_worker_config is not None - return WorkerConfig._sample_index_stack.pop() + assert WorkerConfig._active_state.sample_index_stack is not None + return WorkerConfig._active_state.sample_index_stack.pop() def worker_deactivate(self): """Deactivates the worker config for the current worker and deactivates it for iterating. Must be called after next() call on the datasets.""" if WorkerConfig.active_worker_config is not None: - assert len(WorkerConfig._sample_index_stack) == 1, ( - f"Sample index stack not empty: {WorkerConfig._sample_index_stack}" + assert WorkerConfig._active_state.sample_index_stack is not None + assert len(WorkerConfig._active_state.sample_index_stack) == 1, ( + f"Sample index stack not empty: {WorkerConfig._active_state.sample_index_stack}" ) - WorkerConfig._sample_index_stack = None - WorkerConfig.active_worker_config = None - WorkerConfig._worker_override_global_rank = None + WorkerConfig._active_state.sample_index_stack = None + WorkerConfig._active_state.worker_config = None + WorkerConfig._active_state.override_global_rank = None + WorkerConfig._active_state.cache_pool = None @property def active_worker_sample_index(self) -> int: """Returns the current sample index for the actively iterating worker.""" # Internal sample index is for the local worker. If using multiple workers per rank, this # must be multiplied by the number of workers and offset by the local worker index. + assert WorkerConfig._active_state.sample_index_stack is not None return ( - WorkerConfig._sample_index_stack[-1] * max(self.num_workers, 1) + self.rank_worker_id() + WorkerConfig._active_state.sample_index_stack[-1] * max(self.num_workers, 1) + + self.rank_worker_id() ) @property @@ -124,8 +151,10 @@ def active_worker_batch_index(self) -> int: """Returns the current batch index for the actively iterating worker.""" # Internal batch index is for the local worker. If using multiple workers per rank, this # must be multiplied by the number of workers and offset by the local worker index. + assert WorkerConfig._active_state.sample_index_stack is not None return ( - WorkerConfig._sample_index_stack[0] * max(self.num_workers, 1) + self.rank_worker_id() + WorkerConfig._active_state.sample_index_stack[0] * max(self.num_workers, 1) + + self.rank_worker_id() ) @property @@ -176,9 +205,9 @@ def default_worker_config( def rank_worker_id(self) -> int: """Returns the self worker id within the current rank.""" - if self._worker_override_global_rank: + if WorkerConfig._active_state.override_global_rank: assert self.worker_id_offset == 0 - return self._worker_override_global_rank % self.num_workers + return WorkerConfig._active_state.override_global_rank % self.num_workers worker_info = torch.utils.data.get_worker_info() if worker_info is None: return self.worker_id_offset @@ -211,15 +240,12 @@ def global_worker_id(self, override_local_worker_id: Optional[int] = None) -> in override_local_worker_id (int, optional): The local worker id to override. None means the current worker, which is the default. """ - if self._worker_override_global_rank is not None: - assert override_local_worker_id is None - return self._worker_override_global_rank - if override_local_worker_id is not None: return self.rank * self.num_workers + override_local_worker_id - else: - self.assert_worker() - return self.rank * self.num_workers + self.rank_worker_id() + if WorkerConfig._active_state.override_global_rank is not None: + return WorkerConfig._active_state.override_global_rank + self.assert_worker() + return self.rank * self.num_workers + self.rank_worker_id() def worker_seed(self, override_local_worker_id: Optional[int] = None) -> int: """Returns the seed for the current worker (or a specified worker). diff --git a/src/megatron/energon/wrappers/base.py b/src/megatron/energon/wrappers/base.py index 2f4dba89..2bfe4618 100644 --- a/src/megatron/energon/wrappers/base.py +++ b/src/megatron/energon/wrappers/base.py @@ -1,6 +1,7 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause +import threading from abc import ABC, abstractmethod from contextlib import contextmanager from typing import Any, Generator, Generic, Iterable, Optional, Tuple, Type, TypeVar, Union @@ -54,6 +55,9 @@ def dataset(self) -> SavableDataset: assert len(self.datasets) == 1 return self.datasets[0] + def len_worker(self, worker_idx: int | None = None) -> int: + return sum(ds.len_worker(worker_idx) for ds in self.datasets) + def can_restore_sample(self) -> bool: return all(ds.can_restore_sample() for ds in self.datasets) @@ -90,6 +94,20 @@ def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_s src=self, ) + @abstractmethod + def reset_state_own(self) -> None: + """Resets the state of the dataset, excl. the inner datasets.""" + ... + + def reset_state(self) -> None: + """Resets the state of the inner datasets and then the own state.""" + + for ds in self.datasets: + ds.reset_state() + + super().reset_state() + self.reset_state_own() + def save_state(self) -> FlexState: own_state = super().save_state() @@ -102,22 +120,6 @@ def restore_state(self, state: FlexState) -> None: super().restore_state(state) - def reset_state_deep(self) -> None: - """Resets the state of the inner datasets and then the own state.""" - - for ds in self.datasets: - if isinstance(ds, BaseWrapperDataset): - ds.reset_state_deep() - else: - ds.reset_state_own() - - self.reset_state_own() - - @abstractmethod - def reset_state_own(self) -> None: - """Resets the state of the dataset, excl. the inner datasets.""" - ... - class SampleIndex(Savable): """A simple class to hold the sample index for one worker.""" @@ -141,7 +143,9 @@ def get_next(self) -> int: def ctx(self, sample_idx: Optional[int] = None): if sample_idx is None: sample_idx = self.get_next() - assert WorkerConfig.active_worker_config is not None + assert WorkerConfig.active_worker_config is not None, ( + f"WorkerConfig.active_worker_config is None on thread {threading.get_ident()}" + ) WorkerConfig.active_worker_config.worker_push_sample_index(sample_idx) # print(" " * SampleIndex.actives + f"Activated from {type(self.src).__name__}({id(self.src)}) {sample_idx} -> {WorkerConfig.active_worker_config._sample_index_stack}") SampleIndex.actives += 1 diff --git a/src/megatron/energon/wrappers/batch_dataset.py b/src/megatron/energon/wrappers/batch_dataset.py index b3e5f158..2b150bbb 100644 --- a/src/megatron/energon/wrappers/batch_dataset.py +++ b/src/megatron/energon/wrappers/batch_dataset.py @@ -83,8 +83,6 @@ def __init__( self.error_handler = error_handler self.failure_tolerance = failure_tolerance - self.reset_state_own() - def reset_state_own(self) -> None: self._sample_index = SampleIndex(self.worker_config, src=self) self._generator_sample_keys = None diff --git a/src/megatron/energon/wrappers/blend_dataset.py b/src/megatron/energon/wrappers/blend_dataset.py index e1ed5738..7b60ce5a 100644 --- a/src/megatron/energon/wrappers/blend_dataset.py +++ b/src/megatron/energon/wrappers/blend_dataset.py @@ -46,7 +46,6 @@ def __init__( super().__init__(self.datasets, worker_config=worker_config) self.dataset_weights = dataset_weights - self.reset_state_own() def reset_state_own(self) -> None: self._worker_rng = WorkerRng(self.worker_config) diff --git a/src/megatron/energon/wrappers/buffer.py b/src/megatron/energon/wrappers/buffer.py index 09da5739..8b6808a8 100644 --- a/src/megatron/energon/wrappers/buffer.py +++ b/src/megatron/energon/wrappers/buffer.py @@ -33,7 +33,6 @@ class SavableSampleBuffer(BaseWrapperDataset[T_sample, T_sample], Generic[T_samp def __init__(self, dataset: SavableDataset[T_sample], *, worker_config: WorkerConfig): super().__init__(dataset, worker_config=worker_config) - self.reset_state_own() def reset_state_own(self) -> None: self._buffer = [] diff --git a/src/megatron/energon/wrappers/epochize_dataset.py b/src/megatron/energon/wrappers/epochize_dataset.py index 3f6f71a0..85fc7727 100644 --- a/src/megatron/energon/wrappers/epochize_dataset.py +++ b/src/megatron/energon/wrappers/epochize_dataset.py @@ -43,8 +43,6 @@ def __init__( self.length = length self._active_iter = None - self.reset_state_own() - def reset_state_own(self) -> None: self._offset = 0 diff --git a/src/megatron/energon/wrappers/filter_dataset.py b/src/megatron/energon/wrappers/filter_dataset.py index f28d84f0..2302c104 100644 --- a/src/megatron/energon/wrappers/filter_dataset.py +++ b/src/megatron/energon/wrappers/filter_dataset.py @@ -42,14 +42,9 @@ def __init__( self.filter_fn = filter_fn self.filter_fn_config = filter_fn_config - self.reset_state_own() - def reset_state_own(self) -> None: self._sample_index = SampleIndex(self.worker_config, src=self) - def len_worker(self, worker_idx: int | None = None) -> int: - return self.dataset.len_worker(worker_idx) - def __iter__(self) -> Iterator[T_sample]: for sample in self.dataset: with self._sample_index.ctx(): diff --git a/src/megatron/energon/wrappers/gc_dataset.py b/src/megatron/energon/wrappers/gc_dataset.py index 70f31688..bcb89584 100644 --- a/src/megatron/energon/wrappers/gc_dataset.py +++ b/src/megatron/energon/wrappers/gc_dataset.py @@ -97,9 +97,6 @@ def __init__( def reset_state_own(self) -> None: return - def len_worker(self, worker_idx: int | None = None) -> int: - return self.dataset.len_worker(worker_idx) - def __iter__(self) -> Iterator[T_sample]: in_worker = torch.utils.data.get_worker_info() is not None if in_worker and not _frozen_cuda_tensors_initialized: diff --git a/src/megatron/energon/wrappers/group_batch_dataset.py b/src/megatron/energon/wrappers/group_batch_dataset.py index 747a042c..23e8d2e6 100644 --- a/src/megatron/energon/wrappers/group_batch_dataset.py +++ b/src/megatron/energon/wrappers/group_batch_dataset.py @@ -112,8 +112,6 @@ def __init__( self.error_handler = error_handler self.failure_tolerance = failure_tolerance - self.reset_state_own() - assert not inspect.isgeneratorfunction(batcher), ( f"Batcher {batcher} must not be a generator function for grouped batching." ) diff --git a/src/megatron/energon/wrappers/iter_map_dataset.py b/src/megatron/energon/wrappers/iter_map_dataset.py index e555ec02..3faf02ee 100644 --- a/src/megatron/energon/wrappers/iter_map_dataset.py +++ b/src/megatron/energon/wrappers/iter_map_dataset.py @@ -83,8 +83,6 @@ def __init__( self.stateless_iter_fn = stateless_iter_fn self.iter_map_fn_config = iter_map_fn_config - self.reset_state_own() - def reset_state_own(self) -> None: self._sample_index = SampleIndex(self.worker_config, src=self) diff --git a/src/megatron/energon/wrappers/limit_dataset.py b/src/megatron/energon/wrappers/limit_dataset.py index 9ee76200..0090e97f 100644 --- a/src/megatron/energon/wrappers/limit_dataset.py +++ b/src/megatron/energon/wrappers/limit_dataset.py @@ -38,7 +38,6 @@ def __init__( super().__init__(dataset, worker_config=worker_config) self.length = length self.reset_after_epoch = reset_after_epoch - self.reset_state_own() def reset_state_own(self) -> None: self.current_offset = 0 @@ -104,10 +103,10 @@ def __iter__(self) -> Iterator[T_sample]: ) # Reset the inner dataset - self.dataset.reset_state_deep() + self.dataset.reset_state() self.current_offset = 0 if self.reset_after_epoch: - self.dataset.reset_state_deep() + self.dataset.reset_state() def worker_has_samples(self) -> bool: return super().worker_has_samples() and self.length > 0 diff --git a/src/megatron/energon/wrappers/log_sample_dataset.py b/src/megatron/energon/wrappers/log_sample_dataset.py index 415b9f0d..a0851693 100644 --- a/src/megatron/energon/wrappers/log_sample_dataset.py +++ b/src/megatron/energon/wrappers/log_sample_dataset.py @@ -72,14 +72,9 @@ def __init__( self.get_keys_fn = get_keys_fn self.mode = mode - self.reset_state_own() - def reset_state_own(self) -> None: self._step = 0 - def len_worker(self, worker_idx: int | None = None) -> int: - return self.dataset.len_worker(worker_idx) - def _log(self, sample: T_sample) -> None: if self.worker_config.should_log(level=1): log_entry = { diff --git a/src/megatron/energon/wrappers/map_dataset.py b/src/megatron/energon/wrappers/map_dataset.py index dd6d3fa6..b3e35ec6 100644 --- a/src/megatron/energon/wrappers/map_dataset.py +++ b/src/megatron/energon/wrappers/map_dataset.py @@ -81,16 +81,11 @@ def __init__( self.map_fn_config = map_fn_config self.failure_tolerance = failure_tolerance - self.reset_state_own() - def reset_state_own(self) -> None: self._sample_index = SampleIndex(self.worker_config, src=self) self._generator_sample_key = None self._generator_offset = None - def len_worker(self, worker_idx: int | None = None) -> int: - return self.dataset.len_worker(worker_idx) - def __iter__(self) -> Iterator[T_sample_out]: last_map_failures = 0 diff --git a/src/megatron/energon/wrappers/mix_batch_dataset.py b/src/megatron/energon/wrappers/mix_batch_dataset.py index 7b12aca2..eb6d9b69 100644 --- a/src/megatron/energon/wrappers/mix_batch_dataset.py +++ b/src/megatron/energon/wrappers/mix_batch_dataset.py @@ -117,9 +117,6 @@ def __init__( def reset_state_own(self) -> None: return - def len_worker(self, worker_idx: int | None = None) -> int: - return self.dataset.len_worker(worker_idx) - def __iter__(self) -> Iterator[T_batch]: yield from self.dataset diff --git a/src/megatron/energon/wrappers/packing_dataset.py b/src/megatron/energon/wrappers/packing_dataset.py index f0b85358..1bc28c61 100644 --- a/src/megatron/energon/wrappers/packing_dataset.py +++ b/src/megatron/energon/wrappers/packing_dataset.py @@ -144,8 +144,6 @@ def __init__( self.final_packer_failure_tolerance = final_packer_failure_tolerance self.sample_encoder_failure_tolerance = sample_encoder_failure_tolerance - self.reset_state_own() - def reset_state_own(self) -> None: self._reading_buffer = SavableSampleBuffer(self.dataset, worker_config=self.worker_config) self._pre_packing_buffer = SavableSampleBuffer( diff --git a/src/megatron/energon/wrappers/repeat_dataset.py b/src/megatron/energon/wrappers/repeat_dataset.py index eb2298cd..f63010e9 100644 --- a/src/megatron/energon/wrappers/repeat_dataset.py +++ b/src/megatron/energon/wrappers/repeat_dataset.py @@ -41,8 +41,6 @@ def __init__( self.repeats = repeats self.restart = restart - self.reset_state_own() - def reset_state_own(self) -> None: self._repetition = 0 self._index = 0 diff --git a/src/megatron/energon/wrappers/shuffle_buffer_dataset.py b/src/megatron/energon/wrappers/shuffle_buffer_dataset.py index 4d7bbc24..19926436 100644 --- a/src/megatron/energon/wrappers/shuffle_buffer_dataset.py +++ b/src/megatron/energon/wrappers/shuffle_buffer_dataset.py @@ -31,15 +31,11 @@ def __init__( """Create a shuffle buffer for the dataset.""" super().__init__(dataset, worker_config=worker_config) self.size = size - self.reset_state_own() def reset_state_own(self) -> None: self._worker_rng = WorkerRng(self.worker_config) self._active_buffer = SavableSampleBuffer(self.dataset, worker_config=self.worker_config) - def len_worker(self, worker_idx: int | None = None) -> int: - return self.dataset.len_worker(worker_idx) - def __iter__(self) -> Iterator[T_sample]: self._active_buffer.worker_start() it = iter(self._active_buffer.append_iter()) diff --git a/src/megatron/energon/wrappers/task_encoder_state_dataset.py b/src/megatron/energon/wrappers/task_encoder_state_dataset.py new file mode 100644 index 00000000..32534acc --- /dev/null +++ b/src/megatron/energon/wrappers/task_encoder_state_dataset.py @@ -0,0 +1,81 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import ( + Any, + Dict, + Generic, + Iterator, + Tuple, + TypeVar, + Union, +) + +import megatron.energon +from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.worker import WorkerConfig +from megatron.energon.wrappers.base import BaseWrapperDataset + +T_sample = TypeVar("T_sample") + + +class TaskEncoderStateDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]): + """This dataset wrapper applies a custom function to transform each sample.""" + + # Will save it's own state + _task_encoder: "megatron.energon.TaskEncoder" + _task_encoder_was_reset: bool = False + + _savable_fields = ("_task_encoder",) + + def __init__( + self, + dataset: SavableDataset[T_sample], + task_encoder: "megatron.energon.TaskEncoder", + *, + worker_config: WorkerConfig, + ): + """Construct a wrapper for saving/restoring the state of the task encoder. + The dataset is transparently delegated. + + Args: + dataset: The input dataset to wrap + task_encoder: The task encoder to wrap. + worker_config: Worker configuration. + """ + super().__init__(dataset, worker_config=worker_config) + self._task_encoder = task_encoder + + def reset_state_own(self) -> None: + self._task_encoder_was_reset = False + + def __iter__(self) -> Iterator[T_sample]: + if not self._task_encoder_was_reset: + self._task_encoder_was_reset = True + self._task_encoder.reset_state() + yield from self.dataset + + def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_sample: + inner_sample = self.dataset.restore_sample(restore_key) + self._task_encoder.restore_sample() + return inner_sample + + def config(self) -> Dict[str, Any]: + return { + "type": type(self).__qualname__, + "dataset": self.dataset.config(), + "map_fn": self._function_config(self.map_fn), + **( + { + "map_fn_config": ( + self.map_fn_config() if callable(self.map_fn_config) else self.map_fn_config + ) + } + if self.map_fn_config + else {} + ), + "map_fn_stateless": self.stateless_map_fn, + } + + def __str__(self): + return f"MapDataset(map_fn={self.map_fn}, dataset={self.dataset})" diff --git a/src/megatron/energon/wrappers/watchdog_dataset.py b/src/megatron/energon/wrappers/watchdog_dataset.py index 68ecffea..07c85d11 100644 --- a/src/megatron/energon/wrappers/watchdog_dataset.py +++ b/src/megatron/energon/wrappers/watchdog_dataset.py @@ -41,9 +41,6 @@ def __init__( def reset_state_own(self) -> None: pass - def len_worker(self, worker_idx: int | None = None) -> int: - return self.dataset.len_worker(worker_idx) - def _watchdog_trigger(self) -> None: if self.fail_on_timeout: # Raising an exception here will kill the whole process diff --git a/tests/test_crudedataset.py b/tests/test_crudedataset.py index 637ce8fb..213dd229 100644 --- a/tests/test_crudedataset.py +++ b/tests/test_crudedataset.py @@ -401,8 +401,6 @@ def test_loader(self): max_samples_per_sequence=None, packing_buffer_size=2, ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, ) samples = [s.__key__ for idx, s in zip(range(100), loader)] @@ -423,8 +421,6 @@ def test_loader(self): max_samples_per_sequence=None, packing_buffer_size=2, ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, ) loader.restore_state_rank(state) @@ -454,8 +450,6 @@ def test_aux_random_access(self): max_samples_per_sequence=None, packing_buffer_size=2, ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, ) print("Iterating from dataset") @@ -483,8 +477,6 @@ def test_aux_random_access(self): max_samples_per_sequence=None, packing_buffer_size=2, ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, ) loader.restore_state_rank(state) @@ -514,8 +506,6 @@ def test_aux_random_access_with_cache(self): max_samples_per_sequence=None, packing_buffer_size=2, ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, cache_pool=FileStoreCachePool( parent_cache_dir=self.dataset_path / "cache", num_workers=1, @@ -548,8 +538,6 @@ def test_aux_random_access_with_cache(self): max_samples_per_sequence=None, packing_buffer_size=2, ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, cache_pool=FileStoreCachePool( parent_cache_dir=self.dataset_path / "cache", num_workers=1, @@ -583,8 +571,6 @@ def test_aux_random_access_with_cache_and_postencode(self): max_samples_per_sequence=None, packing_buffer_size=2, ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, cache_pool=FileStoreCachePool( parent_cache_dir=self.dataset_path / "cache", num_workers=1, @@ -617,8 +603,6 @@ def test_aux_random_access_with_cache_and_postencode(self): max_samples_per_sequence=None, packing_buffer_size=2, ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, cache_pool=FileStoreCachePool( parent_cache_dir=self.dataset_path / "cache", num_workers=1, @@ -718,8 +702,6 @@ def test_nomds(self): shuffle_buffer_size=None, max_samples_per_sequence=None, ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, ) samples = [s.__key__ for idx, s in zip(range(100), loader)] diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 94b8baef..39dd7ca3 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -151,7 +151,7 @@ def test_dataloader_fork(self): worker_config = WorkerConfig( rank=0, world_size=1, - num_workers=1, + num_workers=2, seed_offset=42, ) @@ -217,7 +217,7 @@ def test_dataloader_thread(self): worker_config = WorkerConfig( rank=0, world_size=1, - num_workers=1, + num_workers=2, seed_offset=42, ) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 43567c9d..c3c5f319 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1435,8 +1435,6 @@ def pack_selected_samples( max_samples_per_sequence=None, task_encoder=TestTaskEncoder(), ), - checkpoint_every_min_n_samples=1, - checkpoint_every_sec=0, ) samples_r0 = list(loader_r0) @@ -1465,8 +1463,6 @@ def pack_selected_samples( max_samples_per_sequence=None, task_encoder=TestTaskEncoder(), ), - checkpoint_every_min_n_samples=1, - checkpoint_every_sec=0, ) loader_r0.restore_state_rank(rank_state_r0) @@ -1592,8 +1588,6 @@ def encode_batch(self, batch: CaptioningSample) -> CaptioningEncodedBatch: max_samples_per_sequence=None, task_encoder=GroupingTaskEncoder(), ), - checkpoint_every_min_n_samples=1, - checkpoint_every_sec=0, ) batches = list(zip(range(40), loader)) print([batch.__key__ for idx, batch in batches]) @@ -1612,8 +1606,6 @@ def encode_batch(self, batch: CaptioningSample) -> CaptioningEncodedBatch: max_samples_per_sequence=None, task_encoder=GroupingTaskEncoder(), ), - checkpoint_every_min_n_samples=1, - checkpoint_every_sec=0, ) batches = list(zip(range(40), loader_r0)) @@ -1637,8 +1629,6 @@ def encode_batch(self, batch: CaptioningSample) -> CaptioningEncodedBatch: max_samples_per_sequence=None, task_encoder=GroupingTaskEncoder(), ), - checkpoint_every_min_n_samples=1, - checkpoint_every_sec=0, ) loader_r0.restore_state_rank(state) diff --git a/tests/test_dataset_det.py b/tests/test_dataset_det.py index 7cac5b21..911ee2f8 100644 --- a/tests/test_dataset_det.py +++ b/tests/test_dataset_det.py @@ -659,7 +659,6 @@ def test_restore_state_workers(self): n1 = 18 n2 = 109 n3 = 28 - ces = 0 # This seed is used by the dataset to shuffle the data torch.manual_seed(42) @@ -673,7 +672,7 @@ def test_restore_state_workers(self): max_samples_per_sequence=2, parallel_shard_iters=psi, ) - loader = get_savable_loader(ds, checkpoint_every_sec=ces) + loader = get_savable_loader(ds) # print("save state") state_0 = loader.save_state_rank() diff --git a/tests/test_metadataset.py b/tests/test_metadataset.py index 45230f55..d362d616 100644 --- a/tests/test_metadataset.py +++ b/tests/test_metadataset.py @@ -977,8 +977,6 @@ def new_loader(): shuffle_buffer_size=None, max_samples_per_sequence=None, ), - checkpoint_every_sec=0.5, - checkpoint_every_min_n_samples=1, ) # Train mode dataset @@ -1303,8 +1301,6 @@ def test_save_restore_next(self): shuffle_buffer_size=None, max_samples_per_sequence=None, ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=0, ) skip_initial = 9 @@ -1321,8 +1317,6 @@ def test_save_restore_next(self): shuffle_buffer_size=None, max_samples_per_sequence=None, ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=0, ) rst_loader.restore_state_rank(previous_cp) for i, rst_sample in zip(range(1), rst_loader): @@ -1355,8 +1349,6 @@ def test_save_restore_next(self): shuffle_buffer_size=None, max_samples_per_sequence=None, ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=0, ) second_loader.restore_state_rank(state_initial) @@ -1409,8 +1401,6 @@ def test_save_restore_next(self): shuffle_buffer_size=None, max_samples_per_sequence=None, ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=0, ) ref_loader.restore_state_rank(state_offset) diff --git a/tests/test_metadataset_fewsamp.py b/tests/test_metadataset_fewsamp.py index f1ad30c5..7a9336bb 100644 --- a/tests/test_metadataset_fewsamp.py +++ b/tests/test_metadataset_fewsamp.py @@ -181,8 +181,6 @@ def test_metadataset_few_samples_save_restore(self): train_loader = get_savable_loader( train_dataset, - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, ) # Load 3 samples @@ -203,8 +201,6 @@ def test_metadataset_few_samples_save_restore(self): shuffle_buffer_size=100, max_samples_per_sequence=None, ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, ) train_loader.restore_state_rank(state1) # Load 5 samples @@ -235,8 +231,6 @@ def test_too_few_samples(self): shuffle_buffer_size=None, max_samples_per_sequence=None, ), - checkpoint_every_min_n_samples=1, - checkpoint_every_sec=0, ) lens.append(len(loader)) diff --git a/tests/test_metadataset_v2.py b/tests/test_metadataset_v2.py index f0a03ed7..d11be83b 100644 --- a/tests/test_metadataset_v2.py +++ b/tests/test_metadataset_v2.py @@ -362,8 +362,6 @@ def test_joined_metadataset(self): train_loader = get_savable_loader( train_dataset, - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, ) data = list(zip(range(2 * 55), train_loader)) @@ -403,8 +401,6 @@ def test_joined_metadataset(self): shuffle_buffer_size=None, max_samples_per_sequence=None, ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, ) train_loader.restore_state_rank(state) @@ -473,8 +469,6 @@ def test_joined_metadataset_joiner(self): train_loader = get_savable_loader( train_dataset, - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, ) data = list(zip(range(2 * 55), train_loader)) @@ -553,8 +547,6 @@ def test_left_join(self): train_loader = get_savable_loader( train_dataset, - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, ) data = list(zip(range(2 * 55), train_loader)) @@ -712,8 +704,6 @@ def test_left_join_exclude(self): train_loader = get_savable_loader( train_dataset, - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, ) data = list(zip(range(2 * 55), train_loader)) @@ -828,8 +818,6 @@ def test_metadataset_fixed_epochs(self): train_loader = get_savable_loader( train_dataset, - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, ) data = list(enumerate(train_loader)) @@ -889,8 +877,6 @@ def test_metadataset_fixed_epochs(self): max_samples_per_sequence=None, repeat=False, ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, ) train_loader.restore_state_rank(state1) data2_restore = list(enumerate(train_loader)) @@ -964,8 +950,6 @@ def test_metadataset_fixed_fractional_epochs(self): train_loader = get_savable_loader( train_dataset, - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, ) assert len(train_loader) == 38 + 55 + 27, len(train_loader) @@ -1002,8 +986,6 @@ def test_metadataset_fixed_fractional_epochs(self): max_samples_per_sequence=None, repeat=False, ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, ) data1 = list(zip(range(95), train_loader)) @@ -1020,8 +1002,6 @@ def test_metadataset_fixed_fractional_epochs(self): max_samples_per_sequence=None, repeat=False, ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, ) train_loader.restore_state_rank(state1) data2_restore = list(enumerate(train_loader)) @@ -1055,8 +1035,6 @@ def test_metadataset_fixed_fractional_epochs(self): max_samples_per_sequence=None, repeat=False, ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, ) ds1_counter = 0 @@ -1082,8 +1060,6 @@ def test_metadataset_fixed_fractional_epochs(self): max_samples_per_sequence=None, repeat=False, ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, ) train_loader.restore_state_rank(state1) data2_restore = list(enumerate(train_loader)) @@ -1114,8 +1090,6 @@ def test_metadataset_fixed_fractional_epochs(self): parallel_shard_iters=1, max_samples_per_sequence=None, ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, ) data = list(zip(range(200), train_loader)) @@ -1146,8 +1120,6 @@ def test_metadataset_fixed_fractional_epochs(self): max_samples_per_sequence=None, repeat=False, ), - checkpoint_every_sec=0, - checkpoint_every_min_n_samples=1, ) # TODO: This should be exactly 60. There is a corresponding TODO in the repeat_dataset.py From 139e7f571f7ee72a92aa0933fdd30bb75ee73d71 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Tue, 12 Aug 2025 15:34:34 +0200 Subject: [PATCH 13/36] WIP: Fix thread local state --- .../energon/dataloader/asynchronous/base.py | 13 ++- .../energon/dataloader/asynchronous/fork.py | 32 +++++-- src/megatron/energon/dataloader/dataloader.py | 11 ++- src/megatron/energon/dataloader/pin_memory.py | 7 ++ .../energon/dataloader/workers/base_worker.py | 4 +- src/megatron/energon/flavors/base_dataset.py | 30 ++++--- src/megatron/energon/task_encoder/base.py | 1 + src/megatron/energon/watchdog.py | 4 +- .../energon/wrappers/batch_dataset.py | 4 +- src/megatron/energon/wrappers/buffer.py | 28 +++--- .../energon/wrappers/group_batch_dataset.py | 12 +-- .../energon/wrappers/iter_map_dataset.py | 5 +- src/megatron/energon/wrappers/map_dataset.py | 4 +- .../energon/wrappers/packing_dataset.py | 6 +- .../wrappers/shuffle_buffer_dataset.py | 4 +- .../wrappers/task_encoder_state_dataset.py | 2 +- tests/test_dataloader.py | 87 +++++++++++++++++++ 17 files changed, 200 insertions(+), 54 deletions(-) diff --git a/src/megatron/energon/dataloader/asynchronous/base.py b/src/megatron/energon/dataloader/asynchronous/base.py index 879a6bc5..cfcacefc 100644 --- a/src/megatron/energon/dataloader/asynchronous/base.py +++ b/src/megatron/energon/dataloader/asynchronous/base.py @@ -83,6 +83,9 @@ def _set_result(self, result: Any) -> None: def _set_exception(self, exception: Exception) -> None: self._exception = exception + def __str__(self) -> str: + return f"FutureImpl(worker={self._worker._name!r}, future_id={self._future_id!r}, done={self.done()!r}, exception={getattr(self, '_exception', '')})" + class Asynchronous: """Asynchronous base class.""" @@ -168,13 +171,13 @@ def _worker_call(self, fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> self._pending_futures[future_id] = future = FutureImpl(self, future_id) print( - f"[{self._name}] worker_call {fn.__name__=} {args=} {kwargs=} {future_id=}\n", + f"[{self._name}] worker_call {fn.__name__=} {future_id=}\n", end="", ) self._cmd_queue.put( WorkerCommand(cmd=fn.__name__, args=args, kwargs=kwargs, future_id=future_id) ) - print(f"[{self._name}] queue: {self._cmd_queue.qsize()}\n", end="") + print(f"[{self._name}] cmd_queue: {self._cmd_queue.qsize()=}\n", end="") return future def _worker_run( @@ -200,7 +203,7 @@ def _worker_run( ) cmd = cmd_queue.get() print( - f"[{self._name}, fut={cmd.future_id}] got command {cmd.cmd=} {cmd.args=} {cmd.kwargs=}\n", + f"[{self._name}, fut={cmd.future_id}] got command {cmd.cmd=}\n", end="", ) try: @@ -209,9 +212,11 @@ def _worker_run( except Exception as e: print(f"[{self._name}, fut={cmd.future_id}] send exception {e!r}\n", end="") result_queue.put(WorkerResult(future_id=cmd.future_id, exception=e)) + print(f"[{self._name}] result_queue: {result_queue.qsize()=}\n", end="") else: - print(f"[{self._name}, fut={cmd.future_id}] send result {result!r}\n", end="") + print(f"[{self._name}, fut={cmd.future_id}] send result\n", end="") result_queue.put(WorkerResult(future_id=cmd.future_id, result=result)) + print(f"[{self._name}] result_queue: {result_queue.qsize()=}\n", end="") del result # cmd_queue.task_done() if cmd.cmd == self._wrk_shutdown_worker.__name__: diff --git a/src/megatron/energon/dataloader/asynchronous/fork.py b/src/megatron/energon/dataloader/asynchronous/fork.py index 75a12a92..d8f12a0f 100644 --- a/src/megatron/energon/dataloader/asynchronous/fork.py +++ b/src/megatron/energon/dataloader/asynchronous/fork.py @@ -6,6 +6,8 @@ import threading import warnings +import torch.multiprocessing + from megatron.energon.dataloader.asynchronous.base import ( Asynchronous, QueueProtocol, @@ -30,11 +32,11 @@ def _asynchronous_init(self, name: str) -> None: self._spawning_process = os.getpid() def _queues(self) -> tuple[QueueProtocol[WorkerCommand], QueueProtocol[WorkerResult]]: - return multiprocessing.Queue(), multiprocessing.Queue() + return torch.multiprocessing.Queue(), torch.multiprocessing.Queue() def _check_parent_process(self, evt_exit: threading.Event) -> None: """Check if the parent process is alive. If it is dead, exit the worker process.""" - parent_proc = multiprocessing.parent_process() + parent_proc = torch.multiprocessing.parent_process() parent_pid = os.getppid() if parent_proc is None: print(f"[{self._name}] No parent process, exiting", file=sys.stderr) @@ -49,6 +51,21 @@ def _worker_run( cmd_queue: multiprocessing.Queue, result_queue: multiprocessing.Queue, ) -> None: + try: + from torch.utils.data._utils import signal_handling + + signal_handling._set_worker_signal_handlers() + except (ImportError, AttributeError): + pass + + try: + torch.multiprocessing._set_thread_name("pt_data_worker") + except (ImportError, AttributeError): + pass + + # Disable torch internal multithreading, it may deadlock the forked process. + torch.set_num_threads(1) + # cmd_queue is read only, so we can cancel the join thread. cmd_queue.cancel_join_thread() worker_exit_evt = threading.Event() @@ -74,17 +91,22 @@ def _worker_run( print(f"[{self._name}] shutting down, done\n", end="") def _in_worker(self) -> bool: - return multiprocessing.current_process() == self._process + return torch.multiprocessing.current_process() == self._process def start(self) -> None: - multiprocessing.set_start_method("fork", force=True) - self._process = multiprocessing.Process( + torch.multiprocessing.set_start_method("fork", force=True) + orig_num_threads = torch.get_num_threads() + # Disable torch internal multithreading, it may deadlock the forked process. + torch.set_num_threads(1) + self._process = torch.multiprocessing.Process( target=self._worker_run, args=(self._cmd_queue, self._result_queue), daemon=True, name=f"ForkDataLoaderWorker-{self._name}", ) self._process.start() + # Revert the original number of threads in the main process. + torch.set_num_threads(orig_num_threads) def shutdown(self, in_del: bool = False) -> None: if self._spawning_process != os.getpid(): diff --git a/src/megatron/energon/dataloader/dataloader.py b/src/megatron/energon/dataloader/dataloader.py index c1cd95c1..bfcc9faf 100644 --- a/src/megatron/energon/dataloader/dataloader.py +++ b/src/megatron/energon/dataloader/dataloader.py @@ -499,14 +499,17 @@ def restore_sample(self, restore_key: tuple) -> TSample: The restored sample. """ id, global_worker_id, sample_idx = restore_key[:3] - assert id == type(self).__name__ + assert id == "DataLoaderWorker", f"id {id} != DataLoaderWorker" restore_key = restore_key[3:] self._worker_config.worker_activate( sample_idx, override_global_rank=global_worker_id, cache_pool=self._cache_pool ) try: return add_sample_restore_key( - self._dataset.restore_sample(restore_key), global_worker_id, sample_idx, src=self + self._dataset.restore_sample(restore_key), + global_worker_id, + sample_idx, + src=DataLoaderWorker.__name__, ) finally: self._worker_config.worker_deactivate() @@ -527,6 +530,10 @@ def with_restored_state_global( self.restore_state_global(state, src_rank=src_rank) return self + def can_restore_sample(self) -> bool: + """Check if the dataset can save and restore samples.""" + return self._dataset.can_restore_sample() + def config(self) -> dict[str, Any]: """Get the configuration of the dataset.""" return self._dataset.config() diff --git a/src/megatron/energon/dataloader/pin_memory.py b/src/megatron/energon/dataloader/pin_memory.py index 4e2dc247..bc35f2c6 100644 --- a/src/megatron/energon/dataloader/pin_memory.py +++ b/src/megatron/energon/dataloader/pin_memory.py @@ -71,6 +71,13 @@ def __init__( super().__init__(device) self._asynchronous_init(name="pin-memory") + def _worker_run(self, *args, **kwargs) -> None: + try: + torch.multiprocessing._set_thread_name("pt_data_pin") + except AttributeError: + pass + super()._worker_run(*args, **kwargs) + def _wrk_pin_memory(self, sample: Future[TSample]) -> TSample: print( f"[{self._name}] Pinning memory of sample {sample}, waiting for sample data\n", end="" diff --git a/src/megatron/energon/dataloader/workers/base_worker.py b/src/megatron/energon/dataloader/workers/base_worker.py index a650bffc..6fcc0bb2 100644 --- a/src/megatron/energon/dataloader/workers/base_worker.py +++ b/src/megatron/energon/dataloader/workers/base_worker.py @@ -102,7 +102,7 @@ def dataset_init(self, state: FlexState | None) -> None: "Global worker ID mismatch" ) assert self._seed == self.worker_config.worker_seed(self._rank_worker_id), "Seed mismatch" - print(f"dataset_init {state=}\n", end="") + print("dataset_init\n", end="") self.dataset.reset_state() if state is None: self._sample_index = 0 @@ -152,7 +152,7 @@ def prefetch_next(self) -> Future[TSample]: next_sample = next(self._dataset_iter) self._sample_index += 1 next_sample = add_sample_restore_key( - next_sample, self._global_worker_id, sample_idx, src=self + next_sample, self._global_worker_id, sample_idx, src=DataLoaderWorker.__name__ ) except StopIteration as e: self._exhausted = True diff --git a/src/megatron/energon/flavors/base_dataset.py b/src/megatron/energon/flavors/base_dataset.py index 4c2ee398..7c1a9cc1 100644 --- a/src/megatron/energon/flavors/base_dataset.py +++ b/src/megatron/energon/flavors/base_dataset.py @@ -274,6 +274,9 @@ class SavableDataset(IterableDataset[T_sample], Savable, Generic[T_sample], ABC) #: List of names of the fields that are saved and restored in the state. _savable_fields: ClassVar[Tuple[str, ...]] = () + #: List of names of the fields that are not saved, but are still part of the state (i.e. not shared between workers). + _state_fields: ClassVar[Tuple[str, ...]] = () + def __init__(self, worker_config: WorkerConfig): self.worker_config = worker_config if THREAD_SAFE: @@ -408,23 +411,26 @@ def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_s if THREAD_SAFE: def __getattribute__(self, name: str) -> Any: - if name in ("_savable_fields", "_thread_state", "worker_config"): + if name in ("_savable_fields", "_state_fields", "_thread_state", "worker_config"): return object.__getattribute__(self, name) - elif name in self._savable_fields: - return getattr(self._thread_state, name) + elif name in self._savable_fields or name in self._state_fields: + try: + return getattr(self._thread_state, name) + except AttributeError: + return object.__getattribute__(self, name) else: return object.__getattribute__(self, name) def __delattr__(self, name: str) -> None: - if name in self._savable_fields: + if name in self._savable_fields or name in self._state_fields: delattr(self._thread_state, name) else: object.__delattr__(self, name) def __setattr__(self, name: str, value: Any) -> None: - if name in ("_savable_fields", "_thread_state", "worker_config"): + if name in ("_savable_fields", "_state_fields", "_thread_state", "worker_config"): object.__setattr__(self, name, value) - elif name in self._savable_fields: + elif name in self._savable_fields or name in self._state_fields: setattr(self._thread_state, name, value) else: object.__setattr__(self, name, value) @@ -461,13 +467,15 @@ def add_sample_restore_key( """Adds a key to a sample. The sample must be a valid `Sample` or dict containing __restore_key__, which is a tuple of keys that can be used to restore the inner sample. This restore key is prepended with the `key`.""" + if not isinstance(src, str): + src = type(src).__name__ if isinstance(sample, Sample) or hasattr(sample, "__restore_key__"): try: - sample.__restore_key__ = (type(src).__name__, *key, *sample.__restore_key__) + sample.__restore_key__ = (src, *key, *sample.__restore_key__) except KeyError: pass elif isinstance(sample, dict) and "__restore_key__" in sample: - sample["__restore_key__"] = (type(src).__name__, *key, *sample["__restore_key__"]) + sample["__restore_key__"] = (src, *key, *sample["__restore_key__"]) elif fail_otherwise: raise RuntimeError( "Did not yield a sample with a restore key, but is marked stateless/deterministic." @@ -481,13 +489,15 @@ def set_sample_restore_key( """Sets the restore key for a sample. The sample must be a valid `Sample` or dict containing __restore_key__, which is a tuple of keys that can be used to restore the inner sample. This restore key is prepended with the `key`.""" + if not isinstance(src, str): + src = type(src).__name__ if isinstance(sample, Sample) or hasattr(sample, "__restore_key__"): try: - sample.__restore_key__ = (type(src).__name__, *key) + sample.__restore_key__ = (src, *key) except KeyError: pass elif isinstance(sample, dict) and "__restore_key__" in sample: - sample["__restore_key__"] = (type(src).__name__, *key) + sample["__restore_key__"] = (src, *key) elif fail_otherwise: raise RuntimeError( "Did not yield a sample with a restore key, but is marked stateless/deterministic." diff --git a/src/megatron/energon/task_encoder/base.py b/src/megatron/energon/task_encoder/base.py index f92664d1..48a0ad6e 100644 --- a/src/megatron/energon/task_encoder/base.py +++ b/src/megatron/energon/task_encoder/base.py @@ -707,6 +707,7 @@ def build_batch( fixed_batch_size=batch_size, sample_group_key=self.batch_group_criterion, batcher=self.batch, + batcher_stateless=get_stateless(self.batch), drop_last=batch_drop_last, worker_config=worker_config, failure_tolerance=get_failure_tolerance( diff --git a/src/megatron/energon/watchdog.py b/src/megatron/energon/watchdog.py index 2561d6a0..9755dd21 100644 --- a/src/megatron/energon/watchdog.py +++ b/src/megatron/energon/watchdog.py @@ -67,7 +67,9 @@ def __init__( # Condition variable to manage state changes self._cv = threading.Condition() # Background thread (daemon) that monitors timeouts - self._worker_thread = threading.Thread(target=self._worker, daemon=True) + self._worker_thread = threading.Thread( + name=f"watchdog-{id(self)}", target=self._worker, daemon=True + ) self._worker_thread.start() def _get_next_timeout(self) -> float: diff --git a/src/megatron/energon/wrappers/batch_dataset.py b/src/megatron/energon/wrappers/batch_dataset.py index 2b150bbb..edfcfd47 100644 --- a/src/megatron/energon/wrappers/batch_dataset.py +++ b/src/megatron/energon/wrappers/batch_dataset.py @@ -212,14 +212,14 @@ def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_b id, sample_idx, *samples_restore_keys = restore_key assert id == type(self).__name__ batch = [self.dataset.restore_sample(inner_idx) for inner_idx in samples_restore_keys] - with self._sample_index.ctx(sample_idx): + with SampleIndex(self.worker_config, src=self).ctx(sample_idx): batch_sample = self.batcher(batch) if isinstance(batch_sample, Generator): assert inspect.isgeneratorfunction(self.batcher), ( f"Generator in {self.batcher} but not marked as such." ) for cur_batch_sub_idx, (sample_idx, inner_batch_sample) in enumerate( - self._sample_index.iter_ctx(batch_sample, sample_idx) + SampleIndex(self.worker_config, src=self).iter_ctx(batch_sample, sample_idx) ): if cur_batch_sub_idx == batch_sub_idx: return set_sample_restore_key( diff --git a/src/megatron/energon/wrappers/buffer.py b/src/megatron/energon/wrappers/buffer.py index 8b6808a8..af7afb63 100644 --- a/src/megatron/energon/wrappers/buffer.py +++ b/src/megatron/energon/wrappers/buffer.py @@ -16,25 +16,24 @@ ) from megatron.energon.flavors.base_dataset import FlexState, SavableDataset +from megatron.energon.savable import Savable from megatron.energon.worker import WorkerConfig -from megatron.energon.wrappers.base import BaseWrapperDataset, get_sample_restore_key +from megatron.energon.wrappers.base import get_sample_restore_key T_sample = TypeVar("T_sample") -class SavableSampleBuffer(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]): - """A buffer of samples, savable.""" +class SavableSampleBuffer(Savable, Generic[T_sample]): + """A buffer of samples, savable. State is shared, create a state-local instance.""" _buffer: List[T_sample] _restore_keys: List[Tuple[Union[str, int, tuple], ...]] - _savable_fields = ("_restore_keys",) _restore_pending: bool = False def __init__(self, dataset: SavableDataset[T_sample], *, worker_config: WorkerConfig): - super().__init__(dataset, worker_config=worker_config) - - def reset_state_own(self) -> None: + self.dataset = dataset + self.worker_config = worker_config self._buffer = [] self._restore_keys = [] @@ -43,7 +42,7 @@ def worker_start(self) -> None: assert len(self._buffer) == 0 self._restore_pending = False for restore_key in self._restore_keys: - self._buffer.append(self.restore_sample(restore_key)) + self._buffer.append(self.dataset.restore_sample(restore_key)) assert len(self._buffer) == len(self._restore_keys) def append(self, sample: T_sample) -> T_sample: @@ -107,13 +106,18 @@ def len_rank(self) -> int: def save_state(self) -> FlexState: # Don't call super().save_state() because we don't want to save the wrapped datasets # Just save the own state - return SavableDataset.save_state(self) + return FlexState( + __class__=type(self).__name__, + _restore_keys=self._restore_keys, + ) def restore_state(self, state: FlexState) -> None: # Don't call super().restore_state() because we don't want to restore the wrapped datasets # Just restore the own state - SavableDataset.restore_state(self, state) - + assert state["__class__"] == type(self).__name__, ( + f"Expected class {type(self).__name__}, got {state['__class__']}" + ) + self._restore_keys = state["_restore_keys"].copy() self._restore_pending = True def restore_key(self) -> Tuple[Union[str, int], ...]: @@ -125,7 +129,7 @@ def restore_samples( buffer = [] restore_keys = [] for sub_index in index: - sample = self.restore_sample(sub_index) + sample = self.dataset.restore_sample(sub_index) restore_keys.append(get_sample_restore_key(sample)) buffer.append(sample) return tuple(restore_keys), buffer diff --git a/src/megatron/energon/wrappers/group_batch_dataset.py b/src/megatron/energon/wrappers/group_batch_dataset.py index 23e8d2e6..2ca78620 100644 --- a/src/megatron/energon/wrappers/group_batch_dataset.py +++ b/src/megatron/energon/wrappers/group_batch_dataset.py @@ -73,6 +73,10 @@ class GroupBatchDataset( _batch_sample_index: SampleIndex _buckets: Dict[Hashable, Bucket[T_batch_sample]] + _savable_fields = ("_group_key_sample_index", "_batch_sample_index") + # Buckets are saved manually + _state_fields = ("_buckets",) + def __init__( self, dataset: SavableDataset[T_batch_sample], @@ -219,17 +223,13 @@ def flush(bucket: Bucket[T_batch_sample]) -> Generator[T_batch, None, None]: def save_state(self) -> FlexState: return FlexState( - bucket_sample_index=self._group_key_sample_index.save_state(), - batch_sample_index=self._batch_sample_index.save_state(), - buckets={key: bucket.save_state() for key, bucket in self._buckets.items()}, **super().save_state(), + buckets={key: bucket.save_state() for key, bucket in self._buckets.items()}, ) def restore_state(self, state: FlexState) -> None: super().restore_state(state) - self._group_key_sample_index.restore_state(state["bucket_sample_index"]) - self._batch_sample_index.restore_state(state["batch_sample_index"]) for key, bucket_state in state["buckets"].items(): self._buckets[key] = Bucket( batch_size=-1, @@ -251,7 +251,7 @@ def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_batch: id, sample_idx, *sample_restore_keys = index assert id == type(self).__name__ batch = [self.dataset.restore_sample(inner_idx) for inner_idx in sample_restore_keys] - with self._batch_sample_index.ctx(sample_idx): + with SampleIndex(self.worker_config, src=self).ctx(sample_idx): batch_sample = self.batcher(batch) set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self) return batch_sample diff --git a/src/megatron/energon/wrappers/iter_map_dataset.py b/src/megatron/energon/wrappers/iter_map_dataset.py index 3faf02ee..41c05a15 100644 --- a/src/megatron/energon/wrappers/iter_map_dataset.py +++ b/src/megatron/energon/wrappers/iter_map_dataset.py @@ -150,12 +150,13 @@ def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_s ) ) try: + sample_index = SampleIndex(self.worker_config, src=self) # Skip inner yielded samples to get the correct sample for skip_idx in range(iter_idx): - with self._sample_index.ctx(sample_idx - iter_idx + skip_idx): + with sample_index.ctx(sample_idx - iter_idx + skip_idx): next(inner_iter) # This is the sample to restore - with self._sample_index.ctx(sample_idx): + with sample_index.ctx(sample_idx): sample = next(inner_iter) return set_sample_restore_key( sample, diff --git a/src/megatron/energon/wrappers/map_dataset.py b/src/megatron/energon/wrappers/map_dataset.py index b3e35ec6..91810f16 100644 --- a/src/megatron/energon/wrappers/map_dataset.py +++ b/src/megatron/energon/wrappers/map_dataset.py @@ -188,14 +188,14 @@ def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_s assert id == type(self).__name__ restore_key = restore_key[2:] inner_sample = self.dataset.restore_sample(restore_key) - with self._sample_index.ctx(sample_idx): + with SampleIndex(self.worker_config, src=self).ctx(sample_idx): mapped_sample = self.map_fn(inner_sample) if isinstance(mapped_sample, Generator): assert inspect.isgeneratorfunction(self.map_fn), ( f"Generator in {self.map_fn} but not marked as such." ) for idx, (sample_idx, res_sample) in enumerate( - self._sample_index.iter_ctx(mapped_sample, sample_idx) + SampleIndex(self.worker_config, src=self).iter_ctx(mapped_sample, sample_idx) ): if idx == local_idx: return add_sample_restore_key(res_sample, sample_idx, local_idx, src=self) diff --git a/src/megatron/energon/wrappers/packing_dataset.py b/src/megatron/energon/wrappers/packing_dataset.py index 1bc28c61..2727ce9b 100644 --- a/src/megatron/energon/wrappers/packing_dataset.py +++ b/src/megatron/energon/wrappers/packing_dataset.py @@ -416,19 +416,19 @@ def restore_sample(self, restore_key: Any) -> T_sample: assert isinstance(sample_idx, int) sample = self.dataset.restore_sample(inner_idx) if self.sample_encoder is not None: - with self._sample_encoder_sample_index.ctx(sample_idx): + with SampleIndex(self.worker_config, src=self).ctx(sample_idx): sample = self.sample_encoder(sample) assert not isinstance(sample, Generator), "Generator not supported" sample = add_sample_restore_key(sample, sample_idx, src=self) pack.append(sample) - with self._final_packing_sample_index.ctx(pack_idx): + with SampleIndex(self.worker_config, src=self).ctx(pack_idx): final_pack = self.final_packer(pack) if isinstance(final_pack, Generator): assert inspect.isgeneratorfunction(self.final_packer), ( f"Generator in {self.final_packer} but not marked as such." ) for cur_batch_sub_idx, (pack_idx, inner_batch_sample) in enumerate( - self._final_packing_sample_index.iter_ctx(final_pack, pack_idx) + SampleIndex(self.worker_config, src=self).iter_ctx(final_pack, pack_idx) ): if cur_batch_sub_idx == pack_sub_idx: return set_sample_restore_key( diff --git a/src/megatron/energon/wrappers/shuffle_buffer_dataset.py b/src/megatron/energon/wrappers/shuffle_buffer_dataset.py index 19926436..9337f1c5 100644 --- a/src/megatron/energon/wrappers/shuffle_buffer_dataset.py +++ b/src/megatron/energon/wrappers/shuffle_buffer_dataset.py @@ -19,7 +19,7 @@ class ShuffleBufferDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sam _worker_rng: WorkerRng _active_buffer: SavableSampleBuffer[T_sample] - _savable_fields = ("_active_buffer", "_worker_rng") + _savable_fields = ("_worker_rng", "_active_buffer") def __init__( self, @@ -53,7 +53,7 @@ def __iter__(self) -> Iterator[T_sample]: yield self._active_buffer.pop(pop_idx) def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_sample: - return self._active_buffer.restore_sample(restore_key) + return self.dataset.restore_sample(restore_key) def config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/wrappers/task_encoder_state_dataset.py b/src/megatron/energon/wrappers/task_encoder_state_dataset.py index 32534acc..6363f4c9 100644 --- a/src/megatron/energon/wrappers/task_encoder_state_dataset.py +++ b/src/megatron/energon/wrappers/task_encoder_state_dataset.py @@ -57,7 +57,7 @@ def __iter__(self) -> Iterator[T_sample]: def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_sample: inner_sample = self.dataset.restore_sample(restore_key) - self._task_encoder.restore_sample() + inner_sample = self._task_encoder.restore_sample(inner_sample) return inner_sample def config(self) -> Dict[str, Any]: diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 39dd7ca3..76b66e70 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -212,6 +212,93 @@ def test_dataloader_fork(self): ] assert train_order2 == cmp_order2, (train_order1, cmp_order2) + def test_dataloader_fork_multi_parallel(self): + torch.manual_seed(42) + worker_config_r0 = WorkerConfig( + rank=0, + world_size=2, + num_workers=2, + seed_offset=42, + ) + worker_config_r1 = WorkerConfig( + rank=1, + world_size=2, + num_workers=2, + seed_offset=42, + ) + + # Train mode dataset + train_loader_r0 = DataLoader( + get_train_dataset( + self.ds1_path, + worker_config=worker_config_r0, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + ), + prefetch_factor=2, + worker_type=ForkDataLoaderWorker, + gc_collect_every_n_steps=10, + gc_freeze_at_start=True, + watchdog_timeout_seconds=60, + fail_on_timeout=True, + ) + assert len(train_loader_r0) == 4, len(train_loader_r0) + + train_order1_r0 = [ + text for idx, data in zip(range(55 * 10), train_loader_r0) for text in data.text + ] + print(train_order1_r0[:10]) + print(Counter(train_order1_r0)) + assert len(train_order1_r0) == 28, len(train_order1_r0) + assert len(Counter(train_order1_r0)) == 28, Counter(train_order1_r0) + assert all(v == 1 for v in Counter(train_order1_r0).values()), Counter(train_order1_r0) + + train_loader_r1 = DataLoader( + get_train_dataset( + self.ds1_path, + worker_config=worker_config_r1, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + ), + prefetch_factor=2, + worker_type=ForkDataLoaderWorker, + gc_collect_every_n_steps=10, + gc_freeze_at_start=True, + watchdog_timeout_seconds=60, + fail_on_timeout=True, + ) + assert len(train_loader_r1) == 4, len(train_loader_r1) + + train_order1_r1 = [ + text for idx, data in zip(range(55 * 10), train_loader_r1) for text in data.text + ] + print(train_order1_r1[:10]) + print(Counter(train_order1_r1)) + assert len(train_order1_r1) == 27, len(train_order1_r1) + assert len(Counter(train_order1_r1)) == 27, Counter(train_order1_r1) + assert all(v == 1 for v in Counter(train_order1_r1).values()), Counter(train_order1_r1) + + train_loader_r1.save_state_rank() + + train_loader_r0.save_state_rank() + + train_order2_r0 = [ + text for idx, data in zip(range(55 * 10), train_loader_r0) for text in data.text + ] + assert len(train_order2_r0) == 28 + + train_order2_r1 = [ + text for idx, data in zip(range(55 * 10), train_loader_r1) for text in data.text + ] + assert len(train_order2_r1) == 27 + + train_loader_r0.shutdown() + train_loader_r1.shutdown() + def test_dataloader_thread(self): torch.manual_seed(42) worker_config = WorkerConfig( From e4db0ae83917491c8668602a559f2037a700543d Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Fri, 15 Aug 2025 13:16:36 +0200 Subject: [PATCH 14/36] Fix several minor issues, e.g. resetting the state properly after ending a generator, saving state of sample buffer, skipping state of watchdog, log, gc dataset. Adapt required changes in tests. Add typing to specific saved states (instead of FlexState). Make debug output conditional. --- src/megatron/energon/__init__.py | 2 - .../energon/dataloader/asynchronous/base.py | 95 +- .../energon/dataloader/asynchronous/fork.py | 20 +- src/megatron/energon/dataloader/dataloader.py | 109 +- src/megatron/energon/dataloader/pin_memory.py | 13 +- .../dataloader/workers/async_worker.py | 7 +- .../energon/dataloader/workers/base_worker.py | 50 +- src/megatron/energon/flavors/base_dataset.py | 2 +- .../flavors/webdataset/error_handler.py | 1 + src/megatron/energon/rng.py | 38 +- src/megatron/energon/savable_loader.py | 1406 ----------------- src/megatron/energon/task_encoder/base.py | 4 +- src/megatron/energon/worker.py | 6 +- .../energon/wrappers/batch_dataset.py | 2 +- src/megatron/energon/wrappers/buffer.py | 24 +- src/megatron/energon/wrappers/gc_dataset.py | 9 + .../energon/wrappers/limit_dataset.py | 1 - .../energon/wrappers/log_sample_dataset.py | 9 + .../energon/wrappers/packing_dataset.py | 4 - .../energon/wrappers/watchdog_dataset.py | 13 + tests/test_crudedataset.py | 3 + tests/test_dataloader.py | 37 + tests/test_dataset.py | 18 +- tests/test_dataset_det.py | 4 + tests/test_metadataset.py | 439 ++--- 25 files changed, 564 insertions(+), 1752 deletions(-) delete mode 100644 src/megatron/energon/savable_loader.py diff --git a/src/megatron/energon/__init__.py b/src/megatron/energon/__init__.py index 2da8d544..203a005f 100644 --- a/src/megatron/energon/__init__.py +++ b/src/megatron/energon/__init__.py @@ -66,7 +66,6 @@ load_dataset, prepare_metadataset, ) -from megatron.energon.savable_loader import SavableDataLoader from megatron.energon.source_info import SourceInfo from megatron.energon.task_encoder import ( AugmentTaskEncoder, @@ -180,7 +179,6 @@ "RepeatDataset", "Sample", "SampleDecoder", - "SavableDataLoader", "SavableDataset", "ShuffleBufferDataset", "SimilarityInterleavedSample", diff --git a/src/megatron/energon/dataloader/asynchronous/base.py b/src/megatron/energon/dataloader/asynchronous/base.py index cfcacefc..4ba11813 100644 --- a/src/megatron/energon/dataloader/asynchronous/base.py +++ b/src/megatron/energon/dataloader/asynchronous/base.py @@ -13,6 +13,9 @@ R = TypeVar("R", covariant=True) +DEBUG_LEVEL = 0 + + class QueueProtocol(Protocol[T]): """Protocol for a queue.""" @@ -65,10 +68,11 @@ def get(self) -> Any: def cancel(self) -> bool: if hasattr(self, "_result") or hasattr(self, "_exception"): - print( - f"[{self._worker._name}, fut={self._future_id}] already has result or exception\n", - end="", - ) + if DEBUG_LEVEL >= 1: + print( + f"[{self._worker._name}, fut={self._future_id}] already has result or exception\n", + end="", + ) return False self._exception = CancelledError.with_current_traceback() self._worker._cancel_future(self._future_id) @@ -115,13 +119,15 @@ def _wait_for_worker_result(self, future: FutureImpl) -> None: Args: future: The future to wait for. """ - print(f"[{self._name}, fut={future._future_id}] waiting for result\n", end="") + if DEBUG_LEVEL >= 1: + print(f"[{self._name}, fut={future._future_id}] waiting for result\n", end="") with self._result_lock: if future.done(): # If calling get() from multiple threads, the future may be done now, because # the other thread already set the result. return - print(f"[{self._name}, fut={future._future_id}] got future\n", end="") + if DEBUG_LEVEL >= 2: + print(f"[{self._name}, fut={future._future_id}] got future\n", end="") while True: res = self._result_queue.get() fut = self._pending_futures.pop(res.future_id) @@ -130,18 +136,23 @@ def _wait_for_worker_result(self, future: FutureImpl) -> None: else: fut._set_result(res.result) if res.future_id == future._future_id: - print(f"[{self._name}, fut={future._future_id}] got result, return\n", end="") + if DEBUG_LEVEL >= 2: + print( + f"[{self._name}, fut={future._future_id}] got result, return\n", end="" + ) return else: - print( - f"[{self._name}, fut={future._future_id}] got result for {res.future_id=}, continue\n", - end="", - ) + if DEBUG_LEVEL >= 2: + print( + f"[{self._name}, fut={future._future_id}] got result for {res.future_id=}, continue\n", + end="", + ) continue def _cancel_future(self, future_id: int) -> None: """Cancel a future.""" - print(f"[{self._name}, fut={future_id}] cancelling future\n", end="") + if DEBUG_LEVEL >= 1: + print(f"[{self._name}, fut={future_id}] cancelling future\n", end="") # In case the main process is waiting for thie future to complete, add the result self._result_queue.put( WorkerResult(future_id=future_id, exception=CancelledError.with_current_traceback()) @@ -170,14 +181,16 @@ def _worker_call(self, fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> self._next_future_id += 1 self._pending_futures[future_id] = future = FutureImpl(self, future_id) - print( - f"[{self._name}] worker_call {fn.__name__=} {future_id=}\n", - end="", - ) + if DEBUG_LEVEL >= 2: + print( + f"[{self._name}] worker_call {fn.__name__=} {future_id=}\n", + end="", + ) self._cmd_queue.put( WorkerCommand(cmd=fn.__name__, args=args, kwargs=kwargs, future_id=future_id) ) - print(f"[{self._name}] cmd_queue: {self._cmd_queue.qsize()=}\n", end="") + if DEBUG_LEVEL >= 2: + print(f"[{self._name}] cmd_queue: {self._cmd_queue.qsize()=}\n", end="") return future def _worker_run( @@ -197,37 +210,46 @@ def _worker_run( assert self._in_worker(), "_worker_run must be called in the worker" try: while True: - print( - f"[{self._name}] waiting for command {cmd_queue.qsize()=}\n", - end="", - ) + if DEBUG_LEVEL >= 2: + print( + f"[{self._name}] waiting for command {cmd_queue.qsize()=}\n", + end="", + ) cmd = cmd_queue.get() - print( - f"[{self._name}, fut={cmd.future_id}] got command {cmd.cmd=}\n", - end="", - ) + if DEBUG_LEVEL >= 2: + print( + f"[{self._name}, fut={cmd.future_id}] got command {cmd.cmd=}\n", + end="", + ) try: fn = getattr(self, cmd.cmd) result = fn(*cmd.args, **cmd.kwargs) except Exception as e: - print(f"[{self._name}, fut={cmd.future_id}] send exception {e!r}\n", end="") + if DEBUG_LEVEL >= 2: + print(f"[{self._name}, fut={cmd.future_id}] send exception {e!r}\n", end="") result_queue.put(WorkerResult(future_id=cmd.future_id, exception=e)) - print(f"[{self._name}] result_queue: {result_queue.qsize()=}\n", end="") + if DEBUG_LEVEL >= 2: + print(f"[{self._name}] result_queue: {result_queue.qsize()=}\n", end="") else: - print(f"[{self._name}, fut={cmd.future_id}] send result\n", end="") + if DEBUG_LEVEL >= 2: + print(f"[{self._name}, fut={cmd.future_id}] send result\n", end="") result_queue.put(WorkerResult(future_id=cmd.future_id, result=result)) - print(f"[{self._name}] result_queue: {result_queue.qsize()=}\n", end="") + if DEBUG_LEVEL >= 2: + print(f"[{self._name}] result_queue: {result_queue.qsize()=}\n", end="") del result # cmd_queue.task_done() if cmd.cmd == self._wrk_shutdown_worker.__name__: + if DEBUG_LEVEL >= 1: + print( + f"[{self._name}, fut={cmd.future_id}] got shutdown command, exit\n", + end="", + ) + break + if DEBUG_LEVEL >= 2: print( - f"[{self._name}, fut={cmd.future_id}] got shutdown command, exit\n", end="" + f"[{self._name}, fut={cmd.future_id}] processed, waiting for next command\n", + end="", ) - break - print( - f"[{self._name}, fut={cmd.future_id}] processed, waiting for next command\n", - end="", - ) except: traceback.print_exc() raise @@ -252,7 +274,8 @@ def _shutdown_worker(self) -> None: # This is not actually a recursive call, because the worker loop will exit before calling this method. self._worker_call(self._wrk_shutdown_worker).get() self._cancel_futures() - print(f"[{self._name}] shutdown\n", end="") + if DEBUG_LEVEL >= 1: + print(f"[{self._name}] shutdown\n", end="") @abstractmethod def start(self) -> None: ... diff --git a/src/megatron/energon/dataloader/asynchronous/fork.py b/src/megatron/energon/dataloader/asynchronous/fork.py index d8f12a0f..a99625e6 100644 --- a/src/megatron/energon/dataloader/asynchronous/fork.py +++ b/src/megatron/energon/dataloader/asynchronous/fork.py @@ -15,6 +15,8 @@ WorkerResult, ) +DEBUG_LEVEL = 1 + class ForkAsynchronous(Asynchronous): """Mixin for asynchronous workers that use processes.""" @@ -76,19 +78,23 @@ def _worker_run( try: super()._worker_run(cmd_queue, result_queue) finally: - print(f"[{self._name}] shutting down\n", end="") + if DEBUG_LEVEL >= 1: + print(f"[{self._name}] shutting down\n", end="") worker_exit_evt.set() - print( - f"[{self._name}] shutting down, wait for parent_check_thread\n", - end="", - ) + if DEBUG_LEVEL >= 1: + print( + f"[{self._name}] shutting down, wait for parent_check_thread\n", + end="", + ) parent_check_thread.join() - print(f"[{self._name}] shutting down, close queues\n", end="") + if DEBUG_LEVEL >= 1: + print(f"[{self._name}] shutting down, close queues\n", end="") result_queue.close() result_queue.join_thread() cmd_queue.close() cmd_queue.cancel_join_thread() - print(f"[{self._name}] shutting down, done\n", end="") + if DEBUG_LEVEL >= 1: + print(f"[{self._name}] shutting down, done\n", end="") def _in_worker(self) -> bool: return torch.multiprocessing.current_process() == self._process diff --git a/src/megatron/energon/dataloader/dataloader.py b/src/megatron/energon/dataloader/dataloader.py index bfcc9faf..dbbd763a 100644 --- a/src/megatron/energon/dataloader/dataloader.py +++ b/src/megatron/energon/dataloader/dataloader.py @@ -6,6 +6,7 @@ import warnings from typing import ( Any, + ClassVar, Generator, Generic, Literal, @@ -19,10 +20,14 @@ from megatron.energon.cache.base import CachePool from megatron.energon.dataloader.future import CallableFuture, Future from megatron.energon.dataloader.pin_memory import NoPinMemory, PinMemory, PinMemoryThread -from megatron.energon.dataloader.workers.base_worker import DataLoaderWorker +from megatron.energon.dataloader.workers.base_worker import ( + DataLoaderNoWorker, + DataLoaderWorker, + WorkerState, +) from megatron.energon.dataloader.workers.fork_worker import ForkDataLoaderWorker +from megatron.energon.edataclass import edataclass from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key -from megatron.energon.state import FlexState from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset, get_sample_restore_key from megatron.energon.wrappers.batch_dataset import BatchDataset @@ -31,6 +36,20 @@ TSample = TypeVar("TSample", covariant=True) +DEBUG_LEVEL = 1 + + +@edataclass +class RankState: + """ + State of a rank. + """ + + prefetched_samples_keys: list[Any] + worker_states: list[WorkerState | None] + next_worker_id: int + micro_batch_size: int | None + class WorkerType(Protocol[TSample]): """Protocol for a worker type, i.e. for the constructor of a worker class.""" @@ -51,11 +70,14 @@ class DataLoader(Generic[TSample]): to avoid leaking workers (fixes a bug). """ + _next_id: ClassVar[int] = 0 + _id: int + _workers: list[DataLoaderWorker[TSample]] | None = None _exhausted_workers: list[bool] _next_worker_id: int = 0 - _restore_state: FlexState | None = None + _restore_state: RankState | None = None _dataset: SavableDataset _worker_config: WorkerConfig @@ -109,8 +131,11 @@ def __init__( If "automatic", the memory is pinned automatically if cuda is available. If a `PinMemory` instance, the instance may only be used for one `DataLoader`. """ + self._id = DataLoader._next_id + DataLoader._next_id += 1 + if dataset.worker_config.num_workers == 0 and worker_type == ForkDataLoaderWorker: - worker_type = DataLoaderWorker + worker_type = DataLoaderNoWorker if watchdog_timeout_seconds is not None: dataset = WatchdogDataset( @@ -168,7 +193,7 @@ def _start(self) -> None: if self._restore_state is None: worker_states = [None] * self._worker_config.safe_num_workers else: - worker_states = self._restore_state["worker_states"] + worker_states = self._restore_state.worker_states assert len(worker_states) == self._worker_config.safe_num_workers, ( "Number of initial states must match number of workers" @@ -185,11 +210,11 @@ def _start(self) -> None: ) for sample_key in prefetched_samples_keys ] - for prefetched_samples_keys in self._restore_state["prefetched_samples_keys"] + for prefetched_samples_keys in self._restore_state.prefetched_samples_keys ] - self._next_worker_id = self._restore_state["next_worker_id"] + self._next_worker_id = self._restore_state.next_worker_id self._exhausted_workers = [ - False if worker_state is None else worker_state["exhausted"] + False if worker_state is None else worker_state.exhausted for worker_state in worker_states ] # State was restored, clear @@ -243,6 +268,8 @@ def _epoch_iter(self) -> Generator[TSample, None, None]: for worker in self._workers: worker.new_iter() self._exhausted_workers = [False] * self._worker_config.safe_num_workers + # Ensure deterministic interleaving across epochs by starting from worker 0 + self._next_worker_id = 0 # For all workers, enqueue prefetching samples. for worker_idx, (worker, exhausted) in enumerate( @@ -261,19 +288,23 @@ def _epoch_iter(self) -> Generator[TSample, None, None]: # - Pop the first sample future from the prefetching samples. # - Get the sample from the sample future (may wait for the sample to be prefetched). # - Yield the sample. - print(f"{self._exhausted_workers=}\n", end="") + if DEBUG_LEVEL >= 1: + print(f"{self._exhausted_workers=}\n", end="") while not all(self._exhausted_workers): # Get the next worker to prefetch samples from. worker_idx = self._next_worker_id worker = self._workers[worker_idx] - print(f"{worker_idx=} {worker=}\n", end="") + if DEBUG_LEVEL >= 2: + print(f"{worker_idx=} {worker=}\n", end="") self._next_worker_id = (worker_idx + 1) % self._worker_config.safe_num_workers if self._exhausted_workers[worker_idx]: - print(f"{worker_idx=} exhausted, continue with next worker\n", end="") + if DEBUG_LEVEL >= 1: + print(f"{worker_idx=} exhausted, continue with next worker\n", end="") continue # Pop the first sample future from the prefetching samples. sample_future = self._prefetching_samples[worker_idx].pop(0) - print(f"{sample_future=}\n", end="") + if DEBUG_LEVEL >= 2: + print(f"{sample_future=}\n", end="") # Prefetch samples from the worker. while len(self._prefetching_samples[worker_idx]) < self._prefetch_factor: # Add a new sample future to the prefetching samples if the worker has not prefetched enough samples. @@ -284,13 +315,15 @@ def _epoch_iter(self) -> Generator[TSample, None, None]: # Get the sample from the sample future (may wait for the sample to be ready). sample = sample_future.get() except StopIteration: - print(f"{worker_idx=} exhausted, remove from prefetching samples\n", end="") + if DEBUG_LEVEL >= 1: + print(f"{worker_idx=} exhausted, remove from prefetching samples\n", end="") # If the sample future raises StopIteration, remove the worker from the list. self._prefetching_samples[worker_idx] = [] self._exhausted_workers[worker_idx] = True continue else: - print(f"{worker_idx=} got sample, yield\n", end="") + if DEBUG_LEVEL >= 2: + print(f"{worker_idx=} got sample, yield\n", end="") # Yield the sample. yield sample @@ -298,10 +331,19 @@ def __iter__(self) -> Generator[TSample, None, None]: # Restart the epoch iterator if was not created yet. Otherwise, the existing epoch iterator will be continued. # That happens e.g. when iteration was interrupted. if self._current_epoch_iter is None: + if DEBUG_LEVEL >= 1: + print("DL: Starting epoch iterator") self._current_epoch_iter = self._epoch_iter() + else: + if DEBUG_LEVEL >= 1: + print("DL: Continuing epoch iterator") assert self._current_epoch_iter is not None - yield from self._current_epoch_iter + # Important: Do not use yield from here, as it will delegate .close to the inner generator. + for sample in self._current_epoch_iter: + yield sample # Reset the epoch iterator, it was exhausted. + if DEBUG_LEVEL >= 1: + print("DL: Closing epoch iterator") self._current_epoch_iter.close() self._current_epoch_iter = None @@ -319,29 +361,33 @@ def _get_batch_size(self) -> int | None: else: return None - def save_state_rank(self) -> FlexState: - # TODO: The redist tool must be able to change the batch size. - # That means that the redist tool shall split a saved restore key for the "BatchDataset". - # It should also change the saved micro batch size to match that. - # TODO @pfischer: Add changing the batch size to the docs. + def save_state_rank(self) -> RankState: + if self._restore_state is not None: + return self._restore_state prefetched_samples_keys = [ [get_sample_restore_key(sample_fut.get()) for sample_fut in prefetching_sample] for prefetching_sample in self._prefetching_samples ] + worker_states: list[WorkerState | None] if self._workers is None: worker_states = [None] * self._worker_config.safe_num_workers else: worker_states = [worker.save_state() for worker in self._workers] - return FlexState( - __class__=type(self).__name__, + # Make sure that the exhausted_workers match the individual worker states + assert all( + worker_state is None or worker_state.exhausted == exhausted_worker + for worker_state, exhausted_worker in zip(worker_states, self._exhausted_workers) + ), "Exhausted workers mismatch" + + return RankState( prefetched_samples_keys=prefetched_samples_keys, worker_states=worker_states, next_worker_id=self._next_worker_id, micro_batch_size=self._get_batch_size(), ) - def save_state_global(self, global_dst_rank: int) -> Sequence[FlexState | None] | None: + def save_state_global(self, global_dst_rank: int) -> Sequence[RankState | None] | None: """ Saves the state of the dataset globally, collecting the state from all ranks using torch distributed. Allows for restoring the state later using `restore_state_global`, given the @@ -365,7 +411,7 @@ def save_state_global(self, global_dst_rank: int) -> Sequence[FlexState | None] # Gather the merged states if self._worker_config.world_size > 1: - output: Sequence[FlexState | None] | None + output: Sequence[RankState | None] | None if self._worker_config.global_rank() == global_dst_rank: output = [None] * self._worker_config.world_size else: @@ -394,7 +440,7 @@ def save_state_global(self, global_dst_rank: int) -> Sequence[FlexState | None] # Not distributed -> return the merged state return [merged_state] - def restore_state_rank(self, state: FlexState | None) -> None: + def restore_state_rank(self, state: RankState | None) -> None: """ Restore the state of the DataLoader on the current rank. The state is actually restored when the processes are started, in the iterator. @@ -408,15 +454,14 @@ def restore_state_rank(self, state: FlexState | None) -> None: # Assume initial state. return - assert isinstance(state, FlexState) - assert state["__class__"] == type(self).__name__, "DataLoader type mismatch" - assert state["micro_batch_size"] == self._get_batch_size(), "Micro batch size mismatch" + assert isinstance(state, RankState) + assert state.micro_batch_size == self._get_batch_size(), "Micro batch size mismatch" self._restore_state = state def restore_state_global( self, - state: Sequence[FlexState | None] | None, + state: Sequence[RankState | None] | None, *, src_rank: int | None = None, ) -> None: @@ -514,7 +559,7 @@ def restore_sample(self, restore_key: tuple) -> TSample: finally: self._worker_config.worker_deactivate() - def with_restored_state_rank(self, state: FlexState | None) -> "DataLoader[TSample]": + def with_restored_state_rank(self, state: RankState | None) -> "DataLoader[TSample]": """ Use this data loader and restore the state. Useful for chaining commands. See `save_state_rank` for more details. """ @@ -522,7 +567,7 @@ def with_restored_state_rank(self, state: FlexState | None) -> "DataLoader[TSamp return self def with_restored_state_global( - self, state: Sequence[FlexState | None] | None, src_rank: int | None = None + self, state: Sequence[RankState | None] | None, src_rank: int | None = None ) -> "DataLoader[TSample]": """ Use this data loader and restore the state. Useful for chaining commands. See `save_state_global` for more details. @@ -539,4 +584,4 @@ def config(self) -> dict[str, Any]: return self._dataset.config() def __str__(self) -> str: - return f"DataLoader(prefetch_factor={self._prefetch_factor}, worker_type={self._worker_type.__name__})" + return f"DataLoader(_id={self._id}, prefetch_factor={self._prefetch_factor}, worker_type={self._worker_type.__name__})" diff --git a/src/megatron/energon/dataloader/pin_memory.py b/src/megatron/energon/dataloader/pin_memory.py index bc35f2c6..60b29a00 100644 --- a/src/megatron/energon/dataloader/pin_memory.py +++ b/src/megatron/energon/dataloader/pin_memory.py @@ -12,6 +12,8 @@ TSample = TypeVar("TSample") T = TypeVar("T") +DEBUG_LEVEL = 1 + class PinMemory(Generic[TSample]): """Base class for pinning memory of samples. @@ -79,11 +81,14 @@ def _worker_run(self, *args, **kwargs) -> None: super()._worker_run(*args, **kwargs) def _wrk_pin_memory(self, sample: Future[TSample]) -> TSample: - print( - f"[{self._name}] Pinning memory of sample {sample}, waiting for sample data\n", end="" - ) + if DEBUG_LEVEL >= 2: + print( + f"[{self._name}] Pinning memory of sample {sample}, waiting for sample data\n", + end="", + ) sample_data = sample.get() - print(f"[{self._name}] Got sample data\n", end="") + if DEBUG_LEVEL >= 2: + print(f"[{self._name}] Got sample data\n", end="") return self._pin_memory(sample_data) def __call__(self, sample: Future[TSample]) -> Future[TSample]: diff --git a/src/megatron/energon/dataloader/workers/async_worker.py b/src/megatron/energon/dataloader/workers/async_worker.py index d4a4e90b..0647264c 100644 --- a/src/megatron/energon/dataloader/workers/async_worker.py +++ b/src/megatron/energon/dataloader/workers/async_worker.py @@ -10,10 +10,9 @@ WorkerResult, ) from megatron.energon.dataloader.future import Future -from megatron.energon.dataloader.workers.base_worker import DataLoaderWorker +from megatron.energon.dataloader.workers.base_worker import DataLoaderWorker, WorkerState from megatron.energon.flavors.base_dataset import SavableDataset from megatron.energon.rng import SystemRng -from megatron.energon.state import FlexState from megatron.energon.worker import WorkerConfig TSample = TypeVar("TSample", covariant=True) @@ -56,7 +55,7 @@ def _wrk_prefetch_next(self) -> TSample: # so immediately resolve the future to the result (get returns immediately). return super().prefetch_next().get() - def dataset_init(self, initial_state: FlexState | None) -> None: + def dataset_init(self, initial_state: WorkerState | None) -> None: if self._in_worker(): return super().dataset_init(initial_state) else: @@ -74,7 +73,7 @@ def prefetch_next(self) -> Future[TSample]: return super().prefetch_next() return self._worker_call(self._wrk_prefetch_next) - def save_state(self) -> FlexState: + def save_state(self) -> WorkerState: if self._in_worker(): return super().save_state() else: diff --git a/src/megatron/energon/dataloader/workers/base_worker.py b/src/megatron/energon/dataloader/workers/base_worker.py index 6fcc0bb2..af7ec73f 100644 --- a/src/megatron/energon/dataloader/workers/base_worker.py +++ b/src/megatron/energon/dataloader/workers/base_worker.py @@ -4,14 +4,33 @@ from megatron.energon.cache.base import CachePool from megatron.energon.dataloader.future import DoneFuture, ExceptionFuture, Future +from megatron.energon.edataclass import edataclass from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key -from megatron.energon.rng import SystemRng +from megatron.energon.rng import SystemRng, SystemRngState from megatron.energon.state import FlexState from megatron.energon.worker import WorkerConfig TSample = TypeVar("TSample", covariant=True) +@edataclass +class WorkerState: + """ + State of a worker. + """ + + rng: SystemRngState + dataset: FlexState + exhausted: bool + sample_index: int + + def __str__(self): + from hashlib import sha256 + + rng_hash = sha256(str(self.rng).encode()).hexdigest() + return f"WorkerState(dataset={self.dataset}, exhausted={self.exhausted}, sample_index={self.sample_index}, rng_hash={rng_hash})" + + class DataLoaderWorker(Generic[TSample]): """ A worker for a :class:`DataLoader`. @@ -89,7 +108,7 @@ def __del__(self) -> None: # ------------------------------------------------------------------------------------------------ # Section: Worker methods - def dataset_init(self, state: FlexState | None) -> None: + def dataset_init(self, state: WorkerState | None) -> None: """ Initialize the worker (may restore the state). Calls `new_iter` if the worker is not exhausted and also initially (`state=None`). @@ -110,13 +129,13 @@ def dataset_init(self, state: FlexState | None) -> None: self.new_iter() print("dataset_init new_iter\n", end="") else: - assert state["__class__"] == "DataLoaderWorker", "state type mismatch" - self._sample_index = state["sample_index"] - SystemRng.restore_state(state["rng"]) - self.dataset.restore_state(state["dataset"]) - if not state["exhausted"]: + print(f"dataset_init restore_state: {state=}\n", end="") + self._sample_index = state.sample_index + SystemRng.restore_state(state.rng) + self.dataset.restore_state(state.dataset) + if not state.exhausted: self.new_iter() - assert self._exhausted == state["exhausted"], "Exhausted state mismatch" + assert self._exhausted == state.exhausted, "Exhausted state mismatch" def new_iter(self) -> None: """ @@ -161,15 +180,24 @@ def prefetch_next(self) -> Future[TSample]: self.worker_config.worker_deactivate() return DoneFuture(next_sample) - def save_state(self) -> FlexState: + def save_state(self) -> WorkerState: """ Save the state of the worker. """ # This is called in the worker context (process/thread). - return FlexState( - __class__="DataLoaderWorker", + print(f"save_state: {self._sample_index=}, {self._exhausted=}\n", end="") + return WorkerState( rng=SystemRng.save_state(), dataset=self.dataset.save_state(), exhausted=self._exhausted, sample_index=self._sample_index, ) + + +class DataLoaderNoWorker(DataLoaderWorker[TSample], Generic[TSample]): + """ + DataLoader without async worker. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) diff --git a/src/megatron/energon/flavors/base_dataset.py b/src/megatron/energon/flavors/base_dataset.py index 7c1a9cc1..02518622 100644 --- a/src/megatron/energon/flavors/base_dataset.py +++ b/src/megatron/energon/flavors/base_dataset.py @@ -260,7 +260,7 @@ class SavableDataset(IterableDataset[T_sample], Savable, Generic[T_sample], ABC) How dataset state saving works: 1. The dataset state needs to be saved in all forked worker processes which contain a copy of - the main dataset instance (see :class:`megatron.energon.SavableDataLoader`). Each worker returns + the main dataset instance (see :class:`megatron.energon.DataLoader`). Each worker returns only its own state. 2. The main process merges the states via the :meth:`megatron.energon.SavableDataset.merge_states` method in the main process on the main dataset instance (which doesn't hold the worker states, diff --git a/src/megatron/energon/flavors/webdataset/error_handler.py b/src/megatron/energon/flavors/webdataset/error_handler.py index f12583a1..9829b4b3 100644 --- a/src/megatron/energon/flavors/webdataset/error_handler.py +++ b/src/megatron/energon/flavors/webdataset/error_handler.py @@ -35,6 +35,7 @@ def error_handler( Tuple[Union[T_sample, dict, FilteredSample, None], ...], ], ): + sources: list[SourceInfo] | None if isinstance(sample, dict): key = sample.get("__key__") sources = sample.get("__sources__") diff --git a/src/megatron/energon/rng.py b/src/megatron/energon/rng.py index 8eb2f008..cbb61324 100644 --- a/src/megatron/energon/rng.py +++ b/src/megatron/energon/rng.py @@ -17,6 +17,31 @@ T = TypeVar("T") +@edataclass +class WorkerRngState: + rng: Any + + def _hashable_value(self, value: Any) -> Any: + if isinstance(value, (int, float, bool, str)) or value is None: + return value + elif isinstance(value, torch.Tensor): + return self._hashable_value(value.tolist()) + elif isinstance(value, numpy.ndarray): + return self._hashable_value(value.tolist()) + elif isinstance(value, Mapping): + return tuple( + (self._hashable_value(k), self._hashable_value(v)) for k, v in value.items() + ) + elif isinstance(value, Sequence): + return tuple(self._hashable_value(v) for v in value) + else: + raise ValueError(f"Cannot hash value of type {type(value)}: {value!r}") + + def __repr__(self): + # If the hash is the same, the state is the same. Should suffice to identify the state. + return f"WorkerRngState(hash={hash(self._hashable_value((self.rng)))})" + + class WorkerRng(Savable): """Helper class for getting a worker random generator, which is still in itself deterministic. If not in a worker, uses the global random generator's seed to initialize a new rng.""" @@ -79,14 +104,16 @@ def shuffle(self, l: List[T]) -> List[T]: def rand_pop(self, l: List[T]) -> T: return l.pop(self.randbelow(len(l))) - def save_state(self) -> FlexState: - return FlexState(rng=None if self.rng is None else bytes(self.rng.get_state().tolist())) + def save_state(self) -> WorkerRngState: + return WorkerRngState( + rng=None if self.rng is None else bytes(self.rng.get_state().tolist()) + ) - def restore_state(self, state: FlexState): - if state["rng"] is None: + def restore_state(self, state: WorkerRngState): + if state.rng is None: self._restore_state = None else: - self._restore_state = state["rng"] + self._restore_state = state.rng class UserRng: @@ -210,4 +237,5 @@ def get_seed_from_args(*args: Any) -> int: @staticmethod def seed_args(*args: Any) -> None: """Seeds the global random generators deterministically from the given arguments.""" + print(f"Seeding with args: {args}") SystemRng.seed(SystemRng.get_seed_from_args(*args)) diff --git a/src/megatron/energon/savable_loader.py b/src/megatron/energon/savable_loader.py deleted file mode 100644 index a8a5d777..00000000 --- a/src/megatron/energon/savable_loader.py +++ /dev/null @@ -1,1406 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. -# SPDX-License-Identifier: BSD-3-Clause - -import gc -import multiprocessing -import queue -import threading -import time -import traceback -import weakref -from functools import partial -from typing import ( - Any, - ClassVar, - Dict, - Generic, - Iterator, - List, - Optional, - Sequence, - Tuple, - TypeVar, - Union, -) - -import numpy as np -import torch -import torch.distributed -import torch.multiprocessing -from torch.utils.data import DataLoader, IterableDataset - -from megatron.energon.cache import CachePool -from megatron.energon.edataclass import edataclass -from megatron.energon.errors import deprecated -from megatron.energon.flavors.base_dataset import ( - FlexState, - SavableDataset, - State, - add_sample_restore_key, -) -from megatron.energon.rng import SystemRng, SystemRngState -from megatron.energon.worker import WorkerConfig -from megatron.energon.wrappers.base import BaseWrapperDataset -from megatron.energon.wrappers.batch_dataset import BatchDataset -from megatron.energon.wrappers.gc_dataset import GC_DEFAULT_EVERY_N_ITER, GcDataset, gc_init_worker -from megatron.energon.wrappers.log_sample_dataset import default_get_keys -from megatron.energon.wrappers.watchdog_dataset import WatchdogDataset - -T = TypeVar("T") - - -def _init_worker(seed_per_worker: List[int], worker_id: int): - """Initializes the the worker process. - - Sets the random seeds and prepare EPath for the forked worker process. - """ - gc_init_worker(worker_id) - - worker_seed = seed_per_worker[worker_id] - - SystemRng.seed(worker_seed) - - -class SimpleSavableDatasetWrapper(BaseWrapperDataset[T, Tuple[int, int, T]], Generic[T]): - """Wrapper for non-multiprocessing savable datasets. Restarts the inner dataset. This class is - not intended to be used directly.""" - - #: The cache pool to use for the dataset. - cache_pool: CachePool - - _state_restored: bool - _sample_index: int - - _savable_fields = ("_sample_index",) - - def __init__( - self, dataset: SavableDataset[T], worker_config: WorkerConfig, cache_pool: CachePool - ): - """ - Args: - dataset: The dataset to wrap. - worker_config: The worker config to use for the dataset. - cache_pool: The cache pool to use for the dataset. - """ - super().__init__(dataset, worker_config=worker_config) - self.cache_pool = cache_pool - - def reset_state_own(self) -> None: - self._sample_index = 0 - self._state_restored = False - - def len_worker(self, worker_idx: int | None = None) -> int: - return self.dataset.len_worker(worker_idx) - - @property - def __len__(self): - # Note: This disables hasattr(self, "__len__"), because that attr will - raise AttributeError("Disabled direct length access to avoid DataLoader warnings.") - - def __iter__(self): - self._state_restored = True - worker_id = self.worker_config.rank_worker_id() - global_worker_id = self.worker_config.global_worker_id() - while self._state_restored: - self._state_restored = False - self.worker_config.worker_activate(self._sample_index, cache_pool=self.cache_pool) - worker_active = True - try: - for src_data in self.dataset: - self.worker_config.worker_deactivate() - worker_active = False - sample_index = self._sample_index - src_data = add_sample_restore_key( - src_data, global_worker_id, sample_index, src=self - ) - self._sample_index += 1 - yield worker_id, sample_index, src_data - if self._state_restored: - # Restart iterator after restore - break - self.worker_config.worker_activate( - self._sample_index, cache_pool=self.cache_pool - ) - worker_active = True - finally: - if worker_active: - self.worker_config.worker_deactivate() - - def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T: - id, global_worker_id, sample_idx = restore_key[:3] - assert id == type(self).__name__ - restore_key = restore_key[3:] - self.worker_config.worker_activate( - sample_idx, override_global_rank=global_worker_id, cache_pool=self.cache_pool - ) - try: - return add_sample_restore_key( - self.dataset.restore_sample(restore_key), - global_worker_id, - sample_idx, - src=self, - ) - finally: - self.worker_config.worker_deactivate() - - def config(self) -> Dict[str, Any]: - return self.dataset.config() - - def __str__(self): - return f"SimpleSavableDatasetWrapper(dataset={self.dataset})" - - -@edataclass -class SavableDatasetState(State): - """State of the dataset wrapper. It stores the global random states and the index of the next - sample to be returned from the dataset. This class is not intended to be used directly, but by - :class:`megatron.energon.SavableDatasetWrapper`.""" - - #: The state of all the system random number generators - rng: SystemRngState - #: The state of the savable dataset - dataset_state: FlexState - #: Index of the next sample to be returned from the dataset - sample_index: int - - def __repr__(self): - return f"SavableDatasetState(rng={self.rng!r}, sample_index={self.sample_index})" - - -@edataclass -class SavableCheckpoint: - """Checkpoint data for :class:`megatron.energon.SavableDatasetWrapper`. An instance is created - regularly to be able to save the state of the dataset wrapper before the currently emitted - sample. - """ - - #: The state of the wrapper - state: Optional[SavableDatasetState] - #: The time at which the checkpoint was created - checkpoint_time: float - #: Index of the next sample to be returned from the dataset after restoring the checkpoint - sample_index: int - - -@edataclass -class SavableDatasetCheckpoint(State): - """Checkpoint data for :class:`megatron.energon.SavableDatasetWrapper`. The checkpoint state - represents a state before that checkpoint, with an offset (i.e. samples to be skipped).""" - - #: The state of the wrapper at the sample index when the checkpoint was created. - state: Optional[SavableDatasetState] - #: Offset of the checkpoint to the actual sample index to be restored. - offset: int - - -class SavableDatasetWrapper(IterableDataset[Tuple[int, int, T]], Generic[T]): - """Internal class for wrapping a savable dataset for a worker process. Provides communication - with the :class:`megatron.energon.SavableDataLoader`. This class is not intended to be used directly. - See :class:`megatron.energon.SavableDataLoader` for more information.""" - - #: The wrapped dataset - dataset: SavableDataset[T] - #: The configuration of the worker process - worker_config: WorkerConfig - #: The time interval in seconds to wait at minimum between two checkpoints - checkpoint_every_sec: float - #: The minimum number of samples to be emitted between two checkpoints. Should be `number of - # workers * 2`. - checkpoint_every_min_n_samples: int - #: The number of checkpoints to keep in memory, before discarding. Should be 2. - n_checkpoints: int - #: The cache pool to use for the dataset. - cache_pool: CachePool - #: The queue of the worker process to receive commands from the `SavableDataLoader`. - _cmd_queues: List[torch.multiprocessing.Queue] - #: The queue of the worker process to send results to the `SavableDataLoader`. - _result_queues: List[torch.multiprocessing.Queue] - - _sample_index: int = 0 - _worker_offset: int = 0 - _last_checkpoints: List[SavableCheckpoint] - - _workers_restore_from: List[Optional[SavableDatasetState]] = list() - _workers_skip_samples: List[int] - - _running: bool = False - _command_lock: Optional[threading.RLock] = None - _cmd_thread: Optional[threading.Thread] = None - - def __init__( - self, - dataset: SavableDataset[T], - worker_config: WorkerConfig, - checkpoint_every_sec: float, - checkpoint_every_min_n_samples: int, - n_checkpoints: int = 2, - *, - cmd_queues: List[torch.multiprocessing.Queue], - result_queues: List[torch.multiprocessing.Queue], - cache_pool: CachePool, - ): - """ - Create the savable dataset wrapper for multiprocessing data loading. - - Args: - dataset: The dataset to wrap - worker_config: The worker config as used by all datasets - checkpoint_every_sec: The time interval in seconds to wait at minimum between two - checkpoints. - checkpoint_every_min_n_samples: The minimum number of samples to be emitted between - two checkpoints. Should be `number of workers * 2`. - n_checkpoints: Number of checkpoints to keep. - cmd_queues: The command queues for communicating with the worker processes. - result_queues: The result queues for communicating with the worker processes. - cache_pool: The cache pool to use for the dataset. - """ - num_workers = max(worker_config.num_workers, 1) - - self.dataset = dataset - self.worker_config = worker_config - self.checkpoint_every_sec = checkpoint_every_sec - self.checkpoint_every_min_n_samples = checkpoint_every_min_n_samples - self.n_checkpoints = n_checkpoints - self._last_checkpoints = [ - SavableCheckpoint(state=None, checkpoint_time=time.perf_counter(), sample_index=-1) - ] - self._workers_restore_from = [None] * num_workers - self._workers_skip_samples = [0] * num_workers - self._cmd_queues = cmd_queues - self._result_queues = result_queues - self.cache_pool = cache_pool - - @staticmethod - def _command_thread(self: "SavableDatasetWrapper"): - """The internal thread, which processes the command and result queues. This thread is - static, because `self` is actually passed as weakref proxy, to avoid keeping the dataset - alive via the thread. - """ - # print(f"{id(self)}:{multiprocessing.current_process().ident} Worker command thread starting") - assert self._command_lock is not None - - try: - while self._running: - try: - cmd_args = self._cmd_queues[self._worker_id].get(timeout=1) - except queue.Empty: - continue - # print(f"recv cmd {cmd_args}") - with self._command_lock: - cmd = cmd_args[0] - if cmd is None: - break - try: - fn = getattr(self, cmd) - self._result_queues[self._worker_id].put( - {self._worker_id: fn(*cmd_args[1:])} - ) - # print(f"result sent") - except Exception as e: - traceback.print_exc() - self._result_queues[self._worker_id].put({self._worker_id: e}) - # print(f"exc sent") - except BaseException: - traceback.print_exc() - raise - finally: - pass - # print(f"{id(self)}:{multiprocessing.current_process().ident} Worker command thread closing") - - def len_worker(self, worker_idx: int | None = None) -> int: - return self.dataset.len_worker(worker_idx) - - def len_rank(self): - return self.dataset.len_rank() - - @property - def __len__(self): - # Note: This disables hasattr(self, "__len__"), because that attr will - raise AttributeError("Disabled direct length access to avoid DataLoader warnings.") - - def __del__(self): - if self._cmd_thread is not None: - # print(f"{id(self)}:{multiprocessing.current_process().ident} Closing cmd thread") - self._running = False - self._cmd_thread.join() - self._command_lock = None - self._cmd_thread = None - # print(f"{id(self)}:{multiprocessing.current_process().ident} Cmd thread closed") - - def __iter__(self): - # First: Set the worker offset globally for the current worker - WorkerConfig.worker_id_offset = self._worker_offset - self._worker_id = self.worker_config.rank_worker_id() - global_worker_id = self.worker_config.global_worker_id() - if self._cmd_thread is None: - self._running = True - self._command_lock = threading.RLock() - weakref_self = weakref.proxy(self) - self._cmd_thread = threading.Thread( - target=SavableDatasetWrapper._command_thread, - name="command_thread", - args=(weakref_self,), - daemon=True, - ) - self._cmd_thread.start() - # atexit.register(lambda: weakref_self.__del__()) - try: - assert self._command_lock is not None - with self._command_lock: - if self._workers_restore_from[self._worker_id] is not None: - my_state = self._workers_restore_from[self._worker_id] - my_ds_state = my_state.dataset_state - assert my_state is not None - self.dataset.reset_state() - if my_ds_state is not None: - self.dataset.restore_state(my_ds_state) - self._restore_state(my_state) - self._workers_restore_from[self._worker_id] = None - else: - self.dataset.reset_state() - # Store the initial state of the worker if we stop before the first sample - self._store_checkpoint() - # If skipping, also restart the iterator to reach the start of the restored - # checkpoint - last_was_skip = True - while last_was_skip: - dataset_has_samples = False - self.worker_config.worker_activate( - self._sample_index, cache_pool=self.cache_pool - ) - worker_active = True - try: - for src_data in self.dataset: - self.worker_config.worker_deactivate() - worker_active = False - dataset_has_samples = True - if self._workers_skip_samples[self._worker_id] > 0: - # Skip ahead to reach the start of the restored checkpoint - # print(f"Skip [{self._sample_index}:{self._worker_id}] {src_data}") - self._workers_skip_samples[self._worker_id] -= 1 - self._sample_index += 1 - last_was_skip = True - self.worker_config.worker_activate( - self._sample_index, cache_pool=self.cache_pool - ) - worker_active = True - continue - last_was_skip = False - sample_index = self._sample_index - add_sample_restore_key( - src_data, global_worker_id, sample_index, src=self - ) - self._sample_index += 1 - self._store_checkpoint() - try: - self._command_lock.release() - # print(f"{id(self)}:{multiprocessing.current_process().ident} Lock released") - # Commands may be executed only when data was yielded, not during - # iteration fetching. - # print(f"Yield next data [{sample_index}:{self._worker_id}] {src_data}") - yield self._worker_id, sample_index, src_data - finally: - # print(f"{id(self)}:{multiprocessing.current_process().ident} Lock acquiring") - self._command_lock.acquire() - # print(f"{id(self)}:{multiprocessing.current_process().ident} Lock acquired") - self.worker_config.worker_activate( - self._sample_index, cache_pool=self.cache_pool - ) - worker_active = True - finally: - if worker_active: - self.worker_config.worker_deactivate() - - # If the dataset is empty, don't try again and again - if not dataset_has_samples: - break - finally: - # print(f"{id(self)}:{multiprocessing.current_process().ident} Worker iter closing") - # Always store a final checkpoint (it's likely to be saved) - self._store_checkpoint(force=True) - - def _store_checkpoint(self, force: bool = False) -> None: - """ - Internally create a checkpoint for the current state. This is required to store states - from the past, which have already been yielded here, but not yet been retrieved from the - intermediate queues. - - Args: - force: If true, ignore time or frequency condition. - """ - if ( - force - or ( - self._last_checkpoints[-1].checkpoint_time + self.checkpoint_every_sec - < time.perf_counter() - and self._last_checkpoints[-1].sample_index + self.checkpoint_every_min_n_samples - <= self._sample_index - ) - or self._sample_index <= 1 - ): - # print(f"Storing checkpoint at {self._worker_id}:{self._sample_index}") - self._last_checkpoints.append( - SavableCheckpoint( - state=self._save_state(), - checkpoint_time=time.perf_counter(), - sample_index=self._sample_index, - ) - ) - if len(self._last_checkpoints) > self.n_checkpoints: - self._last_checkpoints.pop(0) - - def _save_state(self) -> SavableDatasetState: - """Saves the internal state""" - return SavableDatasetState( - rng=SystemRng.save_state(), - dataset_state=self.dataset.save_state(), - sample_index=self._sample_index, - ) - - def _restore_state(self, state: SavableDatasetState) -> None: - """Restores the internal worker state""" - assert torch.utils.data.get_worker_info() is not None, "Can only restore in worker process" - if state.rng is None: - SystemRng.seed(torch.initial_seed() & 0xFFFFFFFF) - else: - SystemRng.restore_state(state.rng) - - self._sample_index = state.sample_index - self._last_checkpoints = [ - SavableCheckpoint( - state=self._save_state(), - checkpoint_time=time.perf_counter(), - sample_index=self._sample_index, - ) - ] - - def get_checkpoint(self, last_sample_indexes: List[int]) -> SavableDatasetCheckpoint: - """ - Get a checkpoint given the last emitted sample indexes for all workers. - - Args: - last_sample_indexes: The last emitted sample indexes for all workers. - - Returns: - The found checkpoint including the offset to the next sample index - """ - sample_index = last_sample_indexes[self._worker_id] + 1 - for checkpoint in reversed(self._last_checkpoints): - if checkpoint.sample_index <= sample_index: - # print(f"Found cp for {sample_index} at {checkpoint.sample_index}") - return SavableDatasetCheckpoint( - state=checkpoint.state, - offset=sample_index - checkpoint.sample_index, - ) - - # Immediate save after restore - if len(self._last_checkpoints) == 0: - return SavableDatasetCheckpoint( - state=self._workers_restore_from[self._worker_id], - offset=self._workers_skip_samples[self._worker_id], - ) - raise ValueError("No checkpoint found") - - def restore_checkpoint( - self, - worker_states: Optional[List[SavableDatasetCheckpoint]], - worker_offset: int, - ) -> None: - """ - Restores the merged checkpoint from all worker processes. - - Args: - worker_states: The state to restore for each worker - worker_offset: The offset of the last worker which has emitted a sample. This will be - set in all worker processes to ensure the right worker starts as first. - """ - assert torch.utils.data.get_worker_info() is None, "Cannot restore in worker process" - num_workers = max(self.worker_config.num_workers, 1) - - if worker_states is None: - self._workers_restore_from = [None] * num_workers - assert worker_offset == 0 - self._worker_offset = 0 - self._workers_skip_samples = [0] * num_workers - else: - assert isinstance(worker_states, list) - assert len(worker_states) == num_workers - assert isinstance(worker_states[0], SavableDatasetCheckpoint) - - self._worker_offset = worker_offset - - # Tear the state_list apart (which has len=num_workers) - # and store the states in the internal arrays - self._workers_restore_from = [state.state for state in worker_states] - self._workers_skip_samples = [state.offset for state in worker_states] - - def get_initial_checkpoint(self) -> Optional[List[SavableDatasetCheckpoint]]: - """ - Get the initial checkpoint for all worker processes if they have not started yet. - - Returns: - The initial checkpoint for all worker processes and the worker offset. - """ - assert torch.utils.data.get_worker_info() is None, ( - "Cannot get initial checkpoint in worker process" - ) - if all(s is None for s in self._workers_restore_from): - assert all(s == 0 for s in self._workers_skip_samples) - # Initial state, no checkpoint - return None - - return [ - SavableDatasetCheckpoint( - state=state, - offset=offset, - ) - for state, offset in zip(self._workers_restore_from, self._workers_skip_samples) - ] - - def can_restore_sample(self) -> bool: - return self.dataset.can_restore_sample() - - def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T: - id, global_worker_id, sample_idx = restore_key[:3] - assert id == type(self).__name__ - restore_key = restore_key[3:] - self.worker_config.worker_activate(sample_idx, override_global_rank=global_worker_id) - try: - return add_sample_restore_key( - self.dataset.restore_sample(restore_key), - global_worker_id, - sample_idx, - src=self, - ) - finally: - self.worker_config.worker_deactivate() - - def config(self) -> Dict[str, Any]: - return self.dataset.config() - - def __str__(self): - return f"SavableDatasetWrapper(dataset={self.dataset})" - - -@edataclass -class SavableDataLoaderState(State): - """Saved state of the :class:`megatron.energon.SavableDataLoader`. Contains the state for all worker - processed of a single rank.""" - - #: The internal state of the dataset (for each worker process) - worker_states: List[Union[SavableDatasetCheckpoint, FlexState]] - #: Which worker will be the next to emit a sample. Used to restore the proper order - next_worker_id: int - - #: The micro batch size that was used, if available. - #: On restore, this is used to compare the new and old micro batch size. - micro_batch_size: Optional[int] - - -class SavableDataLoader(DataLoader[T], Generic[T]): - """DataLoader that supports saving and restoring the state of the dataset. - When restoring, the dataloader and dataset must be instantiated with the exactly same - parameters. - - How this works (for no worker processes) - ---------------------------------------- - - 1. The state of the dataset is saved using :meth:`megatron.energon.SavableDataset.save_state` - 2. (for compatibility) The state of the dataset is converted to using inner arrays using - :meth:`megatron.energon.SavableDataset.merge_states`. - 3. The state can be restored using :meth:`megatron.energon.SavableDataset.restore_state` given the - previously saved (and merged) state. - - How this works (for worker processes) - ------------------------------------- - - - First issue is, that worker processes work with internal queues between processes to pass - loaded samples to the main process (also to perform collating). This means that the whole - state of the dataset is not directly accessible from the main process. - - To solve this issue, the dataset regularly saves a checkpoint of its state to be able to - resume from that state (and skip the samples that have already been yielded). - - To have a consistent state, the sample index from the latest yielded samples is saved for all - worker instances. Thus, the main process knows exactly which sample indexes should come next - from which worker. - - Internally, pytorch iterates through the workers in order to retrieve the next worker's - samples. Unfortunately, that next worker index cannot be restored in pytorch's dataloader, - thus the workers are shifted internally by that offset - (see :attr:`megatron.energon.WorkerConfig.worker_id_offset`). - - 1. The dataset is wrapped in a :class:`megatron.energon.SavableDatasetWrapper`. This allows the main - process to communicate with the worker and send commands to the workers and retrieve the - results. - 2. The state of the dataset is saved using - :meth:`megatron.energon.SavableDatasetWrapper.get_checkpoint`. This gives the last checkpoint - from the requested sample index and stores the offset (i.e. number of samples to skip) from - that checkpoint. - 3. The state is merged using :meth:`megatron.energon.SavableDatasetWrapper.merge_checkpoints`. This - merges the states of all workers and returns a single state that can be used to restore the - state of the dataset. - 4. The state can be restored using :meth:`megatron.energon.SavableDatasetWrapper.restore_state` - before a worker is started, such that all workers initially receive the same state array. - The worker firstly sets the worker index offset, then uses its (shifted) own index to get its - required state from the merged state array. - - """ - - #: The worker config - worker_config: WorkerConfig - #: The wrapped dataset. For multiprocessing, this is a :class:`megatron.energon.SavableDatasetWrapper` - dataset: Union[SavableDatasetWrapper[T], SimpleSavableDatasetWrapper[T]] - - #: The global ID counter - _next_id: ClassVar[int] = 0 - #: Class instance id - id: int = 0 - - #: The queues used to send commands to the workers - cmd_queues: List[torch.multiprocessing.Queue] - #: The queues used to receive results from the workers - result_queues: List[torch.multiprocessing.Queue] - - #: Instance of the current data iterator. There shall be only one active iterator, such that the - # dataset is not iterated multiple times in parallel. The state will proceed. - _persistent_iterator: Optional[Iterator[T]] = None - #: Whether the dataloader has running workers. - _has_workers: bool = False - #: The index of the current worker. -1 if not started yet. - _worker_sample_counters: List[int] - #: Id of the next worker to retrieve data from - _next_worker_id: int = 0 - #: Global index of the last yielded sample - _global_sample_idx: int = 0 - #: Current iterator index of the last yielded sample - _sample_idx: int = 0 - - def __init__( - self, - dataset: SavableDataset[T], - *, - checkpoint_every_sec: float = 60, - checkpoint_every_min_n_samples: Optional[int] = None, - n_checkpoints: Optional[int] = None, - gc_collect_every_n_steps: int = GC_DEFAULT_EVERY_N_ITER, - gc_freeze_at_start: bool = True, - prefetch_factor: int = 2, - cache_pool: Optional[CachePool] = None, - watchdog_timeout_seconds: Optional[float] = 60, - watchdog_initial_timeout_seconds: Optional[float] = None, - fail_on_timeout: bool = False, - ): - """ - Create the dataloader supporting saving and restoring the state. - - Args: - dataset: The dataset to load. - worker_config: The worker config to use - checkpoint_every_sec: This is the time in seconds after which a checkpoint is saved. - It may take the same duration to restore a checkpoint, but introduces additional - overhead during reading data from the dataset, so this should be chosen accordingly. - Only applies if using workers. - checkpoint_every_min_n_samples: Overwrites the minimum number of samples between - checkpoints. Defaults to `number of workers * 2`. Only applies if using workers. - n_checkpoints: The number of checkpoints to keep in memory. Only applies if using - workers. If None, computes a suitable value. - gc_collect_every_n_steps: The number of steps after which the garbage collector is - called. As we're usually handling large (but few) tensors here, and the python - garbage collection is already full of objects just by importing, this can improve - the memory footprint quite a lot, and may even be necessary to avoid memory - overflow. - gc_freeze_at_start: If true, the garbage collector is frozen at the start of the worker - processes. This improves the garbage collection performance by a lot. - In rare cases, this may cause issues and can be disabled. Keep enabled if you - experience no issues. - cache_pool: If set, the cache pool to use for the dataset. - watchdog_timeout_seconds: The timeout in seconds. If None, the watchdog is disabled. - watchdog_initial_timeout_seconds: The initial timeout in seconds. If None, the timeout is the same as watchdog_timeout_seconds. - fail_on_timeout: If True, stops the whole process upon timeout, after printing a stack trace. - """ - self.worker_config = dataset.worker_config - self.id = self.next_id() - - dataset = WatchdogDataset( - dataset, - worker_config=self.worker_config, - timeout_seconds=watchdog_timeout_seconds, - initial_timeout_seconds=watchdog_initial_timeout_seconds, - fail_on_timeout=fail_on_timeout, - ) - - if gc_collect_every_n_steps > 0: - dataset = GcDataset( - dataset, - worker_config=self.worker_config, - every_n_iter=gc_collect_every_n_steps, - freeze=gc_freeze_at_start, - ) - - self.cmd_queues = [multiprocessing.Queue() for _ in range(self.worker_config.num_workers)] - self.result_queues = [ - multiprocessing.Queue() for _ in range(self.worker_config.num_workers) - ] - - num_procs = max(self.worker_config.num_workers, 1) - - if n_checkpoints is None: - n_checkpoints = prefetch_factor * num_procs + 1 - - if self.worker_config.num_workers > 0: - if checkpoint_every_min_n_samples is None: - checkpoint_every_min_n_samples = self.worker_config.num_workers * 2 - - dataset = SavableDatasetWrapper( - dataset, - self.worker_config, - checkpoint_every_sec=checkpoint_every_sec, - checkpoint_every_min_n_samples=checkpoint_every_min_n_samples, - n_checkpoints=n_checkpoints, - cmd_queues=self.cmd_queues, - result_queues=self.result_queues, - cache_pool=cache_pool, - ) - else: - dataset = SimpleSavableDatasetWrapper( - dataset, self.worker_config, cache_pool=cache_pool - ) - - self._worker_sample_counters = [-1] * num_procs - - kwargs = {} - if self.worker_config.num_workers > 0: - kwargs["persistent_workers"] = True - kwargs["prefetch_factor"] = prefetch_factor - - # Assert that prefetch_factor works well with num_checkpoints. - # This ensures that the oldest checkpoint is old enough to cover - # all the buffered samples in the torch dataloader. - assert prefetch_factor * num_procs + 1 <= n_checkpoints, ( - "When increasing prefetch_factor, also increase n_checkpoints, so that " - "the number of checkpoints is at least as large as num_workers * prefetch_factor + 1" - ) - - # Compute seeds for each worker, based on current rank - seed_per_worker = [ - self.worker_config.worker_seed(i) for i in range(self.worker_config.num_workers) - ] - - super().__init__( - dataset, - batch_size=None, - shuffle=False, - num_workers=self.worker_config.num_workers, - pin_memory=True, - worker_init_fn=partial(_init_worker, seed_per_worker), - **kwargs, - ) - - if self.worker_config.should_log(level=1): - self.worker_config.worker_log( - { - "t": "SavableLoader.__init__", - "r": self.worker_config.rank, - "w": None, - "id": self.id, - "config": dataset.config(), - } - ) - - @staticmethod - def next_id() -> int: - next_id = SavableDataLoader._next_id - SavableDataLoader._next_id += 1 - return next_id - - def __len__(self): - # We override this, because otherwise we'll see warnings - return self.dataset.len_rank() - - def __iter__(self): - def _inner_generator(iterator): - iter_idx = 0 - id = self.next_id() - if self.worker_config.should_log(level=1): - self.worker_config.worker_log( - { - "t": "SavableDataLoader.iter", - "r": self.worker_config.rank, - "w": None, - "id": self.id, - "iter_id": id, - } - ) - try: - for worker_id, sample_idx, sample in iterator: - self._worker_sample_counters[worker_id] = sample_idx - # If the next sample will be from the first worker, we can safely resume - self._next_worker_id = (worker_id + 1) % max(self.num_workers, 1) - # self._debugf.write( - # f"[w={worker_id}, s={sample_idx}] {self._sample_str(sample)}\n" - # ) - # self._debugf.flush() - if self.worker_config.should_log(level=1): - keys = default_get_keys(sample) - self.worker_config.worker_log( - { - **{ - "t": "SavableDataLoader.yield", - "r": self.worker_config.rank, - "w": None, - "id": self.id, - "iter_id": id, - "worker_id": worker_id, - "worker_idx": sample_idx, - "idx": self._sample_idx, - "iter_idx": iter_idx, - "global_idx": self._global_sample_idx, - }, - **({} if keys is None else {"keys": keys}), - } - ) - self._sample_idx += 1 - self._global_sample_idx += 1 - iter_idx += 1 - yield sample - self._persistent_iterator = None - self._next_worker_id = 0 - finally: - if self.worker_config.should_log(level=1): - self.worker_config.worker_log( - { - "t": "SavableDataLoader.StopIteration", - "r": self.worker_config.rank, - "w": None, - "id": self.id, - "iter_id": self.id, - } - ) - - if self.num_workers > 0: - # Always keep same iterator alive, as long as it yields data - if self._persistent_iterator is None: - self._persistent_iterator = _inner_generator(super().__iter__()) - self._sample_idx = 0 - self._has_workers = True - # print("New Iterator", self._persistent_iterator) - return self._persistent_iterator - else: - return _inner_generator(super().__iter__()) - - def _worker_command(self, *cmd_args) -> List[Any]: - """Executes a command in all workers and returns the results.""" - # print(f"cmd: {cmd_args}") - for cmd_queue in self.cmd_queues: - cmd_queue.put(cmd_args) - # print(f"waiting for res") - assert len(self.result_queues) == self.worker_config.num_workers - res = {k: v for results_queue in self.result_queues for k, v in results_queue.get().items()} - res = [res[i] for i in range(len(res))] - # print(f"res: {res}") - for r in res: - if isinstance(r, Exception): - raise r - return res - - def _get_batch_size(self) -> Optional[int]: - """Try to infer micro batch size from the dataset""" - if isinstance(self.dataset, (SavableDatasetWrapper, SimpleSavableDatasetWrapper)): - dataset = self.dataset.dataset - else: - dataset = self.dataset - - if ( - isinstance(dataset, BaseWrapperDataset) - and (bds := dataset._find_wrapped_dataset(BatchDataset)) is not None - ): - assert isinstance(bds, BatchDataset) - return bds.batch_size - else: - return None - - def save_state_rank(self) -> Optional[SavableDataLoaderState]: - """ - Saves the state of the dataset for the current rank. Allows for restoring the state later - using `restore_state_rank`, given the result of this method. - - Returns: - The state of the dataset. - """ - # Fetch current rank's worker's state - if self.num_workers == 0: - # No workers configured - assert isinstance(self.dataset, SimpleSavableDatasetWrapper) - worker_states = [self.dataset.save_state()] - assert self._next_worker_id == 0 - elif self._has_workers: - # Fetch from worker processes - worker_states = self._worker_command("get_checkpoint", self._worker_sample_counters) - else: - # Workers configured, but not started yet. - # If a state has already been restored, it will be returned. - assert isinstance(self.dataset, SavableDatasetWrapper) - worker_states = self.dataset.get_initial_checkpoint() - - if worker_states is None: - return None - - # Merge the states - merged_state = SavableDataLoaderState( - worker_states=worker_states, - next_worker_id=self._next_worker_id, - micro_batch_size=self._get_batch_size(), - ) - - # Not distributed -> return the merged state - return merged_state - - def restore_state_rank(self, state: Optional[SavableDataLoaderState]) -> None: - """ - Restores the saved state for the current rank. - - Args: - state: The state to restore, as saved by `save_state_rank`. - """ - assert not self._has_workers, "Cannot restore state while workers are running" - if state is None: - # Assume initial state - return - assert isinstance(state, SavableDataLoaderState) - - old_micro_batch_size = state.micro_batch_size - micro_batch_size = self._get_batch_size() - - if self.num_workers == 0: - # No workers configured - assert isinstance(self.dataset, SimpleSavableDatasetWrapper) - assert micro_batch_size == old_micro_batch_size, ( - "Changing micro batch size is not allowed without workers" - ) - - assert len(state.worker_states) == 1 - assert isinstance(state.worker_states[0], FlexState) - self.dataset.restore_state(state.worker_states[0]) - else: - # Workers configured - assert isinstance(self.dataset, SavableDatasetWrapper) - assert all(isinstance(s, SavableDatasetCheckpoint) for s in state.worker_states) - - # Check batch sizes (before and after) - if micro_batch_size != old_micro_batch_size: - assert micro_batch_size is not None and old_micro_batch_size is not None, ( - "Cannot resume with different batching mode " - "(batching to non-batching or vice versa)" - ) - - if micro_batch_size > old_micro_batch_size: - raise ValueError( - "Resuming with larger micro batch size is not allowed: " - f"{micro_batch_size} > {state.micro_batch_size}" - ) - elif ( - micro_batch_size < old_micro_batch_size - and old_micro_batch_size % micro_batch_size != 0 - ): - raise ValueError( - "Resuming with smaller micro batch size only allowed if the old " - f"micro batch size is a multiple of the new one: {micro_batch_size} < {state.micro_batch_size}" - ) - batch_size_ratio = old_micro_batch_size // micro_batch_size - for worker_state in state.worker_states: - assert isinstance(worker_state, SavableDatasetCheckpoint) - # When resuming with a smaller micro batch size, the offset must be scaled - # up to the new micro batch size to skip the same number of samples as before. - worker_state.offset *= batch_size_ratio - - self.dataset.restore_checkpoint(state.worker_states, worker_offset=state.next_worker_id) - - # Initialize the worker-sample counters so that every worker owns a valid - # "last emitted sample" index. Workers that have not emitted anything yet keep - # the default value ``-1``. - - assert isinstance(state.worker_states, list) - - self._worker_sample_counters = [ - ( - ws.state.sample_index - 1 - if (isinstance(ws, SavableDatasetCheckpoint) and ws.state is not None) - else -1 - ) - for ws in state.worker_states - ] - - self._next_worker_id = state.next_worker_id - - @deprecated( - "`save_state` is deprecated and was renamed to `save_state_global` and will be removed " - "in a future update. If you actually do not want to gather the states to a rank, use " - "`save_state_rank` instead." - ) - def save_state(self, dst_rank: int) -> Optional[Sequence[Optional[SavableDataLoaderState]]]: - """Deprecated. Use `save_state_global` (or `save_state_rank`) instead.""" - - return self.save_state_global(dst_rank) - - def save_state_global( - self, global_dst_rank: int - ) -> Optional[Sequence[Optional[SavableDataLoaderState]]]: - """ - Saves the state of the dataset globally, collecting the state from all ranks using torch - distributed. Allows for restoring the state later using `restore_state_global`, given the - result of this method. - Typical scenario: Save the state to disk only on the `dst_rank`, the other ranks do not - save the state. Later, restore the state either only loaded on the `dst_rank` or - loading on all ranks separately using `restore_state_global`. - - Note: If you want to save/restore the state per rank separately, use `save_state_rank` and - the corresponding `restore_state_rank`. Also, these do not rely on torch distributed. - - Args: - global_dst_rank: The state will be gathered to this rank. The rank refers to the - global rank, not the rank within the data parallel group. - - Returns: - The state of the dataset (or `None`, if not on `dst_rank`). - """ - # Fetch current rank's worker's state - merged_state = self.save_state_rank() - - # Gather the merged states - if self.worker_config.world_size > 1: - output: Optional[Sequence[Optional[SavableDataLoaderState]]] - if self.worker_config.global_rank() == global_dst_rank: - output = [None] * self.worker_config.world_size - else: - # Check if the global_dst_rank is in the same group at all - if self.worker_config.data_parallel_group is not None: - try: - _ = torch.distributed.get_group_rank( - self.worker_config.data_parallel_group, global_dst_rank - ) - except RuntimeError: - raise ValueError( - f"global_dst_rank {global_dst_rank} is not in the group of the current rank's worker config" - ) - - output = None - - torch.distributed.gather_object( - merged_state, - output, - global_dst_rank, - group=self.worker_config.data_parallel_group, - ) - - return output - else: - # Not distributed -> return the merged state - return [merged_state] - - @deprecated( - "`restore_state` was renamed to `restore_state_global` and will be removed in a future update." - ) - def restore_state( - self, - state: Optional[Sequence[Optional[SavableDataLoaderState]]], - ) -> None: - """Deprecated. Use `restore_state_global` (or `restore_state_rank`) instead.""" - - return self.restore_state_global(state) - - def restore_state_global( - self, - state: Optional[Sequence[Optional[SavableDataLoaderState]]], - *, - src_rank: Optional[int] = None, - ) -> None: - """ - Restores the saved state from `save_state_global` (in torch distributed setup). - The global state needs be loaded on every rank that has a data loader instance. - - Optionally, one can specify a src_rank and only provide the state once. - In case of multiple data parallel groups, you must provide the state once - in each data parallel group. In this case the `src_rank` is the rank within the - data parallel group. - - Args: - state: The state to restore, as saved by `save_state_global`. - src_rank: The rank from which the state is broadcasted (within the data parallel group, if using DP groups). - """ - - assert self._persistent_iterator is None, "Cannot restore state while workers are running" - - # Only restore multi-rank if state is actually a list and we are in a distributed setup. - # Otherwise treat as single rank state. - if src_rank is None or self.worker_config.world_size == 1: - assert isinstance(state, list), "State must be a list in distributed setup" - assert len(state) == self.worker_config.world_size, ( - "State must be a list of size world_size" - ) - - # All ranks have the state - # Select the state of the current rank - rank_state = state[self.worker_config.rank] - else: - if self.worker_config.data_parallel_group is not None: - # Only the src_rank has the state within this dp group - try: - global_src_rank = torch.distributed.get_global_rank( - self.worker_config.data_parallel_group, src_rank - ) - except RuntimeError: - raise ValueError( - f"src_rank {src_rank} is not in the group of the current rank's worker config" - ) - else: - # If no DP group is given, we assume the global rank is - # the same as the data parallel rank - global_src_rank = src_rank - - if self.worker_config.rank != src_rank: - # Send the state to all other ranks - assert state is None - # Must still be a list of Nones - state = [None] * self.worker_config.world_size - else: - assert isinstance(state, list), "State must be a list in distributed setup" - assert len(state) == self.worker_config.world_size, ( - "State must be a list of size world_size" - ) - - local_object = [None] - torch.distributed.scatter_object_list( - local_object, - state, - src=global_src_rank, - group=self.worker_config.data_parallel_group, - ) - rank_state = local_object[0] - - self.restore_state_rank(rank_state) - - def can_restore_sample(self) -> bool: - return self.dataset.can_restore_sample() - - def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T: - """Restores a sample from a key. This is useful to debug the dataset.""" - return self.dataset.restore_sample(restore_key) - - def config(self): - """Get the configuration, which defines the dataset. Useful in conjunction with `save_state` - and `restore_state` to match the configuration as well.""" - return { - "type": type(self).__qualname__, - "num_workers": self.num_workers, - "persistent_workers": self.persistent_workers, - "pin_memory": self.pin_memory, - "prefetch_factor": None if self.num_workers == 0 else self.prefetch_factor, - "dataset": self.dataset.config(), - } - - -class BasicDataLoader(DataLoader[T], Generic[T]): - """DataLoader that supports debugging the dataset without saving capability (e.g. for val/eval).""" - - #: The worker config - worker_config: WorkerConfig - #: The wrapped dataset. For multiprocessing, this is a :class:`megatron.energon.SavableDatasetWrapper` - dataset: Union[SavableDatasetWrapper[T], SavableDataset[T]] - - id: int - _sample_idx: int = 0 - - def __init__( - self, - dataset: SavableDataset[T], - gc_collect_every_n_steps: int = GC_DEFAULT_EVERY_N_ITER, - gc_freeze_at_start: bool = True, - prefetch_factor: int = 2, - cache_pool: Optional[CachePool] = None, - watchdog_timeout_seconds: Optional[float] = 60, - watchdog_initial_timeout_seconds: Optional[float] = None, - fail_on_timeout: bool = False, - ): - """ - Create the dataloader supporting saving and restoring the state. - - Args: - dataset: The dataset to load. - gc_collect_every_n_steps: The number of steps after which the garbage collector is - called. As we're usually handling large (but few) tensors here, and the python - garbage collection is already full of objects just by importing, this can improve - the memory footprint quite a lot, and may even be necessary to avoid memory - overflow. - gc_freeze_at_start: If true, the garbage collector is frozen at the start of the worker - processes. This improves the garbage collection performance by a lot. - In rare cases, this may cause issues and can be disabled. Keep enabled if you - experience no issues. - cache_pool: If set, the cache pool to use for the dataset. - watchdog_timeout_seconds: The timeout in seconds. If None, the watchdog is disabled. - watchdog_initial_timeout_seconds: The initial timeout in seconds. If None, the timeout is the same as watchdog_timeout_seconds. - fail_on_timeout: If True, stops the whole process upon timeout, after printing a stack trace. - """ - self.worker_config = dataset.worker_config - - self.id = SavableDataLoader.next_id() - - dataset = WatchdogDataset( - dataset, - worker_config=self.worker_config, - timeout_seconds=watchdog_timeout_seconds, - initial_timeout_seconds=watchdog_initial_timeout_seconds, - fail_on_timeout=fail_on_timeout, - ) - - if gc_collect_every_n_steps > 0: - dataset = GcDataset( - dataset, - worker_config=self.worker_config, - every_n_iter=gc_collect_every_n_steps, - freeze=gc_freeze_at_start, - ) - - dataset = SimpleSavableDatasetWrapper( - dataset, worker_config=self.worker_config, cache_pool=cache_pool - ) - - self._worker_sample_counters = [0] * max(self.worker_config.num_workers, 1) - - kwargs = {} - if self.worker_config.num_workers > 0: - # These must not be specified for num_workers =0 - kwargs["persistent_workers"] = True - kwargs["prefetch_factor"] = prefetch_factor - - seed_per_worker = [ - self.worker_config.worker_seed(i) for i in range(self.worker_config.num_workers) - ] - - gc.collect() # This ensures that we don't include any old worker refs in the newly forked worker processes - - super().__init__( - dataset, - batch_size=None, - shuffle=False, - num_workers=self.worker_config.num_workers, - pin_memory=True, - worker_init_fn=partial(_init_worker, seed_per_worker), - **kwargs, - ) - if self.worker_config.should_log(level=1): - self.worker_config.worker_log( - { - "t": "BasicDataLoader.__init__", - "r": self.worker_config.rank, - "w": None, - "id": self.id, - "config": self.config(), - } - ) - - def __len__(self): - # We override this, because otherwise we'll see warnings - return self.dataset.len_rank() - - def __iter__(self): - def _inner_generator(iterator): - iter_idx = 0 - id = SavableDataLoader.next_id() - - if self.worker_config.should_log(level=1): - self.worker_config.worker_log( - { - "t": "BasicDataLoader.iter", - "r": self.worker_config.rank, - "w": None, - "id": self.id, - "iter_id": id, - } - ) - - try: - for worker_id, sample_idx, sample in iterator: - # If the next sample will be from the first worker, we can safely resume - if self.worker_config.should_log(level=1): - keys = default_get_keys(sample) - self.worker_config.worker_log( - { - **{ - "t": "BasicDataLoader.yield", - "r": self.worker_config.rank, - "w": None, - "id": self.id, - "iter_id": self.id, - "worker_id": worker_id, - "worker_idx": sample_idx, - "idx": iter_idx, - "iter_idx": iter_idx, - "global_idx": self._sample_idx, - }, - **({} if keys is None else {"keys": keys}), - } - ) - self._sample_idx += 1 - iter_idx += 1 - yield sample - finally: - if self.worker_config.should_log(level=1): - self.worker_config.worker_log( - { - "t": "BasicDataLoader.StopIteration", - "r": self.worker_config.rank, - "w": None, - "id": self.id, - "iter_id": id, - } - ) - - return _inner_generator(super().__iter__()) - - def config(self): - """Get the configuration, which defines the dataset. Useful in conjunction with `save_state` - and `restore_state` to match the configuration as well.""" - return { - "type": type(self).__qualname__, - "num_workers": self.worker_config.num_workers, - "persistent_workers": self.persistent_workers, - "pin_memory": self.pin_memory, - "prefetch_factor": None if self.num_workers == 0 else self.prefetch_factor, - "dataset": self.dataset.config(), - } - - def can_restore_sample(self) -> bool: - return self.dataset.can_restore_sample() - - def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T: - """Restores a sample from a key. This is useful to debug the dataset.""" - return self.dataset.restore_sample(restore_key) - - -def _sample_str(self, sample): - """Returns a human readable debug string for a single sample, also uniquely identifying it.""" - import dataclasses - import hashlib - - if isinstance(sample, torch.Tensor): - return f"Tensor(shape={sample.shape}, dtype={sample.dtype}, sha256={hashlib.sha256(sample.detach().cpu().numpy().tobytes()).hexdigest()!r})" - elif isinstance(sample, np.ndarray): - return f"ndarray(shape={sample.shape}, dtype={sample.dtype}, sha256={hashlib.sha256(sample.tobytes()).hexdigest()!r})" - elif isinstance(sample, (str, bytes)): - if len(sample) > 100: - return f"{sample[:100]!r}..." - return repr(sample) - elif isinstance(sample, (int, float)): - return repr(sample) - elif isinstance(sample, dict): - innerstr = ", ".join(f"{k}={self._sample_str(v)}, " for k, v in sample.items()) - return f"dict({innerstr})" - elif isinstance(sample, (list, tuple)): - innerstr = ", ".join(self._sample_str(v) for v in sample) - return f"{type(sample)}({innerstr})" - elif dataclasses.is_dataclass(sample): - innerstr = ", ".join( - f"{field.name}={self._sample_str(getattr(sample, field.name))}, " - for field in dataclasses.fields(sample) - ) - return f"{type(sample).__name__}({innerstr})" - else: - return repr(sample) diff --git a/src/megatron/energon/task_encoder/base.py b/src/megatron/energon/task_encoder/base.py index 48a0ad6e..ddb39b4e 100644 --- a/src/megatron/energon/task_encoder/base.py +++ b/src/megatron/energon/task_encoder/base.py @@ -1079,10 +1079,10 @@ def cache(self) -> CachePool: assert WorkerConfig.active_worker_config is not None, ( "The cache can only be fetched within the worker, and to be usable, you must use the get_(savable_)loader methods provided from the package." ) - assert WorkerConfig.active_worker_config._cache_pool is not None, ( + assert WorkerConfig.active_worker_config._active_state.cache_pool is not None, ( "Cache pool must be set by the loader." ) - return WorkerConfig.active_worker_config._cache_pool + return WorkerConfig.active_worker_config._active_state.cache_pool class DefaultTaskEncoder( diff --git a/src/megatron/energon/worker.py b/src/megatron/energon/worker.py index 22027bc5..8a549355 100644 --- a/src/megatron/energon/worker.py +++ b/src/megatron/energon/worker.py @@ -105,10 +105,6 @@ def worker_activate( WorkerConfig._active_state.worker_config = self WorkerConfig._active_state.override_global_rank = override_global_rank WorkerConfig._active_state.cache_pool = cache_pool - print( - f"worker_activate {self.rank} {self.num_workers} {self.rank_worker_id()} on {threading.get_ident()}\n", - end="", - ) def worker_push_sample_index(self, sample_index: int): """Pushes a new sample index to the sample index stack. Should be set by wrapping datasets @@ -217,7 +213,7 @@ def rank_worker_id(self) -> int: # worker (id=0) corresponds to the logical worker that should emit the # next sample. For example, if `worker_id_offset` is 1, logical worker # 1 becomes the first to emit a sample, shifting the ordering forward. - return (worker_info.id + self.worker_id_offset) % worker_info.num_workers + return (worker_info.id + self.worker_id_offset) % max(worker_info.num_workers, 1) def assert_worker(self): """Checks if the current process is a worker (if configured so), and that the workers are diff --git a/src/megatron/energon/wrappers/batch_dataset.py b/src/megatron/energon/wrappers/batch_dataset.py index edfcfd47..10b534f6 100644 --- a/src/megatron/energon/wrappers/batch_dataset.py +++ b/src/megatron/energon/wrappers/batch_dataset.py @@ -37,7 +37,7 @@ class BatchDataset(BaseWrapperDataset[T_batch_sample, T_batch], Generic[T_batch_ drop_last: bool error_handler: Callable[[Exception, list[T_batch_sample], Sequence[SourceInfo]], None] _sample_index: SampleIndex - _generator_sample_keys: Optional[Any] + _generator_sample_keys: Optional[list[Any]] _generator_offset: Optional[int] _savable_fields = ("_sample_index", "_generator_sample_keys", "_generator_offset") diff --git a/src/megatron/energon/wrappers/buffer.py b/src/megatron/energon/wrappers/buffer.py index af7afb63..3b0d0b32 100644 --- a/src/megatron/energon/wrappers/buffer.py +++ b/src/megatron/energon/wrappers/buffer.py @@ -15,7 +15,8 @@ Union, ) -from megatron.energon.flavors.base_dataset import FlexState, SavableDataset +from megatron.energon.edataclass import edataclass +from megatron.energon.flavors.base_dataset import SavableDataset from megatron.energon.savable import Savable from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import get_sample_restore_key @@ -23,6 +24,13 @@ T_sample = TypeVar("T_sample") +@edataclass +class SavableSampleBufferState: + """State of a SavableSampleBuffer.""" + + restore_keys: List[Tuple[Union[str, int, tuple], ...]] + + class SavableSampleBuffer(Savable, Generic[T_sample]): """A buffer of samples, savable. State is shared, create a state-local instance.""" @@ -103,21 +111,15 @@ def len_worker(self, worker_idx: int | None = None) -> int: def len_rank(self) -> int: raise NotImplementedError("len_rank is not available for SavableSampleBuffer") - def save_state(self) -> FlexState: + def save_state(self) -> SavableSampleBufferState: # Don't call super().save_state() because we don't want to save the wrapped datasets # Just save the own state - return FlexState( - __class__=type(self).__name__, - _restore_keys=self._restore_keys, - ) + return SavableSampleBufferState(restore_keys=self._restore_keys.copy()) - def restore_state(self, state: FlexState) -> None: + def restore_state(self, state: SavableSampleBufferState) -> None: # Don't call super().restore_state() because we don't want to restore the wrapped datasets # Just restore the own state - assert state["__class__"] == type(self).__name__, ( - f"Expected class {type(self).__name__}, got {state['__class__']}" - ) - self._restore_keys = state["_restore_keys"].copy() + self._restore_keys = state.restore_keys.copy() self._restore_pending = True def restore_key(self) -> Tuple[Union[str, int], ...]: diff --git a/src/megatron/energon/wrappers/gc_dataset.py b/src/megatron/energon/wrappers/gc_dataset.py index bcb89584..82fdc375 100644 --- a/src/megatron/energon/wrappers/gc_dataset.py +++ b/src/megatron/energon/wrappers/gc_dataset.py @@ -11,6 +11,7 @@ from torch.distributed.distributed_c10d import reduce_op from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.state import FlexState from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset @@ -119,6 +120,14 @@ def __iter__(self) -> Iterator[T_sample]: if self.freeze: gc.unfreeze() + def save_state(self) -> FlexState: + # Just delegate, make self transparent + return self.dataset.save_state() + + def restore_state(self, state: FlexState): + # Just delegate, make self transparent + return self.dataset.restore_state(state) + def config(self) -> Dict[str, Any]: # This is transparent, no config to be saved (it does not affect the dataset) return self.dataset.config() diff --git a/src/megatron/energon/wrappers/limit_dataset.py b/src/megatron/energon/wrappers/limit_dataset.py index 0090e97f..62425ec9 100644 --- a/src/megatron/energon/wrappers/limit_dataset.py +++ b/src/megatron/energon/wrappers/limit_dataset.py @@ -103,7 +103,6 @@ def __iter__(self) -> Iterator[T_sample]: ) # Reset the inner dataset - self.dataset.reset_state() self.current_offset = 0 if self.reset_after_epoch: self.dataset.reset_state() diff --git a/src/megatron/energon/wrappers/log_sample_dataset.py b/src/megatron/energon/wrappers/log_sample_dataset.py index a0851693..a6db5c38 100644 --- a/src/megatron/energon/wrappers/log_sample_dataset.py +++ b/src/megatron/energon/wrappers/log_sample_dataset.py @@ -4,6 +4,7 @@ from typing import Any, Callable, Dict, Generic, Iterator, List, Literal, Optional, TypeVar from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.state import FlexState from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset @@ -96,6 +97,14 @@ def __iter__(self) -> Iterator[T_sample]: self._step += 1 yield sample + def save_state(self) -> FlexState: + # Just delegate, make self transparent + return self.dataset.save_state() + + def restore_state(self, state: FlexState): + # Just delegate, make self transparent + return self.dataset.restore_state(state) + def config(self) -> Dict[str, Any]: # Transparent logger, it won't change the samples return self.dataset.config() diff --git a/src/megatron/energon/wrappers/packing_dataset.py b/src/megatron/energon/wrappers/packing_dataset.py index 2727ce9b..8d48742f 100644 --- a/src/megatron/energon/wrappers/packing_dataset.py +++ b/src/megatron/energon/wrappers/packing_dataset.py @@ -398,19 +398,15 @@ def restore_sample(self, restore_key: Any) -> T_sample: # We need to store multiple indices to restore a batch. self.assert_can_restore() if inspect.isgeneratorfunction(self.final_packer): - id, pack_idx, pack_sub_idx, *pack_restore_keys = restore_key id, pack_idx, pack_sub_idx, *pack_restore_keys = restore_key assert id == type(self).__name__ else: - id, pack_idx, *pack_restore_keys = restore_key id, pack_idx, *pack_restore_keys = restore_key assert id == type(self).__name__ pack = [] for inner_idx in pack_restore_keys: if self.sample_encoder is not None: - id, sample_idx, *inner_idx = inner_idx - assert id == type(self).__name__ id, sample_idx, *inner_idx = inner_idx assert id == type(self).__name__ assert isinstance(sample_idx, int) diff --git a/src/megatron/energon/wrappers/watchdog_dataset.py b/src/megatron/energon/wrappers/watchdog_dataset.py index 07c85d11..6bd65cb2 100644 --- a/src/megatron/energon/wrappers/watchdog_dataset.py +++ b/src/megatron/energon/wrappers/watchdog_dataset.py @@ -5,6 +5,7 @@ from typing import Any, Dict, Generic, Iterator, Optional, TypeVar from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.state import FlexState from megatron.energon.watchdog import Watchdog from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset @@ -15,6 +16,10 @@ class WatchdogDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]): """This dataset wraps another dataset and watches the time it takes to yield samples.""" + timeout_seconds: Optional[float] + initial_timeout_seconds: Optional[float] + fail_on_timeout: bool + def __init__( self, dataset: SavableDataset[T_sample], @@ -65,6 +70,14 @@ def __iter__(self) -> Iterator[T_sample]: ) yield from watchdog.watch_iter(self.dataset) + def save_state(self) -> FlexState: + # Just delegate, make self transparent + return self.dataset.save_state() + + def restore_state(self, state: FlexState): + # Just delegate, make self transparent + return self.dataset.restore_state(state) + def config(self) -> Dict[str, Any]: # Watchdog is transparent, it won't change the samples return self.dataset.config() diff --git a/tests/test_crudedataset.py b/tests/test_crudedataset.py index 213dd229..3e2ec4bc 100644 --- a/tests/test_crudedataset.py +++ b/tests/test_crudedataset.py @@ -95,6 +95,7 @@ class CookingTaskEncoder(DefaultTaskEncoder[TextSample, TextSample, TextBatch, T Cooker(cook_aux, has_subflavors={"crude_type": "aux_random_access"}), ] + @stateless def batch(self, samples: List[TextSample]) -> TextBatch: return TextBatch.from_samples( samples, @@ -163,6 +164,7 @@ def pack_selected_samples(self, samples: List[LazyTextSample]) -> TextSample: text=samples[0].txt + "|" + next_txt, ) + @stateless def batch(self, samples: List[TextSample]) -> TextBatch: return TextBatch.from_samples( samples, @@ -196,6 +198,7 @@ def pack_selected_samples(self, samples: List[TextSample]) -> TextSample: assert len(samples) == 1 return samples[0] + @stateless def batch(self, samples: List[TextSample]) -> TextBatch: return TextBatch.from_samples( samples, diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 76b66e70..126b4de7 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -22,6 +22,24 @@ ) from megatron.energon.dataloader import DataLoader, ForkDataLoaderWorker, ThreadDataLoaderWorker from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME +from megatron.energon.task_encoder.base import DefaultTaskEncoder, stateless + + +class VerifyWorkerTaskEncoder(DefaultTaskEncoder): + def __init__(self, expected_num_workers: int): + self.expected_num_workers = expected_num_workers + super().__init__() + + @stateless + def encode_sample(self, sample): + sample = super().encode_sample(sample) + worker_info = torch.utils.data.get_worker_info() + if self.expected_num_workers > 0: + assert worker_info is not None + assert worker_info.num_workers == self.expected_num_workers + else: + assert worker_info is None + return sample class TestDataloader(unittest.TestCase): @@ -112,6 +130,7 @@ def test_dataloader_no_workers(self): shuffle_buffer_size=None, max_samples_per_sequence=None, repeat=False, + task_encoder=VerifyWorkerTaskEncoder(expected_num_workers=0), ), ) as train_loader: assert len(train_loader) == 6, len(train_loader) @@ -139,6 +158,7 @@ def test_dataloader_no_workers(self): shuffle_buffer_size=None, max_samples_per_sequence=None, repeat=False, + task_encoder=VerifyWorkerTaskEncoder(expected_num_workers=0), ), ).with_restored_state_rank(state1) as train_loader: cmp_order2 = [ @@ -164,6 +184,7 @@ def test_dataloader_fork(self): shuffle_buffer_size=None, max_samples_per_sequence=None, repeat=False, + task_encoder=VerifyWorkerTaskEncoder(expected_num_workers=2), ), prefetch_factor=2, worker_type=ForkDataLoaderWorker, @@ -199,6 +220,9 @@ def test_dataloader_fork(self): shuffle_buffer_size=None, max_samples_per_sequence=None, repeat=False, + task_encoder=VerifyWorkerTaskEncoder( + expected_num_workers=worker_config.num_workers + ), ), prefetch_factor=2, worker_type=ForkDataLoaderWorker, @@ -236,6 +260,9 @@ def test_dataloader_fork_multi_parallel(self): shuffle_buffer_size=None, max_samples_per_sequence=None, repeat=False, + task_encoder=VerifyWorkerTaskEncoder( + expected_num_workers=worker_config_r0.num_workers + ), ), prefetch_factor=2, worker_type=ForkDataLoaderWorker, @@ -263,6 +290,9 @@ def test_dataloader_fork_multi_parallel(self): shuffle_buffer_size=None, max_samples_per_sequence=None, repeat=False, + task_encoder=VerifyWorkerTaskEncoder( + expected_num_workers=worker_config_r1.num_workers + ), ), prefetch_factor=2, worker_type=ForkDataLoaderWorker, @@ -300,6 +330,7 @@ def test_dataloader_fork_multi_parallel(self): train_loader_r1.shutdown() def test_dataloader_thread(self): + assert False, "Not done yet" torch.manual_seed(42) worker_config = WorkerConfig( rank=0, @@ -317,6 +348,9 @@ def test_dataloader_thread(self): shuffle_buffer_size=None, max_samples_per_sequence=None, repeat=False, + task_encoder=VerifyWorkerTaskEncoder( + expected_num_workers=worker_config.num_workers + ), ), prefetch_factor=2, worker_type=ThreadDataLoaderWorker, @@ -349,6 +383,9 @@ def test_dataloader_thread(self): shuffle_buffer_size=None, max_samples_per_sequence=None, repeat=False, + task_encoder=VerifyWorkerTaskEncoder( + expected_num_workers=worker_config.num_workers + ), ), prefetch_factor=2, worker_type=ThreadDataLoaderWorker, diff --git a/tests/test_dataset.py b/tests/test_dataset.py index c3c5f319..b0bf9bf7 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -35,7 +35,6 @@ MapDataset, MixBatchDataset, Sample, - SavableDataLoader, TaskEncoder, WorkerConfig, generic_batch, @@ -45,6 +44,7 @@ get_val_dataset, homogeneous_concat_mix, ) +from megatron.energon.dataloader.dataloader import DataLoader from megatron.energon.dataset_config import get_dataset_from_config from megatron.energon.edataclass import edataclass from megatron.energon.flavors import BaseWebdatasetFactory @@ -877,6 +877,8 @@ def test_val_limit(self): samples = [[batch.__key__ for batch in loader] for _ in range(10)] print(samples) + for s in samples: + print(" -", s) assert all(samples[0] == one_ep_samples for one_ep_samples in samples) worker_config = WorkerConfig(rank=0, world_size=1, num_workers=2) @@ -894,8 +896,13 @@ def test_val_limit(self): assert len(loader) == 3 samples_wrk2 = [[batch.__key__ for batch in loader] for _ in range(10)] - print(samples) - assert all(samples_wrk2[0] == one_ep_samples for one_ep_samples in samples_wrk2) + print(samples_wrk2) + for s in samples_wrk2: + print(" -", s) + assert all( + all(a == b for a, b in zip(samples_wrk2[0], one_ep_samples)) + for one_ep_samples in samples_wrk2 + ) def test_current_batch_index(self): # Tests if the get_current_batch_index works properly @@ -1260,7 +1267,6 @@ def encode_sample(self, sample): shuffle_buffer_size=20, max_samples_per_sequence=10, ), - worker_config=worker_config_r0, ) loader_r1 = get_savable_loader( get_train_dataset( @@ -1271,7 +1277,6 @@ def encode_sample(self, sample): shuffle_buffer_size=20, max_samples_per_sequence=10, ), - worker_config=worker_config_r1, ) batches = list(zip(range(20), loader)) @@ -1328,7 +1333,6 @@ def encode_sample(self, sample): shuffle_buffer_size=20, max_samples_per_sequence=10, ), - worker_config=worker_config_r0, ) loader.restore_state_rank(state) @@ -1659,7 +1663,7 @@ def test_debug_dataset(self): ) # Reset this to 0 to make sure the test is deterministic - SavableDataLoader._next_id = 0 + DataLoader._next_id = 0 loader = get_savable_loader( get_val_dataset( diff --git a/tests/test_dataset_det.py b/tests/test_dataset_det.py index 911ee2f8..6719dbe1 100644 --- a/tests/test_dataset_det.py +++ b/tests/test_dataset_det.py @@ -362,8 +362,10 @@ class TestTaskEncoder(DefaultTaskEncoder): def encode_sample(self, sample: TextSample) -> TextSample: rand_str = ( f"_{torch.randint(0, 1000, (1,)).item()}_{random.randint(0, 1000)}" + + f"_{WorkerConfig.active_worker_config.worker_seed()}" + f"_{self.current_batch_index}_{self.current_sample_index}" ) + print(f"For sample {sample.__restore_key__}: {sample.text}{rand_str}") return TextSample( __key__=sample.__key__, @@ -410,12 +412,14 @@ def encode_sample(self, sample: TextSample) -> TextSample: # Then save state state = loader1a.save_state_rank() + print("iterating loader1a") # Load another 20 samples data_post = [data.text[0] for idx, data in zip(range(20), loader1a)] # Restore state loader1b.restore_state_rank(state) + print("iterating loader1b") # Load 20 samples again data_restored = [data.text[0] for idx, data in zip(range(20), loader1b)] diff --git a/tests/test_metadataset.py b/tests/test_metadataset.py index d362d616..6567374d 100644 --- a/tests/test_metadataset.py +++ b/tests/test_metadataset.py @@ -3,6 +3,7 @@ """This module defines tests for meta datasets.""" +import dataclasses import gc import logging import sys @@ -14,6 +15,7 @@ from pathlib import Path from typing import Any, Iterable +import numpy as np import torch import webdataset as wds @@ -73,19 +75,21 @@ def assert_nested_equal(a: Any, b: Any, path: str = "") -> None: and other basic types) are equal. If they are not equal, prints the path of the first mismatch and raises an AssertionError. - :param a: First nested structure to compare. - :param b: Second nested structure to compare. - :param path: Internal parameter used to pass the current traversal path (do not set this manually). - :raises AssertionError: If a mismatch is found. + Args: + a: First nested structure to compare. + b: Second nested structure to compare. + path: Internal parameter used to pass the current traversal path (do not set this manually). + + Raises: + AssertionError: If a mismatch is found. """ - # Check if types differ if type(a) is not type(b): + # Check if types differ mismatch_details = f"Type mismatch at {path or ''}: {type(a)} != {type(b)}" print(mismatch_details) raise AssertionError(mismatch_details) - - # If they are both dictionaries, compare each key and value if isinstance(a, dict): + # If they are both dictionaries, compare each key and value # Check if they have the same keys a_keys = set(a.keys()) b_keys = set(b.keys()) @@ -102,9 +106,8 @@ def assert_nested_equal(a: Any, b: Any, path: str = "") -> None: for key in a: sub_path = f"{path}['{key}']" if path else f"['{key}']" assert_nested_equal(a[key], b[key], sub_path) - - # If they are lists (or tuples), compare elements in order elif isinstance(a, (list, tuple)): + # If they are lists (or tuples), compare elements in order if len(a) != len(b): mismatch_details = f"Length mismatch at {path or ''}: {len(a)} != {len(b)}" print(mismatch_details) @@ -112,9 +115,31 @@ def assert_nested_equal(a: Any, b: Any, path: str = "") -> None: for index, (item_a, item_b) in enumerate(zip(a, b)): sub_path = f"{path}[{index}]" if path else f"[{index}]" assert_nested_equal(item_a, item_b, sub_path) - - # Otherwise, compare values directly + elif isinstance(a, torch.Tensor): + if a.shape != b.shape: + mismatch_details = f"Shape mismatch at {path or ''}: {a.shape} != {b.shape}" + print(mismatch_details) + raise AssertionError(mismatch_details) + if not torch.all(a == b): + mismatch_details = f"Value mismatch at {path or ''}: {repr(a)} != {repr(b)}" + print(mismatch_details) + raise AssertionError(mismatch_details) + elif isinstance(a, np.ndarray): + if a.shape != b.shape: + mismatch_details = f"Shape mismatch at {path or ''}: {a.shape} != {b.shape}" + print(mismatch_details) + raise AssertionError(mismatch_details) + if not np.all(a == b): + mismatch_details = f"Value mismatch at {path or ''}: {repr(a)} != {repr(b)}" + print(mismatch_details) + raise AssertionError(mismatch_details) + elif dataclasses.is_dataclass(a): + for field in dataclasses.fields(a): + assert_nested_equal( + getattr(a, field.name), getattr(b, field.name), f"{path}.{field.name}" + ) else: + # Otherwise, compare values directly if a != b: mismatch_details = f"Value mismatch at {path or ''}: {repr(a)} != {repr(b)}" print(mismatch_details) @@ -288,8 +313,8 @@ def test_metadataset(self): for idx, data in zip(range(55), train_loader1) for subflavor in data.__subflavors__ ] - print(train_subflavors[:10]) - print(Counter(train_subflavors)) + print("train_subflavors[:10]", train_subflavors[:10]) + print("Counter(train_subflavors)", Counter(train_subflavors)) assert len(Counter(train_subflavors)) == 2 assert all(250 <= v <= 300 for v in Counter(train_subflavors).values()) @@ -737,222 +762,203 @@ def new_loader(): print(loader.config()) print() reference_config = { - "type": "SavableDataLoader", - "num_workers": 0, - "persistent_workers": False, - "pin_memory": True, - "prefetch_factor": None, + "type": "MapDataset", "dataset": { - "type": "MapDataset", + "type": "BatchDataset", + "batch_size": 10, + "batcher": "megatron.energon.task_encoder.base.DefaultTaskEncoder.batch", + "batcher_stateless": True, + "drop_last": False, + "error_handler": "megatron.energon.wrappers._log_exception.log_exception", + "worker_config": wrk_cfg, "dataset": { - "type": "BatchDataset", - "batch_size": 10, - "batcher": "megatron.energon.task_encoder.base.DefaultTaskEncoder.batch", - "batcher_stateless": True, - "drop_last": False, - "error_handler": "megatron.energon.wrappers._log_exception.log_exception", - "worker_config": wrk_cfg, + "type": "MapDataset", "dataset": { - "type": "MapDataset", - "dataset": { - "type": "BlendDataset", - "dataset_weights": [ - ( - { - "type": "RepeatDataset", + "type": "BlendDataset", + "dataset_weights": [ + ( + { + "type": "RepeatDataset", + "dataset": { + "type": "MapDataset", "dataset": { - "type": "MapDataset", - "dataset": { - "type": "WebdatasetSampleLoaderDataset", - "joins": 1, - "len": 55, - "slice_offsets": [[0, 10, 20, 30, 40, 50, 55]], - "worker_config": wrk_cfg, - "shuffle_over_epochs": 6, - "parallel_slice_iters": 2, - }, - "map_fn": "megatron.energon.flavors.webdataset.base_webdataset.BaseWebdatasetFactory._load_sample_raw", - "map_fn_config": { - "type": "StandardWebdatasetFactory", - "training": True, - "_path": str(self.dataset_path / "ds1"), - "shards": [ - { - "name": "parts/data-0.tar", - "count": 10, - "_path": str( - self.dataset_path - / "ds1/parts/data-0.tar" - ), - }, - { - "name": "parts/data-1.tar", - "count": 10, - "_path": str( - self.dataset_path - / "ds1/parts/data-1.tar" - ), - }, - { - "name": "parts/data-2.tar", - "count": 10, - "_path": str( - self.dataset_path - / "ds1/parts/data-2.tar" - ), - }, - { - "name": "parts/data-3.tar", - "count": 10, - "_path": str( - self.dataset_path - / "ds1/parts/data-3.tar" - ), - }, - { - "name": "parts/data-4.tar", - "count": 10, - "_path": str( - self.dataset_path - / "ds1/parts/data-4.tar" - ), - }, - { - "name": "parts/data-5.tar", - "count": 5, - "_path": str( - self.dataset_path - / "ds1/parts/data-5.tar" - ), - }, - ], - "sample_excludes": [], - "shuffle_over_epochs": 6, - "parallel_shard_iters": 2, - "max_samples_per_sequence": None, - "subflavors": { - "source": "metadataset.yaml", - "dataset.yaml": True, - "number": 43, - "mds": "mds", - "__subflavor__": "ds1", + "type": "WebdatasetSampleLoaderDataset", + "joins": 1, + "len": 55, + "slice_offsets": [[0, 10, 20, 30, 40, 50, 55]], + "worker_config": wrk_cfg, + "shuffle_over_epochs": 6, + "parallel_slice_iters": 2, + }, + "map_fn": "megatron.energon.flavors.webdataset.base_webdataset.BaseWebdatasetFactory._load_sample_raw", + "map_fn_config": { + "type": "StandardWebdatasetFactory", + "training": True, + "_path": str(self.dataset_path / "ds1"), + "shards": [ + { + "name": "parts/data-0.tar", + "count": 10, + "_path": str( + self.dataset_path / "ds1/parts/data-0.tar" + ), + }, + { + "name": "parts/data-1.tar", + "count": 10, + "_path": str( + self.dataset_path / "ds1/parts/data-1.tar" + ), + }, + { + "name": "parts/data-2.tar", + "count": 10, + "_path": str( + self.dataset_path / "ds1/parts/data-2.tar" + ), + }, + { + "name": "parts/data-3.tar", + "count": 10, + "_path": str( + self.dataset_path / "ds1/parts/data-3.tar" + ), + }, + { + "name": "parts/data-4.tar", + "count": 10, + "_path": str( + self.dataset_path / "ds1/parts/data-4.tar" + ), }, - "sample_loader": "megatron.energon.flavors.webdataset.default_generic_webdataset.DefaultGenericWebdatasetFactory.__init__..", - "image_decode": "torchrgb", - "av_decode": "AVDecoder", - "video_decode_audio": False, - "guess_content": False, + { + "name": "parts/data-5.tar", + "count": 5, + "_path": str( + self.dataset_path / "ds1/parts/data-5.tar" + ), + }, + ], + "sample_excludes": [], + "shuffle_over_epochs": 6, + "parallel_shard_iters": 2, + "max_samples_per_sequence": None, + "subflavors": { + "source": "metadataset.yaml", + "dataset.yaml": True, + "number": 43, + "mds": "mds", + "__subflavor__": "ds1", }, - "map_fn_stateless": True, + "sample_loader": "megatron.energon.flavors.webdataset.default_generic_webdataset.DefaultGenericWebdatasetFactory.__init__..", + "image_decode": "torchrgb", + "av_decode": "AVDecoder", + "video_decode_audio": False, + "guess_content": False, }, - "repeats": None, - "worker_config": wrk_cfg, + "map_fn_stateless": True, }, - 0.5, - ), - ( - { - "type": "RepeatDataset", + "repeats": None, + "worker_config": wrk_cfg, + }, + 0.5, + ), + ( + { + "type": "RepeatDataset", + "dataset": { + "type": "MapDataset", "dataset": { - "type": "MapDataset", - "dataset": { - "type": "WebdatasetSampleLoaderDataset", - "joins": 1, - "len": 55, - "slice_offsets": [[0, 10, 20, 30, 40, 50, 55]], - "worker_config": wrk_cfg, - "shuffle_over_epochs": 2, - "parallel_slice_iters": 2, - }, - "map_fn": "megatron.energon.flavors.webdataset.base_webdataset.BaseWebdatasetFactory._load_sample_raw", - "map_fn_config": { - "type": "StandardWebdatasetFactory", - "training": True, - "_path": str(self.dataset_path / "ds2"), - "shards": [ - { - "name": "parts/data-0.tar", - "count": 10, - "_path": str( - self.dataset_path - / "ds2/parts/data-0.tar" - ), - }, - { - "name": "parts/data-1.tar", - "count": 10, - "_path": str( - self.dataset_path - / "ds2/parts/data-1.tar" - ), - }, - { - "name": "parts/data-2.tar", - "count": 10, - "_path": str( - self.dataset_path - / "ds2/parts/data-2.tar" - ), - }, - { - "name": "parts/data-3.tar", - "count": 10, - "_path": str( - self.dataset_path - / "ds2/parts/data-3.tar" - ), - }, - { - "name": "parts/data-4.tar", - "count": 10, - "_path": str( - self.dataset_path - / "ds2/parts/data-4.tar" - ), - }, - { - "name": "parts/data-5.tar", - "count": 5, - "_path": str( - self.dataset_path - / "ds2/parts/data-5.tar" - ), - }, - ], - "sample_excludes": [], - "shuffle_over_epochs": 2, - "parallel_shard_iters": 2, - "max_samples_per_sequence": None, - "subflavors": { - "source": "metadataset.yaml", - "dataset.yaml": True, - "number": 44, - "mds": "mds", - "__subflavor__": "ds2", + "type": "WebdatasetSampleLoaderDataset", + "joins": 1, + "len": 55, + "slice_offsets": [[0, 10, 20, 30, 40, 50, 55]], + "worker_config": wrk_cfg, + "shuffle_over_epochs": 2, + "parallel_slice_iters": 2, + }, + "map_fn": "megatron.energon.flavors.webdataset.base_webdataset.BaseWebdatasetFactory._load_sample_raw", + "map_fn_config": { + "type": "StandardWebdatasetFactory", + "training": True, + "_path": str(self.dataset_path / "ds2"), + "shards": [ + { + "name": "parts/data-0.tar", + "count": 10, + "_path": str( + self.dataset_path / "ds2/parts/data-0.tar" + ), + }, + { + "name": "parts/data-1.tar", + "count": 10, + "_path": str( + self.dataset_path / "ds2/parts/data-1.tar" + ), }, - "sample_loader": "megatron.energon.flavors.webdataset.default_generic_webdataset.DefaultGenericWebdatasetFactory.__init__..", - "image_decode": "torchrgb", - "av_decode": "AVDecoder", - "video_decode_audio": False, - "guess_content": False, + { + "name": "parts/data-2.tar", + "count": 10, + "_path": str( + self.dataset_path / "ds2/parts/data-2.tar" + ), + }, + { + "name": "parts/data-3.tar", + "count": 10, + "_path": str( + self.dataset_path / "ds2/parts/data-3.tar" + ), + }, + { + "name": "parts/data-4.tar", + "count": 10, + "_path": str( + self.dataset_path / "ds2/parts/data-4.tar" + ), + }, + { + "name": "parts/data-5.tar", + "count": 5, + "_path": str( + self.dataset_path / "ds2/parts/data-5.tar" + ), + }, + ], + "sample_excludes": [], + "shuffle_over_epochs": 2, + "parallel_shard_iters": 2, + "max_samples_per_sequence": None, + "subflavors": { + "source": "metadataset.yaml", + "dataset.yaml": True, + "number": 44, + "mds": "mds", + "__subflavor__": "ds2", }, - "map_fn_stateless": True, + "sample_loader": "megatron.energon.flavors.webdataset.default_generic_webdataset.DefaultGenericWebdatasetFactory.__init__..", + "image_decode": "torchrgb", + "av_decode": "AVDecoder", + "video_decode_audio": False, + "guess_content": False, }, - "repeats": None, - "worker_config": wrk_cfg, + "map_fn_stateless": True, }, - 0.5, - ), - ], - "worker_config": wrk_cfg, - }, - "map_fn": "megatron.energon.task_encoder.base.DefaultTaskEncoder.encode_sample", - "map_fn_stateless": True, + "repeats": None, + "worker_config": wrk_cfg, + }, + 0.5, + ), + ], + "worker_config": wrk_cfg, }, + "map_fn": "megatron.energon.task_encoder.base.DefaultTaskEncoder.encode_sample", + "map_fn_stateless": True, }, - "map_fn": "megatron.energon.task_encoder.base.DefaultTaskEncoder.encode_batch", - "map_fn_stateless": True, }, + "map_fn": "megatron.energon.task_encoder.base.DefaultTaskEncoder.encode_batch", + "map_fn_stateless": True, } print("Comparing dataset configs in test_save_restore_state_train.") assert_nested_equal(loader.config(), reference_config) @@ -1357,6 +1363,13 @@ def test_save_restore_next(self): print("same_state:", same_state) assert same_state == state_initial + # This will propagate the state to the workers. + second_loader._start() + # Save the state again, to check that it is the same as the just restored state + same_state = second_loader.save_state_rank() + print("same_state:", same_state) + assert_nested_equal(same_state, state_initial) + for offset in range(10): try: # Save state and restore in next loader From 00263e32376c4e267f76a3ebc5e83ed42ac1c756 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Tue, 19 Aug 2025 15:14:39 +0200 Subject: [PATCH 15/36] Implement RestoreKey as dataclass. Fix redistribution of checkpoints --- src/megatron/energon/dataloader/dataloader.py | 27 ++-- .../energon/dataloader/workers/base_worker.py | 17 ++- src/megatron/energon/flavors/base_dataset.py | 88 +++++++---- .../energon/flavors/webdataset/itar_reader.py | 8 +- .../flavors/webdataset/joined_webdataset.py | 3 +- .../flavors/webdataset/sample_loader.py | 24 +-- .../energon/flavors/webdataset/structs.py | 17 ++- src/megatron/energon/task_encoder/base.py | 4 +- src/megatron/energon/tools/checkpoint.py | 138 ++++++++++++++---- src/megatron/energon/wrappers/base.py | 53 +++++-- .../energon/wrappers/batch_dataset.py | 76 ++++++---- .../energon/wrappers/blend_dataset.py | 21 ++- src/megatron/energon/wrappers/buffer.py | 21 ++- .../energon/wrappers/concat_dataset.py | 24 ++- .../energon/wrappers/group_batch_dataset.py | 29 +++- .../energon/wrappers/iter_map_dataset.py | 47 +++--- src/megatron/energon/wrappers/map_dataset.py | 84 ++++++----- .../energon/wrappers/packing_dataset.py | 83 ++++++----- .../wrappers/shuffle_buffer_dataset.py | 5 +- .../wrappers/task_encoder_state_dataset.py | 11 +- tests/test_dataset.py | 4 +- tests/test_dataset_det.py | 82 ++++++----- 22 files changed, 561 insertions(+), 305 deletions(-) diff --git a/src/megatron/energon/dataloader/dataloader.py b/src/megatron/energon/dataloader/dataloader.py index dbbd763a..dd2bed2a 100644 --- a/src/megatron/energon/dataloader/dataloader.py +++ b/src/megatron/energon/dataloader/dataloader.py @@ -23,11 +23,12 @@ from megatron.energon.dataloader.workers.base_worker import ( DataLoaderNoWorker, DataLoaderWorker, + WorkerSampleRestoreKey, WorkerState, ) from megatron.energon.dataloader.workers.fork_worker import ForkDataLoaderWorker from megatron.energon.edataclass import edataclass -from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key +from megatron.energon.flavors.base_dataset import RestoreKey, SavableDataset, set_sample_restore_key from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset, get_sample_restore_key from megatron.energon.wrappers.batch_dataset import BatchDataset @@ -45,9 +46,14 @@ class RankState: State of a rank. """ - prefetched_samples_keys: list[Any] + #: This is a list (per worker) of lists of (batch) sample keys, which have been (asynchronously) prefetched from workers + # but not been fetched yet by iterating. + prefetched_samples_keys: list[list[RestoreKey | None]] + #: This is a list of worker states, which have been saved from the workers (or `None` for the initial state). worker_states: list[WorkerState | None] + #: The next worker ID to prefetch from (i.e. append to the prefetched samples). next_worker_id: int + #: The micro batch size of the dataset, or `None` if not known. Needed for redistributing the state. micro_batch_size: int | None @@ -533,7 +539,7 @@ def restore_state_global( self.restore_state_rank(rank_state) - def restore_sample(self, restore_key: tuple) -> TSample: + def restore_sample(self, restore_key: RestoreKey) -> TSample: """ Restore a sample from a restore key. @@ -543,18 +549,15 @@ def restore_sample(self, restore_key: tuple) -> TSample: Returns: The restored sample. """ - id, global_worker_id, sample_idx = restore_key[:3] - assert id == "DataLoaderWorker", f"id {id} != DataLoaderWorker" - restore_key = restore_key[3:] + assert isinstance(restore_key, WorkerSampleRestoreKey) self._worker_config.worker_activate( - sample_idx, override_global_rank=global_worker_id, cache_pool=self._cache_pool + restore_key.sample_idx, + override_global_rank=restore_key.worker_id, + cache_pool=self._cache_pool, ) try: - return add_sample_restore_key( - self._dataset.restore_sample(restore_key), - global_worker_id, - sample_idx, - src=DataLoaderWorker.__name__, + return set_sample_restore_key( + self._dataset.restore_sample(restore_key.inner), restore_key ) finally: self._worker_config.worker_deactivate() diff --git a/src/megatron/energon/dataloader/workers/base_worker.py b/src/megatron/energon/dataloader/workers/base_worker.py index af7ec73f..fcf5b932 100644 --- a/src/megatron/energon/dataloader/workers/base_worker.py +++ b/src/megatron/energon/dataloader/workers/base_worker.py @@ -1,18 +1,26 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause +from dataclasses import dataclass from typing import Generic, TypeVar from megatron.energon.cache.base import CachePool from megatron.energon.dataloader.future import DoneFuture, ExceptionFuture, Future from megatron.energon.edataclass import edataclass -from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key +from megatron.energon.flavors.base_dataset import SavableDataset from megatron.energon.rng import SystemRng, SystemRngState from megatron.energon.state import FlexState from megatron.energon.worker import WorkerConfig +from megatron.energon.wrappers.base import WrappedRestoreKey, wrap_sample_restore_key TSample = TypeVar("TSample", covariant=True) +@dataclass(kw_only=True, slots=True, frozen=True) +class WorkerSampleRestoreKey(WrappedRestoreKey): + worker_id: int + sample_idx: int + + @edataclass class WorkerState: """ @@ -170,8 +178,11 @@ def prefetch_next(self) -> Future[TSample]: try: next_sample = next(self._dataset_iter) self._sample_index += 1 - next_sample = add_sample_restore_key( - next_sample, self._global_worker_id, sample_idx, src=DataLoaderWorker.__name__ + next_sample = wrap_sample_restore_key( + next_sample, + WorkerSampleRestoreKey, + worker_id=self._global_worker_id, + sample_idx=sample_idx, ) except StopIteration as e: self._exhausted = True diff --git a/src/megatron/energon/flavors/base_dataset.py b/src/megatron/energon/flavors/base_dataset.py index 02518622..9a3bfb0a 100644 --- a/src/megatron/energon/flavors/base_dataset.py +++ b/src/megatron/energon/flavors/base_dataset.py @@ -27,6 +27,7 @@ from torch.utils.data import IterableDataset from typing_extensions import Self +import megatron.energon from megatron.energon.cache import FileStore from megatron.energon.edataclass import edataclass from megatron.energon.epathlib import EPath @@ -130,7 +131,8 @@ class Sample(ABC, PinMemoryMixin, ExtendableDataclassMixin): __key__: str #: Key for restoring the sample. This is used to restore the sample from a checkpoint. It # should be a (nested) tuple of strings and integers, which can be used to index the dataset. - __restore_key__: Tuple[Union[str, int, tuple], ...] + # May be None in some cases, but it may then not be restorable. + __restore_key__: "RestoreKey | None" #: A dataset may define a subflavors to distinguish between samples of the same sample type. __subflavors__: Optional[Dict[str, Any]] = None @@ -396,16 +398,18 @@ def assert_can_restore(self) -> None: """Asserts that the dataset can restore a sample from a key.""" assert self.can_restore_sample(), "This dataset cannot restore samples." - def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_sample: + def restore_sample(self, restore_key: "RestoreKey") -> T_sample: """ - Generic key type, because it might be either an integer (for a core dataset), or something - more complex (e.g. for blended datasets). + Restores a sample from a restore key. - Default raises an exception (assumed non-deterministic if not implemented, does not - guarantee determinism). + Args: + restore_key: The restore key to restore the sample from. + + Returns: + The restored sample. """ raise NotImplementedError( - "This dataset does not support indexing, because it is not safely deterministic." + "This dataset does not support restoring, because it is not safely deterministic." ) if THREAD_SAFE: @@ -461,43 +465,65 @@ def __len__(self) -> int: ... -def add_sample_restore_key( - sample: T_sample, *key: Union[int, str], src: Any, fail_otherwise: bool = False -) -> T_sample: - """Adds a key to a sample. The sample must be a valid `Sample` or dict containing - __restore_key__, which is a tuple of keys that can be used to restore the inner sample. - This restore key is prepended with the `key`.""" - if not isinstance(src, str): - src = type(src).__name__ - if isinstance(sample, Sample) or hasattr(sample, "__restore_key__"): - try: - sample.__restore_key__ = (src, *key, *sample.__restore_key__) - except KeyError: - pass - elif isinstance(sample, dict) and "__restore_key__" in sample: - sample["__restore_key__"] = (src, *key, *sample["__restore_key__"]) - elif fail_otherwise: - raise RuntimeError( - "Did not yield a sample with a restore key, but is marked stateless/deterministic." +@dataclasses.dataclass(kw_only=True, slots=True, frozen=True) +class RestoreKey(ABC): + """Base class for restore keys.""" + + def _tupleify(self, value: Any) -> Any: + if isinstance(value, (int, str, float, bool)): + return value + elif isinstance(value, RestoreKey): + return value.as_tuple() + elif isinstance(value, (list, tuple)): + return tuple(self._tupleify(v) for v in value) + else: + return value + + def as_tuple(self) -> tuple[Any, ...]: + return ( + self.__class__.__name__, + *( + getattr(self, field.name).json() + if isinstance(getattr(self, field.name), RestoreKey) + else getattr(self, field.name) + for field in dataclasses.fields(self) + ), ) - return sample + + @staticmethod + def _untupleify(value: Any) -> Any: + if isinstance(value, (int, str, float, bool)): + return value + elif isinstance(value, RestoreKey): + return value.from_tuple(value) + elif isinstance(value, (list, tuple)): + if isinstance(value[0], str) and hasattr(megatron.energon, value[0]): + return getattr(megatron.energon, value[0]).from_tuple(value[1:]) + else: + return tuple(RestoreKey._untupleify(v) for v in value) + + @staticmethod + def from_tuple(json: tuple[Any, ...]) -> "RestoreKey": + cls = getattr(megatron.energon, json[0]) + kwargs = {} + for field in dataclasses.fields(cls): + kwargs[field.name] = RestoreKey._untupleify(json[1:]) + return cls(**kwargs) def set_sample_restore_key( - sample: T_sample, *key: Union[int, str], src: Any, fail_otherwise: bool = False + sample: T_sample, restore_key: RestoreKey, fail_otherwise: bool = False ) -> T_sample: """Sets the restore key for a sample. The sample must be a valid `Sample` or dict containing __restore_key__, which is a tuple of keys that can be used to restore the inner sample. This restore key is prepended with the `key`.""" - if not isinstance(src, str): - src = type(src).__name__ if isinstance(sample, Sample) or hasattr(sample, "__restore_key__"): try: - sample.__restore_key__ = (src, *key) + sample.__restore_key__ = restore_key except KeyError: pass elif isinstance(sample, dict) and "__restore_key__" in sample: - sample["__restore_key__"] = (src, *key) + sample["__restore_key__"] = restore_key elif fail_otherwise: raise RuntimeError( "Did not yield a sample with a restore key, but is marked stateless/deterministic." diff --git a/src/megatron/energon/flavors/webdataset/itar_reader.py b/src/megatron/energon/flavors/webdataset/itar_reader.py index d2bdd826..06d058fe 100644 --- a/src/megatron/energon/flavors/webdataset/itar_reader.py +++ b/src/megatron/energon/flavors/webdataset/itar_reader.py @@ -26,7 +26,11 @@ ITarSamplePointer, ) from megatron.energon.flavors.webdataset.metadata import get_info_shard_files -from megatron.energon.flavors.webdataset.structs import FilteredSample, ShardInfo +from megatron.energon.flavors.webdataset.structs import ( + FilteredSample, + ShardInfo, + WebdatasetRestoreKey, +) from megatron.energon.source_info import SourceInfo T_index = TypeVar("T_index", covariant=False) @@ -188,7 +192,7 @@ def _get_item_by_sample_pointer( return FilteredSample( __key__=f"{shard_name}/{sample_base_name}", __shard__=self.tar_filenames[sample_pointer.tar_file_id], - __restore_key__=("Webdataset", restore_index), + __restore_key__=WebdatasetRestoreKey(index=restore_index), __sources__=( SourceInfo( dataset_path=self.base_path, diff --git a/src/megatron/energon/flavors/webdataset/joined_webdataset.py b/src/megatron/energon/flavors/webdataset/joined_webdataset.py index 328bb509..6bbb1a69 100644 --- a/src/megatron/energon/flavors/webdataset/joined_webdataset.py +++ b/src/megatron/energon/flavors/webdataset/joined_webdataset.py @@ -226,8 +226,7 @@ def load_sample(self, samples: RawSampleData) -> T_sample: # Then combine the loaded smaples into the final type return set_sample_restore_key( self._sample_joiner(*loaded_samples), - *samples.__restore_key__, - src=self, + samples.__restore_key__, fail_otherwise=True, ) diff --git a/src/megatron/energon/flavors/webdataset/sample_loader.py b/src/megatron/energon/flavors/webdataset/sample_loader.py index 78c2399f..9c3838d9 100644 --- a/src/megatron/energon/flavors/webdataset/sample_loader.py +++ b/src/megatron/energon/flavors/webdataset/sample_loader.py @@ -1,14 +1,14 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Dict, Generator, Iterator, List, Optional, Sequence, Tuple, Union +from typing import Any, Dict, Generator, Iterator, List, Optional, Sequence, Tuple import torch from megatron.energon.edataclass import edataclass -from megatron.energon.flavors.base_dataset import FlexState, SavableDataset +from megatron.energon.flavors.base_dataset import FlexState, RestoreKey, SavableDataset from megatron.energon.flavors.webdataset.itar_reader import ITarReader -from megatron.energon.flavors.webdataset.structs import FilteredSample +from megatron.energon.flavors.webdataset.structs import FilteredSample, WebdatasetRestoreKey from megatron.energon.rng import WorkerRng from megatron.energon.worker import WorkerConfig @@ -18,7 +18,7 @@ class RawSampleData: """Represents the iteration state of a single slice slice to the index.""" #: Index of the sample. This is also the restore key - __restore_key__: Tuple[str, int] + __restore_key__: WebdatasetRestoreKey #: The sample data data: Tuple[Optional[FilteredSample], ...] @@ -140,7 +140,7 @@ def ensure_slice_offsets(self) -> None: def _get_sample(self, index: int) -> RawSampleData: return RawSampleData( - __restore_key__=("Webdataset", index), + __restore_key__=WebdatasetRestoreKey(index=index), data=tuple(reader[index] for reader in self.join_readers), ) @@ -369,7 +369,7 @@ def slice_at(idx: int) -> SliceState: "t": "WebdatasetSampleLoaderDataset._slices_iter.yield", "r": self.worker_config.rank, "w": self.worker_config.rank_worker_id(), - "index": sample.__restore_key__[1], + "index": sample.__restore_key__.index, "key": sample.data[0]["__key__"], "shard": sample.data[0]["__shard__"], "count": self._sample_count, @@ -443,13 +443,13 @@ def can_restore_sample(self) -> bool: def assert_can_restore(self) -> None: pass - def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> RawSampleData: - # Key is: ("Webdataset", index) + def restore_sample(self, restore_key: RestoreKey) -> RawSampleData: # The key is joined in the dataset's typed joining (i.e. load_sample of JoinedWebdatasetFactory). - id, index = restore_key - assert id == "Webdataset" - assert isinstance(index, int) - return self._get_sample(index) + assert isinstance(restore_key, WebdatasetRestoreKey) + assert isinstance(restore_key.index, int), ( + "WebdatasetRestoreKey.index must be an integer, cannot restore by sample key" + ) + return self._get_sample(restore_key.index) def config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/flavors/webdataset/structs.py b/src/megatron/energon/flavors/webdataset/structs.py index 6794e158..79bafd53 100644 --- a/src/megatron/energon/flavors/webdataset/structs.py +++ b/src/megatron/energon/flavors/webdataset/structs.py @@ -1,11 +1,12 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -from dataclasses import field -from typing import Dict, List, Optional, Tuple, TypedDict +from dataclasses import dataclass, field +from typing import Dict, List, Optional, TypedDict from megatron.energon.edataclass import edataclass from megatron.energon.epathlib import EPath +from megatron.energon.flavors.base_dataset import RestoreKey from megatron.energon.source_info import SourceInfo @@ -42,6 +43,15 @@ class ShardInfo: count: int +@dataclass(kw_only=True, slots=True, frozen=True) +class WebdatasetRestoreKey(RestoreKey): + """Restore key for the webdataset sample loader.""" + + #: The global index of the sample. Either an integer representing the global sample index, + # or a string representing the sample key within the shard. + index: int | str + + class FilteredSample(TypedDict): """This is just a definition for the internal loaders. Not exposed to the user.""" @@ -52,8 +62,7 @@ class FilteredSample(TypedDict): #: The base name of the shard file e.g. "shard_000" __shard__: str #: Globally unique key to restore a sample from disk. - #: For example `("Webdataset", 123)` would restore the sample at index 123. - __restore_key__: Tuple[str, int] + __restore_key__: WebdatasetRestoreKey #: The source information for the sample. __sources__: tuple[SourceInfo, ...] diff --git a/src/megatron/energon/task_encoder/base.py b/src/megatron/energon/task_encoder/base.py index ddb39b4e..8f31ba84 100644 --- a/src/megatron/energon/task_encoder/base.py +++ b/src/megatron/energon/task_encoder/base.py @@ -38,7 +38,7 @@ SampleDecoder, SavableDataset, ) -from megatron.energon.flavors.base_dataset import ExtendableDataclassMixin +from megatron.energon.flavors.base_dataset import ExtendableDataclassMixin, RestoreKey from megatron.energon.metadataset.loader_interface import DatasetBlendMode, LoadedDataset from megatron.energon.rng import SystemRng, UserRng from megatron.energon.savable import Savable @@ -325,7 +325,7 @@ class Batch(PinMemoryMixin, ExtendableDataclassMixin): __key__: list[str] #: Key for restoring the sample. This is used to restore the sample from a checkpoint. It # should be a (nested) tuple of strings and integers, which can be used to index the dataset. - __restore_key__: Tuple[Union[str, int, tuple], ...] + __restore_key__: Tuple[RestoreKey | None, ...] #: A dataset may define a subflavors to distinguish between samples of the same sample type. __subflavors__: Optional[list[Optional[Dict[str, Any]]]] = None diff --git a/src/megatron/energon/tools/checkpoint.py b/src/megatron/energon/tools/checkpoint.py index 08115a35..30769065 100644 --- a/src/megatron/energon/tools/checkpoint.py +++ b/src/megatron/energon/tools/checkpoint.py @@ -1,17 +1,22 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause +import dataclasses import re -from typing import List, Optional +from typing import Callable, Generator, List, Optional import click import torch +from megatron.energon.dataloader.dataloader import RankState +from megatron.energon.dataloader.workers.base_worker import WorkerState from megatron.energon.epathlib import EPath -from megatron.energon.savable_loader import SavableDataLoaderState +from megatron.energon.flavors.base_dataset import RestoreKey +from megatron.energon.wrappers.base import WrappedRestoreKey +from megatron.energon.wrappers.batch_dataset import BatchRestoreKey -def natural_sort_key(s): +def natural_sort_key(s: str) -> List[str | int]: """ Function to use for natural sorting of filenames. @@ -21,7 +26,7 @@ def natural_sort_key(s): return [int(text) if text.isdigit() else text.lower() for text in re.split(r"(\d+)", s)] -def detect_and_replicate_pattern(file_list): +def detect_and_replicate_pattern(file_list: List[str]) -> Callable[[int], str]: """ Given a list of file paths, detect the single numeric pattern and return a function that, when called with integer n (starting from 0), generates @@ -156,7 +161,7 @@ def __init__(self, state_files: List[EPath]): else: self.megatron_style = False - if isinstance(first_state, SavableDataLoaderState): + if isinstance(first_state, RankState): if self.megatron_style: self.rank_states = [first_state] + [ torch.load(str(state_file), weights_only=False)["dataloader_state_dict"] @@ -170,7 +175,7 @@ def __init__(self, state_files: List[EPath]): self.is_global_checkpoint = False elif isinstance(first_state, list): assert len(state_files) == 1, "Global checkpoint must contain exactly one file" - assert all(isinstance(state, SavableDataLoaderState) for state in first_state) + assert all(isinstance(state, RankState) for state in first_state) self.rank_states = first_state self.is_global_checkpoint = True else: @@ -184,9 +189,12 @@ def __init__(self, state_files: List[EPath]): self.rank_num_workers[0] == num_workers for num_workers in self.rank_num_workers ), "All ranks must have the same number of workers." - def write_new_states_to_folder( - self, output_folder: EPath, new_states: List[SavableDataLoaderState] - ): + assert all( + rank_state.micro_batch_size == self.rank_states[0].micro_batch_size + for rank_state in self.rank_states[1:] + ), "All ranks must have the same micro batch size." + + def write_new_states_to_folder(self, output_folder: EPath, new_states: List[RankState]): for rank_idx, rank_state in enumerate(new_states): output_file = output_folder / self.file_pattern_func(rank_idx) if self.megatron_style: @@ -197,20 +205,66 @@ def write_new_states_to_folder( else: torch.save(rank_state, str(output_file)) - def get_num_ranks(self): + def get_num_ranks(self) -> int: return len(self.rank_states) - def get_num_workers(self): + def get_num_workers(self) -> int: return self.rank_num_workers[0] - def get_micro_batch_size(self): + def get_micro_batch_size(self) -> int | None: return self.rank_states[0].micro_batch_size - def __iter__(self): - """Iterates the SavableDatasetCheckpoints of mulitple ranks in a round-robin fashion.""" - for rank, state in enumerate(self.rank_states): - for worker_state in state.worker_states: - yield worker_state + def __iter__(self) -> Generator[tuple[WorkerState | None, list[RestoreKey | None]], None, None]: + """Iterates the WorkerStates of multiple ranks in a round-robin fashion.""" + for rank_state in self.rank_states: + for worker_state, prefetched_samples_keys in zip( + rank_state.worker_states, rank_state.prefetched_samples_keys + ): + yield worker_state, prefetched_samples_keys + + +def split_batch_restore_key( + restore_key: RestoreKey | None, batch_split_factor: int +) -> list[RestoreKey | None]: + """Split the given restore_key into multiple restore keys, one for each batch.""" + if restore_key is None: + raise ValueError("Cannot split None restore key") + if isinstance(restore_key, BatchRestoreKey): + # Split the inner keys into batch_split_factor keys + # Duplicate the sample_idx for each batch + assert len(restore_key.inner) % batch_split_factor == 0, ( + "Batch size must be a multiple of the batch split factor" + ) + split_size = len(restore_key.inner) // batch_split_factor + return [ + BatchRestoreKey( + inner=tuple(restore_key.inner[i : i + split_size]), + sample_idx=restore_key.sample_idx, + ) + for i in range(0, len(restore_key.inner), split_size) + ] + elif isinstance(restore_key, WrappedRestoreKey): + inner_restore_keys = split_batch_restore_key(restore_key.inner, batch_split_factor) + inner_kwargs = dataclasses.asdict(restore_key) + inner_kwargs.pop("inner") + return [ + type(restore_key)(**inner_kwargs, inner=inner_restore_key) + for inner_restore_key in inner_restore_keys + ] + else: + raise ValueError(f"Unsupported restore key type for splitting batch: {type(restore_key)}") + + +def split_batch_restore_keys( + restore_keys: list[RestoreKey | None], batch_split_factor: int +) -> list[RestoreKey | None]: + if batch_split_factor == 1: + return restore_keys + return [ + new_restore_key + for restore_key in restore_keys + for new_restore_key in split_batch_restore_key(restore_key, batch_split_factor) + ] @click.command(name="redist") @@ -227,8 +281,12 @@ def __iter__(self): @click.option( "--new-world-size", type=int, help="Number of ranks to redistribute to", required=False ) +@click.option("--new-micro-batch-size", type=int, help="New micro batch size", required=False) def command_redist( - input_files: List[EPath], output_path: EPath, new_world_size: Optional[int] = None + input_files: List[EPath], + output_path: EPath, + new_world_size: Optional[int] = None, + new_micro_batch_size: Optional[int] = None, ): """Redistribute a checkpoint. @@ -267,22 +325,52 @@ def command_redist( # Ensure output directory exists output_path.mkdir(exist_ok=True, parents=True) - new_rank_states = [list() for _ in range(new_world_size)] + # A list (rank) of lists (workers) of (worker_state, prefetched_sample_keys) for each new rank + new_rank_states = [[] for _ in range(new_world_size)] rsi_iter = iter(rsi) for rank_idx in range(new_world_size): for _ in range(new_workers_per_rank): - state = next(rsi_iter) - new_rank_states[rank_idx].append(state) + worker_state, prefetched_sample_keys = next(rsi_iter) + new_rank_states[rank_idx].append((worker_state, prefetched_sample_keys)) assert all( - len(new_rank_states[0]) == len(new_rank_states[rank]) for rank in range(1, new_world_size) + len(new_rank_states[0]) == len(rank_states) for rank_states in new_rank_states[1:] ), "All ranks must have the same number of workers, also for the new distribution." + # Check batch sizes (before and after) + old_micro_batch_size = rsi.get_micro_batch_size() + if old_micro_batch_size is not None and new_micro_batch_size != old_micro_batch_size: + assert new_micro_batch_size is not None and old_micro_batch_size is not None, ( + "Cannot resume with different batching mode (batching to non-batching or vice versa)" + ) + + if new_micro_batch_size > old_micro_batch_size: + raise ValueError( + "Resuming with larger micro batch size is not allowed: " + f"{new_micro_batch_size} > {old_micro_batch_size}" + ) + elif ( + new_micro_batch_size < old_micro_batch_size + and old_micro_batch_size % new_micro_batch_size != 0 + ): + raise ValueError( + "Resuming with smaller micro batch size only allowed if the old " + f"micro batch size is a multiple of the new one: {new_micro_batch_size} < {old_micro_batch_size}" + ) + batch_split_factor = old_micro_batch_size // new_micro_batch_size + print(f"Splitting batches by {batch_split_factor}x") + else: + batch_split_factor = 1 + new_states = [ - SavableDataLoaderState( - worker_states=new_rank_state, + RankState( + worker_states=[worker_state for worker_state, prefetched_sample_keys in new_rank_state], next_worker_id=0, # Reset the next worker ID - micro_batch_size=rsi.get_micro_batch_size(), + micro_batch_size=new_micro_batch_size, + prefetched_samples_keys=[ + split_batch_restore_keys(prefetched_sample_keys, batch_split_factor) + for worker_state, prefetched_sample_keys in new_rank_state + ], ) for new_rank_state in new_rank_states ] diff --git a/src/megatron/energon/wrappers/base.py b/src/megatron/energon/wrappers/base.py index 2bfe4618..f90bca21 100644 --- a/src/megatron/energon/wrappers/base.py +++ b/src/megatron/energon/wrappers/base.py @@ -1,6 +1,7 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause +import dataclasses import threading from abc import ABC, abstractmethod from contextlib import contextmanager @@ -8,9 +9,9 @@ from megatron.energon.flavors.base_dataset import ( FlexState, + RestoreKey, Sample, SavableDataset, - add_sample_restore_key, ) from megatron.energon.savable import Savable from megatron.energon.worker import WorkerConfig @@ -80,20 +81,6 @@ def _find_wrapped_dataset(self, cls: Type[SavableDataset]) -> Optional[SavableDa return res return None - def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_sample_out: - if len(self.datasets) == 1: - return self.datasets[0].restore_sample(restore_key) - else: - id, ds_idx = restore_key[:2] - assert id == type(self).__name__ - restore_key = restore_key[2:] - assert isinstance(ds_idx, int) - return add_sample_restore_key( - self.datasets[ds_idx].restore_sample(restore_key), - ds_idx, - src=self, - ) - @abstractmethod def reset_state_own(self) -> None: """Resets the state of the dataset, excl. the inner datasets.""" @@ -120,6 +107,10 @@ def restore_state(self, state: FlexState) -> None: super().restore_state(state) + def restore_sample(self, restore_key: RestoreKey) -> T_sample_out: + assert len(self.datasets) == 1, "Must be implemented by subclass" + return self.dataset.restore_sample(restore_key) + class SampleIndex(Savable): """A simple class to hold the sample index for one worker.""" @@ -186,7 +177,37 @@ def restore_state(self, state: Optional[int]) -> None: self.current_idx = state -def get_sample_restore_key(sample: Any) -> Optional[Union[str, int]]: +@dataclasses.dataclass(kw_only=True, slots=True, frozen=True) +class WrappedRestoreKey(RestoreKey): + inner: RestoreKey + + +def wrap_sample_restore_key( + sample: T_sample, t: "Type[WrappedRestoreKey]", fail_otherwise: bool = False, **kwargs +) -> T_sample: + """Adds a key to a sample. The sample must be a valid `Sample` or dict containing + __restore_key__, which is a tuple of keys that can be used to restore the inner sample. + This restore key is prepended with the `key`.""" + if isinstance(sample, Sample) or hasattr(sample, "__restore_key__"): + try: + sample.__restore_key__ = t(inner=sample.__restore_key__, **kwargs) + except KeyError: + pass + elif isinstance(sample, dict) and "__restore_key__" in sample: + sample["__restore_key__"] = t(inner=sample["__restore_key__"], **kwargs) + elif fail_otherwise: + raise RuntimeError( + "Did not yield a sample with a restore key, but is marked stateless/deterministic." + ) + return sample + + +@dataclasses.dataclass(kw_only=True, slots=True, frozen=True) +class MultiWrappedRestoreKey(RestoreKey): + inner: tuple[RestoreKey | None, ...] + + +def get_sample_restore_key(sample: Any) -> RestoreKey | None: """Gets the restore key from an arbitrary sample.""" if isinstance(sample, Sample) or hasattr(sample, "__restore_key__"): return sample.__restore_key__ diff --git a/src/megatron/energon/wrappers/batch_dataset.py b/src/megatron/energon/wrappers/batch_dataset.py index 10b534f6..3aa9bfcd 100644 --- a/src/megatron/energon/wrappers/batch_dataset.py +++ b/src/megatron/energon/wrappers/batch_dataset.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import inspect +from dataclasses import dataclass from typing import ( Any, Callable, @@ -12,7 +13,6 @@ List, Optional, Sequence, - Tuple, TypeVar, Union, ) @@ -22,13 +22,29 @@ from megatron.energon.source_info import SourceInfo from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers._log_exception import log_exception -from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key +from megatron.energon.wrappers.base import ( + BaseWrapperDataset, + MultiWrappedRestoreKey, + RestoreKey, + SampleIndex, + get_sample_restore_key, +) from megatron.energon.wrappers.skip import SkipSample T_batch = TypeVar("T_batch", covariant=True) T_batch_sample = TypeVar("T_batch_sample", covariant=True) +@dataclass(kw_only=True, slots=True, frozen=True) +class BatchRestoreKey(MultiWrappedRestoreKey): + sample_idx: int + + +@dataclass(kw_only=True, slots=True, frozen=True) +class BatchGenRestoreKey(BatchRestoreKey): + gen_idx: int | None = None + + class BatchDataset(BaseWrapperDataset[T_batch_sample, T_batch], Generic[T_batch_sample, T_batch]): """This dataset wrapper transforms a dataset of samples into a dataset of batches.""" @@ -97,7 +113,7 @@ def len_worker(self, worker_idx: int | None = None) -> int: def __iter__(self) -> Iterator[T_batch]: batch: List[T_batch_sample] = [] - sample_restore_keys = [] + sample_restore_keys: list[RestoreKey | None] = [] last_batch_failures = 0 @@ -121,10 +137,11 @@ def __iter__(self) -> Iterator[T_batch]: self._generator_offset = batch_sub_idx + 1 yield set_sample_restore_key( inner_batch_sample, - sample_idx, - batch_sub_idx, - *sample_restore_keys, - src=self, + BatchGenRestoreKey( + sample_idx=sample_idx, + gen_idx=batch_sub_idx, + inner=tuple(sample_restore_keys), + ), ) self._generator_sample_keys = None self._generator_offset = None @@ -150,16 +167,20 @@ def flush() -> Generator[T_batch, None, None]: self._generator_offset = batch_sub_idx + 1 yield set_sample_restore_key( inner_batch_sample, - sample_idx, - batch_sub_idx, - *sample_restore_keys, - src=self, + BatchGenRestoreKey( + sample_idx=sample_idx, + gen_idx=batch_sub_idx, + inner=tuple(sample_restore_keys), + ), ) self._generator_sample_keys = None self._generator_offset = None else: last_batch_failures = 0 - set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self) + set_sample_restore_key( + batch_sample, + BatchRestoreKey(sample_idx=sample_idx, inner=tuple(sample_restore_keys)), + ) yield batch_sample sample_restore_keys.clear() except GeneratorExit: @@ -202,40 +223,35 @@ def assert_can_restore(self) -> None: ) super().assert_can_restore() - def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_batch: + def restore_sample(self, restore_key: RestoreKey) -> T_batch: # We need to store multiple indices to restore a batch. self.assert_can_restore() + assert isinstance(restore_key, BatchRestoreKey) if inspect.isgeneratorfunction(self.batcher): - id, sample_idx, batch_sub_idx, *samples_restore_keys = restore_key - assert id == type(self).__name__ - else: - id, sample_idx, *samples_restore_keys = restore_key - assert id == type(self).__name__ - batch = [self.dataset.restore_sample(inner_idx) for inner_idx in samples_restore_keys] - with SampleIndex(self.worker_config, src=self).ctx(sample_idx): + assert isinstance(restore_key, BatchGenRestoreKey) + batch = [self.dataset.restore_sample(inner_idx) for inner_idx in restore_key.inner] + with SampleIndex(self.worker_config, src=self).ctx(restore_key.sample_idx): batch_sample = self.batcher(batch) if isinstance(batch_sample, Generator): assert inspect.isgeneratorfunction(self.batcher), ( f"Generator in {self.batcher} but not marked as such." ) + assert isinstance(restore_key, BatchGenRestoreKey) for cur_batch_sub_idx, (sample_idx, inner_batch_sample) in enumerate( - SampleIndex(self.worker_config, src=self).iter_ctx(batch_sample, sample_idx) + SampleIndex(self.worker_config, src=self).iter_ctx( + batch_sample, restore_key.sample_idx + ) ): - if cur_batch_sub_idx == batch_sub_idx: + if cur_batch_sub_idx == restore_key.gen_idx: return set_sample_restore_key( inner_batch_sample, - sample_idx, - batch_sub_idx, - *samples_restore_keys, - src=self, + restore_key, ) - assert False, f"Batch sub-index {batch_sub_idx} not found in batch" + assert False, f"Batch sub-index {restore_key.gen_idx} not found in batch" else: return set_sample_restore_key( batch_sample, - sample_idx, - *samples_restore_keys, - src=self, + restore_key, ) def config(self) -> Dict[str, Any]: diff --git a/src/megatron/energon/wrappers/blend_dataset.py b/src/megatron/energon/wrappers/blend_dataset.py index 7b60ce5a..c8f49b5d 100644 --- a/src/megatron/energon/wrappers/blend_dataset.py +++ b/src/megatron/energon/wrappers/blend_dataset.py @@ -1,18 +1,29 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause +from dataclasses import dataclass from typing import Any, Dict, Iterator, List, Sequence, Tuple, TypeVar import torch -from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key +from megatron.energon.flavors.base_dataset import SavableDataset from megatron.energon.rng import WorkerRng from megatron.energon.worker import WorkerConfig -from megatron.energon.wrappers.base import BaseWrapperDataset +from megatron.energon.wrappers.base import ( + BaseWrapperDataset, + RestoreKey, + WrappedRestoreKey, + wrap_sample_restore_key, +) T_sample = TypeVar("T_sample") +@dataclass(kw_only=True, slots=True, frozen=True) +class BlendRestoreKey(WrappedRestoreKey): + dataset_idx: int + + class BlendDataset(BaseWrapperDataset[T_sample, T_sample]): """ This dataset wrapper blends multiple iterable datasets together give a weighting. @@ -104,10 +115,14 @@ def __iter__(self) -> Iterator[T_sample]: if all(dataset_iter is None for dataset_iter in dataset_iters): break else: - yield add_sample_restore_key(sample, ds_idx, src=self) + yield wrap_sample_restore_key(sample, BlendRestoreKey, dataset_idx=ds_idx) self.exhausted = [False] * len(self.dataset_weights) + def restore_sample(self, restore_key: RestoreKey) -> T_sample: + assert isinstance(restore_key, BlendRestoreKey) + return self.datasets[restore_key.dataset_idx].restore_sample(restore_key.inner) + def config(self) -> Dict[str, Any]: return { "type": type(self).__qualname__, diff --git a/src/megatron/energon/wrappers/buffer.py b/src/megatron/energon/wrappers/buffer.py index 3b0d0b32..0d2eb2b8 100644 --- a/src/megatron/energon/wrappers/buffer.py +++ b/src/megatron/energon/wrappers/buffer.py @@ -10,13 +10,12 @@ List, Optional, Sequence, - Tuple, TypeVar, Union, ) from megatron.energon.edataclass import edataclass -from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.flavors.base_dataset import RestoreKey, SavableDataset from megatron.energon.savable import Savable from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import get_sample_restore_key @@ -28,14 +27,14 @@ class SavableSampleBufferState: """State of a SavableSampleBuffer.""" - restore_keys: List[Tuple[Union[str, int, tuple], ...]] + restore_keys: list[RestoreKey | None] class SavableSampleBuffer(Savable, Generic[T_sample]): """A buffer of samples, savable. State is shared, create a state-local instance.""" - _buffer: List[T_sample] - _restore_keys: List[Tuple[Union[str, int, tuple], ...]] + _buffer: list[T_sample] + _restore_keys: list[RestoreKey | None] _restore_pending: bool = False @@ -73,12 +72,12 @@ def pop(self, index: int) -> T_sample: self._restore_keys.pop(index) return self._buffer.pop(index) - def flush(self) -> Tuple[List[T_sample], Tuple[Any, ...]]: + def flush(self) -> tuple[list[T_sample], tuple[RestoreKey | None, ...]]: buffer = list(self._buffer) - restore_key = tuple(self._restore_keys) + restore_keys = tuple(self._restore_keys) self._buffer.clear() self._restore_keys.clear() - return buffer, restore_key + return buffer, restore_keys @property def buffer(self) -> List[T_sample]: @@ -122,12 +121,12 @@ def restore_state(self, state: SavableSampleBufferState) -> None: self._restore_keys = state.restore_keys.copy() self._restore_pending = True - def restore_key(self) -> Tuple[Union[str, int], ...]: + def restore_key(self) -> tuple[RestoreKey | None, ...]: return tuple(self._restore_keys) def restore_samples( - self, index: Tuple[Union[str, int, tuple], ...] - ) -> Tuple[Tuple[Union[str, int, tuple], ...], List[T_sample]]: + self, index: tuple[RestoreKey | None, ...] + ) -> tuple[tuple[RestoreKey | None, ...], list[T_sample]]: buffer = [] restore_keys = [] for sub_index in index: diff --git a/src/megatron/energon/wrappers/concat_dataset.py b/src/megatron/energon/wrappers/concat_dataset.py index 83e35660..38c7df7d 100644 --- a/src/megatron/energon/wrappers/concat_dataset.py +++ b/src/megatron/energon/wrappers/concat_dataset.py @@ -1,15 +1,25 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause +from dataclasses import dataclass from typing import Any, Dict, Generic, Iterator, TypeVar -from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key +from megatron.energon.flavors.base_dataset import RestoreKey, SavableDataset from megatron.energon.worker import WorkerConfig -from megatron.energon.wrappers.base import BaseWrapperDataset +from megatron.energon.wrappers.base import ( + BaseWrapperDataset, + WrappedRestoreKey, + wrap_sample_restore_key, +) T_sample = TypeVar("T_sample") +@dataclass(kw_only=True, slots=True, frozen=True) +class ConcatRestoreKey(WrappedRestoreKey): + dataset_idx: int + + class ConcatDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]): """ This dataset wrapper concatenates multiple iterable datasets together. The datasets must be @@ -35,12 +45,16 @@ def len_worker(self, worker_idx: int | None = None) -> int: def __iter__(self) -> Iterator[T_sample]: for ds_idx, dataset in enumerate(self.datasets): for sample in dataset: - yield add_sample_restore_key( + yield wrap_sample_restore_key( sample, - ds_idx, - src=self, + ConcatRestoreKey, + dataset_idx=ds_idx, ) + def restore_sample(self, restore_key: RestoreKey) -> T_sample: + assert isinstance(restore_key, ConcatRestoreKey) + return self.datasets[restore_key.dataset_idx].restore_sample(restore_key.inner) + def config(self) -> Dict[str, Any]: return { "type": type(self).__qualname__, diff --git a/src/megatron/energon/wrappers/group_batch_dataset.py b/src/megatron/energon/wrappers/group_batch_dataset.py index 2ca78620..f0e8912f 100644 --- a/src/megatron/energon/wrappers/group_batch_dataset.py +++ b/src/megatron/energon/wrappers/group_batch_dataset.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import inspect +from dataclasses import dataclass from typing import ( Any, Callable, @@ -22,6 +23,7 @@ from megatron.energon.errors import SYSTEM_EXCEPTIONS, FatalSampleError from megatron.energon.flavors.base_dataset import ( FlexState, + RestoreKey, SavableDataset, set_sample_restore_key, ) @@ -29,7 +31,11 @@ from megatron.energon.source_info import SourceInfo from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers._log_exception import log_exception -from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex +from megatron.energon.wrappers.base import ( + BaseWrapperDataset, + MultiWrappedRestoreKey, + SampleIndex, +) from megatron.energon.wrappers.buffer import SavableSampleBuffer from megatron.energon.wrappers.skip import SkipSample @@ -54,6 +60,11 @@ def restore_state(self, state: FlexState): self.samples.restore_state(state["samples"]) +@dataclass(kw_only=True, slots=True, frozen=True) +class GroupBatchRestoreKey(MultiWrappedRestoreKey): + sample_idx: int + + class GroupBatchDataset( BaseWrapperDataset[T_batch_sample, T_batch], Generic[T_batch_sample, T_batch] ): @@ -163,7 +174,10 @@ def flush(bucket: Bucket[T_batch_sample]) -> Generator[T_batch, None, None]: f"Batcher {self.batcher} returned a generator, which is not supported for grouped batching yet." ) last_batch_failures = 0 - set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self) + set_sample_restore_key( + batch_sample, + GroupBatchRestoreKey(sample_idx=sample_idx, inner=sample_restore_keys), + ) yield batch_sample except SkipSample: pass @@ -246,14 +260,13 @@ def assert_can_restore(self) -> None: ) super().assert_can_restore() - def restore_sample(self, index: Tuple[Union[str, int, tuple], ...]) -> T_batch: + def restore_sample(self, index: RestoreKey) -> T_batch: self.assert_can_restore() - id, sample_idx, *sample_restore_keys = index - assert id == type(self).__name__ - batch = [self.dataset.restore_sample(inner_idx) for inner_idx in sample_restore_keys] - with SampleIndex(self.worker_config, src=self).ctx(sample_idx): + assert isinstance(index, GroupBatchRestoreKey) + batch = [self.dataset.restore_sample(inner_idx) for inner_idx in index.inner] + with SampleIndex(self.worker_config, src=self).ctx(index.sample_idx): batch_sample = self.batcher(batch) - set_sample_restore_key(batch_sample, sample_idx, *sample_restore_keys, src=self) + set_sample_restore_key(batch_sample, index) return batch_sample def config(self) -> Dict[str, Any]: diff --git a/src/megatron/energon/wrappers/iter_map_dataset.py b/src/megatron/energon/wrappers/iter_map_dataset.py index 41c05a15..fc564dea 100644 --- a/src/megatron/energon/wrappers/iter_map_dataset.py +++ b/src/megatron/energon/wrappers/iter_map_dataset.py @@ -1,6 +1,7 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause +from dataclasses import dataclass from typing import ( Any, Callable, @@ -9,7 +10,6 @@ Generic, Iterator, Optional, - Tuple, TypeVar, Union, ) @@ -17,16 +17,27 @@ from torch.utils.data import IterableDataset from megatron.energon.errors import SYSTEM_EXCEPTIONS, FatalSampleError -from megatron.energon.flavors.base_dataset import SavableDataset, set_sample_restore_key +from megatron.energon.flavors.base_dataset import RestoreKey, SavableDataset, set_sample_restore_key from megatron.energon.source_info import SourceInfo from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers._log_exception import log_exception -from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key +from megatron.energon.wrappers.base import ( + BaseWrapperDataset, + MultiWrappedRestoreKey, + SampleIndex, + get_sample_restore_key, +) T_sample = TypeVar("T_sample") T_sample_out = TypeVar("T_sample_out") +@dataclass(kw_only=True, slots=True, frozen=True) +class IterMapRestoreKey(MultiWrappedRestoreKey): + sample_idx: int + iter_idx: int + + class IterMapDataset(BaseWrapperDataset[T_sample, T_sample_out], Generic[T_sample, T_sample_out]): """This dataset wrapper applies a custom function to transform the stream of samples and yield a new stream of samples. @@ -97,7 +108,7 @@ def __iter__(self) -> Iterator[T_sample_out]: # This is the sample index within the currently yielded sample iter_idx = 0 sample_idx = 0 - sample_restore_keys = [] + sample_restore_keys: list[RestoreKey | None] = [] def reset_idx_iter() -> Generator[T_sample, None, None]: # Resets the inner sample index @@ -116,10 +127,11 @@ def reset_idx_iter() -> Generator[T_sample, None, None]: for sample_idx, sample in self._sample_index.iter_ctx(self.iter_map_fn(ds_iter)): yield set_sample_restore_key( sample, - sample_idx, - iter_idx, - *sample_restore_keys, - src=self, + IterMapRestoreKey( + sample_idx=sample_idx, + iter_idx=iter_idx, + inner=tuple(sample_restore_keys), + ), ) sample_restore_keys.clear() iter_idx += 1 @@ -139,31 +151,26 @@ def assert_can_restore(self) -> None: ) super().assert_can_restore() - def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_sample: + def restore_sample(self, restore_key: RestoreKey) -> T_sample: self.assert_can_restore() - id, sample_idx, iter_idx, *sample_restore_keys = restore_key - assert id == type(self).__name__ - assert isinstance(iter_idx, int) + assert isinstance(restore_key, IterMapRestoreKey) inner_iter = iter( self.iter_map_fn( - (self.dataset.restore_sample(inner_index) for inner_index in sample_restore_keys) + (self.dataset.restore_sample(inner_index) for inner_index in restore_key.inner) ) ) try: sample_index = SampleIndex(self.worker_config, src=self) # Skip inner yielded samples to get the correct sample - for skip_idx in range(iter_idx): - with sample_index.ctx(sample_idx - iter_idx + skip_idx): + for skip_idx in range(restore_key.iter_idx): + with sample_index.ctx(restore_key.sample_idx - restore_key.iter_idx + skip_idx): next(inner_iter) # This is the sample to restore - with sample_index.ctx(sample_idx): + with sample_index.ctx(restore_key.sample_idx): sample = next(inner_iter) return set_sample_restore_key( sample, - sample_idx, - iter_idx, - *sample_restore_keys, - src=self, + restore_key, ) except StopIteration: raise RuntimeError( diff --git a/src/megatron/energon/wrappers/map_dataset.py b/src/megatron/energon/wrappers/map_dataset.py index 91810f16..4f50ccc2 100644 --- a/src/megatron/energon/wrappers/map_dataset.py +++ b/src/megatron/energon/wrappers/map_dataset.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import inspect +from dataclasses import dataclass from typing import ( Any, Callable, @@ -11,23 +12,38 @@ Iterator, Optional, Sequence, - Tuple, TypeVar, Union, ) from megatron.energon.errors import SYSTEM_EXCEPTIONS, FatalSampleError -from megatron.energon.flavors.base_dataset import SavableDataset, add_sample_restore_key +from megatron.energon.flavors.base_dataset import RestoreKey, SavableDataset, set_sample_restore_key from megatron.energon.source_info import SourceInfo from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers._log_exception import log_exception -from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key +from megatron.energon.wrappers.base import ( + BaseWrapperDataset, + SampleIndex, + WrappedRestoreKey, + get_sample_restore_key, + wrap_sample_restore_key, +) from megatron.energon.wrappers.skip import SkipSample T_sample = TypeVar("T_sample") T_sample_out = TypeVar("T_sample_out") +@dataclass(kw_only=True, slots=True, frozen=True) +class MapRestoreKey(WrappedRestoreKey): + sample_idx: int + + +@dataclass(kw_only=True, slots=True, frozen=True) +class MapGenRestoreKey(MapRestoreKey): + gen_idx: int + + class MapDataset(BaseWrapperDataset[T_sample, T_sample_out], Generic[T_sample, T_sample_out]): """This dataset wrapper applies a custom function to transform each sample.""" @@ -36,8 +52,8 @@ class MapDataset(BaseWrapperDataset[T_sample, T_sample_out], Generic[T_sample, T stateless_map_fn: bool map_fn_config: Optional[Union[Dict[str, Any], Callable[[], Dict[str, Any]]]] _sample_index: SampleIndex - _generator_sample_key: Optional[Any] - _generator_offset: Optional[int] + _generator_sample_key: RestoreKey | None + _generator_offset: int | None _savable_fields = ( "_sample_index", @@ -107,11 +123,11 @@ def __iter__(self) -> Iterator[T_sample_out]: # Skip other samples if idx >= target_offset: self._generator_offset = idx + 1 - yield add_sample_restore_key( + yield wrap_sample_restore_key( inner_sample, - sample_idx, - idx, - src=self, + MapGenRestoreKey, + sample_idx=sample_idx, + gen_idx=idx, ) self._generator_sample_key = None self._generator_offset = None @@ -134,20 +150,20 @@ def __iter__(self) -> Iterator[T_sample_out]: ): self._generator_offset = idx + 1 last_map_failures = 0 - yield add_sample_restore_key( + yield wrap_sample_restore_key( inner_sample, - sample_idx, - idx, - src=self, + MapGenRestoreKey, + sample_idx=sample_idx, + gen_idx=idx, ) self._generator_sample_key = None self._generator_offset = None else: last_map_failures = 0 - yield add_sample_restore_key( + yield wrap_sample_restore_key( mapped_sample, - sample_idx, - src=self, + MapRestoreKey, + sample_idx=sample_idx, ) except GeneratorExit: raise @@ -176,34 +192,34 @@ def assert_can_restore(self) -> None: ) super().assert_can_restore() - def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_sample_out: + def restore_sample(self, restore_key: RestoreKey) -> T_sample_out: self.assert_can_restore() + assert isinstance(restore_key, MapRestoreKey), ( + f"Expected MapRestoreKey, got {type(restore_key)}" + ) if inspect.isgeneratorfunction(self.map_fn): - id, sample_idx, local_idx = restore_key[:3] - assert id == type(self).__name__ - restore_key = restore_key[3:] - assert isinstance(local_idx, int) - else: - id, sample_idx = restore_key[:2] - assert id == type(self).__name__ - restore_key = restore_key[2:] - inner_sample = self.dataset.restore_sample(restore_key) - with SampleIndex(self.worker_config, src=self).ctx(sample_idx): + assert isinstance(restore_key, MapGenRestoreKey) + inner_sample = self.dataset.restore_sample(restore_key.inner) + with SampleIndex(self.worker_config, src=self).ctx(restore_key.sample_idx): mapped_sample = self.map_fn(inner_sample) if isinstance(mapped_sample, Generator): assert inspect.isgeneratorfunction(self.map_fn), ( f"Generator in {self.map_fn} but not marked as such." ) + assert isinstance(restore_key, MapGenRestoreKey) for idx, (sample_idx, res_sample) in enumerate( - SampleIndex(self.worker_config, src=self).iter_ctx(mapped_sample, sample_idx) + SampleIndex(self.worker_config, src=self).iter_ctx( + mapped_sample, restore_key.sample_idx + ) ): - if idx == local_idx: - return add_sample_restore_key(res_sample, sample_idx, local_idx, src=self) - assert False, ( - "Generator did not yield enough samples, but is marked stateless/deterministic." - ) + if idx == restore_key.gen_idx: + return set_sample_restore_key( + res_sample, + restore_key, + ) + assert False, f"Generator sub-index {restore_key.gen_idx} not yielded by {self.map_fn}" else: - return add_sample_restore_key(mapped_sample, sample_idx, src=self) + return set_sample_restore_key(mapped_sample, restore_key) def config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/wrappers/packing_dataset.py b/src/megatron/energon/wrappers/packing_dataset.py index 8d48742f..1091884f 100644 --- a/src/megatron/energon/wrappers/packing_dataset.py +++ b/src/megatron/energon/wrappers/packing_dataset.py @@ -3,6 +3,7 @@ import contextlib import inspect +from dataclasses import dataclass from typing import ( Any, Callable, @@ -19,14 +20,21 @@ from megatron.energon.errors import SYSTEM_EXCEPTIONS, FatalSampleError from megatron.energon.flavors.base_dataset import ( + RestoreKey, SavableDataset, - add_sample_restore_key, set_sample_restore_key, ) from megatron.energon.source_info import SourceInfo from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers._log_exception import log_exception -from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key +from megatron.energon.wrappers.base import ( + BaseWrapperDataset, + MultiWrappedRestoreKey, + SampleIndex, + WrappedRestoreKey, + get_sample_restore_key, + wrap_sample_restore_key, +) from megatron.energon.wrappers.buffer import SavableSampleBuffer from megatron.energon.wrappers.skip import SkipSample @@ -35,6 +43,21 @@ T_batch_sample = TypeVar("T_batch_sample") +@dataclass(kw_only=True, slots=True, frozen=True) +class EncodePackRestoreKey(WrappedRestoreKey): + sample_idx: int + + +@dataclass(kw_only=True, slots=True, frozen=True) +class PackingRestoreKey(MultiWrappedRestoreKey): + pack_idx: int + + +@dataclass(kw_only=True, slots=True, frozen=True) +class PackingGenRestoreKey(PackingRestoreKey): + gen_idx: int + + class PackingDataset( BaseWrapperDataset[T_sample, T_encoded_sample, T_batch_sample], Generic[T_sample, T_encoded_sample, T_batch_sample], @@ -221,10 +244,10 @@ def encode_pack_samples(pack: List[T_sample]) -> List[T_encoded_sample]: encoded_sample = self.sample_encoder(sample) assert not isinstance(encoded_sample, Generator), "Generator not supported" encoded_pack.append( - add_sample_restore_key( + wrap_sample_restore_key( encoded_sample, - encode_idx, - src=self, + EncodePackRestoreKey, + sample_idx=encode_idx, ) ) except SkipSample: @@ -310,17 +333,14 @@ def next_final_pack() -> Generator[T_batch_sample, None, None]: ): yield set_sample_restore_key( inner_batch_sample, - pack_idx, - pack_sub_idx, - *pack_restore_keys, - src=self, + PackingGenRestoreKey( + pack_idx=pack_sub_idx, gen_idx=pack_sub_idx, inner=pack_restore_keys + ), ) else: yield set_sample_restore_key( final_packed_sample, - pack_idx, - *pack_restore_keys, - src=self, + PackingRestoreKey(pack_idx=pack_idx, inner=pack_restore_keys), ) except SkipSample: pass @@ -394,49 +414,44 @@ def assert_can_restore(self): ) super().assert_can_restore() - def restore_sample(self, restore_key: Any) -> T_sample: + def restore_sample(self, restore_key: RestoreKey) -> T_sample: # We need to store multiple indices to restore a batch. self.assert_can_restore() + assert isinstance(restore_key, PackingRestoreKey) if inspect.isgeneratorfunction(self.final_packer): - id, pack_idx, pack_sub_idx, *pack_restore_keys = restore_key - assert id == type(self).__name__ - else: - id, pack_idx, *pack_restore_keys = restore_key - assert id == type(self).__name__ + assert isinstance(restore_key, PackingGenRestoreKey) pack = [] - for inner_idx in pack_restore_keys: + for inner_key in restore_key.inner: if self.sample_encoder is not None: - id, sample_idx, *inner_idx = inner_idx - assert id == type(self).__name__ - assert isinstance(sample_idx, int) - sample = self.dataset.restore_sample(inner_idx) + assert isinstance(inner_key, EncodePackRestoreKey) + encode_key = inner_key + inner_key = inner_key.inner + sample = self.dataset.restore_sample(inner_key) if self.sample_encoder is not None: - with SampleIndex(self.worker_config, src=self).ctx(sample_idx): + with SampleIndex(self.worker_config, src=self).ctx(encode_key.sample_idx): sample = self.sample_encoder(sample) assert not isinstance(sample, Generator), "Generator not supported" - sample = add_sample_restore_key(sample, sample_idx, src=self) + sample = set_sample_restore_key(sample, encode_key) pack.append(sample) - with SampleIndex(self.worker_config, src=self).ctx(pack_idx): + with SampleIndex(self.worker_config, src=self).ctx(restore_key.pack_idx): final_pack = self.final_packer(pack) if isinstance(final_pack, Generator): assert inspect.isgeneratorfunction(self.final_packer), ( f"Generator in {self.final_packer} but not marked as such." ) - for cur_batch_sub_idx, (pack_idx, inner_batch_sample) in enumerate( - SampleIndex(self.worker_config, src=self).iter_ctx(final_pack, pack_idx) + assert isinstance(restore_key, PackingGenRestoreKey) + for pack_sub_idx, (pack_idx, inner_batch_sample) in enumerate( + SampleIndex(self.worker_config, src=self).iter_ctx(final_pack, restore_key.pack_idx) ): - if cur_batch_sub_idx == pack_sub_idx: + if pack_sub_idx == restore_key.gen_idx: return set_sample_restore_key( inner_batch_sample, - pack_idx, - pack_sub_idx, - *pack_restore_keys, - src=self, + restore_key, ) assert False, f"Pack sub-index {pack_sub_idx} not found in pack" else: - return set_sample_restore_key(final_pack, pack_idx, *pack_restore_keys, src=self) + return set_sample_restore_key(final_pack, restore_key) def config(self) -> Dict[str, Any]: return { diff --git a/src/megatron/energon/wrappers/shuffle_buffer_dataset.py b/src/megatron/energon/wrappers/shuffle_buffer_dataset.py index 9337f1c5..ea85bda2 100644 --- a/src/megatron/energon/wrappers/shuffle_buffer_dataset.py +++ b/src/megatron/energon/wrappers/shuffle_buffer_dataset.py @@ -1,7 +1,7 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -from typing import Any, Dict, Generic, Iterator, Tuple, TypeVar, Union +from typing import Any, Dict, Generic, Iterator, TypeVar from megatron.energon.flavors.base_dataset import SavableDataset from megatron.energon.rng import WorkerRng @@ -52,9 +52,6 @@ def __iter__(self) -> Iterator[T_sample]: pop_idx = self._worker_rng.randbelow(self._active_buffer.len_worker()) yield self._active_buffer.pop(pop_idx) - def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_sample: - return self.dataset.restore_sample(restore_key) - def config(self) -> Dict[str, Any]: return { "type": type(self).__qualname__, diff --git a/src/megatron/energon/wrappers/task_encoder_state_dataset.py b/src/megatron/energon/wrappers/task_encoder_state_dataset.py index 6363f4c9..30b0e75e 100644 --- a/src/megatron/energon/wrappers/task_encoder_state_dataset.py +++ b/src/megatron/energon/wrappers/task_encoder_state_dataset.py @@ -6,13 +6,11 @@ Dict, Generic, Iterator, - Tuple, TypeVar, - Union, ) import megatron.energon -from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.flavors.base_dataset import RestoreKey, SavableDataset from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset @@ -53,9 +51,10 @@ def __iter__(self) -> Iterator[T_sample]: if not self._task_encoder_was_reset: self._task_encoder_was_reset = True self._task_encoder.reset_state() - yield from self.dataset + for sample in self.dataset: + yield sample - def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_sample: + def restore_sample(self, restore_key: RestoreKey) -> T_sample: inner_sample = self.dataset.restore_sample(restore_key) inner_sample = self._task_encoder.restore_sample(inner_sample) return inner_sample @@ -78,4 +77,4 @@ def config(self) -> Dict[str, Any]: } def __str__(self): - return f"MapDataset(map_fn={self.map_fn}, dataset={self.dataset})" + return f"TaskEncoderStateDataset(map_fn={self.map_fn}, dataset={self.dataset})" diff --git a/tests/test_dataset.py b/tests/test_dataset.py index b0bf9bf7..d07337a0 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1380,7 +1380,7 @@ def pack_selected_samples( ) -> EncodedCaptioningSample: return EncodedCaptioningSample( __key__=",".join([sample.__key__ for sample in samples]), - __restore_key__=(), + __restore_key__=None, image=torch.stack([sample.image for sample in samples]), caption=torch.cat([sample.caption for sample in samples]), ) @@ -1515,7 +1515,7 @@ def pack_selected_samples( ) -> EncodedCaptioningSample: return EncodedCaptioningSample( __key__=",".join([sample.__key__ for sample in samples]), - __restore_key__=(), + __restore_key__=None, image=torch.stack([sample.image for sample in samples]), caption=torch.cat([sample.caption for sample in samples]), ) diff --git a/tests/test_dataset_det.py b/tests/test_dataset_det.py index 6719dbe1..c12cc943 100644 --- a/tests/test_dataset_det.py +++ b/tests/test_dataset_det.py @@ -934,7 +934,7 @@ def test_redist(self): batches_per_rank = [] for rank_config in scenario["configs"]: - loader = get_savable_loader( + with get_savable_loader( get_train_dataset( self.dataset_path, split_part="train", @@ -944,27 +944,26 @@ def test_redist(self): shuffle_buffer_size=42, max_samples_per_sequence=2, ) - ) - - # Throw away some samples to advance the loader state - num_pre_samples = 20 - for _ in zip(range(num_pre_samples), loader): - pass - - # Save the state to a file - checkpoint_file = self.checkpoint_dir / f"state_rank{rank_config.rank}.pt" - state = loader.save_state_rank() - torch.save(state, str(checkpoint_file)) - checkpoint_files.append(checkpoint_file) - - # Now capture the next micro-batches - micro_batches = [ - data.text - for idx, data in zip( - range(55 * 8 // (world_size * scenario["micro_batch_size"])), loader - ) - ] - batches_per_rank.append(micro_batches) + ) as loader: + # Throw away some samples to advance the loader state + num_pre_samples = 20 + for _ in zip(range(num_pre_samples), loader): + pass + + # Save the state to a file + checkpoint_file = self.checkpoint_dir / f"state_rank{rank_config.rank}.pt" + state = loader.save_state_rank() + torch.save(state, str(checkpoint_file)) + checkpoint_files.append(checkpoint_file) + + # Now capture the next micro-batches + micro_batches = [ + data.text + for idx, data in zip( + range(55 * 8 // (world_size * scenario["micro_batch_size"])), loader + ) + ] + batches_per_rank.append(micro_batches) # Compose global batches global_batches_cur_rank = [] @@ -985,6 +984,7 @@ def test_redist(self): # === Stage 2: Now check that the global batches are the same after redistribution for scenario in scenarios[1:]: + print(f"\n\nRunning scenario {scenario}") # Redistribute the saved state runner = CliRunner() result = runner.invoke( @@ -992,11 +992,15 @@ def test_redist(self): [ "--new-world-size", str(len(scenario["configs"])), + "--new-micro-batch-size", + str(scenario["micro_batch_size"]), *[str(cpt) for cpt in checkpoint_files], str(self.redist_dir), ], ) print(result.output) + if result.exception is not None: + raise result.exception assert result.exception is None, result.exception assert result.exit_code == 0, "Redistribution failed" @@ -1013,7 +1017,11 @@ def test_redist(self): batches_per_rank = [] for rank_config in scenario["configs"]: - loader = get_savable_loader( + state = torch.load( + str(self.redist_dir / f"state_rank{rank_config.rank}.pt"), weights_only=False + ) + + with get_savable_loader( get_train_dataset( self.dataset_path, split_part="train", @@ -1023,20 +1031,14 @@ def test_redist(self): shuffle_buffer_size=42, max_samples_per_sequence=2, ) - ) - - state = torch.load( - str(self.redist_dir / f"state_rank{rank_config.rank}.pt"), weights_only=False - ) - loader.restore_state_rank(state) - - micro_batches = [ - data.text - for idx, data in zip( - range(55 * 8 // (world_size * scenario["micro_batch_size"])), loader - ) - ] - batches_per_rank.append(micro_batches) + ).with_restored_state_rank(state) as loader: + micro_batches = [ + data.text + for idx, data in zip( + range(55 * 8 // (world_size * scenario["micro_batch_size"])), loader + ) + ] + batches_per_rank.append(micro_batches) # Compose global batches global_batches_cur_rank = [] @@ -1064,8 +1066,10 @@ def test_redist(self): for global_batches in global_batches_per_scenario ), "Number of global batches per scenario does not match." - for global_batches in global_batches_per_scenario: - print("= Global batches per scenario") + for idx, (global_batches, scenario) in enumerate( + zip(global_batches_per_scenario, scenarios) + ): + print(f"= Global batches per scenario {idx} {scenario}") for global_batch in global_batches: print(" Global batch: ", global_batch) From 3e4ba10eb05637d31e9a0bb4091f488503ef7fc0 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Tue, 19 Aug 2025 15:30:51 +0200 Subject: [PATCH 16/36] Fix usage of dataclasses.asdict --- src/megatron/energon/tools/checkpoint.py | 5 ++++- tests/test_dataset.py | 9 ++++----- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/megatron/energon/tools/checkpoint.py b/src/megatron/energon/tools/checkpoint.py index 30769065..786d26e5 100644 --- a/src/megatron/energon/tools/checkpoint.py +++ b/src/megatron/energon/tools/checkpoint.py @@ -245,7 +245,10 @@ def split_batch_restore_key( ] elif isinstance(restore_key, WrappedRestoreKey): inner_restore_keys = split_batch_restore_key(restore_key.inner, batch_split_factor) - inner_kwargs = dataclasses.asdict(restore_key) + inner_kwargs = { + field.name: getattr(restore_key, field.name) + for field in dataclasses.fields(restore_key) + } inner_kwargs.pop("inner") return [ type(restore_key)(**inner_kwargs, inner=inner_restore_key) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index d07337a0..44c625c5 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -3,7 +3,6 @@ """This module defines tests for the dataset.""" -import dataclasses import gc import io import json @@ -688,7 +687,7 @@ def __init__(self, task_encoder: TaskEncoder, weight: float, target_data_class: def encode_sample(self, sample): sample = super().encode_sample(sample) - return self.target_data_class(**dataclasses.asdict(sample), weight=self.weight) + return self.target_data_class.extend(sample, weight=self.weight) torch.manual_seed(42) @@ -763,7 +762,7 @@ def __init__(self, source: int): self.source = source def encode_batch(self, batch): - return TestBatch(**dataclasses.asdict(batch), source=self.source) + return TestBatch.extend(batch, source=self.source) loader = get_loader( MixBatchDataset( @@ -820,7 +819,7 @@ def __init__(self, source: int, batch_cls: Type[TestBatch1]): self.batch_cls = batch_cls def encode_batch(self, batch): - return self.batch_cls(**dataclasses.asdict(batch), source=self.source) + return self.batch_cls.extend(batch, source=self.source) loader = get_loader( MixBatchDataset( @@ -1580,7 +1579,7 @@ def batch_group_criterion(self, sample: CaptioningSample) -> Tuple[Hashable, int @stateless def encode_batch(self, batch: CaptioningSample) -> CaptioningEncodedBatch: - return CaptioningEncodedBatch(**dataclasses.asdict(batch)) + return CaptioningEncodedBatch.extend(batch) worker_config = WorkerConfig(rank=0, world_size=1, num_workers=0) loader = get_savable_loader( From 6436e6ac7c30b3f65ef0fa3ee94e9b722717737f Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Wed, 20 Aug 2025 17:49:31 +0200 Subject: [PATCH 17/36] Fix small issues with threaded worker and thread-local states --- .../dataloader/workers/async_worker.py | 23 +++ .../energon/dataloader/workers/fork_worker.py | 4 +- .../dataloader/workers/thread_worker.py | 20 +- src/megatron/energon/flavors/base_dataset.py | 2 +- .../energon/flavors/webdataset/itar_reader.py | 194 ++++++++++++------ .../flavors/webdataset/sample_loader.py | 63 +++--- src/megatron/energon/worker.py | 76 +++++-- tests/test_dataloader.py | 1 - 8 files changed, 244 insertions(+), 139 deletions(-) diff --git a/src/megatron/energon/dataloader/workers/async_worker.py b/src/megatron/energon/dataloader/workers/async_worker.py index 0647264c..25b4f370 100644 --- a/src/megatron/energon/dataloader/workers/async_worker.py +++ b/src/megatron/energon/dataloader/workers/async_worker.py @@ -1,7 +1,10 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause +import threading from typing import Generic, TypeVar +import torch.utils.data._utils + from megatron.energon.cache.base import CachePool from megatron.energon.dataloader.asynchronous import ( Asynchronous, @@ -17,6 +20,26 @@ TSample = TypeVar("TSample", covariant=True) +# Patching the torch worker info. + +_thread_local_worker_info = threading.local() + + +def torch_set_worker_info(id: int, num_workers: int, seed: int, dataset: SavableDataset): + _thread_local_worker_info._worker_info = torch.utils.data._utils.worker.WorkerInfo( + id=id, + num_workers=num_workers, + seed=seed, + dataset=dataset, + ) + + +def _patch_get_worker_info(): + return getattr(_thread_local_worker_info, "_worker_info", None) + + +torch.utils.data.get_worker_info = _patch_get_worker_info + class DataLoaderAsynchronousWorker(DataLoaderWorker[TSample], Asynchronous, Generic[TSample]): """ diff --git a/src/megatron/energon/dataloader/workers/fork_worker.py b/src/megatron/energon/dataloader/workers/fork_worker.py index 39a8d591..9fa92f72 100644 --- a/src/megatron/energon/dataloader/workers/fork_worker.py +++ b/src/megatron/energon/dataloader/workers/fork_worker.py @@ -6,6 +6,7 @@ from megatron.energon.dataloader.asynchronous import ForkAsynchronous from megatron.energon.dataloader.workers.async_worker import ( DataLoaderAsynchronousWorker, + torch_set_worker_info, ) from megatron.energon.wrappers.gc_dataset import gc_init_worker @@ -25,9 +26,8 @@ def _worker_run( result_queue: multiprocessing.Queue, ) -> None: gc_init_worker(self._rank_worker_id) - import torch.utils.data._utils - torch.utils.data._utils.worker._worker_info = torch.utils.data._utils.worker.WorkerInfo( + torch_set_worker_info( id=self._rank_worker_id, num_workers=self.worker_config.num_workers, seed=self._seed, diff --git a/src/megatron/energon/dataloader/workers/thread_worker.py b/src/megatron/energon/dataloader/workers/thread_worker.py index 7f5ea4fe..5550d3b4 100644 --- a/src/megatron/energon/dataloader/workers/thread_worker.py +++ b/src/megatron/energon/dataloader/workers/thread_worker.py @@ -1,33 +1,17 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -import threading from typing import Generic, TypeVar -import torch.utils.data - from megatron.energon.dataloader.asynchronous import ThreadAsynchronous, WorkerCommand, WorkerResult from megatron.energon.dataloader.asynchronous.base import QueueProtocol from megatron.energon.dataloader.workers.async_worker import ( DataLoaderAsynchronousWorker, + torch_set_worker_info, ) TSample = TypeVar("TSample", covariant=True) -_orig_get_worker_info = torch.utils.data.get_worker_info - -_thread_local_worker_info = threading.local() - - -def _patch_get_worker_info(): - if not hasattr(_thread_local_worker_info, "_worker_info"): - _thread_local_worker_info._worker_info = _orig_get_worker_info() - return _orig_get_worker_info() - - -torch.utils.data.get_worker_info = _patch_get_worker_info - - class ThreadDataLoaderWorker( ThreadAsynchronous, DataLoaderAsynchronousWorker[TSample], Generic[TSample] ): @@ -38,7 +22,7 @@ class ThreadDataLoaderWorker( def _worker_run( self, cmd_queue: QueueProtocol[WorkerCommand], result_queue: QueueProtocol[WorkerResult] ) -> None: - _thread_local_worker_info._worker_info = torch.utils.data._utils.worker.WorkerInfo( + torch_set_worker_info( id=self._rank_worker_id, num_workers=self.worker_config.num_workers, seed=self._seed, diff --git a/src/megatron/energon/flavors/base_dataset.py b/src/megatron/energon/flavors/base_dataset.py index 9a3bfb0a..fe01bf17 100644 --- a/src/megatron/energon/flavors/base_dataset.py +++ b/src/megatron/energon/flavors/base_dataset.py @@ -108,7 +108,7 @@ class MyExtendedClass(MyBaseClass): The extended dataclass instance. """ assert is_dataclass(cls), "Must be a dataclass" - assert issubclass(cls, type(src)), "Cannot extend class of different type" + # assert issubclass(cls, type(src)), "Cannot extend class of different type" for f in dataclasses.fields(src): if not f.init or f.type is ClassVar or typing.get_origin(f.type) is ClassVar: diff --git a/src/megatron/energon/flavors/webdataset/itar_reader.py b/src/megatron/energon/flavors/webdataset/itar_reader.py index 06d058fe..47a57340 100644 --- a/src/megatron/energon/flavors/webdataset/itar_reader.py +++ b/src/megatron/energon/flavors/webdataset/itar_reader.py @@ -34,6 +34,65 @@ from megatron.energon.source_info import SourceInfo T_index = TypeVar("T_index", covariant=False) +T_key = TypeVar("T_key") +T_value = TypeVar("T_value") + + +class MultiKeyCache(Generic[T_key, T_value]): + """A cache that can store multiple values for the same key.""" + + _size: int + _cache: dict[T_key, list[T_value]] + _lru_keys: list[T_key] + + def __init__(self) -> None: + self._size = 0 + self._cache = {} + self._lru_keys = [] + + @overload + def pop(self, key: None = None) -> T_value: ... + + @overload + def pop(self, key: T_key) -> T_value | None: ... + + def pop(self, key: T_key | None = None) -> T_value | None: + """Pop the value for the given key from the cache. + + If no key is provided, pop the oldest key from the cache. + + Args: + key: The key to pop from the cache. If None, pop the oldest key from the cache. + + Returns: + The value popped from the cache. + """ + if key is None: + key = self._lru_keys.pop(0) + elif key not in self._cache: + return None + else: + self._lru_keys.pop(len(self._lru_keys) - 1 - self._lru_keys[::-1].index(key)) + + l = self._cache[key] + value = l.pop(0) + if len(l) == 0: + del self._cache[key] + self._size -= 1 + return value + + def add(self, key: T_key, value: T_value) -> None: + """Add a value to the cache.""" + if key not in self._cache: + self._cache[key] = [value] + else: + self._cache[key].insert(0, value) + + self._lru_keys.append(key) + self._size += 1 + + def __len__(self) -> int: + return self._size class ITarReader(ABC, Generic[T_index]): @@ -54,7 +113,8 @@ class ITarReader(ABC, Generic[T_index]): tar_filenames: List[str] tar_filepaths: List[EPath] part_filter: Optional[Callable[[str], bool]] - itar_files_cache: Dict[int, ITarFile] + cache_lock: threading.Lock + itar_files_cache: MultiKeyCache[int, ITarFile] sample_filter: Optional[Callable[[str], bool]] def __init__( @@ -74,7 +134,8 @@ def __init__( self.tar_filenames = tar_filenames self.tar_filepaths = tar_filepaths self.part_filter = part_filter - self.itar_files_cache = {} + self.cache_lock = threading.Lock() + self.itar_files_cache = MultiKeyCache() self.itar_cache_size = itar_cache_size self.sample_filter = sample_filter @@ -98,24 +159,24 @@ def _get_itar_sample_pointer(self, idx: T_index) -> ITarSamplePointer: def _get_itarfile_cached(self, tar_file_id: int) -> ITarFile: """ Get the ITarFile object for the given tar file id. - If the file is not already open, open it. If we exceed - the global cache limit, close the least recently used file. + If the file is not already open, open it. """ - if tar_file_id not in self.itar_files_cache: - file_object = self.tar_filepaths[tar_file_id].open(mode="rb") - tar_file = ITarFile.open(fileobj=file_object, mode="r:") - self.itar_files_cache[tar_file_id] = tar_file - - # If we hit the limit of open files, close the least recently used file - while len(self.itar_files_cache) > self.itar_cache_size: - # Get the oldest file - lru_key = next(iter(self.itar_files_cache)) - - self.itar_files_cache[lru_key].fileobj.close() - self.itar_files_cache[lru_key].close() - del self.itar_files_cache[lru_key] - - return self.itar_files_cache[tar_file_id] + with self.cache_lock: + reader = self.itar_files_cache.pop(tar_file_id) + if reader is None: + file_object = self.tar_filepaths[tar_file_id].open(mode="rb") + reader = ITarFile.open(fileobj=file_object, mode="r:") + return reader + + def _update_itarfile_cache(self, tar_file_id: int, reader: ITarFile) -> None: + """ + Update the ITarFile object for the given tar file id. + """ + with self.cache_lock: + while len(self.itar_files_cache) >= self.itar_cache_size: + # Evict the oldest file + self.itar_files_cache.pop().close() + self.itar_files_cache.add(tar_file_id, reader) def _get_item_by_sample_pointer( self, @@ -136,7 +197,6 @@ def _get_item_by_sample_pointer( """ # Open the tar file (cached) - tar_file = self._get_itarfile_cached(sample_pointer.tar_file_id) shard_name = self.tar_filenames[sample_pointer.tar_file_id] sample_base_name = None sample_name = None @@ -144,50 +204,58 @@ def _get_item_by_sample_pointer( file_names: list[str] = [] # Position the tar file at the correct offset - tar_file.offset = sample_pointer.byte_offset - - while tar_file.offset < sample_pointer.byte_offset + sample_pointer.byte_size: - tarinfo = tar_file.next() - if tarinfo is None: - raise ValueError( - f"Unexpected end of tar file: {self.tar_filenames[sample_pointer.tar_file_id]}" - ) - fname = tarinfo.name - if not tarinfo.isfile() or fname is None: - continue - if skip_meta_re.match(fname): - continue - - # Extract the base_name and extension - m = split_name_re.match(fname) - if not m: - continue - cur_base_name, cur_ext = m.groups() - + tar_file = self._get_itarfile_cached(sample_pointer.tar_file_id) + try: + tar_file.offset = sample_pointer.byte_offset + + while tar_file.offset < sample_pointer.byte_offset + sample_pointer.byte_size: + tarinfo = tar_file.next() + if tarinfo is None: + if tar_file.offset == sample_pointer.byte_offset + sample_pointer.byte_size: + break + else: + raise ValueError( + f"Unexpected end of tar file: {self.tar_filenames[sample_pointer.tar_file_id]}" + ) + fname = tarinfo.name + if not tarinfo.isfile() or fname is None: + continue + if skip_meta_re.match(fname): + continue + + # Extract the base_name and extension + m = split_name_re.match(fname) + if not m: + continue + cur_base_name, cur_ext = m.groups() + + if sample_base_name is None: + sample_base_name = cur_base_name + sample_name = f"{shard_name}/{cur_base_name}" + if self.sample_filter is not None and not self.sample_filter(sample_name): + return None + else: + if sample_base_name != cur_base_name: + raise ValueError( + f"Inconsistent sample base name: {sample_base_name} vs {cur_base_name}" + ) + + if entry_match_fn is not None: + # If entry_match_fn is provided, use it to determine if we should take this entry + take_entry = entry_match_fn(fname) + else: + # If no entry_match_fn is provided, use the part_filter to determine if we should take this entry + take_entry = self.part_filter is None or self.part_filter(cur_ext) + + if take_entry: + member_bytes = tar_file.extractfile(tarinfo).read() + group_parts[cur_ext] = member_bytes + file_names.append(fname) if sample_base_name is None: - sample_base_name = cur_base_name - sample_name = f"{shard_name}/{cur_base_name}" - if self.sample_filter is not None and not self.sample_filter(sample_name): - return None - else: - if sample_base_name != cur_base_name: - raise ValueError( - f"Inconsistent sample base name: {sample_base_name} vs {cur_base_name}" - ) - - if entry_match_fn is not None: - # If entry_match_fn is provided, use it to determine if we should take this entry - take_entry = entry_match_fn(fname) - else: - # If no entry_match_fn is provided, use the part_filter to determine if we should take this entry - take_entry = self.part_filter is None or self.part_filter(cur_ext) - - if take_entry: - member_bytes = tar_file.extractfile(tarinfo).read() - group_parts[cur_ext] = member_bytes - file_names.append(fname) - if sample_base_name is None: - raise ValueError(f"No valid files found in sample {sample_pointer}") + raise ValueError(f"No valid files found in sample {sample_pointer}") + finally: + # Return the reader to the cache + self._update_itarfile_cache(sample_pointer.tar_file_id, tar_file) return FilteredSample( __key__=f"{shard_name}/{sample_base_name}", diff --git a/src/megatron/energon/flavors/webdataset/sample_loader.py b/src/megatron/energon/flavors/webdataset/sample_loader.py index 9c3838d9..8d44b411 100644 --- a/src/megatron/energon/flavors/webdataset/sample_loader.py +++ b/src/megatron/energon/flavors/webdataset/sample_loader.py @@ -39,8 +39,9 @@ class WebdatasetSampleLoaderDataset(SavableDataset[RawSampleData]): #: The readers for each joined dataset join_readers: Sequence[ITarReader] - #: The offsets of the slice slices to iterate over for the current worker - slice_offsets: Optional[Sequence[int]] + #: The offsets of the slice slices to iterate over for each worker + # On worker initialization, this is set to _slice_offsets for the current worker. + workers_slice_offsets: Sequence[Sequence[int]] # If = 1, every sample is seen exactly once per epoch. If > 1, samples # (or rather slice slices) are shuffled within this number of epochs (i.e. randomly @@ -53,6 +54,9 @@ class WebdatasetSampleLoaderDataset(SavableDataset[RawSampleData]): # Worker's random generator _worker_rng: WorkerRng + #: The offsets of the slice slices to iterate over for the current worker + _slice_offsets: Optional[Sequence[int]] + #: The RNG state to be used for regenerating the pending slices _pending_slices_rng_state: Optional[FlexState] #: The number of slices that have already been opened / processed and thus been removed from the @@ -81,6 +85,8 @@ class WebdatasetSampleLoaderDataset(SavableDataset[RawSampleData]): "_epoch_sample_count", ) + _state_fields = ("_slice_offsets",) + def __init__( self, join_readers: Sequence[ITarReader], @@ -95,7 +101,7 @@ def __init__( Args: join_readers: A sequence of the joined readers (or just a single reader) to iterate over. - worker_slice_offsets: The offsets of the slice slices to iterate over, for each worker. + workers_sample_slice_offsets: The offsets of the slice slices to iterate over, for each worker. worker_config: The worker configuration. shuffle_over_epochs: If None, disable shuffling. If = 1, every sample is seen exactly once per epoch. @@ -110,14 +116,10 @@ def __init__( super().__init__(worker_config=worker_config) self.join_readers = join_readers + self.workers_slice_offsets = workers_sample_slice_offsets self.shuffle_over_epochs = shuffle_over_epochs self.parallel_slice_iters = parallel_slice_iters - # Store the slices for all workers - # The slices for the current worker, will have to be extracted from this list later - self.workers_slice_offsets = workers_sample_slice_offsets - self.slice_offsets = None - assert shuffle_over_epochs is None or shuffle_over_epochs == -1 or shuffle_over_epochs >= 1 assert self.parallel_slice_iters >= 1 @@ -131,12 +133,7 @@ def reset_state(self) -> None: self._sample_count = 0 self._epoch_count = 0 self._epoch_sample_count = 0 - - def ensure_slice_offsets(self) -> None: - self.worker_config.assert_worker() - - if self.slice_offsets is None: - self.slice_offsets = self.workers_slice_offsets[self.worker_config.rank_worker_id()] + self._slice_offsets = self.workers_slice_offsets[self.worker_config.rank_worker_id()] def _get_sample(self, index: int) -> RawSampleData: return RawSampleData( @@ -146,9 +143,9 @@ def _get_sample(self, index: int) -> RawSampleData: def _slices_once(self) -> List[int]: """Yields the indexes to slice offsets once. Possibly shuffles the list.""" - assert self.slice_offsets is not None + assert self._slice_offsets is not None - num_slices = len(self.slice_offsets) - 1 + num_slices = len(self._slice_offsets) - 1 slices_offset = self._pending_slices_offset if self.shuffle_over_epochs is None: @@ -196,17 +193,17 @@ def _slices_iter(self) -> Generator[RawSampleData, None, None]: """Iterates the samples in a list of slices, possibly using multiple parallel iterators over the slices.""" - assert self.slice_offsets is not None + assert self._slice_offsets is not None active_slice_probs = torch.zeros(self.parallel_slice_iters, dtype=torch.float32) active_slices = self._active_slice_state pending_slice_indexes = self._pending_slice_indexes def slice_at(idx: int) -> SliceState: - assert self.slice_offsets is not None + assert self._slice_offsets is not None return SliceState( index=idx, - current=self.slice_offsets[idx], + current=self._slice_offsets[idx], ) # Weight the slices by their size to get a more even distribution of samples @@ -222,8 +219,8 @@ def slice_at(idx: int) -> SliceState: for idx, slice_state in enumerate(active_slices): if slice_state is not None: active_slice_probs[idx] = ( - self.slice_offsets[slice_state.index + 1] - - self.slice_offsets[slice_state.index] + self._slice_offsets[slice_state.index + 1] + - self._slice_offsets[slice_state.index] ) if self.worker_config.should_log(level=1): @@ -281,8 +278,8 @@ def slice_at(idx: int) -> SliceState: self._pending_slices_offset += 1 slice_state = slice_at(slice_index) active_slice_probs[len(active_slices)] = ( - self.slice_offsets[slice_state.index + 1] - - self.slice_offsets[slice_state.index] + self._slice_offsets[slice_state.index + 1] + - self._slice_offsets[slice_state.index] ) active_slices.append(slice_state) # Fill up the slice iterators with None @@ -316,7 +313,7 @@ def slice_at(idx: int) -> SliceState: slice_state.current += 1 self._sample_count += 1 self._epoch_sample_count += 1 - if slice_state.current >= self.slice_offsets[slice_state.index + 1]: + if slice_state.current >= self._slice_offsets[slice_state.index + 1]: # Iterator exhausted -> take next / remove from list if len(pending_slice_indexes) > 0 or self.shuffle_over_epochs == -1: if len(pending_slice_indexes) > 0: @@ -326,12 +323,12 @@ def slice_at(idx: int) -> SliceState: self._pending_slices_offset += 1 else: # Randomly select a new slice directly (with replacement) - num_slices = len(self.slice_offsets) - 1 + num_slices = len(self._slice_offsets) - 1 next_idx = self._worker_rng.randbelow(num_slices) next_slice_state = slice_at(next_idx) active_slice_probs[slice_idx] = ( - self.slice_offsets[next_slice_state.index + 1] - - self.slice_offsets[next_slice_state.index] + self._slice_offsets[next_slice_state.index + 1] + - self._slice_offsets[next_slice_state.index] ) active_slices[slice_idx] = next_slice_state # print( @@ -410,15 +407,13 @@ def len_worker(self, worker_idx: int | None = None) -> int: def worker_has_samples(self) -> bool: self.worker_config.assert_worker() - self.ensure_slice_offsets() - assert self.slice_offsets is not None - return len(self.slice_offsets) > 1 + assert self._slice_offsets is not None + return len(self._slice_offsets) > 1 def __iter__(self) -> Iterator[RawSampleData]: self.worker_config.assert_worker() - self.ensure_slice_offsets() - assert self.slice_offsets is not None + assert self._slice_offsets is not None if self.worker_config.should_log(level=1): self.worker_config.worker_log( @@ -426,13 +421,13 @@ def __iter__(self) -> Iterator[RawSampleData]: "t": "WebdatasetSampleLoaderDataset.__iter__", "r": self.worker_config.rank, "w": self.worker_config.rank_worker_id(), - "slice_offsets": self.slice_offsets, + "slice_offsets": self._slice_offsets, "parallel_slice_iters": self.parallel_slice_iters, "shuffle_over_epochs": self.shuffle_over_epochs, } ) - if len(self.slice_offsets) <= 1: + if len(self._slice_offsets) <= 1: return yield from self._slices_iter() diff --git a/src/megatron/energon/worker.py b/src/megatron/energon/worker.py index 8a549355..a77d0f0e 100644 --- a/src/megatron/energon/worker.py +++ b/src/megatron/energon/worker.py @@ -13,7 +13,6 @@ import torch.utils.data from megatron.energon.cache import CachePool -from megatron.energon.edataclass import edataclass __all__ = ("WorkerConfig",) @@ -22,16 +21,50 @@ THREAD_SAFE = True -@edataclass -class ActiveWorkerState(threading.local): - #: The current sample index within the current iterating worker - sample_index_stack: Optional[List[int]] = None - #: The global rank override for the worker. Required for restoring samples. - override_global_rank: Optional[int] = None - #: The current cache pool for the worker. - cache_pool: Optional[CachePool] = None - #: The current worker config within the current iterating worker - worker_config: "WorkerConfig | None" = None +class ActiveWorkerState: + """ + Thread local state for the active worker config. + """ + + _thread_local: threading.local + + @property + def sample_index_stack(self) -> Optional[List[int]]: + """The current sample index stack for the worker.""" + return getattr(self._thread_local, "sample_index_stack", None) + + @property + def override_global_rank(self) -> Optional[int]: + """The global rank override for the worker. Required for restoring samples.""" + return getattr(self._thread_local, "override_global_rank", None) + + @property + def cache_pool(self) -> Optional[CachePool]: + """The current cache pool for the worker.""" + return getattr(self._thread_local, "cache_pool", None) + + @property + def worker_config(self) -> "WorkerConfig | None": + return getattr(self._thread_local, "worker_config", None) + + @sample_index_stack.setter + def sample_index_stack(self, value: List[int]): + self._thread_local.sample_index_stack = value + + @override_global_rank.setter + def override_global_rank(self, value: Optional[int]): + self._thread_local.override_global_rank = value + + @cache_pool.setter + def cache_pool(self, value: Optional[CachePool]): + self._thread_local.cache_pool = value + + @worker_config.setter + def worker_config(self, value: "WorkerConfig | None"): + self._thread_local.worker_config = value + + def __init__(self): + self._thread_local = threading.local() class classproperty: @@ -101,6 +134,9 @@ def worker_activate( ): """Activates the worker config for the current worker and sets it as actively iterating. Must be called before next() call on the datasets.""" + assert WorkerConfig._active_state.worker_config is None, ( + f"Worker config already active for thread={threading.get_ident()}" + ) WorkerConfig._active_state.sample_index_stack = [sample_index] WorkerConfig._active_state.worker_config = self WorkerConfig._active_state.override_global_rank = override_global_rank @@ -121,15 +157,15 @@ def worker_pop_sample_index(self): def worker_deactivate(self): """Deactivates the worker config for the current worker and deactivates it for iterating. Must be called after next() call on the datasets.""" - if WorkerConfig.active_worker_config is not None: - assert WorkerConfig._active_state.sample_index_stack is not None - assert len(WorkerConfig._active_state.sample_index_stack) == 1, ( - f"Sample index stack not empty: {WorkerConfig._active_state.sample_index_stack}" - ) - WorkerConfig._active_state.sample_index_stack = None - WorkerConfig._active_state.worker_config = None - WorkerConfig._active_state.override_global_rank = None - WorkerConfig._active_state.cache_pool = None + assert WorkerConfig._active_state.worker_config is self, "Worker config mismatch" + assert WorkerConfig._active_state.sample_index_stack is not None + assert len(WorkerConfig._active_state.sample_index_stack) == 1, ( + f"Sample index stack not empty: {WorkerConfig._active_state.sample_index_stack}" + ) + WorkerConfig._active_state.sample_index_stack = None + WorkerConfig._active_state.worker_config = None + WorkerConfig._active_state.override_global_rank = None + WorkerConfig._active_state.cache_pool = None @property def active_worker_sample_index(self) -> int: diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 126b4de7..2bac88e1 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -330,7 +330,6 @@ def test_dataloader_fork_multi_parallel(self): train_loader_r1.shutdown() def test_dataloader_thread(self): - assert False, "Not done yet" torch.manual_seed(42) worker_config = WorkerConfig( rank=0, From a76d8606d0e079797d2e8c6270b01feb7113bc1d Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Thu, 21 Aug 2025 10:52:45 +0200 Subject: [PATCH 18/36] Implement group batch buckets as savable object --- .../energon/wrappers/group_batch_dataset.py | 139 +++++++++++------- 1 file changed, 89 insertions(+), 50 deletions(-) diff --git a/src/megatron/energon/wrappers/group_batch_dataset.py b/src/megatron/energon/wrappers/group_batch_dataset.py index 6d8792fa..b736d081 100644 --- a/src/megatron/energon/wrappers/group_batch_dataset.py +++ b/src/megatron/energon/wrappers/group_batch_dataset.py @@ -22,7 +22,6 @@ from megatron.energon.edataclass import edataclass from megatron.energon.errors import SYSTEM_EXCEPTIONS, FatalSampleError from megatron.energon.flavors.base_dataset import ( - FlexState, RestoreKey, SavableDataset, set_sample_restore_key, @@ -36,7 +35,7 @@ MultiWrappedRestoreKey, SampleIndex, ) -from megatron.energon.wrappers.buffer import SavableSampleBuffer +from megatron.energon.wrappers.buffer import SavableSampleBuffer, SavableSampleBufferState from megatron.energon.wrappers.skip import SkipSample T_batch = TypeVar("T_batch", covariant=True) @@ -44,20 +43,93 @@ @edataclass -class Bucket(Savable, Generic[T_batch_sample]): +class BucketState: + """State of a bucket. This is used to save and restore the bucket.""" + batch_size: int + samples: SavableSampleBufferState + + +@edataclass +class BucketsState: + """State of the buckets. This is used to save and restore the buckets.""" + + buckets: Dict[Hashable, BucketState] + +@edataclass +class Bucket(Savable, Generic[T_batch_sample]): + """A bucket for a GroupBatchDataset. It contains the samples.""" + + batch_size: int samples: SavableSampleBuffer[T_batch_sample] - def save_state(self) -> FlexState: - return FlexState( + def save_state(self) -> BucketState: + return BucketState( batch_size=self.batch_size, samples=self.samples.save_state(), ) - def restore_state(self, state: FlexState): - self.batch_size = state["batch_size"] - self.samples.restore_state(state["samples"]) + def restore_state(self, state: BucketState): + self.batch_size = state.batch_size + self.samples.restore_state(state.samples) + + +class Buckets(Savable, Generic[T_batch_sample]): + """This class manages the buckets for a GroupBatchDataset. It is a savable object, which can be saved and restored.""" + + _dataset: SavableDataset[T_batch_sample] + _worker_config: WorkerConfig + + _buckets: Dict[Hashable, Bucket[T_batch_sample]] + + def __init__(self, dataset: SavableDataset[T_batch_sample], worker_config: WorkerConfig): + self._dataset = dataset + self._worker_config = worker_config + self._buckets = {} + + def save_state(self) -> BucketsState: + return BucketsState( + buckets={key: bucket.save_state() for key, bucket in self._buckets.items()} + ) + + def restore_state(self, state: BucketsState): + self._buckets = { + key: Bucket( + batch_size=-1, + samples=SavableSampleBuffer(self._dataset, worker_config=self._worker_config), + ) + for key, bucket in state.buckets.items() + } + for key, bucket in self._buckets.items(): + bucket.restore_state(state.buckets[key]) + + def get(self, key: Hashable, batch_size: int | None) -> Bucket[T_batch_sample]: + """Get a bucket for a given key. If the bucket does not exist, create it.""" + bucket = self._buckets.get(key) + if bucket is None: + assert batch_size is not None + self._buckets[key] = bucket = Bucket( + batch_size=batch_size, + samples=SavableSampleBuffer(self._dataset, worker_config=self._worker_config), + ) + else: + assert bucket.batch_size == batch_size, ( + f"Got different batch size for group {key}: {bucket.batch_size} != {batch_size}." + ) + return bucket + + def flush(self) -> Generator[Bucket[T_batch_sample], None, None]: + """Yield all buckets and clear afterwards.""" + yield from self._buckets.values() + self._buckets.clear() + + def clear(self): + self._buckets.clear() + + def worker_start(self): + for bucket in self._buckets.values(): + bucket.samples.worker_start() @dataclass(kw_only=True, slots=True, frozen=True) @@ -82,12 +154,11 @@ class GroupBatchDataset( error_handler: Callable[[Exception, List[T_batch_sample], list[SourceInfo]], None] _group_key_sample_index: SampleIndex _batch_sample_index: SampleIndex - _buckets: Dict[Hashable, Bucket[T_batch_sample]] + _buckets: Buckets _last_batch_failures: int = 0 - _savable_fields = ("_group_key_sample_index", "_batch_sample_index") - # Buckets are saved manually - _state_fields = ("_buckets", "_last_batch_failures") + _savable_fields = ("_group_key_sample_index", "_batch_sample_index", "_buckets") + _state_fields = ("_last_batch_failures",) def __init__( self, @@ -135,24 +206,18 @@ def __init__( def reset_state_own(self) -> None: self._group_key_sample_index = SampleIndex(self.worker_config, src=self) self._batch_sample_index = SampleIndex(self.worker_config, src=self) - self._buckets = {} + self._buckets = Buckets(self.dataset, self.worker_config) def len_worker(self, worker_idx: int | None = None) -> int: # Return an upper bound. This is for sure not correct. return self.dataset.len_worker(worker_idx) def __iter__(self) -> Iterator[T_batch]: - buckets = self._buckets - - if buckets is None: - buckets = self._buckets = dict() - # Load saved state if available - for bucket in buckets.values(): - bucket.samples.worker_start() + self._buckets.worker_start() # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] initial GroupBatchDataset state:\n", end="") - # for bucket_key, bucket in buckets.items(): + # for bucket_key, bucket in self._buckets._buckets.items(): # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] - Bucket [{bucket_key}] (bs={bucket.batch_size}, len(samples)={len(bucket.samples)}):\n", end="") # bucket.samples.debug_print(" ") # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] initial done\n", end="") @@ -160,7 +225,7 @@ def __iter__(self) -> Iterator[T_batch]: def flush(bucket: Bucket[T_batch_sample]) -> Generator[T_batch, None, None]: # Debug print the state # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] flush GroupBatchDataset state:\n", end="") - # for dbg_bucket_key, dbg_bucket in buckets.items(): + # for dbg_bucket_key, dbg_bucket in self._buckets._buckets.items(): # print(f"[wrk={worker_idx}, s={self._batch_sample_index.current_idx}] - Bucket [{dbg_bucket_key}{'*' if dbg_bucket_key == bucket_key else ''}] (bs={dbg_bucket.batch_size}, len(samples)={len(dbg_bucket.samples)}):\n", end="") # dbg_bucket.samples.debug_print(" ") batch_items, sample_restore_keys = bucket.samples.flush() @@ -211,44 +276,18 @@ def flush(bucket: Bucket[T_batch_sample]) -> Generator[T_batch, None, None]: except Exception as e: self.error_handler(e, [sample]) continue - bucket = buckets.get(bucket_key) - if bucket is None: - assert batch_size is not None - buckets[bucket_key] = bucket = Bucket( - batch_size=batch_size, - samples=SavableSampleBuffer(self.dataset, worker_config=self.worker_config), - ) - else: - assert bucket.batch_size == batch_size, ( - f"Got different batch size for group {bucket_key}: {bucket.batch_size} != {batch_size}." - ) + bucket = self._buckets.get(bucket_key, batch_size) bucket.samples.append(sample) if bucket.samples.len_worker() >= bucket.batch_size: yield from flush(bucket) # Flush out last samples if not self.drop_last: - for bucket in buckets.values(): + for bucket in self._buckets.flush(): if bucket.samples.len_worker() > 0: yield from flush(bucket) # Clear the buckets self._buckets.clear() - def save_state(self) -> FlexState: - return FlexState( - **super().save_state(), - buckets={key: bucket.save_state() for key, bucket in self._buckets.items()}, - ) - - def restore_state(self, state: FlexState) -> None: - super().restore_state(state) - - for key, bucket_state in state["buckets"].items(): - self._buckets[key] = Bucket( - batch_size=-1, - samples=SavableSampleBuffer(self.dataset, worker_config=self.worker_config), - ) - self._buckets[key].restore_state(bucket_state) - def can_restore_sample(self) -> bool: return super().can_restore_sample() and self.batcher_stateless From 7c0bec477cfb5c037af8a468c29e650b6f95d694 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Thu, 21 Aug 2025 11:38:45 +0200 Subject: [PATCH 19/36] Make analyze_debug and debug output work again --- src/megatron/energon/dataloader/dataloader.py | 68 +++++++++++++++++ src/megatron/energon/flavors/base_dataset.py | 7 +- src/megatron/energon/tools/analyze_debug.py | 74 ++++++------------- tests/test_dataset.py | 10 ++- 4 files changed, 98 insertions(+), 61 deletions(-) diff --git a/src/megatron/energon/dataloader/dataloader.py b/src/megatron/energon/dataloader/dataloader.py index dd2bed2a..63ba28aa 100644 --- a/src/megatron/energon/dataloader/dataloader.py +++ b/src/megatron/energon/dataloader/dataloader.py @@ -33,6 +33,7 @@ from megatron.energon.wrappers.base import BaseWrapperDataset, get_sample_restore_key from megatron.energon.wrappers.batch_dataset import BatchDataset from megatron.energon.wrappers.gc_dataset import GC_DEFAULT_EVERY_N_ITER, GcDataset +from megatron.energon.wrappers.log_sample_dataset import default_get_keys from megatron.energon.wrappers.watchdog_dataset import WatchdogDataset TSample = TypeVar("TSample", covariant=True) @@ -78,6 +79,7 @@ class DataLoader(Generic[TSample]): _next_id: ClassVar[int] = 0 _id: int + _next_epoch_id: int = 0 _workers: list[DataLoaderWorker[TSample]] | None = None _exhausted_workers: list[bool] @@ -96,6 +98,8 @@ class DataLoader(Generic[TSample]): _spawning_process: int + _global_sample_idx: int = 0 + def __init__( self, dataset: SavableDataset, @@ -187,6 +191,17 @@ def __init__( self._spawning_process = os.getpid() + if self._worker_config.should_log(level=1): + self._worker_config.worker_log( + { + "t": "DataLoader.__init__", + "r": self._worker_config.rank, + "w": None, + "id": self._id, + "config": dataset.config(), + } + ) + def _start(self) -> None: """Start the workers and restore the state if available.""" self._workers = [ @@ -265,6 +280,20 @@ def __exit__(self, exc_type, exc_value, traceback) -> None: def _epoch_iter(self) -> Generator[TSample, None, None]: """Iterate over the dataset for one epoch (i.e. all workers StopIteration). One epoch may also be infinite (if looping the dataset).""" + epoch_id = self._next_epoch_id + self._next_epoch_id += 1 + + if self._worker_config.should_log(level=1): + self._worker_config.worker_log( + { + "t": "DataLoader.epoch_iter", + "r": self._worker_config.rank, + "w": None, + "id": self._id, + "epoch_id": epoch_id, + } + ) + if self._workers is None: self._start() assert self._workers is not None, "DataLoader not started" @@ -296,6 +325,7 @@ def _epoch_iter(self) -> Generator[TSample, None, None]: # - Yield the sample. if DEBUG_LEVEL >= 1: print(f"{self._exhausted_workers=}\n", end="") + epoch_sample_idx = 0 while not all(self._exhausted_workers): # Get the next worker to prefetch samples from. worker_idx = self._next_worker_id @@ -326,10 +356,48 @@ def _epoch_iter(self) -> Generator[TSample, None, None]: # If the sample future raises StopIteration, remove the worker from the list. self._prefetching_samples[worker_idx] = [] self._exhausted_workers[worker_idx] = True + if self._worker_config.should_log(level=1): + self._worker_config.worker_log( + { + "t": "DataLoader.epoch_iter.StopIteration", + "r": self._worker_config.rank, + "w": None, + "id": self._id, + "epoch_id": epoch_id, + } + ) continue else: if DEBUG_LEVEL >= 2: print(f"{worker_idx=} got sample, yield\n", end="") + if self._worker_config.should_log(level=1): + keys = default_get_keys(sample) + restore_key = get_sample_restore_key(sample) + self._worker_config.worker_log( + { + **{ + "t": "DataLoader.epoch_iter.yield", + "r": self._worker_config.rank, + "w": None, + "id": self._id, + "epoch_id": epoch_id, + "worker_id": worker_idx, + "worker_sample_idx": restore_key.sample_idx + if isinstance(restore_key, WorkerSampleRestoreKey) + else None, + "epoch_sample_idx": epoch_sample_idx, + "global_sample_idx": self._global_sample_idx, + }, + **({} if keys is None else {"keys": keys}), + **( + {} + if restore_key is None + else {"restore_key": restore_key.as_tuple()} + ), + } + ) + epoch_sample_idx += 1 + self._global_sample_idx += 1 # Yield the sample. yield sample diff --git a/src/megatron/energon/flavors/base_dataset.py b/src/megatron/energon/flavors/base_dataset.py index fe01bf17..9e72170f 100644 --- a/src/megatron/energon/flavors/base_dataset.py +++ b/src/megatron/energon/flavors/base_dataset.py @@ -482,12 +482,7 @@ def _tupleify(self, value: Any) -> Any: def as_tuple(self) -> tuple[Any, ...]: return ( self.__class__.__name__, - *( - getattr(self, field.name).json() - if isinstance(getattr(self, field.name), RestoreKey) - else getattr(self, field.name) - for field in dataclasses.fields(self) - ), + *(self._tupleify(getattr(self, field.name)) for field in dataclasses.fields(self)), ) @staticmethod diff --git a/src/megatron/energon/tools/analyze_debug.py b/src/megatron/energon/tools/analyze_debug.py index e1eb1754..c0755d50 100644 --- a/src/megatron/energon/tools/analyze_debug.py +++ b/src/megatron/energon/tools/analyze_debug.py @@ -97,24 +97,6 @@ ) -class YieldBatchLogLine(TypedDict): - # Json example: - # { - # "t": "yield_batch", - # "r": 1, - # "w": 1, - # "m": "train", - # "idx": 1, - # "keys": ["parts/data-train-000051.tar/528866", ...], - # } - t: Literal["yield_batch"] - r: int - w: int - m: Literal["train", "val"] - idx: int - keys: List[str] - - class SampleLoaderYieldLogLine(TypedDict): # Json example: # { @@ -442,7 +424,7 @@ def command( class LoaderInitLogLine(TypedDict): - t: Literal["SavableLoader.__init__", "BasicDataLoader.__init__"] + t: Literal["DataLoader.__init__"] r: int w: None id: int @@ -450,33 +432,32 @@ class LoaderInitLogLine(TypedDict): class LoaderIterLogLine(TypedDict): - t: Literal["SavableDataLoader.iter", "BasicDataLoader.iter"] + t: Literal["DataLoader.epoch_iter"] r: int w: None id: int - iter_id: int + epoch_id: int class LoaderYieldLogLine(TypedDict): - t: Literal["SavableDataLoader.yield", "BasicDataLoader.yield"] + t: Literal["DataLoader.epoch_iter.yield"] r: int w: None id: int - iter_id: int + epoch_id: int worker_id: int - worker_idx: int - idx: int - iter_idx: int - global_idx: int + worker_sample_idx: int + epoch_sample_idx: int + global_sample_idx: int keys: Optional[List[str]] class LoaderStopLogLine(TypedDict): - t: Literal["SavableDataLoader.StopIteration", "BasicDataLoader.StopIteration"] + t: Literal["DataLoader.epoch_iter.StopIteration"] r: int w: None id: int - iter_id: int + epoch_id: int LoaderLines = Union[ @@ -487,14 +468,10 @@ class LoaderStopLogLine(TypedDict): ] LOADER_LOG_LINE_TYPES_T = ( - "SavableLoader.__init__", - "BasicDataLoader.__init__", - "SavableDataLoader.iter", - "BasicDataLoader.iter", - "SavableDataLoader.yield", - "BasicDataLoader.yield", - "SavableDataLoader.StopIteration", - "BasicDataLoader.StopIteration", + "DataLoader.__init__", + "DataLoader.epoch_iter", + "DataLoader.epoch_iter.yield", + "DataLoader.epoch_iter.StopIteration", ) @@ -553,34 +530,29 @@ def loaders(self) -> Dict[int, LoaderInfo]: loaders = {} for log_line in self._iter_log_lines( ( - "SavableLoader.__init__", - "BasicDataLoader.__init__", - "SavableDataLoader.yield", - "BasicDataLoader.yield", + "DataLoader.__init__", + "DataLoader.epoch_iter.yield", ) ): - if log_line["t"] in ("SavableLoader.__init__", "BasicDataLoader.__init__"): + if log_line["t"] == "DataLoader.__init__": loaders[log_line["id"]] = LoaderInfo( id=log_line["id"], modality=self._find_config_modality(log_line["config"]), path=self._find_config_path(log_line["config"]), global_count=0, ) - elif log_line["t"] in ("SavableDataLoader.yield", "BasicDataLoader.yield"): - loaders[log_line["id"]].global_count = log_line["global_idx"] + elif log_line["t"] == "DataLoader.epoch_iter.yield": + loaders[log_line["id"]].global_count = log_line["global_sample_idx"] return loaders def log_entries(self, loader_ids: Container[int]) -> Generator[Optional[List[str]], None, None]: idx = self._start_idx - for log_line in self._iter_log_lines(("SavableDataLoader.yield", "BasicDataLoader.yield")): - if ( - log_line["t"] in ("SavableDataLoader.yield", "BasicDataLoader.yield") - and log_line["id"] in loader_ids - ): - assert log_line["global_idx"] >= idx, ( + for log_line in self._iter_log_lines(("DataLoader.epoch_iter.yield",)): + if log_line["t"] == "DataLoader.epoch_iter.yield" and log_line["id"] in loader_ids: + assert log_line["global_sample_idx"] >= idx, ( f"Found entry {log_line} with wrong idx <{idx}" ) - while log_line["global_idx"] != idx: + while log_line["global_sample_idx"] != idx: yield None idx += 1 if "keys" in log_line: diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 44c625c5..74d68c40 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1687,11 +1687,13 @@ def test_debug_dataset(self): with (debug_log_path / "0.jsonl").open() as rf: for line in rf: line_data = json.loads(line) - if line_data["t"] == "SavableDataLoader.yield": - print(line_data) + print(line_data) + if line_data["t"] == "DataLoader.epoch_iter.yield": for i in range(len(collected_keys_order)): - if collected_keys_order[i][line_data["idx"]] is None: - collected_keys_order[i][line_data["idx"]] = line_data["keys"] + if collected_keys_order[i][line_data["epoch_sample_idx"]] is None: + collected_keys_order[i][line_data["epoch_sample_idx"]] = line_data[ + "keys" + ] break else: assert False, "Too many entries for key" From c310c88585802c548bc4396e360d41521cfc9621 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Thu, 21 Aug 2025 16:00:12 +0200 Subject: [PATCH 20/36] Add proper favicon --- docs/source/_static/android-chrome-192x192.png | Bin 0 -> 25856 bytes docs/source/_static/android-chrome-512x512.png | Bin 0 -> 98923 bytes docs/source/_static/apple-touch-icon.png | Bin 0 -> 23089 bytes docs/source/_static/favicon-16x16.png | Bin 0 -> 640 bytes docs/source/_static/favicon-32x32.png | Bin 0 -> 1702 bytes docs/source/_static/favicon.ico | Bin 0 -> 15406 bytes docs/source/_static/site.webmanifest | 1 + docs/source/_templates/favicon.html | 5 +++++ docs/source/_templates/layout.html | 8 ++++++++ docs/source/conf.py | 3 +++ 10 files changed, 17 insertions(+) create mode 100644 docs/source/_static/android-chrome-192x192.png create mode 100644 docs/source/_static/android-chrome-512x512.png create mode 100644 docs/source/_static/apple-touch-icon.png create mode 100644 docs/source/_static/favicon-16x16.png create mode 100644 docs/source/_static/favicon-32x32.png create mode 100644 docs/source/_static/favicon.ico create mode 100644 docs/source/_static/site.webmanifest create mode 100644 docs/source/_templates/favicon.html create mode 100644 docs/source/_templates/layout.html diff --git a/docs/source/_static/android-chrome-192x192.png b/docs/source/_static/android-chrome-192x192.png new file mode 100644 index 0000000000000000000000000000000000000000..4b15719ccb1ed281bf69d0f81239239f28d9501c GIT binary patch literal 25856 zcmX`SV|b)p(={C1nb@{xV%s(*w$s7Hww+9@iQVDEw%wW7wln!M*Yn=rkM91_$3FJC z>#SP4s@AGVWko4u1bhTAFfe2pX>rxh-=2RTIGE3G)pa{qFfcSQ8F3NyAK+)cutxgB zZs$3f%kDVQ8}JZ}M$q1g8jO)AE-dio<8eDAuQUy&A1g9v0}ZC=K>fM%4Aqxu$5EwO zC4Ey3Ip8AVqI4Zj05s(+U)}G^$u|H=-4+JaAWRTxYqZB>r?2D2MCV~HAHZ@0A|b`s z9KAINeFUZl>+G%SSbP<1q+u&cl9T{U3;5bNWC=++9>{v(W zFzIZxU{i{d^6x*{*nJ$HULtQOXTIw4!`W@3!)}~b4~od3YKAHq=beo1x3E>_yx7LTh8Hoi_V-ylAhmsRfM^JV6l8} z+ZFy;h`BTB_V>aUR$OI0PBg0>!_S4J^5O6{q77li<`A`pjBp{}yFWDBYRqzqrGL{$ zY0(D}x|?YJy87ShO*wtoWr<^9)>YyBNN;mUBZxY)8?}rBq*Csz<;3D8UXj;4B$-WI zp^CT*$L2WL8muDhN+!3NIGYd3ZGcbgxRXv-I3#ryq+6)lQjwnt+d6)bJw!@(Uj?qM z11H!JxeH2O<`biF{Pdo<;R2>f$63{>I3x(aR%Jv zJ`;z1wzR=Xnu<0(2j1{*Xz>VT$)_lX$m`gP;S{}2uPMGcG@mbp8+9wg^(c5yvO>zQ z`+J!6n8n=-bj#RcNez;WEmk?FwVq;m!`Q7H6W9(oh z@Nz3sdeg1>50+(-vglK!Q%QmWd$Mm{Ki<9C61majagjbZ;`p}_?8w}JQYjz7%xgqhy>_*Qj*n$tv zakb8w*;ZS>o`~18fHC8ryYc^YB3Q?6kzvx-Fd@p9E5j6FnLtrIdVO#=uf+&EXGAEK z*n&!QQlCDaGN75u!AO$0E-q|$Kp~F~Fu*bZ zr+!J^J+M4Kq?gVw0vba{`zaAZz2CxMdg+EA!p@^;N?cp@B(hBwm+bU{yXxs^hMU8R zldnpOgkOrkIs|M8mV}dwY)Lg?tKbZvzNyQv(NiC$Po|I zP&X@;VV4Kp=ocpKo%GhJTZjTtst`{H+Uw(>48os@koMrF@kkcbbWKsija`Mr@a%$HXZ2;HQz{zPLbI=m82X~k zfGOirYp9}r=5MzjepEExB~od*6z;No<2YAa^%o5bC&dYmnP4N>8?pv6Ni+weJW=%D zt2q9Z($8DRT!nc$#EaodwUp^*Nq`1nhlyxzNJEx@9DvT9LC8o@{$2T9y1g+jEN_W? z&?-i&Lbp;{qYYjHI=DGGxdBsqQRO1PC`@0AnotYw{2{msqqCXvm`?A9yv1N%XoPOd zV$EWA*8vDvd^i8a11CxNozX{M%cZvUvBGf_bZ9`Ol6DDyDM5ts&mlc<;jE%Td(Aqq zOQCKif3&5RQu6`0aCL!s;*AP+1fUVoRjr5sq|Mr^7R-<`9D);Y7}W@#BZ(H%q*Yw0 zTz_t@Ntc(I+ZMfW%$oVr*Jjy0p0+sG;v4m?84<3fL^tysok#s*J3BdU{5pF(3Xm*}&`djYcx z+^fxEP8E&5ZZ+#WVA?mLR7IQRgy`R9x7F6SBGuH9t*zA zd){*7Lo?>NEe+x@!XZ=*JaaO0ZS17m%-Gmwv)}DcA+$2RRflYltjXuo^WajhS>%ZR z$uGiEjDsawdUq-W8SAZ3H!}h2TX=w;90#xixUAS(R;1GH5a08?B~i6lL`wnmJpyF8 zJ2x2QJZ#ZH!PwTD%mu3BqPclMh~?gXS~=&jU>PPVfmC@3}_egf(>5b_7=~8B7{RS=wOk0A^1)%q@C`_itL+dj*iE26Jc|xy$MK^U{#a9 z?I53KD4@Pf%2N0~=!zGQB|BpX1VjaMKdAd~Qb&dDU(H%QesxUkFdH$Hp{}~PQZ>Oz z`{oS*{;5rJU@tJxv(XBaM>lTEK9PnM3HUBVnN1}SDO1y2Rra#qet(nZ`8Yn-feJ^3 z*ta40C)!l$m3QQAJj*P1P3px8)p!8cFz!Cgd_t%jFAmA{bi3W${0;IZZ4KG#mg4i4 zA}v5XEzT!69l3>G(h>1!4zxGUTH$O?v^URMbsq`{kcfbZIn4#P)E+hcs_s{6 z^x5NYo_B8;bm5UwHw%-1+6FgN&CdDf!6@o9gUX=%e2)XWkf%QTqdt-XCKme_1`%<)bQVNMEif*=;;V&(1Q zWjLDRtYfmn1f@AuUGd4<-Cn(s7T`v^^@OO2KuC7jjb)%C1y#)R!ci*BFW2o}DH06A z+}UkosuuID?-EJoPRoT7mDCJ+cI6}}&PBD3{tZdTkYN$doi10*cMI{>yjww^YJrZp zTyamSsnjhIRDa}Y^EybUGnqt&AEP8fpBEJFy;Vqgd941@F&i_{59@#57I~?5Wx?_M zf$G*VBd&kA;y%)Nh~|pct^T-q)ZSL9{r@1vqnQi{%~@~%kl?&ANn4GSu7HO* z6z2aCP;W8;w_dl$^WSm6M9|ia5Q=Do}1F)OAO*YQ0 z{e!Ct6PhguQ7AVacl8J;Rvn=(rjhnm&rlQhV0s5?^W6%n5O!)vkD-q&a!^@uOOBG* z6)hmfesLv5i@}?^od#TbpeUa_OUn9@1lFD7nm1z+W9};(-(fp1#=IWY?B}6G6%AHZ z))f!8hjaE|H%Ho=5-^35229%h&XHec3ylZm>WsrN*>1ndh&Gr=MB1%@sGVxlfz`G*Z zqUCf&bsM~Ix>6+=JO^Q`Tq7QnABFm=4}4~os-p2i^wORX^7`Sq0>Y3&4*BY)u-xS= z6~#Y!`WpJVs{-&$hjH1{Vvbpx!QAle=b6F6e1oYvKXa|}=@XvRpms(bbed5VVu zKNcgU3k8ZNf8$%qJ_aRkykTT)6*62EYw{JkE#Y{-syudg%Zvu+{7VUyLqtFCYM)CR z&-rY}(X2y1X^lNiqqpe#YF+Y0Q&3=_P|96WniDR;P#HIQ7#pNJKM5`|v?Obe1s_k& z8Jl?jG3xeWl+3zO3HKU^QA1otRi?w(=IEDQW>Q>f2xh{ij68NIno8m>R)y{1y7A`M zk4JO+>(ZDn&nbzzSI#$1IQMTREN<_vuvB_py&N1=l&fn}f8y;>-AsvemXpP!_`kIr zC7&Fuu~gdHUYP0G=l3;pz(8n^sL%CA7|CH^o0h<&nw`PnVn~9MMorGJ4smBwXJu2n zXWPMq*-aS|$Ptt?4Q~GtuiEbXfjB-_ISf32z(UOpN2QL?+ChH>~1^e#~Xg}G2GG9X@ma2lTyH1|+#$K<* zNn^AG;ROfO3!|}e#-N*0QP2;>*{ngEj&4#5;O!k>TdzU)Ts8dh(nbEzj)W+9DM^7R zOTEF$7!>gOUoU{=tvG??70n)AN$YSig74rXOjqwdAc8x31P-<87w{w{&aOpRZ(_}7$VaVZT4R=Or$xQARr>x!xXPy zFrmUjbYCAi#;72jT~c2-6y<(Hl%9t3o7$Xaqc8#)M6j!|wyHt!!pj}AjjsjCJ;_-< zSRmWtTu$Q>Yt!u0SXc5rqC0(d_!o;s`zR@ zEo!atz|M*W-(ZgAkrB=$TSSz2`!pLzGvIFU#!}1 zkNd|Koj>u#VvUfPAPQG*>Y8bs7#acA`fN^nX8{QMWe;XY;;a!!#3spOksu^zUKN7m zg$jh}?a8PmENXFe(WmcrV8&X|NEj2*4AK9n^46Q`E-mbVX)ZC6k+32FMSRu{$U&1n zgh-TSE%i*SBC43ICRMl`LvhA(rlZ623ct+3?1-;bLoqpsh|a~~z*r>KG>TIN+Rc$IrMlI0 zISXo3WlAtha(-TVh*WU2s~l)+8&8cnUuWjROAIhg`v13om^Or3A+klgMDJH5F!$Mh zV5mo;yO?6$Y|Sh+E`q*`&s~(@QYYlaqVHzul7ZB54~x6nX{iSi~Hv%4_U_G zG7QDkQyAfz$e?l*&MSV}xvtV{H#N*9YGU?*I62P+f~ZN@eiuZSDJoS9d1r2o*Ta68 z*n%?OW%;12r3Nb*91~4zwi*>>k*+++cc6-=3a^SmH%m9je9W^{{-u4f!&I@1;gnyq zfVs-%JUmvanfRAq6&!fNEO{5(8g@@u%(OP!{{kW%@^xBc*cuCr9Qb|arJc7}1Uc-1 zD=TV_s`ol$NrWIM04TJDx5=eTaF>=@h>itm@D98j>;yix34EQIMWrWY5(O*P*=glW zWhu*l=3*{hZ{U`Ziz}yU92(`ZAaar0{rV9Wst*1~dl=Z&!R3dx*2j77ZOjs}8P5o} zyM;T+hHwP#2u4x-fcMWVrbx@FUvY<0q_<$X=I%TdD!~HuX529CtA^1X`4PUlWvDym zzyE;Bbld_(iX%*+K!o%h7ZQqnP|Wm{L(^wq$e2FF1(K zOf+CGnkv~vuZWp$KqcNftzb%NdOlt{47sUZ8;)M6t0cb-hUw7C zEDSEKbV$Y=(AyvmoOFf4GRat}*B)s66;5-yP2-Rn1VNI!OzR|wL25}NTYO|>%3$va zKaCpamY8?NhL4Mp!B@AUs_Lwgy`=vw-sT_$c^Y-9USR6h)*?OB?F-I6#`36Jco69p zr`{~*X-|8dVio9lW#AD?@<`mV$RfrH=^36y%vEwdl3Vep!OH*}+=>0^DFnrjVlntn z>Dbk?E1o;t4QyYucVl=9fe^;w#wHZz)qU7Tmx$MG>^sn2kD>|aimoJ>=&h{#=;Sj1$8&W9U}BBl1KgY?|Ol=)QRVjQQ2i25_(uDKA0}`8%(Gi9Y{P&p(n) z#zfdcL+FCtc+1xB{Vvxp_dzG8N-kmz8i!@rNJt78+W@dr-O zp{-2BTZgB{BWHrzuz-a%VK=uXuI-$+08^mJDC18)8*Q z=MEboaLzixAvm>pSEMM*9t}e!valW(4pwtJ>p$hKx1L?Ze+Q3Wyuc@*RRrEFN&CS? z)b76RLx2&(iBR0pAdxRR4kTB@h#*Hp+rvD)BAQvt1s7uwiaph4mQeXJcXWNowB`Dz zr(Wu9k`h~xEr^M4AzKW6b%rRx)b`?sP!+u0QRJOj5FPEMWDA@(DWG7Qh@In0n=Wa8 z*7y?JY8C<)8h?7AG0UepLB`x~KL2LUpXt7ld*}h zkNxrqnAv#_l}=YW|!DK;RkQdN3f|l#*CTuy5J7bi6HufP%-)G2+teJiAdk! zz=5vf7dkw+bbc^l#0PIU6gKI5(};ND@Q=gAS?NXz@FjdYo3@ggMXY#UhXO}6p^b+fSd6S@x$yfC$*0wFajwJekkG-=@pF>monK% z9*|!nE_Yu&%}7y-=?LLm7_X9_Ge?EiiQ`VBBBuLkf;EE*zUHb+g^xt3MDOo}e*is{ z`rOV$hULMuGtF8+?%I&laFYe)s#2bd6Oi2G!@=}#q4%`T7BkPeJ zOAY1q%DJ{Ts=)n;gBN0v=!a1dQjqD21$pF`2&(tY+Iv(gD+aIEQfJi7aW|`VKs~h5 zw}GJVQtXm!MyW&DY-Wt%Y?gMz{n@_TU;wUxY~R7h#!EOsLS(3i^P>o`ynd$+#53kq z_fZI=6{il&5Sr#Asy_csF%Z@xu@u{S^Bm*2y>$zr&~h`trSD@b0~li~Gp&FCQcF9e z?;692Au?R#{gFENT+J; zY{(C8yv$T!>Ui=(jOmVW>`c5XuTLPkoFYEzJl@NZw2___Ne#>v;~0!#FFQ!7q6mx+-@MrN9Pzl0D|hUW`BX$Q^|VgEBF zJWHqxU5ZC>IEwYmu>|Q=9^_+eb3G&NH;ky!FF=_bU&yV@Ox^@$=he}LSWq97IQzfCYB1w)b<7wkoDGNAlp%o4G-oK=x2MZ!qAeS8 zVpYpW=rzuFq-#$|YC5n-(892B8op9a9Ti1cV9p4#R~B61LQYW=&_7@ac`6kVkA-&u zU)R^eWBJJ{@)8pQrPr{j0Coh0K+D5k_)#*JH=zz$YpujPU^B(oN+M~33UcID#^c+n z7WZLkA$m$9|FO_SaylAbS3djhj=jo51~uEefuj? z?Axsff~K4GsH^E`^9U?LcWkdeeWyO612P~_dqZDbNlCv%dWu@!x6g>)h%#``zk_Mk zp>20|eZRWV<%WtocutMF1)vUiBxfuf6rXjOKTb z1>O`%LaU@+9-O9P-B$N$8wX6!ZSjsF?QlyUNgM4pI_WbUx6>21zC{~n2v-Ob6bq#* zz6q=O8z#YcB@Yaz>yi+GDI{>>;D%~T5T@57e!?X}2Kz%yy#@w@*#w^VIu`O4;sd;) zR?lX5ttuShg8_n|`W4k{%9h^4zauF+S1nxYqw2%)pu<3Li^2#cW5VGufeCx`VqoMM zElB51iy6`pAn$rg+fg;}!ou3?Nsl8ycpH)K(2u(TT3c|M`HX@hYQxD%r77!XP6-(L zDAQFX&%b?R*sHC1=5PbOT}Bxi+zUKUt(`(-r%M(fjp;w7|qIR{loti#AgO9iX&*>8(E!zsM zdKTPUuzgVyYs|!y8=oh|0>MbH)zUM)hk7QnBfx_IAOMnV4@z}=TewvY0}zJfmyc)A zfewy58}c6kQh^y^7++v7Ky1?6dex03L>KeFUI2z9!lKFEb9F~U@ctME`S^`5$fcy{ zOrHk6PL9Ol=G@Wsrgsu8hlTRUFG|0Cuk$Y+^ za_tRk`i(L&{rnpMm&aBB{^G? zQcPM;G2CDnhdWM8C@&#a1tB!H!jrZuBR4ovn56q$kJZhvtm&)PIz|4ujrgDa2V}rs z@r|ap#$U}woIzAWg7`xOIaseKf@X#dN~fwfD_}~7l$J9KBU*TCI zN14)g`9H}Z!ehEDZ}Q0LJ?jLaj6|+}zs*Tb``J_acR8^{yqDjj&)uoS&T?}OtwI+sv+GH09|IY?zrUZX2K5dW*IRSkD1qkNp)5qfzg5!al-e4qK0tmA*mqSfE!GUUmUC$lfxIi!S%I73^#XlM3O) zRVEZ()m0`+(L?s(6D3mei>#?#mbIbuU%=?(e83-9=#bBd8kIZoCpK3or?z?MarM%1 zO|`2jZ^@9{qlYS@i~eY8d-iQnPkECw;Qz~`G};~hmrfCyTFu9FgzM*5r!2E;3fcF^ z(C4;~^QEXc|J*E`OLKWjX%;*RBP2Kv(?S}4xk58~+)wIl?8I?4YR;Uxa0fzB#^md)bR^|m!Lz65ISFIdGfrbs z{e9E`$JU;3`%84^<184Wdy$RdglV%vqK`YX-Mg^Pq6f~wUNbw)fBpQCvK&k>hHD1f zF0zxcZYU(%fHr&$G+d6QaxsH}6-=LmE*MSSE0t<}!Ck-WL~}!qXa;Oqt-+N~kcV=z zBtbn95An@aUsD(aj_WIL{hzBCPhBnR!Vx00K`EjGqIo2@@KYclXDH#@1Je07z1y6F zFt%qt#WebV$Y|R-x3{-1gGr}Ldzzj~L)sJcA9$+?ubW*{T-U!KdzO=J(@A;FVH;0R zbcR5Q=uE5^wjIQI=GASe0J=vWYiPjo=-P#3fa*HV!6)uaWtIVBqV z*m&!2)tby_*rgoA4+bVi1G^$*@NM_j$sYDZTiILGb!F9;5(v-+ZRGd97av)QZ*3VE zzl4nT2kj`TKvpEZ?=FGoKJRNgg)9aJx^;Q>X~yv^tbGXaYy}cj(MNHfAB#Eyx0UP{ zRZ|05@tIe_j^ETCDfr56rI5)Uf{}-%`*C~()XYTgc-`(NUI%k-|3-(fP}RqW4*!sI z5z7njQd%R!$-sQu<`zkWN#9 zW!zblRs7Q-$$MCF$n+@>+WwA~JKb|e>sDCP)S7E~E;ky_hrfMeeJm4MT>>^_$FE0{UxPMYS@w6bJigU-r_#xZ; zpfZgZyAT;5&FL1Grmz|Of+(SK_x{-+?WT6f`*6lg_2fR}k%vH>5_$R7A4$2XrujT_ zlGMjC#Po{VlwT7N!JOvw((h7#a_EI@7r^B2sv*;R*;QJ#Q~K>%TeMFTZ_TdnOL9G0 z%j8E!Q}c^$1?__$nYapm%bj4PO7rb3(L(=6L+S?~%8ma}*Ws$~!u<-0?6=44sJ4`n z2JXZ3y{5*KL4{kvGw~|IA${^mlf1?cEbgAbk!={l}=XiBGJREAZAOgp9Hwfkwq_F=ziwS1tMq>O4%0!$@F3T-Kn~a zrJchrbqBW4=zNJ^;`Zrj4u%sRxQ%m3oV+0W7`$7hW=41v6e%}#Rl8C)ny&OJs6ROR zR8w(vm5@5SP(C_5mTzIp?!)!0kEFSvxl0Swbs_FvB6m9%!e;Y#j|Ea}E@^_<%fhuo zI7Ml7DlGQiE8aKSg%%kYYS}#%1LZTNA$Xl@agF>~xiC&&jTU444RZ1xMWmugY}x;y znVp8!<*{C?P>hQ&@u=e1aQ=x)`Di2O`sgRAsCYTVUm0j-J6_RjUi~78a90u4)U%_c zuCYNApn&I;lz_%E#Jo7Gl#}(S4D#nhGabh*y*?0lHGrZHiv_38P}bvO^Kv!Z*=y*KXE7J_GHmGYrE9z=uNZX|VzXdB3O zdD91d7Frp$6FdhS{{&WT(p!{^!CCVDSH$vID>=@|mh^1;=o{tVuW&McNOpx%)P_(B z%^eh!R}VDHmMzGN6Vh1Ux2E0cL(i4c?w1&Su}Dp<^B8nWhnpulw$3i~DII7ldEh$3 z;hpELWy%eY!{Ie(qKkd`y|2A6srHX1@9i^$6azd#*l6fpl0Cy{xzmOIm8ZA$ zsxSTWZ+Oq(^r5WB7?sVbT~|W!3kdfZGE-0Lj%VIJf$P zLN33X3UB-lm)5+4`CC%jdc4@VAe=%%dHUC70zcgoG26s}Z8*c$ev1$g&ofd44oGSn zC8Ro_hXAh%{e5I_cynHZq7~pFYE~`P9t;mBRHkikA{!^XtUxQpFfMboj5D%*={Hc_ zN|7AMfOPmgwoAz36XQTj9t$7pLN4(z7iD`+VEZARy7z{WNE;$W6{3w4@$(Efe;3y) z{@sTKxd8U?yCEjoJ2w}2#(<)orymKCuL}(os}}1hsa-fI2#hk(a$oAgwKnUP23WD9 z$-t$Vz4T34Q(R|sB*yBioA-W8GCSemh){fX2sb5KT5{4>0tNr&v8G z=~zB>eWL}%!=K1{Pc>>h#rBvHJTbVqMF~j8F^OoC-C~S^^4>P&J(tV9kbN17)koKY zyHRs&%%7vWv64&NiHwZ>o|kt-$7oDdox*?_m&HCuQ1@1b+F(mp6o}AAQ;*RZ?<^i? zpyk};t1yXE;Y0;+*|{BV46w}J9=|3W1e%7oBe%y z_Sz1Mtej?Db!L*)Y0;8a))Cw%jZY16T89FK{d#q^3o1tzROoBTJin<2nl~Bf-C#Yg zooqIztA95oWRN1jnw`mO>nBXY9ZWGLpI|!3fLK;1pEkY-EeZrO9Z&)|_GkfE*~o2( zL8Wam{c4<^W0}c{xNFMOG)rpQY$HmV;T2afUw*M^_8kM%=`U!D`LL-3{$i-3J*k9# z$7G;ogBL3fVOuG^DdIRgwe=?$@>4ljS%w#D1{1d4CN6{I*(=C9&wn_VqT@;6zs%Jj z0IYediP4>`yJe#YbVjdz7ZS(Mo97MjR#25Z`vw;>{hMpSZ%}>cFuyAzg_;uU;G4=a z_Qtse72jnU{d@B2^--tnh*suEGF{Su?YHY}@5`!+w_T7$+J#Y$HW#(vX_Uis(ann8 z&!{d6Z@(}AiHR17tC&Yktv z57ci0P04YcVYzKAB7sZA@QSU{7`U-ua?0b)6bnxSC?~cT#iZNeFLjFAOhG@tG`g&N zv}#Sh#kncmOO>o)`Cr5EPVFINE!Ozejt`y+)w|n+%PpTp!_O3>31eu6{8}I8>MJ z&BjnMt7bdE1~3Y{CyGW*uA8E;<76gJC>+oNQhRhvE-<)}z!aVEN!U3vo%s*YS{8P{ z0Ht_82L}FcMgWy2W) z>>2*T@c39_D!;$}Wq26FdmK&0x-9ttl<$|M{+UKTw&};%w6k-umieKtmEd1%Ti(rK z76(PTAdaBv@l!zV@B7Z%y3(eBz-^u^(}V#NUtzvj8KNtKi9@Eg%3oq{xBc|sgay}!zsvBfaR8+HoaXP6;&A95^e#eJ zI%>C62WXn{QLB>J@Nu5acMwqE0HuK72?=OLcCO5#2@MF#{%2tUZuu?t%b~;;_<-0Q zWAPd6sn^#v2@Q57S6%1DN(>um*ye!ZjJG}Sq!1jVGIm`lDshdyiN?RU-@_{ISjis~-FyNrkdnl1+7kOQc=AkamkY^g@@O|Uy1rQ0RoiR7rn4fzrL zg;;M=%luktwUWNOssfOXNW5=UnQ~OHg4gHn$V}(>vVT%>sd+eXiACAmC-W7(&&|0V z;4xVx)Lr5_1b)26Y!@G5F9wT@u5oR?adNWRB0{*gh$vxXR`y#ch7aWWv?yuR#RdEp zY9vhLS{~xN(bb)bpZ7;>vX+Ys>t5=)*9#Lcq$!svc5!mTq9Wl64jLsy>}g5>Y!LyJ zk9sWre*3J*!4tjdn4Q>+55tj5&M`g(Xn2(YxI55iE*Cjw$M(^Zu*j^^0cR(CB| zGJK@nRZ+1nn;XpcaROsk78<%kS0tP7?O5AZ%pr>C*iwp3f5(SrA91fJ_9h5n15ovK z@pN>bhnFfc^Uy+PZIEl&4vSft0SRp**=7#xQkj(#ItLCiww1UCPftYwgk9WI-U3ri zOObnIzb!Ww~B#iggx zG(^IeUR#UrUst%q@dMd3r~=`CPVPei+-aD2e8vq32P*+VFAZI5_Re1`>45Bhif3tn9*b){0@(fRK(hCRyU(JrC#v~)0P z^lYqVGTbXtMS;pS*#=HIzkP(goXd<4t8BXskDFV}Tyy;5N1w`6(pI z0S-15(Iij{&K|k{FS+q!_Qh+9_Y8t_?F6Ggk7AFM)RUR++uE{JmLQ2)C2>tnpZQ9rHqN?U5po9j z|Kb5d-+amiKjqi;@qe8 zF@7@A?K8G8{?)J?joE~YX?T^ncIEN7z&P^yezU68*l~d>jSveZO-d zFhCl_+rBJWQFfb1I!uQJE@PZXDEovEXR~aE5y*(Ur|@KEjNC4j7Z5=2pwe-ghxhD^o;D?)4fo6G>t zgyE=)dywdcn}4EzKxQD$@I?OVCA9tqAkt4ah3IuVO>uPAGo`|I7W;>k$=8%Zt6gxqCEs#GTg_R8qgD$Y&Bttj)8pJZoYi z-=QRlbje@2=<4guzJ+pFt&XjJnT}0LNMdd==2kPr2r8_HTV$#(kZ|Lm+gTIGduVtU zdlv&j{-6xQ_&Fn@mXLxjag*yA-J(8hlGj}P#FNh-q+TYMj3>$pg^P}aND`lYdcdyl z?#}3bDS(LMxXs(j+>o->-iGWD4VqZ!1yn2lp9Gr2zI8n=ljHiw*S(oH362P!RBIN{ z^g5`%3I44Sm?$zbnNYN#(4P(cb0uS(pq|7rc{gWWm_YbP;o`L{2g960T#>S6o~=B4 zMLkprTkdJG?D0}FvDbT(??US(F0yf7P)7uN2`h^vgXY99&nB!d@lw*JtEVoQ7w3ml zbfp?_2T6tceRXn~oGyM|W%a2WU?{_JZYE|4)aF1mQpQnXndG{^a-SKKp)T+(*%lG; zvOMna&7pi(K>3K*KZ>T^zqfDLr-qm2VX!GQ^wDY6+vy+8`p<(NIo4rJ{pvpHR;Of| zsr9aUuLtoJT^kK9lhD|l{Bqb&qJ zo}m%DK7QXGeq6FJW`B4veh&HhnO>h+wjrf!_@G|w^U|UvofZoh>WKaM97XcM67hKZ z!L_|7({6ktznCgfaM5TZ8)=W{Je)F=)20$VWkJ=d3jJ9=vSV2|d>eVJM#QQtYNpz{Ots-@^jqvL@w8hepg;5xZoX&8ohQwgmUKEh;=5BvhHiw zesqc?Slv@#7<=&DK{K`2c{St$h))=QJO^Tz8F8s>6!Q8|Gj-;lj3TLzFiF-#oK;OR zv+Ct(*|2cuvIi^xlGBGrDo^bic(pM+dky!3;=P)f1Hz*-ksrOebB-ijzkMC+fq|7b z;*}rfKQ+#$0Jo2~@dD8zb$)Fh5aF!~4|>`tr=AGi2_kGq@VfbsyTZLizo=vS?4d23 z>Y>gMAS)~c1GCOankcB$J~emiF~LQjN5|S?1ff?Miz^~E%g_}7_op5Q9YLbkX>a_9M%t*igJ{W7W3<76a4g;G5r$h-(*f) zd#r0@7+*K!;yhfQ)?xaS!zbSgs0!#5+gYcPcSK)nROt+Vses-4}TN{;FDgn1 zsHOE@;4|H^Ilq66=Ta;`mrr<+*w@n3On-b*dR zF^*`3%$wKFij|uF1wepEB{vISqpPowTJl_6JfV)04IfeON1Sm9oeq;UI)75h$ah>J z9`&!K<&>JySeD2m4-7OV&QtuO=x|?r8j8VvjY1)Cy*iJK)&@v{+Al51OV7n$EdZg4~QVF8vl+q|4i zd4>|UrxvqF+|b88I*Z<%_VHR_A^A5u{JNd3VVW#ezse5HDlkum2mAq-=w2{!+)J~F zzO22_^rokfh~DILFWXgCXO*9~&R;F4Z}1fM<bKDD@|xPeQQ|u1<=mSDaNE+6?aG!RjO`VyAO6eYMTXZYFM(`@-c^Ff z)5E+k5g%tS1l}l^pOTayn<4^GhfS}%Zf6 z3b?mDKM|7xMC3};Zmuo*FTGI4UUJZ6W=*IZ(6fD_vp&xrrQ$BYEBQFM{&kxG>QiK! z+I~J#O0U2amfGn%CeKXvzg__B-L|*`Yk(wVQNiA-OyA>i&#b!_3TEcA2T38N3HG;m z2UP@mrkMGYP>y5zfD}x(>Gli2Y=1xgUwrPI5)6@^f=x@>8nck~4YS@NI zF<^Ju5lA(_K)Rc*iL?J;OUrw})ROFER_NPea>7th6NOfrF0 zaDmtV7k?vw+#*^>mwY7I!D`tFCMVyte zF-jVMu>_j*P>TAS+C$-E!nx7wL{Wizyp=P)Vc#(KW4bI>0XKSFH@AUfi`(S3-&s=p zl_J~u4s<$#Ui1{7unhxYDtau${7_3Lv8~W&bmO$78YAx=*DRt1)n(*iL_XDSqH{Wj z+SzwYW>?T1@Z}5ynr$hZJ+UR>`^KXk^vDa_!HJfT8_bb&Y6p-y+V_X@59m?yuLD=X z*^s*@nj81%Z395SX5#pv4Q3-MBvibabE{U;kwGgso9H}Uncl%6dad^y2nL zdT#4};(udwy7~ae@wn~$Q(a!p@2)K@{E-9VTLP;+WLZRkC8V;E&llLMtKMs^skz@O zkAECy!~|wdX(9A>{f9g0*z^0S`4@XaIriK>;&Xg0T`t@U6SK;4Kjy5hxnI6x=KHvA zVU!(^umZGc38h$vwM7L#UXq{pJGr*~kF4YA^zZ5)gpCL21>oM?)=g~t3B-wPrGU{4 zx^CGVnvLfOjNR$#db{%$<`uP7#q>`2t`3U}O3KV;V5TMxME3m)+p%m8F1%fRXS&9D zf9+sV5ONb74Y>-=61Y3eVLI*WWX;AYHe7+A4OEYj0F&a)h%QVn4qLq@K&W1bBtZ-& z|8v{+6HD)4;v@}(5LB2Oqg9w6+`MWYl^5lraZv3n+T|6@qP!ehw_+ZxSu~SM@^eTC zSd-BLk47(TZ=z?yp|>52X+Di-kH#7^a_&c#OGlpDOWo-GeFOtb_rdn?+!0*kGGp(P zj+(8?*v0Ruc@TLtYdE0Z!jB}t0eG4Ceya@j^%hK$DR{JhK0CFMApWK z>r6a01pILUO_&RmL^41+)ImRdb{lOya5QuZ*GVHUB_$Q*#pq@@qEG4P5 zAWpZgUJ#O*F(qI-wjTa0gb01|j)gUJ)5_Y=R|@yj$cxJi0^^3pW59eHeBDk7w9n@! zLUQ6s&736K0I&qKq!GT3DDXBr`8L;1xg@}MLiHP{-XxC3Fal@_%*08;Pthy84${wF z-%Yy@w@?zj17*A<0%1PVp*z+tq@~C$Bc6Ry`wsH_MKxt0{sLNpp_hmBjN2&yq4qbo z)6;M4rROm8^4j_M`Wp>k_U60x->fbw{Jj;0`7e^uaI?S!)&lXh&Aq;;$f>ed&G?8@ zSMfC~*SlQ=#ujKJEhgbp=l(WoLXXeCYn`E_xQ} z)?wtoJm)}Dr}V_eh*aFR|L^krqOV?8RQmn0h}UBYrEDQFd9bWNg5(gjxcpXkX~iE{ zML8dlN>sqBlqm!4{j~4dJ#^?bJh~tGix=ayH{LnK5cS}mLtC0?-!prN>&mQ~kEIml zR$=TeYf0tTR9)F^9FPI#2ZXr+{A01sGa`|m8;i^TV`V|XSL3$ze>H5=z;dkkS`%aP zEYgxEfKgj_-zaYBgFIl$spy^Pd}fswhM3pRD4{#AT1d<1R#H6bri+GqG(*Qad+E_v zcG2@lGml~^*+;|JPQm;_&+VnXpgx^-uOBKn@+L?2=e*mCJ2(frSdu{R-wg8q=gT|j z#T`x5i{6vlLbKrp0#I(uqjxTyO&?vil5RnI#9v=f>V!$g7~%&hTlO5LAHTGn4z_lo z2lC+nbBdtspe^M~i}T|5zPr5i@8`y%EuNCX=3X_-^QCQ+iK&|0I``T+Uw3Ebe?cl$ zjfxb6LIHCVeM59)+hN-I^bYFAqx%$kaMt^LB!yKHu7m43y!8-q-Gny6wuV`wq$Ov% zRRy1SZ(4NK=xUMxaNp1_>z<+{Ruk0_c-l&VGkYI*-Sp3duXHz6Cqqz^RA;|-P{JZ zm2Nk^SDwAxol|_Dv$Fix%#7GV0i8qLHDSW6?Ep3{jORApP*U`rljzJ z>&uFtEcM($+d}t=-ZPNG5kfoClW}`#O}k#)OO0C( zQ*Y-0aT>?-)(I8#_*^%)u@h}|9qpn_9r#!nDo{D@d)&Fj_aQa= z$aRLUFE9DowIxMgFLj+wjKO%=@VOSGLeKALq#wSxjd;ZdWyjrkl-LWOuy44tX@{@) z{>hSYl{(9l=N$lh(10dGlPM`3P+M<5J@(ozdW>_`o_>%`LlX*z^Ul6O`q`U%>Bs2x zk9GA@0?Vx9=x2~D_Cze}@#~6h9C)e=TZRj~PKd zzzSm)^w01sFI68$ATXv+Vxrv9HxOdYpa+;cw800D!XGT`8!nG}CY2Yg%lXQA2LLx1 zXafRHzN7+8JCC%|_g`$F=g>nn96U}hV1?la`21GmQR3C3@pOZ_DObwPHIeAX_f(X9 zt0o?AFQyjHFB%>=~q{-A8HL%X{hA;r5Ua z6WZIIO#s1|V@KMd!#--<)g1n6YNn2StSmXpomuev?t-#=tr@YE8ny|Ev<|ig7!la^ znw*^WcUP2sdsZ}7?@IMFaPuL!{baxLYR*=ql#DTB&6qRx{COvqlG!KY-8UKMn2YBt z7Zd=K?({Ll=ouKI*Y+O46Xhm)1ryzL?=+rX0I0jvi9E2br0jv)$||0Vd7f`+8}?h4 zl)}MW(bX}h!d+7F8FxnhZzIpY9?bSAFd+Z1=#-}BrZ#%>m3?$z&r#y<-iHo`UYv_g zx*!NxhsM3f=&hIc(UGQBVjVMeMq1?B5&NBys)8?guuZ^5wpC&zAhImRfU&r>wEX#7 zO3NQ82S&iEXFLq~9FMY$^@|NnA@<$y0iVB=0)5YWljnUWZUp||2&W&HPC|a-4ZZhi zbpbyI^KJXV+sh08v)qpMq)Z?%c~pT(+Q@kvirGb~I`_ud{L+nX)V*06kt+g&GKR1R zsRHWl?58(h-AmiIHc=l^tpLOSrF42`+pG)gv}5bRP{-cRzEEe@eMIuLd7MWkR zQO(R-$2PO=jDZNWVoXXb)q7(_(NBQYCSz#uwEo3~ATVK^;qbJ$Pu4Zs4l%e^(iIdh7`J z1hakVOB3^|b81V!Y!^g7ryRLbXgvJQAb$#v?sxBLrY*1UqozY`G=$Fqy0ptdSXb7W zb!Qvc7Rs;<$5O&(uQ>8queRtbBG;>COaKirMvPU3>;3)deX6krV8+Gz^!Hr6@~rpy z%J~NXo`iu90O8W5KLLiD+W-I%07*naRNOvBVhS)`^PWGC6+_f{Lq)~QCU7+25kgeB z#;V9$8LKY*j^$WaOG7bgIibC2U24JVm#^+0SK&LizRjuoJQ9?(KsQFpB*kx%e&rVI_G zs9|q2y|ncJbzyl=bDlp9q!YHVO$`{)*v3qo`B(~(@8(3-64D?iFq*M)X}~9AjsyU3 ztegq+vE?uJ=Z5noFz%C(^WGSUsUT^9$J50alQe-9;Ay5Yigb!nG%z@1p{y8dxMM4- zDl4FP%tNeVr&?r5UU^=1Y}aU7IrR~`(Pv}GS6Mg z#H<-|`?l?#< zBGox~tdjzsAZFOAl%#nzWpvH5IaFI!O0HvL6hMlF%7V*5Kw1Yn*Ul`Zt6{@@*upkt zm>Al8@OUS^wCw=Bx}%ADhY~>1_#sA&6)@B2`vcT9R+!`gN@Ii=#$}u`!6cr`wCh}t zFDL*Qh)XUYqz_?Yj;jk6e&W6z=$J+^$Jf$J`392$+;bGmnlHoo{0s2ieM2W<$N5oS zPK1`ut)y!~qVl5LDdFt1Ar|Y(zt+gQ^H;;!mQ3T=R<`+-?M>9&))TJi7(y2TY_tJJ zz=|>RLE`Z+Wvl@vX<#_s^~pY$Y1_FTPoV%XbUwf_A$;zQAmiPS>!u^sADAvXaLM?Z zhWx;=>(CxXkEHL@k+v>+?B(sWbzgJ%^=Y0Pr-77GP+@MAZa`Xe?b6wlm*bHDqYUny zO2Rs^Zmc8g$~v>|X*<{^{^fZ7g@3k{ZBD>;Jz_^-01B{Sj2J7%4Cl=Fm}(lv+?r z3h2#!N9gBo?4{k|uL!14&}!pi^`VOyDPFiZNoW0$`@msf0G7jm-U-F$cC|?#aA=qL0Vg zdoG{PI{@$)U`j-P0P_0`Huf6(Q$-uUw^{osCnzhhQc5|=6gsZ)oZy4ZEnR)|%FcuI z)a!eQe;aKqnEcylb+gN9-O711AM*nq4#ul%GR|?^^QubdCg{RCv2J7aJ>J<%PrtsG z_*YGj!#1{YJl#@8YhS1oW5if7W*T5(JH^ju#vA~T>x+wXeyq*s^7*_2fE(dqbt(uu z>2!J^E~ui^4T4^+JG2)hO^}KnR>!H0?k-#J@|-=~3@M;uN6E3v=oAwF_wNf*Q;fVyC7S zXVs25>-!3GV`!@yfwzDvFfkvi5dS*jo0!Wy{_1Ypd9VfbYRH-EM@JE2!x%AE1k5zB zo9g~${l;Q&E}t(b0M4bMOH~#fqoG#Pvr0%0CUb(pWPjj?Lx3R$?~DRw@K=Dh>^V%2 zzOsXMA4X*$b*4RzWzm9}C3G8l=0oWcC?f z>-Ol&J7~+EL&JaI3%Z}tzLB?&g!MmhkI0)uf1a>MH z{3lR@ut0kD0)kLGQlY2b+&lW)Xc$RG*&&3WsO!*jOpxDu^%hO2#lI!Ty0MO|>)3kE^ckS&fx)g|7BC9Q&zObUkD53w z52k+ru;=UQ%){)^&L?$$EYLo>%K#y8zOo-rk$?DNL-<>0?L7mh5Th_2rQ268pc|LX zrK;jQBoc%)M4yvzr898v4NK?J9cva)5td?U%!}K8ac2|#2)gV)+I~*;kG0*`f$5JG zYXU|gR={jbbJOzU^bY`bTYbG*|8SilI5ly!(_mhJ1-`TI{W*rl3V+k`tpi7ifBF6u z{%hBMz%7l6l9J|D6w$3K=hIEg=TgLT!v1t@J>z}mvNz)4+ih6gsjVy`B~Nx!5^cST zG>6-M6FT%KQfFQFc@goL%Z+4u#KL>6KrmqbvRvp=W?D$fWEysAx1>? z=A?$F@CPJ=$yRCw7tc@0^A3Q&r7a-8_$M^$L#FHoPa1$UB7!=%7Z@f@Z~-aMc=z%j zfcn8p4dGv7zl46Iy?gZD@pL(#b9h8fom2q01PP)LuS^{-*6{__>qFsx+)SI#n}g>a0IawU z{4Zji;H$u^n?GL$Vvf>oV@O{Z%!@Hp{brOO1w!RpYQ@_>8NP=J{|`^SN?Z0F8UDw7 z2;GaJGuIh}T|%Y?Q}P%zBj(tTWFW43u>0{pWDe2=ye=pJ((Qh)=ZQUiU5UGi%;;#!~H#Ij+1}i5QUy?yUdyciyPqyqNCUsZ;;AsKjYd_|4kljbyi0_>& zJvog!tzwt>f+4e?gt#XFw>EV@{_o7SiBf680$@ZT!$ZLH|T}BuDi3LMCtA)cNQ?Gq-So(;)07W}D~$7}e^*wkMh&?0odohQvp3l2?oR zvCr5p-Jt=t0UUlQgme2Rx9+8Xd+sfI7R#s4K>qE8t@Pwud&9ZHrO+vC^U}Fvyq5a~ zng5G)`lFk=9{Y_=9h(o2S7;*6U9H+kvKICev)OYRsbV2Iv8cKSgh)n_rueToR5@OH7biyZ)= z-hbToquz(QpZHmqYi@u!ST?0bB5)9Imtymwga*==T6`$L-{rw>(o(V0wsk2SaiUGm zznX}|*F4v}b5r+YPj2dXq=zQqH7Nm*w&VXCf4m>YeLEoiA@FY#&No9QedErKm;I)O z7@7)tT*5iPQ6lpo0`tQ~U5|e2Pr;GNB*df!z=&byFT0+2Lv+dC*V=p-o$+Hd$uEt; z2Y_GlUN?Xe0VcuNU2*@FKtyq=OIlb|ruRq3*|Ct0y>Sl9jj=0T7zFi5^Kj z)bsdD7!ls3wfb0BpYbPWONAFYHQuTLj{L zBly#K{TSKs zg6V1Q9Uwznj1E4)K7FwF@ty4B)0Dj=r>9f^WFTN3>VEu*0b}ko!h8u`#cLVYvWJ%k z0p0$-_cvte_u;AFAK+vsER&!T1EHc z-%_M+#*oUr$(ertfSvZN5aP~FxR2|}%DE%3sq^t~8KLhqM8Ck@j45`f-FnD~$Fvc* zZ0dRZn}6N@)DiaGX-iL$3sXJ-GH^2A#q$HJyXQkfn@?l4=m|_=QyJI!(8sywFj{

`9ECd+i{lvyTc#^dEwAJ1Hp@*=Z zk4u*g=9g0d09T^#&v{ZNhl0e%uzK+sAkE)4F))#j_{y!$`Q99b3 z_S-+S{iwIF>#_e8q||yC`z1(v1a2impLGJ~?(g+H`Y-GoT)V8c%Ps&$&^3SCySeSb z?#KUoVEPxr3Ad$uQ@g3_u|NCUfya*46X}t1+1&&oz?lAJ|4)wr*&>xKH!Ts!2Eep*Yu2c%8Iu;9)#j79^e`5(jCrUQ@rFCL;Y5V>9(t zomdU8xBy`IinG;DVyw4Am0Q@a&Nj=#8=);!s-(?KlkMj)>4Vk8!VB5B1+MhadygT>L$jA>6nFn$H07ks^ zFeb;Fx*yrMsq692OKq-0AUusr6aL$1sB^M6w&7{g@|_QMJ^EXJ+4I<5wso?UnCk6h z17I8v{YB5?FXCP8);IH4*^T>;rn*o!dZ#Do%pgvKo6Y4dv+Ta zfouSb3(Kh|H(`Fz*OmO3CizKR{V9Y~4D6=F#=1Z!`Ki9H#NC|dPYJCq10}Kna1KuT zpR{@Cp{_@NfSmdhSVR96+{%CG3A;l?q<^Z388PN)~E;61yYH3HcH zIBTd*Ey#bZ;(MKs{iYy&7q0#z<_Bl`%LyoU!5g>nPlkezZ0daUGpxe}*LVe2kPUzf zcfQ%w^~BHnt-jA0L;N6Z%Y|2Y+8SQdMt_i`{M>_GPi{u3Y>y8RE@aCFz=fcfaqmC0 zJ=cr0=&3~1y#e!s`*5idq9GT$aUG3Viv3C=>RkU&_Y*(M{>6(6h4jSz*#J0!#e^LF zZS#)@5D0&S%07sc=zG{-$UgyTMi;+{DEKgzV*ilqMt#|B0ujgtzyyLhS{0bTsq3*f zsY`vCG<}E+J(_Xt><$e*rnUL7=#sy`sq4wt33WoU+XN$!4S)#-G_$%*^hjb;&tord z>VE89M$^Z^;WPackQA3OTHk|v?tG}{@t3%c%sR8j6N^AL045gC)7JK2&tw0LN%H%J zFdI?m)PDkEXb;kz2YeHJ;K80J{t0EW+aw^64S-1id87%0xxs~_%+$Ogbghxp_M>S?q8Q14H7Kh*j7rxl5X zn$&^^J0JVZU-dq^i?5wrI{WEjBajV%i;eKx)c@VCM;jmN`f1bK6uIm!U2Y0z1K@J= zSXSQ&MIajh6AEXxvdayDYyezt9?R-Gp$KFHU_#-{R(81|kPU##&0|@8ClrAR41fvM zD_dC>fk{Ci8vv66bGDJGia<61rmAPMdQA!f*#MXnn6r&cRRppDFjYO1)oW4^$Ogbf zz?`k)vPB>p0GF-nvWiax0@(nV2o$q*T($^g1K_fCT~_glKp-0c6M{H-Yye!g zuA3T)U+!A`e*gdg|NpaLbTa?|00v1!K~w_(HYZHBAnuCB00000NkvXXu0mjffkdkh literal 0 HcmV?d00001 diff --git a/docs/source/_static/android-chrome-512x512.png b/docs/source/_static/android-chrome-512x512.png new file mode 100644 index 0000000000000000000000000000000000000000..5bd67a3ac91c6d81788ffff33c816f28206a931b GIT binary patch literal 98923 zcmYgXby!qgx2J?bx@*V*NoD9Ri9x!QE@=e;M@qVea2Q(YPC*HkknRu$5Tv^X>F&70 zd%ySI@7aIMJm=Z_?7h}r`CIEm>1aWS9@0O=z`!7SrmCcefq{koi-iHeMgJUok6mJ5 zkYYSjQZVqw+-|}T0#3LynVnb+4OrAsW4JSSJ_Aa#CXjPP4?phz$uaremhmGMFT%Me zDsxin+M#X16F3u3$scrvkA)*2L||3bAT-)@=l^H4 z;oFC)E-Wbw&nj2XS3j>#PAr-(KW{#7+Vvop4aWZe0YzZ9VWvRMndpJyr2iQ#bujL1qBT|62v_5hn!(NsKS$*8jT{nu+am6^b2q#CnZO zjPsu{z)g&UXQk!g<>k{aOq2fa}j)nK{SSQ2_!3I10loS?rWBF%TeTC-K-?U*muB#Z4{AVoym;wRY`7;cyy%?!V-*xGh$%%e!Gb+1e$FA^86zrj(#*PKgxCmHy9S^hCjbCaxff zH2-^I-TlNAF>CGrT!8#ytKBa&(PlN~4D{6EENOkms!9g!oapyhjlZ?3j>f~Y65I{* z561o$icLXx=`_Ja92SJElYIgNVuM-zu<`7L?k^$Tmw(J(>upZBLRJ`wc-o`u#?q5Q z21CXb+V8T0N%taPCP}OjfLg@RKxS8pxml^WVc>doPS$)SML0UGD(@yqC5MT?r}w0J z4~+BzvZE@}_n2?G;vGs?Hma;2qmQ%r8_FV%@u`{QDG?T)$y=+pq0vn?gt7xtgfam` z;C+8zWz~>1gUFr)6Z6+Ed*r+QbQ%^3;Y+HM0ASrYP<)0l6x&UdIQ%9UTL^&-u85Ib z!QJ91poLCy^^>Z+k1aEWn<@_IugHGZvN~zhFwAJt8Ucu$}Fm8?gKBg&ISj zo~$$BS*f`kOV_M%xxmk%9W&s|75I0$J!+MJ*ww(0HQOZS9}p7>e9xVts}pyBxDeQ+ zXZ>gXcV*wSSisi3hU}nPrz&WHLEnzJfMnusy5ep{xk`P;@MgNTN1-jn){8CP?;`qX zrz3Wtxg$qACUsxU8~$2&x=Lc^$i`)K(aFCK0lDe2bI-E7S-rX!#j$fr#6w`64iKEL zx+4A_+c|Z~p)VEWFVsk+*oX&91-%#gZNCo%wJ;RVLpO60aDE-JW~nd#Sal3Am!$IK zpPks9D~b3MrG-&K!cOFX@u#)6UZ(w{ae9@a!V&{o(cLh`!O7~lcvd`S4|~v;i9y`T z_nTDIoQ^{{_C*3iK&}7>zATYV175e+*aiybBj!136ia_eH7)V|x2kCQ*iml$k@lD2 z<)4jsdw;e`E1$e5R=LrkA{$Q?t_QV_<8sQQWp44e+x->8_LjavQ?S8D!q=KG5)R?; z?!+qCg334u<`Fx&(I7y-Ft`FyIje7D7CDai-ng3*PY4{ff4^Q*_hQVCZ~}r{aC7Vo z&e3OC0fw~GRsI~2a2lJ~)Zz?6Zv{p$I~(QP>)jEmbqq=PFY8i7@vPQ+=p>Uey7p&! z(^P(t%`>|mpIIkmGu+HBj=Ci}F6QTU{lsS{k{o81&@+(9{StYl+I1lPX=3q*w>*T zKfNoxPjHo-^iTuU2h>*dW{^rTqAtJKHVkKLaI7)5=Mb3n>-D}#YMRRJ-e9>w~3x69zEtP)14OS+zQzrxVhu zS)D_&mQY-@woC5o1J=Mn)t8;x`OjARhTl8rM`pub;J)`@%h$Hi?9jB)T4^pH&*#td zM121~pWl3xa+LY<9q^YuPrnag?3x6-Md-w`h>y#BC^9G_R977PoqV|#c9ZDDViYga zTeyzV9H%=@`Oac+T!FW_JCf7MuTlgN8^dT#hL&?MCSH*R5ExW?!~>)gcY8sc;Z&*f z+xGm*aa-R0+Uh)>gE@6_q8}|I?-x_1RT=89v2V1c=q(W4|I0L|oSVP1lUKPmh;DPM zRXjBj?aiKQuil3F+<-U@`B@a>e-4;uBOeiTw%A0=vDjpcHy-X${WBMlAtvw8wTUQE zUIGl~S~GlzY-JLcfdgpyTN7pDa3KGv$V+#$iqL!*U6TDIf)-g*xPk7gFU|T_b;8&Y z$*?6F`qq4@cVf}qvFC%{?fdBhg}Qd$%ol$2X7uy+`TXE(U>kqxLZR^V*F>1bY?OXS zfbMKAyP>YvTl9@CYh8Ti%q)*|$#&jGRY^y5zCI?7g1RwzyUcfVR9PkExZxVnj7ND$ zg^gk9#p&5hu`CIDilH-Mk8$ukY(=YcRpBZDB_yw=f5``q4<&;6}9&eEo2J;ra}Y1L~r7KYQ2%^;Z~jh4l*A^UR)%vg{!yEukZ~3t!ppn zFxp^$$opAVe=nZt<8#?+sq*`vMI-n={J~0WXAQ=FEN+DhfR#AW)C#8Z4Y_s`s|>-J zRXXem1te>91yz2DzUq$5?oL!mqSzO2FLGat^mX#rp-VtHX5_!6x20cC1P=oTM;+5K zBjStQ4M|okUjR)GkcN%Z@3$kaQ#<{o?R!x@Q}u4-8!VF|3~15t9+H+E-us8)HjKuG z8O0KTg?~fIyB^_x*e3#_BP8+KQ#`!I*>frcF;Kj0sg_if z@DrT?-0_TM%u-`t=(KHrFTuP$>aD)OHPBiJZ-`6`GEdkyiqQa5Sa&opBXAbh&X;-B zY9l92=__V%Dg$2Ke5t+pq!eo8`skxOb17$gBr|H?EIeOF9JjGXf~t;YukVTmiLupL zk6jMh>X+31Lk6}8;_PM+hpbZhbVmjqkT!0A%cRhvw1o$0Ve}ygIJq+*<#Z z4=9HcZv_jLl+MkvRYi6Ka8}sa(IRc;)Yu(yM{+dtJn64RU8Om$ zzT#MzV)^<_W6p2L#@m?+wTk*UD6lA!$?rSgu^ZltwYTHfB_!A?jG2%10wT_d@PLgj z_efzmf@b+5zzJrXm8lHnxw5R(H0(xPW3jUnrO`bSs?{DRP0TeK~9XkZ2e{X-#Y zji5k161pbFO$NCzxI^Aj3?oLQ;c`U2BBQFwGf{@LCutdxlt`>n z7)R7rfuCi787Fe7z7e*J_9~3{k^oPBaqPz0KlmQjX|0i^u+Px)ZzJH0Pt*vEnSrR7 zu9eW$XJrsP0!XVjPb6MYPxtG6Gu1>Xtd&dc`;bC#y|rHKu{Z__`OU4{4t@L!Y8ZB* zqn9k*jz&idUUwgyQsoPEdpDQor&XWZ9Vxe$qd0Q(GR>+Hu}DT4k52bxAm^|^AaegA zkG8`e?GT3#U@f#eQuKBVe@E+jvSx(Gg{MEGsqhUCo_N2UImT8T?-eeTQ)*ZWv5B;c z9>YK^Rf%MDKzHctUL?=V#+G32{dpkvG*widvCdrS!2muOPx-Kzn0}{Hw2jP);S-YC zpe9v|#7EvDj&qRWL5xErfR?09eTCawooO$A93ek1ws;bm*zK|O;=}F0S@=S0aQ(o2 z+OjrJ;=oZEdwmg0{1(d{ zvtkZiqZnI4S4R#_zCUdF2<*SyrE-lA57>74PKT^$ zi%SpoBI}nIUPNB~WMPvJQOp?(6P$^#N$MGQCmssVZ+$(&z4(J5!s<9H%E6qL%4)7I z(7R+il?m*S^HB=RnPR3}^fbHN=l|$jCGt^1-yiP0dqB_uxil52&%j1>D?Y}AOMpBX zS!HwlIK_jJXB)~+BQ3KOTy9pU&_zF!MI@@gO>41rt@9t6F$;z1#&u`|o}?4DtQA;? zb)znN+ufmS~%ie>Cm(N9DX2Ytl$+zaW0hKP!_qz9OhqInP{9i0! z71Ku{D^6^c%ttvZ-fPe)xSb5lNkIJRz~Xd~Boo_#u!U_SUv>#km;OmP$<(JBYq`l+ zBMUXW*Eq!JT)Ix;&-5_bmsSBCtpQd{1CIOh^kk&)=r~7PRbh#Y3|x?HHvlwho)B@B z&b;+4Fa>*M$gI{>I)~6_bdqszRqEnTeQp2>_Gg4k64ozKPok0ya-i3yHxMf5-Gk1t zZd)n?nnypZ)U>B`)c7*m9P3+FI2lA76zRq)=T|HCGTRB{W910K>(mH2SvbF%(w$BT z#6Z2@@V!LGUr;=oF}hakED!XRjF(3LtU}B(?Exw^+0*$VYfREdJP%tt@OglkhmnnJ zLn+iOBs!S0X+)voLo5%9u|I9b*5(9|psn$mYDrmLI1@C~ zD|J|@N-yp&WG3-V8iF`}n+dUx@j!pQpcBR(9@dSfx`V z?QeU>hTfwFtvzQ!AY3*kd+p-aaz=#@Jw&Sj3eFHt`)JrB3HkFdzkljd%i7lZeCv7G zF(Z)_p=|$4GNA$CF#&j;ZuhV(pUP1kBIS`B&vya ze<}}BNdS52`d0MMu`&_o;`v;-hh1jb3J|j<#>Lk3%b=swy_tiHE$Yh3?!a}c)gR__fb3ILPM@9VIZ}zDb{MV%9{t@J>7KKxSMJ@|r-O_Nb8V{M-x& zWs@K@D@+GTDrMy@jk0rGCn4KqY>%z{0i4K7S}pVJFJQLT!rjShVXDDE5EH4+nT z5Qe*lJL?C7sUJ2OAB@rlSpmb8Y!t{-h-m|0ADs-umuqK*;4y1?*5?=7lF+e%=cD@# z-Ev#GOU(YE3zO1!^7>f|MAoa8dOcz%9?Is3ZZWOmr15g)GPL4? z$-k%Otj;G>X4au+(MSrCi4sv-lHZpU&pwN+Vn3LV_V-!MeFJ^4XyctBx(qO4Fx#Gb zBq@S6nxAD)qifE_e#Vovy*Fme#8ScP9wK4k?;Cw_gguF;(R9z9@>qHV#?@Ch99kmq zOXfrz|2ORC!oRRZ50s-`DzvlR+Y3QFWXqe~Yi%M2T;c&|Vx6ea?&pI1eTr>hEMeRn zwpny;M97+cu$T&d)UY~$9{@ZV-4}ukl8Qjb35}0NL}8WIe=^o;csCu#QjTh7#RwU* z*HFD#xH@GVd!)+~CRq8#AqQ}IGj0Db$df8qcK3@vHaJ@OKuDzyS?%*>fN6Zm5&n3x z3a~p1S0w$fS_u7VlgU^kE<+)A@9cP|e_(j;#4jW;OHo zdeRS130wS`j{jVu;c9BUCgm|g+C@Yz$iKBY5=K2!mW&RYE(!fPuZ4eOvAJz7&Kg_h z4~4UL*@T8@-}B?z-ed<_4A#wVlh@yU$$Lf^Z@JawG=4JgB@pOM6q_S;)C!)9*LXYXzI0KSo-z5J%1_D0-A*CTS$x);r6H^Q}Lfn>?=RgKj#mQd9^x=2Wlq z?^>6%06QY{;N2fp4AtQns%MIZDq3oTDp*+5B;VfQDv zSJ6(GTT;ilCmb0VMQ`6itRe*It^7$)Lq=Gs;3-n2@N$+8e*6#JG^XO(sxb+bXncMS z==QDQGWRqXyBkyJT4%UM;8H(Hr!+KY^??EW(g{bLErV}%2zWGhbj=Bbl>3d(!3~WW z@2D!m>#T=UuGTju1aOceZr114lgTqiOBmoGCW7{>cDg$zq{5qXiuDtb zBf^8Sm&dU-I4Je#YQ88l?LyAW4PM#!fHR`({_!V`y2rq6YQ*3R{yNs1-jfz4|81(f z_QAc3Pm5`dqI>a+7O&2f%#HEJlqitwg-%(|dl!Au4w(Ii}JJx3yTKSldFUjEuntPMDHd+t4lF zU;1!~FlcePo`uC4vqiKJ9;?Ye~hPs#80r8Dt z#qG{UrqsgM7+<%(pYfC)l|PL#_4wIzn|J(;n@|GoGKk4XOAR;~%r3k$2+`l%ud-+3ll-u3EC zUf(Yt@-bJ~k`v}-6{m7PuMDea#_y~|E2OojkNei2{~R@)xFNV0nYfLfZCrofK(kv~C$=U3 zGV}1KU6zMkYR>KpSWPEOP1F@Ft5y1xCvcgLuGc4F@)`rPqE{&V0CkOlXOU3cZEAUB zLeW!CwT@#j<{p4?73SX|`_fY2u{rMCqw)@X zYp{E&N~_{$a0$em2HAkGsoB;6Wcj!EP-yczSnSD2WG}b-fQS%E1JnT&hK~5zXVyE> zgXjZf$Hyt67+_UAhu5^_Dt}Dy^WVb{$u>Rdtz`~96rKRPEfj=d{clg2y}Prz6R+xQ z>oO_!aIT(bN?`LDx;KNb`}!aJQqOK&>q&HNS$FtWIJdQi<^Q{I&ah~&>9Tl^)Aa)1 zxXDc!t#6wfE>GFE-gqcWsk10EM2qrFXVl7^IP)(51s!3hQ0qV(a#k77B_!o?w)<=j zr@iOha3*QCDNc#5BK^B$d4ZUx99Kzb6OTY-p}971o3jK+p};^5+ddGqsaCXj`jJaN z>G7L~83TTF0qF2T>>eU%xN?0m9xxmE0hX=wVz$)ZOx_wAuYEWpM7~=q$aQ5NLJCf5 zp~FN+W09#s_Pr_HA$xf!A(Sv%{@xA^1kI~st1z%!jJ7dZFv6c?Lcn0+)po-o4!B@2 zLlim_ty>;iRhEhw2D19q3q1tJyq{VdIvfS4CFB8@-YTa6VM+O~bP**^+f}ism&e(v z_MqDWphuFd_f5p?S<|6dA}>vo32>l9r{a`2@gc(zD(xn;g0Mlg5T#K(1g8{_-qqd zOM7CvQG0M{A%sPHazrM$4yO}bKe`?`nh`=}zLKy{fA^w~^ZbDB#79_?MRo{{dg`xU zI2F8OAL2*WFsNu|G7tKXp66!NySG&nf=Y0g!PrWlfEG@iNAE?F!R_fV2DtIn+Y6huG%_p|L+jGAPC+f)yzw6#PdJ z?i3HE?3omq_P>OX`&J-h4*_U|>KT*h5XoXq$V3L0VDf06QWFDCzv zx9QgZVgZi+KGT|)cMpn^sB@g-ajNZ=UeE|SY6V&GieuWU=NSJ}nvQ;xHj`YS#}V@u z(^Th&k1BP4iW;3Qt0{dT8f9E`ye8;y&7CjOl~ZgRAQBShQ>W(YxFMtyiWo2x%$BgE z?09crr7AxlF{3&<(I|R&?dO>(vH1k2!YM`ak#`m|wq@82C-Ml2qhV75K2Ec>_)FNj zKT~5btLGkNMVT4JYyQo->?(IA;E8WXH#3sj;5`R6{}Wcbqke*Fu{YObXB8GVmTD7^ zXC9T@!@*_Mk0Wt&9!LG)Ua&W{iN78U;~GNYEsB57b_)JE20GNpJTr-H9HuMb6qU-k22|&-;QlLrMg)8Ou^Fd zR98M|C})k_$(O_HsJhwn;?y&B-}g5?u9p*iUKVGT)oVwNju2QzH+l6Mxn7O){KsR& zR=|AV-k{v_Cl7ut9S2EFXY3hOy>N>^p6sY=l|pYO5q!?Em-^k(I~ zQXeOV@hApJtQJa?an>bFD=Nl{d6)nZ9)`3j(u& z)gh#BmFLb;7pcDX8$QQAw+V*)U!P~SSV#{sqj0k&T+tXPUqroIc&<&IrGO+tEQMQRm&O)ZPNf46Vu1tcyH8MyC_Ze2IFuC8b*!kn;$ycmHUGr$S zCC*Iz9R-}o`W04&F5P5ff1%<~3m?;|G^k-A=n^Hod)IzLNQ6c`sACAu&;~=84RBWW zL5Ao9+u1hw(iwTgTg%WL)vdPN3oS)wDMLs3^3HOEbVT&MHSx@b;qiDQq&x^Zp;szg zb8|mWlMg#;10?;sKF8WK7DmSrxv-oq2PFlfO!Y3W6;;+2nzd)!^nZWda>y)Ht9kt= zVm5a6)X8h+($vjdT=)%pT3b!^U|y;rDM}iUQY85ilEK`1hDXbvDYtcvvBhv4e?oNy z_m7bM-mD}Xguv*#r1wWQXG=dGE)&(IECoRyEg7UieLj8<0@(f7DA6^`VHd~B-v+!E z3c9aYVP=9S%7#q@@(fh6|Dvn7de-M>R-Kl#w{syRNk-R zFLIk$amn3u^ZAZ5-O{+o>5)tBs0jDXMZVtUv6ynUKiw=9$p5RQTJ}Xz((N|ci2k3` zSptub7sh(>M+SW?M;k1_b?_ga!mDF5HJX3H~Mp56&l9t}^`Pd`uRd94M6W5HaFd zk8I*Zmhcc5c5WoL@go?CgRqR~?^(>Pqk{2^0M&Di2<#n6F>#EYW!pNW1T^Odn-cec zoZ$vZnOqbmi#`a8IMEMe{xaUpdYN_o;l*%vtDA9D9;ukwW3H^jfW z1@=#z1x$Pm+?9QMQ-Ka>CW486Q>&3`op)nK1)OZ zWYM%qxYTPfFIDZ0txZPTI$8vro5n~F&ex`wXrZ6eS0 zI!Pj~kHuEN3x~Z?3BU(fnCie3^ZiW!>V3T%#-rvg0HDy`eX_p@KS1~vT}!r(qPu0B zmTp*h`c0=-+bZc1c2hsyT*lNO-F(L1LAr&Ei$NWJ#b)AMO9D)x9+G8tIRdgzf2d+l zi~)4Hwutk4bV*3CiiVb+o^ef-P{%((y2bHLB@Rqe;5W_Sp>DY%3vihoq;j#sBGb@gkqJ zJ0WrJ#oQVEGC#T6qi<*GtYx!yZl<w6FqzIiEFSZa2bJ1kJP+hSW`gH? z{;JYFS^7dUO!FDPf3z`{6w@l(c%i=2tIHg5g#B}AtSa^|j#0K);e!prpSjly#QHC| zgp$zNyIqu^@>Cu?wOnCc;}jjXJ&u#Yiqq`hh-{4RwTPEpig3bw(x%PoF2m`}&Asrj z1zieVk-s`^I5p=n<94%pORUou-^~=BKo+jbfgkdl@^>U#SKS>~60eGhK}|-SUF<6b zd{YEtioH4WEh0L&F;Q}T(`5*F|nhj48uo<*@k25=^1FaYVN zm3%=@N{d*I7y5$gF6=Hm3Sh_C(PIXo!Ebi^Flo4smH_c@_AXCc?2kU-6s;?y{gBi; z)h5Mnne&y)l6Po5P2PTPcuVk?$17iB`e_et*G*gK^6dmRn-PDV^q)c?Gx^ySNo3~T zZ!6!;uyhB;LaipBInHa6`NP*Eye|a$F&UK<`9I#W#HvL`M}|1C?+<&IN?1~kduWjBn)263^6Ua6x%$K1)BT@^|zMFG8M zA$*_IR$@me70ZAJ3}@i!X$}Y3aR<;h*wT>YI_UJ$VkquOQauAt`^~b(h)`!NnVe#~ zj1e=ejsIzvxo&Y8+iZm!2Cbea;BlFv+>eM+jZ#{E_h&cS;eL?o&QkG9kS{5mA1nNsfj)x_^xSIC|P5`(lht$X4BOJo5v&) zpG^A_d%AAM=kfv*A*4nAZQX<4=D&(eNqxO^x;sd}4gH!`8?B8WCze5=z?9(dE-g$N zm%nG9x)Jeg)?sXxa|48uxf-b2i|)Z4Dc+u0Z?<31$E7990kw)rA6B(V(+}@S+eQc) zh+`#vX(!7WF#e2>ESAW1Q_&P<&2J~kc%Eq93(mwybn2~%D`4Gg*A^nKb7`gqK44vY zuq5V|6NSBFMno(xRZsWT0~ar6-toF8IiHe4Im@zyQlW_5`dnaMZp9}6art<48eLrp zJb$`2P{Bll zTT>R>nZ-{`xH6KTmAXa-+(;-ix!E1o6D1Yoc`=ioZK)E>dgr_bHrk;lOP#1uWiO|Dj^%P-emKgaE}hXg@&PTRCUiSln!it8?i<_){qr8N^8ztOaleA$=;rFLZ;bmOT!}Tq(3`v#Xl$>gG|$vu6VoSs7LuW`O}4;`lnsj>m;Sm0Lf zY91!b>vJk`g$5oVaojA*JY1j1+?7pQ6#}@_leqh-?cj)3r6cVANB|%MBhgq1o2w&f zZTSW>9K>?n`gdnVkCtrlVXDINd#y=5y=E$OvORs@vHJZ}_R_0G`Pbp$dfOEGo;uz- z0a!dLH^SG{f?bhOYKfo{=-`o%t?yUY0}4n-wl{qg8eaCG1)!Q(Y6l?5^NrM(l+U7_ zoarflZKxVpCy{^bP%ULlzre{4T)%K?y{_ppy+J3_K4;!YM&Kz0>;*K`HcO9X)$DiA6+XC3JCQTEY&N_JvH0d3VY>{E zSS;irl=J>|ff@A$H4`m(8lG^M(}>OyG119#WLazdQC2L`xll~j@Q*Bgrs<&I?@2aTR3Y{m}4s~ zA}q$!t@XzW+Y)^+HWBf6+As1R_)(D7eW*fV4@VM<0qgiaK}QS2YvJn#_Ms!!g8g|@^Ghax0wCW8MuGRJaT4!Wcki% zh}s6*Sit->idI5G6()~gh%w6Q@T`1hw* z!IQ)IEK{E6slGr*uMgo*AG+}6l>>}weVYPCaYPRbbz=OE`*>=kVeM-UrV=gTqI zTr|9JQSwI!(k0N-Z^jG#DtxWSSZTIBc6fcDnq^t$Gimrbs)8p*9u5Y_C_Y>Xa-luD z#ECr-m;kEJx?!1ta567?m9+#U?Y^d!`WWH;iv_6g#1kTeB~J?V4iRO$k+Sb^T>rx6 z=km^10JTI)q1yTBx1-+@y8cF{vJ~^<=Mw)l%uYP6T?9z`(Gp|48i9F+*nHliZW2X9 z+X1}3s2Y+IFAiOgRRJGObRk1{GtW3*dsyt@N^JAXY}bIvx3P<7XpcH2Z!6Er zyTZ(1fMhIwjdIRIQd;xUC1{{v|4BMYH>Q1pLCiHKOFg)dXIpKC?&uoS2MbPfT$IQS9Wm?6zOoy#Fx2MC z?y|t_*n1_@j|MC%2l^g9tuBJc7uz9U{92w(kq$hg^9fiFuvTfLi+I26L}|tETnf0F zDRWpN1VqB}S=xgA?7p;50wO{Ff6`=zm`hv162FeN%ElYJ0~a(~{#KYDu>R#r4tI9M zwgWV90!}ys8!dlR{P`B~*L}!ya;ek;Bx9VsI>Z!@QyRRjb7jHHw^mXx!FQK))odz@ zdSiU_YJb$*MOs@(+`d4Lwor{G{joe4rOe*8OjfY(gJx#rY)qi(|3*?sMAY zVfs%}29G+czbkrtl->Pt-E_%o*DdGfG1}Q)`AAND6)!yyv$q;^RcAQ)FxiV2J*4tf zC#Id^qhno0wk&f8L$Jr>uTTbzEsa1~;J(qP) z02jQJ(nr_mP<%badZL^W#t3*Lh>k1mgnTZcwB13_? z$YH%MMs=*0lQD8TmMsT8b^&(A4SerB>0u`oOv&dhbD+NjHwlCLEywaLop(2~&u*8= z;)6eb|3D(yGP`9_d$6Fq038iX7;H)xuWk2^mDLGff@E_6F`u83UGbcZd$GkpgpbgC zN=~_}^*b^ei?R7|u%!85rmf6ZlE-W*_`~Q@1_4*qQ_ng!X2C%k>Hs z6TWSzT!sHd>#iC)%OhlAgy;aDxp#{!4A{FKFl>6gG0JRer^*jlxBpR&sGnxyvejDG zWA=SQ_9X^*c>S;56d4$At-p?jJGL~C(hyB#FaU#0m?wtw za2~$q_Y8KtcB!Ta&7eKWH%P&o;{f031P#Q#d1-BvkljgL;}1H1+sX(UMcPy;(`1 zpUGx+lj(0xiv-UDNUQ(;_2A(MgKT<}-7nuMn!M;Sd9>Q16>|f?B1!tXQ`T;wd&A~b zgebtcsZZWqI8Q@I)aYv*6=mx7*gtHewEIBu0D{6xl0k>V%HH@%nhAta40oam{laz) zk+?eR+*j{{hzKIXa??2i|NddhQ;uZBgmub#eSRFK7gurMwN0&#^?+C-^B#$jtZ0wd z1O_cTaTVl|a@3e`@d!Mq=|T68+g|rXT$3C_kdz>}97F>s;f}LLrstA1=L&0iguIA0 zguTqoXi)ddYFXW}Y#Vo6y)d%eevmKrBi=kjH*n5kF(2uxg`2;OYDp^;$XVtuJu^9^ zy^$g~D6{u@V{FzOwq!_qi!+lD1)X@@(02EYA+Tf5+6TC zy1glVCO>Q`G`;|4K?SP_ELK%)IP>wZDIwCX`e47*Ih1s)D7H-YkuM2Z|wi}QN8 zgqwBR`jY{>j_3Ri#I|fJl*DRb*|38J-bDI9U*U)K!w%q02cxlx?)6KF+-OjVvcPxF zr#HmH;zlPx7vqf|kKWKq7_VQH1~J0@A!C0jmrT2UK0TNyqc1%IpBv22f3n#VDYGe{ zVbEPF%Oj*DIP4T~xUj(~P1WXOjC4i;25ADXE>jE1hVtI?%8B{~6P_j%jmC-i^HV!o5v z{ZhzU1c2#t-aSlP$4hbk_~(Gkpr0>FP{X)L1Ik2z`ZhU>(JO72O>3nVlmhNKzXoCQ z++*Mfpc~5DPh8r~3W!(IrhEvwseB#Xon5cPBM(!iWAgaHv}QVPr3y<2Dw=v=*Fytpr9Xe~KB}4fpqBWd$|l?B8fXYoym9E$Fni^BZF0QJcsJm2{^Znr zBzD{8IqJTa^yb^1J3a8}X@l5Q4qWHO+tm1yWxprtNZ)YhD7|j^IEmnGUEB{BnA9;L z@ZY+B(&`RIF#1L9GOD#wi?<?wWJ))S8g>o5z(B>sh&3JW0gyOz_jM4L8S_=v-~s0!dy#cp96P52=$$D|=` zfO3+&$dNFv@WbiJkM|!9C}xTWZe5CPL|7cqwX8*Zsjt%FlVVa7srLm3&|Na0shtqS zs!JMW@2=U;UHo-3{~UN1MtF1P%kIe6h)qK~c;ylB52Ug2-8RyR((yX}f z+1?Ff6rT~ZKl154K$Sg9X7v!C8?#s2oL#RcT`8-&=K?cHSp|nHJxMpboX4)F93kjj zO_*Bbtl!POrGc+=dhwDMlbVnsGMWt3`4H8IAz`zQGbsulkDz)<{6WAnvJk-yU_hf- zOF>e(8q!^k{Dob7cPKFw1vhUv3QdrPpiilG5zj{m%EZcRuAdFeR;S{Pi^V9Jk=~ zangaw-sKS9{gr)Y%{udkbcgTcoWU@#cY6HjJK&}?Q}Vc@~H*RdMXiRZbs^x-@7>txH?`6YEBiw zP1P_2wWmOvCON;7v1z`VqdTbF8k@Gi536q!&L+j;oSzM4#yq2jR| zXpg6cn3;U=e_WS0$hTUrs0>L(fx~^21At-85$0W;PTy$L`6tVverK4M7Yq z{}5O(&bjZnZl;R{R$7b`|A?2qv{X%v)f2}F*&wy;v{fE@hhrf8c2bV*HPSa@nfpW1 zpxF~pHUzuBxs0_VdX3cXXSr5cG0<*~$cC|cVH<~(p%YN`bb!{=_S1eQU8h038N-z~ z%QI%?UxPib{TbW1tY65;+pN#KW6Ww!s zpkv=Xc4Vv0CCz}R@`!aOIez3$HKm&DvKVWHKzn@UWpj;u=^{|&R*!kC4=%^Nx77~L zTrxoJi-jb7`Jhy}BD(yUfU@#Em+c<9jeM@Ns7eI4B@Y!9FRm zEryNw8=Ll?7|4#TZ~u!0yjQQVIP18TsYoIxopqLE}I<)lKL!~$c|+?B0H zkyyRJ3$sc4p9H=aZYRIOuqCGv123-58$QY9=ikju+SXMAZuEnvw%0fD&8ze}+s>i5 z!kh|^qotlVlzj#gCI6Y)C<(JO&uBk~9(RreMpTG@fmmdo0qv$D2|)g^o9FZt*J|_n z)=*QOqN$Q4Z9%gbJA-~UooS@QTypsC&WxOKz!k!BU){GxJzDffaP@nhjsW|jl1?1`y|eVPSBR?s zw^PWl95gMY^3l<4bbQOeo*udf>*_l7h&k+l`K9B#jDt@l9TG8q2Mg1ifV1?xu1Ixx zgPMnSA54^=$#l3aT;ZrmlFnOs29b<8RgTJl1Pu7K~Mt)2ooG|y_TpSyvW8?*n= zn==UQwK#DJ30lLfv1yHA1QbG@$j(v>x~Rs3{R5?OI6uM6w^(eNpWVAp7g)^h-mc0E zEOSlQa9TQr-PX7L4^i*HUgsZd4e!{t?KHNX##Un|jds}B#*S^HaoV7`vvanx__EkYi8C;MC#8V-J0l+>*x--JSRK;Mu^sraHy&@wgRDm zbo^J@dGT~3UM z?SqEKzVyE^tl#f&_oF?cLJYo@s)A1`OTcTvv{c8x7duaRB<~b@B2JwW;ZpqHsDrjt zgb(0V7wOI0PPvO95HqE!dMvOTzjB;(;MIW|Cq{!X-7{jDELJT@8;wnR*dXlDsRp7Y zf5k5AeR@Ftv^ZN1SwHCnd80z9qPZ(eNA05C`jvoWqJsmf0K9Hp$pB_M?#` z?R0;O2XpQ4>APXGn-`sj#3sBbqAx`qFQR9ug>a8wwlrrb94NYd*SS+$jBoRscpaWR z)$gvinj{vvss{Dsh&Zep;k*V+)-95h>Grf1e4;vA+z&cExihs{r!0sY9XaB$qJoH# zxM{#|s(6WFA|km*B!yX}2@AVS_1<3}>S2?vRC8BN(uVM={bz7eBgTZW(PK=Q}e5aS;aC}SP=MS10ftD;`+7V{&7+@tn~^Qb=4 zH-qQzm$lq28qJ!y<=~1ISJ0`Y9Cd=uSjktbXY9W(=>zT~juhj}=F~kFt1+W5=wR4c zQ8_2W7>G}y9ZArS$|l9ECxw#8Z5FE=!le)-TsywL-vTB7U;sgh$et$;*hehpGY0rYV7gQmcuthZ%{t^_YBY>~O0C?R|MYK~enm z<+gaT)X(9l4EcKd)s`S#ZIkXelcCtR?`&^S@DYpd6U6cxV>n-}r~=-6bTAGoOhgvbisNFKBUkZEQJnQ?V71jD;0)#re^4J|bJQldkA3hBS`fZvmJ-i2+;+K6c(}qR6^<(LWpKWuEhsZtv%5m~ zCnC#oQ56S-3nV9b^giemZy8VwalMonh^G~DrIxH~c+tvP4T(x*!ef7u(blZbj2GB# zz9oxXZ~K@2I`&wnDGkG+`*x(BH9Auqv3L0XO9%E$klZ$h{AUgfq;b<9iNmHKr%rNN z|92yT)c{>;X#knKEZvw}BnNIyPJsF2zrX$rKL4K3Dp|Fph&T}bl*BTsYo_Q0bue1! zHgIL$(Fb9XM*`?Aj2e`{8;z()Ur!O6fTI?;J}7~Vp)i2mARUWgm*DoyW1ax>C%i`b6;_dU^6ypRd#bond-*$3vJm*t<4yTm9x`dadN7kyw)PajzE&#PAn+!V7&9+=O8MGo<-&QOgC}!JeLv4c zyuLq#62FR~$^Vn?(#G%ngDP9F>+3b1^sW8={AKd=*s>*4QGy4DLtZb=fTcYtdh~;r zt4VN8cK?rGBFBxU2e3>0PXZ~$jG-}*Hez)H-QAw`8<1ptK!>srlD8$LGx7D0coM_nNpr$0g{X zxV;|F^ApnF@WZzc#r`|t98{jFKDJ^;PstIx4|t^M76mKK@{hGwFJnxLF9|+N&H1VB zG0n8cLY4~_xQ$F7{9`cy5scq87Qw*ul^7~A6aov;-Yvv9JoT&!Q*Z#iK}xY^sLMJW z#MB@&f#E~n_bqY^kVqnX>=q_O^i>b)gee~3xdcN-G z$9T8bekJ{>Vl5mZBNE2;zSWM-=)WCl@sKR$!haR$CwMcj#{U8i@1s6>cz3UE7Uhqt zHfl?0S7%D1s?r1))?G$3)wB<-p%o6ptOGlhDE+77cT(#TFw&6r5^f%+hWfS=1!U+xd z_k5PT=V5wTjJG8&b)`=EtP1L5bbiaw9#fRijS1V`Tw0N+n>2TyF{k$9mLe^r4eBC3 zS*e8IxhyCm?uw}ITP?c-y!pvL@_0_J>p?6|k|n+}SL|^b{9fnRSK9~Y$>LqwttgEa zjr2vOsYu__!9#>iPgzXopAEhgAtPDnQ7ix*F7h%eUs^RoRUM7F_?mW3`#^ftYWVuK(_5YN>+ zvgsHkHu+0@LpR$1f6AiPoNg&}5u}u^CL3cF9m-&&d*PPKvB;n1hqOaOoq*;>ucQ0c z=yOfo5)aka2V4FtJhbtKZXd{K6$<^1aHdBW z%bo4`wDqPP*Q4grIyv(36kE`ntfg>6N;-c}4FSDQ<(@XI*&LqzNSs*-?kYvKLXA zcV=^4S9FRgWk>2NXXP&6iyxj@X zp5|(y0g04jvO=P==~PqnYPmysU@#e*}j5 z>dY*5?)5*8A}rO@j;IEwEA26U{lM<=cqm`yO0DmFF>PA>;}sXJ)x2O9YG47W&Li>q z;Jb;J(U-We;#a7Em%yAMcvQxA!;n@WiQBTl&Vj8|Lf-q1P_PuZr!tqdEs3>s9M6$g z4~`Y-4{Lp35oZ^csSV{su5*woShoq`WBvxgwXZg&KS-eOy#otYCwJ~=<84qC^0l-G zf+9}WT&EEqnK0+@V}^KIxG`bE!_AE`c}k>5NU{2e&=M33>MmF)t{cqJJ_vQ6o;?hB z{eZ7>+pSI#$igQOS>DJWoXhsr(pzTuzTT*mw)RWSR|nyah+^Yod?k8`#>(Oh~X zMs@5d4g=eQCbl*!C#pg8K3JPj9r()8b|9?u&$G-n_xtc8)L0RsL#>;5gIb&WgSWOb z##q4>8)Vg1Mv_8f66F*kx6jS`LD+lffv>FI>M2KX)*P>K_F@FQZrjDm?RHxMFLm(Sf?x@k!NUVX$8LC#y&^kuppgc>zWZ% z64*`R?{@;k=Q~GPMSCbve0YwpXncszsN=BA)@NU&T(C_w*8!Op8{JkO0-7E;qnUUYNQaGL* z@#Xgju)8UOA7vyMEMK~zU1++{ZGyod1&vIwrW;|`^sn~yKNXjGB!V_cdof7?&EINA zcy5`S8=<@~{Np1{mL(R#P@aDy91?;GUOwX=ne1xR_CGB7r45!)D`3)#5E}~fGpoE0 zN?vWRP47BcDL6m!s4gXWHJ*JEv>2S!i9X+a34HNy^=^0bYeju>YQCJ)qSW)B^Bx)U zI5PN6^&3`ihdb%3A?n6;tk88j+UUh-p#cw9NCzEQEG*$w4X!oMy^n!bCK4x6W3wdE z6kZn-qD-o!2cKQ`w9p28I+}WpM?fDS1p;mrdasZVxLHlS%=MrrJ+Abjp3Scjv6NwMu#Oe4*AmPr-YPNK@wEZfomMv>GS%hadM=ku|5rHS^AfZSp5wqt|;>QUy^05a${rtsb{G=p;g=@;Vse z^28=l#HRTh)@6MZXQ?(VuH`%>FV%PLtMhPu4ifR9UV|9VYL{f1VYjv~H1{%WW|09Z zDZlCLcgYRUFJ@AJp=*v=uHtHq{qY1pv%nKD{bq2pM0YR$8*TmKZ>^AD7^O1NfQscZd&JI!`UUWH+!_{ zRMfT4Do0lIdB<1&Hh=aqf^+vQxU2V&xoQX%F}hy))QGEjbYzoFyZ~Eu^^#7OUfz8X ziQFl~=lqvqC*=bQ8YdufGLtjX_Q3v!4io_vrBkuGlmML+S9=V*0pfAmGER2E%Ldcw z(!HM2b{}y1LUL~XFE_iKXb%B^u}{q}n8`1%tSu0oV|hM06>EuX8O@rFl|X^Q2EQ{M zrJ^3sDGm=#V7jQ4h@oCH6(8Uv{jYkQy3$`Bkud$x^@+V|Jk{jxGN(6r#1sDkTQh(ND;d1NQ$qeXFhXq)UldCND_P)2N1$0Uj|Nq8*|>(QCLr>QUo$;`XYSd!UY(6R+7 z?gM&Xq8M%W6Yt|q9_S1)6?^N4L!AqBL5jSJ|h0Pa?lws zic^qQ1vdW*ePs*g2tn{`wdq&?oQ!2IMQ)7^+kj1|rLYFPiN_St&?GEO&4S@KfOvM~2$jCeg%*DcSmH{!r3~`Tcn@_Ps(@Dm zj266e9xl5U&ke)wUb3F1TX(;u%oip&6VUq{3L7{Goa7=2SnfFJTdaNOk#B4Fow^lh z(GD&CI~b&{Be&S;LiObM-HqkcJ_g*PcPz>6bw&i_>#kRU1?38edZ|ivBq7hT$}v0- z02;&;(TT1Zuhq@MGAKME3n^b%dvN>T{|0gQuilstP32(_xBf^4r*@@#@65lRmrcw> z&`h{cdSsKXhJ;v5z`P zBrxosR}d2$f1Jy_(IsE&J-*Ngye(>W$LxQOAmayrVjF@#JO#_rq0w8h84X8?;e__u zSvCs-Djddy5{gGThjc13)$EFlotiawB$6Rf>tlbR$sTWKtN~Ysdx9jAbHgX1xH<5KG$uJNS9&D`ayhfzDL2B z1uxUHC4D>vBxaHSb9G|Tsd1K>uHL^jvdvU)_mtM!gWolGw26!7cru0MX8@u5lq~6E z-GqLovl*sjN=}6Z6>h8nU$7}DryzGmCozg8fec^H-l|pzW7>-dD>7;)wkOS>460_;FNkM=~H2S6qJ%YG-Ii4M4V9T z+6^TpNR?0hgSSmjYy+rS57pD=JJIG!jJN_)v`+Nk`U*57@hv0;o0oEOg@AO#8sfKu$)sB~R{_ z)IZa@vEV{)W+8Fw&BDhsb{?Xq8`7ni0Gp7gVov>=60xTZpneo|c@Wyu8*4iCxBz4O zjYf(8YIH0TFfR^-JVsTpS*V^I=;=$OiywgsxGf;|!rWnOC5X=c5t3GYl$z@#b?9~N zqNW9I1^C(ap5=u|JKm@?-g8aTXz|a?R+j9FFsaPtrjqK4++z0vOSV4KFu`B(za5~z zPVnmo)PG!fN|lV3ohg1cP|d&%7E?(JY# zTnbv5oFnk>XTcF4y6-J@Nx4qT6i7NO12n9fNc(s0c+AECQm1roLG76Pu1|guQI(riYY` z_f3$&YM6|;nH2%fs!s{5FO21Fy1!It>z4REO;GQk^M%329$Sob${2h0sZJ_lq_t}hk!&eXQgtCUh`70YCVjjzzgk|(OEBI4_SY?0UPq5- zqgA|ly-`}=;G29li{L@s-0YmG(3;D&%+X<0&qa&u=WrFBy~{dL+eYB8fQ#)5{#&4D zyT(|D=ZFYfeV{HqiWTHX?dl^0vfZdKv)@nSz3}?_m&QEFhj!8 zenfX#3#Dat0=OYXK1LbJxWuS9O#!r(ZUn0bj-k)Ay^-P3N}JFZk5*G~1=~BtL1plE zNkhjroa~U}Kk?>YU<^h5D~w6O5{YdQsGrqURy^il6!u4_{x*+3*)2x1%HsV%!ON0O zzk2H0wOm`U+dW=#z`Ock7;9lH+Hwx#wdq8bWDaNSxMW%Gmiwy==t)!h5&EK6wU+;n z3kceoc+Nkah(wusq63#-%p$Pfp>;VFF%nQ><%0NA;(95Y`@iu3O|0hG!Mf$eWLCbP zhGIVvQlSuenLziA9hfKW-$NGMRN`ck`^rfV8PK|BO*pU|DQoNtRDyfmD&Z^94Wlan z^EMtB400M?TTys>`fRYWdBV|ZCa|(5Pg4G$^t+{MCBK+FrViSLHN`A9{_$arUTVBm z*b|tl0i2I0xGCyC?H#%f@9G1ej66>h>d#qu(&_)p8|nHnp#~t?#qGgc@j)t2iAenp zq~1Cd4fy9Y=lTw&yh4P;Vig7np_jG$RW^G{K5j+HDS}`&A23|~D7mWxIl97=vcQvv zKS1+{lpO|^uE!-a5?ZxaZB}hd?NSW_#0Mdl)hi(}yHne7C0()tG|cNb*Ymm{yBn`G zckrzN?pnT1al8CGy91fA;^ed3_~9{UvMzOaWE(^!20*X>Jn=6F<;!e{P^zpx|CkT| znl3f8j4RWIxRdX}!X=xbOHS)zA3OQXd6RE0SJIPa>cGktd=3xWK-$$50mWHb2JZf4 z5N1a@W)~PTv5ct+$LgU(QlP|D4P_uwftf+5H=w2C0-qtEa$uOOb8d{yeP#0S48Wxq z&Hqu&>ol*tghw~5P|P7CBZJo6L&xHoiFaI!*C7S*YPt0|<9X`#pNo!ObJw=~g-x>C z>qhk>Apjd*E?$P^&XhC|?*scBzB3DGH;WL4T* z4T2Yg#Sz9jJVYAa%j*u%%?r3muyAIKPHPGAuG1x7x_*Vwslh&Q5C3uZ+j&)*(XOR7 zfynVMOH4ZJSEBqs)^O$y@`vnyKP`sa%LYOpqW=I3U2ADO8PK*HwMf?L z{%kc&&vjqY`&G4s{GO#WUYc8X?sv55+;$LyHqIGOoPzJStelx~_Y>0^3_i)o?&x@I zUSM?DCjjb`L;1@|oo4wM6MwYdl8f1m(ZETk#)rIlSzjqokrU3vTvT1XBDX8iMMJYW zz)ih?fs4L*r74C}+3~9fB&wEoV?PkGG3Esuz$Ao2C^O|L+Bb3^u*V*n2**Cur#IIO zotZM8Xy4{hV{m_5!x5myD^uCZ^$`Buc=6;K*JJc&i+y6vQf;z)X{b!`pu^@tPn5XV z=J)vVE}y)#!rqmM-NLF-0ZStzfkd3?r|e5Qod z2LgX9HaIKw1u%=2>}{Sl5A)SJr${_%fh4+;U4nsBW2N(ED>gb7F)li;P-&kxI~O2w zmqL0n%Av&JYf+~D-k{+EZemOvCiT?jv#mG8wOWV7AJEx@Z9op;)A%d8=AFuD)NItG zT`R4Bz0jftd2K)OVDABFl=savU+8__bTmnmI_%5=>XUxM zXGhrtk^x91`S8rBG8tpCX_8$NCa$ddlXvprKe4g#NI8`&?Aez%T zOeJXfm&9BOhV`Ei7n}HsB^AK|UMvp2-YjEF21?QZ=g2&{{)-99_Xp6Lg!sdh3F<_3 z6%3q-)$Oo-m6Au?4>Asnm2FZG>4&V`{0!J}+=w?sq8I zkyRO6iyRVacn4g)jA~YKP=TUODqtL)=u&IdKzuml35Ke{Gon z2ziD<4U=$Pt%eb7<&79fv3kkH$Ht7d`dYI3K1mmPk`zSaPyhlYowj;(8ho3hJvuw; zv24J4IzDOUkqC{Bu#Ujrdb zdf7~uTbX#@d%$caM$%1?7ssr1qhGi&F%Y&H+`+Ws&)k1c>1(y zpc>e4`)U>Yz0yHe<*Z0cW?oX9VUY2*!QXhAcn3MYgNf`SHeXh>4=mW=gL`1#%S7Y#&0~>x`i_n2#_#BZyUJLNy4F)2fA-z0|rtw04W1aMZd)0qLoklGSs^KNa@0 zDt5lkFL(zKI7Ss*t5|O zkG?}2kQF6JOP=Q;;s8a|JLI;@%*TZkP?(|Fu3V=7JPNnQiF@_jqiGhD>?2 z^lxLi@|eu!5<+QQZEX!i+)EJh65<^?#jUFLKuESlaSDI-tF2gTP^j-I_{nM3`1vK> z4^5cP_n<{SbkGX~d?KvO72fMA!V9`&a4ARucTQs36##W?Z@6hik)om;RWT>XgJ-6g%u3>}Lc7SV^+REzSbqMqK&72sDNZ z!!G+~hCC_4%)rq|{?Ct5`%HIe`=j*^pJlEw3`&s~3tQzFph!G$3uwpDxvb!EOIEzowN)RJ%dGWY{A9>zClJ`ejv9Q>=gi z1Mh0Rnzw4X$Itk4L(hvK>)q_lwJ5BR^lCr6;5tRu*^BH6Bt=m+6I}o;GEst`|MLSC zead!|&Kb6&zOM8`o5s`t75hp=|NmO(c-F?YTlvRd<|MG^4;dX(3qY2zKljT_1!X>f z2f{L!nRbq$0|{E|gy-LDb=dKV_k-Kn_sTp@Ul>|Cjv4iSS#p>o+aXbx_5GKZ3-rb4me>Fn9rv{Sb%@&{|fgau9{BzFwiUe(=5_y2IVMK zgl>y`qlW&?jGL|Lxb1MIp<-;`j&yDyns<0xvOf51Jq*j&eQEe?;@zWkR%wv?&L^9g$} zd0}PLjc)_UbLKXeQ|7;LLBHc>fEFy2qrfjHM27cB>X|wCp|W>K|8FR&rs6Tf*AH%L ze6%*g?i_-f2}MIbPT=34_7I$*1+vBV5+H0n|J9y9dMbI!*5^yh4h|uMbU0FP?XI`U z;jk?SPr=`)jOiVxtzUnxy6oWUjkKKp%#4TuPU>`-#hk7-%P^K!d~&yHSIH&jGFpA# z5g8S;eXv6t>=c_ka{3xX9k&MKM)V{IJXm} zJ4&zerg@$nnvj1(2D;1%H_-g9P$5g-hX&@M-C}HkC8vD}X`&%9akm6DKCH_!e`9f6 z4Wz`ra&4N2LI279+C%m5nVAe~<`U5~|7?i8#7CJ#(^ig2bv=4YC~t}U(z|$!N{4U< zC!B$K5d+jinvj8&z{I0jc{Kh5z{Ae>i&)mLX2tDg590B7Eq*lQIWPT7DvIoLOl{tf zi;+rs$5ptJ%Nof2_pDlQ++>-wug*L7-=(q1*8O0wTn>)qY#9=-xG&sk5qF}j#&1LB z5=K9^XZLhz55GRO|L(iMdtj?ZX9Z&~&ct5R1Pyz8j|ZXtus15r*Z$%$WOz5`Gm_Bl1jg2P2k!Ar1m%2f zg9b4ppE<3@$FaTbgr@19J6DQ>9Q!ZbijM7G->bhg*MD+1`@1|;o-a&k{BrE?brxpc z=IR1+7fkn<_+bt=t34)wtF!@28i`^u;K2p#r-VpC*5DbNCh<0_B?xH`lrw80c}UYB zMvfJD?^trw9xDnjp7P&9&PeWiomVQU6oeFFnxXkX`79LIg*i44!(RRLxdRt?1>=)U z^Zn5BL5QZt$-M#F(Gdn-YwW2(R>uq0py=#ai-K(f8Yg#b(9CaOI_EK+TKZdUJU?0P^ z$K&rXAZhbh+U@utljZ=|iQ10{hv#7@A;#xCKKfg}Ocy^EI{wJeWPr3x|4YKV@S&Xv zWH&ry4?)4DXv}OHO^V|GaRIy&-2B%A)D3?eeKjIZlhh)*ak}mcYQu~gdBCCuxoa3U zh3}_`>3^mt98hP3U75bMEg~b{@}xo{n2gIpc|F7jGRr_teyqTA zWbgTcjZK3iq+AHdi-d79P#uXx92ATl(0tj!8uhK?FfV%2=TSb z6GgrHOjYkmkQjnxy7}_ADhB~!-Bc%`ixl5~Z*+U}M-;smt)J%gK)c}%c6(_jJY&8k z@U&IaFK4J;3qOf0@Fd<0ac*FXF#c8a16Howf5bQlx_jfNPwX>SmH;2fw6_BrS$hBj z>&T+xk%^ES#O^AgPW=Y`nRK{Oa%Rc+pDyH2mBSs(@`o=x^JKTVJ!|?d%TWti7>M7Z zE*#upyHNhlXp}YG6P{O58KoS>c= zvMf5#5ARXPPlp#m=U)=JNOl^gu02tVm~=?nG)(TlN+DLG#yz=lK+XbTV<+R#i^SA) zvkT~}iRBZ<*rwF4hF1&vNU~zasfJZXMOR-lgl-0q)cZgG@thCOUx!(9VRCs~ax|pE z;e&y3QlMd7ve*06a<;-8gLn61OZkz8Z!7@5_?5NjQ$&1j&h@LjKhg{1*e|jDI0bO) zllaA7>`A-5d7|ct!vj=d6rMQVhR=$a*#qbaE5`S*Pp>BR(d#aiEi&< ziYh`o0=Da$+C5;yP5rPMn?X!me}>TY>(l}UKxLDB>z?CjYPYO4)1Ym<(h<;2(9+-_ z!GL(iEM(z5PMtEF%!3Z0I5^=4MF5c9GZ|V4j9l$x2s8gG7B*Y1P9BM16`Td2+mvIf zf+EZn*aqccQuVs%)yV2&+~W*`qvp6>f~OK=mLyp9oe?{H!8C`nhpmVUY`%nBgAJ3N zlS(q2)2kXtfWUXT8=!mp8nG-=0@JYNq=1M?q2<`;E%>rVl7YD858ifli#Q)|-Og4x@z45_G9) z#{5ldLpqNyVx!!Jykox1* z3uo!#bdZtXXBAjuMN%Pq70{F6zXs<7-k0b4ApCy@X4@d%M3c^fqMYFi-MoIV#hX#g z5?5)1r*aOl>zge_5Bdp)c(=y0rxRq56fmKHlsR8!XXF({jGz_30YxOH zLk}DNepu=gdWfLO7*>qGg*5=()U%1T3Egq7$*WrR=&yR7@U2twm7IWr02BW#$Gag9 zEB+D$_Ny`qkss*Sqfdc7U#MowN`)bT*vE-^gdD8sO(9va8jE+Ze9#zZ)c`Uu%RV|o zdffy(En;2W!sVyUNooPVTKU06*j>d%E#JQ4H?e>V zqj`q;Pt{QFDDI#mQa(Cuo%z3RSSzN#5dVGkfL1IQZ^_+)4^(l3wr8j0e#8yvl^je8 zm40Z>2CE9#@XAW3NFXIO_m&^2tc|o)*Pe}EUoAJnGii?NgXMHt--4w33Z;^4f-$zTMc6I z|5l4*EzILL23szyZ1*1wwlz2sdTJ@CfzKBVeDPYoMEcQMn7B;`f^&9+#DRDB=zH#~ zS}~OP{%7n#!-&Z5HhjSRGZ{5@Mz#xp6fAAvSdSL|Tew+4Vrc^nJ;F?bd||-AW`1Zx zc~Ch9S?saV_o>8HlZbC2T6NscSc(7YSb!cvQy+2A$kmTG2Z0Y-6CV{LOwE$LLKC-I zA47&&24&Cyg0Y=$hDH3V3E7Gq&DL=xM*uOsuQv_&77k*0m zO}!{z^ri7-a_MfIN~N2=TH~-K>bZI|&O8fAhJE;+f*Y*&)Wd@!nrmqm2Ac++IKUsZ z8{%!S%7D6C>@WN;2+lzi4^%-m4V?8*KkWOf`**nhk4xob-_}7VJ_PX~uX9Evo^b9tV~ z*5WZke?}X0tB#28KW1%0K_dJ0bJFtJKs!jFAa6^Bck8jwj66|<9;1WGTc#5H*%1q@ zRoJlRVS)IM;>P#aV~DR9J_k%jL?_4Q8`0sqAoHBMM~Zy1x>nfSFmXP#a^`#8iTlDK)56;#9*#LV68|qDi^kQBk1%HU@BNwm#n%Ujp=+ z{hXi5x_({J;p?KxvUsqk5psogFGRcH`9MHWy(ke7ba@(IshZ^FFm;LbQGb}Ve~1E? zPVm2QbSp%CwA@lHB2MITgrLEZJn+~<)(tSS3%dL2_y)Nj!*6E`g#7B~zUV{;!JaFT z^#7_?o)@kz=7;>OguW5`ptK90;V9^?sv`)>a)GIVG(b3{X)?s1p*+BOWu>cR!+E89I|W=Z8rY(bXeep7p4vHgz+IZ45)N?) zCAD*#5OFvV1dW4?Bf}oO*{RY;U}9LoY@neT=w7C7PM;hdJO+}FAB&eyz7rWedSzAWd;gl&F0tknGp8t8zd0AVFA5(ZOpKi-L_+2)ByBzrn z8Ig`5JSM^}#*B`IT!Te5AdfLg6>jmF_`kBURZftTSv}L-zyCQRP ztMax$)G}SO7qfdX=Ng+wV8MX!nY3)ky$HhpE-NX6*fAy8TY;7Mb(9!l&V9zqB1{ZU z9q%ki+1>;C$=(_R#b&WwVv>*;VrTj$01VwxNF|u{;VI?+Wnmkk3x7!>@ITwE5JkIp z9dHVrWtH6G3f4qME}o5{jl?&is(@z+jy0h=tiYcb(Z=ETc7~?Ifr3ydHGw4rC%!=d zy5(G^XHzt)sVhcefU3JHRa$a*8`;2Z@Q*K7{d^!qtBeCv(Ns9F3+w$1MPRZxTnJ*+ z6o3SMaTwlI9z-aF4Uc-_C`a#T{Rb)*&~KR1w{RGuC*V9S3y`+p~m7JX9wo3(;F zMVCY#B`PzCjAGx;mO8}<8K7jBo4^F$dn&6S!!#T&(v5+xD|}KUdkx+T-D5e2}I>7dZ-pvwd-wawsr=e=+h0vD2 z!zLXg>G4P>GaQge z0}`@Iq?3uU7jo@ftdj#DOb3s9xMBLNC?CSqQpXV?A5}Hcr2|*a3GGG?uRCsWsy&rE zf%_LtVN-!xPv3+}^tm7)fnGGDH{nMRNrOtTz!zvC(0Ft0nHe$o|K3-~*a=Ao40 zY?F75fVz5@dV6A73Q|e8x)e15i^tQf>do@0B+G_L_ z;YL}+{_gDiKp@x~Rif_OGb%PYwiN|?C=>!>a7Zy^@L;|N0Nr?AlFf@V3$`{by6_Yy zd`TR^ZD|l}kX5V4fQ?wlcK1c4Bc#+lQ)X2s-MfT3eOsWu2Ndd_mdJH0ly=}IMlm1$ z9SCnmj8s}kb6&t!)%p$RnRx)-x<_1(Eq0^=Kmw&K-#9~KO`MbAhLQG4@Nc~#1bU_ho!Jq>Ua?O$z&43Za}a>7MZCTM=Sy^|O7@>$?g;pVf;d9e|JuPcU_nS!q;R@?@z#BRb=niwDKibWi z;LHK^3nUQlO0EbefrkqUJoKmC<53z)0%k$w`R#1epJs)DkN4anXo|bbo1nes{_}x6X_@D6$wKk)IR}>P3ZVw4?A;dzJ%l*K1`!8 zWSUd;VX(j$vw^t+Y?~J)_L;v!Hu+S+YJ&w+LT>bc!9xSPHdAeYw%#IG+zFdIa2)|f z6x91Bu*d|2#l|)l9(v31RNTeci*|8e|2|*jjkD*5UZ6{SzeVVQ#*&>Pty7{5&etR& zg}I%Dx!tFzd?VeP!yGt*uZl5mY`$jE)NVTJM$735KJ)|#zB~jKUmd*PBnaKTV2pAY zrMd{n53&c`rAGY2cf76;`*#9&)rd^5wix9^=-&@uQwEO?Zut%}$vg||9z?(m6~E?XqwEiF>;fm^HO?5Y0S`fac}=Bs)nq(#y61UzAPgGq742# z&6*)uX^q&<%XKc-IL1#=JHQ5sxMR4hGs*V2Ehl`Nx1D#Gcb<25gXV(^7MM}bpNWWt zUw#3|d@khV7yZgBjSUJMA$>8062R@I1VTQA<_o5Fa|EX));N0$vUclsH|>cyHf!FP zEn%5BP>9k0j8;7qLzkdXqnE>;b@uW{a-bNnBxMmY#JW2MI~<9t?Ls_8#vw<|wMz0o zma8H|MNa9L1%ECx^)l9Ws z=Ufa3eWwywVr=r%LMtQNY)J83@%fd>h_AwlKBQhoZ(!|?t#V`R{d;52qfnujc?v!_ zv~k{hDtXnIEpmg5=azQirb}-|1|bi3+Uta8$c=D>&CphI?a1?5zW9CiK3cAqo-`gd znF)fsKykA~aON^7ZFE9zhC7^(vM~=a4}+L%a==IhJp)t52_&(6+GGXM{F+6>63^~;@ zyg=|YM5rcmE~$wfefSsc7{p>@Udc&xQJFJMc!^@8={Qkj@7Nh*0zkYoab@*7I8jrN zq8elet1!~T4|2B}=iWzOSF&CCZ06n4TB^l#N<3Di2~lCX!Df*_fE*9u>kSu&!l_gd z833R4JbI6IH|tm52Ou7|{o=<(1J44kVEDPYe0$G*w)AlhtDDr z2>rK%bqlW*k=>}o=YNZF$ag4k*iOwd@jp)2p@IN>vEU}qm4_O>_)DH?#yUmUjNdUq zBruHYnv(}<$~JNUi>uk>XHb9e>sa*?W>)ayd zz+Y>ja|aMabUSfvw3j4MIv%Xz?c;i?npUe}o8u!;a2_&Z0H= z{8r>uiClqraG_oJAy>^5QP8eQ$F)70V1XD}5^9`c_H=Up=23NZD|vDP>8miRD>*p7 zf4`l~cu9}iOJ^7&^T*h?pkbc|`2k!JeLS#oQtsqOI=(I5_8DAo`A*ANd*}3!OTDUX z*;RIkp85b}YckV?*tse(45kAe^|99~Q8&sP&^tjRWM+`e_YuD|+3^u_xFXLnO>i*l z>=Uhxs}F*+Am!X--iOpO45b--d{m7DDgg&obxVMdGupgu(?qh+SpsWl_%4cWVUh=) zK)oZPtr`-!qV}l#Y_TWO!RLLG>0yi)C0WY(#eH4&o|r*ZQ^i^Z7|ziU?5|{X@3NLD zJU(vl6GwRVv-_*&meoDUrOAF|3;S*#nX_r+z`qB6y)pOQ5spuQFusRDv0#X7Z|r#n zd=essxK)>wu1eNj&pAaGNu-oHJh6iK*K>^PXmfP6!39+$HHkEnjKPJ3Bc!j!E4r#1 z#K6W;nX)Db2;Dwb0O7ql$3KYQMU#N^nz}kSKfKmASla%B!~PFOP1gz=vra_-Le$7xAIs_4729&#wj;FK+zW?RW(6f7 zU8zkt@zfwhHcle;u%%{4=)ZF8^12lQq0d)F70T|f{Rm6X3s^b9mMXSh**WfA+35l? z*ul@l$k(zUVn14iEL_h%z$1Bkbb}~@{g?6X&(#eUPb)>$~}) z1qscola~9VU6TSh`S-6*LKZYEk zvy9{8YD6x^=-y5}q%J16vF_it_^82rH7G}j(T@w9SaQK#vfftrv;9U&41gKLtSgwy zB@=roqIn*%?M^H?+6JSal9&zZpvjk+2dR$=CeaGvq0VZ=_kS!S+XzPm8@+2TS9woZ ze6NdhCWYs{SL}@H_~Son;KuypkBAT0^4h#9q}-yd>qp ztA{ILK(XO6#dah*gpl=L!x{Jx8 zVE<#Y>y8eA8<%$!874#7N1rBdMvjK^!`0r+FwJlmyJYf`YYlQJcond!Yy!n})GgKB z7}&=eHAtD&7%oM7q(vi>Q#xzJtw?Q91!ly4 zV88E3B`;b1P89a-M1ILO)$?J<-N$0y+Wxz?n@59mO4Ghok_7Q(CQeXL5B#T1|IKw8 zPnqn!rncUrdIPP*O^2*K(Jkh2QLf3W=PhDF8ti)^7TNCYD?e=(ccfJ2k3h%SiUuH0 zJ6x+k&-bA=1jlj!by$Rd*G^+ZS93|WA5SkvK{fj>P7Ko~HM44gv-Eo7OTlX(?{nvR z5*sB8F4Hqvp5x^9k7aa^_2Vz%2|PlEY;Z3fGd5 zo8*;a0A9%Bj|}+I7Rt?j&s!npeqSJc6K2?5hYk0P6o%5}%H+F$Xl z$rzrMc+gq5kU zhgYb(Wq+@Fs1+B0C;Xb#^Xw-aV8!CveBiwAN+o~s3g`|^2$$ZDtR%e6tG)VU*gRTq zXqrCTYO2myfcN|eBKw$)n&-a{hV{8!dsiO zD}!v~|5=wqmwrFH^RZt@hqVI|=V-9_emLpE8z2k(6G(#Z8d5VfK;c7*$lDQS#Plo9 zmPz+3){sM_@h#{hLwGxEXp*K_zFA+nHK&+U(1K?&LoX};rFZ|h4|9MsbRTSt{b+t5 zmzx%4(tVSa5@KSKmK)G4IHGA>CqPZMxBzn`2WL5jF+>O)m4#DOgACqx!wsoLixCAP zFpvW#`kQr{Wyj#E*=v%#-^&uKm~=UwwtxL^%|HbfUA*l02yA^xqs>`?5#B+e7U7^?JelFK_RFZgG{C zY-~Q<8FJJtxS4)eG}m`oSOQdp1Vg1L?K*!oeUsnk#yGEKjP#D<6k9ae5db&*d||Mo86t*r=4d4@<>%q)0to3{~xLkH5ZOxEQ=T#|+$e zq z(u`E;sk`zE(^3z9O&t4NSwmqWWGFHTH{^)pg&?}zHU9e|zHaqIq2$j7!aKUBF;9V8 zSSo9(Bh&pV5DhB^=ng=L9%TCeUkiW<+H`YaEndPQl5vl~RC=D5zcvm?pGA+pvp#{M zydG@Hk32>?PLpp$PobXkZ2QGbRa@Ky^9o^h;4Fcl)*$n!kNV^aj!HE-t+MLI;@)M6 z^_#$&KSS>0JOW9k*BkrjFsIGEkew>4m^FX>b>|($~%7m@nXnJ?oOQ>VrIkj z<&*WK)BadIrbS60B_ezsc0B5URdY^{q=R!xy+`(&^`IKQi>CdBmw8o0&+rO?q|RS= zACG#v8GL(eilKt?7!L>lKQyg4o{?W?c=rpeh9@Uh2c@D&{5*)x7 zD%x-<#XLy<6kiB4Hg;rcd!@tPLe~3mqOWu}7iy_YbyL36(&S9|QtK z@_ho%!p`y2AvAHhW3{Cxroy_ajNGhD?;q67L*rwvvLj3ex1>=V4h#f^BSc-6#OA1c?ZEuY>oB2()U+T{tnFh1}b83vQ`})2a5E4$c zl|gfXxPU8-2-mE6yxr(bem{rNL3#s=?NPoMWKR#|qdR1RlPsr;MjlVMZ-=+X&p3K& z-9VJUFsWR(eS=(9$7`e;jsBNBn;?M_#1;5$LT@SW(u|R+9X`2mY4oQEXKlXc_wFbr zBY|aDhe9sEjfT=#VF)$E^y(Ql-@5e-;QbXN@U!)&Kyqucz&y(bwL|dXRvC;>$R!55 z^X>akq1yV#ZzOU9VLWHr|7vZC-j?ekyiesq?$L{SxlSnmXf=ngG6sQXPANMf>k^7U z4!u%TDtff#{i|oo@vDkz&zQ%Dh1k}G>2sR~L^i$)7>TtBfijmS7o(>1%4ln7Pw{LN z6no^QMUPwxiy+~Yim>%e3aI-?c&zFTFP#@`a{j($HDk^2aRUQ-GRNi-Sz7M6JMW&r zdb)P7_I=h@dA(!Xc(nbJy4zbWUE*sEs~Wy5A|gcN*Sq)X603aSR%U(Jo-P*p3=y+g z|EAX)J^qc`ot{244h^qc;=ZzeL`jIU?+R}3UFrNTf1^n1(Yp3#TPL1UDlE;N z2uj^Ub(B&hOET;o3X!BP!~8H_Z@2go@U7u&7412jWwr?AT738?&9?zHE^?-VD75=S zFmbdtQC>Skw%Hdt4dm@_X3|&ZS5myoc;xSDuZIiy&A*v5YpS`NT zOOYq|Q`k{L%6+`gNDHtG-uS&=LhefPan8!L2uB3+SOYv3Ja!W$)RTk`q43G^Az z1dy~6=9%@+SK8*Pp8GCiMh3O6%XqZf4F5uHs}f59J9XCMN(nfBuUm?WM(EdC?ECLl zY8)pu6}I=fh`{QjONtM_qt*;U^p9+=*}6nFFs z|ENzO3aTixWZ#N;l6Pt|7d#U^6Au$yb^c&EW-@6Lx`|vN9Q#&*&j-Xa4_x=%4u2ea z6uTWDiaouX?hK}xu)NK4MC?+f8XiLVqVS%<=!$A?bl}pe6Hthilg81bhCIK`!iPA>$4t7RG{6$)C`lD5ohh626D1J zFOoo;*@4}MHmhqJ+G;@HoXNZJvPdsOU?yFS72-x$##kRfn_5>4UPBGHcSA!D5X4MF z3$QHQ0)95+clh~sYjc;pR{x-qwcg?!L~bTeD$i7;;QBY>tT>Ts-pC2>bi4E;qJqM- z0msK$sh(EG#>AQ5EVXi6+R{!(2}Q>4M-kc%IXa9lNnwsL0K&~>OiW#DIS}?qvwpKo zAOeDL=e0{ywXV-oh7R*vJ|W%@oJqk9Q&n5+_3|`&n7sqiTKIqVateM-i!Im4Z=0k( zG_Buf@4m^FDJ}#P`K>(1{41Kb+rM%%xJ$G@34nZe01Nt8<&9C59Z>SYqXHofpO_bu zb-7FwXa+qZ>nU7wH1`GeB@U&WXwM85iLMZuxx5t&Bf|VDRo#8M#EHb+G(%PY)pqyb z)QALC?*AqzMsC(%PdVbxlaha!JAQ%uZaSRGN0^`uL2iu5LXogU8xj(F?;GU9OoZ6W z_T0s^Lfgiyr6n{z?)bK__aJ-q4BI+x1Q1var5axY+ePn>HAkWL&V!(7)HC)ns^c zvA??P+gqJ_hbH}WzYbyLYH=rL8g1jfkzRLNeaomkHp*Z^o6*UP&1^1&70VjF^0CeF zB>NZVgFG_F?TuS5-HcnT4oDftUR2Tb5}8M9fL!Dvg}}EG|Jaj+0{p6%J%@e6>)obG zqlYwmV-;0Jmi%9dChz=Zjb@*EK<|D0{Y~Uy6O{@|1uHcO=%_U+N(XgS1e4h|k0aIV z&x!WpgXQm{kDec?ee}j}CIuT`BHkM9-h_@0_G`tFT``ZWA+j$D6Y}{{x9YO|s#*Qt zHA9_-?-G*?hikoS<@)L}aa9fXx$1~eVMlpllHAim%p%qz6}xPTja(P{913=*wU5|N z>2wA;-1Ni5i&}h-mER)b0x~Jtl*)1chlH*W;;&!QPu?{J(?)O;be}Olx9aIynhLa; z*sfkBP_M8Z3&C7?4>tAad(XcB!?IMDYBWfw^Bq&;D96WTm0)D|O!pBTZPv8)B7309 zVT}M9rF?vWhA7Xts=xM5eWv1n1Z6Y_ofZr*=XKGTr&PCh=XK}M)w|o2*^GQOW+W7l z%vgCEi6+6GDVZd8`lb%dP{(cbk+)5->PZP^x!LAVOK&zPP1p2mFXXcA`)fJ$rChZc zb6tBT9@hW7U!OGbLA(u@4?f^VmrJ8@a48G~yb4v0$dB<}sZPId5dN8XrvKaGEd8|V z?gzxUZRl?sK66Lzos#r~^deHnpus*($>U#$u<)7h_+i7FkEYw5)5tg0-|Y8ep8oCA z_t9?0iXb8|<)8%epS>92i^?rng#7FEnbobe$^J$6phL%Zq%zgF;M}E_RwgM$a}QKbpTtB&5f5B(Sy0-3P!L+(@76>k2x1m^VA4M_ zz+Xt2fZybv-M0TrRqxWq%~e+5Cvk?}Sxlb&as~{UP3gYoBa*`a0wUiD!%_0^CV%So zo=uXEsId0DB;(}+)QFKiW`hbm;5U6pu5>BhAhf+NUfQq^I@@!Rv_uoR5oay#dnDG| zRboFZ9n)&v5|m|}=-me!b}-GRm;eI1?L#R*CiVy;_hUYeKQ%~G1HL-_iJftSM;1OF z*KDj{k#%{>v*B&b=FO(wW)`cp>e36{az`UUX6!*Za(z?u$Erh{PGwsbBw5Qa00-e$ zG_N$Asj74yT2bhYg|zrys2v+Vk}T$7d+6!80ptCBd4D8@=Ub~e#X+U(vtPMTj{LnK z1?I;vK&~FLymmUzLq-$=0LW0dVU$b%uG$bN^K=?wYTI!Th)_oz5B2yt@EpfNygpWn1-9{YQ$Njh~p|7?Q)6{ z${gJCm@HZrqGd3IwJ2UO$quP=61kCyFg&7hu#pXkZ6E*q=HHDK;`+_dJ9ZT#o}-B> zIXg^kH|xs>xZ(oA%}Q`hWBnRzHqP&JEQa%hEPai@4uPrAw>3kAb#6)A~!Mmh& zjmjx<3)6h4jQyGL^p5?O+v$%s!T8me?DcL>NO*94ri>AJm;?rUsruGYJ~EJl{Rhw+ z6BrB*`7?1UCL!#NnUvLAyXSE-di7&yiQVi}XRUcBqNx)yTP}(R;^mlxU)79RSiqHSJ(9+h|160y7N z#<3o^*Il>5I9zMLujpN^_bxdQ67#dHNN-3XB+GK-JgwI>W!ObAGkq=;!y5Fgmofji z_jOZ#czWgwK0O1bCnsiS7BdU_yGbGZkjIFAQwq|Z+*lRAA|DRoV=@V@=1pKKo}V}k zxqo~;V1W&G_yme7VK|}|1!8XNi-Xd*WBiyHg(Xv1h*582*fX;L5+%SP2_E8`-_y)V zfpIbCfhw!C?%C@fFuF@cB}&&30R^vt}le?3)46>zaua4~GJTq$`po;?K(JbQ9EQ6MgC2JumEDe01+g#48l3b>in8 zo~x!jzdu{lf9tlwG>H*%a{mCQ|5-La<~sp|=x802lbRA<26%7Xcay0maiTQx&O%~hv+Mu0 zfQ{-5e@Go;ygTI=ikD%<2uEnwej*HL7N`96-tCRQ}4-cSO2sifKU?C@5XA#=^`w4+U7rGf* zWytjwC$Sqg`B)g^jq@*%8DTmlkXffw!aWp6R-!#ud59RjcO|=>Os`KWOE);7(L8n| zP`r3CvqP6&KAy0P=l%G4K-Y)lT+YaZw1!M4OaL*@X6b(o={tEGADH&hkMLm}W1eH? zU1zI%mZrc7mDk5L(-9m|3JiX(;K9dZ?IdjKN>sCwEl-d^Jm21k(a{#j(SQgAE(rX| zayE0%jr9gykJGgh+dep#^Nt`=1|6j?N@Bc-nhdm1t%i#lQ#6_hKNM3HgUVb5; z_x|iwxfGodt%=XRyWwwD1#(?^23)Fj^h*c{4SMBoWhM0DZX36HuIf(4LZ!7)keZZgC8zY-L?OMcYcSH zKanP{r^Ap74$@z$?{kyATMX}GuyPV!MLon+dTL`;b&C8XUKXqm-bArRhN13S37Wle z7{?s^scTb^6v8BXgBu#3@djuvDUm%Xc~{}?y;_s{AkH9KR>=m&#UJx3JRB0WGds~+ zb+Aw)wo?W_hB7VkPkkh9~T|1v}W3~Qplw_35#Qvs=KbaK1+tSXhErAfa3 zLXi~3UdJ$w74j}kRYGTi^LHkOs*L4Hg3yXXOO_05Dw*P_VAE1AJU08f3l*GgOl8&> zG+8~T&^aBsUBQw?ZT3$rUF^rZ3C%bf7@m*5_RbvWigTGi5h?Ja=Y6%0cH+f|K`d*kPmJ6EF z1cJAkzt40-Qi#OWpEIVOsvpMd>^GCur8fsaXqOhBq}4xP64Ri7P&M%#Te}q4&a>AV zzt2)YeI72GzPIf2Qc7#b$!qA204dn{{gza*y|bM^-xPlzw^H8e=Kq zwS9HJv(tHz6N0*M!20~sC7RUlq^+RvrG329KFyrh;40o0<8IBs6?WCOO}>5F@7p0~dluV5CYo7kNSt63hF)oRWKvGi>Ke zOE6NPoBV1`?1{17)UvHlwlQT`3s?d~K=VpX281i_-UD(IcJ>e%F^}MX#P&kM#+9em z1j?fY$9cRLCjcR7hyGCCF>ZVHE`r&MS6ZFq^D>BnLmU0eVH1=ouk#(wXb)ZZxr4N_ zef6^*5j}bIMs+SKG(n&nttE9Vqh6NN*7D2Uzt9c)^~c`Dc9^JDLLZNwH#x*1RSx_) z#gBV#_bmc`6mq4~3TU`%ls^SJ=tLJD{PHW%skf~@rpXc?Zw zP52+u>urTk&2w@kbce|8Obh&r+$`n=hMeOdzFDay>Mo)$;!-i+$QUQ-tmtwVJp4MZ zHam3~+%|f)ctO%}GIFMNrjXA*58qjs#&)huE@#Bxn(yb9?pK}eF%}IYdxLWFnQuR| zd1na_J`tceL;p57*H)V>yZ%=6(wQ+ASa!WrTbxEs)%jea?G%Mf%w2?}FU&I}ze&ge z>sYnW1b;!3?Xrz2LXvr#j9V=3pq9WDfcw}T+TkZXO+>49CGm~1Y+L_t zf)ZNqx~xUQ4}A2u53C37{ts=n_8&;1puyiJ#IKj`bMk&&mD4X>>QcBUbG?$Qmu%yj z91}?BN)Dw-FmT`pTwzbl3=Vx|2>MJ}za+(hxy4`3b(zK{P6evUTlgHzl(!~Nr}CY( zW6}4iAt$1d^(B?bMFZm=PCEK)t5T>*&`ZqQLGRR?PAuhd=wy55GCimCMSJ@-%!K?r zaj;lBh2sP}FaJlD@A&QbL8T4UumH%A@W)0{9q5iI3l@mjdhEY4z8`||wS1*@=G_Em zWxLgD^mm15H^k7;o`=l4!$sVsE%ydR>zuaf?T-!bfAy4tcE+P(t{i7N*edS&ct=gG z`ED)hGv`0v6vTA&UMz&vM&i@>3~`dZJA?iG@%&(|Cg7Qywum>Is`mqUej7B-r*LC99hFHOSRwBzg-nXVs|Khhfl;I<-L0PHdN8KD3bG9nPuYr!U zCTI8qotVxa0`x<3HNaOz8vK4H>PenJ({BoXBh(zU{-1!J@cu<=^jkH?Wb_ELDL21v z&K6!1f9g~{Vfh#hqM1r}Cs;F2Z|bT%Pis0~U)?F*1i!{lQpgJINnQ+pb$@X95O3-W ztNE7X8{W)lXt6pN-r~Hik>`ZqnMF`*)O^Vk%99%IEd!wmRt4wh{=S_v*YUbpwSnPj zzT)Ggzf6N|+S0w{s}oJGb1W8VD|?mp5pekVMHyd8+^P-lVwzZuZ_P*0q7}cI;mgn$ zE2Wv{z|VKp?wW<5{`maY&wPEVwj?jnPSJkm&iJKl>h>Rn+fR+q`m)W7QAy|Haf;+^ z-xlMo*F5X(Hjne5_G{nM6uh_l!ot9Y70z#xhK8#DB1T}R4r-6Pv7F>tVI7UVaI}2t zvFVl)l|L64ohmR~eZU#y8skYJFA<0AoLsHk6}j(cd*G?P$~%xo%l6v|I|XSWORln2 zHU;}@t@s#uf!!ux+7`nW(+;a(%pww;81VpSQDOBPF{gvtTQ_WFk^LHpvvR^mE&tH{> zCz<_Wg;!k)NeS3lzV1y-TYz(=+DPW|nH!k-4IW91R_wA9)AYZbjMm-%NZrCF18uN4!v+}=d$ z^T=j{WaIb8G)*On+B<}?)PoCe=AZ~FMl=JQ0Yu*P&#T_5=6e5jA)+`PE>Eic6|KDB zk6Sc98JHr7U%E4Y@jTKDfTAOMhTPY;y>6=@N3VN!qsh?P>8H1lN4lLq28CW-{6t@} zn@IfR7}s82l_QfoZ-kHLd}8Xpx!P!CGy=Stv=R`HB1UN z-DunEHO>F6_!ZQy{Enjk!McH!zVkAP?BTg-V zK9S5Q^{RPHWvXR@U%6?RP@wNtN4#fR zW;&L-T5f{rvIUl`ckF!ph6jif?mq?U8#L+K&I;2FNm8EI zncaugv#37rB8=9`iNC)s(x~C%ayReC(gm6-s;~nbY%|wiI%91YuHS6GEUaEXuHlCm zfnDF%A&}&+;C~`ejPmYBT>Ik^6Cpv!kZ5}O{8OXJbG6v_v-|X%@s;B4Du~4O$$dj? zASnhZ8 z(1D(}+9Rc0@)~jb`{4lB)ef!udw!={3k<`<3G=nED})6s*vfa$^fw2kP!l}8{l5DE za3VAes220qci^rhR}0OrB_er(lSa z24@n4=!}y{db3W`ukMR;KmD?uM7ib6K&V=797LMR$!SVQLC0T^w_h z(3_8LQ*a zY~Z8t0dnHjl1UGHse4!@+^h+1e;(?f;N$9pBxCZFI}DZU(Mgo~(RhmL#FbhjEgLpq zFmW{d0^c<`U&N{_eq1F>R?_5B#=?1WpWyCmfr0R;oq{!uf;EyXj!cOl8r(QP%25E@ z&lE+!Nl|hnU!EU4vhlson;0ynjXpmH`eGZ3G@$uEm;Py8etJObH;F^zH>7(p=iLmu zZKBi2xgWYtqv$Lvf*?Qu-ymp0@TR3qGihQ05)-NO^d(;`|6Jh}mPV@3=cRzll4-Zs ztk(GND4vI-&(yN#*d@XPk#9ewNXm^<3?>=z#Yz6W4N$B4Z5$&6fI5plD_9ycdyDr< z(|5R)?!o}mWR0s#)F0t?`Ek9B>;ggl(`h6pz4iLzUJc9N*FVTi$j%lff5vwIP&S z1!+35nxO^+-vyqN4E!)~9|g$=&5`$q#1Q4p_x;{cuMfW>#OP;d1|DA9cD)(+p)d1k zYceEVxGJ`BtXL&cw1v^;5m`g@SAUCa|_AXWTwctL;VM}^juX9lF1RueuQr4(~_k@vN+VJk$buV zxbww{bNY`8C^ESq`k%-niKEoOiGgN(yrd`w`VXRM0TO8-08Ub38KbJFem50jbFus%wU#aBAtVQ_RZG0a@J<<#ro?HeJsAAp0Q_Hw&wd`dmPwRn626J&wyu#kE@_wRrW}8W{HddN`H1J>`~xQFzp!RY$>a-+3Y{E0rrl;bQVvWby4U!PIY& zH~me%N5!1|*fPEJR+n9!{0#b+OqF}#V&YRAkR7k%&&Q!9GVbO^-=-b46?r*L;;j?U4C!%?)}+)bOq zz}ZofjYUa_lZd#m!;TiH^=dnztZh1yPIBx;>3EuhzRQ8L#5iMAb)~hy^aWe2bMFlQ ztcj7!ORcg$VLtuHIOPnK`kCwoOXaRDNfFtas1SUQOb-wBh<5ks?9{m`q40{N2s%Z~ zME!Gc^$d$vh+1Z2x7XEM(@;J-N`X-7@F_rAm*xn(<1bCQqPmbWCRMm3L@_hzieOWR zt2K_jqj2lLXPTY2;C?aubNHC1q?14Xcy(`2Haq)?UubuxTR+u_5OxE}`_~*gsD()* zs|i>;*0r#qmEWx|F*rNXWG^@wH;bgi15Ezf@S`E^smS8&pg|Xa#KgodJhI*Vn>kq= z+rFx0m7X~z*J$(Gu%ddf6}kaR?~cW&B+kCa;vhc5-voebo2mY2|Ko8 zlN`zb*z>9;Itc)xY7ObcpAIdflB2O^9>w#wR;+>*U1uZU}<1UDY_)W_8&;PCw9-3Y7k#1oQ>L zfe|<);F$2)dy(MCgfKlsayjIfL@08P%JFMMiKUov!BZ{8NMy-y)0u#H4D1#Z-v!A5 z(sk+~WM(`ckrOSsprz3lA%a`9!Ao-|?-Uw#)V*TR}qt{}i51gLy@0%ocE$`Oe$nCpy{w*wOcJA+TwK zrILxpwC79H*o$lf__$(C^z?v}Od6Z=X9Y$UB12wO9&=>O)}OoBy$#SQ@pRp;>2!&A z?xL!XM)+~lT#u}w06tY9KAbIUgHJacxxLc8?|C1Tv7P z{|7aSh>v{n;|U=1Jc_x+rT}j!V~sKUZsX1}_JXS214A<&;_bN=$3A3NE%x}W`#t~j z4k7BJGp^)M^GX_*Upkm8P!&wkM!f($go{PDrE1l`6n7b~%e$|YSz5ej_-|B#NE@L_ z#JT(}{KjPBM)9apLXY2)r`=o>LNDKKOuJA{A{pmgIr@A+z(@9jr@|#o!u%sf$G)=r zgzya;i0)Ha)0;|J6vmIdAc-!pRR@=U!(WUv_6HgJh-{Yw9R$u@VG34-qX~x91ior7 z1d{fsL`BBbvg&B(<=G3qX)`mpqLV|gB%5P$vM#O+=EfsmNJ|NCY4 z*^aB$e$0_PSYtqKY{q|uQkE)j-%XY5EwF7_V9j9L-54icakrs|DpLtD;QJOSwfLUQ z!}rD|&b!ssIiEWedga&eFv=~vw`3Y}9Fw_H8y%n8ZXPJLwyT(i-j@L0KsRLaO%LJe zf}69Qcl>p9r^hLWYgPo_)A5(oa`6Y04if(7K8wY>UhYqR^1G9SNjJ8l)SjJf8Zm6^ zv^ZV_PxvTlk28O+qcz_o);V6>em5L^?oNsXO`=`qozXHB!R^{FjTd-bnd6L5Xei=og{VA3aJYlgA(5uf~v@@%vfYsta*giT3@ zlksu7AWXZD*@TG0_FOWcj}Oz;fxruG5gu-5I069mEx-!H(LR?}4l=o@E>WQ5$ z4tx%XPIfI46qs_`KJiP!S^wG8dU|m85!o1Y_rciM>!r{GYD+Fmpb~omBgC}Ox`&M2 zQBRJf#H7Q~_mbd?5uEBkc9&Ov#5mv#7j^$0ew?h};QTy{?7d)}SYP9<(?w&|s0yfI zl7@pBEExpcF2C)WM-aI_=3{JwK1GoEx?_<^rQb!!Z)P3g2sP(=73hkHc4yqz3*H!pTDxkP_(tc(cui$!ZK#^ zBGV-77axgZ4IuA0b*E^1ksy?~T{j@}c)XX{(@~ICWBVr?RGjWwtnG|VV zd2uSmMR{F?B1h64@B-G$mPnLN3t5F(U-zMzgiLi(g@6LVT{-|yvjjW^Jj*bC>1XS{ zdVudY6F*UTu=&Pl-e20i3IThFZfV0)!4oMARs6BIzY$0 z*TvFurN&T$vP=3qqC#ReXF!^yL$b1*q}5kzd`F!-*U7OGmIP4HRli4b)*L{B z9NnM}CrN2D>-G=sy z;5QA5KqViDwaaLBX6DE~gF0)CPF#Z@XqlRLVvga7`uu4E6$#s28;tb=B#QNP*s3Bi zR;{03Ey^{ezp0cWJ#*?lSKvQ{_f*)LsyQ<-{)}mYW_lJoC%5Yu+t<(5>}a~ursR;4j)rbYk17T~K7ne7whnhodq+GLuGy`3UC zZl5SDs>||%qDq;FtYgV0JVJHizA^l_j>sjE)*=rMOY54cXj1hQg<>P-0pS z*q8))E~^@SDQN_ldMS*6{I6B(T)X+{oBWv73rW_wsi(KYu5Qk;w_^v{53fYCHWXI9 zr&G}x7Ef$MGB1w`L?Bw1W_VzGxtXT3`XI#*ht!)rbwL_2kct+Pjmi-uKtH}JyC+o7R4 zlakLeFb$x}jyV2~l}j&jQVbJ*W&ntoHZ!oHzu>(_f|%cqNC1xytb@x{+xVEo9Z<)P zttu{#e}|+3Cxw()sAFlL`dl%bQR2-osnx(9lha*;35a#ATBjF`r#@Tq0=^ERwYZmQ zuIJ&^k*3AK-iV(yqD(u|WIH;v`qPmE-Xm=`yOzBivzpz1rs3i8Y^gM1H%R&YU;vW( z5-4_&3R_HalT0A&+lh_8ikhl(04IC8oJ51qx2?W6x4q`uw(0B#dK+3mcb;N!ZoT=I zbR5qbbd|gDqzVgm1$1aY#X)lCOS3YKAnA|-pkOZI+`E(EULn>MMM(|`u(@Ce#-gMD zMMr+E;TE23GaaxjuRgq|feK812xYB{0VO6%jyoi-iW;zZqqOG6OyITDGX&sfE_mpE ze;rrg@BE%KF&Ize6t|w5Z4B8?8c!oFd6v=L=H^>R_aqFD(je>;Q>AnZv_pRUbB^XE za2K6-hLtObj<{0YQMX0g;E&l|)4+kYu*u$s1fK_!qbu=3bho?K2en*81N?)gx(JkQ zO?bqxK!0rkwTl0~;nzl`4luwW!}ST6e5j2j9zt2qL+%MhKynA<7WOK&=&bC0Oh$_u z`8kwZDbdZYX_)sC*LLGOla$GrVL58)VFF!_uUGZX1KjgBhNwRo_@Bi8Pb{g0jTC%D(H&qB}jG?BDESYLR@88SneUpV?L6@W@Y6?DiKc4oRlYV;nFpK zc+k>tbN09}Gn|w5*>GKXjmW%XdY~Rjw3T>vT5JP`q-x$ffqd*s`Gf6qVUsE_( z>}?9pzcP~yFHHIz4r>CDTj$s&Jsi0>`z&Hcu16EFVed!@Sj~0N8R;3=yvd`uzfJmt z`3_#Z3(K5ZW@3%t0^WRv!zBE3_OJzlWYp#rTpthyW0jHyPMHmn*hedS#8jM=aCt8DmDtr`u}MSR_Pa0 zVB(8t7{vTIw6%be3Gyx9o?(Xi8;O_1Yfx!X<_uwAf-QC?KI0Sch z2(H0xu$}kY{crYWuFlQrXHIulRd+q7#_IBR5)&ojvz81oef!@oPb=d-@9(WeJkD1e z49H1Gqi;tyoVKl9T;V#f&G%l1;gykHmUZGw;_P77P+|T4F`WV4LYS?*U2dmOKjY7= zen!p9k645tUwHWVrx5!%m8rIwsUEgxx~rXyxm%iFnPu%>Ei05Z>FoUok^5P(c%Pd0 z@Gt!lCu7<5tFJo;75e(D1_mc!76tTEO%6+(?&KB5Vh9q(hu#!=1>q3v4;N&AFoqx6 z*i}RIfHrA@E{d$)NBMN2r)_N)uC~V>wJL>7l}e&OwxkXZ894(3-?%*_YD&RzGpPvh zCkc4$AQy#BKXX4(+mvQD?o8I7#dbUMx-dD`?1U6ReUl<3OMb-DW4>S)_=g`a_67Rd zLnIFnJW8)+5V{#T|AwZdSXCZkm7K_X8YHUcM(Y!!W{75?HVAN-No!JZ*SjgD|!RKr*RaJ^-pQoYNJRjhEaU0L` z*o=5vXH87ai&KM%<|38$!?NecPM3%Jh{3jy$=e_U75`N<)p0hTe-$~@dMoCQ&no|0 zBSwkRFhs970hC}&^6#7N z1Ygm?hAe|bnRw`tL&ZvN!wC@1u$}^GJPpS@QO)vsUUcNbet~%#LI4=bX=;CIkf*7i zw(_N9&_|=Man_1l9VYKH$r%&`?&OmPtm+~|Dw}?XkeCzZ3$1aaX9k!kn0&pELVF!| zF>H@Me7dcYeHyKy9M}==M8!3_9t@*+H1AyW8wn?s8jcq;Pd3A!|M0xl>FgC@ABU3g zEvc$shKE!)(l4uzAuv<1o}sr>6y_f!3{tQ%wqeS})FQu`MUERNofU%L!Q(}wAIymD zXHXhY7~6}@liL!!tnyQs`ih7*7+IoeTKOBR;-4Kp4zX0F0HU2?r@!3tLQ|bWtI72J z?(G|VrRE1--O~?s60N9_9UJ>hkc^nbk02LRRZ2B{hSgal7!afcS^=5xIL`fvIr%P8 zQN9^;7Qwr*Y94%IDq^=V9MGKGV$pwx_xt1) z{)4}rVj+J3Jn{=7_#$Z9D;I6SS`sl@!BnyvCb+lc9A1Jw!3KCp=7UtQQy%M5t$p_G zEBuAX{z@svs)YD;DLT5n$eUehCt7!xfz{5IH~MPPv^)SlH8&LP4dJW-1#pCN-`bHbrkU)^mXnXRG(Ot8s=w5H1>BE(m}+6f{5nV4-2gy>-ko7(Up z&lfyX5Jso}pbU49T~Nd@OMZ;j&Q;@QAEHE2q~z{U1UQbFXp38?oGv5kmcrHhN~4Qj zP{Cp%Z3U#S(npbKV{&aZ^K~Akdd~>q?8c(qDWKi)kIUtBF5_~G_*PCzu|CNxwwzRWM7M$%Y6Tavt5@X_vjwuqf6ocv6GB)|9(!g zox|0ekh%KI+KxoUq35{MkJb1-Vq>v$w?Vy<2ZLr8`l~OHKJtM6NVV;4hRNB7>qr%% z#A#GgjP5=4tf~qec_}^48%vJZI`vr;7zGb1m@2{5Mg_cIvlL{9h+b0h7$+$~Meq&a z3s`^psIp_=CoK;1ehaamuCh6KTU*okv{L?dDG;qbVaWdrF`iv-_2Rn^G`@b z#C!dU+mctB&mB0UiqXJNJYYG)FO1wKW2Q5U^I%|fI3 zn@X8|=530(&6FD=-bjH0blr4Kn8n1C_E>`3YY4wV!@mM2^3enH&nN-&4KTK2+(^1Ube~@RO@S>zhL+)7SPgFI3h)XkZEccPNT4A{Nf9n`ju}qy-D0_G znacowX&i-bq{VnTXu;HcxKcJIxi*7lY4|sU5)v?POWX>v&pS(VWDU&g2c+xDDpGSJ zvd^2)Jj`262Ns|z=!r6#GIzs2N!6F!8XSMt%zP%UbnF?8iS&mrD3@|{J~2!?*=D$E z-h{#@IZc{-;}G=GQPZrHX;NI)_0hY*KIydq zq*_H9jX-fv!)k{XdNKjlNxp3MX?9P8YQ(&MzlfkpTNNq&5FZpsY@2!R%vo|y<3-0n z!wRFtJc+O>s=h6!l81y&U6tVz3+seFmmm-m5hZNMcj5gfUAc~*ayOZ?KjHvdB52wc zq0M-6WBK7geA$zHrmBtFTAJ8<$Heu0>u%U-eu(xAMo^`oe<(N0Ucysy zAY|QN2POEm5P%yDkK&XpEZq)ci_0h&u&O%RW=Ik5kr8OJt#{LprjN;mQa zT%(hg_u3NI>=K;-=?3}lyiH~wZQv*}0J3MfBYnL>Zsy1PSC^6dUL^iAyzs_fAhB-M z^fQzKwndu2TKW@tqqJn1EfWGVz_iK(j540}3dD49Cqa3N+V~*9t7LTLfK_68iBHK4 z4mA2*xNoKypgulb)|OftlW8U!aaX_+%VPs0BAvG3bR5^K7m;ff7~iKn4BUZpfcaOdPQvDh24etsM}n-`A? z=`~Ot093LRi$^ed7p@ke-W5&DPU?F>o@Z>Z`9Cf|Jjz-kR@Y@x{%^bbtWMEAnEUVe z%)#Q{R=v!cF=&7y@5n|Nm?##yV!F!Y|A<9X%)S+ZA+>PXP1)Bo$6Ut`&F>l2SZ}5x zrn|c1ga_^vlmR!A6BtKB8?n@kX7z<8etV7^mmW9!4k)~!IGrlg@=$r_uAM=?_7uTi z6lArE(=@f&JUOks6)ibmr)&){q$Dyi!~dyMQmW|rG4NSdn|1rUL%%nsC$AVdI0U`k zN=HcW*GIa~BrJOXQ%Yz^*e-h}kM|;wv=f3M$frgyY0gAgAqHHb@v!gj$;pet-G&4| z_yh1|nUfXMFo9GYiMnH_&Zcx6$r=!ziZ}`0Fl0cMo_*bWN5;U|DvcGBtHs!M+0a*3 z-+p@_oHa@*Px{erw#)~^D5=!7Lq(N3BiOVP%HJ1Z@`cK6US${s(*H(16PcACj$ixZ zNJoqw5|RDP10yTE3=$>`(D>q+i*A9<8XEm;AAepD-Q1kI;wbLDVo^z>H=9_)rAb>J zyAun}?NX4D$iG=0wRVfiQ8KM1sqOGYD7w?;iDT`}j297F54T?%>~BR9(YcshNkiga zS(|a)TbQZIlso{R9TMcCI!gn=5TJB2KE%Y2c&68923aOHvKw9mWnKrzxos!{=0YirwezH&7AagN%{Epe^0o3Ol_0y{%J zGXm&O-Q)hrBu`|YbS!-E*BqdyN_*YGTbfh7>A4i~z1wxbVcC%9aONF!kBpOKclXd8 zwtSJIRDZQE5C`vnZiGw17Pb|g$WNs>Pp03i!u<{`akUqrWW?kPdz-E@WTL4^%{-sy zQF@QmTm9<x6UxyM1J_ca*J0Khc-%Zi4Jmeenp%S1e!#EW@*xg$qr>^IjnQF-Zng z;lqo1yK%suz?EWbJR4>Wc9PeGtcF3nP3f970xxid8i>Q(rRPXy9OAD|x@64;XLI7#j$7O7KK zaI~x}Z3;dxH3Bc-C~uMw1K3oBX*}OCPm;sz_#m|8NAeE8zk8%+!u-#?5qgf7$mot&&n7{${`kvd&I zLIpSX1_=>S?OMb8F$Hq9?vd^?-ikP!`_#$^j>|Hzncy{VC3SkP>m=H>@Gv##_UTFL zmz>EQ*NOw<0OW`P@>Y@-3}kT5P<(4K6)96_btyfATk4f>kt;ST-Vq2WPsR3mA5ZT3 z3rdLTej)os}_8Hs*C-Yz>2wzb_Ml zErP_b7Q*cipO5NU@l}U}P091%$Z0ceyD7oh7$#lrV$zIWpZsYz?wm4bVAFKWaoGzW zucFDwr7k;KHPYqtp9!q++Nz5t!aR|p3;R|T;Tv7A?2IWNS{JRdqzUI>Ny_;6t* zK@HK1_@^{@x`Y0SoGFj|qoM-K_^qX)XmoihPp6uX~P%the7JtMso3p@nV2ouYBdg4( z;)m%STy$8UaO6nOYuv+Fs?73)a`9V}qc+D7CQgpbph%gSQ<70ead z6szh?p~p6`d?t+5qB&C>s}>hd;F+VI%ZBTPXG%gL8Za$}r}Di7{~m5wRV9U68u_9V z9~$5oZgLI3ZbI0Af)DpFHzbGy+(kmB#>K?i7!lLq!0BTaz0rX{l~6Mr(o8ey#e=al zo;_$HClbEbSHR8b>IAFNEvu$0GELA?@L*E;ac45z*Jd&{=u0>`0R(<9f|;?l~@OOtNSY zIR)-RbqFftFxc;;Kn1a&NLK?*4ikUyYosD3Kg$qv(&BiycHn%un1C&k3rzI{qbXS) z9J=`c2`wd;a`Zhq^yt`(sR{0`cAJ0P=KgYUuky%lzs#@3nlR`2T^A!?fnv&kLT=I= z>Uxuu(I0nr1XGM6AUwY4MSrB&%M;5pe^CZ$G1;hl%M^DNCD?Z$e)@MwOLV!LP$WLy z*=*DV;TXkOD*ub5HE0hQ3Wk!Lq!K9Ypq~oubunShs7J?doA;@0rFD7thm{R(9zRMn z4doYF>NLOyy_52rX7ZPlmO%xayf!0%bjwL|c}~rC6-_44M7#nAw^n7b>{yfK-*Ywd z*)vYoEn#G=c*M^|2@42S0QJr(u%BU84={^;zYLL$!BPMjp=`1-g$KL(Bx$^+?|YH< zeezv){De)vCc&qVj2iQy(eaIK*-+?|Eh*;o#fcyeg^I8FG@$zR<_BLRN}w+?zn*GAd*;uz zOG33U;5oI)- zQsnZu)v|9`S+}|t#M&6YiGLA{nGrTj&4}!m26ztj)B;5@kRl?6q`yP(Lz9_LlDI=k zkfauthu5rPcj#P0JnsE&7nO?2k}x`_$a|^8Qt8c6P$C!Zz&YkH?H^~c7lty3w!JO# zFaE^1QA~yXtMWmu$S6##Mgod-q0vGxw%A3+Ar9u0XBu4WJ&{ZY1SUAbdyW7CMFA## z*yI`?o?+x_=Hh?LMY}dv*O{=u_1|!5F*PnDBoCLH50;RZgWMo@mD(V0n;B{g>c10| zq1GmyRlE4sEwe0AFU|i1p8Ioxkx}#T-UG=QOFO>A%Gwx5a8vH7@dP?aQ`#j<{sIhh z!gSNrzt9%I2_PUyG9`-j5h0&c!fTa4J`lMaeORyGeK5ni#iW!SZ{#x0d@{dMYEfPo z!1@*|Azc!^3&-HWkS6FHWbBi8!W$tzpN~Te9&s;sShIj5&wDyi(eU754gRx{T(0jh zoRjG@fvE{NKIljEiJH=!i;ox>>%EiV^H6EGyc-sNR+` z?irS_GRl!lB7dLm}fM~MRPHHN^CxFvaI1sEkIN`HAUBq)inz@%5Qr5eUK5BG%p zAN>H6UXF+c&0j646S*2Owa^IgthTUD7|q#xe5rYO$?^Nw1Xo6j-*DX)9Sz*{fmJ7r zw#EMohO0w^qhTtaPc&YxU*{>`{d$!LUr++n_HfjLrJI4SeWQ^WxN_FX47@rPF^G5P zg)XB|z1-?~gmt)apYNGc`@)AhbnE+wDK7%|TXh$SZnALhBzyjMjp{E%FU0sKB=I#f z_6bue^J@3~T>U^({Md{TIV~mY4|vGH-U+uwX0=EL3H~r~RYvG4F|hz`et60vQ7O9q zN!%RC@qsDPC%9dc6LL_KFKFdh>D5R5-Z$tdR(inoSj69FdvT_L?E>qybXMl^t4;Fv zom2WTq{MxcpV+37H1M$4eIm$rU7C*V>h-U--VZ59zsCFxzP7xW84XAP9cWOA$av@2 z406}04bWGgJf4CNsf^p$&T(<6UH7gL^k%GkaRjIV1H%PK)q@Ij>y^iCpAK$rBV=-B zOOh%CB=e;ed4(vK-*A?dc>o0k+`$$c|K+gMNj*#bU`)8O#lAl>) ziQtfuyi}YWqyzYj<^g<mD2)ZBQtxay>D z@OuNcYs+3SX`4vU8z?$Zb1y@YI zlXl%APTC?3KVkS2G?1GXV*xMAV?r#$L~|-m>+Xg0gnm@@^XZMo2hZy+C`g#^T83;* z2cbWh{rIP-->kK>5~=5I{64`mT&&fILWz>DU-UglKcb^kE z7B@|tAixZw^z%A;612q-J?px05UjXvcLH&q-B#iFkXNDr`5q$FWQ*+iQ9dP_!^LEDk@7iK~j-80y z;|@T{;}2D}GSsHCup(#Ik(2U!*7`8koSd+(ce+8Q)W8e1}zxCD2r``QFgtt&z$>UPv{|hIHtwkz+sjVS~(Q zv1=$*7!-pmh+AFWJ4^bdlPcBoL>~yE>eR$2*W|udsMKFJJ0OoYpJRcJK;g`Q*U68h zQKUensDb-$>9%jSEdR#^9OIBMt(v5xgf}ZqWOI+g)#|AM>uKs7+ARV|VMvhDF-A`Lp;7^|`4x+H{D@L|cx( z44aaIA6{BI@j4k1S`AC?6xs)(sDiQpJ2ok_Fon`O$y%A^x-R2+cwo?ZUu&hse|*1FGNq$ng+ zUgG$xcWZ{-72Ims9Nr3Du^0|$v`0Ri0Ra7Y06|g;DWVCz9Uuc)XjY=pl024cjN*Jz-s zHgoefYI`2RA1}7zTX{0KCo|+TlMQqyQq(dXg1gc;yy^%$g6xDF*v;P%*!6K>UWA&~ z8O)FT7kxxw@Mm$_`SISTi`@>Rt1#KUa+!QzqRK8WD*NM#QuyMV2#7SThf_7LR4pVtl%A9`=A#Sz^0zRvlr2iO zHznacjQjE1(&Dv+SQZX(Hsd2ik^c(MeS)uBplH7yNSdj?HJ)iRa#zuy^7yn$;G@>d z>zqyyZ*@Z#n|?uH7xL9j#Xfm^`lEg5TLm5I4RTJLMAy#Edrsq7LQA8~gZ-{8dK79y zLPX}lLFS8@ND*$UKcg(&`rs&`sYW<*s%je5oYV#5cZhJiNl5f_{9Ud)C1h#@VN{|( zZ~`)=ZAikGXQnB|G;q?do?wRpyspPMI&h5V4bTy0%maT=<$cV2-p`P@F`%16q~C@F zt0t&nhB6}2C6K)zf60+h1(sV<;xNjEG#U7Jy5AJP&MUR>FqV*OE)LVCh$oEx{U&2U za}@q$yUB)5Z)hi|l{Q~J-hw9$I}*=o_(D(oCsPR1RF^RA8f&S7q2+GGclmeayM|he zh&H7dEC3MJr_Fxm2JMTd=*w{8&#lTIoX{xQQV&ChO`meIY5@nZ4FZ$|BGp1jHzS`t)LG{W=0ZF>@DT)>%8!uf5ldO=%t@q{esjsQ`)<%6{ zKM~5J(|b!;FJkd}&KLf|%#ol~UB+;a!A1z20(Vk_ZxPM?twu2ZE~*goXa-TcYdo}k zRp9@Z)L|z=mU8S?ZdWQ z5z~_y8WP^I;2g0pj|+jNM~(WQWe~)-6!>I$P%H?@hB9{0Hb;0u7X%>4SWLX_ooao= zTb~lHXUl%I+>z~I59}c-G2D?#!Yc?`o3Ve9z0|_`q*&5)%c5P7KlptO8XZp%g1^An zyj*KW3-pElSVN4lM@6oBSX0z)H66E{!a6T`wekV4E=y;fHMm_VFy}m9bs)HTU{*=) zqdos8abtkl7-UFj-7uGA@#WuSdYsBo_m1+!+)Cg^$5aD)^^|?_G8@ECw4BTD%4}%zjV~nVntv zC>c2QKwt?1e|FcEBwPnU;y;=7axO=HpI)L!&XjtLU!nn-NSTIoW zGU#AF;#2!;>xWrxnb~VEaq`xWQO5;^4U2p>qBMM_2VJXO`qF>EolO zMwo)!{VfXEjdP7uqOPpOYxSr2mZV=EUm6$uywmk)D@e}7&f{t|!y(oZPVZ#%1=Z`@ zTJ3fJXZFrJ^p-#LW8D3mGd;NB4Zx(VyLmi9E<4F4_5?(^PPiJ8vVk%Z(g;7 z?$;ooLfrVq<9_P&muiHUXECgf{JryJh<@u*jP0y@pdFHJmIeS(4DdQCj6_^SL;_}q zYBRWxToA6QQ8&JmB;!~}uc`gQ$?6&M$x7ZP1NngbOaTo6>7rW^!RlWtq`o4J_zTOI z&O;(AP)j!UeOV7=AC-BH@Ej5rhSFhKq90>Jr3y@W5yCK~Z%9rbhl7a1HxM*B>pTGL zPPSn-Bm@)v2_w)~!L9dx7{Gvxszm=_N0gUCjQy&r#~aR!5ABB_gA9=!lCk}K=ZK0$ zO$FD>{t_|m($+JvpFCkd9n{Hni_O0YR| z9~Y7c{Z*>9#i;o)q{jIW_z_nF#evKSDjZ6>q09uAQWoAVbyKiWISH~o+COT2Q@!j&$rEQDsW?TP@hA4RybzH>5Dy?)RI#)yvJ@2tI?PKoU z&@_Ly^3$^Hs_azkwTl z9HxYBB0JB;ypvFmPQBrwZkv2aH)#M%24TJCi-YAC2V?=j&u)X9s=4E{gR1j8nl!~l z;Jq;mO?12SnkaYOHUm?xK3k5AifAp>oSXLg7kys#*}Md|iMH{|V{V132w?^(9?zOu z6}J3;zpViUQRETk?E1<#0EF0}HO~Gt?aDG}mtB?#>D*P)3e1YVSy($<^T*XNWKP11 zU)Z4IV-M;|6=wPt4qqWCWn3GIq#mOWwbP!FC8eOIR8p8y71PcLqa75k+n1&1J?noP*itBkx1fary(@0T9S`vQ8I$v%1E$Rmq8&~VlWPaud+9(DNu~DHk6mm{KGbSB>-YU* z3*7ZFZ>47e(Md{S8I=?OXSn8QgEvzcdz&2fz5l!I^ln(7GS6%fdrC`e0*j|;w=grf;%l_6Nz()2uIRp4Zi!XcFyfZVgNx9)`A zF?yf*U4YMb@gITD^{U#Khtg)5j#mj{ejaRz>vU;6;TR<+W;KmvHA*uM&f2WL=V3{W zGw_N-Z8bI+S@8IP!R%l?>q!3fnRbOV>IzZplgWqNp10DWJe6i;I+^c|PL}fTpHly= zobJZAm#O(3v^G&BX2?PySpXVF_-%g4ShP7iN&LVUfIoLK0=oatcYV*o&y_cr&zETi z4To}tXMb3?ZEF)Eyu{gyoEC(urmIZOLy*dJ&L^4e0$}48hCd<+Fe~Fjr4Wtyg$s(j z)V3#3^uZ#L$fP|+u`rtZVw5X8ZQM{Muaj`x!oe^j^RZ7579XHhBs`+ z1_OfztG;J2H#yg3zX|#|pWo|6zSBD>-LKNJ*{$(yWHp;&rdr1di@g>33Iw>w;wm-< zHYlbYP}4!A%PgS`s4mxTQ?PZ{>nXebC5%Ec&8N-=ZF*uI{@{H9{i^)Q=r)TXvCsHm z_{EbtDR-OM*^+BWb+FUHZy%Jo)k}qW|+UhU@%%NBx-YV_*SfU1d&QN}D$? zzIbZ=fPB^O{Z8@k#lYprFVqfb@O&iFs^m#9_4#v{yXkU3xSq?eyxlnvP~W?1xz6J$ zmRoJmt#PYjv_FZ@FADB`kZ_e#H8VY83+6S5Bgzs|K&-rE=jT<~b~8C5YT*!05&5r@ zz0CejQG<-8xNhD7E`ap;R%##-M{FRVMkEj71k2O}tpo7?HI1)o?x4hxWewe^qCQ~Q zpH<@N!Z4S)wc3srH`bc-W(krpFQ(U9DVh4pC%}j9?~wN=Q2yJg|XY$f_EvrAz z322-qLt-&N;O@zD11axJOe3oD7mYoYFNUF^k>@`aoFA6B%s<}b5*3GUl@t4ybXp4t zV6@|`jOa@&Zj-q?NHruc&4_@#H(xfT6E4L0uPf-5$8v>a`k|1F}nSs zEnV5S9vi{xp(DO#>G1l+4o}}P5OnlaHi7gAFM_mwk`O4ByByR$a{XSR!O?=>8;H_-akcIkaWheXXb z2oFK$$D@l3W>Rx%*-j3d47I=p^%?m6%XAJDXCYBj(4zbFh8!_yh(DHCm|{k-NsQZe zlZvm$^ETQqN*oZ-n9VYoV78GKVx9oPtHw&ihOB^EWG_mphF89|uJ|bNYkTDDx>PAm^#)dD z1MG%)6{?|*_^#nT#IVH`U0313a2FHYZVMh+p_g{P>-bjaq#&1wq8%3B@rB7ny0rt5 z!bOaCiDAjU?@FM+#}#Aw2ceO9kvpME#zV+@a?5*xNk+)?Qiwj@V$Q%Pe3rl(OTUE- z`^_$$T)*$##KMp=YT?7mNWH(*q-8kleq?R_e7F5O`gihG`;moI?HM zG?}gBOyvf&y$zvvbE~a~de;Bg(f(_V!cQps9oqE46ON5qwo z+SqELzZOn$KE9#e7T!v*PguA zhU&Ya&*}Nh=tP}uQku4HN-m~9qG2+^3WMS+%Wm_|-s!;PPRBoE1~L<3UL@EqKO`xA zgY!d>jpYQgxqmdQEGiUMIgjOh0u}rPcqQwchoc^CdW|1y_bI}H^_@)DH2Wl*=4cFD zbONc@dj^NiWI}c2CaK*qB_c|OG)7cu=(dQ63>>S?Vw)&{s2 zZ{x%!sOvT%7oij-TY7P>PplIUA3?j!SdDdF5~G&FU-KoSQKsIc4>as`Yl8zww^} z9^bw!oI2aIv;dwhA*i}2C_dOncu45I1#Cz=H3DE0(J{2`(_i2F@M0M8eA>Jq;=4}jK7B(h`kGM7RSE$MxF=6)?p6#W z$2^nK1S>LLWAqs7^ugf~r%46lm+VJIxrIeQ$s=Nz+GsJGtX=wAx7fl#A%s}{{n;XU z64kV*eua``#mres9*Qjy6T>FnoLDiH>KOB4JMSTXeQXK^!sB za%ZFBMuH+tfJ!?HAXlw}y$Ycm8M<*3b`??f)UOC1KA@2=dRt7MfLP2C&?jvnEjNu6 zGIIiT-H@Z1fDzz>Y7rb^Avk>3F=f*}#h+Lmr3by}dnVik2XBX7^d-X1j3^24xoy~s zQT)LNvXmaMD{|n2=2m}})NR_YlQ}qy-I6&#Fn@Lb$A?@N#nf+X8z9o|mA-+P=HARB zj`w{mIijWzdoP8e6#@JeJrCUc80~uW4es;nGW4HPw6?Eil@ozI*FlGdzjzwx;TaV( z&kke$95~5KrybWdC2oi<@`rfT0|JHL(FBRJKj6OS+I<*L_pNeqtiEuKT9ikTo<)S- ztzHe3j@k(Mwxhb>Dt==R?bn3WAXK#bl5^6;Y&JwiYgUEHZHn86`G!G&LM~3&H!5O6 z-AFi6jUyn5SR~te#m_7fShU)5T#h9i)1^6|9+kej&CFxWWl*Y^Uj$r|a9B~@Fy^c4 z6120fP9aL0PVQJ3Q5e8y&|-^RArslAn(3;S)2`cLfy&s7e@MAKB3}(NyXH_Y-_!Sz z`$=&$DFF~k3>}4}X1|@e>5agj8>eQ6AK`~K*-6357bEOr<>y2suY+G*PZuRf5)D8P z{|W6L)FVzFzw4=swtd8rJLt0YbsVrBGuFVm$k6`lNfoglI{az60NmbiC}% zwNgA-$RgyhZTe^o1$-TOps|{EyOlwJd4^)(dTT9y2Z+^E>t&%F1VV zg}C7Se_|I^#`~*e`uZ+bY|;Wsg7h zg#oVu#P>zH&4+74$n25hWRf$Himk=0gEhWG}!+-_4Z+2+LJK6wTstj6wtMv8?=2FHMhveETR;Br?vk%2Ljm zKm>umisKZwOUqSC@xuYdZk1Gj| z3RsZdiicA(J&OA3cr6u zXfD0FVY#kxuMv=A>J3+&9#S*DUXFK1VQ)lT4@4Xda9WRYJgrZ|5_-SAnAk9XIU|{0 z$--E~N76D=wmuDFa7GcOL8BkEH$5ESbPkvX{%dvzyFHm8XGN(-2NsL+75#0q0xX_Iqc*Z}b}Q?8)Na;1Pe$WFH8c z5Ur-7KS^yQ_h)_~=l@Hv0{0l|Yq~T06$2XTqG^Nm?>RudAx@-MjoWsO4MH8oX{@cK zLE7GkQ{#w2Lz0BCAR#*qiyPl`#6sza=Kia1U{%_0$ zXlvxnxVowLwu@;~djDmv^fa<}v{jOMh9SwYm_FIC6lX+%28REN`|;QwfF5Injj>2NZK6o=93skQ-@O`RfRmXU3F2wg(cu7LCqhM%+j7A7W2M3Lg z)AA!iLwci<&hXIxa;Qm0H~-iEy;YHTg&8t+Xl`4c1fPEV1&eIG%jpJL?~8X3HCBk0 zCg$M$+7_J(#Zs@&JL~DbpvYOYgmk$yBo;#`bAaxCKqI2d;htDiw3S;TW+u79EKabd zfZEqSD?@@JQQ{jWC-i>|v8zzgA23utD(vfMqPk3Wn&i1;(zM9pY(spWh~Eg}7;=Np zZ{7ZS*~0P8EX4sR4un*D_k@t_fVcoI>`XhPpiyY6l?wN$@5azvAAIIK(7vv|4?g%E zT29yTZph|Po}PpDo@a+s&ec^qdmw5bmD=<*3LdsxhAEVU1cj9>RXw#&+21c-fq^UmPE}QK8aRck$`YT>`b`j83 z10HCad|%|Kox z(Fa-s6k&ApM39K?>%FG$;JmG}ki2G50Cb=mCxdHX=jZ z5Jt(zOi1}Us3$NRLFbOEY3HRqOgGz7JFe`U$9o3z8WGacQt=^3YQcQ0YbD20F{(ELl=Bo_YgObxPf&XaNhNn^a%+bsQ9X-3}{+xy*P-u!x$&f3lQsthsRo47CKBnX_$);-qNWdRYq&-}+*QTahq2QcmzEDG z!;~85=>xHs@$IF}iuDec4~65_U`xM?=eJ0Uo54~8tjx%Ei4mo5#|hDni<`q-zKZg5 z=FZb%AFk<5>CkOA?Vf*hwW_2KS61*JdNUrrmR5M_n;N}3RCvsM$_O}2wg{%JJR?5r z>4P03unEIDJ<$&Th1v#(k8tPTNOy1(P+1k(v`PP-b8oY2IZuk~X^6XEdX4Up=ldzI zcPZ?sNb2bqY%94O*Nd8%F1)XQe(0`&cpal4wigQe$==@LaVsCWL?Fl5$ucB+A=922hq*o)uL4o{ooj~kGH;*UWt(`ncl zj3z`30acyW?3MvKfyN&UqvDkraE(Aax{*FCgumQGkhww67A%9?bH&~t`NFHZXPCmF zKfV+;(P9*?OdOYh%6yw{;WYa}_3d^vipYFw**Eb|eM`3BaO))n`oWSsr=LP<^Jv^vamfvV&5;HYvm#?vi9QoOAG0`#M z7XR5GDM@2AOnbd^=YeFn1@+8e{%v0BymRjvhMOriayQB*t%k%D_DLUK zYhhY?vWD+oOzhG}r>-T9h(}mXgm9>uhSN{G8`rumaZNZcA$Jg->OQ5`tt?IykX`&E z=_9_?_HCE3lo&X~R5vdnsOmHDpa z`xHneW9Jnj#x-EQS+a06Bj&0kvZ6_T&>Dabyb!?|&+qRv1-fxWi3#vv{ z?b>V4HP&NpD$7N|B!Ny z?!KL{qZ59ofQifh?e$$W1Al6(yLh@Zk$?wF;glh5c_gO=7GCZv#a~`FjF0ra{~b#j zINs>_UUNrOZg%zJel>o?1nuqH)w5a{Fe*`*ffTuA9oplRUfZ->5Y6MsxQmSaz$O0d z<3wKHJjNc#h16mKIspnxmmf8R^ah2yK5;lAF7Fm-w8`Ye4d~B3YyW*@JM1s~`se#;zK8Hn@9dGB(4`-4QB+nHp;S~oGNR&sliKp2^Kb^M) zGC@QH-}`I9nw|llXDM33v2aGwgi$?seWam@Wnr}10o_|8TE0yG#5a2BoyuVN0mk?M zVfkM&)qaL(c|mtFIp!3G{koS3!Ed{2n^k8G5`U9)kEHU)PQT8Y^3x!7b{0NLnb)8| zXh1k?z??CWzYoM;l^>1Q-oFv^egq~-VPjYbuRK?sD2R%I7e!-MH$^nnC!QU>UZ;Jw z9TYhBn9k{S<;fSv$-Os^g!SMofO?dO%sj8S`kW3vgM<7zXUj? zW~hB5F(1YtICZqEH~5So1beX9*?2sCoF~GpqMu!9m;eb*iop@tlezj7+HQoHX9AX5 zNj(|VoQ^h$ox5wCOwLM~lrAW()Y|-X#qB=ZiK30E2;IBa z0d}1`-}o#bVfhgf=1G$0OF>^c#Q1hc5#7noX}{$GeqREiECGx|M0s?iv;L6xGtHNdEUDj*gQky)F^s>U9?G1q-9^MD6MJ42R z86Q})y2^h16|ynAydL(zq!wZrTfuekd2}0^^~on#A;43}@f*rps4r~!j5#q?~C~I5nlxr&U z9&x2j&~5gvLe;H5y}oBJn2J(kdDC_5I19GxUMg39ZrL_#t`b5bK3J98`I2JhV19y6 znJdd6|C=hqXOsHs>oGd}CPVWtN7U98Iui+ud}=bCc4c9$IlxKb%wSF4vOzWwp7fGD00PoOGfl<4yGklb~xRZfe)s7lyl zXmd^&D_ew!dNY?Jp&;>210WL2B?JFpI2L=^KvFf?L31UN_%t4**CA!k%{6C-h7%DV zmo7Ue_a<^(o~(VlY;L6#tqg`q>D#YA6Y<(aDDN=e1Wrf45F7X;tU`mfW9{Msb;p6x z48UjaD;N3$$jEPkiB3w963_s^_=-EE5XkTRjV6D!ed~lUJ3H&w#PFT;*@pa$wXtQ` z?}NjCbfO#cshq;vp*`B!a^tRWCD48-Dq#z5ZSRb0XZF@LqH!$RFY~9r!QSmWpp`u) z*e{X-e(7&}IdY;M6^urI1X#*q@1O$VhMbd?tw6LPC2<>94LLFYu#~0izeuDQ3X+C5p(v4`@^nu?n6C@s=o%KdcgE zB@Jk2aDRB2&_1A=Bz=4sIi|GQp|sQBkhRle(LDpWkmlE7H6pk`y1?3NaESvKhv3`o zFG1|F4BPjOG6VPB$F9Q*ir(wr$|SGqt*z0oZEdcLl>x7)*wx-eo#}Vt&Vb+mM8F$1 zD|HP>z81)mYA`LT*m%BB4iA%7IIE#av~+!;PCgHN>0!yBwz(fIWAkCq8aqt=&LRZS z#)Pztt*_tSzxyk==`0PZmu~-cXx{+$d(H^>^l<92C!(1qs`IlM0`BUcq@g3gq7uj+6SE;59@Fg2 z1r}8OFJD=tbq4BpxdIW}Vhnq36n4@D z1)zr=^n#@L(H-7lZ9?YAZGiq0*$s`TL=8E;-loC|MpwD$8gHLxXaB9+CbgH82SMV z0~_hnkfBh4fFUR!Y?KTYc{vZQ9wCrU^~{}q$+yC}_mXi_@JB-oSZtx}zcl3bk$yJ@@Xs)94|$2= zr5ruWy{jqY&%=h+7lt|htE3YYhD<(uP8nk2b-QCifh{LDM>vU7k)nMbiTJH6DEI!b`BQt%_&Y zVV>x^Ax+=WXfsi3hmCQ2g!P;L-dMYw%N$F}PK7DV%~n=^W@c|@^J(KtKg00XfR0nf zUiHiim@$I3Ko6h?hjx->PyzV#`KApyURoYHBZ2Rz7@8!F+pkxcWVd_F)l%4}zi|Sm zV+pvD+(DX}A|-r<7dHlFmA}?{OxNsJGDN!SHeoG&F>vJb!&b=u!X3fSm_4ZIJQ*_} z@gxGMf3PAO!Gd>^Obgw_@Ki~Xz1iW-KP_KwQ zt&%$WpmrZKfp~v~r3dsqcdE3O=P8e{50w5()=XR@dg0;gM3mQmW?0O=c6gm-w$?R6 z5@c1eK5%6GF{<_4#Q(qXEnL_aBwyf8kupxTmA&SvtJZY{^TryB_(^ogcyeKMF9!i~ zF~h!z^TH>N+#5zbn)jp?zScefeV>^`>1rXO^esR%UAsw)#4Qfucm1aRrj8n8bXye; z7~-5;(%>d}N3^DC)BUK8VAo+(wbSQO{$W)iWv-JC*}WR?G{)^P8nUJ-1}bz-kX!S$ z$-O@OwH3Tpu?bEGWdQExpKjn1SzimrI|~rIw1oQNneqvF;Pao(D^Nhe1j&deA{gA@ z?^Q`+o87y;mf7IlKM!|wb*F-m?=5o%kMAELRQC(>1MVH`m0VjtEgJJK^GWhZo%4b7 za|{Oa#BV&r*XS@o z=}B*!`;mH@P|6Q4`_Y-NFejA4%Dc5|Gl5cUBak>>uLp)8_(2W*HaVbz=dm5Y$HZaVZhT9JQ-OKuJ6Z;yJYEy^J1r3d$zR1<7 zU#4&m?Lbin^d4o;<{1US;R}SVF!u7VMO-rn0-nK5WE&A)6u_x8?S)0;t6YlaQKDF2 z*rsc!D9JprN9Gg~Q`P}7>~vY=-end=_l9Mc%)LPn^TT7A@uz9MSY$(ezZOtP5W4N& zL0-S+txeXu3-@kN4ux%(Uvt7LyngU})3i3X__zoPZ{8F4ZVBfbWK$IK z3!meg#yltEO=7H@qWUtiI!}&tvj2Pcifn1T&Ib=Qh;?q7fX>qgI)mvT_!xxyiTv?y zqQuw)Xdq>c5Ly_~kx}kaSbDm@hO|eTK6QjiZXwWfImEFXJFweBfZn%{09?zfROIME zUSrMSR+hgMJ6KcL^f_S%$W}v_dzR0lRgvE`yKfN6hcQ1JgXtxNl&9=DM^@?T#LbQZ zfc4kz)RtmjDYB+b$p`n3z)?e7h8PWX*cUL(a`NMyxmS$yd+yKXXnqxoPwuMYthp1e0@VV3#HetJ7zeM=tbtUO%l<}VZ-Qu%ksC&tAt}==c+(WfJhzK?6 zREy3xaW_)@C;t_g>^FW744)l24uGQ-1cTEaQUwMC5WCE{J{$xESYPgQB!Ax_6Z^w~ zHgIRs&XeUo{rbp>A|Hpyui5|=zkSeJ#n|t`FR{;JFV@|18vsq`slkzgO*I!sy0x`* z*X7l&$>kTVP=b-*$ti9u)lVz=9HY@L>Oi{xFg*$f3LwnZR(`iBEk`|elEtw!nS|L# z0b^)L28SR`phYz>CS=p(3=xNa)L1}rl3Xnviov4ZVBB6g?-a~YE5vHjKb%!tW!{Yn zip?B#0$ZRZxY-GkF1K~5MwSVYv9z#$4_!_ngx!s>;UU9?(Jw~by&aY3E(0M8 zAw-9ZX_O0-?vymp!9p@{l(Kq%iHJ~xL5M2xdxU;MN7ThS-6eHmn1J7PV#4&jmx!DV zUfR<7u{YLS9TR0@FtWa9ZC+nak1JQUj0w}*xs+P~eFbuFPX9K>oSs>oi`u&1vYzSP zze82|fA^*mbzQ!_HW*MtQT^8R;LJk0(BEnjZvYN=yFI+{+%fGAfGx0Shgkwa zj{nhxEOjHd0B->yh-d-PuQ~L>oUIWnW47`0&txg`-U(cbc<*PSn9{KLG;Zkl$EK|b zD1U-SvA@v!i^W@-?>~)p@FQ-s$MY0opfkCW>~-leR(>Q#GO8*(89kE zdv`0Uqw!1ZN+%#!kiP0jP)Jr`+pqqxE5LJ(}q;D9g*tO2tA$$8gN&VAb23 zY>oroSq5lqFp+<65r_rEjyk{xi44#PP$+eX>LmmP2p2aAij-;)pd1KABU*6xd?K87 zn0@Ci43@Osf+ywMZ49J{b-lNn93!bu1YosD`@{FPRfUfrl5p4UzrS7x=sc^7or{R? zYooePcG7nieEus_kA!vQLxx^aO@$z~S0U4~1Thzp&f5e_gm2gpLxBfhg8uA;CtP8Rr!t`~h9Q zedv;<-r<#9hZUvuHBws+k2Y;6RNocYVDB|sz8=3{#XLTMEtUd%gTebq{)NbSC5&Mqg9qv^}ojS)+gwg9<{leXulpZtm3(GXU8#cTEd@!na z5aS9mB*+BBtN+$4k?-~fXz-j)`&Ws7z{ipaLp`6mrc8VN2=B| zC$ve}cH`qjo?g-FJ3eRy1LNoF&htK|l`Bp6uY~P-vV24;p?-LO69V%uzWRz~X=4qy zqj}me0dXS!G{#AQ8Qx5&UxCWVD1#CH5M~Da2WkpAoPkMy3)XtyoHpy>W1E$@U@0a+-QFX2REw+ZR!Nswtjh~+hhoAWl9G5bJX5OL6%K`*uw zj*ckoEphH=Lq+!ZIR5!Y_5)QEM4K&Vjqc}%AJ3IUW=akn30O_qB~9!T>x>0Cu_(<; zKim^gM7_(dwZNb3g@#F{dif%IzlIPxHdM~O3gIur)q1AYQ$5>X?S1@B%_U#q>ZYmS z>Tc%kjBzSQFj)bW{Yon|cr@t4O>>Nk&Mj6p`MW_ehthZa)&l0Ujv#FhVaCJkqFVB@ zFfM(A#&;o3jOXDb@vjGEjdPa7&Q0QF8h>BMEJ03!O{G?v z$su>Ih#S3KaL*m$4Zxkd!vvM@eU8_Mwl=G-@pzjG--p>-;=apE>K~>#9sN;;=8B)A z`-#AiAkjn}^#K-$Auu7_2KEaN?p1(roQR;P6JkrJ68W&L$&Q&__G|MYLDKJNxtL-% zr~Y>bCGPmHy4yqxVO0WJN6eVR>wH(KatpDugT)Y?VYXxQbgi_koDSzL^IHj34?Tjf zHnyl3%kP3KSN)c9A)YBSWUF*M!B_oR<|u&ihT#h2sfk!@ z7fIK=jbblHj4rGxAw(S(V&YqhO?`@aOObu8Z)Zsx%_b=uZin#Ym3y&&XX3Cir!M*1 zjT(7^lXh<5Zr=3dE2HPFtbH#L_J=>l1~V#zFN9EDhDbJ%W*qa-Vf+b$<9moU+w|JqJxW}y{uqD524~blikn%e}JVqEkMCUeN!Y8M9I zE1pvMiB+gq3JUD<{;0RTDcimv#tyyP*A9g%nqWb#hs1~|ylxeVIoz6q7D z6N3pn8$q#f;LgyF=GMKZV!L7x=y;h7k?+!k&faUJtW^|k0CSl_fzo<>Pi2pQxAgCv zU)j8pTcPn%Ue3C|d<=iK5hWdBYx+Ewt2VsPI0DF|CRQl?=js{97W6wkHva1C%qVn- z4LYCtLP=12@tOa>768Bk$?*bYAQ$?8KvOB?HSbxl^j5h-mmQxZU3Van7URGNYq2PA z=t#@Cur%J1J0C~Y{-j|4OnQo@XM8OEmWZ|Sa;U!WvCBMoa|R+F8*(-Tg!a82E}6Rx&Dk~sajg`g`_MrtX)<_Jc1&|h1K|asIcHK}d2Y?W=U3OK$gl+jx_|lxdy7pZ`>UnpZ|@0_5O2|N54i0PXEO9|~pozc} zh)B|K4~eS@MF1`Y16+WU-vz4BI~1e$#Qc42S^tJH=r#$B__)F2LS#=*kCrTRt25TW zBx3|Jyx@udgVB0D`xQt`(Czww~HG^q3&Y%s_{^zr0yXj2KzAEOWL zOG{8!c-&XQuHDbI97)YD)L0>-NmWBG+pT^LbG(&%M??-iN?657lUWR6Tdwn$Uns-sT{7C=+|2@3I}6Nhec5Bl69J%_Ey(k-E8~FgQGC z?${I)E59Dx7`_a#;G}-RWMY}jVja*0**QH;<)a_}N;EZq2Yw}#?%}c^8f9&H%;>)k?z?Of`=oDk#UrnklAk|+;PL!Wu^_^a>tP!Z4pWO=i%;s&EnzPBQRwK^=&&LSgS9h^$tvK;p3CKxk?Wf+f zfWF9>2VH|O*r_6~Mfgn@PZZ$)z&|C31`CpEF%EN2sMR!=1Y$ZSAaGND+x5l0KH_bw zD7-H({KQ{xdW5UIHeI!ORr2{_jln{k5)$w^#|bbu*&J?N8d-x2s5OYYwFAd-X56l2 zg*CEh7RA?oB8b1{X_Hd3U8fL4A6~n1nsDNHL>oBJLm?h(JAQ*8bV&czzZ9ZP_Z!e# zQ(s7CmrIhQaIe$H!~i;o1F#ul(MX8bLGqJ&Sjf1%xqAAI zVTB@r!6YYO3S>$WtTg^G2AdX#7{@Oj@ace(2hKfp>|L6WTfQm9nQT}GzxwPe)^!T;z%5@|v0KH>WE-~W$l~oEBew-?o3B%#KftV|T&FLB@5TLYC9tI6aSrT!FQGAU(nhO^F^2(=)Y@#GGZAUi^`3hrmtCcp z+}+y~EPvGx{rfi|%0o=)kVFai1N<9l`?Csw8w{Y31(5#%2MPF#Lh;Pm4^MhHF62!E zflPv~IjbnXsdT+mu0PE4JtEVebd_ugk@q#lxBB)Xw)#4WIrhNtNe9@BlwaNYwbrEC z!qX!77JKhUuhCTL!qhESM8WPe*x@s*UNGjm+z(^%7xLd}%FI!2JJH`T`2K=U8{`^G zGYg?BVc#s+JRQb0b&!K0+i_r`hC7pY*YWvTgFLuXHwh_kPK&;|GknI8_Hy)B=70uw*vmak4o zWgtvL0JiypY2dBB?>BL^wfTOpuhefpv4q|NyklTtV8NlLv4_yRMG8o`|_|(MYwfwuiz&78YYZ=JLMWXze^K?#vkNGQ8Eg|@x4puq^RRvO=J;%MUf0c( zcA069k6uT5=u-rvK!3ct!N8aPWd&V=y#jU7l4iJxpGbBRB*UZfzNe^_*ofJSD0e_- zA5e6O0$HwKQCGJ1iSGPt_i(z$nBSNPl|z$(;TD&FX!rOEQ@3&J38ZZGMqBG-sjay^ z-sZ8p*lsCNrnoEQj|YGV2rw`lq2Gw_`Rhke`3=H=W31F63R}PEU&zJ0pvCV&qhTu> zV>W*q{bNn+iL(qctc{#)IG) z2O3%!3Oa@hcsSsiO?vR!THw37S7&4M`IY~0mzLqvQuDXLP6mSKX*}iQ6RW@M)1s#^el{PlWdpaP&tl7Cw88=AWyYyC9*OJTp)Z(}8B5yc z8!t!Y`^T0@qEN%>cKKG4x4T<|2?kw1y0Ei^U^4a?e~4?>k07d&`_x3pPD?<%USREq zwY=Jdy60BALxuw#1F(#_J>>PC+U?2JBYls}=J4oovhDr8&Qwr+6t};|{^oAPwQ8l} z@Q!64;QE6uK!WsPh;9)GCjKG1V;-p4bdq@g!I3MUMpK;(6*!^-S{Giz-c98=(DXic ziC{m>SfV0gvG;Q>989<2tBM7u5&S;L_$5WM%Ekq}nGo1fIMps|hV; z8Q@JxI$n{6kms$7vy&4ij}w{b+Rw`X9d307>MDc`|Ho8~LTMqj_}{hxX*YA-y$vchb=y+%j-fAx#gGtoRAuP)cacO1=+6gmx-AfD0fDv ze$qaO%GYkvZTng!U+{7{m6nClMZZjug$A@az`qa#L=ptpa#c&#!IdKuAyA9(%`jGN zyDC7JIC8;?l=Ua=pmM-$r_Gxvj89z+WZzX&sE;+CyTR(GkMkwP+W?Q+HP*JDqnpMf zU)s7P@faK~fxdb1WGx;$Bb}wvxXvut76JlfTy+I77Y+{@g}!6E-bptb3cG)h`r8zv z+1VTqgSRWS7_OLs#3&I{h@hD{s(wnp=?dNk%)+nzcdns=E$MKsXZLX~KEgVyR%v$lk=}_jmj^f1u0;y=y&PWK_dU(DBlB(zg&%-Te+@_<7LR zn9<79WQ-aJp;38Cx5d9Fusi+O9ujLeqPkc>*1+t#^h<^Yw(hW$m@mM8N|@hZ8PGTW zB*|y8Lzz=t_W^I?P&`7Rq zO$8ST#15A6hQ-u?FRZ_Sd2YbHdoPl}@t?T7UMEG+8+A$q6uLiZD_DX2Qa}JUQ^zH# zsx*Z!=c21k zKB5CcE?Qb#4d2sJ%&i||YyNJkclA<|c5cTTBT_rSEhV*Qt6Uy zYJd*FO+SN6yCO?GZAEVp@3SHSMXlGNLkQFN?qUJP;qW@r@8*MR_h|@IA989Z-Ie>` z`}4pQ_`tmT`cjpBwJz#!Rm=Ik+(-WEYj3RlO%b@xaszoeJ%9nYTp41wr9MLU0f2cS zJ*bzQN?M|vQP8AlM!7r007NB@{N0wpX0&cK)|y@3HQ?2&SA2@Zmw3AcTWVK`+d!s( znI3jxBafVu&ey*^!hr0!>^LaL{IEVKCQtRdxA5oH@_((0Ze`laDM^;tqM=MrfHU|f zY&-=hi3J!)&;=CnBwkmM8^&Q%@IkS&@zA&E>ztX}TbR>_TnkgAROMF9=bC{RgLw4+ z*8)ZkGvz@o;*E}(SuFmE>dG-NcsK^lY~yjyQrH-WyGSlODu%EIXPy&hIV%FGnsk?Y z;yoR4R>vQ^z_DfUT{v16t^n7xj2EQI`<_6~$-GHY%?JAVIGK<8buMz>ofbx0q(4e* zdB$OlM?nTIXW+j4unZOM3jXnWWf)U}N9xXY+aQ3a0nnSsqdzkhAFE9l(HeSIELAGIo3D-W|62+?T}q@3iK>uxo+>fbxU z@bu|HVv9a0Mg<7)q(+d0Il$BGQhq(wV>GxOgS;p|!Q=nBZ@4~zc|B6DJg8UIE@n?0 z$UA>T2%^QME9Si(v)0Pse`|5v?~Z3n*rn9p$vQ29C}atV&x|lZzA^U3qN?6G2mwcC zaJVou+5DrQ6vmrXScPKil6I5%5Xji0*HX7!32DaUhdd0vx zabbW$#a!`{(-3OXcr9HgLMv-K-ekDG#OYpkI{J~{ulrJ{CtH-*IC4{1kMl&uT?QZ` z--uGtW?}hn1{Aw!;>41C9tm4zJN!~x6QDF?Hy$`n@yT}sHYR5wpTaCbMa@fhXQez^CIaoe+*ES zm3(D>fi!~;eF^+4+QPz}>oh`*ZpE+n)kUYW#?p(1yQVd^#+EDgZPxsKB=e=ED|f`z zA3ukwxegctX;h9=-OB}Qu!81y-q1l3ZuUTG5-BoA=P@?y{w}xJ zaYaSj(Qrm3SGQFu>nFO{ErBU*3RA0C(!nDb7s$g>K_2}Kh8duYJHQN4fW?&>V$;Yh zJZYJ;uq=q#+8VLb*=VFe)77;H1bEgq}Ka+mGue-x8$H$CfCT|dz2%~7>-zoJZOdUPMKjdQc@fd1UmBnUK}3u2Go*rh z=Niv@^ugTp{_e5?lIQ$13TPh4d8!;?@7(^CedA}N{jZr=V9VkJ*~LG=u#7}_^9C7O zLA{KbWbJqSEU0dQT*wTqDa!+bp&Ugub|!mX6Nu_`auJSUcNK; zWd4oW{a2v_I{kz*uA3C%4nhi|^==p2CGW5AyOW8#{*p~pfQP0Un-c26>sRZQ9qw~L z2lueg$P9)yhpxho*k}BnMHE@QMeLML+!9hV%x zFcDGJ6GghHaq8lf#ryV{i&0bq>csRLsEJ`q%?1^XUxS_;-Z|!hJ+%t~-U;Zm8gH9m zOJ{?y6Ih#uiO?!}qObC2Rm--l&gxgXj*pcVx~8x$axu$e2-lvVI*$w-)rnkLo;;`z z@L?@UU=vd!%Vv{~5$URo^89QTad&U}p2l4t@ynyz{Z0)*6fQ2(8BF?+)_xl0=Wo|Z zt7z?vhrbr~D$)IhoBPByJg%w`QcAc8r3i;}U;Kjz7!8$Xq;+9@6WJ+s{JtlhFdfJd zHo5dbQU2KO=lTAWcy)h9l4yI6Q+a!>{E^A`R1*`PYh5Af#=KauV*oL@Z8iUT2f>$> zYuc`F$#DaHs^#AW0A|IylP#r5G?k44Q+njz#dBC;-v#F?nVGeG^F4h>w#m=u)boVv zT(p@)f$+o@f~3^AtL>iqKQ|EcF5?(neoJEU$zUBZ<4!kD&naN20j-W>RH~U7+U&KT z^otiM`iR~=YN7x)1OxR2ENchmIE!K)SzS~-G`tBCXlPx@ymBw!VC%d@1y&iXT^Adu zypl7K>5gIfk3Cgc1i}zHer-E1!-%i^RCXA@j8?+F#H?fUd*BE;eJt%X5^5l_jo2t` zF|Bn6ebC#qGCbCYi%;M6?;Q3fjvgV09LC##S`1&nDK8(PLtP_obNAC+>e%=|8ra7) zg?_xxh@3>5sbSI@cXP#?J3oWzoW1^{LQi(Jv8QTZdBo5OA*o+?ohb1kVW_6?2wp#x z^VAXzhc#%ANIav&acDd%-lWQ_wErdE&&B{fx!D*@L1B&B&f)(6FF!fpoAn$U(P!4=SspU0$GjsWfM}?gFX& zoYn15iDIYOm?4Knz%$C>H*d2NfdE6V3ZLC`#>8%o9ymhff%pWwK?VaFmNm_z|9KL# z8mZ?jX}hZhUb;s93#dN{{VCKfkhSR~b~$+0$^ZqTWxLJO9+2cRTX(vCD_|v*9h|U&GN5ZU5y-MN!^(YDA>} znKj7O_sdPB(|1I4|Mqq5mZJcHewXUkZa&uGnCtqz{cB(%Z7Sa2TA7sBc2BkIB=0r$ z_7}b6;V~Jq;||geR#b@u2tb#FLo2EU^f@jU>V`c7$=XnwxhwaZx5rgZo63{{k{z$y z6m?63wer;+*Q~{dI$h3f@X+H0$-bp#la$_Y!S7-Q1|*Z(#CYJ)`7WVT?`sfpwZW|w zT$2*+3MWh^A9vr%Eb0}Ewe|72VNp+<5}G=G9m=088DHpojT!o8zxCB)9^{j0yo?VGPt7^|rY#NG8)B zzB+d|NqX$HHM@m`IV_Lz!;7l;kE#i((lMkyTnh3cvjLzV!e3?beQo`>o=Z^O&f0(v z0zCWRS*+t1R!Gae5Tj^p>m)`h)~?%u=KFR2RJ$_0sCe~XCnk(oUG~d{9$ysJrq40L zqcIwW12o`I;~~U++V$p8CO0o!s~yZ+%dn-mbLxvz_|NK|V*BD%weuuv5Q*`Bj;UfE zaG*tomBD=r(-9PABn8c=RRR#3Eg-94+i7MDApx90r!+Ui>T5C8F+;mY4Ga>4`C zzqM>BTeoIXEa#C7sbksXv`hb`zIWGZ^8a?^#1m&1|6h5u@^?Z|T1;qRU%m^#IA z$tK&i-<;Qq<=S$o#BSnu1>#QoD=iU0T@H9BbM4*UsCRnZ8O>Qa@M9RmzW+M z4w?nUUNZe%6bZ%8etZc~qi6Da$?A!5NKq5^RNG3Ef8Si%qWPwwMddWNsDkJi2)=Sq z3rXx8eS?+A?DBBdO5Ob>G}OR2gvYnF5SE3yXc5hfidW;UJlub5T)&oGEGnTSa>hf4 z$hX<<)X(v9>cZohfkD*8)=ZLzOm%xl!oP57SZlY;Rd91W0h~*y{Ul$Ww$e2#Y5g#x zE7$5rRtACfM+#K3I3xa^jT0#-uUycatRlV#zX@A|#5@`lhC9+PQTzDmwZJvrD!dZaDTc8~_+!qT>4cDowySwk}e_v_1 z?bwp{xyr^*^rOW7B+5n_5f4jq9#ec1A(T@BIe8uAZ{@Ydn0uCD`67u!bS-0eBYE*! z!^Z0)8ehhLye`k=f3mMUue6`J8Fc7Jg~DHP*U167vEOXa4@BPQk=|0VDB@qi0QG3K z=LM`*Cw_Xx4i&&HP}nA`EUN1(bK74WKNf|0B_#w?!bLZWrI!1V-hMIl z^Lqxj%J10Gxa4yH;Ha^B(6o5(1E>oE~yZ{$A<(SqfZvJjFFNQTg@kVCit zH&BXsIMBgTmK-{?4o$E*$#QHKlKQD>^Is5?_uE0g0nS*rzf;h3eeU|Q$5$uhR^N(oSzBK6%nL|>)U=zl`8{B|Do2%F8 zxz$x%_2y*HJ(TIX*Ok>pnrYz@>!phX_`pCzb(hebp9#LkpTQ~eWciV8TJ~J7%uO?@ zxE5r@EUu|Yke)fp7AR+IzZG?|t_C2cFtH)3a9`0Kny=t=njb$>!=$mpDU0N#HJUS4(g)9rQEXFAncHjF`DY{xZnV`>cu zvhK~bRPb&{;RMXX7s{nHeu{NKfiyEUTPY3HvQ62%S5}-I;4wh-BDn*)lN#_v<88Jb z6>>a1v;Tw0Aj0yLdFwmjt#o{s@w`rQt+ssni=@p_iJC*|X9m+P$>QE0Gqvq5Gg2wg zYtbY$%xprp%kV-fwwCXMj0fiiL)-Qu8DME?=iBy*rLZkuRxus%@3p*^!MEVMMyh?KLd5Uh- z=Qh59`@*MjI3ylhAyl)D zLQuv_D5j=^8$kA+{ubPRr2&+DhxEq6`-j-hh@qyDXwT`m=l=N-ht1z)`-gS-N0_4O zVFJn!2ycOwbIPzvRy{m{O_FOS5=*LjuZ9ccK$(mR|qm zK@kVB)=ea;2=UQpbAmu+-jqJ<&i8x!(p&C5Z+1i^2yfTtU4H~^gZH7^>E#6c(_yl$ z1T?jSw)43yUUy$(8X|)lw}4LV2-Kvlo3Yz~zUHD^jOZ zb}1ezc~HC$_*}(w zE81ex@HC5@&8G|XX{_>%hNA)>hO53+`G}NXDEab$c*o)C zG5gpWf>&@YMvU-?hB#L`_Bu!Y>AaQIeC{Jodar=MFPHbC81aa9d0c-FgE9MAv8-zB zX*RM%egS%;#^JxVUf8yg)%1mjTgVrG^u*i%fF<2C()joVei;-@_@n@JyduaNduxxr z(DG3bKd!ng*t&rFgnzJv;192O{;|IRh5r8ff9<6gaetIHZK|s#>P6sD!%C@xM^4#p z`-l+uxpmoWJBw^cls#DW)oNs@mmiailZXKR8egM$T(mtlfVB(ih6@wH4(qCBYi)3J zJ@Q#rx9NgAYcDufaJu9G?|ZZ#(g<%yKDo6F?6GGc-JGPr2M4s&eUnl7$gvVNeq3{daDs0 zdNLw;evX`vQoP7hKgXy4gBAU9ISNr-FRn$>il|9B3yxB&C@Fs1pRqb=kHv5yXpdbR zf*4H0GgW`e2hgBqLrH)$OVGx|g{_GBsF5j$4M^hP{P<@A0g=vGmwO}#sbFizAkS$h_abCvWsW0uQV4fPK?Z-v_KuJ-S+)} zG3xD^@HFsC1#hwV!enBx6*$j6T1amf_qhHfjV$)_l#E`kIDr`o^%TIj;LB&*B?4HKo6d#iIOkU*ir zklvRxOOv`(qHi8FgMkUk<()5hz-^jq2DSh)^KAcIwNS0VPE=ZOs~OM8DN|*W+yt=c z+3`Cqy*XY_G-f5>wXy7h`OhtW-Kli!-XUy}vneA~H>4;7@?1bZEMakcfqo83F#$U? zech3EV7Eje^m3OD+hlC8Ue;-ZwbAsbh_RqKqMCo;(DJZ-)-q@zJ3VpjpvL-m9Q;JZ zvW+T!>nqB|?a>vha3gr$ryKUfCHp$nmugKGeF$pRam`J%G^}KP;L#ZCtUaC(ATS@- zR#8+%2O!+B4o1i7DTJmQc~Rp?nDxeirM$_UWw0eAYq&3kbMzl8?SNJpq)Lf>FbDF9CvcY-C zaK?+hnRe(yZnVPyF#BsY(^5qvMI3i^S2sK7T_ln-*UP%?UQ35pPC5L#P)Yr59NQ_F zakO(kRhg*SO8F`U?O-r;RD&Zu(lJE`Bv!C1Lvb`DWwSPsi**SN~O z#s>ByVOFj^y#t=}4cBFqeH=OYBs0F6e6Qr5 z?3ygo!QGfz}rn&WXq{5i2fZD4s)-C%7n_Ek+z;rvGv?HEDBpFAJ ziI>oAblG!yGG7RR36|X{TBZAxVI6$Nx!mhenU{q~^dg&8?q$N{2ot0=`M{}=Sf9uD z=ZMWb(v8|o^WGKRFbo)jybd&J_KY9D%8dEOD@-C+WxS;9zC3d-uG?jsS`$^C3M-b_ z47*X(w&QLKhFj2+oAd!cRI(qyOZ$mD3~17mlKAtpL&e`pElh-vfk(MCgZI+qV6H^BDVD2zJr5wp3jI1_Snx2id zY;1w`6;ke=sKL}}td}d4nj!M7DO3GGr_F4rw{8xVf%?{9b`QqD4xSn4k4kCB@PxGf zEF)X-(#!k!J=(%Ctk8VmH19M!s~YxTY}vk>OaC≧H!zEaW{u7RX}Nf6fu1+Sm$i zVQrLf$D=hCobehkf60=(p|AI%KPj|;^v}}ujd4=s`s&<70jejxCzIgSnwP2{cNfbi* zAf}T(8OYayaF82I8XkWeJit3Arq06Z`NKxL@@4T2)4AwQNw;D$$G13qT2Ld64|DfQ zc!r_<(USnzo=Chu`JCbnp9VmQ)fxe(30R06&F`=Cdpx-1p8f22Yms&5?@M4AXyX=J zA$4OhSM1--{;E`c(t07-^21~VWoPM;2BNDzC#fdv_>cRI&a1qbNCzA@sXpSd?mo)b z$Zwjfgi!{5Yv2b(6)sfkd*AJ8&A8Kvm^!Odzkq)*@d7sh+y@*@^|Cc$^^;P3bLMgl z3jvZ(I7scW*>lfq1OljksV#loK?q~rJRQErZwmJVEhp!Iv+}6Il8|g;@CeP32a6Xe zORL!oix3Q%T5mw^`k5wQLs8hV*?rqu4BmDIBlyd72sKq_^+^JCs?GY|XILrwWz*sB zte#iAsQwLfZZW8KSJis3nRh$CJWpv)o?m2v_*t7P+BzNv8E8^<>KERjwYGT73CnL2 z8tH5M6mj4?r|>H-Hl*WSyoYcxq|YE!TV>dvAz=y^BqLK2+y;6URaM?L79*_%>7p88tO;^zCMr7-(#m^P73)M6{xxKYtArUuejT*ZQbF)w9XVCIa7CZIog9zO5EH$H>y zDL9mBlo(bo>~@BG9?D|vS1m?aR!TgJWfOeZ719T9tHrnu2xa>cmRq8M>umNcn&$Tx z<2l}Eg?P!~O(3uN`Fz0M^Am*~b63!GG#xsj2;eg22cCh7XM;P_4bX8RhY zZnQ}fMJoY1-ERg1A~J=C{6nQ{2kASak}T)kiT zj|WLI>VWjx*?^TsoA+lQBTNd@s;L%u|HuS`&8KsZiPvb+#>uEPUE!n64Y#Ok$8kJ; z^nE{rT?_|VD|7Ryc|G0UPt9cofk3Sf%<=C=d~LZ86|1v#Cx$2K*m zUoYITI`?^JlT<_OB|d)}2CR4BjrJ_!XL1e#>p*UaIYa$F<=uGIl9Z7J%~ScPt`e<5 zI68xy8MlC^8A0$})sbD{{!ro|RA^@b4PiH_`Cckid8!X#q}N@^h%dC=uP&+pTOer!rQ=H3q-<+C)ea>qmFnAN>vDA0d==P!9 zx+?q%WWXq;#5kyXiH<>9~3V>c_{d7$;+Y2dfpkW8I32TEjdPoBMX=pW50@+AYfU@c=v1O zKu&kqGGxaZN>^g8XdUgk=oHu=fEAGNah8FpnGdJ_b&n5ffJn=B_JDL$lXZsgKnq{^B0*&1a~B&BrX zxwSLan#0Nb4fTmv_zn)008H;!#WO+rF!T3cYnIscfqHv?h^-H(7qh98iGWTQL7W9Ma2AqQ~!)>*D zBQxAg9exkLfZ56XdXaH>OaH`NzIJo!GOS=+@gHyr9&C^+HPR6M4-o&i^XhwD1ruw?$}WlODCgwy%=_lKbzt3;w~ zP7C*oY4y0vxvz#_oaDo4%$Y06XBCLDghs0^S4msrd?TyHi9D=nas-LLT!GRB7mCL{ z+IYVTV+5;1R&;G*^2(t!twJk%=h{4*zvEKR=&^n2=tA*eI_W|Cyms^19varO^kQUP zE>p*^&U;W;zGlCUR5G?&W-JKFx_jPJ+PE5oJUTLqr@b3x^DVAsFZQw7#$`_gr8Kl? zXR1)n1uvffCij&j8PzcL^1P`YC z`D0I-4W)BKzW2ij8ZM}Jd{?fDGy-XE_7(Nanh~DekG_+PR)6$<=k{a!GYRRQk*|TS z0VTgI$VGig^tSdTQiL9RNMae(ZHOPNG=L74?EGrE#Y(3LpzARFu7z5vh_oJ(wc@Ro zA245UQQvr-ST{;8H?sC1zX|=X|oEny+9|P-pCiqXZ1d-#Bu0)=Fi#A7h zHKuOCi8!b1uIk$ro|rv4tlSBc1F?q&aLp$6#w7MQWY80uAZ|@x84yxaQ`1n>s-`&w zIE661!8h{a-3q&HvOF}rexcf_eJK}$|4|2;#!)T1=(S2*?q(M3CZ&7kt*679R^NcY zR-E`+Z_)EqJYjZ`inJkR@3FmR2?~^4)^5^`j^-Y?msOW%zWQGX9LPk)h>OaEAnJall}}8i1oh2#p+XVv`^EPLi~ueDoJGYi$S<5Dyp7i3O}YNe zq}uWAkKT$pCB6At?dJ1NrWtR#W*>q6DGatmsZ#A{V{_W`SFFWnCl&>$_+OFkJVAFF z3g${gX#%tETeY6mtJl~K&yX}D$QJX4Ri^U)s-M^#qaS?!OqhAl#Xu&@GQZOKgFT&a z7^T`m@tX@j2T(*fP7xW{E=D3)y1P3#bqno!7ex8##_~^lNJeopg1g197uPG8yMZE4 z!owgFflilmucsEY^f`@g!YuavDzmejX9rG~nk|ApBq;(w3nw`ol`JAFz{fc^OC{7zla<&74 z!$H%|T}{$Via}$=L^p1ye)%T~X9>8sDX}j*82f=uNSA+DEsHZfk3dp*L0<-EnEO#l{CHah zT%0cVq;a=_5@6-4pDW0xBj)m!36cW=9j28u_Hlymp-7Fo_uU_q?VWw;p+A~evvWMG39lfU{L&Ow zqgg?;B;-+T<4V!rc-T^2LFzrK-=xzX+exp1H!dE_4ts9t8Ln>h1U|WSvXHi=@gF+T za*JxGzaQp4nc)x>vF@5t(2*|~S6ykeD~dEv7@nnria&I@s|LVtTq=(O zY}PTqF1n6AO_-h~!{tAbi}l$M#kVC0HYl(l%`UI)Ms*MT zlBCoL0w+djwf}%}>4!vK48BK(WeiKHMvDO7yi=3=fx>M%`S?wc+(N`!PugQdpV1#` zfJnYjVIW7v5^G=fwg^=N@t5$L)Q{TAHA&-QXj_cS^{|b`Z@=IBwg*tg{W`ME^n4D8 zN)q&_n%Q(E#=hMh{*BxhV><4OCr)3B&_g-PQn`j7BSq};koVd!tyy#OAktkOb1Cv{ zCr=#`<%*mEXBr9q^H?}CU995FBinkrX61gZwQ^{pw4nM)tO}1M4vsvbq$|obN8%OX z=M5C3r~$eUr1fT4mY_fbfXL(kYqvPm<00o-b$T>P)kL@9coPaShRI z)#tn_J4=YA;$3{TOCUZ-)S zX?6;*ZKU+WEHBePY|EaOQq5A!DC;X6E`ud!F^4a>PMytt6`SpSUSAX-{h^yOhUK;l@ zKPbtr`bi&#F6(T#0RLzU(>`T3RUR0W^}#NyWX3)CyP{S5@3s~6o-a^##9mZHK5_*u z=O6m2P`7=3T5VSC3+kv^VJRz3!6pgsn~MSy7is37Riptv6w7tw8zW zPgS*<@{4AbYvc_OHRb0T-JmF(kU`wH<~`TSM0y0OAoE>m4Sgx^t5KBe$+3b!5q)~n zKXebIO%27#FE8h7mAbemzkl(8IxY$E=skhde##mDvFhn{7=XXSKJIM zuYy~1)D4>^Z)6r@IxR-AND#{S;oA5I42)EMnDtP=F^NP?aHOF^K$UU=9ZjLq?h%(Y zNP*Ir-jkuA7BXVGZfu?PDSJ40J0>{iEAqbOm$rIx{4V(lG0Erlci%}*@2k-?!xF{K zLAy(|2Pod5Ef{Gc|3~MqKBv(*fYK+2622pgA9&_Bns(}8)(SM*Rca6W#~LP=E5*~Z zlo~C4%8^MMe#HVQai+sg<))=VcKR07B-z}}vh?{-R(90FX&d;nNL-!exzwqBfW&1F zqluP!=XI+Oj=qCfPExcVGuPHCL8W+7^l?Z)32%YRUYw7|+Vjg(v4ga>WW)V$SxQkGP9SrojroS%Njz3m|zKGjtiPs3b0=9&X^Rht!>F-#s#eqr0rd4SLs zxmflwC(;Ur{-nZhHmG;ba)01tWRA2C+c6i7_NpVhj70#4I?l%n8Yo!pQBZm)VKbEN zu=w)DhI4c6kdRVY(yu?g9x)Cezxnaa^ScQd$U_J&sM+D_mnPPt>)BEy~Rs+{p5&%VwIP!A-YR#qskyl9n>mM!VM(uF|Ax zTDPBYKDu+W*UiCKdd7~M(AFkgKWCc&UmESLlSbaVDZZOJpXyNZCNXpHH6^Zp#tRU( z{B3KhRG@6XRXGmZK%LHiueFq!f3+R`)%i)3c1er?>EKtVKND;_hBNyb@^O*#F0t~~ zHoU7lD~D$ON~s?ks8y-idezTtuDEoyG}`0Pd$!@(6BT{+Jkl`5mg#xtg)jO-H(QG% zt)G?gL0_XTdz(f``Dw4ZE7OuKi|(HHGPU4=2w9_k@wHR&t5RX;4(`&GpNxJNI{6Z; z7o#SCQ@e-oQS}74kPlRyT{l0h*D6lr)+5@uid8AM*x|jN5s2%F{b;op2c`ees(hT7 zf~rHSH{GZ`o!h6G^Q&348~x4%-=dN&_Cf&XW9YdfIkCd6(G=1peu2(0;}r!Lf@oBn zkNcYEg8q)j@@3~v9pf!cUb)UwDs~u(?MGC%>Ds-iri4PlqZ6s?I3SJ^l5Q}HcIdra zt41DM%Y`)c0ud%G%4df~6SlJW)#jeG{}k@_OKO{(?7rmiiz+ESIvFNbTvhC@6a{L` zJVc`IPgn}>-Q-k+2vu5?Ub=TgJt(~=j?LxXjXC}S2aM?2kvU1y$Gkhz2ct?u5IudO zXfpak99fTQYwxipl=wQ0bh2*gcNSp#hQM}%BHw$JWLQ_zKWA9X^pfbzYb8_iUsSDUI8qNF zAGd|xi*4v~fql#IQP|0NJ#~8xW(3je&32_!y%yC2&n;xl-3P)y+xx{m$3!Ahq5Vwz zWUQ#x48?aQIlDb~Y*a09>1mpYzC2t|&%NjR@iJx#uJ>yU|Ac9HIWc-!F-z80B$Cw! zjFcoKXx>3Swg85K=TYXnc(LD8v>wW-IxRx*o5j#Oy=Wv-U4vt;Wru>7M9(j`b$D+d zc{VU*!nb)ZEriSZ8I;Z#$mS~IK-7L+n|CDFIUo5C%Q_Ii5WDoRG>JMh!aTI~Uv1%w zVY}=4J}2dSBHk-LHcoj3(hta&4~L|aa+}W|ZWejZ@jmHlN>QVZD?Vj@2TW%8HFPQL zaGZ5OcI_&Q43cbyZeMj2v`k+wQ^6&-a-A}=@mYiYVMcX@gPvaTdfYH_`ElPeH+#oIZ?B?P+3ACJ z%lvt3=)1DMI1`W((Rds`?OpNy2gjnSmdJ7%Kq|*nr51*PAX1492@`tFD&v;Bd!(U0 z$0Qe7GRa!*#Xx7_7WbM4RQ)B)q;_m32Yy<<^&#XrraqZXQCf=H!fZ~?Q%qglI}vOB z^1kQPAf-?GF|_8*X7e`rIEx z>Rbch^l7{ScE3?Uk}tfnFQrkN$KB!GPA0|DGhyky9Ef~MvcyPt+n?pqje!aReZ>+$ zr5&=Vwl@sS-266nKf_P+)lO&?i@(t`+%UBaGK#&BY${%cA32rXY2?RMR`Yn5?ZR-q z*Qfm(`g~fV?1TGWfEdefPZ_+=?nnh_Yf)}#k@u0*zT!B~%-klgS^p8vooj1w@~>lohvAB{|l zU7x#<3%GauOjb?DTzwc_B6yW7sO|QhuKD|iPft+?f8EMmBX5dI#oRN$_Lz*TKY({N zJ&ju*KU${%@!H%5(<{3AdeKE-P^uclvnUBg=Qu7!QnFv zmkf2i9I(})Il}N8i9C)w)Wlr}{Ao$nr%#{$VIq3YQoz~}Q}E$KsFn+U=&Kmbm~Z0K z&%6fZ!HvGh4dVJ0`pvUZ~Qnyh^T$j%w*nKPfCjcxX?iJRvP?K%=fW68cqNZodX z=PdG!@x@(wEe_3j-E2F0DCmuGBHb!F-qD`VmpxdzZ*6OsK)VNvFNcy3QpfHWu3bqG zFP;!ecJsc6e=>KnV+lh$88y!)^c9LO<$zvs3DF-}-XFPFvwDg#Kc`11h!%aXunhON zzg{th(-+Wq>m0b>UF$k$3qOnurA2=kczs4lvdrvDPW`b``@OjS{4fIX!y`m-^L4B#L}hyv)-d0b7!{<`;C(7g;@m*RVt%+ z!&r1}mx7BvPB^(YH13@OrfvAA_A^7KKb31}tu$iJz<-uBbqT8&NGFh%$R~|GJ##duO`BJ&JQ{R z$snt)xN}XuGVtsG~Z=M{gC=#L?M$b07SykeYr{ zr^t%=en6Sp$4GK1I^ZWnhH2g_xTb~@E@uFpTK;jHaqo2i%mf32!+UC9U(_F`SM_5z z-4bqP#=QRqa$+xvz^YCzJ>I`~J48Ug;aP9>+{g7pcgjcE3ouKOB5{bwy)=PkJD2mz z$Y`Qv5u}}i42kG>uTJQ-3tpA*+*2MFw(fg2yp%y?#~iq-Za+^yMz+ejitvkV@Em99 zF{|tgu+7dryrOmz`tf|F)F&xQ4{6xXRc^7YMflfCU4l}n)JiVLM7_*=PjjcEOWjY^ zDg$3Z{f6}mc3YNSCtgaKTNSfdilUO~MACE*d$jr}j0&A^Lwzwml8dlT9@bh6V_ZT|YtD-X1*Q5n0nW7{ne zjmP-7`FY3g zf7Yf=Eh^8Q|8c5nPSWN-Br;SobLHO z*ycC&rM|uRA)(P)w8x&;da1w%8~@Mf3h?f*Qp)LAFhAQnX?BbQ!6 zNa?61)IyP)2fl{QSm&0y6Xjyz_X0v)UKZycSMz^9U9*{g2I5PK^4MAsVEb4q-%+n; zwSM6CD{lDgM}@aZXGSv4O=7Y7Ywl*Q8|*e{L%-z>Ji|ZIQRQCjw?7D7l*Y`kg-pG- z6=xs@VllPH;l*}UR$Sa4_JW2{t=@?wz~$sZXLQOzE!50zqxpB9S0znfF&w|?M&>e%;m6w-VKs!L5>Xybc- zxQXens=)m;jg%Q`Z^nb^Lo4Mwf%b!?RI*U>H319n2sTB>%ggt!6?niQL)Ic`IFo~X zz-^>$+zq)7`x2YPh4P^PS~8u^S9AmMn2o3KQ^JMkxhmWS(Dp&Wr-<=UAoSMo_>+o} zfh1b0lHf%yBdb0xCfn2uv>1Q%QY;t|IQ)(E98JO0bxC7WT|`o1w&>yjs|=__ zf%ocX^VVQZ)O`T)wR5QMP-TJHTyF*ELi|3S=&~p*@w0dKyHi(vD02%}RB}Na!_XIq zeer7*3`S8zrB>M_igJMsTpHraMM$S!M&GGjx3O>_ad~ZAqLYeB@zDTd|Bs$ggjckB3)YQa)mzRuQ$T%v zj7pFjwr#6D7Ul~?`i-i}uDp@}`R>`jcL9gpPnj72IhIe%O+w@f^lxPKfTQKUI__zg zw5vDHa!-Gp^mIk5Ta5W%Iq?E8^0pD@)^ZNT_~>z&K_-2xU=^_NBE=s&US*y3pp>hqViv zWNwk=ZtEB2-gcWU3Tm4@7i~d1YvF60K))*)s4z=Noow-;h-&)WM9JMOZ6xwV%hh-Y@rD9k&}aWMF45#7tKq#6_13BY3h*%6?hQPi&IxdP zuyLB=pS3U*w6CBTBMS1j6UbyDz>48~;I!l#Sx15%my4nUnB>?$V`c0hhSLHowsZgo zGFh;QxWu0dTd$UhuzQ*dtib-4wHlx~jX6JIAD3Vpo!I3!jkWmD!(HB>_}pqDmhm;#CpC2$ zm(RakgNsDQ;Rl`^i_)GMnye zb}Oe*=Gek8Tppg+@}y66ixQ>)h>D-s_*MsvH_N}}C+Z^V)3UfvqLfwYIqbtaXm`pb zr#|+E%PDKz!8AjXZ^1CpU|`$Ser4|7;P8-r%)5p~F35w5)$<2xEAIg&kulXt87=-& z?Q^X0r%i+K1wDLkH8nBDUqpZ|Z#3ec6Ty;rM0gK^{_$_l$NpOHbjY;S5!KgntE%}1 z(HjfSu*N^C+YhrLE)dWmY<+*WZ^F79;9s)Sb`32X;iKq(2gtc!Q*?;6X+C02hv!JR z%A@f(iH}DH2@SBTU^)guF|q|hyY?qOhQs`ep-tl2)Dv2X?V>8g7M~y*2Ut9`Qz9E& zXRncAD(HX)ec~tYr-1=Z%L!@e^K5jOvu4_)Y!MOI?CCIx6(suiEP+r~<&KNHnb8Y@ zGHgUiz$7!yhT}t|5xrJ-YqrGh9E|slpI$^oxC}0Ifn)qy?FQlrlZW+Jtf`@#fe%}V zuzo-w99g-mZ;$~oK4&-@8tp?QC9`Ete&MBj z;{Wp~XW^~7v-mB2MQCbF>LsY#^GmO)QcXR~wvt?XkF?b&M2scrNazqki1T8Jww9`0 z-VUMZFhwP2=K7@d@F0#jziTCPAD!&CNu3(4ZHHr*GV1U!RC&D*KM7)RKKqjo(^^?p8IGPmt z~~ZH)h=Xb~Aiz1Z#dfi4~v`sRHLb{#0rl9kURFqND-lnnCQC;}UyZ5@-pI z!9m^&i{henMm+S3UFj%suH{UkS3(C^CMVD7pIqci`6fu(51Ao^i@6fp7Ap&lk4-x~ zs!!UFKJpNga@UcDdF9F1N5h(P=cEUIv_%4jtJSgskk4HWa%@j>J*+xr1 z_(4NG!u@VjkuRQ5obl1j{2fSq8Fjk|1fCS3nvb)>Wx8iciMVrN=C`&z78itiSCf%Z3*7n9o2R zdqY%+wqju^3e#6Eo|#phA@`{+)Ry~N85uR`Z>9QmUCb*h7tbILCZ*@qRd+}TVE)Hr z7!vjcpSM;NUcD0KvFy|K(eqn~*WI#qnbRcVzdRTbCU;V#XZjtDfmQVqZz zup#5(#W#owou%0$!kq|bhAgZoD+e=53y4oqt@@uUY^BwE6dn^j1w5A!AddshQ36Cy zu(ek|Ty}Ti@Z(QRTRdZzEk!v7#IA!orNqk%xq!eQ#Rn2y+ZhCL|z z_>Y}Ea1-1b4uDdN>bswG!5aa+A$-HCMOeIf(G=`!M2xrq1oK-#Pi^SfkzmcOK6`n~ zF)ai@(LhxaPSd;WNn3m?^>85bb@K1B%a$56!0%B4ERY1WzBpC1nWP5l6Z)!GSg$Kslm>YIgqOMV139euxT?ij0$yTcBFw6;_wv5A z=uSNxI%?xV>)L&P5?OOb*b53{0rYTfhW=s{t=1U3gxFSC9n~qD82#M*cPRFme7A;l zJ2r5wcDV8lJ7R@SNd}OGxB{hEobobY^~BjQ3(NO^UxrrIY?Jlby%(hEaGfI*h}%wx z&_-GFDY#2PO;N!>yarrn&yWrtI${I(rfUjfFYY6$o!K&u+@Tewxwpt>*Ox)Aj)E|dR{}qmxvR9 z{^dg4s=+KyliA(VvC3$DvFCziO7VWGyGAsXLLd@flr9Z$XxGW4nLqPd zSAwYj%?Zhkg@4QmXLF5Znmy>DoK)BP0r)DqVOq{B{05|QS^BDzm~Td4qR4pTNzBIy2u-)HEf~qYg z50O%)WW}L)3r$^I?3I|4>QzVGzZE(X{wW4gF(;cMbURp3O!)wRLd4!1DidomK0?MbTR3TC)U8 zp>E~Xo#-eU3#>!({Z7)=5+K+gz!?yk7l#nRB9L7c)On9sYjV8&LcG?Fu&1SgE6U}7 z2MqJw<)q-H)K@@!ng{GdGKKFCYpm#~zdewkIU&MXnpo^$lT#OmL!ozRdb=%3HP=Tw zW~+hsQ>PFx9%CACxUV;2@Sm_3`Rb9lE$10*e(RZyB)4k$NvTJyc{=jG%`e+_YoP?5 z0~ioVz?m^|j9_sNA#Qh(_=u#lodxe?_y9?1%|+msb095O;KA?!2rmRgSDuP+BmGa{ zOos%>1Z#Mk0MIT$Dd3}Wz(n+jrwHT!6Il0r1Ps6`diw2uF*_Y4@VNE-Hu1kPyFBnN zj`iHrzffiXTDZ3TZ#DqKdW3f7?*4@e0MNa?ZSsHD{*5F4Lf-*Umh)}Of1%-zQHEdc z++V2nBi{RM>i_UQLJf1r{z9!D@tXTS{6{npDbpc2)fd)-Rv&j+{^)|eaYFvtFbFmd zkfp+q+wpfXE+88Bcx0o0GYL9V;9W6o+r|Io89l*(7=$-qjF0&Z7*7T`>vU?i{00BL z`{b!8;@7wYI-Gy|en;+q@ae!`>temXBs1+Yk zCt@3k{X-hv^GDR_+c|$x#{sCDH~cOBLoErQCN|^G{}**;#rq#15CTZ{t=I?f$qRst z8uBMJ|B(R*Q3fG_^?zglq2U@tAq4GTGXAsB{(s4cd?X{8*hu_eG7JGS0$`TSf62&w zhC;I02qKL*73>J&wPQL`Q~fIuKP{L-Li3it!% ckrzmyiSeMayD6jx;1iInq{7=`aRa~q19ds*zW@LL literal 0 HcmV?d00001 diff --git a/docs/source/_static/apple-touch-icon.png b/docs/source/_static/apple-touch-icon.png new file mode 100644 index 0000000000000000000000000000000000000000..50f35abfe3029241fbf54e2511f58f1e76da3cbf GIT binary patch literal 23089 zcmZs?Wl&sQ&^3x%aCdii4esvlg9mpD5Hz^EhTzT+7~I|6Ex1E)hda;vyrVx}L0E4tMTxwKoT1lT6TGIbWZjxh?2T+&VG9a>Toh{J%-O4{4W%U z0SBF+4$zgl9Q;`7Kvu2*wQHos_^pF)GQ9kf&y}#v)#oG0$=Wg~b)~^_yrK4LPT(=O zbJ^s~a9iI$FJPaE#iEV_Mg{}QB8*hJm)Kp}GIqv-8DASA7Q;^~+3W0Fxz3QpIYc}DJVD|W)+EwGc4lB1!%+7mSH)jVo?b-CF#oW_1wTD&Tpd2v(xKNbx0I*f~i zbWQ9gv0@tjG8D7NFyiKX5JZ8wvuo`9m6JSzodjM1`37c`LKMwpT(*8DjI#Wka&}J@ z3(Ngln~~^cxgPW7(}f9skxRkw+Sg?pB9Il?LrJ_G-Z<9jAwGVLg@)z1rvezf7Zfwf ztVE{8&ADH{a|F}j%i~Hk?jHJLspd~Nq36zr!~vr8Hyt5RbwyuJR#7ZeR2dJ~@_lr2 zG#XtK=6;lo zIq3Uuj(Y~Ol>me6wcJ!CSHlMn{}Ydi$GNrX9|N_EuDSx(Ju6bYa2&v}m;eRb*RMnF z4p>dnV9QZYf2)8ECyudjbHJED3b2x^GcF?h5Em+g-*iv)qSmLUL{T6lueQs$OUKJ` z<>D)AnKVcF-cr1B8%jFOqU2M;vFx48en^33KRWQIOPN1E!zGm5PK!3&T^{MtcjCGO zW?jm07XApZup^OymbWm;!H+v^8NkvQ44R>}i*zFNzeok@l7fAC3c4tR{E2~?t^N)4 zJ7a{IO>Yl3g5&YmgV88V{_^IfEP=j_1v5kAaSIjW7S6FH{%J94zvb~jO8k<%p}H2S z*nd8D^h@~%=)7n)L(u5>|6Q3mLC8xsUp$%_2z`YfaIHGzM0~Jv7DC5;yqBOk(V?LA)x8h7; zVZRcR5n)vCoZPzn40g40^{i(#y747ho{agH7Z*2Sd3h7!-jGIRSEjmHB0Txj)wJT2e>qESKR}E`zg#}SR=oo94@>e{w zhq{dnr|bV%@t@;_^Um_F(Kn{?&U@MQz>ekynAjTK#I;u2y~len?}3Im8Br+De3m3)m2H=m2Z;H0iYy@#T`yL`++8oU9wP7>X){laQpZ811JP);(fvOA33KEhTUYlW(@8~rixnLh#6Eg(VM}cN# zL^?21DJ0fUpaaA)pz0{y#!otAA9gBliaf4+eH*>N zbY~m@G=2>9w=$Gm>5^rBvOEgoZI5&rbodm<;30XqBgYSN2!Qr_*jM^WMAl@8?3l*U z3oq`1yK>bcMqnS>%n(q%NqbszGboz)23xgURHr<9sKS(l?+xvuR_>#2pY$3uXzNKX zU8`7N8tXN=#Ns6hNb%~B=U-J62fW@5PLVjj*!!Oh2n~2LQV+ml0i+akb#jNSU6;6H z%*c+4Ra0Bq`b-3OVqRjLU*J5|@H{p6*zo@|%1g&(9?{S?<;?r3G3lP?iBLGJ7!Ara zKAi%J3^t?INUy0WUHt>9CLW`;kd28g>6l^rP%4W`FG6vrSFZY7?s}@zEms@M;V+e` z-zrl>JY!$p?m~U~maqB8AfPFU&S<$gEAtHv=?dCUSv?Q-+2qyVFmVV4RI=9w-k-nC zezxk|u`S8Ge$!pQc3dgr?_!_3ZmsIJK3Y_Y0mByiA8ZjQ)ph2|xu6nP;H=KtM%+(9 zFmO^L9ES6<*7h&oAu;?HFkkp770Hr_64=;~`txdSr)yFfGZd|WnP>jOqk6naQ9=Y* z?!QmAl`|W2OOFnAy9(+W^l((1RPZwUF3sh#yfGic zx&uy*7%7NUJK*3WUPoAOKzo(=$r8t=(>(Re>d`G*il=S=zcvz5mnk<+DrVt3mQ1FJ z|DA<@!F&!QkY)5|Gbo&<=H)<-2`Un@ND!0=)fwjG3|(<)HVAfTO4e8WmzvOfnPf}A znvU4vm5P_ALYqEIjzH>G+=xbL;rHl3DJ9<X71G0X3nMK6}T4#Lf z7x!jvdb<`wX*6=3NRzoPS(eJ|D%gZ7w8$2g&CPu0hWxX35&OAn|8iZz6TcTS|k~W9JH^gg|-t{9e4EC&;sySLHn_3RwyfW4qBcgV>vmf z0ux%V>L$djh+kO)eC3lf0K_i?DeMc|<=AYuLRUciJuGH~IEK0?qThWLA+8HI1vi=?Uxlrt@?!hdr_{ zr&e>M0d#N!A(`kz$=QGO!U%eaXgC=W^2XWKmrkf|trLCydbLQFUyVey_b()M4pYzV zMAMV=TlRdT*geF|a4=T!zaDUz&%Fc!J3*HQS9=+sarq|K>IdbV^;&KK^;>JAN5$pt z4$7?UtE755F`0A4dnd8V{zj(8kI`RX+aCG`&Qe!il?=UU8W9Ze;^Cx z9cr|bMFjFQW`Ozp0VUI5H%9*gkZ^D-Hd&e(-zSdfESTCZC<>VgE66&8OwLLOJsYa< zInxSonHE`QnypPi=>SOa<@BMDk?H;vzZ;iCjWkOw=8pW8rri2vXwp-Z{fceVm&Y`< zSV+@tK5(12{XWs-{acNPC`et85#L^Hzc8b|NwJM=6FV+#(@jo=QF3=n`kTo;$>~F@ zck8YF(aDUA$P6C5h+XvvjCi8QU#5W4tAQS9+<~wa1oqu*&oF({6LI2cZvvVynax-^ zsqV&5A~6ASKZzYu;~kp3e22dR1Pc?4u>!)F-o)v+BSZPh1c@sK+T@GZ21lik4NsI%fKQc$J3^4G>oX04Gs7FwL|_VaUCi zF!pWv1G~DXV`r#({KFu7e^3NymT3*%1$Ot_iBE2@UO1`yPmJ@gq0ta6I|*Xcpp(Z~ zLCXt7`DD4j-+7-99lfP-weCaDV5W3AOw^pnDK0At72k`Rq+EKwy3nVJFOQBKhMvY) ziNtTdTvZigfqC5MpCn$PlxXK==bU%XfXznlpRQ#aJ$zNvIM1iqQn=vg2ODYBJVPNf z@C^}LE>x9=tjy6(@hWqr1Yfpm7b zgv!CvH~psj%h$sIQ&pB%tP7!(1ZxrQ-9W;F!NP(r$_uuO~8Lq0tJaPg{qR!Jr~*o$0|i%jB*I-W-I#LL8Cx5VQkBCkxqn@Rufo!l!LNs8PUlmL0`Ua001Sx!Jch%-t;D3OtR#ok)41ir>u` z5F|mZ+wO^@w;!CafSFdmI!B8xz4BH$-ojRmbLI!4;W6Ie5AiTq_F0t)@SBOw#1InW zXWRmZVTYz^6jAO8g1l#`Vp{PolDCu3d0C@IZ4n(#TtauilCX((?}gqAqk4c9Y%|=r z%Z2^p=q0>K5>-w%P0m!j66T}3tJQok@#+fk>Q+2qLDx)K5oP5-1-(zu8_Az;G8I+= zS?1l~xq%3L3?_igd1PU1$Z|QkM<6k(wf_Q!r7ezi8fWcT1VPPk+Jts(J_DTU-<#sH zR^tbeHWSnjryR@cUofg0GXweF66O4t?IJJNNn*r_a;yN_1D}w7VCnP~o*By51P4`v0O%AMa1Hj`w~8T&kzR6=(OXpfP2`Y=QM&{nPNW6n%~ zrnClEs9_$NOc`p`DqMG-$dXW#M%9iQGa2S~2REh{MFZfxSLU+n!_lNM$;T|JuHTl! zgch|Wa;y%m0a_~kHswQc-2W0F7*ZnpB4>KSbum0&hVWnNk5uO&?L_CQaA^GDq}3q{ zUq~cH+S4sy`1m7@h4&m)d4y~Xrpv1#7%>z*tORC-w0H9F10HB>$%a?{aVLI{i>SeC z^hqPD!6!S`lwY5CV)mR5GYp)H)^wnVU_pn9d(2`D5-U3&MHNO#<|EQyk2)4OJS11ijKx^bYg!CLM;hxadk2pMI{$2H6 z@6yr#9PnFh_e!w*m)V{{Y#Zlr)E1wr%A66i#jU77mm0YzL}xoySAz zyZmV)kcB;Wv?>g_wWIB$X`}-3wuAH@Ids%LUAZ3?6ibmD)y?aSHyS#AO%fHX(t{~F zvZX5e8aXU;BqESP2CBLk^8gNrWrlz-{rG50?qQ0Yz6xL83`JlwCf#K2SxTN-Ple1r zn2r7^hRT)r3;D$1BH6LZYs@}@Phq$aRcQvU1fdqkQevrG+6tD_LOpMyrBS7$g%0|Z z&^C7d>l6q7k~DsT#XOvma*{IG3FGR7fxHc^&>&g+<2tU0NK}JV#Q(7boPwg5LyI7s zu{&g83$?KkXcWoHn-w!%c`uvhHf&*vSr^UG^{`~Ja3F(<9A*8Rs{@^i9Q2zc*Yr(8 zzlKz2rZkLEX89i=XR#N_bikxPtB4;F#bpKM3s&S07peBk!- z`v1@+sTBw+fb;~+YAc?kz?CUD#v>(6Bg9J~D@riuZM8ei|AImn^s7{|W~gCx;k=wh z$*W9PsMnXTL<|C5vdad(>(Z)JrTb>6W<;-HnuDynxiJ0{0dn#u_B>fSC#n;N=L>cdXNsKwqWkT4&=W(j*YOc?9Y(UdS0Vw4HNyGY3XomgZ)Cnn3Wd^0fIUZzkOTB)b z%JGOV#|}c0@06POoi3|vM(i%QdOiS>Rg0)L?vxDtyuq~-@YyepHCERlcRx~2TD`{t zreBOly~Do{OR`Lf+|D%ruKq{eNqqLm)QkYu9H37G()L<@$3c97q{ZOv(`D1q@w+}7 zyuk-`@;~N}C&2p0YY`-cO$E~d!asG@2nk=hZ1&y_(s?Vt>kmz%c*`D16JoE8&+0-i zmleW#Yr(ZmNpbpcyO)<3n*>_LRGV&73mX&kGP9QRQC<_X@TAeW4VKJX?S{Fz3d_O~ zCLGbmvt=-`=3v~LZ|9Q!A4vw4#0f86II&h!I6i0ZxK#f~)Lo|^@Z0fV?T!3v-b*T;Y2W?E&-b1aZc5LmwffF2ccVJz&?^|0Vn*ORFC@Y69YU^fw~Nml9i zhoY}JxlSwrz?(VVO~v%ZC|(8KZ+MtAQmF@4YTsMtpDLT2`Ops8RMI9#GkN>G&_=+B z^Qgy*E5rRaij8Q*ucQT(N5}NAaazN#r-%@4o=v zF#l1VfXYd_Ul?g7Yn7mTzi^@c0qIpmaIWW-GlgC8;^lV704?Gcti9@QIFbd-+dm&=MjF5>pHE&zi5~ol=L*7Sa1b zT*oSPC0AFUt>S-qyw~+g`(xS#Nyq8g@_A@?Bw9Sds;>cww=t20P#{Y!8Mf~nvILc_ zzk2?UX+oMt#V0}P$f}= z*k)ilfB~em9*D=sR0L%A^A<(>alk*Qx07SME6WTWl3Oc@&qYPXtw?i7X+#5%6bylp zx2tpy`cxilNhfBZr!^^LGm04tvOy!9m9t<89Ga4~o9E9kw!fWa^jzkd&0y4PMMz7T zhfxl@xdh)yw6aGw`5Ihs%EYeE(hyeH-q7lnXMZO3+B+|Ldwb(TKlUZKW<35wl_zZH z>-(!Of~&OQ4w41Mo63$ys@VEmwWI#$*&kBY!Dl~z zlD_E(I3?vc1vP)^pg}}F!s>C25X{12{ugUCJpCWPo<}<3`CWBAHpF?ozUClE@r*o%bhwszyhDG{OyCO*C4xC=5hu0l z$PS7vwmb9vCdfIY#kE?ZO*R(J%T)WKj*MkH=Euax&N`OVg0W509O5n1Q; zm*f@#*8o080BVnQ1Pz5GM2-0V%q)3j-$I$c)%E^(hp5Su>dTP+=kfR5=pda8rR+MC zDLeEv#FRBp^>lKJJ&9v(;iy>4-6h-UNRofz!!8AX@N+mf_B!$d%}z06swV+MlF1Mv zYo1sX9ME0^YAm@vz>fi|PbG)<1m+jIfWg%9fHPVjR^%u)7ca*U;OK6M9EFL#r2`)D z?4NM~cxeQNLzx5dfcO&o5|f&&geGW6g>7 z-)Ok=pZr4HFH*0L(I!5*^wo@=N$SWg{o#;$=wdU>bO^{2YOYg12&%vB)XC>+_x^&k zh96rc&DbmSaK2Bwy}gc-q5`_8Nrd&#dWg*+ROu-kE>9(aZ>9M->A9`mB}CD;wAl1g zIl8Sy*0n!=IOqDM8Wc`=L4XAhU~QFo}YKJFWp>j{guQHFOFaYr3v?<)A|BmK{yUk9p}~@ubV3UHhwJq@{*q zng>`t=elKXlf)qbPZMr|85tyDNEF2qn^FF_6l7p5P$8a1yA>fMWC~Tyd?ZC(QHkW_ z308A-w`%@;j=ZRtFDwJ2W;`*6T(Rz@!+&OI=?2;>aNeuy?lZ&;I}PCr7-E`LL#X7{ zT=Sy_Na+MxD=^+C@#*S*ch@nG-ZmvL{~fFxmUM+XGKu$|)a5s~oL3pL?sA(gh-zpN zV4Ut#<$bVkwe`yS;h}b(A@hz-)B^3G9KvC7j^7V8oB;R_fwai~AV`7UhZ3M*lKL`+ zj^QHRLkyJEiO$6+6yN?T$)3l3_Yq8Lo)>$+FQEOqto=Z)JAXZ!;opA^$obI$AI&~d z+AM`9;LO7Zh2jPTO|Ap8O3!z7R|eJ8_GQ)X-uIce?)k;=X1f{%Y_(Vz&k9S4s0a7q zMpj*{Xwl=3X2n3u?XL%ug8B4?tkL2S3dE+SIwpx`9p|OO!np}3_8>YEqjLA-BLm_)X(fh!^kyGA+Tv!e~vjj2*%BMniTJ$ z#b`}_n4}JS*!$?R)g2El&RpTMLVwwLc_dR23Vr-ge zx?%vEUH~|9jO^C(f3<)KVf#nL*Bg}Ntr=z*QSW#|9Cz9l_sZ^ar>bs6RwC~Tb;b%K zQBsMZNK6=Sjx8a^Z-K&faYV!?9VetTbXms|*bjZ(YaAGXJh^<38o+T0C8p3F zApmQ;KM4u31+1j&)5JUPP9Ap6bY8|{;EeR+;sO%1oM$v}<~24Bm9B+yFE2IfP!t7C z2+v`Y{^YeV#h>}`Hbzz3W5(NKxe~>Mp16j#JD}{!>;K*ec+6#Mxvjd=karH&yQUe4 z#6}n3Uty@}_EBCV-KO3&L{(st+kBbv!4$_wMkwD;lSMu}5a0K7#hhEOOzXL>Ay#W=n!P-dbt--me#Ye=vM)(yY}cDO_) z3>(nauZ-Ox^&T2Kh_Zp3?OzS|3|NF(CqxkbA>#VbKwJsm-nDY@e1eQ3@0vU80qpWu z<7#LdEx}6x+QU8$%U-F-cIyanqAATUSL`a~J*V#6}Zy;)XG2HE< z>3D%>jf+RZVL(IjBtv#ml{(1OK~&m39q_u*4T*cDG1P{^8}Nkr=K;@K2H(q3*n=2! z(kZ(?Q@4Zwqyyf*%10I|rYyH!<=!^G)Np->>C)Nrn-Fpg;v-JtQ!SzHPW+#=CGM0+ z1y!5KwT~(=z$rDt;1)wHN!$Yx*&1c&REe6=TC1uNzeWOF^vmYitDCskSqrA`&<;fB zT_S%bT9r+&jVkX&pp6Q^ia!%BC&vv-50b$1`J-Fp!Q{)*?cX3Sw=Pe%Y*#S6)w_9w zvlBfcj>p6=`tL)C2RJ5JwsZ2w*{zeZecgIEby`h`oofy3CMQ86B*nQ9+{;CzA#>g4 z*b}i4zrcahSM0@ADs&Bn!1n3#?+?44?&<_ROf&9XPsb!l!dZ=|AtFvDTa@3aqtzz# zA8=g7G%qN<&7aY!_(X`zWQ?%Az|DTKXZ7CZI5V;NJ}R452J{P-doX7TPMSh_rJ7g^ z+@@ZgiVdUtJ2pM=fB=|zJLBooW2d$C!8R<>Zmnf~g1&O(y3AmrHJQ#N9M5shi3Fes zNrV!tE_b4J_Q2~7TkOmEcg)k-jF7lR8c~*wu(-tl^Zs*_Bi1^a^F_* zDV~1jZ!u4fFUzZH#;Bp)-$xaZ7+*|RnH+9^y+!4z`bHX=ud!RybqU+eNT3Tw%Lkxt zccLM+S5RFZ5Riqwio=oGE1wP$Pkf)2LvsvPe#MIa$OxwyRCKP&d$YBhQ^#;kq&|-E z)CzW2KQKDV(F zYx|I@qDIfxLc^Zm@8DGG@tUnzRPakF>OuHF)F$>lm&z3fsA<)EAJT7B;K_;a)zX5S zf|V%(i(%wkXB9I|NfsitItR}#fXsq_vO0ZvV2+?AVgJU^lB40U7uY=Ek4jjk; z8M@YvJzQF*`vXlU`3?3xx;Fm6qDaPM^cRNAkC6CAVv{Dr?7;PvFT4%hAuzG$zo_Nl z-7L?P5Dn=*m3IEAx!#B){L`VI_w!0 zK|wwSn;CSp$pv!?OZrTZFw5ku?&X4u0sr7ca!m@pZU*ao+YIx$;R;gYyH^A^jS13} z304$xXm&DQztEZ-ZL;46mNENQX;j|PgZ}YI3Cymw=z6^Yw8P{Rr~2}Hcg8YdB2m~P zOSfy2`sweXzAoVnu=AB-YDeOC^SHiO2We%fwoS4ylvBfd3?Qb(wm`;P_NGO0ArY)Y zCiKv_MQ`lZRV3T*8)+G)gzi%^ut0a;^_2OP@^Sqs#Yof_vtAO)e`+azI+{y01rIHa zGYk&Knc!2$H82dT%M&^`Dy(H{RRe?rn7R+)1Q{^Mn<2$2d~srlTy6*~;1&jV zS2`Lr5z--{BNRd$=wEd&Q%mW1<<^JZ`>y$&R2qvIHyCJ4BTa>$t5gCmnueOlqZP8% zG~&-_jGkQ1!minjj5R+_k~1GsX59R3Q2=~zA@)O- zR)*J1h=5ECmM(-44QY6H z;+WoxgAR)8MroS(42BtK>0*Owx`t$ClNS@}YCQo!hXOc$cAzNi(Q2iUcHGM(!qgkJ&dM1#L^%1;e`B$dBabPTT!Bk1DA7fRQs-vRtGZ_GJh?na9-KVc5 zsNL1*;x*vnO{i92YGuTOlNj@9?R}c)Th-PlJeECKrmo(c3K1AD&K^Ey`Gujnya%fw zzk_A1jdqx*F9)zL7yZ031kUiRfET50Ni%ll6@b~0)7AJo#yHPhT=F%)Lyc$J)zYix zSod+4@0xMtOpWURlqOO=tc?VTs!TT-6l;=ZmdeCKRHJNdlKZ|18~hoN9#RFM~H8W8|Exz74`UQg?YlXz41itp>~bNQ0G_2N~38* z7DV9MAJ$a|@=2nH(bC~11@}ECAie+1g_OS5jt=(s9z*_}y3W zN)Yf-Dfi#aB$j=*URvaP67W#YC&CKBQ*D_gnFL+demIZDYKj z%1+WZq4hw{ZR;=0Im?$x$_agRsV+yzhAXZvntvR80RbN7d0{m`SQWe7O<*`y*Tpo+ zBkQnd%t)SE(FAmc#0Gn^;eIC8&{wzk4J?<$64?QHms|eY8_9y>t_wt>nKqz0%g$^( zXH)mgtRN(iq*h~0*{2qd5$lWeAj*=*;GdGqg`4p=rqZ_((XL|)1?=upO+^IZ#Ckd= zmUh(fm*M5HEl+>y{v1ptqC;XNL^+{s9lp`hGTxD-&Kk^+6@eV!qQcoXXSb81O{TK% zi%JZw1A->OUJZyf4odb|Is0WH>tMFkuyT7xA>Lp< zF+^!528(nGm6P&L!Bb4x#zCkM{fokA$%wa03^CO*H^1y4*GzGW|!TgcLpm{H_XfLO2G<+CgbwM!Bh>; z;VcdJgEezef~(}GRLb$hwi(LPk83L^MD8Hj^%;7)OVjl7s;A`U`iC!|YO@9%6=nnu zMc*OCXfNl=%%{kr1!t6oJ2RMoAZ3z;l0vSzIF8)p1?tb6>w*OET+XQcOGbbGUj2nk zE9kzYCh6kNym&{!U(kFHcOakGR<;$MY(G=a?9AjmL})9-HK>*;nO-uHxu<4|{4C|a z2sruyUL)&i(Vt@UP1scc( zF@#>}@>FbJ*lU{6Rvyuwu|5_RMw4<*4nC@}8b*cil)X{^8SHy*tQotgikPxMITCQ+ zD|tqG>ijfk&)xpYRrA9W!NfK7g#!Y-*R&IQk29t!{d{Oda3rf}nX4s7rb4X_unu*2 zeUFbMf{NdNlDO=dw#Rh04?3SJ_(x&5iyNBB+KTInxB9@V$JY1)J}U5k&Iq{O8{sCD zd7D?-eFryfSJ#-uZDBlTeW_)<95h_7}B}%Gqssg5@!SSnflgk9yL*cqB=?Z$L@##Eb2^iWxD0 z7TFb;$0osgzUn2m^n++0p*8u6taTrUP9!~3-5p*)rdB$a>aj=MO}&`((QE&3^KV)1 z9pR6e;kCs>um0*}sVn%9%c1w?iW$8-9wF*^BkItzh4- zgEp&oYp8|Fjd8sB&Mk8$YnzDRvhU1fjZ)vJ_>P5cKb7dw>D()|0BFv3G$IV)nHsY~ zc3cB5=sM?G9jNGiHS2)B*4d64jUoZ_y5g2w)-=)O6RnhQLIo^l%XONPB4wPkicHpp4q9eCWVKH(0jDhQyC!k4K~ybGc*`HhBX?Zcxu@f z<~GsSa>l3Q6Y3%oIT*`n=ax2Msz#SWs}ot4w*P&PcRequsm?Lu-rh^$N6}vrWs;H+_4J5=*@k{3LSBMUbqrf3v-3;B1f`FBb|2zcT^0s8h z9iBi0w}BfSIoV5Vj-X2gh#+fX$xKj|PzP z%dsY>bZ!{I1?Eo}|MkD=-AEFM(qMv*u5ktWrLuBgvtzwrela?*nQw_&0~P7#2Ue;- zoQp88=0oV>3?cZ@?D%oIzxe1kx@4!OmcpdmL{6==2qoS`OSJCZ#8|#85Jt6Mx_%UM z@LXMh-X0a77$Z@r!_3HU$5=Gg(w}eYoo2nB<<<5sPJn~NE_4mq!#(z>z?4lHp9=+`NLITs2eQ-rX361D z^JB76c%Yh7L;A12ynL_`6y+2dZPM=liBpJqfaB_3oRefUs5=Orl_+-{gbh20K==(6)sm@=_8^qtE*|3 zw~)wyRpJ)|sAcuvMi;i*>P{0t!Ao#!Bv$3y9(BLSfAh^o3viCBKLfZ34v;*>>Hhv1 z$XKbR1=TsFbiSg-2t%$*toU&_w8|UsJL?g+M8HM7mty;0N6WE_DpMq5y-U-Edq_f7?AG_EvCt9H=T6`+C(r}}h{WU%N$K0iSZ?k;3mp&?SH9E1#Q9W~--y0N_3*pv5X-zB?8uEXDKC4_nv>}u8Vht8jtgNj5}{qgU~Ll$bbf-hc; z>8qOdR@ZNWS$dkZNCm13uDl}d(fOQR614{W#@0%exUytul62O2O@@zS^WEFj?P!e@ zs31{wJ-WqUx>`44Rr_$+MM2WtS)bzpOzLV^t-?UfO_s}bo50P^lizji&e?%qF5i%R zS*3_Tp4IVx*gU1$JR)~zD?eOT-0PL|1FimW_kmO_{}lBAhGz{*RV$OLaFy_0WXN9U zIW~szHHW;e2R%94YjxxtGQIDVRUBaWbUHfJK^Fq_`$|E)0{fXMQ@FwI`}sdY_W@v< zmx7ewmPYon|0MT7a=!WNSSH;%+O)lBL*dZOhg>eK5F1W>O@LMM$W(vh>Rf)j&C&lL zX}cz6lU~xL!W{c0#Wjl>^q z9E;5vJ!ES!<%DltpGmyW{j+N3-d}6U?__2uw!B5dNvRzPW|Fqg?tsgpnXPR*%fPEl z!AG|uH7G6Lz^&|d>sd6}E{&^ch{~fe?&X&=!hhTVar&7pl;fqpm?boUT}Q=i}`n&1s*W)epv|xC{Jpuw(VH{)mkbXX(yn1De93GD!mpg%bH*s zcFN6>88&XG47??W6zEJ|fdC1MI~@hYsLuVl7)w$oZ2ZVB{X5Tl)KZkP+})63=1c)~ z1Ru`J7>$3bex{A~+dEV@6IF%>C7wc*?Y?MUyL#mMWp(#A6{7>gcz78fUledlh)Jw$ zR2jC*+3o1kRJ?23kBhDcQ@V=2+X#{l@XL~L?v`f?O}4xBdEXgYzeQym2<8PygnyA> z+74x1p8e+Py4t-^l1vsO@M0eC`@8Wr#p{{wA# zyjO;4znK0_@*ye3^d0j>kIM*gWG?2^0C~Psi5@ja?6mPfp)qNu3uu`%?|tr^WCQUw zb}vMehZ0{gObSImPp31wcv~hPU>s*#qauG*5j)9?^m3>|M5jy9;z`fNPZ}Zoa@n&m z8FMMHFxHL+f;#^Qa7#7d5y0raO>UUGQel&GyfvC$lzE;>kS*UOa!kHVo8LZS$)5Se zq|r=O8XnE(`W_`$sHy)i(55#G?y=HNu_-;wegKWvKjLeS0S4$Nk3}YYsKE!0i)9A#9F)5^=}ek@$j z66lo+lbncywi;d&-%i0j^oo-o|JnfPLEo9Vp z4D@j6uQ%B3JJiAhmo#B7&rioYat;HxNlf1`4$289Ze=wSd)-WjmqI0WMDQ2=QO8!7 zNOW(8MPAl?wKIs0In8-Xe}M<>nB#qwtp-VGYJ=DuR^yyWtkR7(8ihJ1%Oq&A= zW|~JvyU6YW4vW9}4yUwho%zl4I$BSDy0kk<;|qf z^})3&rfYQRqsq&~MxiPCd`i0-LC8LkO6^;h+!tmf6qL=T5=}~{ zHl|5&k6?Ky_ax1I*SWTRK1ctVhVF2I!SYwG8TrYEuinGgR`mywKLVJSxA@ZsYpzwZ zd|v8{rasP59+MnnLuOw%n_?EfAGnRA7FY8##2l$a0koq13%|d|(tLIdAq_kX-V4DEJZKG_ zAKRh88pZI#Um^-1fLUGYo9pc(Jhew4RWPNUZT)}al?_r#HkEy0OHh%$h~tR?EEhKm z_J)>&RXO2&f*p2=T^VP`^eG+p4OtDght#gd>4nZ;ao3GS{2Q3{`Jol1rAMTOH;O*I zswqU<+(N%DM`m}`6xDU`o7DCG&a7Hs`VuqtVJUAlD5}K5xU$DMd<_*98GF{djM&`M z3D_m`xnJCNK?3OZIhgoDiAkw_|7onhdV7w$`aIxgw;jq!ix57pv2nzaZ$Si#d#V3{ zBQG$zlE&)MrmOrUY5W&AjP_2lN;n*7Y?aq$`R&S(k?GUbME!c$mda6>0~Hc;0217` zBI}*JzyNps%q{M$3HoRjn$5zGq6Jb^zpN`-t0XrY^D;%FSYR^@pq<(v>vptB=NmQY z;~Ne=Nv*3K-RgAH>>3ORxTubZ`|d3RBhhBkb)ichJ)om+tav+{VZVG*_FTqV9ewCL z!xKy+h(CKs}331P`xg|^o9o*Y4{lsA%TRt zfN!T~E!`s!8>ars2>0D5?~w>oaQP9;AU7`qi2!|sj3;B3mWxZmH>MCln^OT0xPkj{ zL!e@Eij2v8YM7>mYB9Gx!WDL48#U5cUkoimL2lvmEMHlVw&BA_ovCdqkI;hxpOc1L zz-@Ept`OrM@XFX&|M{SbX^iv;dzz3J+uG{PoTksaGCINqgy*s~tQsl##;Lc8Z`@8XmH?d8FsqvfZr}BK)*+hA!4fmg> zd)pMuX%-1Y4{hHQtYsbTpjOVEAiP?1rud*1J3~~ee zm}as7RVeTf)pBzKw}H0R{QuPgX!uM76Yq!-L zgx&&l64Mf{Ra8z&R4cSis8DBi@bsQSe~l)}BZM2bxPFII6hiCjP|5y>PSFjwFoJ9a zai2G`2uLp5Z!xKD^4LW1zA2 zI!RMpqT~i`LGQr^<`)NcDPvKMhM#`C!8= zIdr5`9{c@9d3fb|*?p)D1T)f~oyfw^VFlr<*N?0E!HmMhJ`dywQLl340?+FFAmJ3d z^Q!+_e0ufQ+@ko++DI&a87@F=y8ZV%WaD!iyNu^sGW}e4UBQrKTa8XcSf8z|rv7Ak(yB!pvIf%iJ z;S`n!o3IMs!yy>jvUO=7v^QTwIRxDvRsa|esz9X$@-M?ts;?e1E#UA%$n7&D_@+rY~m06lH~i4yehy}$w`2xhx@lS z%d3y9k>+j9lIh7J8`>ce{-<5)q;8JEv2x>bOI!TX#CubhP5OabS8=6{xka=&gzfZ! zzR;(e#!dLz_+;YiTB)7z9W*{32V%1gg#7tatK@mCMi4u|boA|LjmiG#N8Ge$<@Z~5 z%fCMPs%*mJYQ_e7WjN*7=%=@*Pp^5dbt3`@!VEGa4V0<+JI-IFd~W@A`PGZ4gR#b% z4a1xf(Y6P5F~R=ji)-cO&AXzE+@Bsa-e2Lz{_cv3iaV~Yt$xB&+VMza!Be`>#+0Mt z&ZMG6vAN^E=#4M=3s>tJX2+Zulotnf9g=k~Y>_Rib_&TC8VjBcy-+vDSogwaIk5AP zm`oUD>YOcFtD4yO;yaVDl{V9MEIY+Q_(Fmww5zYHt$F;4iqbpE;WNs_ef?n)B7Q43 zH_0!5w>C-|`OgM`m$1B8vtwU`{j=+~3uDorp5Zvg@dqP=Bs3BhN_N&}XF~!IXKzb~ zJo?&Z`8V7+wqVu5S}TX!yX4=Vc}9vg%K6nppe)rY8#y z@WDlBLf@bs)nOR7aks>uTl<;Vq_SmB%)1<`3clnymPMVo_QlPze&tqajUF?H$WF-0 zGJFF2Fjl*)e`TwzeQC4scZTzAqK!`6y(|u!{kb)t4dYHJeNoEw(borAy(x*pgYSgT zZ>p^N7_QhS;gW?&z68Lw8KuWlYq!Y*&%PF6N^FTSKR=2RgV?iOe!t<~9}Jn>_9-pRciI*Wb%venl$mt6H@V36>2ARw5N>g zD_6%PuB+WymA57@9`|jpAa;W?xDUyTRCh;@Y+b)wUU+Pc9Be*v91-%G4f}D9{lep` zBU@quaQ5~5aZ!Go7qR&J4X>vjo$eUxJj=grUQ z=%y{*xG7z?u1RR)Vc1HWX*-3!Xi*aqHAQ#CFDn0Yr@Ua6Qp#7#afm?%q^>V2Y^pCQ z__!6@J`lznhceLlvvGWov^)bsBV(C!CAgUhAOnr>DfRGWC1qR0K_g!T@l;`+=+XNoFClM zHy1u8#ua@feo^_SCE-;O0|6EuFuAI-pyK683KNE)gX(Ew`JRYS;7C1y01mj zK_~<09dXBLYZGi<@xpr9vTYx395YA#qK{gn%#V3DC*ivS27;FbCe}-ap5wg8s-r(& zgMqOY zeFhHEHX^n?CvhHR2`%tzaD+B+_FM|HNWbIbVAh(TE9f@qY=A`VnYNBj9hy*`=xJdx zCNm~iiRT@CnvLMec_}V@m@&Cp=+lXQ()YHGZq0JT0uy3`b!*)BSqr%2L>Um)n|4bp)~zjqk^@I6B{k*6vSiVW zh}em^cTR714bti93T>W&i}Ed(_WsJOQpc`a`bwYa`^KgNCfx%(j1v+YVl+?;d$BdR zu3}@->SkGkso5CT z9>-ux#C@4tS1VV~pC*%P%A$-F@`4$`LfmKrDy}vB|0cRY(6BIEYp_`ac)x0@;Uq zoAc*%w}U41*}78u8z%su0a(`GIRc$A*0hG*p-Gv**O{yl1d~DPlOvto^4#jJ^7Lz4 zqW>|?sW~6GT9hHzE|@Nt&YmO%$(TUhh!EOBn`m2YMRC5#LLc@WJ|fS+r)Sq}l{Wa8 zZ+Bo#7-0cSh)pKM0pMu*5n>zkbG*N{!*%R9=T76ppa|WF0fx(wQ`H+p$OrqWQlC2509?gh?%EN(%JANTeB7yHLxHi#3nGthH%*Q zBd>>Tf9;3kSkAxG-42S-9Ar40I31k`hz!&o^@Z8rzSk6$p!RUm*{n4@gdm%fBpZg( zxMQFE^0_tAxP33kaQxRS_^u<1XHS+J7F{G$YRfR=tCMtuxVVNxlu|MoHU0G%d-0qp zLR*M^PBPgbvN^IHKG2te?Fa$`EQ){$u}P!hK={2LwgdGU>f<>dPj@?LLWj7rN1_ne zKxxxmHZ9yAc7%a4$j&A}g(9yft(`sc(uQ5~i)U6#(}7m3HyjUTRcWE{e{sC^(u<@p z5r17>uQ$%Ap5uRwx&dQdGp|mnG4}D+NMz%GKC?QqnYO=?-XK&#piDqah>aD4J7CO~{865_hz=Q@D;2TMRRgELvY2p9KcyiSi zX>RQhL#FgOp*5EI%8KOHtLDljGbTi9zHg-SbkB1g`H~sqotAr{1h*cg-NSZ;t(upM=TW=xDASO>&(fz;JbItZ*ar!$*mLL=*& znuY&w?zwf_qZ*ZZbK-lBESNS<-hSyUnK!jY6i9bE`ul6El#+Q;i zQ^?RKX1b1^jJ&pOul)1lD`f-LvD2toshXlx?=s)w*;C|}E9OXbNue-TKQpc9zIOfLP2Vlc1ku9|8jdb>RKC4jZ1B(a~V8fDwvl)FEO)@G8 ztwLqxfuEZtOYq>Y?onfmJ*^CVVwUSnXNCVc^ud=m$ku&_1og{t9AJ(6&MRiin=hUw z+Nqm=raN=B20jd5oK}N zj5>lw3f@!=KW|p6U%Yzo4Z1&lUaHB?>9zcaZHXq z7@PC^1ST6{Aq~JpfK6_|5x4~gcB2$xbNa>&iqJonm03XYe=r=GgU!f#GyA|kWFY*M z1{)w_hnR7PLIM-Yy50NbCy&1@zuT}=SOVmGuTd^sHgl4^_v(2vXJWN@SoX_}%l$c6 z)8@I$(S~jF`*xzlq0XP6#97<4UpQ8-@0o6ifdC5unSsin$RcOw6@rvMSs^Qe!3>AK zgXt!caASX}to^cx%??h(S?L8*i#5t1LklwI&`CSU_uTp&@-Hi1k+rxGJL?>!qr=Kq?{cT3jzY#+sRJatN?5Z}k2dx~P+1iDsZ~CPa#22JR1AN*ze58IZ1*EMPmV)_3>Z0lf+818;}tRf+d9BRp8nXe7Y ze7{+_UVibr#*@D7iA!#^rA6|=H!YA4LB>VPaK|7Iu$t0zs4eRIYv`aZ>Ktl2sn152 z^Kc_QIQ>?nO|1OXgl_$PZ4ZB&+!~xFkz?o%Qp5L;JiIxR@;?qmx8XixIn)@)h^>;}_WT+!mEflU012u|L_t(}2(|B#t{#kd?CoeRjs`-<4hh$g+4kS? zgpz%!W4P?2-Vs?FiWnvGh=_VuI^}=t{`TK)lrwxCkMbau{dmu>y6$OR{#1EO@hxKY z=ai`3P&J}3G)Mu%EuEQ+tlF_x{_c?%W#!i0(w#ZE8so87aPIGx|1N6FKErJn^@lQ| z!a-y9h>}0Et@@VxS|3_L4uLag{2mfQ=d5Uu2Qv4zJ@~iAnrlVu|G-qzk~?=Ux2zvS zWazUc{9?ohk*R<7>}q*!$39{9rM$LtpZx5()$-flZvxo^Ia~Hbn(u2{zU|)D<#%~Trdz9ii{s!nex9>;hzio4 zRUv`5W0~(iP+R}tH_r;?KVE2*?aeJ9_}RB-1n}8nl6(;NT+_ghyZ*NAv8}^FhhZUf z&QR3XhdbJO-iu1<4@A_HBEoVOo3s1opdrvaSHka%<6nO|SL64Ndu}ML5Oc^=SYG+C z+jiu=7Cbak9=dj^wBSLaIb=v@M9rUY7k%vu8FbS+g9|!AGI%kjxjzXjf3obP6(Z|^#OSx zyr=D<2g0m#nNsR&n5MXXojW~b+bBaLxaJ%Q%MHr88lU41*E%#}#`WZ7>9kz_!?uTi z&X^8a;0B@SoDw=`?~givyRWR}p>HMK_-qmN6A=j@L-yuudo)vl)MtC2*_m>-$31R5 zUe!Z}Xrzc5|4C_Qc6rOg-~4gc!+WEn=lI8nAT%*(5HWm)zq+;O~O2=D~yH;&ERj%4}~ELiT! zojb+tcH^mGS+qaJ@z0B9p_a`sbQ|P!-c8Dwk2;cGdT-0}Uo{|Br_=dtv>kDT&N=;q z))lSywJrZym)lW~V?V%($Ubae@7)CQe_bl{qWcd&-1vG8$IgjF7**&?mDsz1+v_I{ zEBlnlJ%@E`{SS{U|5wIlq+&Q?37rG%9}Yg-j;ZZVZPl#+|L;YjDNpiVM+i1P-@L!I z=o%_8;(K;#JI=o>R6e{Jv|kW6kXYfazre-HoexbsIF zj-Z_pPiSJ*AS;4OPtG<|E=bs+2Cwbh^gQnfO_+WpD>_&Mo~y-s z+Avy$m=arJd_MYkUK09z1bWQovxC4Gq0bJ^W22lU1WqROSsHk3m@x!~1Az;I&=yYv zh6B9wNvFo6z4OW5!P=?^+wpwr7izT!_MY|sSb#A{%DdfYGc3ROHBubR*>OP-y1FI( zl2~;tJ8V?0}?rnSM$0qB%0}l=UKSYSX zDTd9ceuq}JuXQq0!{y`Oqw&*)MChaOyRYT9t2^D;f3r&7kAu%ZMnk}@74@u9 z>h^>i`|#i5aW?$A;NBM`p&8|WIQYx<`&yPi=(Zktn^pEJIL809dp_3JwE3V?>MLGr z+gtB%dFZ$Ny>j?-LB20oLg%9WeR;n7zLw?Rw2r+NHRnAz-VQlm-t9`%J)sl6>4z;3 zeUrYNFNP17ofkZzbHLtrX!-hz*5!9dTF*gk%`e5tof?r_CGu35w)JJLi|_c+p-0z^ zNF&dIQDcP81HD1S+}r(NQ!FO$wAOx1S@|7K59$AJguXL*)0P*lRUh+X_NE_oFW=Q5 z4J_w^t7tv_hpi9& z%{|SJ?LVsqoM(fL5&9GY_WeV@-knbQA43`PUNrv=gi}8@!}<;|-fuzX_hwQa|93UN zjAvK}91Y6H2z`oC=2pNtij?{TbTlJ}mw#;q!w>|V2$HrJ= z?^8fvjL@e5QhpOWMi1_7UH%WAwwG8d-$W^LApiUs?OQnjecv*YU33rXWR7)4`YyoE zV}u@L(EsM}@?HDeik3;J>#f*-hZ<&9YzFn_O2vOBK q2LJ&7|KI^lvH$=821!IgR09Cvb9d^lw|{K_0000q-{0c5}&Ux|n47%=59;{hk)`B+vK!y!jPrhugLTM%%l0>$e}D z1i-hu&ZYs3NT_t)KXCR&AvH0Jw9Z zIi<*5^&9%MZU7@D$irpl^P6vSMOY>`SjPblO;*YdGHZwM0k}1GG2USIK8gK#5WDx= zvT|R4DD+(^ae&WT)+9V+kVOZBL`HGZ>h8OwwFOdzp`TCl{bN0mNb9lFrw%_aEG(7* z@Qz5U3=bKtEtoLoHW|POEcNoY(2sq;{rpk96()x#N!jaA@$5(Q>m< zbkG%}nYcSNR&(HCbL?r1#kZ z4Q8iDP8`{HC_ABcj!&o(0N(6t@Q^_k9sUk*F#g&_!?p=1S(f>FbUa|V4mCD~IKW@| z8a!k;{0_SO4ifsnO%hnl7Zib!r&Ox!Q$eM7*(nZiXz+*x1zFx9s{IEnv5x_gb01j=%UU%AZz-l6?_h*3B^;YM900030|6jqSRsaA121!Ig aR09A-DQ+>#yEC(w?#%4Y?{@~K6bfv?2jgZ==j-{;cfPa# z%%8Q>fBgA>1`tHHGexV0bQQdl4xk1{_wc#v9XR|{X$n2Q=KhC3yI8R5z5+vWMW*LJp7M6_Mq7H?S3sO=a9D!|+QH@+Q~<%>+Fu+V9Zh7p zhb}FAKT+tah*g}QAKiD}4xk2yHAT;SB+Od$3?@X<0D{QYCP~BZpEG`uc;Vp+skY;d z?yt|L%k2RQ#dQ;?!AZQ9@}m1&J8RK{UP|&w8UXzxhpV=LUsD&GWJfO6{-re>X{5BQbQnHZ~Pc?KTMui zu%YY3g%zFUwP_vSoShO*7~5&y zb1(kVWD6=0^Ck1Q`l<1|nk;`u;KKz=Jj+)+U9{dX#9r(IfO^I3qV-t2+}^X!)N&7Y z1s?EfI_RC-o1MPTUOet2%)-p;0i2^8A*bW@Sg-b4^Qk|F`L9T%=-LgCRh3%n+sW71 z>9!9~Eh`$Gxh*b?DxhdY<~F-CwanM(b3+5|&HeCfVsAC3uf#0knSBAO$){J>wW%jg zHd1a*CXLS@Lovgk5TADo!_tRR!IV4-N>Ms~tf6(5E8AH2?xKcL#;kGK2DpF}F*VRy zeYnF5o)54$W?&X(_63+(y`<*RLfvx70jlfMgG@QVCtRZVjcTJ)e79`J%^6?~^;8MC{Bq0*ZDwYeS5J_2g? z*X4pMnLBL`G<7ZQ<>*0gls(`DPj~}zT;+EVQ;FpBz0}U*tXCM#0_~!3;nR?$}XpAvDWMrb$61x)Z(dfS00O_r@%1ZeM^9 zA1UP)keHKoy|e0EYhf&=JH}=?sh7VMo^5}C8#~%b{E|7Km2|BWy|?njtN{U{@i@g2 zi2(o~RMptUZdp6MKmh)2 z`JPEHW5C4+f`1Ft1inpuyPx(PuOoYkot`V2WWt`}FgdJ2BWbFEo&iQz3GcO8=6`*U zV3S-~JxWri+eo~{D*;B|WpJga*Pmf9rqUtr1u8#X&+wWxIK_I-tmT%QuK>I;1GB8f zz5pOf=}3H&RJqwm#3$K=*ved-!EdTCPe_=zyg$0aQR)gus6U4qFD8*dZnVd3!nPHj z@Wu>~7H$Y&ky}ECdN}23VM`PWL#!SpJ2&Fo!0`qMK?cB${ds839`1$5Ci$n;@MPJ5 z>y`m5Rax-h`7)()Ybov3*i^RoWOr!MgI;*tNq@%x{TsuQx>yLYmbozFc30SD6KkQp zo#1AE&;h=nBe62Ma)YkwlLd+0{4M#qYR3gX2(cS$n}U@apuM`1{O_OxSXQ6+R&|z1 zm9N+gdoc<64L}VJYnE=X_qz^Y8E;m8JiS@|`3VzTTm0SL?>WG=LA?9C2iN|&m)pM$ wkaT7;!9M^10RR6Gp`*Y6000I_L_t&o00qmgeFAoV8vp#C6-yLU5L5<`GD96Y>Ied&v4CO(!Ct@; zH@eY{Es053O^At!F^R?)L#!wQGXtpT+4H~g2@Eq}vgq!&pMKx*op4NUS6hNtZ6V`#xq8$$GX89Qd%kmqb#~RwC)cJ{-akIy+yU z`acI#xfJqy6Wf`{yV{$`ciWpPND$UxYYfM(z48^umlPfCOyw6I#XiJ z%oH9RQ)F)ve8{O)JrTq901j&!zH~0! z{&?P`zb4;+Ko@M>2A?JDix|X$gG55jRKuuc_9BgE&SG_fSKs`j?mgC%SKn=d==8i} zXoC-Y5u=N(`JdP(l~8lP!H0j1ay~^$yKe-64jbCwQxj2NbG0sdDKnv@AlB(Th1#AK z1UhW49gAM7AE9rdX0eXyw(FOw<1->liy_dJcItM9IO|7?_kNaz&o2)x?FNA^+W*(C zc+YfUk+R_Z^z7dg=>d${2FL9a$;zOPEfvG@D`D@&`5V zO!))Y(5CO-RQFQJ>=*f46VEx-HF+O?hs5s@x7nEm$HTX&@`^b8y!#u-tNR{4;vwHJ z#{DPPac1)Vb}i+Bt`gM?ocs9W=T{s@3}S)tdpV445!BkwBsi1v-9t`06Gf!GX)xyx z;peOo2eS}@)VAsL;N!X&;qP8&9?6 ztxl;lg{8Jlrzafec@bR)gE4`P9EN!C^IEe18NHZ?>l{V&>@Pkt?pAX7TM&ENB5 zG~e}tAkbkG+u>^7~IMAL*0s>HcghrUp`U7~2`Z)sNF zeDdtQt&YGJ?es0F+4 zu2=h89VTy!eFg<>@PY5tPU*Rb)yEU=IhTI^sa?`luAS3K6t^y`c-?k{5+?sZUi~)T zo7rni{=7-Wp{td53KKmpP~fmnXs+c)fv+wzqeRyKpUYOrwF|NnHNP#K@e@I^SclEW z;?UdhfiGeZ3k+ZZQ^fb^v2kJT{8-1x`t48ivph%(TrW}lr1QL<+{bI&J�T`w1=b zxXx?LLiGK)IQ!e9*QZsiTAx<63OZ~Nm6aOwKDWWgSiHjKGeQpJLQb&{ z6g}wADrdEMX4KuBbmh&8Nc++f=-s-aMHGG~@w+DL z(LL{;sUWm0SZyhzdy!T??@)5x{ zC9iw+*-=pEozrI2nr4Ifp}X&J`?16aT`^vs4Bz=q!=ewOsSE62G? z=s9nT{hzS#*xye<)*n*R%qtW(;hZ4QVH4Zoix|WLqp{yM#yE&}aQ7c_yg+XK870rT zM5(f?JWo!Lzvah*K!*)&@PV)Hc}0Z(U&=cm4gaQlJXX{!bl^qP%y@0!<_t~IV&BUYA-!Nw(6%{4TLfRkG);c|7N_RQ z5!c$qJ86e(&D3UY%PP|(d)}m&aW@2k4jbBH9Lk0&N2uF1KF%NS=|wLU|0c)04Ig!@ z)7x`3Z|-|nvwz!K&3y=T*w6+a_=<5&eZ}5_f`8}7*p_x%zf3(15@TD{4K}njbw0y* zn45*qD=n_Q0@3{rZ;G*IINLMp|8EaqpMo2cBXN2zw2t3UA)q zkv+~H!BeExXe#@bVODmgilh8HTMH3%wlx~~UvzMlxIX7-qUg^1Su2@ywShf?YJQG% z`<8({;QL>sS}*mO*8ZTjVcZ@9^K^8?JA;0mBIV6F2kRRC1pi8Y?2tMy~Q|U_5c6wG!0sFW!I5 z`W=04@L7N%!q>(YZB2D$?3bBqDEHUDxJElcDv>sCD*gf%WAff)0dgQ$Q}ITP*RDnI z2p+E=F}652)ebEJdh7THbHDy^0fs}1klu_f?*HFME+MDk+VNvOn(#hEEyYZc&11Zj z@gM#D8Nb4=g}je*tH>(~aQ{e1E=&lO|>>jr%|4j4xGBOh{K z1{ZK@>>0SvU$8d`nZhwYu{`3duZ#UZ`PG!{N!ZJ+C6jsoYeuiN6zTXmEuM3mrG%8M z!-lrnzR-cu0RP$?egZdeH1-@A)6BJ8pRfKs_>XHH`i5KQ^mS8TUr7+`|IX}_OJQSq z&)EGpigNmjW|?mlB+BU&+lquX_y~K&598ckPhSTt1LLPQcTq0~M`Qf$%>2#uxf;6X z^Ar^m+QnJ8chB9xH3ifaD9Gk8rFaz!`^&O{d&r|_u3$%-N6+;H8}^>z179%?7{KDO zsz3fk&ULQfXq>;&+_6;z&o|?OI)h_tW##lPS;J=a+qJ`Q$osT_>lcgXlv0TFQ}X2f zL*COjC}<0jekS-L2C=jH?cNS7U<%j`g-R$>uztFsF>ll#WYz4yHVq-eY4}qHBfw^1!>Nmim|7k9hM7@)>wQIIn^;E5-yq z@Lf8$?Edn=d#8XgqsNADU;?|L+(H|3G&X)FcS<*(`uYl`L32Gtc1buh6f}%$9V|FMl;Zsht%=v%fRE~}Vn4(n zcD6-cHF%?z6yu`9MlK)Ay*$Tjcs%0NJ;xw_?1g*x${XO(b2Gm$bTZp~C-0-5r8HR? z_sv(lkG_ZZ^S6@k;C)z$occ{nFUeJ_Rg zy&*ZN>U*%Ic$UlH17F01k3At^B)oNAzyvmOAQy5XH#mrVa0SQu`}Vq;K!C;J;Wz^( zv-mrO{OwcXdJxx?u95%HLp1A+UG+q6bDZCYx7g<@1zI2F+uNB13ckU5-UF4qN4VPj7cID#uU zi}L{c+R>IJCP^M8<1?cw@1}Zlf1(aM_$bf8#sz*MqfXMYpc2YZYVRnlz8vY?Hhw3+ zuZtnjVM7~y;QM>9!3CUFGdFMq*SdT0B=mm**Q8bi--%zhu=>Kvuv>Sd98Yu2@B{@8 zEif!r1r%&^QkdJBvE`L(=ie<%o_Wj1qx+_5Tz|ak+$N47tY3jG#q*XdYr)+w;EOqm z*g(cGhL7Cf04_O+%qgfc0UXi)JboVaKgLoPk13UXSu51l7R}JQyc1WoHfNFM;xfgR z2eB^SQ}Bqd>(++&e%_RGlpcBGUQSBI=~a=Hsi+USbL zzkTkT$-cQMRbL|(7>2N8$A3@AowG=LF^jn@a?`qi8#sb1IE&*)ch${d#jsl~!>n(< zrW{*hkt$QW<}Rt)y=9fAdU?oIQjR%C;Uj;fSf>lLCgCP+$<&l*CswRVo?hnc*?apU z*Ul?1@_CPHocYknGSpum>VmBrKJZ1%x}=H>U;qo4z($Vc!B+`6H*junh+{6`1aA6s zx9&4YB&aWk8evWb!wyu1x3e2t?nGd&F?9-6z?j*d16qvNdJEMvk4FD?LK%huq)*F5m=i z`g{%F^SMy2t@&AAc{xH_-X?NPRrh3f_2^9*+T>kZwD0AnX(HyilntHTckcrJ2HnJY zQ$ym}doy|V+H9~5{vJIy-U0@&a#J-CySG;SeM^QqIoV4+8aa_09Kfaicrnbpu2nat z0;MBiHu(4tJzDE5$&no0xURweqwF3x4t;}jDgi?Z zHU|v(l-I+jXhmo-zo&2WIjrxv)_;NmhJH$D6NL33wxJC^@I{QSyJ`8OX45fj7zy(&%(|>q;`EAP zSt@OE{<@0m>z9<$f~iHM9DR|LQX|4`@PRL4)~d8AzyKC78_wF4t$Y?pN-yG!ezeW4 zeo3wxhrFDM-FvocZmo_fc`$$Cb&7fO8by!Q33S-dwtKs#1ilNWmN_C87{JoK7n+h| z!&t%`=JPwD(%YsulN4uPI(S`@I^f-W&B5JUG}_F#GEzx5C2T~j>DZE_Zug&a7jIkRuC<^kluLG^bz zi!~vOW|Y{%2fm0w?9=Y|y8Lx__&vvbnk()2+z5|{aNF`;i`>VTGbp`}FY{DA-iF{X^_Brt!SOZ{90C{rjfu`mC R4DDGDJnMmHJ@9*a;NN0>Jpcdz literal 0 HcmV?d00001 diff --git a/docs/source/_static/site.webmanifest b/docs/source/_static/site.webmanifest new file mode 100644 index 00000000..e7891223 --- /dev/null +++ b/docs/source/_static/site.webmanifest @@ -0,0 +1 @@ +{"name":"Megatron Energon Dataloader Documentation","short_name":"Megatron Energon","icons":[{"src":"/android-chrome-192x192.png","sizes":"192x192","type":"image/png"},{"src":"/android-chrome-512x512.png","sizes":"512x512","type":"image/png"}],"theme_color":"#ED467A","background_color":"#411046","display":"standalone"} \ No newline at end of file diff --git a/docs/source/_templates/favicon.html b/docs/source/_templates/favicon.html new file mode 100644 index 00000000..8de2e6ba --- /dev/null +++ b/docs/source/_templates/favicon.html @@ -0,0 +1,5 @@ + + + + + \ No newline at end of file diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html new file mode 100644 index 00000000..bcd6d81c --- /dev/null +++ b/docs/source/_templates/layout.html @@ -0,0 +1,8 @@ +{# Extend the base theme layout #} +{% extends "!layout.html" %} + +{# Add custom favicon links to the head section #} +{% block extrahead %} + {{ super() }} + {% include "favicon.html" %} +{% endblock %} diff --git a/docs/source/conf.py b/docs/source/conf.py index 0b54693a..8b17f56d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -138,6 +138,9 @@ html_static_path = ["_static"] html_css_files = ["css/custom.css"] +# Favicon configuration +html_favicon = "_static/favicon.ico" + # Custom sidebar templates, must be a dictionary that maps document names # to template names. # From 6742388f9c85c8ebfc1af8d981c7ba7e1279dc5d Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Fri, 22 Aug 2025 16:06:31 +0200 Subject: [PATCH 21/36] Publish dataloader.start --- src/megatron/energon/dataloader/dataloader.py | 6 +++--- tests/test_metadataset.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/megatron/energon/dataloader/dataloader.py b/src/megatron/energon/dataloader/dataloader.py index 63ba28aa..8be531b3 100644 --- a/src/megatron/energon/dataloader/dataloader.py +++ b/src/megatron/energon/dataloader/dataloader.py @@ -202,7 +202,7 @@ def __init__( } ) - def _start(self) -> None: + def start(self) -> None: """Start the workers and restore the state if available.""" self._workers = [ self._worker_type(self._dataset, self._worker_config, local_worker_id, self._cache_pool) @@ -271,7 +271,7 @@ def __del__(self) -> None: def __enter__(self) -> "DataLoader[TSample]": # Already start if using the context manager. This ensures the lifecycle is fixed. # Otherwise, will start when iterating. - self._start() + self.start() return self def __exit__(self, exc_type, exc_value, traceback) -> None: @@ -295,7 +295,7 @@ def _epoch_iter(self) -> Generator[TSample, None, None]: ) if self._workers is None: - self._start() + self.start() assert self._workers is not None, "DataLoader not started" if all(self._exhausted_workers): diff --git a/tests/test_metadataset.py b/tests/test_metadataset.py index ca2cc832..75f9e8b0 100644 --- a/tests/test_metadataset.py +++ b/tests/test_metadataset.py @@ -1370,7 +1370,7 @@ def test_save_restore_next(self): assert same_state == state_initial # This will propagate the state to the workers. - second_loader._start() + second_loader.start() # Save the state again, to check that it is the same as the just restored state same_state = second_loader.save_state_rank() print("same_state:", same_state) From 338eed88af83fdbe6c955e11a9559c55124cd93c Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Mon, 25 Aug 2025 12:57:25 +0200 Subject: [PATCH 22/36] Clean up rng --- src/megatron/energon/task_encoder/base.py | 26 ++++-- .../wrappers/task_encoder_state_dataset.py | 80 ------------------- 2 files changed, 19 insertions(+), 87 deletions(-) delete mode 100644 src/megatron/energon/wrappers/task_encoder_state_dataset.py diff --git a/src/megatron/energon/task_encoder/base.py b/src/megatron/energon/task_encoder/base.py index 2ec124a3..3c7be289 100644 --- a/src/megatron/energon/task_encoder/base.py +++ b/src/megatron/energon/task_encoder/base.py @@ -4,13 +4,13 @@ import dataclasses import functools import inspect +import threading from abc import ABC from dataclasses import is_dataclass from types import MethodType from typing import ( Any, Callable, - ClassVar, Dict, Generator, Generic, @@ -41,7 +41,6 @@ from megatron.energon.flavors.base_dataset import ExtendableDataclassMixin, RestoreKey from megatron.energon.metadataset.loader_interface import DatasetBlendMode, LoadedDataset from megatron.energon.rng import SystemRng, UserRng -from megatron.energon.savable import Savable from megatron.energon.source_info import SourceInfo from megatron.energon.task_encoder.cooking import Cooker from megatron.energon.worker import WorkerConfig @@ -409,7 +408,7 @@ def from_samples(cls: Type[T_batch], samples: Sequence[Sample], **kwargs) -> T_b return cls(**init_args) -class TaskEncoder(Savable, Generic[T_sample, T_encoded_sample, T_raw_batch, T_batch]): +class TaskEncoder(Generic[T_sample, T_encoded_sample, T_raw_batch, T_batch]): """ Base class for task encoders. @@ -435,11 +434,12 @@ class TaskEncoder(Savable, Generic[T_sample, T_encoded_sample, T_raw_batch, T_ba #: The decoder to use for decoding samples. Set manually as needed to override options. decoder: Optional[SampleDecoder] = SampleDecoder() - # Defines which fields are saved and restored when saving and restoring the state of the task encoder. - _state_fields: ClassVar[Tuple[str, ...]] = ("rng",) + #: Thread-local state. Used for properties, that are worker-local. + _worker_local: threading.local - # State fields, they are initialized when the dataloader is started. - rng: UserRng + def __init__(self): + # Create a thread-local state for the workers. + self._worker_local = threading.local() @stateless def cook_crude_sample( @@ -1087,6 +1087,18 @@ def cache(self) -> CachePool: ) return WorkerConfig.active_worker_config._active_state.cache_pool + # State fields, they are initialized when the dataloader is started. + @property + def rng(self) -> UserRng: + """The random generator that should be used within user methods (like `encode_sample`) for reproducibility (and + thus for savability). + """ + if not hasattr(self._worker_local, "rng"): + # Initialize when needed. + self._worker_local.rng = UserRng(WorkerConfig.active_worker_config.worker_seed()) + + return self._worker_local.rng + class DefaultTaskEncoder( TaskEncoder[T_sample, T_encoded_sample, T_raw_batch, T_batch], diff --git a/src/megatron/energon/wrappers/task_encoder_state_dataset.py b/src/megatron/energon/wrappers/task_encoder_state_dataset.py deleted file mode 100644 index 30b0e75e..00000000 --- a/src/megatron/energon/wrappers/task_encoder_state_dataset.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. -# SPDX-License-Identifier: BSD-3-Clause - -from typing import ( - Any, - Dict, - Generic, - Iterator, - TypeVar, -) - -import megatron.energon -from megatron.energon.flavors.base_dataset import RestoreKey, SavableDataset -from megatron.energon.worker import WorkerConfig -from megatron.energon.wrappers.base import BaseWrapperDataset - -T_sample = TypeVar("T_sample") - - -class TaskEncoderStateDataset(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]): - """This dataset wrapper applies a custom function to transform each sample.""" - - # Will save it's own state - _task_encoder: "megatron.energon.TaskEncoder" - _task_encoder_was_reset: bool = False - - _savable_fields = ("_task_encoder",) - - def __init__( - self, - dataset: SavableDataset[T_sample], - task_encoder: "megatron.energon.TaskEncoder", - *, - worker_config: WorkerConfig, - ): - """Construct a wrapper for saving/restoring the state of the task encoder. - The dataset is transparently delegated. - - Args: - dataset: The input dataset to wrap - task_encoder: The task encoder to wrap. - worker_config: Worker configuration. - """ - super().__init__(dataset, worker_config=worker_config) - self._task_encoder = task_encoder - - def reset_state_own(self) -> None: - self._task_encoder_was_reset = False - - def __iter__(self) -> Iterator[T_sample]: - if not self._task_encoder_was_reset: - self._task_encoder_was_reset = True - self._task_encoder.reset_state() - for sample in self.dataset: - yield sample - - def restore_sample(self, restore_key: RestoreKey) -> T_sample: - inner_sample = self.dataset.restore_sample(restore_key) - inner_sample = self._task_encoder.restore_sample(inner_sample) - return inner_sample - - def config(self) -> Dict[str, Any]: - return { - "type": type(self).__qualname__, - "dataset": self.dataset.config(), - "map_fn": self._function_config(self.map_fn), - **( - { - "map_fn_config": ( - self.map_fn_config() if callable(self.map_fn_config) else self.map_fn_config - ) - } - if self.map_fn_config - else {} - ), - "map_fn_stateless": self.stateless_map_fn, - } - - def __str__(self): - return f"TaskEncoderStateDataset(map_fn={self.map_fn}, dataset={self.dataset})" From ea4367a01807a5f191f6a483c0731c8a337fd96c Mon Sep 17 00:00:00 2001 From: Philipp Fischer Date: Thu, 28 Aug 2025 15:50:05 +0200 Subject: [PATCH 23/36] Fix rng assignment --- src/megatron/energon/task_encoder/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/megatron/energon/task_encoder/base.py b/src/megatron/energon/task_encoder/base.py index 3c7be289..908afa2f 100644 --- a/src/megatron/energon/task_encoder/base.py +++ b/src/megatron/energon/task_encoder/base.py @@ -1046,7 +1046,7 @@ def build_val_datasets( def reset_state(self) -> None: """Internally reset the state of the task encoder. This is called when the dataloader is started.""" assert WorkerConfig.active_worker_config is not None, "Must be called within worker" - self.rng = UserRng(WorkerConfig.active_worker_config.worker_seed()) + self._worker_local.rng = UserRng(WorkerConfig.active_worker_config.worker_seed()) # Burrow the save_state and restore_state methods from the SavableDataset class. save_state = SavableDataset.save_state From 355e73338bf4d3fcfe5cf471e91dc947f0da4845 Mon Sep 17 00:00:00 2001 From: Philipp Fischer Date: Fri, 29 Aug 2025 13:19:49 +0200 Subject: [PATCH 24/36] Remove THREAD_SAFE --- src/megatron/energon/flavors/base_dataset.py | 50 +++++++++----------- src/megatron/energon/worker.py | 2 - 2 files changed, 22 insertions(+), 30 deletions(-) diff --git a/src/megatron/energon/flavors/base_dataset.py b/src/megatron/energon/flavors/base_dataset.py index 9e72170f..d5cc3152 100644 --- a/src/megatron/energon/flavors/base_dataset.py +++ b/src/megatron/energon/flavors/base_dataset.py @@ -252,9 +252,6 @@ def save_state(self) -> MyExtendedState: """ -THREAD_SAFE = True - - class SavableDataset(IterableDataset[T_sample], Savable, Generic[T_sample], ABC): """A dataset that can be saved and restored (i.e. the random state, internal buffers, etc.). I.e. it can be resumed from a checkpoint. @@ -281,8 +278,7 @@ class SavableDataset(IterableDataset[T_sample], Savable, Generic[T_sample], ABC) def __init__(self, worker_config: WorkerConfig): self.worker_config = worker_config - if THREAD_SAFE: - self._thread_state = threading.local() + self._thread_state = threading.local() @abstractmethod def len_worker(self, worker_idx: int | None = None) -> int: @@ -412,32 +408,30 @@ def restore_sample(self, restore_key: "RestoreKey") -> T_sample: "This dataset does not support restoring, because it is not safely deterministic." ) - if THREAD_SAFE: - - def __getattribute__(self, name: str) -> Any: - if name in ("_savable_fields", "_state_fields", "_thread_state", "worker_config"): - return object.__getattribute__(self, name) - elif name in self._savable_fields or name in self._state_fields: - try: - return getattr(self._thread_state, name) - except AttributeError: - return object.__getattribute__(self, name) - else: + def __getattribute__(self, name: str) -> Any: + if name in ("_savable_fields", "_state_fields", "_thread_state", "worker_config"): + return object.__getattribute__(self, name) + elif name in self._savable_fields or name in self._state_fields: + try: + return getattr(self._thread_state, name) + except AttributeError: return object.__getattribute__(self, name) + else: + return object.__getattribute__(self, name) - def __delattr__(self, name: str) -> None: - if name in self._savable_fields or name in self._state_fields: - delattr(self._thread_state, name) - else: - object.__delattr__(self, name) + def __delattr__(self, name: str) -> None: + if name in self._savable_fields or name in self._state_fields: + delattr(self._thread_state, name) + else: + object.__delattr__(self, name) - def __setattr__(self, name: str, value: Any) -> None: - if name in ("_savable_fields", "_state_fields", "_thread_state", "worker_config"): - object.__setattr__(self, name, value) - elif name in self._savable_fields or name in self._state_fields: - setattr(self._thread_state, name, value) - else: - object.__setattr__(self, name, value) + def __setattr__(self, name: str, value: Any) -> None: + if name in ("_savable_fields", "_state_fields", "_thread_state", "worker_config"): + object.__setattr__(self, name, value) + elif name in self._savable_fields or name in self._state_fields: + setattr(self._thread_state, name, value) + else: + object.__setattr__(self, name, value) class BaseCoreDatasetFactory(Generic[T_sample], ABC): diff --git a/src/megatron/energon/worker.py b/src/megatron/energon/worker.py index a77d0f0e..5759f3d3 100644 --- a/src/megatron/energon/worker.py +++ b/src/megatron/energon/worker.py @@ -18,8 +18,6 @@ T = TypeVar("T") -THREAD_SAFE = True - class ActiveWorkerState: """ From dc38c7670d95096165db9dd3128ccf4051167241 Mon Sep 17 00:00:00 2001 From: Philipp Fischer Date: Mon, 1 Sep 2025 13:23:09 +0200 Subject: [PATCH 25/36] Add coverage command --- justfile | 5 +++ pyproject.toml | 12 +++++++ uv.lock | 89 ++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 106 insertions(+) diff --git a/justfile b/justfile index d1e43ff7..8dc3ad5f 100644 --- a/justfile +++ b/justfile @@ -27,6 +27,11 @@ check: dev-sync test: dev-sync uv run -m unittest discover -v -s tests +coverage: dev-sync + uv run -m coverage run -m unittest discover -v -s tests + uv run -m coverage html + echo "Coverage report generated at ./htmlcov/index.html" + # Build the docs docs: dev-sync uv run sphinx-build -b html docs/source docs/build diff --git a/pyproject.toml b/pyproject.toml index 579ba321..8780df99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ [project.optional-dependencies] dev = [ + "coverage", "ruff", "sphinxcontrib-napoleon", "sphinx", @@ -108,3 +109,14 @@ exclude = [ "docs", ] +[tool.coverage.run] +branch = true +parallel = true +concurrency = ["multiprocessing"] + +[tool.coverage.report] +show_missing = true +skip_covered = true + +[tool.coverage.html] +show_contexts = true diff --git a/uv.lock b/uv.lock index 9cd16693..8c444d92 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.13'", @@ -574,6 +575,91 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335 }, ] +[[package]] +name = "coverage" +version = "7.10.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/14/70/025b179c993f019105b79575ac6edb5e084fb0f0e63f15cdebef4e454fb5/coverage-7.10.6.tar.gz", hash = "sha256:f644a3ae5933a552a29dbb9aa2f90c677a875f80ebea028e5a52a4f429044b90", size = 823736 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a8/1d/2e64b43d978b5bd184e0756a41415597dfef30fcbd90b747474bd749d45f/coverage-7.10.6-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:70e7bfbd57126b5554aa482691145f798d7df77489a177a6bef80de78860a356", size = 217025 }, + { url = "https://files.pythonhosted.org/packages/23/62/b1e0f513417c02cc10ef735c3ee5186df55f190f70498b3702d516aad06f/coverage-7.10.6-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e41be6f0f19da64af13403e52f2dec38bbc2937af54df8ecef10850ff8d35301", size = 217419 }, + { url = "https://files.pythonhosted.org/packages/e7/16/b800640b7a43e7c538429e4d7223e0a94fd72453a1a048f70bf766f12e96/coverage-7.10.6-cp310-cp310-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:c61fc91ab80b23f5fddbee342d19662f3d3328173229caded831aa0bd7595460", size = 244180 }, + { url = "https://files.pythonhosted.org/packages/fb/6f/5e03631c3305cad187eaf76af0b559fff88af9a0b0c180d006fb02413d7a/coverage-7.10.6-cp310-cp310-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:10356fdd33a7cc06e8051413140bbdc6f972137508a3572e3f59f805cd2832fd", size = 245992 }, + { url = "https://files.pythonhosted.org/packages/eb/a1/f30ea0fb400b080730125b490771ec62b3375789f90af0bb68bfb8a921d7/coverage-7.10.6-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:80b1695cf7c5ebe7b44bf2521221b9bb8cdf69b1f24231149a7e3eb1ae5fa2fb", size = 247851 }, + { url = "https://files.pythonhosted.org/packages/02/8e/cfa8fee8e8ef9a6bb76c7bef039f3302f44e615d2194161a21d3d83ac2e9/coverage-7.10.6-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:2e4c33e6378b9d52d3454bd08847a8651f4ed23ddbb4a0520227bd346382bbc6", size = 245891 }, + { url = "https://files.pythonhosted.org/packages/93/a9/51be09b75c55c4f6c16d8d73a6a1d46ad764acca0eab48fa2ffaef5958fe/coverage-7.10.6-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:c8a3ec16e34ef980a46f60dc6ad86ec60f763c3f2fa0db6d261e6e754f72e945", size = 243909 }, + { url = "https://files.pythonhosted.org/packages/e9/a6/ba188b376529ce36483b2d585ca7bdac64aacbe5aa10da5978029a9c94db/coverage-7.10.6-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7d79dabc0a56f5af990cc6da9ad1e40766e82773c075f09cc571e2076fef882e", size = 244786 }, + { url = "https://files.pythonhosted.org/packages/d0/4c/37ed872374a21813e0d3215256180c9a382c3f5ced6f2e5da0102fc2fd3e/coverage-7.10.6-cp310-cp310-win32.whl", hash = "sha256:86b9b59f2b16e981906e9d6383eb6446d5b46c278460ae2c36487667717eccf1", size = 219521 }, + { url = "https://files.pythonhosted.org/packages/8e/36/9311352fdc551dec5b973b61f4e453227ce482985a9368305880af4f85dd/coverage-7.10.6-cp310-cp310-win_amd64.whl", hash = "sha256:e132b9152749bd33534e5bd8565c7576f135f157b4029b975e15ee184325f528", size = 220417 }, + { url = "https://files.pythonhosted.org/packages/d4/16/2bea27e212c4980753d6d563a0803c150edeaaddb0771a50d2afc410a261/coverage-7.10.6-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c706db3cabb7ceef779de68270150665e710b46d56372455cd741184f3868d8f", size = 217129 }, + { url = "https://files.pythonhosted.org/packages/2a/51/e7159e068831ab37e31aac0969d47b8c5ee25b7d307b51e310ec34869315/coverage-7.10.6-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:8e0c38dc289e0508ef68ec95834cb5d2e96fdbe792eaccaa1bccac3966bbadcc", size = 217532 }, + { url = "https://files.pythonhosted.org/packages/e7/c0/246ccbea53d6099325d25cd208df94ea435cd55f0db38099dd721efc7a1f/coverage-7.10.6-cp311-cp311-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:752a3005a1ded28f2f3a6e8787e24f28d6abe176ca64677bcd8d53d6fe2ec08a", size = 247931 }, + { url = "https://files.pythonhosted.org/packages/7d/fb/7435ef8ab9b2594a6e3f58505cc30e98ae8b33265d844007737946c59389/coverage-7.10.6-cp311-cp311-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:689920ecfd60f992cafca4f5477d55720466ad2c7fa29bb56ac8d44a1ac2b47a", size = 249864 }, + { url = "https://files.pythonhosted.org/packages/51/f8/d9d64e8da7bcddb094d511154824038833c81e3a039020a9d6539bf303e9/coverage-7.10.6-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ec98435796d2624d6905820a42f82149ee9fc4f2d45c2c5bc5a44481cc50db62", size = 251969 }, + { url = "https://files.pythonhosted.org/packages/43/28/c43ba0ef19f446d6463c751315140d8f2a521e04c3e79e5c5fe211bfa430/coverage-7.10.6-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b37201ce4a458c7a758ecc4efa92fa8ed783c66e0fa3c42ae19fc454a0792153", size = 249659 }, + { url = "https://files.pythonhosted.org/packages/79/3e/53635bd0b72beaacf265784508a0b386defc9ab7fad99ff95f79ce9db555/coverage-7.10.6-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:2904271c80898663c810a6b067920a61dd8d38341244a3605bd31ab55250dad5", size = 247714 }, + { url = "https://files.pythonhosted.org/packages/4c/55/0964aa87126624e8c159e32b0bc4e84edef78c89a1a4b924d28dd8265625/coverage-7.10.6-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:5aea98383463d6e1fa4e95416d8de66f2d0cb588774ee20ae1b28df826bcb619", size = 248351 }, + { url = "https://files.pythonhosted.org/packages/eb/ab/6cfa9dc518c6c8e14a691c54e53a9433ba67336c760607e299bfcf520cb1/coverage-7.10.6-cp311-cp311-win32.whl", hash = "sha256:e3fb1fa01d3598002777dd259c0c2e6d9d5e10e7222976fc8e03992f972a2cba", size = 219562 }, + { url = "https://files.pythonhosted.org/packages/5b/18/99b25346690cbc55922e7cfef06d755d4abee803ef335baff0014268eff4/coverage-7.10.6-cp311-cp311-win_amd64.whl", hash = "sha256:f35ed9d945bece26553d5b4c8630453169672bea0050a564456eb88bdffd927e", size = 220453 }, + { url = "https://files.pythonhosted.org/packages/d8/ed/81d86648a07ccb124a5cf1f1a7788712b8d7216b593562683cd5c9b0d2c1/coverage-7.10.6-cp311-cp311-win_arm64.whl", hash = "sha256:99e1a305c7765631d74b98bf7dbf54eeea931f975e80f115437d23848ee8c27c", size = 219127 }, + { url = "https://files.pythonhosted.org/packages/26/06/263f3305c97ad78aab066d116b52250dd316e74fcc20c197b61e07eb391a/coverage-7.10.6-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5b2dd6059938063a2c9fee1af729d4f2af28fd1a545e9b7652861f0d752ebcea", size = 217324 }, + { url = "https://files.pythonhosted.org/packages/e9/60/1e1ded9a4fe80d843d7d53b3e395c1db3ff32d6c301e501f393b2e6c1c1f/coverage-7.10.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:388d80e56191bf846c485c14ae2bc8898aa3124d9d35903fef7d907780477634", size = 217560 }, + { url = "https://files.pythonhosted.org/packages/b8/25/52136173c14e26dfed8b106ed725811bb53c30b896d04d28d74cb64318b3/coverage-7.10.6-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:90cb5b1a4670662719591aa92d0095bb41714970c0b065b02a2610172dbf0af6", size = 249053 }, + { url = "https://files.pythonhosted.org/packages/cb/1d/ae25a7dc58fcce8b172d42ffe5313fc267afe61c97fa872b80ee72d9515a/coverage-7.10.6-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:961834e2f2b863a0e14260a9a273aff07ff7818ab6e66d2addf5628590c628f9", size = 251802 }, + { url = "https://files.pythonhosted.org/packages/f5/7a/1f561d47743710fe996957ed7c124b421320f150f1d38523d8d9102d3e2a/coverage-7.10.6-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bf9a19f5012dab774628491659646335b1928cfc931bf8d97b0d5918dd58033c", size = 252935 }, + { url = "https://files.pythonhosted.org/packages/6c/ad/8b97cd5d28aecdfde792dcbf646bac141167a5cacae2cd775998b45fabb5/coverage-7.10.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:99c4283e2a0e147b9c9cc6bc9c96124de9419d6044837e9799763a0e29a7321a", size = 250855 }, + { url = "https://files.pythonhosted.org/packages/33/6a/95c32b558d9a61858ff9d79580d3877df3eb5bc9eed0941b1f187c89e143/coverage-7.10.6-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:282b1b20f45df57cc508c1e033403f02283adfb67d4c9c35a90281d81e5c52c5", size = 248974 }, + { url = "https://files.pythonhosted.org/packages/0d/9c/8ce95dee640a38e760d5b747c10913e7a06554704d60b41e73fdea6a1ffd/coverage-7.10.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:8cdbe264f11afd69841bd8c0d83ca10b5b32853263ee62e6ac6a0ab63895f972", size = 250409 }, + { url = "https://files.pythonhosted.org/packages/04/12/7a55b0bdde78a98e2eb2356771fd2dcddb96579e8342bb52aa5bc52e96f0/coverage-7.10.6-cp312-cp312-win32.whl", hash = "sha256:a517feaf3a0a3eca1ee985d8373135cfdedfbba3882a5eab4362bda7c7cf518d", size = 219724 }, + { url = "https://files.pythonhosted.org/packages/36/4a/32b185b8b8e327802c9efce3d3108d2fe2d9d31f153a0f7ecfd59c773705/coverage-7.10.6-cp312-cp312-win_amd64.whl", hash = "sha256:856986eadf41f52b214176d894a7de05331117f6035a28ac0016c0f63d887629", size = 220536 }, + { url = "https://files.pythonhosted.org/packages/08/3a/d5d8dc703e4998038c3099eaf77adddb00536a3cec08c8dcd556a36a3eb4/coverage-7.10.6-cp312-cp312-win_arm64.whl", hash = "sha256:acf36b8268785aad739443fa2780c16260ee3fa09d12b3a70f772ef100939d80", size = 219171 }, + { url = "https://files.pythonhosted.org/packages/bd/e7/917e5953ea29a28c1057729c1d5af9084ab6d9c66217523fd0e10f14d8f6/coverage-7.10.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ffea0575345e9ee0144dfe5701aa17f3ba546f8c3bb48db62ae101afb740e7d6", size = 217351 }, + { url = "https://files.pythonhosted.org/packages/eb/86/2e161b93a4f11d0ea93f9bebb6a53f113d5d6e416d7561ca41bb0a29996b/coverage-7.10.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:95d91d7317cde40a1c249d6b7382750b7e6d86fad9d8eaf4fa3f8f44cf171e80", size = 217600 }, + { url = "https://files.pythonhosted.org/packages/0e/66/d03348fdd8df262b3a7fb4ee5727e6e4936e39e2f3a842e803196946f200/coverage-7.10.6-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:3e23dd5408fe71a356b41baa82892772a4cefcf758f2ca3383d2aa39e1b7a003", size = 248600 }, + { url = "https://files.pythonhosted.org/packages/73/dd/508420fb47d09d904d962f123221bc249f64b5e56aa93d5f5f7603be475f/coverage-7.10.6-cp313-cp313-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:0f3f56e4cb573755e96a16501a98bf211f100463d70275759e73f3cbc00d4f27", size = 251206 }, + { url = "https://files.pythonhosted.org/packages/e9/1f/9020135734184f439da85c70ea78194c2730e56c2d18aee6e8ff1719d50d/coverage-7.10.6-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:db4a1d897bbbe7339946ffa2fe60c10cc81c43fab8b062d3fcb84188688174a4", size = 252478 }, + { url = "https://files.pythonhosted.org/packages/a4/a4/3d228f3942bb5a2051fde28c136eea23a761177dc4ff4ef54533164ce255/coverage-7.10.6-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:d8fd7879082953c156d5b13c74aa6cca37f6a6f4747b39538504c3f9c63d043d", size = 250637 }, + { url = "https://files.pythonhosted.org/packages/36/e3/293dce8cdb9a83de971637afc59b7190faad60603b40e32635cbd15fbf61/coverage-7.10.6-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:28395ca3f71cd103b8c116333fa9db867f3a3e1ad6a084aa3725ae002b6583bc", size = 248529 }, + { url = "https://files.pythonhosted.org/packages/90/26/64eecfa214e80dd1d101e420cab2901827de0e49631d666543d0e53cf597/coverage-7.10.6-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:61c950fc33d29c91b9e18540e1aed7d9f6787cc870a3e4032493bbbe641d12fc", size = 250143 }, + { url = "https://files.pythonhosted.org/packages/3e/70/bd80588338f65ea5b0d97e424b820fb4068b9cfb9597fbd91963086e004b/coverage-7.10.6-cp313-cp313-win32.whl", hash = "sha256:160c00a5e6b6bdf4e5984b0ef21fc860bc94416c41b7df4d63f536d17c38902e", size = 219770 }, + { url = "https://files.pythonhosted.org/packages/a7/14/0b831122305abcc1060c008f6c97bbdc0a913ab47d65070a01dc50293c2b/coverage-7.10.6-cp313-cp313-win_amd64.whl", hash = "sha256:628055297f3e2aa181464c3808402887643405573eb3d9de060d81531fa79d32", size = 220566 }, + { url = "https://files.pythonhosted.org/packages/83/c6/81a83778c1f83f1a4a168ed6673eeedc205afb562d8500175292ca64b94e/coverage-7.10.6-cp313-cp313-win_arm64.whl", hash = "sha256:df4ec1f8540b0bcbe26ca7dd0f541847cc8a108b35596f9f91f59f0c060bfdd2", size = 219195 }, + { url = "https://files.pythonhosted.org/packages/d7/1c/ccccf4bf116f9517275fa85047495515add43e41dfe8e0bef6e333c6b344/coverage-7.10.6-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:c9a8b7a34a4de3ed987f636f71881cd3b8339f61118b1aa311fbda12741bff0b", size = 218059 }, + { url = "https://files.pythonhosted.org/packages/92/97/8a3ceff833d27c7492af4f39d5da6761e9ff624831db9e9f25b3886ddbca/coverage-7.10.6-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:8dd5af36092430c2b075cee966719898f2ae87b636cefb85a653f1d0ba5d5393", size = 218287 }, + { url = "https://files.pythonhosted.org/packages/92/d8/50b4a32580cf41ff0423777a2791aaf3269ab60c840b62009aec12d3970d/coverage-7.10.6-cp313-cp313t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:b0353b0f0850d49ada66fdd7d0c7cdb0f86b900bb9e367024fd14a60cecc1e27", size = 259625 }, + { url = "https://files.pythonhosted.org/packages/7e/7e/6a7df5a6fb440a0179d94a348eb6616ed4745e7df26bf2a02bc4db72c421/coverage-7.10.6-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:d6b9ae13d5d3e8aeca9ca94198aa7b3ebbc5acfada557d724f2a1f03d2c0b0df", size = 261801 }, + { url = "https://files.pythonhosted.org/packages/3a/4c/a270a414f4ed5d196b9d3d67922968e768cd971d1b251e1b4f75e9362f75/coverage-7.10.6-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:675824a363cc05781b1527b39dc2587b8984965834a748177ee3c37b64ffeafb", size = 264027 }, + { url = "https://files.pythonhosted.org/packages/9c/8b/3210d663d594926c12f373c5370bf1e7c5c3a427519a8afa65b561b9a55c/coverage-7.10.6-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:692d70ea725f471a547c305f0d0fc6a73480c62fb0da726370c088ab21aed282", size = 261576 }, + { url = "https://files.pythonhosted.org/packages/72/d0/e1961eff67e9e1dba3fc5eb7a4caf726b35a5b03776892da8d79ec895775/coverage-7.10.6-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:851430a9a361c7a8484a36126d1d0ff8d529d97385eacc8dfdc9bfc8c2d2cbe4", size = 259341 }, + { url = "https://files.pythonhosted.org/packages/3a/06/d6478d152cd189b33eac691cba27a40704990ba95de49771285f34a5861e/coverage-7.10.6-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:d9369a23186d189b2fc95cc08b8160ba242057e887d766864f7adf3c46b2df21", size = 260468 }, + { url = "https://files.pythonhosted.org/packages/ed/73/737440247c914a332f0b47f7598535b29965bf305e19bbc22d4c39615d2b/coverage-7.10.6-cp313-cp313t-win32.whl", hash = "sha256:92be86fcb125e9bda0da7806afd29a3fd33fdf58fba5d60318399adf40bf37d0", size = 220429 }, + { url = "https://files.pythonhosted.org/packages/bd/76/b92d3214740f2357ef4a27c75a526eb6c28f79c402e9f20a922c295c05e2/coverage-7.10.6-cp313-cp313t-win_amd64.whl", hash = "sha256:6b3039e2ca459a70c79523d39347d83b73f2f06af5624905eba7ec34d64d80b5", size = 221493 }, + { url = "https://files.pythonhosted.org/packages/fc/8e/6dcb29c599c8a1f654ec6cb68d76644fe635513af16e932d2d4ad1e5ac6e/coverage-7.10.6-cp313-cp313t-win_arm64.whl", hash = "sha256:3fb99d0786fe17b228eab663d16bee2288e8724d26a199c29325aac4b0319b9b", size = 219757 }, + { url = "https://files.pythonhosted.org/packages/d3/aa/76cf0b5ec00619ef208da4689281d48b57f2c7fde883d14bf9441b74d59f/coverage-7.10.6-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:6008a021907be8c4c02f37cdc3ffb258493bdebfeaf9a839f9e71dfdc47b018e", size = 217331 }, + { url = "https://files.pythonhosted.org/packages/65/91/8e41b8c7c505d398d7730206f3cbb4a875a35ca1041efc518051bfce0f6b/coverage-7.10.6-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:5e75e37f23eb144e78940b40395b42f2321951206a4f50e23cfd6e8a198d3ceb", size = 217607 }, + { url = "https://files.pythonhosted.org/packages/87/7f/f718e732a423d442e6616580a951b8d1ec3575ea48bcd0e2228386805e79/coverage-7.10.6-cp314-cp314-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:0f7cb359a448e043c576f0da00aa8bfd796a01b06aa610ca453d4dde09cc1034", size = 248663 }, + { url = "https://files.pythonhosted.org/packages/e6/52/c1106120e6d801ac03e12b5285e971e758e925b6f82ee9b86db3aa10045d/coverage-7.10.6-cp314-cp314-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:c68018e4fc4e14b5668f1353b41ccf4bc83ba355f0e1b3836861c6f042d89ac1", size = 251197 }, + { url = "https://files.pythonhosted.org/packages/3d/ec/3a8645b1bb40e36acde9c0609f08942852a4af91a937fe2c129a38f2d3f5/coverage-7.10.6-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:cd4b2b0707fc55afa160cd5fc33b27ccbf75ca11d81f4ec9863d5793fc6df56a", size = 252551 }, + { url = "https://files.pythonhosted.org/packages/a1/70/09ecb68eeb1155b28a1d16525fd3a9b65fbe75337311a99830df935d62b6/coverage-7.10.6-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:4cec13817a651f8804a86e4f79d815b3b28472c910e099e4d5a0e8a3b6a1d4cb", size = 250553 }, + { url = "https://files.pythonhosted.org/packages/c6/80/47df374b893fa812e953b5bc93dcb1427a7b3d7a1a7d2db33043d17f74b9/coverage-7.10.6-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:f2a6a8e06bbda06f78739f40bfb56c45d14eb8249d0f0ea6d4b3d48e1f7c695d", size = 248486 }, + { url = "https://files.pythonhosted.org/packages/4a/65/9f98640979ecee1b0d1a7164b589de720ddf8100d1747d9bbdb84be0c0fb/coverage-7.10.6-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:081b98395ced0d9bcf60ada7661a0b75f36b78b9d7e39ea0790bb4ed8da14747", size = 249981 }, + { url = "https://files.pythonhosted.org/packages/1f/55/eeb6603371e6629037f47bd25bef300387257ed53a3c5fdb159b7ac8c651/coverage-7.10.6-cp314-cp314-win32.whl", hash = "sha256:6937347c5d7d069ee776b2bf4e1212f912a9f1f141a429c475e6089462fcecc5", size = 220054 }, + { url = "https://files.pythonhosted.org/packages/15/d1/a0912b7611bc35412e919a2cd59ae98e7ea3b475e562668040a43fb27897/coverage-7.10.6-cp314-cp314-win_amd64.whl", hash = "sha256:adec1d980fa07e60b6ef865f9e5410ba760e4e1d26f60f7e5772c73b9a5b0713", size = 220851 }, + { url = "https://files.pythonhosted.org/packages/ef/2d/11880bb8ef80a45338e0b3e0725e4c2d73ffbb4822c29d987078224fd6a5/coverage-7.10.6-cp314-cp314-win_arm64.whl", hash = "sha256:a80f7aef9535442bdcf562e5a0d5a5538ce8abe6bb209cfbf170c462ac2c2a32", size = 219429 }, + { url = "https://files.pythonhosted.org/packages/83/c0/1f00caad775c03a700146f55536ecd097a881ff08d310a58b353a1421be0/coverage-7.10.6-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:0de434f4fbbe5af4fa7989521c655c8c779afb61c53ab561b64dcee6149e4c65", size = 218080 }, + { url = "https://files.pythonhosted.org/packages/a9/c4/b1c5d2bd7cc412cbeb035e257fd06ed4e3e139ac871d16a07434e145d18d/coverage-7.10.6-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:6e31b8155150c57e5ac43ccd289d079eb3f825187d7c66e755a055d2c85794c6", size = 218293 }, + { url = "https://files.pythonhosted.org/packages/3f/07/4468d37c94724bf6ec354e4ec2f205fda194343e3e85fd2e59cec57e6a54/coverage-7.10.6-cp314-cp314t-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:98cede73eb83c31e2118ae8d379c12e3e42736903a8afcca92a7218e1f2903b0", size = 259800 }, + { url = "https://files.pythonhosted.org/packages/82/d8/f8fb351be5fee31690cd8da768fd62f1cfab33c31d9f7baba6cd8960f6b8/coverage-7.10.6-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:f863c08f4ff6b64fa8045b1e3da480f5374779ef187f07b82e0538c68cb4ff8e", size = 261965 }, + { url = "https://files.pythonhosted.org/packages/e8/70/65d4d7cfc75c5c6eb2fed3ee5cdf420fd8ae09c4808723a89a81d5b1b9c3/coverage-7.10.6-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2b38261034fda87be356f2c3f42221fdb4171c3ce7658066ae449241485390d5", size = 264220 }, + { url = "https://files.pythonhosted.org/packages/98/3c/069df106d19024324cde10e4ec379fe2fb978017d25e97ebee23002fbadf/coverage-7.10.6-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:0e93b1476b79eae849dc3872faeb0bf7948fd9ea34869590bc16a2a00b9c82a7", size = 261660 }, + { url = "https://files.pythonhosted.org/packages/fc/8a/2974d53904080c5dc91af798b3a54a4ccb99a45595cc0dcec6eb9616a57d/coverage-7.10.6-cp314-cp314t-musllinux_1_2_i686.whl", hash = "sha256:ff8a991f70f4c0cf53088abf1e3886edcc87d53004c7bb94e78650b4d3dac3b5", size = 259417 }, + { url = "https://files.pythonhosted.org/packages/30/38/9616a6b49c686394b318974d7f6e08f38b8af2270ce7488e879888d1e5db/coverage-7.10.6-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:ac765b026c9f33044419cbba1da913cfb82cca1b60598ac1c7a5ed6aac4621a0", size = 260567 }, + { url = "https://files.pythonhosted.org/packages/76/16/3ed2d6312b371a8cf804abf4e14895b70e4c3491c6e53536d63fd0958a8d/coverage-7.10.6-cp314-cp314t-win32.whl", hash = "sha256:441c357d55f4936875636ef2cfb3bee36e466dcf50df9afbd398ce79dba1ebb7", size = 220831 }, + { url = "https://files.pythonhosted.org/packages/d5/e5/d38d0cb830abede2adb8b147770d2a3d0e7fecc7228245b9b1ae6c24930a/coverage-7.10.6-cp314-cp314t-win_amd64.whl", hash = "sha256:073711de3181b2e204e4870ac83a7c4853115b42e9cd4d145f2231e12d670930", size = 221950 }, + { url = "https://files.pythonhosted.org/packages/f4/51/e48e550f6279349895b0ffcd6d2a690e3131ba3a7f4eafccc141966d4dea/coverage-7.10.6-cp314-cp314t-win_arm64.whl", hash = "sha256:137921f2bac5559334ba66122b753db6dc5d1cf01eb7b64eb412bb0d064ef35b", size = 219969 }, + { url = "https://files.pythonhosted.org/packages/44/0c/50db5379b615854b5cf89146f8f5bd1d5a9693d7f3a987e269693521c404/coverage-7.10.6-py3-none-any.whl", hash = "sha256:92c4ecf6bf11b2e85fd4d8204814dc26e6a19f0c9d938c207c5cb0eadfcabbe3", size = 208986 }, +] + [[package]] name = "cryptography" version = "3.4.7" @@ -1094,6 +1180,7 @@ azure-storage-blob = [ { name = "multi-storage-client", extra = ["azure-storage-blob"] }, ] dev = [ + { name = "coverage" }, { name = "myst-parser" }, { name = "ruff" }, { name = "soundfile" }, @@ -1125,6 +1212,7 @@ requires-dist = [ { name = "bitstring", marker = "extra == 'av-decode'", specifier = ">=4.2.3" }, { name = "braceexpand" }, { name = "click" }, + { name = "coverage", marker = "extra == 'dev'" }, { name = "dataslots", marker = "python_full_version < '3.10'" }, { name = "ebmlite", marker = "extra == 'av-decode'", specifier = ">=3.3.1" }, { name = "filetype", marker = "extra == 'av-decode'", specifier = ">=1.2.0" }, @@ -1154,6 +1242,7 @@ requires-dist = [ { name = "tqdm" }, { name = "webdataset" }, ] +provides-extras = ["aistore", "av-decode", "azure-storage-blob", "dev", "google-cloud-storage", "guess-content", "oci", "s3", "transforms"] [[package]] name = "mfusepy" From 7ff6c1896ed825786f3f0e6f5589a270a3aa459e Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Mon, 1 Sep 2025 15:54:49 +0200 Subject: [PATCH 26/36] _state_field -> _worker_local_fields --- src/megatron/energon/flavors/base_dataset.py | 12 ++++++------ .../energon/flavors/webdataset/sample_loader.py | 2 +- src/megatron/energon/task_encoder/base.py | 9 --------- src/megatron/energon/wrappers/batch_dataset.py | 2 +- src/megatron/energon/wrappers/group_batch_dataset.py | 2 +- src/megatron/energon/wrappers/map_dataset.py | 2 +- src/megatron/energon/wrappers/packing_dataset.py | 2 +- 7 files changed, 11 insertions(+), 20 deletions(-) diff --git a/src/megatron/energon/flavors/base_dataset.py b/src/megatron/energon/flavors/base_dataset.py index d5cc3152..53343a61 100644 --- a/src/megatron/energon/flavors/base_dataset.py +++ b/src/megatron/energon/flavors/base_dataset.py @@ -274,7 +274,7 @@ class SavableDataset(IterableDataset[T_sample], Savable, Generic[T_sample], ABC) _savable_fields: ClassVar[Tuple[str, ...]] = () #: List of names of the fields that are not saved, but are still part of the state (i.e. not shared between workers). - _state_fields: ClassVar[Tuple[str, ...]] = () + _worker_local_fields: ClassVar[Tuple[str, ...]] = () def __init__(self, worker_config: WorkerConfig): self.worker_config = worker_config @@ -409,9 +409,9 @@ def restore_sample(self, restore_key: "RestoreKey") -> T_sample: ) def __getattribute__(self, name: str) -> Any: - if name in ("_savable_fields", "_state_fields", "_thread_state", "worker_config"): + if name in ("_savable_fields", "_worker_local_fields", "_thread_state", "worker_config"): return object.__getattribute__(self, name) - elif name in self._savable_fields or name in self._state_fields: + elif name in self._savable_fields or name in self._worker_local_fields: try: return getattr(self._thread_state, name) except AttributeError: @@ -420,15 +420,15 @@ def __getattribute__(self, name: str) -> Any: return object.__getattribute__(self, name) def __delattr__(self, name: str) -> None: - if name in self._savable_fields or name in self._state_fields: + if name in self._savable_fields or name in self._worker_local_fields: delattr(self._thread_state, name) else: object.__delattr__(self, name) def __setattr__(self, name: str, value: Any) -> None: - if name in ("_savable_fields", "_state_fields", "_thread_state", "worker_config"): + if name in ("_savable_fields", "_worker_local_fields", "_thread_state", "worker_config"): object.__setattr__(self, name, value) - elif name in self._savable_fields or name in self._state_fields: + elif name in self._savable_fields or name in self._worker_local_fields: setattr(self._thread_state, name, value) else: object.__setattr__(self, name, value) diff --git a/src/megatron/energon/flavors/webdataset/sample_loader.py b/src/megatron/energon/flavors/webdataset/sample_loader.py index 8d44b411..81f8a010 100644 --- a/src/megatron/energon/flavors/webdataset/sample_loader.py +++ b/src/megatron/energon/flavors/webdataset/sample_loader.py @@ -85,7 +85,7 @@ class WebdatasetSampleLoaderDataset(SavableDataset[RawSampleData]): "_epoch_sample_count", ) - _state_fields = ("_slice_offsets",) + _worker_local_fields = ("_slice_offsets",) def __init__( self, diff --git a/src/megatron/energon/task_encoder/base.py b/src/megatron/energon/task_encoder/base.py index 908afa2f..0b1b0da5 100644 --- a/src/megatron/energon/task_encoder/base.py +++ b/src/megatron/energon/task_encoder/base.py @@ -1043,15 +1043,6 @@ def build_val_datasets( return dataset - def reset_state(self) -> None: - """Internally reset the state of the task encoder. This is called when the dataloader is started.""" - assert WorkerConfig.active_worker_config is not None, "Must be called within worker" - self._worker_local.rng = UserRng(WorkerConfig.active_worker_config.worker_seed()) - - # Burrow the save_state and restore_state methods from the SavableDataset class. - save_state = SavableDataset.save_state - restore_state = SavableDataset.restore_state - @property def current_batch_index(self) -> int: """Returns the current index for the next batch yielded from the current worker. Each batch diff --git a/src/megatron/energon/wrappers/batch_dataset.py b/src/megatron/energon/wrappers/batch_dataset.py index d823749a..b9c8fd5e 100644 --- a/src/megatron/energon/wrappers/batch_dataset.py +++ b/src/megatron/energon/wrappers/batch_dataset.py @@ -58,7 +58,7 @@ class BatchDataset(BaseWrapperDataset[T_batch_sample, T_batch], Generic[T_batch_ _last_batch_failures: int = 0 _savable_fields = ("_sample_index", "_generator_sample_keys", "_generator_offset") - _state_fields = ("_last_batch_failures",) + _worker_local_fields = ("_last_batch_failures",) def __init__( self, diff --git a/src/megatron/energon/wrappers/group_batch_dataset.py b/src/megatron/energon/wrappers/group_batch_dataset.py index b736d081..e6f16281 100644 --- a/src/megatron/energon/wrappers/group_batch_dataset.py +++ b/src/megatron/energon/wrappers/group_batch_dataset.py @@ -158,7 +158,7 @@ class GroupBatchDataset( _last_batch_failures: int = 0 _savable_fields = ("_group_key_sample_index", "_batch_sample_index", "_buckets") - _state_fields = ("_last_batch_failures",) + _worker_local_fields = ("_last_batch_failures",) def __init__( self, diff --git a/src/megatron/energon/wrappers/map_dataset.py b/src/megatron/energon/wrappers/map_dataset.py index 95fbc5e9..11f64fdb 100644 --- a/src/megatron/energon/wrappers/map_dataset.py +++ b/src/megatron/energon/wrappers/map_dataset.py @@ -62,7 +62,7 @@ class MapDataset(BaseWrapperDataset[T_sample, T_sample_out], Generic[T_sample, T "_generator_offset", ) - _state_fields = ("_last_map_failures",) + _worker_local_fields = ("_last_map_failures",) def __init__( self, diff --git a/src/megatron/energon/wrappers/packing_dataset.py b/src/megatron/energon/wrappers/packing_dataset.py index e8086137..d2be6977 100644 --- a/src/megatron/energon/wrappers/packing_dataset.py +++ b/src/megatron/energon/wrappers/packing_dataset.py @@ -110,7 +110,7 @@ class PackingDataset( "_final_packing_sample_index", ) - _state_fields = ( + _worker_local_fields = ( "_last_pre_pack_failures", "_last_final_pack_failures", "_last_sample_encoder_failures", From 7db2c0e399ab858825dfa87aa7684d240e81c800 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Tue, 2 Sep 2025 10:26:51 +0200 Subject: [PATCH 27/36] Fix docs --- src/megatron/energon/loader.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/megatron/energon/loader.py b/src/megatron/energon/loader.py index f9526459..54179f50 100644 --- a/src/megatron/energon/loader.py +++ b/src/megatron/energon/loader.py @@ -39,7 +39,10 @@ def get_savable_loader( Args: dataset: The dataset to create a loader for. worker_config: Deprecated. Please pass this to the dataset instead. - worker_type: The type of worker to use. + worker_type: The type of worker to use. Options: + "fork": forked workers (default), + "thread": threaded workers (should be used with free-threaded python), + "main": iterate data in the main process without parallelization. gc_freeze_at_start: If True, the garbage collector is frozen at the start of the loader. gc_collect_every_n_steps: The number of steps after which the garbage collector is called. prefetch_factor: The factor by which to prefetch the dataset. @@ -47,7 +50,7 @@ def get_savable_loader( watchdog_timeout_seconds: The timeout in seconds. If None, the watchdog is disabled. watchdog_initial_timeout_seconds: The initial timeout in seconds. If None, the timeout is the same as watchdog_timeout_seconds. fail_on_timeout: If True, stops the whole process upon timeout, after printing a stack trace. - pin_memory: If True, the dataset is pinned to memory. + pin_memory: If True, the data iterated by the dataset is pinned to memory, such that it can be quickly used by CUDA. Returns: The instantiated :class:`megatron.energon.DataLoader`, yielding batches from the dataset. From e8c6c436e9f5f8beb6dec0c38e906b77206c6316 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Tue, 2 Sep 2025 10:27:11 +0200 Subject: [PATCH 28/36] Adapt tests to use context managers for loaders --- tests/test_crudedataset.py | 297 +++--- tests/test_dataset.py | 1442 +++++++++++++++-------------- tests/test_dataset_det.py | 413 ++++----- tests/test_jsonl_dataset.py | 74 +- tests/test_metadataset.py | 921 +++++++++--------- tests/test_metadataset_fewsamp.py | 60 +- tests/test_metadataset_v2.py | 656 +++++++------ 7 files changed, 1884 insertions(+), 1979 deletions(-) diff --git a/tests/test_crudedataset.py b/tests/test_crudedataset.py index 3e2ec4bc..2d54ecd6 100644 --- a/tests/test_crudedataset.py +++ b/tests/test_crudedataset.py @@ -363,28 +363,27 @@ def test_metadataset(self): max_samples_per_sequence=None, handler=reraise_exception, ) - loader = get_savable_loader( + with get_savable_loader( train_dataset, - ) - - print(len(train_dataset)) - # assert len(train_dataset) == 11 + ) as loader: + print(len(train_dataset)) + # assert len(train_dataset) == 11 - for idx, data in enumerate(loader): - if idx >= len(train_dataset): - break + for idx, data in enumerate(loader): + if idx >= len(train_dataset): + break - assert isinstance(data, TextBatch) + assert isinstance(data, TextBatch) - print("Batch", idx) - for txt, key in zip(data.txts, data.__key__): - key_int = int(key.split("/")[-1]) - if key_int < 100: - assert txt == f"<{key_int}>" - else: - assert txt == f"<{key_int}|{key_int}>" + print("Batch", idx) + for txt, key in zip(data.txts, data.__key__): + key_int = int(key.split("/")[-1]) + if key_int < 100: + assert txt == f"<{key_int}>" + else: + assert txt == f"<{key_int}|{key_int}>" - print(key, txt) + print(key, txt) def test_loader(self): torch.manual_seed(42) @@ -394,7 +393,7 @@ def test_loader(self): num_workers=2, ) - loader = get_savable_loader( + with get_savable_loader( get_train_dataset( self.mds_path, batch_size=2, @@ -404,17 +403,17 @@ def test_loader(self): max_samples_per_sequence=None, packing_buffer_size=2, ), - ) - samples = [s.__key__ for idx, s in zip(range(100), loader)] + ) as loader: + samples = [s.__key__ for idx, s in zip(range(100), loader)] - print(samples) + print(samples) - state = loader.save_state_rank() + state = loader.save_state_rank() - samples_after = [s.__key__ for idx, s in zip(range(100, 200), loader)] - print(samples_after) + samples_after = [s.__key__ for idx, s in zip(range(100, 200), loader)] + print(samples_after) - loader = get_savable_loader( + with get_savable_loader( get_train_dataset( self.mds_path, batch_size=2, @@ -424,14 +423,11 @@ def test_loader(self): max_samples_per_sequence=None, packing_buffer_size=2, ), - ) - - loader.restore_state_rank(state) + ).with_restored_state_rank(state) as loader: + samples_restored = [s.__key__ for idx, s in zip(range(100, 200), loader)] + print(samples_restored) - samples_restored = [s.__key__ for idx, s in zip(range(100, 200), loader)] - print(samples_restored) - - assert all([a == b for a, b in zip(samples_after, samples_restored)]) + assert all([a == b for a, b in zip(samples_after, samples_restored)]) def test_aux_random_access(self): torch.manual_seed(42) @@ -443,7 +439,7 @@ def test_aux_random_access(self): print("Initializing dataset") - loader = get_savable_loader( + with get_savable_loader( get_train_dataset( self.aux_mds_path, batch_size=2, @@ -453,24 +449,23 @@ def test_aux_random_access(self): max_samples_per_sequence=None, packing_buffer_size=2, ), - ) + ) as loader: + print("Iterating from dataset") + samples = [s.txts for idx, s in zip(range(100), loader)] + for idx, txts in enumerate(samples): + for txt in txts: + m = re.fullmatch(r"<([0-9]*)\|aux\|([0-9]*)>", txt) + assert m, f"Invalid aux text: {txt}" + assert int(m.group(2)) == int(m.group(1)) + 100 - print("Iterating from dataset") - samples = [s.txts for idx, s in zip(range(100), loader)] - for idx, txts in enumerate(samples): - for txt in txts: - m = re.fullmatch(r"<([0-9]*)\|aux\|([0-9]*)>", txt) - assert m, f"Invalid aux text: {txt}" - assert int(m.group(2)) == int(m.group(1)) + 100 + print(samples) - print(samples) + state = loader.save_state_rank() - state = loader.save_state_rank() + samples_after = [s.__key__ for idx, s in zip(range(100, 200), loader)] + print(samples_after) - samples_after = [s.__key__ for idx, s in zip(range(100, 200), loader)] - print(samples_after) - - loader = get_savable_loader( + with get_savable_loader( get_train_dataset( self.aux_mds_path, batch_size=2, @@ -480,14 +475,11 @@ def test_aux_random_access(self): max_samples_per_sequence=None, packing_buffer_size=2, ), - ) + ).with_restored_state_rank(state) as loader: + samples_restored = [s.__key__ for idx, s in zip(range(100, 200), loader)] + print(samples_restored) - loader.restore_state_rank(state) - - samples_restored = [s.__key__ for idx, s in zip(range(100, 200), loader)] - print(samples_restored) - - assert all([a == b for a, b in zip(samples_after, samples_restored)]) + assert all([a == b for a, b in zip(samples_after, samples_restored)]) def test_aux_random_access_with_cache(self): torch.manual_seed(42) @@ -499,7 +491,7 @@ def test_aux_random_access_with_cache(self): print("Initializing dataset") - loader = get_savable_loader( + with get_savable_loader( get_train_dataset( self.aux_mds_path, batch_size=2, @@ -513,25 +505,24 @@ def test_aux_random_access_with_cache(self): parent_cache_dir=self.dataset_path / "cache", num_workers=1, ), - ) - - print("Iterating from dataset") - samples = [s.txts for idx, s in zip(range(100), loader)] - for idx, txts in enumerate(samples): - for txt in txts: - m = re.fullmatch(r"<([0-9]*)\|aux\|([0-9]*)>\|([0-9]*)", txt) - assert m, f"Invalid aux text: {txt}" - assert int(m.group(2)) == int(m.group(1)) + 100 - assert int(m.group(3)) == (int(m.group(1)) + 1) % 55 + ) as loader: + print("Iterating from dataset") + samples = [s.txts for idx, s in zip(range(100), loader)] + for idx, txts in enumerate(samples): + for txt in txts: + m = re.fullmatch(r"<([0-9]*)\|aux\|([0-9]*)>\|([0-9]*)", txt) + assert m, f"Invalid aux text: {txt}" + assert int(m.group(2)) == int(m.group(1)) + 100 + assert int(m.group(3)) == (int(m.group(1)) + 1) % 55 - print(samples) + print(samples) - state = loader.save_state_rank() + state = loader.save_state_rank() - samples_after = [s.__key__ for idx, s in zip(range(100, 200), loader)] - print(samples_after) + samples_after = [s.__key__ for idx, s in zip(range(100, 200), loader)] + print(samples_after) - loader = get_savable_loader( + with get_savable_loader( get_train_dataset( self.aux_mds_path, batch_size=2, @@ -545,14 +536,11 @@ def test_aux_random_access_with_cache(self): parent_cache_dir=self.dataset_path / "cache", num_workers=1, ), - ) - - loader.restore_state_rank(state) + ).with_restored_state_rank(state) as loader: + samples_restored = [s.__key__ for idx, s in zip(range(100, 200), loader)] + print(samples_restored) - samples_restored = [s.__key__ for idx, s in zip(range(100, 200), loader)] - print(samples_restored) - - assert all([a == b for a, b in zip(samples_after, samples_restored)]) + assert all([a == b for a, b in zip(samples_after, samples_restored)]) def test_aux_random_access_with_cache_and_postencode(self): torch.manual_seed(42) @@ -564,7 +552,7 @@ def test_aux_random_access_with_cache_and_postencode(self): print("Initializing dataset") - loader = get_savable_loader( + with get_savable_loader( get_train_dataset( self.aux_mds_path, batch_size=2, @@ -578,25 +566,24 @@ def test_aux_random_access_with_cache_and_postencode(self): parent_cache_dir=self.dataset_path / "cache", num_workers=1, ), - ) + ) as loader: + print("Iterating from dataset") + samples = [s.txts for idx, s in zip(range(100), loader)] + for idx, txts in enumerate(samples): + for txt in txts: + m = re.fullmatch(r"<([0-9]*)\|aux\|([0-9]*)>\|([0-9]*)", txt) + assert m, f"Invalid aux text: {txt}" + assert int(m.group(2)) == int(m.group(1)) + 100 + assert int(m.group(3)) == (int(m.group(1)) + 1) % 55 - print("Iterating from dataset") - samples = [s.txts for idx, s in zip(range(100), loader)] - for idx, txts in enumerate(samples): - for txt in txts: - m = re.fullmatch(r"<([0-9]*)\|aux\|([0-9]*)>\|([0-9]*)", txt) - assert m, f"Invalid aux text: {txt}" - assert int(m.group(2)) == int(m.group(1)) + 100 - assert int(m.group(3)) == (int(m.group(1)) + 1) % 55 + print(samples) - print(samples) + state = loader.save_state_rank() - state = loader.save_state_rank() + samples_after = [s.__key__ for idx, s in zip(range(100, 200), loader)] + print(samples_after) - samples_after = [s.__key__ for idx, s in zip(range(100, 200), loader)] - print(samples_after) - - loader = get_savable_loader( + with get_savable_loader( get_train_dataset( self.aux_mds_path, batch_size=2, @@ -610,60 +597,57 @@ def test_aux_random_access_with_cache_and_postencode(self): parent_cache_dir=self.dataset_path / "cache", num_workers=1, ), - ) - - loader.restore_state_rank(state) - - samples_restored = [s.__key__ for idx, s in zip(range(100, 200), loader)] - print(samples_restored) - - assert all([a == b for a, b in zip(samples_after, samples_restored)]) - - # Verify that the sources are correct - sample_src_check = [s.__sources__ for idx, s in zip(range(1), loader)][0] - print(sample_src_check) - # NOTE: Auxiliary sources have string as index, not int - assert sample_src_check == ( - # Primary source for the sample, reading all source files - SourceInfo( - dataset_path=EPath(self.dataset_path / "ds1"), - index=2, - shard_name="parts/data-0.tar", - file_names=("000002.pkl", "000002.txt"), - ), - # Auxiliary source for the sample, reading from ds2 - SourceInfo( - dataset_path=EPath(self.dataset_path / "ds2"), - index="000102.txt", - shard_name="parts/data-0.tar", - file_names=("000102.txt",), - ), - # Auxiliary source for the sample, reading from ds1, but next sample - SourceInfo( - dataset_path=EPath(self.dataset_path / "ds1"), - index="000003.txt", - shard_name="parts/data-0.tar", - file_names=("000003.txt",), - ), - SourceInfo( - dataset_path=EPath(self.dataset_path / "ds1"), - index=21, - shard_name="parts/data-2.tar", - file_names=("000021.pkl", "000021.txt"), - ), - SourceInfo( - dataset_path=EPath(self.dataset_path / "ds2"), - index="000121.txt", - shard_name="parts/data-2.tar", - file_names=("000121.txt",), - ), - SourceInfo( - dataset_path=EPath(self.dataset_path / "ds1"), - index="000022.txt", - shard_name="parts/data-2.tar", - file_names=("000022.txt",), - ), - ) + ).with_restored_state_rank(state) as loader: + samples_restored = [s.__key__ for idx, s in zip(range(100, 200), loader)] + print(samples_restored) + + assert all([a == b for a, b in zip(samples_after, samples_restored)]) + + # Verify that the sources are correct + sample_src_check = [s.__sources__ for idx, s in zip(range(1), loader)][0] + print(sample_src_check) + # NOTE: Auxiliary sources have string as index, not int + assert sample_src_check == ( + # Primary source for the sample, reading all source files + SourceInfo( + dataset_path=EPath(self.dataset_path / "ds1"), + index=2, + shard_name="parts/data-0.tar", + file_names=("000002.pkl", "000002.txt"), + ), + # Auxiliary source for the sample, reading from ds2 + SourceInfo( + dataset_path=EPath(self.dataset_path / "ds2"), + index="000102.txt", + shard_name="parts/data-0.tar", + file_names=("000102.txt",), + ), + # Auxiliary source for the sample, reading from ds1, but next sample + SourceInfo( + dataset_path=EPath(self.dataset_path / "ds1"), + index="000003.txt", + shard_name="parts/data-0.tar", + file_names=("000003.txt",), + ), + SourceInfo( + dataset_path=EPath(self.dataset_path / "ds1"), + index=21, + shard_name="parts/data-2.tar", + file_names=("000021.pkl", "000021.txt"), + ), + SourceInfo( + dataset_path=EPath(self.dataset_path / "ds2"), + index="000121.txt", + shard_name="parts/data-2.tar", + file_names=("000121.txt",), + ), + SourceInfo( + dataset_path=EPath(self.dataset_path / "ds1"), + index="000022.txt", + shard_name="parts/data-2.tar", + file_names=("000022.txt",), + ), + ) def test_aux_filesystem_reference(self): torch.manual_seed(42) @@ -673,7 +657,7 @@ def test_aux_filesystem_reference(self): num_workers=0, ) - loader = get_savable_loader( + with get_savable_loader( get_train_dataset( self.aux_mds_path, batch_size=1, @@ -682,11 +666,10 @@ def test_aux_filesystem_reference(self): shuffle_buffer_size=None, max_samples_per_sequence=None, ), - ) + ) as loader: + sample = next(iter(loader)) - sample = next(iter(loader)) - - assert sample.txts[0].endswith("|aux|__module__: megatron.ener>") + assert sample.txts[0].endswith("|aux|__module__: megatron.ener>") def test_nomds(self): torch.manual_seed(42) @@ -696,7 +679,7 @@ def test_nomds(self): num_workers=2, ) - loader = get_savable_loader( + with get_savable_loader( get_train_dataset( self.dataset_path / "ds1", batch_size=2, @@ -705,11 +688,11 @@ def test_nomds(self): shuffle_buffer_size=None, max_samples_per_sequence=None, ), - ) - samples = [s.__key__ for idx, s in zip(range(100), loader)] + ) as loader: + samples = [s.__key__ for idx, s in zip(range(100), loader)] - print(samples) - assert len(samples) == 100 + print(samples) + assert len(samples) == 100 if __name__ == "__main__": diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 74d68c40..718c448b 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -317,29 +317,28 @@ def test_captioning_dataset(self): worker_config=no_worker_config, ) - def get_ld(ds): - return get_loader(ds) - # Check len operator assert len(ds) == 50 # Check if iterating returns the same - iter1 = list(get_ld(ds)) - iter2 = list(get_ld(ds)) + with get_loader(ds) as l1, get_loader(ds) as l2: + iter1 = list(l1) + iter2 = list(l2) assert len(iter1) == 50 assert len(iter2) == 50 assert all(elem1.__key__ == elem2.__key__ for elem1, elem2 in zip(iter1, iter2)) # Check case when batch size is larger than dataset size batch_sizes = [] - for wrapped_sample in get_ld( + with get_loader( BatchDataset( ds, batch_size=DATASET_SIZE * 2, batcher=generic_batch, worker_config=no_worker_config, ) - ): - batch_sizes.append(wrapped_sample.image.shape[0]) + ) as l: + for wrapped_sample in l: + batch_sizes.append(wrapped_sample.image.shape[0]) assert batch_sizes == [DATASET_SIZE] # Check returned dimensions and batch sizes if batch size is smaller than dataset size @@ -352,44 +351,45 @@ def get_ld(ds): cnt = 0 expected_num_batches = math.ceil(DATASET_SIZE / batch_size) - for idx, wrapped_sample in enumerate(get_ld(batched_ds)): - # Check batch sizes - if idx < expected_num_batches - 1: - assert wrapped_sample.image.shape[0] == batch_size - assert wrapped_sample.caption.shape[0] == batch_size - else: - assert wrapped_sample.image.shape[0] == DATASET_SIZE % batch_size - assert wrapped_sample.caption.shape[0] == DATASET_SIZE % batch_size + with get_loader(batched_ds) as l: + for idx, wrapped_sample in enumerate(l): + # Check batch sizes + if idx < expected_num_batches - 1: + assert wrapped_sample.image.shape[0] == batch_size + assert wrapped_sample.caption.shape[0] == batch_size + else: + assert wrapped_sample.image.shape[0] == DATASET_SIZE % batch_size + assert wrapped_sample.caption.shape[0] == DATASET_SIZE % batch_size - # Check image size - assert tuple(wrapped_sample.image.shape[1:]) == (3, 100, 100) + # Check image size + assert tuple(wrapped_sample.image.shape[1:]) == (3, 100, 100) - cnt += 1 + cnt += 1 - logging.info(f" Batch {idx}:") - logging.info(f" {wrapped_sample.image.shape=}") - logging.info(f" {wrapped_sample.caption.shape=}") + logging.info(f" Batch {idx}:") + logging.info(f" {wrapped_sample.image.shape=}") + logging.info(f" {wrapped_sample.caption.shape=}") assert cnt == expected_num_batches # Check if actual image and caption data are correct - loader = get_ld( + with get_loader( BatchDataset(ds, batch_size=9, batcher=generic_batch, worker_config=no_worker_config), - ) - batch_sizes = [] - dataset_samples = {sample["caption"]: sample["image"] for sample in self.samples} - for idx, sample in enumerate(loader): - batch_sizes.append(sample.image.shape[0]) - for bidx in range(sample.image.shape[0]): - refimg = dataset_samples.pop( - sample.caption[bidx].numpy().tobytes().rstrip(b"\0").decode() - ) - assert torch.allclose( - sample.image[bidx], - torch.permute(torch.tensor(refimg, dtype=torch.float32) / 255, (2, 0, 1)), - ) - assert len(dataset_samples) == 0 - assert batch_sizes == [9, 9, 9, 9, 9, 5] + ) as loader: + batch_sizes = [] + dataset_samples = {sample["caption"]: sample["image"] for sample in self.samples} + for idx, sample in enumerate(loader): + batch_sizes.append(sample.image.shape[0]) + for bidx in range(sample.image.shape[0]): + refimg = dataset_samples.pop( + sample.caption[bidx].numpy().tobytes().rstrip(b"\0").decode() + ) + assert torch.allclose( + sample.image[bidx], + torch.permute(torch.tensor(refimg, dtype=torch.float32) / 255, (2, 0, 1)), + ) + assert len(dataset_samples) == 0 + assert batch_sizes == [9, 9, 9, 9, 9, 5] def test_field_access(self): ds = get_dataset_from_config( @@ -401,8 +401,9 @@ def test_field_access(self): sample_type=CaptioningSample, ) captions = set(sample["caption"] for sample in self.samples) - for sample in get_loader(ds.build()): - captions.remove(sample.caption) + with get_loader(ds.build()) as loader: + for sample in loader: + captions.remove(sample.caption) assert len(captions) == 0 def test_sample_loader(self): @@ -415,9 +416,10 @@ def test_sample_loader(self): sample_type=CaptioningSample, ) captions = set(sample["caption"] for sample in self.samples) - for sample in get_loader(ds.build()): - assert sample.caption[:4] == "" - captions.remove(sample.caption[4:]) + with get_loader(ds.build()) as loader: + for sample in loader: + assert sample.caption[:4] == "" + captions.remove(sample.caption[4:]) assert len(captions) == 0 def test_sample_loader_key(self): @@ -433,10 +435,11 @@ def test_sample_loader_key(self): keys = set( f"parts/data-{idx // 30:d}.tar/{idx:06d}" for idx in range(len(self.samples)) ) - for sample in get_loader(ds.build()): - assert sample.caption[:4] == "" - captions.remove(sample.caption[4:]) - keys.remove(sample.__key__) + with get_loader(ds.build()) as loader: + for sample in loader: + assert sample.caption[:4] == "" + captions.remove(sample.caption[4:]) + keys.remove(sample.__key__) assert len(captions) == 0 assert len(keys) == 0 @@ -450,7 +453,8 @@ def test_exclusion(self): sample_type=CaptioningSample, ) - keys = [entry.__key__ for entry in get_loader(ds.build())] + with get_loader(ds.build()) as loader: + keys = [entry.__key__ for entry in loader] assert keys == [ f"parts/data-1.tar/{i:06d}" for i in list(range(30, 35)) + list(range(40, 50)) ], keys @@ -469,7 +473,7 @@ def encode_sample(self, sample: CaptioningSample) -> EncodedCaptioningSample: caption=torch.frombuffer(bytearray(sample.caption.encode()), dtype=torch.uint8), ) - loader = get_loader( + with get_loader( get_train_dataset( self.dataset_path, batch_size=10, @@ -480,30 +484,29 @@ def encode_sample(self, sample: CaptioningSample) -> EncodedCaptioningSample: max_samples_per_sequence=None, task_encoder=TestTaskEncoder(), ) - ) - - assert len(loader) == 2 - - def hist(data): - """Histogram function""" - r = defaultdict(lambda: 0) - for k in data: - r[k] += 1 - return r - - print([[batch.__key__ for batch in loader] for _ in range(100)]) - keys = [key for _ in range(100) for batch in loader for key in batch.__key__] - # 100 iterations, 2 virtual epoch size, batch size 10 - print(len(keys), keys) - keyhist = hist(keys) - print(sorted(keyhist.items())) - print(sorted(keyhist.items(), key=lambda x: (x[1], x[0]))) - assert len(keys) == 100 * 2 * 10 - # Data should be approximately sampled uniformly (40+-1 samples per key) - assert len(keyhist) == 50 - assert all(v in (39, 40, 41) for v in keyhist.values()) - - loader2 = get_loader( + ) as loader: + assert len(loader) == 2 + + def hist(data): + """Histogram function""" + r = defaultdict(lambda: 0) + for k in data: + r[k] += 1 + return r + + print([[batch.__key__ for batch in loader] for _ in range(100)]) + keys = [key for _ in range(100) for batch in loader for key in batch.__key__] + # 100 iterations, 2 virtual epoch size, batch size 10 + print(len(keys), keys) + keyhist = hist(keys) + print(sorted(keyhist.items())) + print(sorted(keyhist.items(), key=lambda x: (x[1], x[0]))) + assert len(keys) == 100 * 2 * 10 + # Data should be approximately sampled uniformly (40+-1 samples per key) + assert len(keyhist) == 50 + assert all(v in (39, 40, 41) for v in keyhist.values()) + + with get_loader( get_val_dataset( self.dataset_path, split_part="train", @@ -511,48 +514,48 @@ def hist(data): worker_config=no_worker_config, task_encoder=TestTaskEncoder(), ) - ) - assert len(loader2) == 5 - # The order in the split is shuffled this way - assert list(key for batch in loader2 for key in batch.__key__) == [ - f"parts/data-1.tar/{i:06d}" for i in range(30, 50) - ] + [f"parts/data-0.tar/{i:06d}" for i in range(30)] + ) as loader2: + assert len(loader2) == 5 + # The order in the split is shuffled this way + assert list(key for batch in loader2 for key in batch.__key__) == [ + f"parts/data-1.tar/{i:06d}" for i in range(30, 50) + ] + [f"parts/data-0.tar/{i:06d}" for i in range(30)] def test_default_dataset(self): torch.manual_seed(42) - train_loader = get_loader( - get_train_dataset( - self.dataset_path, - batch_size=10, - worker_config=no_worker_config, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ) - ) - - val_loader = get_loader( - get_val_dataset( - self.dataset_path, - split_part="train", - batch_size=10, - worker_config=no_worker_config, - ) - ) - - n_samples = 0 - for i, sample in zip(range(100), train_loader): - assert sample.image.shape == (10, 3, 100, 100) - n_samples += sample.image.shape[0] - assert n_samples == 1000 - n_samples = 0 - for sample in val_loader: - assert sample.image.shape == (10, 3, 100, 100) - n_samples += sample.image.shape[0] - assert n_samples == 50 + with ( + get_loader( + get_train_dataset( + self.dataset_path, + batch_size=10, + worker_config=no_worker_config, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + ) as train_loader, + get_loader( + get_val_dataset( + self.dataset_path, + split_part="train", + batch_size=10, + worker_config=no_worker_config, + ) + ) as val_loader, + ): + n_samples = 0 + for i, sample in zip(range(100), train_loader): + assert sample.image.shape == (10, 3, 100, 100) + n_samples += sample.image.shape[0] + assert n_samples == 1000 + n_samples = 0 + for sample in val_loader: + assert sample.image.shape == (10, 3, 100, 100) + n_samples += sample.image.shape[0] + assert n_samples == 50 def test_no_batching(self): - train_loader = get_loader( + with get_loader( get_train_dataset( self.dataset_path, batch_size=None, @@ -560,13 +563,12 @@ def test_no_batching(self): shuffle_buffer_size=None, max_samples_per_sequence=None, ) - ) - - one_sample = next(iter(train_loader)) + ) as train_loader: + one_sample = next(iter(train_loader)) - # Single sample without batching - assert isinstance(one_sample.image, torch.Tensor) - assert isinstance(one_sample.caption, str) + # Single sample without batching + assert isinstance(one_sample.image, torch.Tensor) + assert isinstance(one_sample.caption, str) def test_dataset_len(self): torch.manual_seed(42) @@ -581,30 +583,28 @@ def test_dataset_len(self): shuffle_buffer_size=None, max_samples_per_sequence=None, ) - train_loader = get_loader(train_dataset) + with get_loader(train_dataset) as train_loader: + assert len(train_dataset) == 12 + assert len(train_loader) == 12 + assert len(list(train_loader)) == 12 - assert len(train_dataset) == 12 - assert len(train_loader) == 12 - assert len(list(train_loader)) == 12 - - val_dataset = get_val_dataset( - self.dataset_path, split_part="train", batch_size=1, worker_config=no_worker_config - ) - val_loader = get_loader(val_dataset) - assert len(val_loader) == 50 - assert len(list(val_loader)) == 50 + val_dataset = get_val_dataset( + self.dataset_path, split_part="train", batch_size=1, worker_config=no_worker_config + ) + with get_loader(val_dataset) as val_loader: + assert len(val_loader) == 50 + assert len(list(val_loader)) == 50 val_dataset = get_val_dataset( self.dataset_path, split_part="train", batch_size=11, worker_config=worker_config ) - val_loader = get_loader(val_dataset) - - # n samples: ceil(50 / 11) // 4 * 4 - assert len(val_dataset) == 8 - assert len(val_loader) == 8 - assert len(list(val_loader)) == 8 - assert [len(entry.__key__) for entry in val_loader] == [11, 11, 11, 11, 2, 1, 2, 1] - assert sum(len(entry.__key__) for entry in val_loader) == 50 + with get_loader(val_dataset) as val_loader: + # n samples: ceil(50 / 11) // 4 * 4 + assert len(val_dataset) == 8 + assert len(val_loader) == 8 + assert len(list(val_loader)) == 8 + assert [len(entry.__key__) for entry in val_loader] == [11, 11, 11, 11, 2, 1, 2, 1] + assert sum(len(entry.__key__) for entry in val_loader) == 50 def test_multirank_dataset(self): torch.manual_seed(42) @@ -620,63 +620,60 @@ def test_multirank_dataset(self): shuffle_buffer_size=None, max_samples_per_sequence=None, ) - train_loader = get_loader(train_dataset) - - assert len(train_dataset) == 12 - assert len(train_loader) == 12 - assert len(list(train_loader)) == 12 + with get_loader(train_dataset) as train_loader: + assert len(train_dataset) == 12 + assert len(train_loader) == 12 + assert len(list(train_loader)) == 12 val_dataset0 = get_val_dataset( self.dataset_path, split_part="train", batch_size=1, worker_config=worker_config_r0 ) - val_loader0 = get_loader(val_dataset0) - print(len(val_loader0)) - assert len(val_loader0) == 25 - keys0 = set(key for entry in val_loader0 for key in entry.__key__) - assert len(keys0) == 25 + with get_loader(val_dataset0) as val_loader0: + print(len(val_loader0)) + assert len(val_loader0) == 25 + keys0 = set(key for entry in val_loader0 for key in entry.__key__) + assert len(keys0) == 25 val_dataset0b11 = get_val_dataset( self.dataset_path, split_part="train", batch_size=11, worker_config=worker_config_r0 ) - val_loader0b11 = get_loader(val_dataset0b11) - - assert len(val_dataset0b11) == 4 - assert len(val_loader0b11) == 4 - assert len(list(val_loader0b11)) == 4 - keys0b11 = set(key for entry in val_loader0b11 for key in entry.__key__) - print([len(entry.__key__) for entry in val_loader0b11]) - assert [len(entry.__key__) for entry in val_loader0b11] == [11, 11, 2, 1] - assert len(keys0b11) == 25 + with get_loader(val_dataset0b11) as val_loader0b11: + assert len(val_dataset0b11) == 4 + assert len(val_loader0b11) == 4 + assert len(list(val_loader0b11)) == 4 + keys0b11 = set(key for entry in val_loader0b11 for key in entry.__key__) + print([len(entry.__key__) for entry in val_loader0b11]) + assert [len(entry.__key__) for entry in val_loader0b11] == [11, 11, 2, 1] + assert len(keys0b11) == 25 - assert keys0b11 == keys0 + assert keys0b11 == keys0 val_dataset1 = get_val_dataset( self.dataset_path, split_part="train", batch_size=1, worker_config=worker_config_r1 ) - val_loader1 = get_loader(val_dataset1) - print(len(val_loader1)) - assert len(val_loader1) == 25 - keys1 = set(key for entry in val_loader1 for key in entry.__key__) - assert len(keys1) == 25 - print(sorted(keys1)) - print(sorted(keys0)) - assert keys1.isdisjoint(keys0) + with get_loader(val_dataset1) as val_loader1: + print(len(val_loader1)) + assert len(val_loader1) == 25 + keys1 = set(key for entry in val_loader1 for key in entry.__key__) + assert len(keys1) == 25 + print(sorted(keys1)) + print(sorted(keys0)) + assert keys1.isdisjoint(keys0) val_dataset1b11 = get_val_dataset( self.dataset_path, split_part="train", batch_size=11, worker_config=worker_config_r1 ) - val_loader1b11 = get_loader(val_dataset1b11) + with get_loader(val_dataset1b11) as val_loader1b11: + assert len(val_dataset1b11) == 4 + assert len(val_loader1b11) == 4 + assert len(list(val_loader1b11)) == 4 + keys1b11 = set(key for entry in val_loader1b11 for key in entry.__key__) + print([len(entry.__key__) for entry in val_loader1b11]) + assert [len(entry.__key__) for entry in val_loader1b11] == [11, 11, 2, 1] + assert len(keys1b11) == 25 + assert keys1b11.isdisjoint(keys0b11) - assert len(val_dataset1b11) == 4 - assert len(val_loader1b11) == 4 - assert len(list(val_loader1b11)) == 4 - keys1b11 = set(key for entry in val_loader1b11 for key in entry.__key__) - print([len(entry.__key__) for entry in val_loader1b11]) - assert [len(entry.__key__) for entry in val_loader1b11] == [11, 11, 2, 1] - assert len(keys1b11) == 25 - assert keys1b11.isdisjoint(keys0b11) - - assert keys1b11 == keys1 + assert keys1b11 == keys1 def test_weight_aug(self): class WeightAugmentTaskEncoder(AugmentTaskEncoder): @@ -697,7 +694,7 @@ class WeightedCaptioningBatch(Batch): caption: List[str] weight: float - loader = get_loader( + with get_loader( get_val_dataset( self.dataset_path, split_part="train", @@ -709,15 +706,14 @@ class WeightedCaptioningBatch(Batch): target_data_class=WeightedCaptioningBatch, ), ) - ) - - for data in loader: - assert data.weight == [0.8] * 10 + ) as loader: + for data in loader: + assert data.weight == [0.8] * 10 def test_blending(self): torch.manual_seed(42) - loader = get_loader( + with get_loader( BlendDataset( ( get_train_dataset( @@ -741,14 +737,13 @@ def test_blending(self): ), worker_config=no_worker_config, ) - ) - - bs_hist = {10: 0, 20: 0} - for i, sample in zip(range(1000), loader): - bs_hist[sample.image.shape[0]] += 1 - print(bs_hist) - assert 150 <= bs_hist[10] <= 250 - assert 750 <= bs_hist[20] <= 850 + ) as loader: + bs_hist = {10: 0, 20: 0} + for i, sample in zip(range(1000), loader): + bs_hist[sample.image.shape[0]] += 1 + print(bs_hist) + assert 150 <= bs_hist[10] <= 250 + assert 750 <= bs_hist[20] <= 850 def test_mixing_homogeneous(self): @dataclass @@ -764,7 +759,7 @@ def __init__(self, source: int): def encode_batch(self, batch): return TestBatch.extend(batch, source=self.source) - loader = get_loader( + with get_loader( MixBatchDataset( ( get_train_dataset( @@ -792,15 +787,14 @@ def encode_batch(self, batch): batch_mix_fn=homogeneous_concat_mix, worker_config=no_worker_config, ) - ) - - source_hist = {0: 0, 1: 0} - for i, sample in zip(range(1000), loader): - assert sample.image.shape == (10, 3, 100, 100) - for source in sample.source: - source_hist[source] += 1 - assert 1500 <= source_hist[0] <= 2500 - assert 7500 <= source_hist[1] <= 8500 + ) as loader: + source_hist = {0: 0, 1: 0} + for i, sample in zip(range(1000), loader): + assert sample.image.shape == (10, 3, 100, 100) + for source in sample.source: + source_hist[source] += 1 + assert 1500 <= source_hist[0] <= 2500 + assert 7500 <= source_hist[1] <= 8500 def test_mixing_heterogeneous(self): @dataclass @@ -821,7 +815,7 @@ def __init__(self, source: int, batch_cls: Type[TestBatch1]): def encode_batch(self, batch): return self.batch_cls.extend(batch, source=self.source) - loader = get_loader( + with get_loader( MixBatchDataset( ( get_train_dataset( @@ -848,21 +842,20 @@ def encode_batch(self, batch): batch_size=10, worker_config=no_worker_config, ) - ) - - source_hist = {0: 0, 1: 0} - for i, samples in zip(range(1000), loader): - assert len(samples) == 10 - for sample in samples: - assert sample.image.shape == (1, 3, 100, 100) - source_hist[sample.source] += 1 - assert 1500 <= source_hist[0] <= 2500 - assert 7500 <= source_hist[1] <= 8500 + ) as loader: + source_hist = {0: 0, 1: 0} + for i, samples in zip(range(1000), loader): + assert len(samples) == 10 + for sample in samples: + assert sample.image.shape == (1, 3, 100, 100) + source_hist[sample.source] += 1 + assert 1500 <= source_hist[0] <= 2500 + assert 7500 <= source_hist[1] <= 8500 def test_val_limit(self): torch.manual_seed(42) - loader = get_loader( + with get_loader( get_val_dataset( self.dataset_path, split_part="train", @@ -870,19 +863,18 @@ def test_val_limit(self): worker_config=no_worker_config, limit=3, ) - ) + ) as loader: + assert len(loader) == 3 - assert len(loader) == 3 - - samples = [[batch.__key__ for batch in loader] for _ in range(10)] - print(samples) - for s in samples: - print(" -", s) - assert all(samples[0] == one_ep_samples for one_ep_samples in samples) + samples = [[batch.__key__ for batch in loader] for _ in range(10)] + print(samples) + for s in samples: + print(" -", s) + assert all(samples[0] == one_ep_samples for one_ep_samples in samples) worker_config = WorkerConfig(rank=0, world_size=1, num_workers=2) - loader = get_loader( + with get_loader( get_val_dataset( self.dataset_path, split_part="train", @@ -890,18 +882,17 @@ def test_val_limit(self): worker_config=worker_config, limit=3, ) - ) - - assert len(loader) == 3 - - samples_wrk2 = [[batch.__key__ for batch in loader] for _ in range(10)] - print(samples_wrk2) - for s in samples_wrk2: - print(" -", s) - assert all( - all(a == b for a, b in zip(samples_wrk2[0], one_ep_samples)) - for one_ep_samples in samples_wrk2 - ) + ) as loader: + assert len(loader) == 3 + + samples_wrk2 = [[batch.__key__ for batch in loader] for _ in range(10)] + print(samples_wrk2) + for s in samples_wrk2: + print(" -", s) + assert all( + all(a == b for a, b in zip(samples_wrk2[0], one_ep_samples)) + for one_ep_samples in samples_wrk2 + ) def test_current_batch_index(self): # Tests if the get_current_batch_index works properly @@ -919,7 +910,7 @@ def encode_sample(self, sample): ) # First, test simple single main-thread loader with accessing get_current_batch_index - loader = get_loader( + with get_loader( get_train_dataset( self.dataset_path, batch_size=2, @@ -928,169 +919,178 @@ def encode_sample(self, sample): shuffle_buffer_size=20, max_samples_per_sequence=10, ) - ) - - batches = list(zip(range(20), loader)) - print("bi", [batch.batch_index for batch_idx, batch in batches]) - assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches) + ) as loader: + batches = list(zip(range(20), loader)) + print("bi", [batch.batch_index for batch_idx, batch in batches]) + assert all( + all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches + ) - print("si", [batch.sample_index for batch_idx, batch in batches]) - assert all( - all( - si == sample_offset + batch_idx * 2 - for sample_offset, si in enumerate(batch.sample_index) + print("si", [batch.sample_index for batch_idx, batch in batches]) + assert all( + all( + si == sample_offset + batch_idx * 2 + for sample_offset, si in enumerate(batch.sample_index) + ) + for batch_idx, batch in batches ) - for batch_idx, batch in batches - ) - print("pk", [batch.__key__ for batch_idx, batch in batches]) - print("rk", [batch.__restore_key__ for batch_idx, batch in batches]) - assert loader.can_restore_sample() - - # These need to be hard coded to detect breaking changes - # If a change is expected, update the values with the ones printed below - ref_batch_rand_nums = [ - [661, 762], - [206, 470], - [130, 283], - [508, 61], - [625, 661], - [296, 376], - [632, 514], - [715, 406], - [555, 27], - [760, 36], - [607, 610], - [825, 219], - [564, 832], - [876, 512], - [632, 605], - [357, 738], - [40, 378], - [609, 444], - [610, 367], - [367, 69], - ] - - batch_rand_nums = [] - for batch_idx, batch in batches: - restore_batch = loader.restore_sample(batch.__restore_key__) - assert restore_batch.__key__ == batch.__key__ - assert restore_batch.batch_index == batch.batch_index - assert restore_batch.sample_index == batch.sample_index - assert restore_batch.rand_num == batch.rand_num - - batch_rand_nums.append(restore_batch.rand_num) - assert np.allclose(restore_batch.image, batch.image) - - # For constructing the test data above: - print("batch_rand_nums: ", batch_rand_nums) - assert batch_rand_nums == ref_batch_rand_nums + print("pk", [batch.__key__ for batch_idx, batch in batches]) + print("rk", [batch.__restore_key__ for batch_idx, batch in batches]) + assert loader.can_restore_sample() + + # These need to be hard coded to detect breaking changes + # If a change is expected, update the values with the ones printed below + ref_batch_rand_nums = [ + [661, 762], + [206, 470], + [130, 283], + [508, 61], + [625, 661], + [296, 376], + [632, 514], + [715, 406], + [555, 27], + [760, 36], + [607, 610], + [825, 219], + [564, 832], + [876, 512], + [632, 605], + [357, 738], + [40, 378], + [609, 444], + [610, 367], + [367, 69], + ] + + batch_rand_nums = [] + for batch_idx, batch in batches: + restore_batch = loader.restore_sample(batch.__restore_key__) + assert restore_batch.__key__ == batch.__key__ + assert restore_batch.batch_index == batch.batch_index + assert restore_batch.sample_index == batch.sample_index + assert restore_batch.rand_num == batch.rand_num + + batch_rand_nums.append(restore_batch.rand_num) + assert np.allclose(restore_batch.image, batch.image) + + # For constructing the test data above: + print("batch_rand_nums: ", batch_rand_nums) + assert batch_rand_nums == ref_batch_rand_nums # Now, test multi-worker loader with accessing get_current_batch_index worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2) worker_config_r1 = WorkerConfig(rank=1, world_size=2, num_workers=2) - loader = get_loader( - get_train_dataset( - self.dataset_path, - batch_size=2, - task_encoder=TestTaskEncoder(), - worker_config=worker_config_r0, - shuffle_buffer_size=20, - max_samples_per_sequence=10, - ) - ) - loader_r1 = get_loader( - get_train_dataset( - self.dataset_path, - batch_size=2, - task_encoder=TestTaskEncoder(), - worker_config=worker_config_r1, - shuffle_buffer_size=20, - max_samples_per_sequence=10, + with ( + get_loader( + get_train_dataset( + self.dataset_path, + batch_size=2, + task_encoder=TestTaskEncoder(), + worker_config=worker_config_r0, + shuffle_buffer_size=20, + max_samples_per_sequence=10, + ) + ) as loader, + get_loader( + get_train_dataset( + self.dataset_path, + batch_size=2, + task_encoder=TestTaskEncoder(), + worker_config=worker_config_r1, + shuffle_buffer_size=20, + max_samples_per_sequence=10, + ) + ) as loader_r1, + ): + batches = list(zip(range(20), loader)) + print("bir0", [batch.batch_index for batch_idx, batch in batches]) + assert all( + all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches ) - ) - batches = list(zip(range(20), loader)) - print("bir0", [batch.batch_index for batch_idx, batch in batches]) - assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches) - - print("sir0", [batch.sample_index for batch_idx, batch in batches]) - assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches) - assert all( - all( - si == 2 * sample_offset + (batch_idx * 2 - batch_idx % 2) - for sample_offset, si in enumerate(batch.sample_index) + print("sir0", [batch.sample_index for batch_idx, batch in batches]) + assert all( + all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches + ) + assert all( + all( + si == 2 * sample_offset + (batch_idx * 2 - batch_idx % 2) + for sample_offset, si in enumerate(batch.sample_index) + ) + for batch_idx, batch in batches ) - for batch_idx, batch in batches - ) - batches_r1 = list(zip(range(20), loader_r1)) - print("bir0", [batch.batch_index for batch_idx, batch in batches_r1]) - print("sir1", [batch.sample_index for batch_idx, batch in batches_r1]) - assert all( - all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches_r1 - ) - assert all( - all( - si == 2 * sample_offset + (batch_idx * 2 - batch_idx % 2) - for sample_offset, si in enumerate(batch.sample_index) + batches_r1 = list(zip(range(20), loader_r1)) + print("bir0", [batch.batch_index for batch_idx, batch in batches_r1]) + print("sir1", [batch.sample_index for batch_idx, batch in batches_r1]) + assert all( + all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches_r1 + ) + assert all( + all( + si == 2 * sample_offset + (batch_idx * 2 - batch_idx % 2) + for sample_offset, si in enumerate(batch.sample_index) + ) + for batch_idx, batch in batches_r1 ) - for batch_idx, batch in batches_r1 - ) # Now, test multi-worker loader with accessing get_current_batch_index and save/restore state - loader = get_savable_loader( - get_train_dataset( - self.dataset_path, - batch_size=2, - task_encoder=TestTaskEncoder(), - worker_config=worker_config_r0, - shuffle_buffer_size=20, - max_samples_per_sequence=10, + with ( + get_savable_loader( + get_train_dataset( + self.dataset_path, + batch_size=2, + task_encoder=TestTaskEncoder(), + worker_config=worker_config_r0, + shuffle_buffer_size=20, + max_samples_per_sequence=10, + ) + ) as loader, + get_savable_loader( + get_train_dataset( + self.dataset_path, + batch_size=2, + task_encoder=TestTaskEncoder(), + worker_config=worker_config_r1, + shuffle_buffer_size=20, + max_samples_per_sequence=10, + ) + ) as loader_r1, + ): + batches = list(zip(range(20), loader)) + print([batch.batch_index for batch_idx, batch in batches]) + assert all( + all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches ) - ) - loader_r1 = get_savable_loader( - get_train_dataset( - self.dataset_path, - batch_size=2, - task_encoder=TestTaskEncoder(), - worker_config=worker_config_r1, - shuffle_buffer_size=20, - max_samples_per_sequence=10, + assert all( + all( + si == 2 * sample_offset + (batch_idx * 2 - batch_idx % 2) + for sample_offset, si in enumerate(batch.sample_index) + ) + for batch_idx, batch in batches ) - ) - batches = list(zip(range(20), loader)) - print([batch.batch_index for batch_idx, batch in batches]) - assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches) - assert all( - all( - si == 2 * sample_offset + (batch_idx * 2 - batch_idx % 2) - for sample_offset, si in enumerate(batch.sample_index) + batches_r1 = list(zip(range(20), loader_r1)) + print([batch.batch_index for batch_idx, batch in batches_r1]) + assert all( + all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches_r1 ) - for batch_idx, batch in batches - ) - - batches_r1 = list(zip(range(20), loader_r1)) - print([batch.batch_index for batch_idx, batch in batches_r1]) - assert all( - all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches_r1 - ) - assert all( - all( - si == 2 * sample_offset + (batch_idx * 2 - batch_idx % 2) - for sample_offset, si in enumerate(batch.sample_index) + assert all( + all( + si == 2 * sample_offset + (batch_idx * 2 - batch_idx % 2) + for sample_offset, si in enumerate(batch.sample_index) + ) + for batch_idx, batch in batches_r1 ) - for batch_idx, batch in batches_r1 - ) - # Save and restore state - state = loader.save_state_rank() + # Save and restore state + state = loader.save_state_rank() # Restore state and check if the batch index is restored correctly - loader = get_savable_loader( + with get_savable_loader( get_train_dataset( self.dataset_path, batch_size=2, @@ -1099,20 +1099,20 @@ def encode_sample(self, sample): shuffle_buffer_size=20, max_samples_per_sequence=10, ) - ) - loader.restore_state_rank(state) - - batches = list(zip(range(20, 40), loader)) - print([batch.batch_index for batch_idx, batch in batches]) - print([batch.sample_index for batch_idx, batch in batches]) - assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches) - assert all( - all( - si == 2 * sample_offset + (batch_idx * 2 - batch_idx % 2) - for sample_offset, si in enumerate(batch.sample_index) + ).with_restored_state_rank(state) as loader: + batches = list(zip(range(20, 40), loader)) + print([batch.batch_index for batch_idx, batch in batches]) + print([batch.sample_index for batch_idx, batch in batches]) + assert all( + all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches + ) + assert all( + all( + si == 2 * sample_offset + (batch_idx * 2 - batch_idx % 2) + for sample_offset, si in enumerate(batch.sample_index) + ) + for batch_idx, batch in batches ) - for batch_idx, batch in batches - ) def test_current_batch_index_generator(self): # Tests if the get_current_batch_index works properly @@ -1137,7 +1137,7 @@ def encode_sample(self, sample): ) # First, test simple single main-thread loader with accessing get_current_batch_index - loader = get_loader( + with get_loader( get_train_dataset( self.dataset_path, batch_size=3, @@ -1146,184 +1146,201 @@ def encode_sample(self, sample): shuffle_buffer_size=20, max_samples_per_sequence=10, ) - ) - - batches = list(zip(range(20), loader)) - print("bi", [batch.batch_index for batch_idx, batch in batches]) - assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches) + ) as loader: + batches = list(zip(range(20), loader)) + print("bi", [batch.batch_index for batch_idx, batch in batches]) + assert all( + all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches + ) - print("si", [batch.sample_index for batch_idx, batch in batches]) - assert all( - all( - si == (sample_offset + batch_idx * 3) // 2 - for sample_offset, si in enumerate(batch.sample_index) + print("si", [batch.sample_index for batch_idx, batch in batches]) + assert all( + all( + si == (sample_offset + batch_idx * 3) // 2 + for sample_offset, si in enumerate(batch.sample_index) + ) + for batch_idx, batch in batches ) - for batch_idx, batch in batches - ) - print("rk", [batch.__restore_key__ for batch_idx, batch in batches]) - assert loader.can_restore_sample() - - # These need to be hard coded to detect breaking changes - # If a change is expected, update the values with the ones printed below - ref_batch_rand_nums = [ - [661, 1747, 762], - [1171, 206, 1921], - [470, 1705, 130], - [1722, 283, 1990], - [508, 1041, 61], - [1102, 625, 1559], - [661, 1512, 296], - [1866, 376, 1345], - [632, 1176, 514], - [1652, 715, 1702], - [406, 1552, 555], - [1303, 27, 1520], - [760, 1380, 36], - [1869, 607, 1292], - [610, 1084, 825], - [1113, 219, 1102], - [564, 1695, 832], - [1612, 876, 2000], - [512, 1308, 632], - [1425, 605, 1931], - ] - - batch_rand_nums = [] - for batch_idx, batch in batches: - restore_batch = loader.restore_sample(batch.__restore_key__) - assert restore_batch.batch_index == batch.batch_index - assert restore_batch.sample_index == batch.sample_index - assert restore_batch.rand_num == batch.rand_num - - batch_rand_nums.append(restore_batch.rand_num) - assert np.allclose(restore_batch.image, batch.image) - - # For constructing the test data above: - print("batch_rand_nums: ", batch_rand_nums) - assert batch_rand_nums == ref_batch_rand_nums + print("rk", [batch.__restore_key__ for batch_idx, batch in batches]) + assert loader.can_restore_sample() + + # These need to be hard coded to detect breaking changes + # If a change is expected, update the values with the ones printed below + ref_batch_rand_nums = [ + [661, 1747, 762], + [1171, 206, 1921], + [470, 1705, 130], + [1722, 283, 1990], + [508, 1041, 61], + [1102, 625, 1559], + [661, 1512, 296], + [1866, 376, 1345], + [632, 1176, 514], + [1652, 715, 1702], + [406, 1552, 555], + [1303, 27, 1520], + [760, 1380, 36], + [1869, 607, 1292], + [610, 1084, 825], + [1113, 219, 1102], + [564, 1695, 832], + [1612, 876, 2000], + [512, 1308, 632], + [1425, 605, 1931], + ] + + batch_rand_nums = [] + for batch_idx, batch in batches: + restore_batch = loader.restore_sample(batch.__restore_key__) + assert restore_batch.batch_index == batch.batch_index + assert restore_batch.sample_index == batch.sample_index + assert restore_batch.rand_num == batch.rand_num + + batch_rand_nums.append(restore_batch.rand_num) + assert np.allclose(restore_batch.image, batch.image) + + # For constructing the test data above: + print("batch_rand_nums: ", batch_rand_nums) + assert batch_rand_nums == ref_batch_rand_nums # Now, test multi-worker loader with accessing get_current_batch_index worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2) worker_config_r1 = WorkerConfig(rank=1, world_size=2, num_workers=2) - loader = get_loader( - get_train_dataset( - self.dataset_path, - batch_size=3, - task_encoder=TestTaskEncoder(), - worker_config=worker_config_r0, - shuffle_buffer_size=20, - max_samples_per_sequence=10, - ) - ) - loader_r1 = get_loader( - get_train_dataset( - self.dataset_path, - batch_size=3, - task_encoder=TestTaskEncoder(), - worker_config=worker_config_r1, - shuffle_buffer_size=20, - max_samples_per_sequence=10, + with ( + get_loader( + get_train_dataset( + self.dataset_path, + batch_size=3, + task_encoder=TestTaskEncoder(), + worker_config=worker_config_r0, + shuffle_buffer_size=20, + max_samples_per_sequence=10, + ) + ) as loader, + get_loader( + get_train_dataset( + self.dataset_path, + batch_size=3, + task_encoder=TestTaskEncoder(), + worker_config=worker_config_r1, + shuffle_buffer_size=20, + max_samples_per_sequence=10, + ) + ) as loader_r1, + ): + batches = list(zip(range(20), loader)) + print("bir0", [batch.batch_index for batch_idx, batch in batches]) + assert all( + all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches ) - ) - batches = list(zip(range(20), loader)) - print("bir0", [batch.batch_index for batch_idx, batch in batches]) - assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches) - - print("sir0", [batch.sample_index for batch_idx, batch in batches]) - # [[0, 0, 2], [1, 1, 3], [2, 4, 4], [3, 5, 5], [6, 6, 8], [7, 7, 9], [8, 10, 10], [9, 11, 11], [12, 12, 14], [13, 13, 15], [14, 16, 16], [15, 17, 17], [18, 18, 20], [19, 19, 21], [20, 22, 22], [21, 23, 23], [24, 24, 26], [25, 25, 27], [26, 28, 28], [27, 29, 29]] - assert all( - all( - si == batch_idx + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2 - for sample_offset, si in enumerate(batch.sample_index) + print("sir0", [batch.sample_index for batch_idx, batch in batches]) + # [[0, 0, 2], [1, 1, 3], [2, 4, 4], [3, 5, 5], [6, 6, 8], [7, 7, 9], [8, 10, 10], [9, 11, 11], [12, 12, 14], [13, 13, 15], [14, 16, 16], [15, 17, 17], [18, 18, 20], [19, 19, 21], [20, 22, 22], [21, 23, 23], [24, 24, 26], [25, 25, 27], [26, 28, 28], [27, 29, 29]] + assert all( + all( + si + == batch_idx + + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2 + for sample_offset, si in enumerate(batch.sample_index) + ) + for batch_idx, batch in batches ) - for batch_idx, batch in batches - ) - batches_r1 = list(zip(range(20), loader_r1)) - print("bir0", [batch.batch_index for batch_idx, batch in batches_r1]) - print("sir1", [batch.sample_index for batch_idx, batch in batches_r1]) - assert all( - all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches_r1 - ) - assert all( - all( - si == batch_idx + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2 - for sample_offset, si in enumerate(batch.sample_index) + batches_r1 = list(zip(range(20), loader_r1)) + print("bir0", [batch.batch_index for batch_idx, batch in batches_r1]) + print("sir1", [batch.sample_index for batch_idx, batch in batches_r1]) + assert all( + all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches_r1 + ) + assert all( + all( + si + == batch_idx + + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2 + for sample_offset, si in enumerate(batch.sample_index) + ) + for batch_idx, batch in batches_r1 ) - for batch_idx, batch in batches_r1 - ) # Now, test multi-worker loader with accessing get_current_batch_index and save/restore state - loader = get_savable_loader( - get_train_dataset( - self.dataset_path, - batch_size=3, - task_encoder=TestTaskEncoder(), - worker_config=worker_config_r0, - shuffle_buffer_size=20, - max_samples_per_sequence=10, - ), - ) - loader_r1 = get_savable_loader( - get_train_dataset( - self.dataset_path, - batch_size=3, - task_encoder=TestTaskEncoder(), - worker_config=worker_config_r1, - shuffle_buffer_size=20, - max_samples_per_sequence=10, - ), - ) - - batches = list(zip(range(20), loader)) - print("bi:", [batch.batch_index for batch_idx, batch in batches]) - print("si:", [batch.sample_index for batch_idx, batch in batches]) - assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches) - assert all( - all( - si == batch_idx + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2 - for sample_offset, si in enumerate(batch.sample_index) + with ( + get_savable_loader( + get_train_dataset( + self.dataset_path, + batch_size=3, + task_encoder=TestTaskEncoder(), + worker_config=worker_config_r0, + shuffle_buffer_size=20, + max_samples_per_sequence=10, + ), + ) as loader, + get_savable_loader( + get_train_dataset( + self.dataset_path, + batch_size=3, + task_encoder=TestTaskEncoder(), + worker_config=worker_config_r1, + shuffle_buffer_size=20, + max_samples_per_sequence=10, + ), + ) as loader_r1, + ): + batches = list(zip(range(20), loader)) + print("bi:", [batch.batch_index for batch_idx, batch in batches]) + print("si:", [batch.sample_index for batch_idx, batch in batches]) + assert all( + all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches + ) + assert all( + all( + si + == batch_idx + + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2 + for sample_offset, si in enumerate(batch.sample_index) + ) + for batch_idx, batch in batches ) - for batch_idx, batch in batches - ) - batches_r1 = list(zip(range(20), loader_r1)) - print([batch.batch_index for batch_idx, batch in batches_r1]) - assert all( - all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches_r1 - ) - assert all( - all( - si == batch_idx + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2 - for sample_offset, si in enumerate(batch.sample_index) + batches_r1 = list(zip(range(20), loader_r1)) + print([batch.batch_index for batch_idx, batch in batches_r1]) + assert all( + all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches_r1 + ) + assert all( + all( + si + == batch_idx + + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2 + for sample_offset, si in enumerate(batch.sample_index) + ) + for batch_idx, batch in batches_r1 ) - for batch_idx, batch in batches_r1 - ) - # Save and restore state - state = loader.save_state_rank() + # Save and restore state + state = loader.save_state_rank() - # Iter next 20 from the loader - cmp_batches = list(zip(range(20, 40), loader)) - print("bi:", [batch.batch_index for batch_idx, batch in cmp_batches]) - print("si:", [batch.sample_index for batch_idx, batch in cmp_batches]) - print("rnd:", [batch.rand_num for batch_idx, batch in cmp_batches]) - assert all( - all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in cmp_batches - ) - assert all( - all( - si == batch_idx + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2 - for sample_offset, si in enumerate(batch.sample_index) + # Iter next 20 from the loader + cmp_batches = list(zip(range(20, 40), loader)) + print("bi:", [batch.batch_index for batch_idx, batch in cmp_batches]) + print("si:", [batch.sample_index for batch_idx, batch in cmp_batches]) + print("rnd:", [batch.rand_num for batch_idx, batch in cmp_batches]) + assert all( + all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in cmp_batches + ) + assert all( + all( + si + == batch_idx + + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2 + for sample_offset, si in enumerate(batch.sample_index) + ) + for batch_idx, batch in cmp_batches ) - for batch_idx, batch in cmp_batches - ) # Restore state and check if the batch index is restored correctly - loader = get_savable_loader( + with get_savable_loader( get_train_dataset( self.dataset_path, batch_size=3, @@ -1332,25 +1349,27 @@ def encode_sample(self, sample): shuffle_buffer_size=20, max_samples_per_sequence=10, ), - ) - loader.restore_state_rank(state) - - batches = list(zip(range(20, 40), loader)) - print("bi:", [batch.batch_index for batch_idx, batch in batches]) - print("si:", [batch.sample_index for batch_idx, batch in batches]) - print("rnd:", [batch.rand_num for batch_idx, batch in batches]) - assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches) - assert all( - all( - si == batch_idx + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2 - for sample_offset, si in enumerate(batch.sample_index) + ).with_restored_state_rank(state) as loader: + batches = list(zip(range(20, 40), loader)) + print("bi:", [batch.batch_index for batch_idx, batch in batches]) + print("si:", [batch.sample_index for batch_idx, batch in batches]) + print("rnd:", [batch.rand_num for batch_idx, batch in batches]) + assert all( + all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches + ) + assert all( + all( + si + == batch_idx + + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2 + for sample_offset, si in enumerate(batch.sample_index) + ) + for batch_idx, batch in batches + ) + assert all( + all(b1s == b2s for b1s, b2s in zip(b1.rand_num, b2.rand_num)) + for (_b1idx, b1), (_b2idx, b2) in zip(batches, cmp_batches) ) - for batch_idx, batch in batches - ) - assert all( - all(b1s == b2s for b1s, b2s in zip(b1.rand_num, b2.rand_num)) - for (_b1idx, b1), (_b2idx, b2) in zip(batches, cmp_batches) - ) def test_packing(self): torch.manual_seed(42) @@ -1384,7 +1403,7 @@ def pack_selected_samples( caption=torch.cat([sample.caption for sample in samples]), ) - loader = get_loader( + with get_loader( get_train_dataset( self.dataset_path, batch_size=2, @@ -1395,39 +1414,38 @@ def pack_selected_samples( max_samples_per_sequence=None, task_encoder=TestTaskEncoder(), ) - ) - - assert len(loader) == 6 - - samples = list(loader) - - print([batch.__key__ for batch in samples]) - print([batch.__restore_key__ for batch in samples]) - print([len(batch.__key__) for batch in samples]) - print([[len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples]) - - # Each batch should have 2 samples - assert [len(batch.__key__) for batch in samples] == [ - 2, - 2, - 2, - 2, - 2, - 2, - ] - - # The packs of lengths 1, 4, 16 should be unrolled repeatedly across the batches of size 2 - assert [ - [len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples - ] == [[1, 4], [16, 1], [4, 16], [1, 4], [16, 1], [4, 16]] - - restored_sample_1 = loader.restore_sample(samples[1].__restore_key__) - assert restored_sample_1.__key__ == samples[1].__key__ - assert restored_sample_1.__restore_key__ == samples[1].__restore_key__ + ) as loader: + assert len(loader) == 6 + + samples = list(loader) + + print([batch.__key__ for batch in samples]) + print([batch.__restore_key__ for batch in samples]) + print([len(batch.__key__) for batch in samples]) + print([[len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples]) + + # Each batch should have 2 samples + assert [len(batch.__key__) for batch in samples] == [ + 2, + 2, + 2, + 2, + 2, + 2, + ] + + # The packs of lengths 1, 4, 16 should be unrolled repeatedly across the batches of size 2 + assert [ + [len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples + ] == [[1, 4], [16, 1], [4, 16], [1, 4], [16, 1], [4, 16]] + + restored_sample_1 = loader.restore_sample(samples[1].__restore_key__) + assert restored_sample_1.__key__ == samples[1].__key__ + assert restored_sample_1.__restore_key__ == samples[1].__restore_key__ worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2) - loader_r0 = get_savable_loader( + with get_savable_loader( get_train_dataset( self.dataset_path, batch_size=2, @@ -1438,24 +1456,24 @@ def pack_selected_samples( max_samples_per_sequence=None, task_encoder=TestTaskEncoder(), ), - ) - - samples_r0 = list(loader_r0) - assert [ - [len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples_r0 - ] == [[1, 4], [1, 4], [16, 1], [16, 1], [4, 16], [4, 16], [1, 4], [1, 4]] - - restored_sample_1 = loader_r0.restore_sample(samples_r0[1].__restore_key__) - assert restored_sample_1.__key__ == samples_r0[1].__key__ - assert restored_sample_1.__restore_key__ == samples_r0[1].__restore_key__ - - rank_state_r0 = loader_r0.save_state_rank() - samples_r0_cmp = list(loader_r0) - assert [ - [len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples_r0_cmp - ] == [[16, 1], [16, 1], [4, 16], [4, 16], [1, 4], [1, 4], [16, 1], [16, 1]] - - loader_r0 = get_savable_loader( + ) as loader_r0: + samples_r0 = list(loader_r0) + assert [ + [len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples_r0 + ] == [[1, 4], [1, 4], [16, 1], [16, 1], [4, 16], [4, 16], [1, 4], [1, 4]] + + restored_sample_1 = loader_r0.restore_sample(samples_r0[1].__restore_key__) + assert restored_sample_1.__key__ == samples_r0[1].__key__ + assert restored_sample_1.__restore_key__ == samples_r0[1].__restore_key__ + + rank_state_r0 = loader_r0.save_state_rank() + samples_r0_cmp = list(loader_r0) + assert [ + [len(batch_key.split(",")) for batch_key in batch.__key__] + for batch in samples_r0_cmp + ] == [[16, 1], [16, 1], [4, 16], [4, 16], [1, 4], [1, 4], [16, 1], [16, 1]] + + with get_savable_loader( get_train_dataset( self.dataset_path, batch_size=2, @@ -1466,19 +1484,18 @@ def pack_selected_samples( max_samples_per_sequence=None, task_encoder=TestTaskEncoder(), ), - ) - - loader_r0.restore_state_rank(rank_state_r0) - - samples_r0_restored = list(loader_r0) - print("cmp", [batch.__key__ for batch in samples_r0_cmp]) - print("rst", [batch.__key__ for batch in samples_r0_restored]) - assert [ - [len(batch_key.split(",")) for batch_key in batch.__key__] - for batch in samples_r0_restored - ] == [[16, 1], [16, 1], [4, 16], [4, 16], [1, 4], [1, 4], [16, 1], [16, 1]] - - assert all(s0.__key__ == s1.__key__ for s0, s1 in zip(samples_r0_cmp, samples_r0_restored)) + ).with_restored_state_rank(rank_state_r0) as loader_r0: + samples_r0_restored = list(loader_r0) + print("cmp", [batch.__key__ for batch in samples_r0_cmp]) + print("rst", [batch.__key__ for batch in samples_r0_restored]) + assert [ + [len(batch_key.split(",")) for batch_key in batch.__key__] + for batch in samples_r0_restored + ] == [[16, 1], [16, 1], [4, 16], [4, 16], [1, 4], [1, 4], [16, 1], [16, 1]] + + assert all( + s0.__key__ == s1.__key__ for s0, s1 in zip(samples_r0_cmp, samples_r0_restored) + ) def test_packing_val(self): torch.manual_seed(42) @@ -1519,7 +1536,7 @@ def pack_selected_samples( caption=torch.cat([sample.caption for sample in samples]), ) - loader = get_loader( + with get_loader( get_val_dataset( self.dataset_path, batch_size=2, @@ -1528,37 +1545,36 @@ def pack_selected_samples( task_encoder=TestTaskEncoder(), split_part="train", ) - ) - - assert len(loader) == 25, f"len(loader) == {len(loader)}" - - samples = list(loader) - - print([batch.__key__ for batch in samples]) - print([batch.__restore_key__ for batch in samples]) - print([len(batch.__key__) for batch in samples]) - print([[len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples]) - - # Each batch should have 2 samples - assert [len(batch.__key__) for batch in samples] == [ - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, - ] - - # The packs of lengths 1, 4, 16 should be unrolled repeatedly across the batches of size 2 - assert [ - [len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples - ] == [[2, 5], [3, 1], [2, 5], [3, 1], [2, 5], [3, 1], [2, 5], [3, 1]] - - restored_sample_1 = loader.restore_sample(samples[1].__restore_key__) - assert restored_sample_1.__key__ == samples[1].__key__ - assert restored_sample_1.__restore_key__ == samples[1].__restore_key__ + ) as loader: + assert len(loader) == 25, f"len(loader) == {len(loader)}" + + samples = list(loader) + + print([batch.__key__ for batch in samples]) + print([batch.__restore_key__ for batch in samples]) + print([len(batch.__key__) for batch in samples]) + print([[len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples]) + + # Each batch should have 2 samples + assert [len(batch.__key__) for batch in samples] == [ + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + ] + + # The packs of lengths 1, 4, 16 should be unrolled repeatedly across the batches of size 2 + assert [ + [len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples + ] == [[2, 5], [3, 1], [2, 5], [3, 1], [2, 5], [3, 1], [2, 5], [3, 1]] + + restored_sample_1 = loader.restore_sample(samples[1].__restore_key__) + assert restored_sample_1.__key__ == samples[1].__key__ + assert restored_sample_1.__restore_key__ == samples[1].__restore_key__ def test_group_batch(self): class GroupingTaskEncoder( @@ -1582,7 +1598,7 @@ def encode_batch(self, batch: CaptioningSample) -> CaptioningEncodedBatch: return CaptioningEncodedBatch.extend(batch) worker_config = WorkerConfig(rank=0, world_size=1, num_workers=0) - loader = get_savable_loader( + with get_savable_loader( get_train_dataset( self.dataset_path, batch_size=None, @@ -1591,16 +1607,18 @@ def encode_batch(self, batch: CaptioningSample) -> CaptioningEncodedBatch: max_samples_per_sequence=None, task_encoder=GroupingTaskEncoder(), ), - ) - batches = list(zip(range(40), loader)) - print([batch.__key__ for idx, batch in batches]) + ) as loader: + batches = list(zip(range(40), loader)) + print([batch.__key__ for idx, batch in batches]) - assert all(isinstance(batch, CaptioningEncodedBatch) for idx, batch in batches) - assert all(all(key == batch.caption[0] for key in batch.caption) for idx, batch in batches) + assert all(isinstance(batch, CaptioningEncodedBatch) for idx, batch in batches) + assert all( + all(key == batch.caption[0] for key in batch.caption) for idx, batch in batches + ) worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2) - loader_r0 = get_savable_loader( + with get_savable_loader( get_train_dataset( self.dataset_path, batch_size=None, @@ -1609,21 +1627,22 @@ def encode_batch(self, batch: CaptioningSample) -> CaptioningEncodedBatch: max_samples_per_sequence=None, task_encoder=GroupingTaskEncoder(), ), - ) + ) as loader_r0: + batches = list(zip(range(40), loader_r0)) - batches = list(zip(range(40), loader_r0)) + print([batch.__key__ for idx, batch in batches]) - print([batch.__key__ for idx, batch in batches]) - - assert all(isinstance(batch, CaptioningEncodedBatch) for idx, batch in batches) - assert all(all(key == batch.caption[0] for key in batch.caption) for idx, batch in batches) + assert all(isinstance(batch, CaptioningEncodedBatch) for idx, batch in batches) + assert all( + all(key == batch.caption[0] for key in batch.caption) for idx, batch in batches + ) - state = loader_r0.save_state_rank() + state = loader_r0.save_state_rank() - cmp_samples = list(zip(range(40, 80), loader_r0)) - print([batch.__key__ for idx, batch in cmp_samples]) + cmp_samples = list(zip(range(40, 80), loader_r0)) + print([batch.__key__ for idx, batch in cmp_samples]) - loader_r0 = get_savable_loader( + with get_savable_loader( get_train_dataset( self.dataset_path, batch_size=None, @@ -1632,24 +1651,22 @@ def encode_batch(self, batch: CaptioningSample) -> CaptioningEncodedBatch: max_samples_per_sequence=None, task_encoder=GroupingTaskEncoder(), ), - ) - loader_r0.restore_state_rank(state) - - cmp_samples_rest = list(zip(range(40, 80), loader_r0)) - print([batch.__key__ for idx, batch in cmp_samples_rest]) - - assert len(cmp_samples) == len(cmp_samples_rest) - assert all( - len(cmp_sample.caption) == len(cmp_sample_rest.caption) - for (idx, cmp_sample), (idx, cmp_sample_rest) in zip(cmp_samples, cmp_samples_rest) - ) - assert all( - all( - cmp_cap == cmp_cap_rest - for cmp_cap, cmp_cap_rest in zip(cmp_sample.caption, cmp_sample_rest.caption) + ).with_restored_state_rank(state) as loader_r0: + cmp_samples_rest = list(zip(range(40, 80), loader_r0)) + print([batch.__key__ for idx, batch in cmp_samples_rest]) + + assert len(cmp_samples) == len(cmp_samples_rest) + assert all( + len(cmp_sample.caption) == len(cmp_sample_rest.caption) + for (idx, cmp_sample), (idx, cmp_sample_rest) in zip(cmp_samples, cmp_samples_rest) + ) + assert all( + all( + cmp_cap == cmp_cap_rest + for cmp_cap, cmp_cap_rest in zip(cmp_sample.caption, cmp_sample_rest.caption) + ) + for (idx, cmp_sample), (idx, cmp_sample_rest) in zip(cmp_samples, cmp_samples_rest) ) - for (idx, cmp_sample), (idx, cmp_sample_rest) in zip(cmp_samples, cmp_samples_rest) - ) def test_debug_dataset(self): torch.manual_seed(42) @@ -1664,39 +1681,38 @@ def test_debug_dataset(self): # Reset this to 0 to make sure the test is deterministic DataLoader._next_id = 0 - loader = get_savable_loader( + with get_savable_loader( get_val_dataset( self.dataset_path, split_part="train", batch_size=5, worker_config=worker_config, ), - ) - - assert len(loader) == 10 - - samples = [[batch.__key__ for batch in loader] for _ in range(2)] - print(samples) - - debug_log_path = self.dataset_path / "worker_debug" - assert (debug_log_path / "0.jsonl").is_file() - assert (debug_log_path / "1.jsonl").is_file() - assert (debug_log_path / "2.jsonl").is_file() - - collected_keys_order = [[None] * 10 for _ in range(2)] - with (debug_log_path / "0.jsonl").open() as rf: - for line in rf: - line_data = json.loads(line) - print(line_data) - if line_data["t"] == "DataLoader.epoch_iter.yield": - for i in range(len(collected_keys_order)): - if collected_keys_order[i][line_data["epoch_sample_idx"]] is None: - collected_keys_order[i][line_data["epoch_sample_idx"]] = line_data[ - "keys" - ] - break - else: - assert False, "Too many entries for key" + ) as loader: + assert len(loader) == 10 + + samples = [[batch.__key__ for batch in loader] for _ in range(2)] + print(samples) + + debug_log_path = self.dataset_path / "worker_debug" + assert (debug_log_path / "0.jsonl").is_file() + assert (debug_log_path / "1.jsonl").is_file() + assert (debug_log_path / "2.jsonl").is_file() + + collected_keys_order = [[None] * 10 for _ in range(2)] + with (debug_log_path / "0.jsonl").open() as rf: + for line in rf: + line_data = json.loads(line) + print(line_data) + if line_data["t"] == "DataLoader.epoch_iter.yield": + for i in range(len(collected_keys_order)): + if collected_keys_order[i][line_data["epoch_sample_idx"]] is None: + collected_keys_order[i][line_data["epoch_sample_idx"]] = line_data[ + "keys" + ] + break + else: + assert False, "Too many entries for key" print(collected_keys_order) assert collected_keys_order == samples diff --git a/tests/test_dataset_det.py b/tests/test_dataset_det.py index c12cc943..ab90daad 100644 --- a/tests/test_dataset_det.py +++ b/tests/test_dataset_det.py @@ -166,26 +166,25 @@ def test_split_parts(self): training=False, sample_type=TextSample, ) - dl = get_loader(ds.build()) - - all_keys = [sample.__key__ for sample in dl] - assert all_keys == [ - "parts/data-4.tar/000011", # Shard 4 first - "parts/data-4.tar/000012", - "parts/data-4.tar/000013", - "parts/data-4.tar/000014", - "parts/data-4.tar/000015", - "parts/data-4.tar/000016", - "parts/data-4.tar/000017", - "parts/data-4.tar/000018", - "parts/data-4.tar/000019", - "parts/data-4.tar/000020", - "parts/data-0.tar/000000", # Shard 0 - "parts/data-0.tar/000001", - "parts/data-2.tar/000004", # Shard 2 - "parts/data-2.tar/000005", - "parts/data-2.tar/000006", - ] + with get_loader(ds.build()) as dl: + all_keys = [sample.__key__ for sample in dl] + assert all_keys == [ + "parts/data-4.tar/000011", # Shard 4 first + "parts/data-4.tar/000012", + "parts/data-4.tar/000013", + "parts/data-4.tar/000014", + "parts/data-4.tar/000015", + "parts/data-4.tar/000016", + "parts/data-4.tar/000017", + "parts/data-4.tar/000018", + "parts/data-4.tar/000019", + "parts/data-4.tar/000020", + "parts/data-0.tar/000000", # Shard 0 + "parts/data-0.tar/000001", + "parts/data-2.tar/000004", # Shard 2 + "parts/data-2.tar/000005", + "parts/data-2.tar/000006", + ] def test_text_dataset(self): worker_config = WorkerConfig(rank=0, world_size=1, num_workers=0) @@ -201,15 +200,15 @@ def test_text_dataset(self): # Check len operator assert len(ds) == 55 # Check if iterating returns the same - iter1 = list(get_loader(ds)) - iter2 = list(get_loader(ds)) + with get_loader(ds) as l1: + iter1 = list(l1) + with get_loader(ds) as l2: + iter2 = list(l2) assert len(iter1) == 55 assert len(iter2) == 55 assert all(elem1.__key__ == elem2.__key__ for elem1, elem2 in zip(iter1, iter2)) - assert all(f"{idx}" == x.text for idx, x in enumerate(get_loader(ds))) - - del ds - gc.collect() + with get_loader(ds) as l3: + assert all(f"{idx}" == x.text for idx, x in enumerate(l3)) def test_epoch(self): torch.manual_seed(42) @@ -224,11 +223,11 @@ def test_epoch(self): sample_type=TextSample, worker_config=worker_config, ) - loader5 = get_loader(ds3.build()) - order9 = [data.text for idx, data in zip(range(55), loader5)] - print(order9) - print(Counter(order9)) - assert all(v == 1 for v in Counter(order9).values()) + with get_loader(ds3.build()) as loader5: + order9 = [data.text for idx, data in zip(range(55), loader5)] + print(order9) + print(Counter(order9)) + assert all(v == 1 for v in Counter(order9).values()) def test_determinism(self): worker_config2 = WorkerConfig(rank=0, world_size=1, num_workers=2) @@ -275,33 +274,27 @@ def test_determinism(self): ) # Fork the dataset twice - loader1 = get_loader(ds1) - loader2 = get_loader(ds1) - - order4 = [data.text[0] for idx, data in zip(range(55 * 20), loader1)] - order5 = [data.text[0] for idx, data in zip(range(55 * 20), loader1)] - order6 = [data.text[0] for idx, data in zip(range(55 * 20), loader2)] - print(order4) - print(Counter(order4)) - # +-1 is possible due to the random shuffling (actually +-2 is possible) - assert all(17 <= v <= 22 for v in Counter(order4).values()) - - assert order4 != order5 - assert order4 == order6 - - loader3 = get_loader(ds1b) - order7 = [data.text[0] for idx, data in zip(range(55 * 20), loader3)] - assert order6 != order7 - - loader4 = get_loader(ds3) - order8 = [data.text[0] for idx, data in zip(range(55 * 100), loader4)] - assert order6 != order8[: len(order6)] - print(Counter(order8)) - assert all(90 <= v <= 110 for v in Counter(order8).values()) - - # Delete all locals, otherwise loaders might be kept alive - locals().clear() - gc.collect() + with get_loader(ds1) as loader1, get_loader(ds2) as loader2: + order4 = [data.text[0] for idx, data in zip(range(55 * 20), loader1)] + order5 = [data.text[0] for idx, data in zip(range(55 * 20), loader1)] + order6 = [data.text[0] for idx, data in zip(range(55 * 20), loader2)] + print(order4) + print(Counter(order4)) + # +-1 is possible due to the random shuffling (actually +-2 is possible) + assert all(17 <= v <= 22 for v in Counter(order4).values()) + + assert order4 != order5 + assert order4 == order6 + + with get_loader(ds1b) as loader3: + order7 = [data.text[0] for idx, data in zip(range(55 * 20), loader3)] + assert order6 != order7 + + with get_loader(ds3) as loader4: + order8 = [data.text[0] for idx, data in zip(range(55 * 100), loader4)] + assert order6 != order8[: len(order6)] + print(Counter(order8)) + assert all(90 <= v <= 110 for v in Counter(order8).values()) def test_determinism_taskencoder(self): class TestTaskEncoder(DefaultTaskEncoder): @@ -344,17 +337,12 @@ def encode_sample(self, sample: TextSample) -> TextSample: ) # Fork the dataset twice - loader1a = get_loader(ds1a) - loader1b = get_loader(ds1b) - - order1a = [data.text[0] for idx, data in zip(range(55 * 20), loader1a)] - order1b = [data.text[0] for idx, data in zip(range(55 * 20), loader1b)] + with get_loader(ds1a) as loader1a, get_loader(ds1b) as loader1b: + order1a = [data.text[0] for idx, data in zip(range(55 * 20), loader1a)] + order1b = [data.text[0] for idx, data in zip(range(55 * 20), loader1b)] - assert order1a == order1b - - # Delete all locals, otherwise loaders might be kept alive - locals().clear() - gc.collect() + assert order1a == order1b + assert order1a == order1b def test_determinism_taskencoder_save_restore(self): class TestTaskEncoder(DefaultTaskEncoder): @@ -403,34 +391,27 @@ def encode_sample(self, sample: TextSample) -> TextSample: ) # Fork the dataset twice - loader1a = get_savable_loader(ds1a) - loader1b = get_savable_loader(ds1b) - - # Load 7 samples - data_pre = [data.text[0] for idx, data in zip(range(7), loader1a)] - - # Then save state - state = loader1a.save_state_rank() + with get_savable_loader(ds1a) as loader1a: + # Load 7 samples + _data_pre = [data.text[0] for idx, data in zip(range(7), loader1a)] - print("iterating loader1a") - # Load another 20 samples - data_post = [data.text[0] for idx, data in zip(range(20), loader1a)] + # Then save state + state = loader1a.save_state_rank() - # Restore state - loader1b.restore_state_rank(state) + print("iterating loader1a") + # Load another 20 samples + data_post = [data.text[0] for idx, data in zip(range(20), loader1a)] - print("iterating loader1b") - # Load 20 samples again - data_restored = [data.text[0] for idx, data in zip(range(20), loader1b)] + # Restore state + with get_savable_loader(ds1b).with_restored_state_rank(state) as loader1b: + print("iterating loader1b") + # Load 20 samples again + data_restored = [data.text[0] for idx, data in zip(range(20), loader1b)] - print("Data post:", data_post) - print("Data restored:", data_restored) + print("Data post:", data_post) + print("Data restored:", data_restored) - assert data_post == data_restored - - # Delete all locals, otherwise loaders might be kept alive - locals().clear() - gc.collect() + assert data_post == data_restored def test_restore_state(self): worker_config = WorkerConfig(rank=0, world_size=1, num_workers=0) @@ -446,7 +427,7 @@ def test_restore_state(self): # This seed is used by the dataset to shuffle the data torch.manual_seed(42) - loader = get_savable_loader( + with get_savable_loader( get_train_dataset( self.dataset_path, split_part="train", @@ -457,48 +438,23 @@ def test_restore_state(self): max_samples_per_sequence=2, parallel_shard_iters=psi, ) - ) - - # print("save state") - state_0 = loader.save_state_global(global_dst_rank=0) - # print("save state done") - order_1 = [data.text[0] for idx, data in zip(range(count1), loader)] - assert len(order_1) == count1 - # print("save state") - state_1 = loader.save_state_global(global_dst_rank=0) - # print("save state done") - order_2 = [data.text[0] for idx, data in zip(range(count2), loader)] - assert len(order_2) == count2 + ) as loader: + # print("save state") + state_0 = loader.save_state_global(global_dst_rank=0) + # print("save state done") + order_1 = [data.text[0] for idx, data in zip(range(count1), loader)] + assert len(order_1) == count1 + # print("save state") + state_1 = loader.save_state_global(global_dst_rank=0) + # print("save state done") + order_2 = [data.text[0] for idx, data in zip(range(count2), loader)] + assert len(order_2) == count2 - print("state0", state_0) - print("state1", state_1) + print("state0", state_0) + print("state1", state_1) torch.manual_seed(213) - loader = get_savable_loader( - get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=sbs, - max_samples_per_sequence=2, - parallel_shard_iters=psi, - ) - ) - loader.restore_state_global(state_0, src_rank=None) - order_45 = [data.text[0] for idx, data in zip(range(count1 + count2), loader)] - order_4 = order_45[:count1] - order_5 = order_45[count1:] - # print("order1", order_1) - # print("order2", order_2) - # print("order4", order_4) - assert order_1 == order_4 - # print("order5", order_5) - assert order_2 == order_5 - - torch.manual_seed(145) - loader = get_savable_loader( + with get_savable_loader( get_train_dataset( self.dataset_path, split_part="train", @@ -509,15 +465,35 @@ def test_restore_state(self): max_samples_per_sequence=2, parallel_shard_iters=psi, ) - ) - # print("restore state") - loader.restore_state_global(state_1, src_rank=None) - # print("restore state done") - order_3 = [data.text[0] for idx, data in zip(range(count2), loader)] - # print("order1", order_1) - # print("order2", order_2[:100]) - # print("order3", order_3[:100]) - assert order_2 == order_3 + ).with_restored_state_global(state_0, src_rank=None) as loader: + order_45 = [data.text[0] for idx, data in zip(range(count1 + count2), loader)] + order_4 = order_45[:count1] + order_5 = order_45[count1:] + # print("order1", order_1) + # print("order2", order_2) + # print("order4", order_4) + assert order_1 == order_4 + # print("order5", order_5) + assert order_2 == order_5 + + torch.manual_seed(145) + with get_savable_loader( + get_train_dataset( + self.dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=sbs, + max_samples_per_sequence=2, + parallel_shard_iters=psi, + ) + ).with_restored_state_global(state_1, src_rank=None) as loader: + order_3 = [data.text[0] for idx, data in zip(range(count2), loader)] + # print("order1", order_1) + # print("order2", order_2[:100]) + # print("order3", order_3[:100]) + assert order_2 == order_3 def test_restore_state_dist(self): from multiprocessing import Manager, Process @@ -537,7 +513,7 @@ def phase1(rank: int, world_size: int, shared_dict: dict): # This seed is used by the dataset to shuffle the data torch.manual_seed(42) - loader = get_savable_loader( + with get_savable_loader( get_train_dataset( self.dataset_path, split_part="train", @@ -548,24 +524,23 @@ def phase1(rank: int, world_size: int, shared_dict: dict): max_samples_per_sequence=2, parallel_shard_iters=psi, ) - ) - - state_0 = loader.save_state_global(global_dst_rank=0) - order_1 = [data.text[0] for idx, data in zip(range(count1), loader)] - assert len(order_1) == count1 + ) as loader: + state_0 = loader.save_state_global(global_dst_rank=0) + order_1 = [data.text[0] for idx, data in zip(range(count1), loader)] + assert len(order_1) == count1 - # print(f"Rank {rank}: order_1", order_1) + # print(f"Rank {rank}: order_1", order_1) - state_1 = loader.save_state_global(global_dst_rank=0) - order_2 = [data.text[0] for idx, data in zip(range(count2), loader)] - assert len(order_2) == count2 + state_1 = loader.save_state_global(global_dst_rank=0) + order_2 = [data.text[0] for idx, data in zip(range(count2), loader)] + assert len(order_2) == count2 - shared_dict[(rank, "order_1")] = order_1 - shared_dict[(rank, "order_2")] = order_2 + shared_dict[(rank, "order_1")] = order_1 + shared_dict[(rank, "order_2")] = order_2 - if rank == 0: - shared_dict["state_0"] = state_0 - shared_dict["state_1"] = state_1 + if rank == 0: + shared_dict["state_0"] = state_0 + shared_dict["state_1"] = state_1 def phase2(rank: int, world_size: int, shared_dict: dict): order_1 = shared_dict[(rank, "order_1")] @@ -581,7 +556,7 @@ def phase2(rank: int, world_size: int, shared_dict: dict): worker_config = WorkerConfig(rank=rank, world_size=world_size, num_workers=0) torch.manual_seed(213) - loader = get_savable_loader( + with get_savable_loader( get_train_dataset( self.dataset_path, split_part="train", @@ -592,20 +567,18 @@ def phase2(rank: int, world_size: int, shared_dict: dict): max_samples_per_sequence=2, parallel_shard_iters=psi, ) - ) - loader.restore_state_global(state_0, src_rank=0) + ).with_restored_state_global(state_0, src_rank=0) as loader: + order_45 = [data.text[0] for idx, data in zip(range(count1 + count2), loader)] + order_4 = order_45[:count1] + order_5 = order_45[count1:] - order_45 = [data.text[0] for idx, data in zip(range(count1 + count2), loader)] - order_4 = order_45[:count1] - order_5 = order_45[count1:] + # print(f"Rank {rank}: order_4", order_4) - # print(f"Rank {rank}: order_4", order_4) - - assert order_1 == order_4 - assert order_2 == order_5 + assert order_1 == order_4 + assert order_2 == order_5 torch.manual_seed(213) - loader = get_savable_loader( + with get_savable_loader( get_train_dataset( self.dataset_path, split_part="train", @@ -616,10 +589,9 @@ def phase2(rank: int, world_size: int, shared_dict: dict): max_samples_per_sequence=2, parallel_shard_iters=psi, ) - ) - loader.restore_state_global(state_1, src_rank=0) - order_3 = [data.text[0] for idx, data in zip(range(count2), loader)] - assert order_2 == order_3 + ).with_restored_state_global(state_1, src_rank=0) as loader: + order_3 = [data.text[0] for idx, data in zip(range(count2), loader)] + assert order_2 == order_3 def init_process(rank, world_size, shared_dict, fn, backend="gloo"): """Initializes the distributed environment.""" @@ -676,28 +648,27 @@ def test_restore_state_workers(self): max_samples_per_sequence=2, parallel_shard_iters=psi, ) - loader = get_savable_loader(ds) - - # print("save state") - state_0 = loader.save_state_rank() - it1 = iter(loader) - # print("save state done") - order_1 = [data.text[0] for idx, data in zip(range(n1), it1)] - # print("save state") - # time.sleep(0.5) - state_1 = loader.save_state_rank() - # print("save state done") - order_2 = [data.text[0] for idx, data in zip(range(n2), it1)] - state_2 = loader.save_state_rank() - order_3 = [data.text[0] for idx, data in zip(range(n3), it1)] - - print("order_1", order_1) - print("order_2", order_2) - print("order_3", order_3) - - # print("state0", state_0) - print("state1", state_1) - print("state2", state_2) + with get_savable_loader(ds) as loader: + # print("save state") + state_0 = loader.save_state_rank() + it1 = iter(loader) + # print("save state done") + order_1 = [data.text[0] for idx, data in zip(range(n1), it1)] + # print("save state") + # time.sleep(0.5) + state_1 = loader.save_state_rank() + # print("save state done") + order_2 = [data.text[0] for idx, data in zip(range(n2), it1)] + state_2 = loader.save_state_rank() + order_3 = [data.text[0] for idx, data in zip(range(n3), it1)] + + print("order_1", order_1) + print("order_2", order_2) + print("order_3", order_3) + + # print("state0", state_0) + print("state1", state_1) + print("state2", state_2) # Restoring the state of a new dataset should also yield the same torch.manual_seed(42) @@ -711,12 +682,11 @@ def test_restore_state_workers(self): max_samples_per_sequence=2, parallel_shard_iters=psi, ) - loader = get_savable_loader(ds) - loader.restore_state_rank(state_0) - order_6 = [data.text[0] for idx, data in zip(range(n1), loader)] - print("order1", order_1) - print("order6", order_6) - assert order_6 == order_1 + with get_savable_loader(ds).with_restored_state_rank(state_0) as loader: + order_6 = [data.text[0] for idx, data in zip(range(n1), loader)] + print("order1", order_1) + print("order6", order_6) + assert order_6 == order_1 # Restoring the state of a new dataset should also yield the same torch.manual_seed(42) @@ -730,12 +700,11 @@ def test_restore_state_workers(self): max_samples_per_sequence=2, parallel_shard_iters=psi, ) - loader = get_savable_loader(ds) - loader.restore_state_rank(state_1) - order_7 = [data.text[0] for idx, data in zip(range(n2), loader)] - print("order2", order_2[:100]) - print("order7", order_7[:100]) - assert order_7 == order_2 + with get_savable_loader(ds).with_restored_state_rank(state_1) as loader: + order_7 = [data.text[0] for idx, data in zip(range(n2), loader)] + print("order2", order_2[:100]) + print("order7", order_7[:100]) + assert order_7 == order_2 # Restoring the state of a new dataset should also yield the same torch.manual_seed(42) @@ -749,12 +718,11 @@ def test_restore_state_workers(self): shuffle_buffer_size=sbs, parallel_shard_iters=psi, ) - loader = get_savable_loader(ds) - loader.restore_state_rank(state_2) - order_8 = [data.text[0] for idx, data in zip(range(n3), loader)] - print("order3", order_3) - print("order8", order_8) - assert order_8 == order_3 + with get_savable_loader(ds).with_restored_state_rank(state_2) as loader: + order_8 = [data.text[0] for idx, data in zip(range(n3), loader)] + print("order3", order_3) + print("order8", order_8) + assert order_8 == order_3 def test_invariance_global_samples(self): # We'd like to ensure that the user can keep the same global batches @@ -827,15 +795,14 @@ def test_invariance_global_samples(self): shuffle_buffer_size=42, max_samples_per_sequence=2, ) - loader = get_loader(ds) - - micro_batches = [ - data.text - for idx, data in zip( - range(55 * 8 // (world_size * scenario["micro_batch_size"])), loader - ) - ] - batches_per_rank.append(micro_batches) + with get_loader(ds) as loader: + micro_batches = [ + data.text + for idx, data in zip( + range(55 * 8 // (world_size * scenario["micro_batch_size"])), loader + ) + ] + batches_per_rank.append(micro_batches) # Compose global batches global_batches_cur_rank = [] @@ -873,10 +840,6 @@ def test_invariance_global_samples(self): f"Global batch {i} of scenario {scenerio_idx} does not match." ) - # Delete all locals, otherwise loaders might be kept alive - locals().clear() - gc.collect() - def test_redist(self): scenarios = [ dict( @@ -1080,10 +1043,6 @@ def test_redist(self): f"Global batch {i} of scenario {scenerio_idx} does not match." ) - # Delete all locals, otherwise loaders might be kept alive - locals().clear() - gc.collect() - if __name__ == "__main__": unittest.main() diff --git a/tests/test_jsonl_dataset.py b/tests/test_jsonl_dataset.py index 0ffe3192..de32200d 100644 --- a/tests/test_jsonl_dataset.py +++ b/tests/test_jsonl_dataset.py @@ -145,15 +145,14 @@ def test_dataset(self): print(len(train_dataset)) assert len(train_dataset) == 55, f"Expected 55 samples, got {len(train_dataset)}" - train_loader1 = get_loader(train_dataset) - - train_order1 = [ - text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text - ] - print(train_order1[:10]) - print(Counter(train_order1)) - assert len(Counter(train_order1)) == 55 - assert all(v == 10 for v in Counter(train_order1).values()) + with get_loader(train_dataset) as train_loader1: + train_order1 = [ + text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text + ] + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(Counter(train_order1)) == 55 + assert all(v == 10 for v in Counter(train_order1).values()) def test_metadataset_all(self): torch.manual_seed(42) @@ -176,15 +175,14 @@ def test_metadataset_all(self): print(len(train_dataset)) assert len(train_dataset) == 55 * 3, f"Expected 55 * 3 samples, got {len(train_dataset)}" - train_loader1 = get_loader(train_dataset) - - train_order1 = [ - text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text - ] - print(train_order1[:10]) - print(Counter(train_order1)) - assert len(Counter(train_order1)) == 55 * 3 - assert all(2 <= v <= 5 for v in Counter(train_order1).values()) + with get_loader(train_dataset) as train_loader1: + train_order1 = [ + text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text + ] + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(Counter(train_order1)) == 55 * 3 + assert all(2 <= v <= 5 for v in Counter(train_order1).values()) def test_metadataset_multirank(self): torch.manual_seed(42) @@ -215,10 +213,9 @@ def test_metadataset_multirank(self): f"Expected {expected_lens[cur_rank]} samples, got {len(train_dataset)}" ) - train_loader1 = get_loader(train_dataset) - - for data in train_loader1: - sample_counts[int(data.text[0])] += 1 + with get_loader(train_dataset) as train_loader1: + for data in train_loader1: + sample_counts[int(data.text[0])] += 1 for i in range(55): assert sample_counts[i] == 1, ( @@ -246,7 +243,7 @@ def test_s3(self): # EPath(self.dataset_path).copy(EPath("msc://s3/test/dataset")) emu.add_file(self.dataset_path, "test/dataset") - train_dataset = get_loader( + with get_loader( get_train_dataset( mixed_mds_path, worker_config=WorkerConfig( @@ -260,13 +257,12 @@ def test_s3(self): virtual_epoch_length=55 * 10, task_encoder=SimpleCookingTaskEncoder(), ) - ) - - data = list(enumerate(train_dataset)) - assert len(data) == 55 * 10, len(data) - cnt = Counter(t for _, entry in data for t in entry.text) - assert len(cnt) == 55 * 3 - assert all(2 <= v <= 5 for v in cnt.values()) + ) as train_dataset: + data = list(enumerate(train_dataset)) + assert len(data) == 55 * 10, len(data) + cnt = Counter(t for _, entry in data for t in entry.text) + assert len(cnt) == 55 * 3 + assert all(2 <= v <= 5 for v in cnt.values()) def test_prepare(self): print("Creating new dataset") @@ -289,7 +285,7 @@ def test_prepare(self): torch.manual_seed(42) # Train mode dataset - train_loader = get_loader( + with get_loader( get_train_dataset( self.dataset_path / "ds_prep.jsonl", worker_config=WorkerConfig( @@ -303,14 +299,14 @@ def test_prepare(self): max_samples_per_sequence=None, task_encoder=SimpleCookingTaskEncoder(), ) - ) - assert len(train_loader) == 10, f"Expected 10 samples, got {len(train_loader)}" - - train_order1 = [text for _, data in zip(range(50), train_loader) for text in data.text] - print(train_order1[:10]) - print(Counter(train_order1)) - assert len(Counter(train_order1)) == 10 - assert all(v == 5 for v in Counter(train_order1).values()) + ) as train_loader: + assert len(train_loader) == 10, f"Expected 10 samples, got {len(train_loader)}" + + train_order1 = [text for _, data in zip(range(50), train_loader) for text in data.text] + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(Counter(train_order1)) == 10 + assert all(v == 5 for v in Counter(train_order1).values()) if __name__ == "__main__": diff --git a/tests/test_metadataset.py b/tests/test_metadataset.py index 75f9e8b0..03be3892 100644 --- a/tests/test_metadataset.py +++ b/tests/test_metadataset.py @@ -302,25 +302,24 @@ def test_metadataset(self): print(len(train_dataset)) assert len(train_dataset) == 11 - train_loader1 = get_loader(train_dataset) - - train_order1 = [ - text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text - ] - print(train_order1[:10]) - print(Counter(train_order1)) - assert len(Counter(train_order1)) == 110 - assert all(48 <= v <= 52 for v in Counter(train_order1).values()) - - train_subflavors = [ - subflavor["__subflavor__"] - for idx, data in zip(range(55), train_loader1) - for subflavor in data.__subflavors__ - ] - print("train_subflavors[:10]", train_subflavors[:10]) - print("Counter(train_subflavors)", Counter(train_subflavors)) - assert len(Counter(train_subflavors)) == 2 - assert all(250 <= v <= 300 for v in Counter(train_subflavors).values()) + with get_loader(train_dataset) as train_loader1: + train_order1 = [ + text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text + ] + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(Counter(train_order1)) == 110 + assert all(48 <= v <= 52 for v in Counter(train_order1).values()) + + train_subflavors = [ + subflavor["__subflavor__"] + for idx, data in zip(range(55), train_loader1) + for subflavor in data.__subflavors__ + ] + print("train_subflavors[:10]", train_subflavors[:10]) + print("Counter(train_subflavors)", Counter(train_subflavors)) + assert len(Counter(train_subflavors)) == 2 + assert all(250 <= v <= 300 for v in Counter(train_subflavors).values()) # Train mode dataset train_dataset = get_train_dataset( @@ -333,27 +332,25 @@ def test_metadataset(self): print(len(train_dataset)) assert len(train_dataset) == 11 - train_loader1 = get_loader(train_dataset) - - train_order1 = [ - text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text - ] - print(train_order1[:10]) - print(Counter(train_order1)) - assert len(Counter(train_order1)) == 110 - assert all(48 <= v <= 52 for v in Counter(train_order1).values()) + with get_loader(train_dataset) as train_loader1: + train_order1 = [ + text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text + ] + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(Counter(train_order1)) == 110 + assert all(48 <= v <= 52 for v in Counter(train_order1).values()) # Val mode dataset val_dataset = get_val_dataset(self.mds_path, worker_config=worker_config, batch_size=10) print(len(val_dataset)) assert len(val_dataset) == 11 - val_loader1 = get_loader(val_dataset) - - val_order1 = [text for data in val_loader1 for text in data.text] - assert len(val_order1) == 110 - print(Counter(val_order1)) - assert all(v == 1 for v in Counter(val_order1).values()) + with get_loader(val_dataset) as val_loader1: + val_order1 = [text for data in val_loader1 for text in data.text] + assert len(val_order1) == 110 + print(Counter(val_order1)) + assert all(v == 1 for v in Counter(val_order1).values()) def test_nested_metadataset(self): torch.manual_seed(42) @@ -417,76 +414,75 @@ def test_nested_metadataset(self): print(len(train_dataset)) assert len(train_dataset) == 22 - train_loader1 = get_loader(train_dataset) - - train_order1 = [ - text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text - ] - print(train_order1[:10]) - print(Counter(train_order1)) - assert len(Counter(train_order1)) == 110 - assert all(48 <= v <= 53 for v in Counter(train_order1).values()) - - train_subflavors = [ - subflavor.get("__subflavor__") - for idx, data in zip(range(55), train_loader1) - for subflavor in data.__subflavors__ - ] - cnt = Counter(train_subflavors) - print(train_subflavors[:10]) - print(cnt) - avg = 55 * 10 / 5 - assert len(Counter(train_subflavors)) == 2 - assert avg * 4 - 40 < cnt["train"] < avg * 4 + 40 - assert avg - 10 < cnt[None] < avg + 10 - - train_subflavorss = [ - tuple(subflavor.items()) - for idx, data in zip(range(55), train_loader1) - for subflavor in data.__subflavors__ - ] - cnt = Counter(train_subflavorss) - print(train_subflavorss[:10]) - print(cnt) - assert len(Counter(train_subflavorss)) == 3 - assert ( - avg * 2 - 20 - < cnt[ - ( - ("source", "nested_metadataset.yaml"), - ("dataset.yaml", True), - ("number", 43), - ("__subflavor__", "train"), - ("mds", "nested_train"), - ) + with get_loader(train_dataset) as train_loader1: + train_order1 = [ + text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text ] - < avg * 2 + 20 - ) - assert ( - avg * 2 - 20 - < cnt[ - ( - ("source", "nested_metadataset.yaml"), - ("dataset.yaml", True), - ("number", 44), - ("__subflavor__", "train"), - ("mds", "nested_train"), - ) + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(Counter(train_order1)) == 110 + assert all(48 <= v <= 53 for v in Counter(train_order1).values()) + + train_subflavors = [ + subflavor.get("__subflavor__") + for idx, data in zip(range(55), train_loader1) + for subflavor in data.__subflavors__ ] - < avg * 2 + 20 - ) - assert ( - avg * 1 - 20 - < cnt[ - ( - ("source", "nested_metadataset.yaml"), - ("dataset.yaml", True), - ("number", 42), - ("mds", "nested_val"), - ) + cnt = Counter(train_subflavors) + print(train_subflavors[:10]) + print(cnt) + avg = 55 * 10 / 5 + assert len(Counter(train_subflavors)) == 2 + assert avg * 4 - 40 < cnt["train"] < avg * 4 + 40 + assert avg - 10 < cnt[None] < avg + 10 + + train_subflavorss = [ + tuple(subflavor.items()) + for idx, data in zip(range(55), train_loader1) + for subflavor in data.__subflavors__ ] - < avg * 1 + 20 - ) + cnt = Counter(train_subflavorss) + print(train_subflavorss[:10]) + print(cnt) + assert len(Counter(train_subflavorss)) == 3 + assert ( + avg * 2 - 20 + < cnt[ + ( + ("source", "nested_metadataset.yaml"), + ("dataset.yaml", True), + ("number", 43), + ("__subflavor__", "train"), + ("mds", "nested_train"), + ) + ] + < avg * 2 + 20 + ) + assert ( + avg * 2 - 20 + < cnt[ + ( + ("source", "nested_metadataset.yaml"), + ("dataset.yaml", True), + ("number", 44), + ("__subflavor__", "train"), + ("mds", "nested_train"), + ) + ] + < avg * 2 + 20 + ) + assert ( + avg * 1 - 20 + < cnt[ + ( + ("source", "nested_metadataset.yaml"), + ("dataset.yaml", True), + ("number", 42), + ("mds", "nested_val"), + ) + ] + < avg * 1 + 20 + ) # Train mode dataset train_dataset = get_train_dataset( @@ -499,27 +495,25 @@ def test_nested_metadataset(self): print(len(train_dataset)) assert len(train_dataset) == 11 - train_loader1 = get_loader(train_dataset) - - train_order1 = [ - text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text - ] - print(train_order1[:10]) - print(Counter(train_order1)) - assert len(Counter(train_order1)) == 110 - assert all(48 <= v <= 52 for v in Counter(train_order1).values()) + with get_loader(train_dataset) as train_loader1: + train_order1 = [ + text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text + ] + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(Counter(train_order1)) == 110 + assert all(48 <= v <= 52 for v in Counter(train_order1).values()) # Val mode dataset val_dataset = get_val_dataset(self.mds_path, worker_config=worker_config, batch_size=10) print(len(val_dataset)) assert len(val_dataset) == 11 - val_loader1 = get_loader(val_dataset) - - val_order1 = [text for data in val_loader1 for text in data.text] - assert len(val_order1) == 110 - print(Counter(val_order1)) - assert all(v == 1 for v in Counter(val_order1).values()) + with get_loader(val_dataset) as val_loader1: + val_order1 = [text for data in val_loader1 for text in data.text] + assert len(val_order1) == 110 + print(Counter(val_order1)) + assert all(v == 1 for v in Counter(val_order1).values()) def test_worker_sample_balance(self): torch.manual_seed(42) @@ -669,91 +663,84 @@ def new_loader(): ) # Train mode dataset - loader = new_loader() - state_0 = loader.save_state_rank() - order_0 = [data.text for idx, data in zip(range(10), loader)] - state_1 = loader.save_state_rank() - # print("save state done") - order_1 = [data.text for idx, data in zip(range(20), loader)] - - state_2 = loader.save_state_rank() - # print("save state done") - # Iterated 30 samples, afterwards 50 samples. Checkpoint should be around that - order_2 = [data.text for idx, data in zip(range(20), loader)] - - state_3 = loader.save_state_rank() - # print("save state done") - # Iterated 50 samples, afterwards 53 samples. Checkpoint should be around that - order_3 = [data.text for idx, data in zip(range(3), loader)] - - state_4 = loader.save_state_rank() - # print("save state done") - # Dataset size is 55, want to save one sample before end of epoch - # Iterated 53 samples, afterwards 54 samples. Checkpoint should be around that - order_4 = [data.text for idx, data in zip(range(1), loader)] - - state_5 = loader.save_state_rank() - # print("save state done") - # Dataset size is 55, want to save one sample before end of epoch - # Iterated 54 samples, afterwards 55 samples. Checkpoint should be around that - order_5 = [data.text for idx, data in zip(range(1), loader)] - - state_6 = loader.save_state_rank() - # print("save state done") - # Dataset size is 55, want to save one sample before end of epoch - # Iterated 55 samples, afterwards 75 samples. Checkpoint should be around that - order_6 = [data.text for idx, data in zip(range(70), loader)] - - loader = new_loader() - print("state_1:", _norng_state(state_1)) - loader.restore_state_rank(state_1) - order_1_rest = [data.text for idx, data in zip(range(len(order_1)), loader)] - assert order_1 == order_1_rest - - loader = new_loader() - loader.restore_state_rank(state_0) - order_0_rest = [data.text for idx, data in zip(range(len(order_0)), loader)] - assert order_0 == order_0_rest - - loader = new_loader() - print("state_2:", _norng_state(state_2)) - loader.restore_state_rank(state_2) - order_2_rest = [data.text for idx, data in zip(range(len(order_2)), loader)] - print("order_2:", order_2) - print("order_2_rest:", order_2_rest) - assert order_2 == order_2_rest - - loader = new_loader() - print("state_3:", _norng_state(state_3)) - loader.restore_state_rank(state_3) - order_3_rest = [data.text for idx, data in zip(range(len(order_3)), loader)] - print("order_3:", order_3) - print("order_3_rest:", order_3_rest) - assert order_3 == order_3_rest - - loader = new_loader() - print("state_4:", _norng_state(state_4)) - loader.restore_state_rank(state_4) - order_4_rest = [data.text for idx, data in zip(range(len(order_4)), loader)] - print("order_4:", order_4) - print("order_4_rest:", order_4_rest) - assert order_4 == order_4_rest - - loader = new_loader() - print("state_5:", _norng_state(state_5)) - loader.restore_state_rank(state_5) - order_5_rest = [data.text for idx, data in zip(range(len(order_5)), loader)] - print("order_5:", order_5) - print("order_5_rest:", order_5_rest) - assert order_5 == order_5_rest - - loader = new_loader() - print("state_6:", _norng_state(state_6)) - loader.restore_state_rank(state_6) - order_6_rest = [data.text for idx, data in zip(range(len(order_6)), loader)] - print("order_6:", order_6) - print("order_6_rest:", order_6_rest) - assert order_6 == order_6_rest + with new_loader() as loader: + state_0 = loader.save_state_rank() + order_0 = [data.text for idx, data in zip(range(10), loader)] + state_1 = loader.save_state_rank() + # print("save state done") + order_1 = [data.text for idx, data in zip(range(20), loader)] + + state_2 = loader.save_state_rank() + # print("save state done") + # Iterated 30 samples, afterwards 50 samples. Checkpoint should be around that + order_2 = [data.text for idx, data in zip(range(20), loader)] + + state_3 = loader.save_state_rank() + # print("save state done") + # Iterated 50 samples, afterwards 53 samples. Checkpoint should be around that + order_3 = [data.text for idx, data in zip(range(3), loader)] + + state_4 = loader.save_state_rank() + # print("save state done") + # Dataset size is 55, want to save one sample before end of epoch + # Iterated 53 samples, afterwards 54 samples. Checkpoint should be around that + order_4 = [data.text for idx, data in zip(range(1), loader)] + + state_5 = loader.save_state_rank() + # print("save state done") + # Dataset size is 55, want to save one sample before end of epoch + # Iterated 54 samples, afterwards 55 samples. Checkpoint should be around that + order_5 = [data.text for idx, data in zip(range(1), loader)] + + state_6 = loader.save_state_rank() + # print("save state done") + # Dataset size is 55, want to save one sample before end of epoch + # Iterated 55 samples, afterwards 75 samples. Checkpoint should be around that + order_6 = [data.text for idx, data in zip(range(70), loader)] + + with new_loader().with_restored_state_rank(state_1) as loader: + print("state_1:", _norng_state(state_1)) + order_1_rest = [data.text for idx, data in zip(range(len(order_1)), loader)] + assert order_1 == order_1_rest + + with new_loader().with_restored_state_rank(state_0) as loader: + order_0_rest = [data.text for idx, data in zip(range(len(order_0)), loader)] + assert order_0 == order_0_rest + + with new_loader().with_restored_state_rank(state_2) as loader: + print("state_2:", _norng_state(state_2)) + order_2_rest = [data.text for idx, data in zip(range(len(order_2)), loader)] + print("order_2:", order_2) + print("order_2_rest:", order_2_rest) + assert order_2 == order_2_rest + + with new_loader().with_restored_state_rank(state_3) as loader: + print("state_3:", _norng_state(state_3)) + order_3_rest = [data.text for idx, data in zip(range(len(order_3)), loader)] + print("order_3:", order_3) + print("order_3_rest:", order_3_rest) + assert order_3 == order_3_rest + + with new_loader().with_restored_state_rank(state_4) as loader: + print("state_4:", _norng_state(state_4)) + order_4_rest = [data.text for idx, data in zip(range(len(order_4)), loader)] + print("order_4:", order_4) + print("order_4_rest:", order_4_rest) + assert order_4 == order_4_rest + + with new_loader().with_restored_state_rank(state_5) as loader: + print("state_5:", _norng_state(state_5)) + order_5_rest = [data.text for idx, data in zip(range(len(order_5)), loader)] + print("order_5:", order_5) + print("order_5_rest:", order_5_rest) + assert order_5 == order_5_rest + + with new_loader().with_restored_state_rank(state_6) as loader: + print("state_6:", _norng_state(state_6)) + order_6_rest = [data.text for idx, data in zip(range(len(order_6)), loader)] + print("order_6:", order_6) + print("order_6_rest:", order_6_rest) + assert order_6 == order_6_rest wrk_cfg = worker_config.config() assert wrk_cfg == { @@ -992,106 +979,99 @@ def new_loader(): ) # Train mode dataset - loader = new_loader() - state_0 = loader.save_state_rank() - order_0 = [data.text for idx, data in zip(range(10), loader)] - time.sleep(0.5) - state_1 = loader.save_state_rank() - # print("save state done") - order_1 = [data.text for idx, data in zip(range(20), loader)] - - # Ensure a checkpoint is created on next() - time.sleep(1.5) - - state_2 = loader.save_state_rank() - # print("save state done") - # Iterated 30 samples, afterwards 50 samples. Checkpoint should be around that - order_2 = [data.text for idx, data in zip(range(20), loader)] - - state_3 = loader.save_state_rank() - # print("save state done") - # Iterated 50 samples, afterwards 53 samples. Checkpoint should be around that - order_3 = [data.text for idx, data in zip(range(3), loader)] - - # Ensure a checkpoint is created on next() - time.sleep(1.5) - - state_4 = loader.save_state_rank() - # print("save state done") - # Dataset size is 55, want to save one sample before end of epoch - # Iterated 1 samples, afterwards 54 samples. Checkpoint should be around that - order_4 = [data.text for idx, data in zip(range(1), loader)] - - # Ensure a checkpoint is created on next() - time.sleep(1.5) - - state_5 = loader.save_state_rank() - # print("save state done") - # Dataset size is 55, want to save one sample before end of epoch - # Iterated 1 samples, afterwards 55 samples. Checkpoint should be around that - order_5 = [data.text for idx, data in zip(range(1), loader)] - - # Ensure a checkpoint is created on next() - time.sleep(1.5) - - state_6 = loader.save_state_rank() - # print("save state done") - # Dataset size is 55, want to save one sample before end of epoch - # Iterated 1 samples, afterwards 55 samples. Checkpoint should be around that - order_6 = [data.text for idx, data in zip(range(10), loader)] - - loader = new_loader() - print("state_1:", _norng_state(state_1)) - loader.restore_state_rank(state_1) - order_1_rest = [data.text for idx, data in zip(range(len(order_1)), loader)] - print("order_1:", order_1) - print("order_1_rest:", order_1_rest) - assert order_1 == order_1_rest - - loader = new_loader() - loader.restore_state_rank(state_0) - order_0_rest = [data.text for idx, data in zip(range(len(order_0)), loader)] - assert order_0 == order_0_rest - - loader = new_loader() - print("state_2:", _norng_state(state_2)) - loader.restore_state_rank(state_2) - order_2_rest = [data.text for idx, data in zip(range(len(order_2)), loader)] - print("order_2:", order_2) - print("order_2_rest:", order_2_rest) - assert order_2 == order_2_rest - - loader = new_loader() - print("state_3:", _norng_state(state_3)) - loader.restore_state_rank(state_3) - order_3_rest = [data.text for idx, data in zip(range(len(order_3)), loader)] - print("order_3:", order_3) - print("order_3_rest:", order_3_rest) - assert order_3 == order_3_rest - - loader = new_loader() - print("state_4:", _norng_state(state_4)) - loader.restore_state_rank(state_4) - order_4_rest = [data.text for idx, data in zip(range(len(order_4)), loader)] - print("order_4:", order_4) - print("order_4_rest:", order_4_rest) - assert order_4 == order_4_rest - - loader = new_loader() - print("state_5:", _norng_state(state_5)) - loader.restore_state_rank(state_5) - order_5_rest = [data.text for idx, data in zip(range(len(order_5)), loader)] - print("order_5:", order_5) - print("order_5_rest:", order_5_rest) - assert order_5 == order_5_rest - - loader = new_loader() - print("state_6:", _norng_state(state_6)) - loader.restore_state_rank(state_6) - order_6_rest = [data.text for idx, data in zip(range(len(order_6)), loader)] - print("order_6:", order_6) - print("order_6_rest:", order_6_rest) - assert order_6 == order_6_rest + with new_loader() as loader: + state_0 = loader.save_state_rank() + order_0 = [data.text for idx, data in zip(range(10), loader)] + time.sleep(0.5) + state_1 = loader.save_state_rank() + # print("save state done") + order_1 = [data.text for idx, data in zip(range(20), loader)] + + # Ensure a checkpoint is created on next() + time.sleep(1.5) + + state_2 = loader.save_state_rank() + # print("save state done") + # Iterated 30 samples, afterwards 50 samples. Checkpoint should be around that + order_2 = [data.text for idx, data in zip(range(20), loader)] + + state_3 = loader.save_state_rank() + # print("save state done") + # Iterated 50 samples, afterwards 53 samples. Checkpoint should be around that + order_3 = [data.text for idx, data in zip(range(3), loader)] + + # Ensure a checkpoint is created on next() + time.sleep(1.5) + + state_4 = loader.save_state_rank() + # print("save state done") + # Dataset size is 55, want to save one sample before end of epoch + # Iterated 1 samples, afterwards 54 samples. Checkpoint should be around that + order_4 = [data.text for idx, data in zip(range(1), loader)] + + # Ensure a checkpoint is created on next() + time.sleep(1.5) + + state_5 = loader.save_state_rank() + # print("save state done") + # Dataset size is 55, want to save one sample before end of epoch + # Iterated 1 samples, afterwards 55 samples. Checkpoint should be around that + order_5 = [data.text for idx, data in zip(range(1), loader)] + + # Ensure a checkpoint is created on next() + time.sleep(1.5) + + state_6 = loader.save_state_rank() + # print("save state done") + # Dataset size is 55, want to save one sample before end of epoch + # Iterated 1 samples, afterwards 55 samples. Checkpoint should be around that + order_6 = [data.text for idx, data in zip(range(10), loader)] + + with new_loader().with_restored_state_rank(state_1) as loader: + print("state_1:", _norng_state(state_1)) + order_1_rest = [data.text for idx, data in zip(range(len(order_1)), loader)] + print("order_1:", order_1) + print("order_1_rest:", order_1_rest) + assert order_1 == order_1_rest + + with new_loader().with_restored_state_rank(state_0) as loader: + order_0_rest = [data.text for idx, data in zip(range(len(order_0)), loader)] + assert order_0 == order_0_rest + + with new_loader().with_restored_state_rank(state_2) as loader: + print("state_2:", _norng_state(state_2)) + order_2_rest = [data.text for idx, data in zip(range(len(order_2)), loader)] + print("order_2:", order_2) + print("order_2_rest:", order_2_rest) + assert order_2 == order_2_rest + + with new_loader().with_restored_state_rank(state_3) as loader: + print("state_3:", _norng_state(state_3)) + order_3_rest = [data.text for idx, data in zip(range(len(order_3)), loader)] + print("order_3:", order_3) + print("order_3_rest:", order_3_rest) + assert order_3 == order_3_rest + + with new_loader().with_restored_state_rank(state_4) as loader: + print("state_4:", _norng_state(state_4)) + order_4_rest = [data.text for idx, data in zip(range(len(order_4)), loader)] + print("order_4:", order_4) + print("order_4_rest:", order_4_rest) + assert order_4 == order_4_rest + + with new_loader().with_restored_state_rank(state_5) as loader: + print("state_5:", _norng_state(state_5)) + order_5_rest = [data.text for idx, data in zip(range(len(order_5)), loader)] + print("order_5:", order_5) + print("order_5_rest:", order_5_rest) + assert order_5 == order_5_rest + + with new_loader().with_restored_state_rank(state_6) as loader: + print("state_6:", _norng_state(state_6)) + order_6_rest = [data.text for idx, data in zip(range(len(order_6)), loader)] + print("order_6:", order_6) + print("order_6_rest:", order_6_rest) + assert order_6 == order_6_rest def test_save_restore_state_train_epochize_workers(self): torch.manual_seed(42) @@ -1108,7 +1088,7 @@ def test_save_restore_state_train_epochize_workers(self): # Train mode dataset torch.manual_seed(42) - loader = get_savable_loader( + with get_savable_loader( get_train_dataset( self.mds_path, worker_config=worker_config, @@ -1118,16 +1098,16 @@ def test_save_restore_state_train_epochize_workers(self): shuffle_buffer_size=sbs, max_samples_per_sequence=sbs, ), - ) - state_0 = loader.save_state_rank() - order_1 = [data.text[0] for data in loader] - state_1 = loader.save_state_rank() - order_2 = [data.text[0] for data in loader] - state_2 = loader.save_state_rank() - order_3 = [data.text[0] for idx, data in zip(range(17), loader)] + ) as loader: + state_0 = loader.save_state_rank() + order_1 = [data.text[0] for data in loader] + state_1 = loader.save_state_rank() + order_2 = [data.text[0] for data in loader] + state_2 = loader.save_state_rank() + order_3 = [data.text[0] for idx, data in zip(range(17), loader)] torch.manual_seed(42) - loader = get_savable_loader( + with get_savable_loader( get_train_dataset( self.mds_path, worker_config=worker_config, @@ -1137,16 +1117,15 @@ def test_save_restore_state_train_epochize_workers(self): shuffle_buffer_size=sbs, max_samples_per_sequence=sbs, ), - ) - print("state_0:", _norng_state(state_0)) - loader.restore_state_rank(state_0) - order_5 = [data.text[0] for data in loader] - print("order_1:", order_1) - print("order_5:", order_5) - assert order_1 == order_5 + ).with_restored_state_rank(state_0) as loader: + print("state_0:", _norng_state(state_0)) + order_5 = [data.text[0] for data in loader] + print("order_1:", order_1) + print("order_5:", order_5) + assert order_1 == order_5 torch.manual_seed(42) - loader = get_savable_loader( + with get_savable_loader( get_train_dataset( self.mds_path, worker_config=worker_config, @@ -1156,16 +1135,15 @@ def test_save_restore_state_train_epochize_workers(self): shuffle_buffer_size=sbs, max_samples_per_sequence=sbs, ), - ) - print("state_1:", _norng_state(state_1)) - loader.restore_state_rank(state_1) - order_6 = [data.text[0] for data in loader] - print("order_2:", order_2) - print("order_6:", order_6) - assert order_2 == order_6 + ).with_restored_state_rank(state_1) as loader: + print("state_1:", _norng_state(state_1)) + order_6 = [data.text[0] for data in loader] + print("order_2:", order_2) + print("order_6:", order_6) + assert order_2 == order_6 torch.manual_seed(42) - loader = get_savable_loader( + with get_savable_loader( get_train_dataset( self.mds_path, worker_config=worker_config, @@ -1175,13 +1153,12 @@ def test_save_restore_state_train_epochize_workers(self): shuffle_buffer_size=sbs, max_samples_per_sequence=sbs, ), - ) - print("state_2:", _norng_state(state_2)) - loader.restore_state_rank(state_2) - order_7 = [data.text[0] for idx, data in zip(range(17), loader)] - print("order_3:", order_3) - print("order_7:", order_7) - assert order_3 == order_7 + ).with_restored_state_rank(state_2) as loader: + print("state_2:", _norng_state(state_2)) + order_7 = [data.text[0] for idx, data in zip(range(17), loader)] + print("order_3:", order_3) + print("order_7:", order_7) + assert order_3 == order_7 def test_save_restore_state_val(self): torch.manual_seed(42) @@ -1194,28 +1171,26 @@ def test_save_restore_state_val(self): ) # Train mode dataset - loader = get_savable_loader( + with get_savable_loader( get_val_dataset(self.mds_path, worker_config=worker_config, batch_size=10), - ) - state_0 = loader.save_state_rank() - order_1 = [data.text for idx, data in zip(range(55 * 20), loader)] - state_1 = loader.save_state_rank() - # print("save state done") - order_2 = [data.text for idx, data in zip(range(55 * 20), loader)] - - loader = get_savable_loader( + ) as loader: + state_0 = loader.save_state_rank() + order_1 = [data.text for idx, data in zip(range(55 * 20), loader)] + state_1 = loader.save_state_rank() + # print("save state done") + order_2 = [data.text for idx, data in zip(range(55 * 20), loader)] + + with get_savable_loader( get_val_dataset(self.mds_path, worker_config=worker_config, batch_size=10), - ) - loader.restore_state_rank(state_1) - order_3 = [data.text for idx, data in zip(range(55 * 20), loader)] - assert order_2 == order_3 + ).with_restored_state_rank(state_1) as loader: + order_3 = [data.text for idx, data in zip(range(55 * 20), loader)] + assert order_2 == order_3 - loader = get_savable_loader( + with get_savable_loader( get_val_dataset(self.mds_path, worker_config=worker_config, batch_size=10), - ) - loader.restore_state_rank(state_0) - order_4 = [data.text for idx, data in zip(range(55 * 20), loader)] - assert order_1 == order_4 + ).with_restored_state_rank(state_0) as loader: + order_4 = [data.text for idx, data in zip(range(55 * 20), loader)] + assert order_1 == order_4 def test_blending_randomness(self): import random @@ -1247,12 +1222,11 @@ def test_blending_randomness(self): shuffle_buffer_size=None, max_samples_per_sequence=None, ) - loader = get_loader(ds) - - subflavors = [ - data.__subflavors__[0].get("__subflavor__") - for idx, data in zip(range(25), loader) - ] + with get_loader(ds) as loader: + subflavors = [ + data.__subflavors__[0].get("__subflavor__") + for idx, data in zip(range(25), loader) + ] all_ranks_subflavors.append(subflavors) @@ -1265,10 +1239,6 @@ def test_blending_randomness(self): f"Rank {i} and rank {j} got the same subflavors." ) - # Delete all locals, otherwise loaders might be kept alive - locals().clear() - gc.collect() - def test_slice_iter_shuffle_over_epochs(self): torch.manual_seed(42) @@ -1293,8 +1263,8 @@ def new_loader(): ) # Train mode dataset - loader = new_loader() - _ = [data.text for idx, data in zip(range(1000), loader)] + with new_loader() as loader: + _ = [data.text for idx, data in zip(range(1000), loader)] def test_save_restore_next(self): torch.manual_seed(42) @@ -1305,7 +1275,7 @@ def test_save_restore_next(self): num_workers=6, ) - initial_loader = get_savable_loader( + with get_savable_loader( get_train_dataset( self.nested_mds_path, worker_config=wc, @@ -1313,45 +1283,43 @@ def test_save_restore_next(self): shuffle_buffer_size=None, max_samples_per_sequence=None, ), - ) - skip_initial = 9 - - previous_cp = initial_loader.save_state_rank() - print("initial_samples:") - for i, sample in zip(range(skip_initial), initial_loader): - print(f"sample[@{i}]: {sample.text}") - print("previous_cp:", previous_cp) - rst_loader = get_savable_loader( - get_train_dataset( - self.nested_mds_path, - worker_config=wc, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - ) - rst_loader.restore_state_rank(previous_cp) - for i, rst_sample in zip(range(1), rst_loader): - print(f"rst_sample[@{i}]: {rst_sample.text}") - assert sample.text == rst_sample.text, f"{sample} != {rst_sample}" - assert sample.__key__ == rst_sample.__key__, f"{sample} != {rst_sample}" - assert sample.__restore_key__ == rst_sample.__restore_key__, f"{sample} != {rst_sample}" - previous_cp = initial_loader.save_state_rank() + ) as initial_loader: + skip_initial = 9 - # Iterate 10 samples, the save state and store the next 10 samples for reference. - state_initial = initial_loader.save_state_rank() - print("state_initial:", str(state_initial)) - initial_samples = [sample for _, sample in zip(range(20), initial_loader)] - print( - "initial_samples:" - + "".join( - f"\n [@{idx}] {sample.text}" - for idx, sample in enumerate(initial_samples, start=skip_initial) + previous_cp = initial_loader.save_state_rank() + print("initial_samples:") + for i, sample in zip(range(skip_initial), initial_loader): + print(f"sample[@{i}]: {sample.text}") + print("previous_cp:", previous_cp) + with get_savable_loader( + get_train_dataset( + self.nested_mds_path, + worker_config=wc, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ), + ).with_restored_state_rank(previous_cp) as rst_loader: + for i, rst_sample in zip(range(1), rst_loader): + print(f"rst_sample[@{i}]: {rst_sample.text}") + assert sample.text == rst_sample.text, f"{sample} != {rst_sample}" + assert sample.__key__ == rst_sample.__key__, f"{sample} != {rst_sample}" + assert sample.__restore_key__ == rst_sample.__restore_key__, ( + f"{sample} != {rst_sample}" + ) + previous_cp = initial_loader.save_state_rank() + + # Iterate 10 samples, the save state and store the next 10 samples for reference. + state_initial = initial_loader.save_state_rank() + print("state_initial:", str(state_initial)) + initial_samples = [sample for _, sample in zip(range(20), initial_loader)] + print( + "initial_samples:" + + "".join( + f"\n [@{idx}] {sample.text}" + for idx, sample in enumerate(initial_samples, start=skip_initial) + ) ) - ) - - del initial_loader - gc.collect() second_loader = get_savable_loader( get_train_dataset( @@ -1363,90 +1331,95 @@ def test_save_restore_next(self): ), ) second_loader.restore_state_rank(state_initial) - # Save the state again, to check that it is the same as the just restored state same_state = second_loader.save_state_rank() print("same_state:", same_state) - assert same_state == state_initial + assert_nested_equal(same_state, state_initial) + assert same_state is state_initial # This will propagate the state to the workers. second_loader.start() - # Save the state again, to check that it is the same as the just restored state - same_state = second_loader.save_state_rank() - print("same_state:", same_state) - assert_nested_equal(same_state, state_initial) + try: + # Save the state again, to check that it is the same as the just restored state + same_state = second_loader.save_state_rank() + print("same_state:", same_state) + assert_nested_equal(same_state, state_initial) - for offset in range(10): - try: - # Save state and restore in next loader - state_offset = second_loader.save_state_rank() - # Get 1 sample from the current loader - samples = [sample for _, sample in zip(range(1), second_loader)] - assert len(samples) == 1 - sample = samples[0] - - # Check that the sample is the same as the initial loader's reference sample - print(f"sample[@{offset + skip_initial}]: {sample.text}") + for offset in range(10): try: - assert sample.text == initial_samples[offset].text, ( - f"{sample} != {initial_samples[offset]}" - ) - assert sample.__key__ == initial_samples[offset].__key__, ( - f"{sample} != {initial_samples[offset]}" - ) - assert sample.__restore_key__ == initial_samples[offset].__restore_key__, ( - f"{sample} != {initial_samples[offset]}" - ) - except Exception as e: - print( - "samples:" - + f"\n [@{offset + skip_initial}] {sample.text}" - + "".join( - f"\n [@{idx}] {sample.text}" - for idx, sample in zip( - range(skip_initial + offset + 1, skip_initial + offset + 6), - second_loader, + # Save state and restore in next loader + state_offset = second_loader.save_state_rank() + # Get 1 sample from the current loader + samples = [sample for _, sample in zip(range(1), second_loader)] + assert len(samples) == 1 + sample = samples[0] + + # Check that the sample is the same as the initial loader's reference sample + print(f"sample[@{offset + skip_initial}]: {sample.text}") + try: + assert sample.text == initial_samples[offset].text, ( + f"{sample} != {initial_samples[offset]}" + ) + assert sample.__key__ == initial_samples[offset].__key__, ( + f"{sample} != {initial_samples[offset]}" + ) + assert sample.__restore_key__ == initial_samples[offset].__restore_key__, ( + f"{sample} != {initial_samples[offset]}" + ) + except Exception as e: + print( + "samples:" + + f"\n [@{offset + skip_initial}] {sample.text}" + + "".join( + f"\n [@{idx}] {sample.text}" + for idx, sample in zip( + range(skip_initial + offset + 1, skip_initial + offset + 6), + second_loader, + ) ) ) - ) - raise ValueError(f"Failed to iterate @{offset + skip_initial} samples") from e - - # Restore state in a new loader - ref_loader = get_savable_loader( - get_train_dataset( - self.nested_mds_path, - worker_config=wc, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - ) - ref_loader.restore_state_rank(state_offset) - - # Get 1 sample from the restored loader - next_loader_samples = [sample for _, sample in zip(range(6), ref_loader)] - assert len(next_loader_samples) == 6 - next_loader_sample = next_loader_samples[0] - print( - "next_loader_samples:" - + f"\n [@{offset + skip_initial}] {sample.text}" - + "".join( - f"\n [@{idx}] {sample}" - for idx, sample in zip( - range(skip_initial + offset, skip_initial + offset + 6), - next_loader_samples, + raise ValueError( + f"Failed to iterate @{offset + skip_initial} samples" + ) from e + + # Restore state in a new loader + with get_savable_loader( + get_train_dataset( + self.nested_mds_path, + worker_config=wc, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ), + ).with_restored_state_rank(state_offset) as ref_loader: + # Get 1 sample from the restored loader + next_loader_samples = [sample for _, sample in zip(range(6), ref_loader)] + assert len(next_loader_samples) == 6 + next_loader_sample = next_loader_samples[0] + print( + "next_loader_samples:" + + f"\n [@{offset + skip_initial}] {sample.text}" + + "".join( + f"\n [@{idx}] {sample}" + for idx, sample in zip( + range(skip_initial + offset, skip_initial + offset + 6), + next_loader_samples, + ) + ) ) - ) - ) - assert next_loader_sample.text == sample.text, f"{next_loader_sample} != {sample}" - assert next_loader_sample.__key__ == sample.__key__, ( - f"{next_loader_sample} != {sample}" - ) - assert next_loader_sample.__restore_key__ == sample.__restore_key__, ( - f"{next_loader_sample} != {sample}" - ) - except Exception as e: - raise ValueError(f"Failed to iterate @{skip_initial}+{offset} samples") from e + assert next_loader_sample.text == sample.text, ( + f"{next_loader_sample} != {sample}" + ) + assert next_loader_sample.__key__ == sample.__key__, ( + f"{next_loader_sample} != {sample}" + ) + assert next_loader_sample.__restore_key__ == sample.__restore_key__, ( + f"{next_loader_sample} != {sample}" + ) + except Exception as e: + raise ValueError(f"Failed to iterate @{skip_initial}+{offset} samples") from e + finally: + second_loader.shutdown() if __name__ == "__main__": diff --git a/tests/test_metadataset_fewsamp.py b/tests/test_metadataset_fewsamp.py index 7a9336bb..d2cddd5d 100644 --- a/tests/test_metadataset_fewsamp.py +++ b/tests/test_metadataset_fewsamp.py @@ -179,21 +179,20 @@ def test_metadataset_few_samples_save_restore(self): assert len(blend_ds.dataset_weights[1][0].dataset.dataset.workers_slice_offsets[0]) == 1 assert len(blend_ds.dataset_weights[1][0].dataset.dataset) == 0 - train_loader = get_savable_loader( + with get_savable_loader( train_dataset, - ) - - # Load 3 samples - list(zip(train_loader, range(3))) + ) as train_loader: + # Load 3 samples + list(zip(train_loader, range(3))) - # Save state mid epoch - state1 = train_loader.save_state_rank() + # Save state mid epoch + state1 = train_loader.save_state_rank() - # Load 5 samples - data1b = list(zip(train_loader, range(5))) + # Load 5 samples + data1b = list(zip(train_loader, range(5))) # Restore state - train_loader = get_savable_loader( + with get_savable_loader( get_train_dataset( self.mds_path, worker_config=worker_config, @@ -201,21 +200,20 @@ def test_metadataset_few_samples_save_restore(self): shuffle_buffer_size=100, max_samples_per_sequence=None, ), - ) - train_loader.restore_state_rank(state1) - # Load 5 samples - data2_restore = list(zip(train_loader, range(5))) + ).with_restored_state_rank(state1) as train_loader: + # Load 5 samples + data2_restore = list(zip(train_loader, range(5))) - # Check that the restored state is the same - order1b = [(s[0].__key__[0], int(s[0].text[0])) for s in data1b] - order2 = [(s[0].__key__[0], int(s[0].text[0])) for s in data2_restore] + # Check that the restored state is the same + order1b = [(s[0].__key__[0], int(s[0].text[0])) for s in data1b] + order2 = [(s[0].__key__[0], int(s[0].text[0])) for s in data2_restore] - print("order1b") - print(order1b) - print("order2") - print(order2) + print("order1b") + print(order1b) + print("order2") + print(order2) - assert order1b == order2, "The restored state does not match the original state." + assert order1b == order2, "The restored state does not match the original state." def test_too_few_samples(self): # Will only give a single sample, as there are 117 samples in total, and 100 ranks @@ -223,7 +221,7 @@ def test_too_few_samples(self): lens = [] for i_rank in range(ws): worker_config = WorkerConfig(rank=i_rank, world_size=ws, num_workers=0) - loader = get_savable_loader( + with get_savable_loader( get_train_dataset( self.mds_path, batch_size=1, @@ -231,17 +229,17 @@ def test_too_few_samples(self): shuffle_buffer_size=None, max_samples_per_sequence=None, ), - ) - lens.append(len(loader)) + ) as loader: + lens.append(len(loader)) - txts = [] + txts = [] - for i, sample in zip(range(10), loader): - txts.extend(sample.text) + for i, sample in zip(range(10), loader): + txts.extend(sample.text) - assert len(set(txts)) == len(loader), ( - f"Rank {i_rank} should have exactly {len(loader)} sample, but got {txts}" - ) + assert len(set(txts)) == len(loader), ( + f"Rank {i_rank} should have exactly {len(loader)} sample, but got {txts}" + ) assert lens == [ 2 if i in [0, 3, 6, 12, 18, 25, 31, 37, 43, 50, 56, 62, 68, 75, 81, 87, 93] else 1 diff --git a/tests/test_metadataset_v2.py b/tests/test_metadataset_v2.py index bfb9357d..973055d3 100644 --- a/tests/test_metadataset_v2.py +++ b/tests/test_metadataset_v2.py @@ -247,15 +247,14 @@ def test_metadataset(self): print(len(train_dataset)) assert len(train_dataset) == 11 - train_loader1 = get_loader(train_dataset) - - train_order1 = [ - text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text - ] - print(train_order1[:10]) - print(Counter(train_order1)) - assert len(Counter(train_order1)) == 110 - assert all(48 <= v <= 52 for v in Counter(train_order1).values()) + with get_loader(train_dataset) as train_loader1: + train_order1 = [ + text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text + ] + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(Counter(train_order1)) == 110 + assert all(48 <= v <= 52 for v in Counter(train_order1).values()) def test_nested_metadataset(self): torch.manual_seed(42) @@ -360,40 +359,39 @@ def test_joined_metadataset(self): print(len(train_dataset)) assert len(train_dataset) == 55 - train_loader = get_savable_loader( + with get_savable_loader( train_dataset, - ) - - data = list(zip(range(2 * 55), train_loader)) - txt1_order = [data.text1[0] for idx, data in data] - txt2_order = [data.text2[0] for idx, data in data] - key_order = [data.__key__[0] for idx, data in data] - # ds1 has 55 samples, key range 0:55, txt range 0:55 - # ds3 has 28 samples, key range 0:55, txt range 200:255 - # Joining results in: 0:55 - print("txt1:", txt1_order) - # Joining results in: 200:255 - print("txt2:", txt2_order) - # Joining results in: 0:55 - print("key:", key_order) - # Check matching - assert all(int(txt1) + 200 == int(txt2) for txt1, txt2 in zip(txt1_order, txt2_order)) - # Check frequency - assert set(txt1_order) == set(str(i) for i in range(0, 55)) - assert set(txt2_order) == set(str(i) for i in range(200, 255)) - # Every item must occurr 2 times (2*55). - assert Counter(txt1_order).most_common(1)[0][1] == 2 - - state = train_loader.save_state_rank() - - # Iterate 60 more items - data = list(zip(range(60), train_loader)) - txt1_order = [data.text1 for idx, data in data] - txt2_order = [data.text2 for idx, data in data] - key_order = [data.__key__ for idx, data in data] + ) as train_loader: + data = list(zip(range(2 * 55), train_loader)) + txt1_order = [data.text1[0] for idx, data in data] + txt2_order = [data.text2[0] for idx, data in data] + key_order = [data.__key__[0] for idx, data in data] + # ds1 has 55 samples, key range 0:55, txt range 0:55 + # ds3 has 28 samples, key range 0:55, txt range 200:255 + # Joining results in: 0:55 + print("txt1:", txt1_order) + # Joining results in: 200:255 + print("txt2:", txt2_order) + # Joining results in: 0:55 + print("key:", key_order) + # Check matching + assert all(int(txt1) + 200 == int(txt2) for txt1, txt2 in zip(txt1_order, txt2_order)) + # Check frequency + assert set(txt1_order) == set(str(i) for i in range(0, 55)) + assert set(txt2_order) == set(str(i) for i in range(200, 255)) + # Every item must occurr 2 times (2*55). + assert Counter(txt1_order).most_common(1)[0][1] == 2 + + state = train_loader.save_state_rank() + + # Iterate 60 more items + data = list(zip(range(60), train_loader)) + txt1_order = [data.text1 for idx, data in data] + txt2_order = [data.text2 for idx, data in data] + key_order = [data.__key__ for idx, data in data] # Restore state - train_loader = get_savable_loader( + with get_savable_loader( get_train_dataset( joined_mds_path, worker_config=worker_config, @@ -401,20 +399,17 @@ def test_joined_metadataset(self): shuffle_buffer_size=None, max_samples_per_sequence=None, ), - ) - - train_loader.restore_state_rank(state) - - # Iterate 360 more items - data = list(zip(range(60), train_loader)) - txt1_order_rest = [data.text1 for idx, data in data] - txt2_order_rest = [data.text2 for idx, data in data] - key_order_rest = [data.__key__ for idx, data in data] - - # Verify matching - assert txt1_order == txt1_order_rest - assert txt2_order == txt2_order_rest - assert key_order == key_order_rest + ).with_restored_state_rank(state) as train_loader: + # Iterate 360 more items + data = list(zip(range(60), train_loader)) + txt1_order_rest = [data.text1 for idx, data in data] + txt2_order_rest = [data.text2 for idx, data in data] + key_order_rest = [data.__key__ for idx, data in data] + + # Verify matching + assert txt1_order == txt1_order_rest + assert txt2_order == txt2_order_rest + assert key_order == key_order_rest def test_joined_metadataset_joiner(self): torch.manual_seed(42) @@ -467,31 +462,30 @@ def test_joined_metadataset_joiner(self): print(len(train_dataset)) assert len(train_dataset) == 55 - train_loader = get_savable_loader( + with get_savable_loader( train_dataset, - ) - - data = list(zip(range(2 * 55), train_loader)) - txt1_order = [data.text1[0] for idx, data in data] - txt2_order = [data.text2[0] for idx, data in data] - key_order = [data.__key__[0] for idx, data in data] - # ds1 has 55 samples, key range 0:55, txt range 0:55 - # ds3 has 28 samples, key range 0:55, txt range 200:255 - # Joining results in: 0:55, with prefix "j" - print("txt1:", txt1_order) - # Joining results in: 200:255, with prefix "j" - print("txt2:", txt2_order) - # Joining results in: 0:55 - print("key:", key_order) - # Check matching - assert all( - int(txt1[1:]) + 200 == int(txt2[1:]) for txt1, txt2 in zip(txt1_order, txt2_order) - ) - # Check frequency - assert set(txt1_order) == set(f"j{i}" for i in range(0, 55)) - assert set(txt2_order) == set(f"j{i}" for i in range(200, 255)) - # Every item must occurr 2 times (2*55). - assert Counter(txt1_order).most_common(1)[0][1] == 2 + ) as train_loader: + data = list(zip(range(2 * 55), train_loader)) + txt1_order = [data.text1[0] for idx, data in data] + txt2_order = [data.text2[0] for idx, data in data] + key_order = [data.__key__[0] for idx, data in data] + # ds1 has 55 samples, key range 0:55, txt range 0:55 + # ds3 has 28 samples, key range 0:55, txt range 200:255 + # Joining results in: 0:55, with prefix "j" + print("txt1:", txt1_order) + # Joining results in: 200:255, with prefix "j" + print("txt2:", txt2_order) + # Joining results in: 0:55 + print("key:", key_order) + # Check matching + assert all( + int(txt1[1:]) + 200 == int(txt2[1:]) for txt1, txt2 in zip(txt1_order, txt2_order) + ) + # Check frequency + assert set(txt1_order) == set(f"j{i}" for i in range(0, 55)) + assert set(txt2_order) == set(f"j{i}" for i in range(200, 255)) + # Every item must occurr 2 times (2*55). + assert Counter(txt1_order).most_common(1)[0][1] == 2 def test_left_join(self): torch.manual_seed(42) @@ -545,29 +539,28 @@ def test_left_join(self): print(len(train_dataset)) assert len(train_dataset) == 55, len(train_dataset) - train_loader = get_savable_loader( + with get_savable_loader( train_dataset, - ) - - data = list(zip(range(2 * 55), train_loader)) - txt1_order = [data.text1[0] for idx, data in data] - txt2_order = [data.text2[0] for idx, data in data] - key_order = [data.__key__[0] for idx, data in data] - # ds1 has 55 samples, key range 0:55, txt range 0:55 - # ds3 has 28 samples, key range 0:55, txt range 200:255 - # Joining results in: 0:55, with prefix "j" - print("txt1:", txt1_order) - # Joining results in: 200:255, with prefix "j" - print("txt2:", txt2_order) - # Joining results in: 0:55 - print("key:", key_order) - # Check matching - assert all(int(txt1[1:]) == int(txt2[2:]) for txt1, txt2 in zip(txt1_order, txt2_order)) - # Check frequency - assert set(txt1_order) == set(f"j{i}" for i in range(55)) - assert set(txt2_order) == set(f"jB{i}" for i in range(55)) - # Every item must occurr 2 times (2*55). - assert Counter(txt1_order).most_common(1)[0][1] == 2 + ) as train_loader: + data = list(zip(range(2 * 55), train_loader)) + txt1_order = [data.text1[0] for idx, data in data] + txt2_order = [data.text2[0] for idx, data in data] + key_order = [data.__key__[0] for idx, data in data] + # ds1 has 55 samples, key range 0:55, txt range 0:55 + # ds3 has 28 samples, key range 0:55, txt range 200:255 + # Joining results in: 0:55, with prefix "j" + print("txt1:", txt1_order) + # Joining results in: 200:255, with prefix "j" + print("txt2:", txt2_order) + # Joining results in: 0:55 + print("key:", key_order) + # Check matching + assert all(int(txt1[1:]) == int(txt2[2:]) for txt1, txt2 in zip(txt1_order, txt2_order)) + # Check frequency + assert set(txt1_order) == set(f"j{i}" for i in range(55)) + assert set(txt2_order) == set(f"jB{i}" for i in range(55)) + # Every item must occurr 2 times (2*55). + assert Counter(txt1_order).most_common(1)[0][1] == 2 # Test that changing the file works as expected with open(joined_mds_path, "w") as f: @@ -702,28 +695,27 @@ def test_left_join_exclude(self): print(len(train_dataset)) assert len(train_dataset) == 55 - 16, len(train_dataset) - train_loader = get_savable_loader( + with get_savable_loader( train_dataset, - ) - - data = list(zip(range(2 * 55), train_loader)) - txt1_order = [data.text1[0] for idx, data in data] - txt2_order = [data.text2[0] for idx, data in data] - key_order = [data.__key__[0] for idx, data in data] - # ds1 has 55 samples, key range 0:55, txt range 0:55 - # ds3 has 28 samples, key range 0:55, txt range 200:255 - # Joining results in: 0:55, with prefix "j" - print("txt1:", txt1_order) - # Joining results in: 200:255, with prefix "j" - print("txt2:", txt2_order) - # Joining results in: 0:55 - print("key:", key_order) - # Check matching - assert all(int(txt1[1:]) == int(txt2[2:]) for txt1, txt2 in zip(txt1_order, txt2_order)) - # Check frequency - set_filtered_nums = set(range(5, 10)) | set(range(20, 29)) | set(range(30, 55)) - assert set(txt1_order) == set(f"j{i}" for i in set_filtered_nums) - assert set(txt2_order) == set(f"jB{i}" for i in set_filtered_nums) + ) as train_loader: + data = list(zip(range(2 * 55), train_loader)) + txt1_order = [data.text1[0] for idx, data in data] + txt2_order = [data.text2[0] for idx, data in data] + key_order = [data.__key__[0] for idx, data in data] + # ds1 has 55 samples, key range 0:55, txt range 0:55 + # ds3 has 28 samples, key range 0:55, txt range 200:255 + # Joining results in: 0:55, with prefix "j" + print("txt1:", txt1_order) + # Joining results in: 200:255, with prefix "j" + print("txt2:", txt2_order) + # Joining results in: 0:55 + print("key:", key_order) + # Check matching + assert all(int(txt1[1:]) == int(txt2[2:]) for txt1, txt2 in zip(txt1_order, txt2_order)) + # Check frequency + set_filtered_nums = set(range(5, 10)) | set(range(20, 29)) | set(range(30, 55)) + assert set(txt1_order) == set(f"j{i}" for i in set_filtered_nums) + assert set(txt2_order) == set(f"jB{i}" for i in set_filtered_nums) def test_joined_metadataset_prepare_mock(self): torch.manual_seed(42) @@ -816,59 +808,61 @@ def test_metadataset_fixed_epochs(self): print(len(train_dataset)) assert len(train_dataset) == 5 * 55, len(train_dataset) - train_loader = get_savable_loader( + with get_savable_loader( train_dataset, - ) - - data = list(enumerate(train_loader)) - txt_order = [data.text[0] for idx, data in data] - key_order = [data.__subflavors__[0]["source"] + "/" + data.__key__[0] for idx, data in data] - print("txt1:", txt_order) - print("key:", key_order) - assert len(txt_order) == 5 * 55, Counter(txt_order) - ds1_keys = [key for key in key_order if key.startswith("ds1/")] - ds2_keys = [key for key in key_order if key.startswith("ds2/")] - txt_cnt = Counter(txt_order) - ds1_key_cnt = Counter(ds1_keys) - ds2_key_cnt = Counter(ds2_keys) - assert len(ds1_keys) == 2 * 55, (len(ds1_keys), ds1_key_cnt) - assert len(ds2_keys) == 3 * 55, (len(ds2_keys), ds2_key_cnt) - assert all(ds1_key_cnt[key] == 2 for key in ds1_keys) - assert all(ds2_key_cnt[key] == 3 for key in ds2_keys) - assert all(txt_cnt[key] in (2, 3) for key in txt_order) - - # Next epoch - data = list(enumerate(train_loader)) - print([data.text[0] for idx, data in data]) - assert len(data) == 5 * 55, len(data) - - # Next epoch - data1 = list(zip(range(3 * 55), train_loader)) - assert len(data1) == 3 * 55, len(data1) - # Save state mid epoch - state1 = train_loader.save_state_rank() - print(state1) - - data2 = list(enumerate(train_loader)) - assert len(data2) == 2 * 55 - txt_order = [data.text[0] for idx, data in data1 + data2] - key_order = [ - data.__subflavors__[0]["source"] + "/" + data.__key__[0] for idx, data in data1 + data2 - ] - assert len(txt_order) == 5 * 55, Counter(txt_order) - ds1_keys = [key for key in key_order if key.startswith("ds1/")] - ds2_keys = [key for key in key_order if key.startswith("ds2/")] - txt_cnt = Counter(txt_order) - ds1_key_cnt = Counter(ds1_keys) - ds2_key_cnt = Counter(ds2_keys) - assert len(ds1_keys) == 2 * 55, (len(ds1_keys), ds1_key_cnt) - assert len(ds2_keys) == 3 * 55, (len(ds2_keys), ds2_key_cnt) - assert all(ds1_key_cnt[key] == 2 for key in ds1_keys) - assert all(ds2_key_cnt[key] == 3 for key in ds2_keys) - assert all(txt_cnt[key] in (2, 3) for key in txt_order) + ) as train_loader: + data = list(enumerate(train_loader)) + txt_order = [data.text[0] for idx, data in data] + key_order = [ + data.__subflavors__[0]["source"] + "/" + data.__key__[0] for idx, data in data + ] + print("txt1:", txt_order) + print("key:", key_order) + assert len(txt_order) == 5 * 55, Counter(txt_order) + ds1_keys = [key for key in key_order if key.startswith("ds1/")] + ds2_keys = [key for key in key_order if key.startswith("ds2/")] + txt_cnt = Counter(txt_order) + ds1_key_cnt = Counter(ds1_keys) + ds2_key_cnt = Counter(ds2_keys) + assert len(ds1_keys) == 2 * 55, (len(ds1_keys), ds1_key_cnt) + assert len(ds2_keys) == 3 * 55, (len(ds2_keys), ds2_key_cnt) + assert all(ds1_key_cnt[key] == 2 for key in ds1_keys) + assert all(ds2_key_cnt[key] == 3 for key in ds2_keys) + assert all(txt_cnt[key] in (2, 3) for key in txt_order) + + # Next epoch + data = list(enumerate(train_loader)) + print([data.text[0] for idx, data in data]) + assert len(data) == 5 * 55, len(data) + + # Next epoch + data1 = list(zip(range(3 * 55), train_loader)) + assert len(data1) == 3 * 55, len(data1) + # Save state mid epoch + state1 = train_loader.save_state_rank() + print(state1) + + data2 = list(enumerate(train_loader)) + assert len(data2) == 2 * 55 + txt_order = [data.text[0] for idx, data in data1 + data2] + key_order = [ + data.__subflavors__[0]["source"] + "/" + data.__key__[0] + for idx, data in data1 + data2 + ] + assert len(txt_order) == 5 * 55, Counter(txt_order) + ds1_keys = [key for key in key_order if key.startswith("ds1/")] + ds2_keys = [key for key in key_order if key.startswith("ds2/")] + txt_cnt = Counter(txt_order) + ds1_key_cnt = Counter(ds1_keys) + ds2_key_cnt = Counter(ds2_keys) + assert len(ds1_keys) == 2 * 55, (len(ds1_keys), ds1_key_cnt) + assert len(ds2_keys) == 3 * 55, (len(ds2_keys), ds2_key_cnt) + assert all(ds1_key_cnt[key] == 2 for key in ds1_keys) + assert all(ds2_key_cnt[key] == 3 for key in ds2_keys) + assert all(txt_cnt[key] in (2, 3) for key in txt_order) # Restore state - train_loader = get_savable_loader( + with get_savable_loader( get_train_dataset( fixed_epochs_mds_path, worker_config=worker_config, @@ -877,28 +871,27 @@ def test_metadataset_fixed_epochs(self): max_samples_per_sequence=None, repeat=False, ), - ) - train_loader.restore_state_rank(state1) - data2_restore = list(enumerate(train_loader)) - assert len(data2_restore) == 2 * 55 - txt_order_rst = [data.text[0] for idx, data in data1 + data2_restore] - key_order_rst = [ - data.__subflavors__[0]["source"] + "/" + data.__key__[0] - for idx, data in data1 + data2_restore - ] - assert len(txt_order_rst) == 5 * 55, Counter(txt_order_rst) - assert txt_order_rst == txt_order - assert key_order_rst == key_order - ds1_keys_rst = [key for key in key_order_rst if key.startswith("ds1/")] - ds2_keys_rst = [key for key in key_order_rst if key.startswith("ds2/")] - txt_cnt_rst = Counter(txt_order_rst) - ds1_key_cnt_rst = Counter(ds1_keys_rst) - ds2_key_cnt_rst = Counter(ds2_keys_rst) - assert len(ds1_keys_rst) == 2 * 55, (len(ds1_keys_rst), ds1_key_cnt_rst) - assert len(ds2_keys_rst) == 3 * 55, (len(ds2_keys_rst), ds2_key_cnt_rst) - assert all(ds1_key_cnt_rst[key] == 2 for key in ds1_keys_rst) - assert all(ds2_key_cnt_rst[key] == 3 for key in ds2_keys_rst) - assert all(txt_cnt_rst[key] in (2, 3) for key in txt_order_rst) + ).with_restored_state_rank(state1) as train_loader: + data2_restore = list(enumerate(train_loader)) + assert len(data2_restore) == 2 * 55 + txt_order_rst = [data.text[0] for idx, data in data1 + data2_restore] + key_order_rst = [ + data.__subflavors__[0]["source"] + "/" + data.__key__[0] + for idx, data in data1 + data2_restore + ] + assert len(txt_order_rst) == 5 * 55, Counter(txt_order_rst) + assert txt_order_rst == txt_order + assert key_order_rst == key_order + ds1_keys_rst = [key for key in key_order_rst if key.startswith("ds1/")] + ds2_keys_rst = [key for key in key_order_rst if key.startswith("ds2/")] + txt_cnt_rst = Counter(txt_order_rst) + ds1_key_cnt_rst = Counter(ds1_keys_rst) + ds2_key_cnt_rst = Counter(ds2_keys_rst) + assert len(ds1_keys_rst) == 2 * 55, (len(ds1_keys_rst), ds1_key_cnt_rst) + assert len(ds2_keys_rst) == 3 * 55, (len(ds2_keys_rst), ds2_key_cnt_rst) + assert all(ds1_key_cnt_rst[key] == 2 for key in ds1_keys_rst) + assert all(ds2_key_cnt_rst[key] == 3 for key in ds2_keys_rst) + assert all(txt_cnt_rst[key] in (2, 3) for key in txt_order_rst) def test_metadataset_fixed_fractional_epochs(self): torch.manual_seed(42) @@ -948,34 +941,33 @@ def test_metadataset_fixed_fractional_epochs(self): repeat=False, ) - train_loader = get_savable_loader( + with get_savable_loader( train_dataset, - ) + ) as train_loader: + assert len(train_loader) == 38 + 55 + 27, len(train_loader) - assert len(train_loader) == 38 + 55 + 27, len(train_loader) + data = list(enumerate(train_loader)) - data = list(enumerate(train_loader)) + # Check the overall number of samples + # Should be 0.7*len(ds1) + 1.5*len(ds2) = 0.7*55 + 1.5*55 = 38 + 55 + 27 (floor rounding) + assert len(data) == 38 + 55 + 27, len(data) - # Check the overall number of samples - # Should be 0.7*len(ds1) + 1.5*len(ds2) = 0.7*55 + 1.5*55 = 38 + 55 + 27 (floor rounding) - assert len(data) == 38 + 55 + 27, len(data) + sample_counts = Counter([int(s[1].text[0]) for s in data]) - sample_counts = Counter([int(s[1].text[0]) for s in data]) + # The first 70% of samples from ds1 (0 to incl. 37) should be repeated only once + assert all(sample_counts[sample] == 1 for sample in range(38)) - # The first 70% of samples from ds1 (0 to incl. 37) should be repeated only once - assert all(sample_counts[sample] == 1 for sample in range(38)) + # Since ds2 is repeated 1.5 times, the first 50% of samples from ds2 (100 to incl. 126) should be repeated twice + assert all(sample_counts[sample] == 2 for sample in range(100, 127)) - # Since ds2 is repeated 1.5 times, the first 50% of samples from ds2 (100 to incl. 126) should be repeated twice - assert all(sample_counts[sample] == 2 for sample in range(100, 127)) - - # The remaining samples from ds2 (127 to incl. 154) should be repeated only once - assert all(sample_counts[sample] == 1 for sample in range(127, 155)) + # The remaining samples from ds2 (127 to incl. 154) should be repeated only once + assert all(sample_counts[sample] == 1 for sample in range(127, 155)) # ===== Part 2: Save and restore state ===== # Now let's check if the state is stored and restored correctly - train_loader = get_savable_loader( + with get_savable_loader( get_train_dataset( fixed_epochs_mds_path, worker_config=worker_config, @@ -986,12 +978,11 @@ def test_metadataset_fixed_fractional_epochs(self): max_samples_per_sequence=None, repeat=False, ), - ) + ) as train_loader: + data1 = list(zip(range(95), train_loader)) + state1 = train_loader.save_state_rank() - data1 = list(zip(range(95), train_loader)) - state1 = train_loader.save_state_rank() - - train_loader = get_savable_loader( + with get_savable_loader( get_train_dataset( fixed_epochs_mds_path, worker_config=worker_config, @@ -1002,29 +993,28 @@ def test_metadataset_fixed_fractional_epochs(self): max_samples_per_sequence=None, repeat=False, ), - ) - train_loader.restore_state_rank(state1) - data2_restore = list(enumerate(train_loader)) + ).with_restored_state_rank(state1) as train_loader: + data2_restore = list(enumerate(train_loader)) - total_samples_save_restore = len(data1) + len(data2_restore) + total_samples_save_restore = len(data1) + len(data2_restore) - assert total_samples_save_restore == len(data), ( - "Total number of samples do not match when using save/restore" - ) + assert total_samples_save_restore == len(data), ( + "Total number of samples do not match when using save/restore" + ) - sample_counts_save_restore = Counter( - [int(s[1].text[0]) for d in [data1, data2_restore] for s in d] - ) + sample_counts_save_restore = Counter( + [int(s[1].text[0]) for d in [data1, data2_restore] for s in d] + ) - assert sample_counts_save_restore == sample_counts, ( - "Sample counts do not match when using save/restore" - ) + assert sample_counts_save_restore == sample_counts, ( + "Sample counts do not match when using save/restore" + ) # ===== Part 3: Check if the state is restored correctly when saving right at the end of a dataset ===== torch.manual_seed(42) - train_loader = get_savable_loader( + with get_savable_loader( get_train_dataset( fixed_epochs_mds_path, worker_config=worker_config, @@ -1035,21 +1025,20 @@ def test_metadataset_fixed_fractional_epochs(self): max_samples_per_sequence=None, repeat=False, ), - ) - - ds1_counter = 0 - data1 = [] - for idx, sample in enumerate(train_loader): - data1.append((idx, sample)) - if sample.__subflavors__[0]["source"] == "ds1": - ds1_counter += 1 - if ds1_counter == 38: - # Stop right after the last sample from ds1 - break - - state1 = train_loader.save_state_rank() - - train_loader = get_savable_loader( + ) as train_loader: + ds1_counter = 0 + data1 = [] + for idx, sample in enumerate(train_loader): + data1.append((idx, sample)) + if sample.__subflavors__[0]["source"] == "ds1": + ds1_counter += 1 + if ds1_counter == 38: + # Stop right after the last sample from ds1 + break + + state1 = train_loader.save_state_rank() + + with get_savable_loader( get_train_dataset( fixed_epochs_mds_path, worker_config=worker_config, @@ -1060,27 +1049,26 @@ def test_metadataset_fixed_fractional_epochs(self): max_samples_per_sequence=None, repeat=False, ), - ) - train_loader.restore_state_rank(state1) - data2_restore = list(enumerate(train_loader)) + ).with_restored_state_rank(state1) as train_loader: + data2_restore = list(enumerate(train_loader)) - total_samples_save_restore = len(data1) + len(data2_restore) + total_samples_save_restore = len(data1) + len(data2_restore) - assert total_samples_save_restore == len(data), ( - "Total number of samples do not match when using save/restore" - ) + assert total_samples_save_restore == len(data), ( + "Total number of samples do not match when using save/restore" + ) - sample_counts_save_restore = Counter( - [int(s[1].text[0]) for d in [data1, data2_restore] for s in d] - ) + sample_counts_save_restore = Counter( + [int(s[1].text[0]) for d in [data1, data2_restore] for s in d] + ) - assert sample_counts_save_restore == sample_counts, ( - "Sample counts do not match when using save/restore" - ) + assert sample_counts_save_restore == sample_counts, ( + "Sample counts do not match when using save/restore" + ) # Try in repeat mode # Train mode dataset - train_loader = get_savable_loader( + with get_savable_loader( get_train_dataset( fixed_epochs_mds_path, worker_config=worker_config, @@ -1090,14 +1078,13 @@ def test_metadataset_fixed_fractional_epochs(self): parallel_shard_iters=1, max_samples_per_sequence=None, ), - ) - - data = list(zip(range(200), train_loader)) - assert len(train_loader) == 38 + 55 + 27, len(train_loader) + ) as train_loader: + data = list(zip(range(200), train_loader)) + assert len(train_loader) == 38 + 55 + 27, len(train_loader) - # Check the overall number of samples - # Should be 0.7*len(ds1) + 1.5*len(ds2) = 38 + 55 + 27 (floor rounding) - assert len(data) == 200, len(data) + # Check the overall number of samples + # Should be 0.7*len(ds1) + 1.5*len(ds2) = 38 + 55 + 27 (floor rounding) + assert len(data) == 200, len(data) # ===== Part 4: Test count for multiple workers ===== @@ -1109,7 +1096,7 @@ def test_metadataset_fixed_fractional_epochs(self): ) # Train mode dataset - train_loader = get_savable_loader( + with get_savable_loader( get_train_dataset( fixed_epochs_mds_path, worker_config=worker_config, @@ -1120,17 +1107,16 @@ def test_metadataset_fixed_fractional_epochs(self): max_samples_per_sequence=None, repeat=False, ), - ) - - # TODO: This should be exactly 60. There is a corresponding TODO in the repeat_dataset.py - assert len(train_loader) == 58, len(train_loader) + ) as train_loader: + # TODO: This should be exactly 60. There is a corresponding TODO in the repeat_dataset.py + assert len(train_loader) == 58, len(train_loader) - data = list(enumerate(train_loader)) + data = list(enumerate(train_loader)) - # Check the overall number of samples - # Should be 0.7*len(ds1)55 + 1.5*len(ds2)55 = 38 + 55 + 27 (floor rounding) - # TODO: This should be exactly 60. There is a corresponding TODO in the repeat_dataset.py - assert len(data) == 58, len(data) + # Check the overall number of samples + # Should be 0.7*len(ds1)55 + 1.5*len(ds2)55 = 38 + 55 + 27 (floor rounding) + # TODO: This should be exactly 60. There is a corresponding TODO in the repeat_dataset.py + assert len(data) == 58, len(data) @patch.object(WatchdogDataset, "_watchdog_trigger") def test_watchdog_dataset(self, mock_watchdog_trigger): @@ -1167,16 +1153,15 @@ def encode_sample(self, sample: TextSample) -> TextSample: task_encoder=TestTaskEncoder(), ) - train_loader = get_loader( + with get_loader( train_dataset, watchdog_timeout_seconds=3, fail_on_timeout=False, - ) - - for idx, data in enumerate(train_loader): - print(idx, data.text[0]) - if idx > 255: - break + ) as train_loader: + for idx, data in enumerate(train_loader): + print(idx, data.text[0]) + if idx > 255: + break mock_watchdog_trigger.assert_called() @@ -1213,7 +1198,7 @@ def test_dataset_absolute_nested_subset_fail(self): ) try: - get_loader( + with get_loader( get_train_dataset( ratio_mds_path, worker_config=worker_config, @@ -1224,8 +1209,8 @@ def test_dataset_absolute_nested_subset_fail(self): max_samples_per_sequence=None, repeat=False, ) - ) - assert False, "Should have failed" + ): + assert False, "Should have failed" except Exception as e: assert "only allowed for a leaf dataset" in str( e @@ -1259,7 +1244,7 @@ def test_dataset_with_subset_end_keyword(self): ) ) - loader = get_loader( + with get_loader( get_train_dataset( ratio_mds_path, worker_config=worker_config, @@ -1270,11 +1255,10 @@ def test_dataset_with_subset_end_keyword(self): max_samples_per_sequence=None, repeat=False, ) - ) - - all_numbers = [int(s.text[0]) for s in loader] + ) as loader: + all_numbers = [int(s.text[0]) for s in loader] - assert all_numbers == [50, 51, 52, 53, 54], "Subset range [50, end] should be [50, 55]" + assert all_numbers == [50, 51, 52, 53, 54], "Subset range [50, end] should be [50, 55]" def test_dataset_with_subset_ratio(self): worker_config = WorkerConfig( @@ -1309,7 +1293,7 @@ def test_dataset_with_subset_ratio(self): ) ) - loader = get_loader( + with get_loader( get_train_dataset( ratio_mds_path, worker_config=worker_config, @@ -1320,19 +1304,18 @@ def test_dataset_with_subset_ratio(self): max_samples_per_sequence=None, repeat=False, ) - ) - - data = list(enumerate(loader)) - assert len(data) == 33 + 33 * 2, len(data) - - sample_counts = Counter([int(s[1].text[0]) for s in data]) - assert all(sample_counts[sample] == 0 for sample in range(11)), sample_counts - assert all(sample_counts[sample] == 1 for sample in range(11, 44)), sample_counts - assert all(sample_counts[sample] == 0 for sample in range(44, 55)), sample_counts - assert all(sample_counts[sample] == 0 for sample in range(100, 111)), sample_counts - assert all(sample_counts[sample] == 2 for sample in range(111, 144)), sample_counts - assert all(sample_counts[sample] == 0 for sample in range(144, 155)), sample_counts - assert sample_counts.total() == 33 + 33 * 2, sample_counts.total() + ) as loader: + data = list(enumerate(loader)) + assert len(data) == 33 + 33 * 2, len(data) + + sample_counts = Counter([int(s[1].text[0]) for s in data]) + assert all(sample_counts[sample] == 0 for sample in range(11)), sample_counts + assert all(sample_counts[sample] == 1 for sample in range(11, 44)), sample_counts + assert all(sample_counts[sample] == 0 for sample in range(44, 55)), sample_counts + assert all(sample_counts[sample] == 0 for sample in range(100, 111)), sample_counts + assert all(sample_counts[sample] == 2 for sample in range(111, 144)), sample_counts + assert all(sample_counts[sample] == 0 for sample in range(144, 155)), sample_counts + assert sample_counts.total() == 33 + 33 * 2, sample_counts.total() # Combine with subset_samples @@ -1364,7 +1347,7 @@ def test_dataset_with_subset_ratio(self): ) ) - loader = get_loader( + with get_loader( get_train_dataset( ratio2_mds_path, worker_config=worker_config, @@ -1375,19 +1358,18 @@ def test_dataset_with_subset_ratio(self): max_samples_per_sequence=None, repeat=False, ) - ) - - data = list(enumerate(loader)) - assert len(data) == 12 + 12 * 2, len(data) - - sample_counts = Counter([int(s[1].text[0]) for s in data]) - assert all(sample_counts[sample] == 0 for sample in range(14)), sample_counts - assert all(sample_counts[sample] == 1 for sample in range(14, 26)), sample_counts - assert all(sample_counts[sample] == 0 for sample in range(26, 55)), sample_counts - assert all(sample_counts[sample] == 0 for sample in range(100, 124)), sample_counts - assert all(sample_counts[sample] == 2 for sample in range(124, 136)), sample_counts - assert all(sample_counts[sample] == 0 for sample in range(136, 155)), sample_counts - assert sample_counts.total() == 12 + 12 * 2, sample_counts.total() + ) as loader: + data = list(enumerate(loader)) + assert len(data) == 12 + 12 * 2, len(data) + + sample_counts = Counter([int(s[1].text[0]) for s in data]) + assert all(sample_counts[sample] == 0 for sample in range(14)), sample_counts + assert all(sample_counts[sample] == 1 for sample in range(14, 26)), sample_counts + assert all(sample_counts[sample] == 0 for sample in range(26, 55)), sample_counts + assert all(sample_counts[sample] == 0 for sample in range(100, 124)), sample_counts + assert all(sample_counts[sample] == 2 for sample in range(124, 136)), sample_counts + assert all(sample_counts[sample] == 0 for sample in range(136, 155)), sample_counts + assert sample_counts.total() == 12 + 12 * 2, sample_counts.total() # Combine with subset_ratio and subset_samples and nested metadataset nested_mds_path = self.dataset_path / "metadataset_nested_subset.yaml" @@ -1418,7 +1400,7 @@ def test_dataset_with_subset_ratio(self): ) ) - loader = get_loader( + with get_loader( get_train_dataset( nested_mds_path, worker_config=worker_config, @@ -1429,21 +1411,20 @@ def test_dataset_with_subset_ratio(self): max_samples_per_sequence=None, repeat=False, ) - ) - - data = list(enumerate(loader)) - assert len(data) == 10 + 9 * 2, len(data) - sample_counts = Counter([int(s[1].text[0]) for s in data]) - assert all(sample_counts[sample] == 0 for sample in range(17)), sample_counts - assert all(sample_counts[sample] == 2 for sample in range(17, 20)), sample_counts - assert all(sample_counts[sample] == 0 for sample in range(20, 55)), sample_counts - assert all(sample_counts[sample] == 0 for sample in range(100, 127)), sample_counts - assert all(sample_counts[sample] == 4 for sample in range(127, 130)), sample_counts - assert all(sample_counts[sample] == 0 for sample in range(130, 155)), sample_counts - assert all(sample_counts[sample] == 0 for sample in range(200, 230)), sample_counts - assert all(sample_counts[sample] == 1 for sample in range(230, 240)), sample_counts - assert all(sample_counts[sample] == 0 for sample in range(240, 255)), sample_counts - assert sample_counts.total() == 10 + 9 * 2, sample_counts.total() + ) as loader: + data = list(enumerate(loader)) + assert len(data) == 10 + 9 * 2, len(data) + sample_counts = Counter([int(s[1].text[0]) for s in data]) + assert all(sample_counts[sample] == 0 for sample in range(17)), sample_counts + assert all(sample_counts[sample] == 2 for sample in range(17, 20)), sample_counts + assert all(sample_counts[sample] == 0 for sample in range(20, 55)), sample_counts + assert all(sample_counts[sample] == 0 for sample in range(100, 127)), sample_counts + assert all(sample_counts[sample] == 4 for sample in range(127, 130)), sample_counts + assert all(sample_counts[sample] == 0 for sample in range(130, 155)), sample_counts + assert all(sample_counts[sample] == 0 for sample in range(200, 230)), sample_counts + assert all(sample_counts[sample] == 1 for sample in range(230, 240)), sample_counts + assert all(sample_counts[sample] == 0 for sample in range(240, 255)), sample_counts + assert sample_counts.total() == 10 + 9 * 2, sample_counts.total() def test_s3(self): # Create a joined dataset configuration @@ -1466,7 +1447,7 @@ def test_s3(self): # EPath(self.dataset_path).copy(EPath("msc://s3/test/dataset")) emu.add_file(self.dataset_path, "test/dataset") - train_dataset = get_loader( + with get_loader( get_train_dataset( mixed_mds_path, worker_config=WorkerConfig( @@ -1479,10 +1460,9 @@ def test_s3(self): max_samples_per_sequence=None, virtual_epoch_length=10, ) - ) - - data = list(enumerate(train_dataset)) - assert len(data) == 10, len(data) + ) as train_dataset: + data = list(enumerate(train_dataset)) + assert len(data) == 10, len(data) if __name__ == "__main__": From df45aed5e41eaf50ccef77a3ac8c0ea8d16de6e5 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Thu, 4 Sep 2025 19:42:11 +0200 Subject: [PATCH 29/36] Change all tests to pytest. Cleanup a few unused functions. Implement multiple shutdown/close functionalities, especially worker_init/_close --- .coveragerc | 71 +- justfile | 11 +- pyproject.toml | 4 + src/megatron/energon/cache/base.py | 31 +- src/megatron/energon/cache/file_cache_pool.py | 6 + src/megatron/energon/cache/file_store.py | 18 + src/megatron/energon/cache/no_cache.py | 6 + src/megatron/energon/dataloader/dataloader.py | 58 +- .../dataloader/workers/async_worker.py | 19 +- .../energon/dataloader/workers/base_worker.py | 28 +- .../energon/dataloader/workers/fork_worker.py | 6 +- .../dataloader/workers/thread_worker.py | 5 +- src/megatron/energon/flavors/base_dataset.py | 10 + src/megatron/energon/flavors/jsonl/ijsonl.py | 15 +- .../energon/flavors/jsonl/ijsonl_reader.py | 97 +- .../energon/flavors/webdataset/itar.py | 117 +- .../energon/flavors/webdataset/itar_reader.py | 230 +- .../flavors/webdataset/multi_key_cache.py | 72 + .../flavors/webdataset/sample_loader.py | 18 +- src/megatron/energon/retry_stream.py | 94 - src/megatron/energon/task_encoder/base.py | 56 +- src/megatron/energon/tools/checkpoint.py | 4 +- src/megatron/energon/typed_converter.py | 596 +--- src/megatron/energon/worker.py | 74 +- src/megatron/energon/wrappers/__init__.py | 2 + src/megatron/energon/wrappers/base.py | 10 + .../wrappers/file_store_init_wrapper.py | 77 + tests/test_av_decoder.py | 647 ++-- tests/test_crudedataset.py | 895 ++--- tests/test_dataloader.py | 672 ++-- tests/test_dataset.py | 3063 ++++++++--------- tests/test_dataset_det.py | 1626 ++++----- tests/test_epathlib.py | 483 +-- tests/test_file_cache_pool.py | 1193 +++---- tests/test_jsonl_dataset.py | 420 +-- tests/test_metadataset.py | 2393 +++++++------ tests/test_metadataset_fewsamp.py | 341 +- tests/test_metadataset_v2.py | 2484 ++++++------- tests/test_transforms.py | 450 ++- tests/test_typed_converter.py | 465 +++ tests/test_typedconverter_extended.py | 2 + tests/test_weakref.py | 265 +- uv.lock | 52 +- 43 files changed, 8729 insertions(+), 8457 deletions(-) create mode 100644 src/megatron/energon/flavors/webdataset/multi_key_cache.py delete mode 100644 src/megatron/energon/retry_stream.py create mode 100644 src/megatron/energon/wrappers/file_store_init_wrapper.py create mode 100644 tests/test_typed_converter.py create mode 100644 tests/test_typedconverter_extended.py diff --git a/.coveragerc b/.coveragerc index 0e334746..27e98e10 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,5 +1,72 @@ +[run] +# Source directories to measure coverage for +source = src + +# Parallel mode for multiprocessing support +parallel = True + +# Branch coverage measurement +branch = True + +# Data file location +data_file = .coverage + [report] -include = ./src/megatron/energon/** +# Include source code in the report +include = src/megatron/energon/** + +# Minimum coverage percentage to pass +fail_under = 80 + +# Show missing lines in the report +show_missing = True + +# Skip covered files in the report +skip_covered = False + +# Skip empty files +skip_empty = True + +# Precision for coverage percentages +precision = 2 + +# Sort order for the report +sort = filename + +# Exclude lines from coverage +exclude_lines = + # Have to re-enable the standard pragma + pragma: no cover + + # Don't complain about missing debug-only code: + def __repr__ + if self\.debug + + # Don't complain if tests don't hit exceptions: + raise .* + + # Don't complain if non-runnable code isn't run: + if 0: + if __name__ == .__main__.: + + # Don't complain about abstract methods, they aren't run: + @(abc\.)?abstractmethod + + # Don't complain about type checking code: + if TYPE_CHECKING: + + +[html] +# Directory to put the HTML report +directory = htmlcov + +# Title for the HTML report +title = Megatron Energon Coverage Report [xml] -output = ./coverage.xml \ No newline at end of file +# Output file for XML report +output = coverage.xml + +[lcov] +# Output file for LCOV report +output = lcov.info diff --git a/justfile b/justfile index 8dc3ad5f..1fe3896e 100644 --- a/justfile +++ b/justfile @@ -25,12 +25,15 @@ check: dev-sync # Execute all unit tests test: dev-sync - uv run -m unittest discover -v -s tests + uv run pytest tests -v coverage: dev-sync - uv run -m coverage run -m unittest discover -v -s tests - uv run -m coverage html - echo "Coverage report generated at ./htmlcov/index.html" + COVERAGE_PROCESS_START=.coveragerc uv run -m coverage run --parallel-mode --concurrency=multiprocessing -m pytest tests + # COVERAGE_PROCESS_START=.coveragerc uv run -m coverage run --parallel-mode --concurrency=multiprocessing -m pytest tests/test_dataloader.py + # COVERAGE_PROCESS_START=.coveragerc uv run -m coverage run --parallel-mode --concurrency=multiprocessing -m pytest tests/test_typed_converter.py + uv run -m coverage combine + uv run -m coverage lcov + echo "Coverage LCOV report generated at ./lcov.info" # Build the docs docs: dev-sync diff --git a/pyproject.toml b/pyproject.toml index 8780df99..e0319648 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -39,6 +39,7 @@ dependencies = [ [project.optional-dependencies] dev = [ "coverage", + "pytest", "ruff", "sphinxcontrib-napoleon", "sphinx", @@ -120,3 +121,6 @@ skip_covered = true [tool.coverage.html] show_contexts = true + +[tool.pytest.ini_options] +addopts = "--tb=native --capture=no" diff --git a/src/megatron/energon/cache/base.py b/src/megatron/energon/cache/base.py index f05489ec..588bf1b8 100644 --- a/src/megatron/energon/cache/base.py +++ b/src/megatron/energon/cache/base.py @@ -10,7 +10,7 @@ T = TypeVar("T") -class FileStore(Generic[T]): +class FileStore(ABC, Generic[T]): """Base type for a dataset that can be accessed randomly by sample key.""" @abstractmethod @@ -29,6 +29,21 @@ def get_path(self) -> str: """Returns the path to the dataset.""" ... + @abstractmethod + def worker_init(self) -> None: + """Initializes the file store for the current worker.""" + raise NotImplementedError("worker_init is not implemented for this file store") + + @abstractmethod + def worker_close(self) -> None: + """Closes the file store for the current worker.""" + raise NotImplementedError("worker_close is not implemented for this file store") + + @abstractmethod + def close(self) -> None: + """Closes the file store.""" + raise NotImplementedError("close is not implemented for this file store") + @edataclass class Lazy(Generic[T]): @@ -125,6 +140,20 @@ def get_lazy(self, ds: FileStore, fname: str) -> Lazy: """ ... + @abstractmethod + def worker_init(self) -> None: + """ + Initialize the cache pool for the current worker. + """ + ... + + @abstractmethod + def worker_close(self) -> None: + """ + Close the cache pool for the current worker. + """ + ... + @abstractmethod def close(self) -> None: """ diff --git a/src/megatron/energon/cache/file_cache_pool.py b/src/megatron/energon/cache/file_cache_pool.py index df3a945e..d3a68695 100644 --- a/src/megatron/energon/cache/file_cache_pool.py +++ b/src/megatron/energon/cache/file_cache_pool.py @@ -326,6 +326,12 @@ def get_lazy(self, ds: FileStore, fname: str) -> FileCacheLazy: return FileCacheLazy(ds=ds, fname=fname, pool=self, entry=entry) + def worker_init(self) -> None: + pass + + def worker_close(self) -> None: + pass + def close(self) -> None: """ Shutdown the pool, wait for tasks, and clear our structures. diff --git a/src/megatron/energon/cache/file_store.py b/src/megatron/energon/cache/file_store.py index 6247d748..6cd42b25 100644 --- a/src/megatron/energon/cache/file_store.py +++ b/src/megatron/energon/cache/file_store.py @@ -28,6 +28,15 @@ def __init__( self.inner_reader = inner_reader self.decoder = decoder + def worker_init(self) -> None: + self.inner_reader.worker_init() + + def worker_close(self) -> None: + self.inner_reader.worker_close() + + def close(self) -> None: + self.inner_reader.close() + def __getitem__(self, fname: str) -> tuple[Any, SourceInfo]: data, source_info = self.inner_reader[fname] return self.decoder.decode(fname, data), source_info @@ -69,6 +78,15 @@ def __getitem__(self, key: str) -> tuple[bytes, SourceInfo]: file_names=(key,), ) + def worker_init(self) -> None: + pass + + def worker_close(self) -> None: + pass + + def close(self) -> None: + pass + def get_path(self) -> str: """Returns the path to the dataset.""" return str(self.base_dir) diff --git a/src/megatron/energon/cache/no_cache.py b/src/megatron/energon/cache/no_cache.py index 2da8429a..ffdff861 100644 --- a/src/megatron/energon/cache/no_cache.py +++ b/src/megatron/energon/cache/no_cache.py @@ -47,5 +47,11 @@ def get(self, ds: FileStore, fname: str, sample: Any = None) -> Any: def get_lazy(self, ds: FileStore, fname: str) -> DirectLazy: return DirectLazy(ds=ds, fname=fname, pool=self) + def worker_init(self) -> None: + pass + + def worker_close(self) -> None: + pass + def close(self) -> None: pass diff --git a/src/megatron/energon/dataloader/dataloader.py b/src/megatron/energon/dataloader/dataloader.py index 8be531b3..ef3029f8 100644 --- a/src/megatron/energon/dataloader/dataloader.py +++ b/src/megatron/energon/dataloader/dataloader.py @@ -1,6 +1,5 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -import functools import os import sys import warnings @@ -18,7 +17,7 @@ import torch.distributed from megatron.energon.cache.base import CachePool -from megatron.energon.dataloader.future import CallableFuture, Future +from megatron.energon.dataloader.future import Future from megatron.energon.dataloader.pin_memory import NoPinMemory, PinMemory, PinMemoryThread from megatron.energon.dataloader.workers.base_worker import ( DataLoaderNoWorker, @@ -28,7 +27,7 @@ ) from megatron.energon.dataloader.workers.fork_worker import ForkDataLoaderWorker from megatron.energon.edataclass import edataclass -from megatron.energon.flavors.base_dataset import RestoreKey, SavableDataset, set_sample_restore_key +from megatron.energon.flavors.base_dataset import RestoreKey, SavableDataset from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.base import BaseWrapperDataset, get_sample_restore_key from megatron.energon.wrappers.batch_dataset import BatchDataset @@ -49,7 +48,7 @@ class RankState: #: This is a list (per worker) of lists of (batch) sample keys, which have been (asynchronously) prefetched from workers # but not been fetched yet by iterating. - prefetched_samples_keys: list[list[RestoreKey | None]] + prefetched_restore_keys: list[list[RestoreKey | None]] #: This is a list of worker states, which have been saved from the workers (or `None` for the initial state). worker_states: list[WorkerState | None] #: The next worker ID to prefetch from (i.e. append to the prefetched samples). @@ -121,7 +120,7 @@ def __init__( Create the dataloader supporting saving and restoring the state. Args: - dataset: The dataset to load. + dataset: The dataset to load. The loader takes ownership of the dataset, i.e. it cannot be shared and will be closed on shutdown. prefetch_factor: The number of samples to prefetch from each worker. worker_type: The type of worker to use. cache_pool: If set, the cache pool to use for the dataset. @@ -144,6 +143,12 @@ def __init__( self._id = DataLoader._next_id DataLoader._next_id += 1 + if getattr(dataset, "__dataloader_id", None) is not None: + raise ValueError( + f"Dataset {dataset} is already associated with dataloader {getattr(dataset, '__dataloader_id')}. Initialize one dataset per dataloader." + ) + setattr(dataset, "__dataloader_id", self._id) + if dataset.worker_config.num_workers == 0 and worker_type == ForkDataLoaderWorker: worker_type = DataLoaderNoWorker @@ -226,12 +231,10 @@ def start(self) -> None: if self._restore_state is not None: self._prefetching_samples = [ [ - self._pin_memory( - CallableFuture(functools.partial(self.restore_sample, sample_key)) - ) - for sample_key in prefetched_samples_keys + self._pin_memory(self._restore_sample(restore_key)) + for restore_key in prefetched_restore_keys ] - for prefetched_samples_keys in self._restore_state.prefetched_samples_keys + for prefetched_restore_keys in self._restore_state.prefetched_restore_keys ] self._next_worker_id = self._restore_state.next_worker_id self._exhausted_workers = [ @@ -263,6 +266,7 @@ def shutdown(self, in_del: bool = False) -> None: for worker in self._workers: worker.shutdown(in_del=in_del) self._workers = None + self._dataset.close() self._pin_memory.shutdown(in_del=in_del) def __del__(self) -> None: @@ -438,7 +442,7 @@ def _get_batch_size(self) -> int | None: def save_state_rank(self) -> RankState: if self._restore_state is not None: return self._restore_state - prefetched_samples_keys = [ + prefetched_restore_keys = [ [get_sample_restore_key(sample_fut.get()) for sample_fut in prefetching_sample] for prefetching_sample in self._prefetching_samples ] @@ -455,7 +459,7 @@ def save_state_rank(self) -> RankState: ), "Exhausted workers mismatch" return RankState( - prefetched_samples_keys=prefetched_samples_keys, + prefetched_restore_keys=prefetched_restore_keys, worker_states=worker_states, next_worker_id=self._next_worker_id, micro_batch_size=self._get_batch_size(), @@ -607,7 +611,7 @@ def restore_state_global( self.restore_state_rank(rank_state) - def restore_sample(self, restore_key: RestoreKey) -> TSample: + def _restore_sample(self, restore_key: RestoreKey) -> Future[TSample]: """ Restore a sample from a restore key. @@ -615,20 +619,26 @@ def restore_sample(self, restore_key: RestoreKey) -> TSample: restore_key: The restore key to restore the sample from. Returns: - The restored sample. + A future that will be resolved to the restored sample. """ assert isinstance(restore_key, WorkerSampleRestoreKey) - self._worker_config.worker_activate( - restore_key.sample_idx, - override_global_rank=restore_key.worker_id, - cache_pool=self._cache_pool, + assert self._workers is not None, "Workers must be started before restoring a sample" + rank_worker_id = self._worker_config.rank_worker_id( + override_global_worker_id=restore_key.worker_id ) - try: - return set_sample_restore_key( - self._dataset.restore_sample(restore_key.inner), restore_key - ) - finally: - self._worker_config.worker_deactivate() + return self._workers[rank_worker_id].restore_sample(restore_key) + + def restore_sample(self, restore_key: RestoreKey) -> TSample: + """ + Restore a sample from a restore key. + + Args: + restore_key: The restore key to restore the sample from. + + Returns: + The restored sample. + """ + return self._restore_sample(restore_key).get() def with_restored_state_rank(self, state: RankState | None) -> "DataLoader[TSample]": """ diff --git a/src/megatron/energon/dataloader/workers/async_worker.py b/src/megatron/energon/dataloader/workers/async_worker.py index 25b4f370..34d16bb3 100644 --- a/src/megatron/energon/dataloader/workers/async_worker.py +++ b/src/megatron/energon/dataloader/workers/async_worker.py @@ -13,7 +13,11 @@ WorkerResult, ) from megatron.energon.dataloader.future import Future -from megatron.energon.dataloader.workers.base_worker import DataLoaderWorker, WorkerState +from megatron.energon.dataloader.workers.base_worker import ( + DataLoaderWorker, + WorkerSampleRestoreKey, + WorkerState, +) from megatron.energon.flavors.base_dataset import SavableDataset from megatron.energon.rng import SystemRng from megatron.energon.worker import WorkerConfig @@ -78,6 +82,12 @@ def _wrk_prefetch_next(self) -> TSample: # so immediately resolve the future to the result (get returns immediately). return super().prefetch_next().get() + def _wrk_restore_sample(self, restore_key: WorkerSampleRestoreKey) -> TSample: + """Wraps the super class method to call it in the worker process.""" + # The super class implementation already returns a resolved future (to be interface compatible), + # so immediately resolve the future to the result (get returns immediately). + return super().restore_sample(restore_key).get() + def dataset_init(self, initial_state: WorkerState | None) -> None: if self._in_worker(): return super().dataset_init(initial_state) @@ -96,6 +106,13 @@ def prefetch_next(self) -> Future[TSample]: return super().prefetch_next() return self._worker_call(self._wrk_prefetch_next) + def restore_sample(self, restore_key: WorkerSampleRestoreKey) -> Future[TSample]: + # Do not resolve the future here, but return it. + assert isinstance(restore_key, WorkerSampleRestoreKey) + if self._in_worker(): + return super().restore_sample(restore_key) + return self._worker_call(self._wrk_restore_sample, restore_key) + def save_state(self) -> WorkerState: if self._in_worker(): return super().save_state() diff --git a/src/megatron/energon/dataloader/workers/base_worker.py b/src/megatron/energon/dataloader/workers/base_worker.py index fcf5b932..e35c27d6 100644 --- a/src/megatron/energon/dataloader/workers/base_worker.py +++ b/src/megatron/energon/dataloader/workers/base_worker.py @@ -6,7 +6,7 @@ from megatron.energon.cache.base import CachePool from megatron.energon.dataloader.future import DoneFuture, ExceptionFuture, Future from megatron.energon.edataclass import edataclass -from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.flavors.base_dataset import SavableDataset, set_sample_restore_key from megatron.energon.rng import SystemRng, SystemRngState from megatron.energon.state import FlexState from megatron.energon.worker import WorkerConfig @@ -96,7 +96,8 @@ def shutdown(self, in_del: bool = False) -> None: Args: in_del: If True, the worker is being deleted. """ - pass + self.dataset.worker_close() + self.dataset.close() def running(self) -> bool: """ @@ -191,6 +192,29 @@ def prefetch_next(self) -> Future[TSample]: self.worker_config.worker_deactivate() return DoneFuture(next_sample) + def restore_sample(self, restore_key: WorkerSampleRestoreKey) -> Future[TSample]: + """ + Restore a sample from a restore key in the worker. + + Args: + restore_key: The restore key of the sample to restore. + + Returns: + A future that will be resolved to the restored sample. + """ + assert isinstance(restore_key, WorkerSampleRestoreKey) + assert self._global_worker_id == restore_key.worker_id, "Global worker ID mismatch" + self.worker_config.worker_activate( + restore_key.sample_idx, + cache_pool=self._cache_pool, + ) + try: + return DoneFuture( + set_sample_restore_key(self.dataset.restore_sample(restore_key.inner), restore_key) + ) + finally: + self.worker_config.worker_deactivate() + def save_state(self) -> WorkerState: """ Save the state of the worker. diff --git a/src/megatron/energon/dataloader/workers/fork_worker.py b/src/megatron/energon/dataloader/workers/fork_worker.py index 9fa92f72..0897dcb9 100644 --- a/src/megatron/energon/dataloader/workers/fork_worker.py +++ b/src/megatron/energon/dataloader/workers/fork_worker.py @@ -34,4 +34,8 @@ def _worker_run( dataset=self.dataset, ) - super()._worker_run(cmd_queue, result_queue) + try: + super()._worker_run(cmd_queue, result_queue) + finally: + self.dataset.worker_close() + self.dataset.close() diff --git a/src/megatron/energon/dataloader/workers/thread_worker.py b/src/megatron/energon/dataloader/workers/thread_worker.py index 5550d3b4..a8a2dea5 100644 --- a/src/megatron/energon/dataloader/workers/thread_worker.py +++ b/src/megatron/energon/dataloader/workers/thread_worker.py @@ -28,4 +28,7 @@ def _worker_run( seed=self._seed, dataset=self.dataset, ) - return super()._worker_run(cmd_queue, result_queue) + try: + return super()._worker_run(cmd_queue, result_queue) + finally: + self.dataset.worker_close() diff --git a/src/megatron/energon/flavors/base_dataset.py b/src/megatron/energon/flavors/base_dataset.py index 53343a61..e08bb5c1 100644 --- a/src/megatron/energon/flavors/base_dataset.py +++ b/src/megatron/energon/flavors/base_dataset.py @@ -362,6 +362,12 @@ def reset_state(self) -> None: """ pass + def worker_close(self) -> None: + """ + Closes all worker-local resources. + """ + pass + @abstractmethod def worker_has_samples(self) -> bool: """Returns True if the worker's split has samples. This is used to determine if this dataset @@ -408,6 +414,10 @@ def restore_sample(self, restore_key: "RestoreKey") -> T_sample: "This dataset does not support restoring, because it is not safely deterministic." ) + def close(self) -> None: + """Closes all shared resources.""" + pass + def __getattribute__(self, name: str) -> Any: if name in ("_savable_fields", "_worker_local_fields", "_thread_state", "worker_config"): return object.__getattribute__(self, name) diff --git a/src/megatron/energon/flavors/jsonl/ijsonl.py b/src/megatron/energon/flavors/jsonl/ijsonl.py index c292aa2b..7952ced1 100644 --- a/src/megatron/energon/flavors/jsonl/ijsonl.py +++ b/src/megatron/energon/flavors/jsonl/ijsonl.py @@ -112,9 +112,13 @@ class CachedIJsonlOffsetReader: cache_size: The number of entries to keep in the cache. By default, we keep 32. """ + ijsonl_index_reader_cache: Dict[int, CacheEntry] + cache_size: int + jsonl_file: EPath + def __init__(self, jsonl_file: Union[str, EPath], cache_size: int = 32): # Maps current_offset -> CacheEntry - self.ijsonl_index_reader_cache: Dict[int, CacheEntry] = {} + self.ijsonl_index_reader_cache = {} self.cache_size = cache_size self.jsonl_file = EPath(jsonl_file) @@ -219,11 +223,6 @@ def get_ijsonl_byte_offset( return result_byte_offset, length - def __len__(self) -> int: - if len(self.ijsonl_index_reader_cache) == 0: - return IJsonlIndexReader.count_samples(self.jsonl_file) - return len(next(iter(self.ijsonl_index_reader_cache.values())).ijsonl_index_reader) - 1 - def get_total_size(self) -> int: if len(self.ijsonl_index_reader_cache) == 0: self.ijsonl_index_reader_cache[0] = CacheEntry( @@ -259,6 +258,10 @@ class IJsonlFile: """ def __init__(self, fileobj: BinaryIO): + """ + Args: + fileobj: The file object to read from. Takes ownership of the file object. + """ self.fileobj = fileobj def seek(self, offset: int): diff --git a/src/megatron/energon/flavors/jsonl/ijsonl_reader.py b/src/megatron/energon/flavors/jsonl/ijsonl_reader.py index f50b242c..c7e1e999 100644 --- a/src/megatron/energon/flavors/jsonl/ijsonl_reader.py +++ b/src/megatron/energon/flavors/jsonl/ijsonl_reader.py @@ -1,7 +1,7 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -from abc import ABC +import threading from typing import ( Callable, Generator, @@ -17,13 +17,15 @@ IJsonlIndexReader, IJsonlSamplePointer, ) -from megatron.energon.flavors.webdataset.structs import FilteredSample +from megatron.energon.flavors.webdataset.itar_reader import RawSampleReaderInterface +from megatron.energon.flavors.webdataset.multi_key_cache import MultiKeyCache +from megatron.energon.flavors.webdataset.structs import FilteredSample, WebdatasetRestoreKey from megatron.energon.source_info import SourceInfo T_index = TypeVar("T_index", covariant=False) -class IJsonlReader(ABC): +class IJsonlReader(RawSampleReaderInterface[int | str]): """ Class for reading indexed jsonl files containing json samples. @@ -40,8 +42,12 @@ class IJsonlReader(ABC): jsonl_path: EPath sample_filter: Optional[Callable[[str], bool]] - cached_offset_reader: CachedIJsonlOffsetReader - ijsonl_file: IJsonlFile | None = None + _length: int + _total_size: int + + thread_local: threading.local + cache_lock: threading.Lock + ijsonl_files_cache: MultiKeyCache[int, IJsonlFile] def __init__( self, @@ -51,16 +57,62 @@ def __init__( ): self.jsonl_path = jsonl_path self.sample_filter = sample_filter - self.cached_offset_reader = CachedIJsonlOffsetReader( - jsonl_path, cache_size=index_cache_size - ) + self.index_cache_size = index_cache_size + self.thread_local = threading.local() + self.ijsonl_files_cache = MultiKeyCache() + self.cache_lock = threading.Lock() + + with IJsonlIndexReader(jsonl_path) as ijsonl_index_reader: + # Number of samples + self._length = len(ijsonl_index_reader) - 1 + # Byte size + self._total_size = ijsonl_index_reader[self._length] def __len__(self) -> int: - return len(self.cached_offset_reader) + return self._length def __str__(self) -> str: return f"IJsonlReader(jsonl_path={self.jsonl_path})" + @property + def _cached_offset_reader(self) -> CachedIJsonlOffsetReader: + return self.thread_local._cached_offset_reader + + def worker_init(self): + self.thread_local._cached_offset_reader = CachedIJsonlOffsetReader( + self.jsonl_path, cache_size=self.index_cache_size + ) + + def worker_close(self): + if hasattr(self.thread_local, "_cached_offset_reader"): + self.thread_local._cached_offset_reader.close() + del self.thread_local._cached_offset_reader + + def _get_ijsonl_file_cached(self, sample_idx: int) -> IJsonlFile: + """ + Get the IJsonlFile object for the given sample index. + If the file is not already open, open it. + """ + with self.cache_lock: + reader = self.ijsonl_files_cache.pop(sample_idx) + if reader is None: + if len(self.ijsonl_files_cache) < self.index_cache_size: + reader = IJsonlFile(fileobj=self.jsonl_path.open(mode="rb")) + else: + # Reuse the oldest file + reader = self.ijsonl_files_cache.pop() + return reader + + def _update_ijsonl_file_cache(self, sample_idx: int, reader: IJsonlFile) -> None: + """ + Update the IJsonlFile object for the given sample index. + """ + with self.cache_lock: + while len(self.ijsonl_files_cache) >= self.index_cache_size: + # Evict the oldest file + self.ijsonl_files_cache.pop().close() + self.ijsonl_files_cache.add(sample_idx, reader) + def _get_item_by_sample_pointer( self, sample_pointer: IJsonlSamplePointer, @@ -69,8 +121,7 @@ def _get_item_by_sample_pointer( Get a sample from the dataset or slice it. Args: - sample_pointer: The sample pointer to get the sample from. - sample_index: The global index of the sample in the dataset. + sample_pointer: Pointer to the sample in the jsonl file. Returns: The sample or None if the sample is invalid. @@ -80,20 +131,22 @@ def _get_item_by_sample_pointer( if self.sample_filter is not None and not self.sample_filter(key): return None - if self.ijsonl_file is None: - self.ijsonl_file = IJsonlFile(self.jsonl_path.open("rb")) + ijsonl_file = self._get_ijsonl_file_cached(sample_pointer.index) + + json_data = ijsonl_file.next(sample_pointer.byte_offset, sample_pointer.byte_size) - json_data = self.ijsonl_file.next(sample_pointer.byte_offset, sample_pointer.byte_size) if json_data is None: return None + self._update_ijsonl_file_cache(sample_pointer.index + 1, ijsonl_file) + return FilteredSample( __key__=f"{self.jsonl_path.name}/{key}", __shard__=self.jsonl_path.name, - __restore_key__=("Webdataset", sample_pointer.index), + __restore_key__=WebdatasetRestoreKey(index=sample_pointer.index), __sources__=( SourceInfo( - dataset_path=str(self.jsonl_path), + dataset_path=self.jsonl_path, index=sample_pointer.index, shard_name=self.jsonl_path.name, file_names=(f"{key}.json",), @@ -118,7 +171,7 @@ def __getitem__(self, idx: int | str) -> FilteredSample | tuple[bytes, SourceInf except ValueError: raise ValueError(f"Invalid JSONL sample key: {idx}") - byte_offset, byte_size = self.cached_offset_reader.get_ijsonl_byte_offset(idx) + byte_offset, byte_size = self._cached_offset_reader.get_ijsonl_byte_offset(idx) sample: FilteredSample | None = self._get_item_by_sample_pointer( IJsonlSamplePointer( index=idx, @@ -178,13 +231,13 @@ def list_sample_parts(self, sample_key: str) -> Generator[Tuple[str, int, int], except ValueError: raise ValueError(f"Invalid JSONL sample key: {sample_key}") - _, byte_size = self.cached_offset_reader.get_ijsonl_byte_offset(sample_idx) + _, byte_size = self._cached_offset_reader.get_ijsonl_byte_offset(sample_idx) yield f"{sample_key}.json", byte_size, 0 def get_total_size(self) -> int: - return self.cached_offset_reader.get_total_size() + return self._total_size def close(self): - if self.ijsonl_file is not None: - self.ijsonl_file.close() - self.cached_offset_reader.close() + with self.cache_lock: + for ijsonl_file in self.ijsonl_files_cache.flush(): + ijsonl_file.close() diff --git a/src/megatron/energon/flavors/webdataset/itar.py b/src/megatron/energon/flavors/webdataset/itar.py index ef6e435e..e5299add 100644 --- a/src/megatron/energon/flavors/webdataset/itar.py +++ b/src/megatron/energon/flavors/webdataset/itar.py @@ -1,15 +1,12 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -import contextlib import struct import tarfile -from types import TracebackType -from typing import BinaryIO, Dict, Generator, Optional, Tuple, Type, Union +from typing import Dict, Generator, Optional, Tuple, Union from megatron.energon.edataclass import edataclass from megatron.energon.epathlib import EPath -from megatron.energon.retry_stream import RetryReadStream ITAR_SUFFIX = ".tar.idx" @@ -106,76 +103,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.close(finalize=exc_val is None) -class SubFileReader(BinaryIO): - """A file-like object that reads a subfile (i.e. offset, size defined portion) of a larger - file.""" - - def __init__(self, stream: BinaryIO, offset: int, size: int): - self.offset = offset - self._pos = 0 - self.size = size - self.stream = stream - self.stream.seek(self.offset) - - def read(self, n: int = -1) -> bytes: - if n == -1: - n = self.size - self._pos - else: - n = min(n, self.size - self._pos) - if n == 0: - return b"" - read = self.stream.read(n) - self._pos += len(read) - return read - - def seek(self, offset: int, whence: int = 0) -> int: - if whence == 0: - self._pos = offset - elif whence == 1: - self._pos += offset - elif whence == 2: - self._pos = self.size + offset - else: - raise ValueError("Invalid whence value") - self._pos = max(0, min(self._pos, self.size)) - self.stream.seek(self.offset + self._pos) - return self._pos - - def tell(self) -> int: - return self._pos - - def __enter__(self) -> BinaryIO: - return self - - def __exit__( - self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType - ) -> None: - self.close() - - def close(self) -> None: - self.stream.close() - - def isatty(self) -> bool: - return False - - def seekable(self) -> bool: - return True - - def writable(self) -> bool: - return False - - -def get_itar_byte_offset( - path: Union[str, EPath], - sample_offset: int = 0, -) -> int: - """Gets the byte offset from sample offsets.""" - if sample_offset == 0: - return 0 - with TarIndexReader(path) as itar: - return itar[sample_offset] - - @edataclass class CacheEntry: tar_index_reader: TarIndexReader @@ -308,6 +235,11 @@ def get_itar_byte_offset( return result_byte_offset, length + def close(self): + for cache_entry in self.tar_index_reader_cache.values(): + cache_entry.tar_index_reader.close() + self.tar_index_reader_cache.clear() + class ITarFile(tarfile.TarFile): """This class is a subclass of tarfile.TarFile that allows for reading a tarfile, @@ -341,40 +273,3 @@ def next(self): self.fileobj.seek(self.offset) return super().next() - - -@contextlib.contextmanager -def open_itar(path: Union[str, EPath], byte_offset: int = 0, byte_size: Optional[int] = None): - """ - Open an indexed tarfile with offset and size. - Args: - path: Path to the tarfile to open - byte_offset: Byte offset within the file - byte_size: Size of the file to read - - Returns: - The opened tarfile - """ - path = EPath(path) - - # TODO: if tar file startswith(b"\x1f\x8b\x08") -> Seekable gzip file - with path.open("rb") as f: - if f.read(3) == b"\x1f\x8b\x08": - # Open as seekable tgz - raise ValueError("Seekable tgz not supported yet") - - if byte_offset != 0 or byte_size is not None: - if byte_size is None: - byte_size = path.size() - byte_offset - with RetryReadStream(path) as stream: - with SubFileReader( - stream, - offset=byte_offset, - size=byte_size, - ) as fileobj: - with ITarFile.open(fileobj=fileobj, mode="r:") as f: - yield f - else: - with RetryReadStream(path) as fileobj: - with ITarFile.open(fileobj=fileobj, mode="r:") as f: - yield f diff --git a/src/megatron/energon/flavors/webdataset/itar_reader.py b/src/megatron/energon/flavors/webdataset/itar_reader.py index 057c3bbf..3fca64a1 100644 --- a/src/megatron/energon/flavors/webdataset/itar_reader.py +++ b/src/megatron/energon/flavors/webdataset/itar_reader.py @@ -27,6 +27,7 @@ ITarSamplePointer, ) from megatron.energon.flavors.webdataset.metadata import get_info_shard_files +from megatron.energon.flavors.webdataset.multi_key_cache import MultiKeyCache from megatron.energon.flavors.webdataset.structs import ( FilteredSample, ShardInfo, @@ -35,68 +36,55 @@ from megatron.energon.source_info import SourceInfo T_index = TypeVar("T_index", covariant=False) -T_key = TypeVar("T_key") -T_value = TypeVar("T_value") -class MultiKeyCache(Generic[T_key, T_value]): - """A cache that can store multiple values for the same key.""" - - _size: int - _cache: dict[T_key, list[T_value]] - _lru_keys: list[T_key] - - def __init__(self) -> None: - self._size = 0 - self._cache = {} - self._lru_keys = [] - - @overload - def pop(self, key: None = None) -> T_value: ... - - @overload - def pop(self, key: T_key) -> T_value | None: ... +class RawSampleReaderInterface(ABC, Generic[T_index]): + """ + An abstract base class for reading a sequence of raw samples. + """ - def pop(self, key: T_key | None = None) -> T_value | None: - """Pop the value for the given key from the cache. + @abstractmethod + def __len__(self) -> int: + """Returns the total number of samples in the reader.""" + ... - If no key is provided, pop the oldest key from the cache. + @abstractmethod + def __str__(self) -> str: + """ + Must return a descriptive string of the concrete reader. + """ + ... - Args: - key: The key to pop from the cache. If None, pop the oldest key from the cache. + @abstractmethod + def worker_init(self): + """ + Initialize the reader for the worker. + """ + ... - Returns: - The value popped from the cache. + @abstractmethod + def worker_close(self): """ - if key is None: - key = self._lru_keys.pop(0) - elif key not in self._cache: - return None - else: - self._lru_keys.pop(len(self._lru_keys) - 1 - self._lru_keys[::-1].index(key)) - - l = self._cache[key] - value = l.pop(0) - if len(l) == 0: - del self._cache[key] - self._size -= 1 - return value - - def add(self, key: T_key, value: T_value) -> None: - """Add a value to the cache.""" - if key not in self._cache: - self._cache[key] = [value] - else: - self._cache[key].insert(0, value) + Close the reader for the worker. + """ + ... - self._lru_keys.append(key) - self._size += 1 + @abstractmethod + def close(self): + """ + Close the reader and clear all shared resources. + """ + ... - def __len__(self) -> int: - return self._size + @abstractmethod + def __getitem__(self, idx: T_index) -> FilteredSample | None: + """ + Get a sample from the dataset or slice it. Thread-safe. + """ + ... -class ITarReader(ABC, Generic[T_index]): +class ITarReader(RawSampleReaderInterface[T_index], Generic[T_index]): """ An abstract base class for reading a sequence of tar files containing samples. @@ -141,23 +129,14 @@ def __init__( self.itar_cache_size = itar_cache_size self.sample_filter = sample_filter - @abstractmethod - def __len__(self) -> int: - """Returns the total number of samples in the reader.""" - raise NotImplementedError - - @abstractmethod - def __str__(self) -> str: - """ - Must return a descriptive string of the concrete reader. - """ - raise NotImplementedError - def close(self): - for tar_file in self.itar_files_cache.values(): - tar_file.fileobj.close() - tar_file.close() - self.itar_files_cache.clear() + """Effectively clears the internal shared cache.""" + with self.cache_lock: + for tar_file in self.itar_files_cache.flush(): + fileobj = tar_file.fileobj + tar_file.close() + if fileobj is not None: + fileobj.close() @abstractmethod def _get_itar_sample_pointer(self, idx: T_index) -> ITarSamplePointer: @@ -328,7 +307,9 @@ class JoinIndexFileITarReader(ITarReader[int]): index_file: EPath column: int - index_reader_cache: Dict[int, JoinIndexReader] + index_reader_cache_lock: threading.Lock + index_reader_cache: MultiKeyCache[int, JoinIndexReader] + active_readers: int = 0 index_reader_cache_size: int def __init__( @@ -347,7 +328,8 @@ def __init__( # Create the full path to each tar file tar_filepaths = [base_path / fn for fn in tar_filenames] - self.index_reader_cache = {} + self.index_reader_cache_lock = threading.Lock() + self.index_reader_cache = MultiKeyCache() self.index_reader_cache_size = itar_cache_size super().__init__( @@ -359,24 +341,36 @@ def __init__( sample_filter=sample_filter, ) + def worker_init(self): + pass + + def worker_close(self): + pass + def _get_join_index_reader_cached(self, sample_idx: int) -> JoinIndexReader: """ Get the JoinIndexReader object for the given sample index, or create it if it doesn't exist. """ + with self.index_reader_cache_lock: + index_reader = self.index_reader_cache.pop(sample_idx) + if index_reader is None: + if len(self.index_reader_cache) < self.index_reader_cache_size: + index_reader = JoinIndexReader(self.index_file, column=self.column) + else: + # Just reuse the oldest reader + index_reader = self.index_reader_cache.pop() - if sample_idx not in self.index_reader_cache: - index_reader = JoinIndexReader(self.index_file, column=self.column) - self.index_reader_cache[sample_idx] = index_reader - - # If we hit the limit of open files, close the least recently used file - while len(self.index_reader_cache) > self.index_reader_cache_size: - # Get the oldest file - lru_key = next(iter(self.index_reader_cache)) - - self.index_reader_cache[lru_key].close() - del self.index_reader_cache[lru_key] + return index_reader - return self.index_reader_cache[sample_idx] + def _update_index_reader_cache(self, sample_idx: int, reader: JoinIndexReader) -> None: + """ + Update the JoinIndexReader object for the given tar file id. + """ + with self.index_reader_cache_lock: + # If we hit the limit of open files, close the least recently used file + while len(self.index_reader_cache) >= self.index_reader_cache_size: + self.index_reader_cache.pop().close() + self.index_reader_cache.add(sample_idx, reader) def _get_itar_sample_pointer(self, sample_idx: int) -> ITarSamplePointer: """ @@ -387,8 +381,11 @@ def _get_itar_sample_pointer(self, sample_idx: int) -> ITarSamplePointer: # Update cache entry new_offset = index_reader.tell_row() - del self.index_reader_cache[sample_idx] - self.index_reader_cache[new_offset] = index_reader + assert new_offset == sample_idx + 1, ( + f"Expected new offset to be {sample_idx + 1}, got {new_offset}" + ) + + self._update_index_reader_cache(new_offset, index_reader) assert len(row) == 1 shard_idx, byte_offset, byte_size = row[0] @@ -402,8 +399,8 @@ def _get_itar_sample_pointer(self, sample_idx: int) -> ITarSamplePointer: def __len__(self) -> int: try: # Get any reader, they will all work - index_reader = next(iter(self.index_reader_cache.values())) - except StopIteration: + index_reader = self.index_reader_cache.pop() + except IndexError: # If there's no reader yet, we need to create one to get the length index_reader = self._get_join_index_reader_cached(0) @@ -476,13 +473,19 @@ def __init__( ) @property - def cached_offset_reader(self) -> CachedItarOffsetReader: - if not hasattr(self._thread_local, "_cached_offset_reader"): - self._thread_local._cached_offset_reader = CachedItarOffsetReader( - cache_size=self._itar_cache_size - ) + def _cached_offset_reader(self) -> CachedItarOffsetReader: return self._thread_local._cached_offset_reader + def worker_init(self): + self._thread_local._cached_offset_reader = CachedItarOffsetReader( + cache_size=self._itar_cache_size + ) + + def worker_close(self): + if hasattr(self._thread_local, "_cached_offset_reader"): + self._thread_local._cached_offset_reader.close() + del self._thread_local._cached_offset_reader + def _get_itar_sample_pointer(self, idx: int) -> ITarSamplePointer: """ Get the ITarSample object for the given index. @@ -500,7 +503,7 @@ def _get_itar_sample_pointer(self, idx: int) -> ITarSamplePointer: # Now we know the tar file and the sample offset in the file. # We need to figure out the byte offset and size of the sample, # by looking it up in the .tar.idx file. - byte_offset, byte_size = self.cached_offset_reader.get_itar_byte_offset( + byte_offset, byte_size = self._cached_offset_reader.get_itar_byte_offset( shard.path, sample_idx_in_shard_file ) @@ -527,9 +530,10 @@ class SqliteITarEntryReader(ITarReader[str]): A concrete ITarReader that constructs its internal sample list from a SQLite database. """ - sqlite_reader: SqliteIndexReader db_has_sample_parts: int + thread_local: threading.local + def __init__( self, base_path: EPath, @@ -546,12 +550,12 @@ def __init__( tar_filepaths = [base_path / fn for fn in tar_filenames] # Initialize the SQLite reader - sqlite_path = base_path / MAIN_FOLDER_NAME / "index.sqlite" - self.sqlite_reader = SqliteIndexReader(sqlite_path) - - self.db_has_sample_parts = self.sqlite_reader.db_has_sample_parts() + self.sqlite_path = base_path / MAIN_FOLDER_NAME / "index.sqlite" + with SqliteIndexReader(self.sqlite_path) as check_db: + self.db_has_sample_parts = check_db.db_has_sample_parts() self.key_is_full_entryname = key_is_full_entryname + self.thread_local = threading.local() super().__init__( base_path=base_path, @@ -562,12 +566,24 @@ def __init__( sample_filter=sample_filter, ) + @property + def _sqlite_reader(self) -> SqliteIndexReader: + return self.thread_local._sqlite_reader + + def worker_init(self): + self.thread_local._sqlite_reader = SqliteIndexReader(self.sqlite_path) + + def worker_close(self): + if hasattr(self.thread_local, "_sqlite_reader"): + self.thread_local._sqlite_reader.close() + del self.thread_local._sqlite_reader + def _get_itar_sample_pointer(self, sample_key: str) -> ITarSamplePointer: """ Get the ITarSample object for the given index. """ - return self.sqlite_reader.get_sample_pointer_by_key(sample_key) + return self._sqlite_reader.get_sample_pointer_by_key(sample_key) def list_all_samples(self) -> Generator[Tuple[str, int, int], None, None]: """List all samples in the jsonl file. @@ -575,7 +591,7 @@ def list_all_samples(self) -> Generator[Tuple[str, int, int], None, None]: Returns: A generator of tuples of (sample_key, size, tar_file_id) """ - return self.sqlite_reader.list_all_samples() + return self._sqlite_reader.list_all_samples() def list_all_sample_parts(self) -> Generator[Tuple[str, int, int], None, None]: """List all sample parts in the jsonl file. @@ -583,7 +599,7 @@ def list_all_sample_parts(self) -> Generator[Tuple[str, int, int], None, None]: Returns: A generator of tuples of (sample_key + "." + part_name, size, tar_file_id) """ - return self.sqlite_reader.list_all_sample_parts() + return self._sqlite_reader.list_all_sample_parts() def list_sample_parts( self, sample_key: str, slow_mode: bool = False @@ -605,7 +621,7 @@ def list_sample_parts( """ if not slow_mode: - yield from self.sqlite_reader.list_sample_parts(sample_key) + yield from self._sqlite_reader.list_sample_parts(sample_key) else: sample_pointer = self._get_itar_sample_pointer(sample_key) @@ -617,7 +633,7 @@ def list_sample_parts( yield ext, len(sample[ext]), sample_pointer.tar_file_id def get_total_size(self) -> int: - return self.sqlite_reader.get_total_size() + return self._sqlite_reader.get_total_size() @overload def __getitem__(self, key: str) -> Union[FilteredSample, tuple[bytes, SourceInfo]]: ... @@ -648,7 +664,7 @@ def __getitem__( if self.db_has_sample_parts: # Directly fetch the sample part (byte offset and size) from the database - raw_sample_pointer = self.sqlite_reader.get_sample_part(sample_key, sample_ext) + raw_sample_pointer = self._sqlite_reader.get_sample_part(sample_key, sample_ext) raw_data, source_info = self._get_part_by_raw_sample_pointer( raw_sample_pointer, key ) @@ -675,7 +691,7 @@ def __getitem__( def __len__(self) -> int: """Return the total number of samples in the database.""" - return self.sqlite_reader.get_sample_count() + return self._sqlite_reader.get_sample_count() def __str__(self) -> str: """Return a descriptive string of this reader.""" @@ -689,12 +705,11 @@ def __str__(self) -> str: def close(self): """Close the SQLite reader and any open ITarFiles.""" # Close the SQLite reader - if hasattr(self, "sqlite_reader") and self.sqlite_reader is not None: - self.sqlite_reader.close() + if hasattr(self, "_sqlite_reader") and self._sqlite_reader is not None: + self._sqlite_reader.close() # Close any open ITarFiles (using parent class implementation) - for tar_file_id in list(self.itar_files_cache.keys()): - tar_file = self.itar_files_cache[tar_file_id] + for tar_file in self.itar_files_cache.flush(): if ( tar_file is not None and hasattr(tar_file, "fileobj") @@ -703,4 +718,3 @@ def close(self): tar_file.fileobj.close() if tar_file is not None and hasattr(tar_file, "close"): tar_file.close() - del self.itar_files_cache[tar_file_id] diff --git a/src/megatron/energon/flavors/webdataset/multi_key_cache.py b/src/megatron/energon/flavors/webdataset/multi_key_cache.py new file mode 100644 index 00000000..40865591 --- /dev/null +++ b/src/megatron/energon/flavors/webdataset/multi_key_cache.py @@ -0,0 +1,72 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +from typing import Generator, Generic, TypeVar, overload + +T_key = TypeVar("T_key") +T_value = TypeVar("T_value") + + +class MultiKeyCache(Generic[T_key, T_value]): + """A cache that can store multiple values for the same key.""" + + _size: int + _cache: dict[T_key, list[T_value]] + _lru_keys: list[T_key] + + def __init__(self) -> None: + self._size = 0 + self._cache = {} + self._lru_keys = [] + + @overload + def pop(self, key: None = None) -> T_value: ... + + @overload + def pop(self, key: T_key) -> T_value | None: ... + + def pop(self, key: T_key | None = None) -> T_value | None: + """Pop the value for the given key from the cache. + + If no key is provided, pop the oldest key from the cache. + + Args: + key: The key to pop from the cache. If None, pop the oldest key from the cache. + + Returns: + The value popped from the cache. + """ + if key is None: + key = self._lru_keys.pop(0) + elif key not in self._cache: + return None + else: + self._lru_keys.pop(len(self._lru_keys) - 1 - self._lru_keys[::-1].index(key)) + + l = self._cache[key] + value = l.pop(0) + if len(l) == 0: + del self._cache[key] + self._size -= 1 + return value + + def add(self, key: T_key, value: T_value) -> None: + """Add a value to the cache.""" + if key not in self._cache: + self._cache[key] = [value] + else: + self._cache[key].insert(0, value) + + self._lru_keys.append(key) + self._size += 1 + + def flush(self) -> Generator[T_value, None, None]: + """Flush the cache.""" + for values in self._cache.values(): + for value in values: + yield value + self._cache.clear() + self._lru_keys.clear() + self._size = 0 + + def __len__(self) -> int: + return self._size diff --git a/src/megatron/energon/flavors/webdataset/sample_loader.py b/src/megatron/energon/flavors/webdataset/sample_loader.py index 81f8a010..f067c370 100644 --- a/src/megatron/energon/flavors/webdataset/sample_loader.py +++ b/src/megatron/energon/flavors/webdataset/sample_loader.py @@ -7,7 +7,7 @@ from megatron.energon.edataclass import edataclass from megatron.energon.flavors.base_dataset import FlexState, RestoreKey, SavableDataset -from megatron.energon.flavors.webdataset.itar_reader import ITarReader +from megatron.energon.flavors.webdataset.itar_reader import RawSampleReaderInterface from megatron.energon.flavors.webdataset.structs import FilteredSample, WebdatasetRestoreKey from megatron.energon.rng import WorkerRng from megatron.energon.worker import WorkerConfig @@ -37,7 +37,7 @@ class WebdatasetSampleLoaderDataset(SavableDataset[RawSampleData]): """Internal class for loading samples from webdataset slices""" #: The readers for each joined dataset - join_readers: Sequence[ITarReader] + join_readers: Sequence[RawSampleReaderInterface[int]] #: The offsets of the slice slices to iterate over for each worker # On worker initialization, this is set to _slice_offsets for the current worker. @@ -89,7 +89,7 @@ class WebdatasetSampleLoaderDataset(SavableDataset[RawSampleData]): def __init__( self, - join_readers: Sequence[ITarReader], + join_readers: Sequence[RawSampleReaderInterface[int]], workers_sample_slice_offsets: Sequence[Sequence[int]], *, worker_config: WorkerConfig, @@ -134,6 +134,8 @@ def reset_state(self) -> None: self._epoch_count = 0 self._epoch_sample_count = 0 self._slice_offsets = self.workers_slice_offsets[self.worker_config.rank_worker_id()] + for reader in self.join_readers: + reader.worker_init() def _get_sample(self, index: int) -> RawSampleData: return RawSampleData( @@ -446,6 +448,16 @@ def restore_sample(self, restore_key: RestoreKey) -> RawSampleData: ) return self._get_sample(restore_key.index) + def worker_close(self) -> None: + for reader in self.join_readers: + reader.worker_close() + super().worker_close() + + def close(self) -> None: + for reader in self.join_readers: + reader.close() + super().close() + def config(self) -> Dict[str, Any]: return { "type": type(self).__qualname__, diff --git a/src/megatron/energon/retry_stream.py b/src/megatron/energon/retry_stream.py deleted file mode 100644 index b678dddf..00000000 --- a/src/megatron/energon/retry_stream.py +++ /dev/null @@ -1,94 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. -# SPDX-License-Identifier: BSD-3-Clause - -from types import TracebackType -from typing import BinaryIO, Optional, Type - -from megatron.energon.epathlib import EPath - - -class RetryReadStream(BinaryIO): - """A stream that retries reading from a file. Only supports reading bytes.""" - - _path: EPath - _file: Optional[BinaryIO] - _pos: int - _size: int - - def __init__(self, path: EPath): - """Construct a RetryReadStream. It reads only bytes from a file.""" - self._path = path - self._file = None - self._pos = 0 - self._size = path.size() - - def __enter__(self) -> "RetryReadStream": - return self - - def __exit__( - self, exc_type: Type[BaseException], exc_val: BaseException, exc_tb: TracebackType - ) -> None: - self.close() - - def close(self) -> None: - if self._file is not None: - self._file.close() - - def read(self, n: int = -1) -> bytes: - buf = b"" - for retry in range(10): - try: - if self._file is None: - self._file = self._path.open("rb") - self._file.seek(self._pos) - res = self._file.read(n) - self._pos += len(res) - buf += res - if ( - (n == -1 and self._pos >= self._size) - or len(buf) == n - or self._pos >= self._size - ): - return res - except IOError: - try: - self._file.close() - except IOError: - pass - self._file = None - if retry == 9: - raise - continue - - def seek(self, offset: int, whence: int = 0) -> int: - if whence == 0: - pass - elif whence == 1: - offset += self._pos - elif whence == 2: - offset += self._size - else: - raise ValueError(f"Invalid whence value: {whence}") - offset = min(max(offset, 0), self._size) - self._pos = offset - try: - if self._file is not None: - self._file.seek(offset) - except IOError: - pass - return self._pos - - def tell(self) -> int: - return self._pos - - def isatty(self) -> bool: - return False - - def readable(self) -> bool: - return True - - def seekable(self) -> bool: - return True - - def writable(self) -> bool: - return False diff --git a/src/megatron/energon/task_encoder/base.py b/src/megatron/energon/task_encoder/base.py index 0b1b0da5..e8bc9df4 100644 --- a/src/megatron/energon/task_encoder/base.py +++ b/src/megatron/energon/task_encoder/base.py @@ -56,6 +56,7 @@ PackingDataset, ShuffleBufferDataset, ) +from megatron.energon.wrappers.file_store_init_wrapper import FileStoreInitWrapper from megatron.energon.wrappers.repeat_dataset import RepeatDataset T = TypeVar("T") @@ -774,16 +775,26 @@ def build_cook_crude_sample( aux = {k: DecodeFileStore(v, decoder=self.decoder) for k, v in aux.items()} # Cache the primary auxiliary dataset for this dataset, i.e. construct it once when needed - primary_aux = None + primary_aux: Optional[FileStore] = None + + all_aux_datasets = list(aux.values()) if aux is not None else [] def _get_primary_aux(): + # Note: This is happening on-the-fly in the worker when this dataset is actually used + # I.e. we don't know ahead that a cooker with primary=True is going to be used for this dataset + # (it may be a cooker with primary=False). Thus this happens on-the-fly in the worker when + # this dataset is actually used by a cooker with primary=True. nonlocal primary_aux if primary_aux is None: try: if aux is not None: primary_aux = aux.get("primary") if primary_aux is None: + # In the worker. Initialize now! primary_aux = get_primary_aux() + primary_aux.worker_init() + # We modify this list on-the-fly. It should then still deinitialize when the worker closes. + all_aux_datasets.append(primary_aux) assert primary_aux is not None, "Primary auxiliary dataset must always exist" if self.decoder is not None: primary_aux = DecodeFileStore(primary_aux, decoder=self.decoder) @@ -799,22 +810,28 @@ def _get_primary_aux(): else: cook_fn = functools.partial(self.cook_crude_sample, get_primary_aux=_get_primary_aux) - return MapDataset( - dataset, - cook_fn, - worker_config=worker_config, - stateless_map_fn=True, - map_fn_config=dict( - cookers=[ - dict( - cook=SavableDataset._function_config(cooker.cook), - has_subflavors=cooker.has_subflavors, - ) - for cooker in self.cookers - ], - subflavors=subflavors, + return FileStoreInitWrapper( + MapDataset( + dataset, + cook_fn, + worker_config=worker_config, + stateless_map_fn=True, + map_fn_config=dict( + cookers=[ + dict( + cook=SavableDataset._function_config(cooker.cook), + has_subflavors=cooker.has_subflavors, + ) + for cooker in self.cookers + ], + subflavors=subflavors, + ), + failure_tolerance=get_failure_tolerance( + cook_fn, self.__default_failure_tolerance__ + ), ), - failure_tolerance=get_failure_tolerance(cook_fn, self.__default_failure_tolerance__), + auxiliary_datasets=all_aux_datasets, + worker_config=worker_config, ) def _load_dataset( @@ -1073,10 +1090,9 @@ def cache(self) -> CachePool: assert WorkerConfig.active_worker_config is not None, ( "The cache can only be fetched within the worker, and to be usable, you must use the get_(savable_)loader methods provided from the package." ) - assert WorkerConfig.active_worker_config._active_state.cache_pool is not None, ( - "Cache pool must be set by the loader." - ) - return WorkerConfig.active_worker_config._active_state.cache_pool + cache_pool = WorkerConfig.active_worker_cache_pool() + assert cache_pool is not None, "Cache pool must be set by the loader." + return cache_pool # State fields, they are initialized when the dataloader is started. @property diff --git a/src/megatron/energon/tools/checkpoint.py b/src/megatron/energon/tools/checkpoint.py index 786d26e5..121954c1 100644 --- a/src/megatron/energon/tools/checkpoint.py +++ b/src/megatron/energon/tools/checkpoint.py @@ -218,7 +218,7 @@ def __iter__(self) -> Generator[tuple[WorkerState | None, list[RestoreKey | None """Iterates the WorkerStates of multiple ranks in a round-robin fashion.""" for rank_state in self.rank_states: for worker_state, prefetched_samples_keys in zip( - rank_state.worker_states, rank_state.prefetched_samples_keys + rank_state.worker_states, rank_state.prefetched_restore_keys ): yield worker_state, prefetched_samples_keys @@ -370,7 +370,7 @@ def command_redist( worker_states=[worker_state for worker_state, prefetched_sample_keys in new_rank_state], next_worker_id=0, # Reset the next worker ID micro_batch_size=new_micro_batch_size, - prefetched_samples_keys=[ + prefetched_restore_keys=[ split_batch_restore_keys(prefetched_sample_keys, batch_split_factor) for worker_state, prefetched_sample_keys in new_rank_state ], diff --git a/src/megatron/energon/typed_converter.py b/src/megatron/energon/typed_converter.py index 23c4adf7..9d3f39c7 100644 --- a/src/megatron/energon/typed_converter.py +++ b/src/megatron/energon/typed_converter.py @@ -5,10 +5,8 @@ import dataclasses import importlib import inspect -import json -import re import typing -from typing import Any, Callable, Dict, Literal, Optional, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Literal, Optional, Tuple, Type, TypeVar, Union from megatron.energon.edataclass import edataclass from megatron.energon.epathlib import EPath @@ -63,6 +61,8 @@ def _check_instance_type(cls, inst_type: Type) -> bool: @edataclass class JsonParser: strict: bool = True + #: If true, parse '__module__' and '__class__/__function__' attributes to allow explicit instantiation of types + allow_imports: bool = True def _resolve_object( self, @@ -77,7 +77,7 @@ def _resolve_object( """Resolve an object reference to the object.""" try: module = importlib.import_module(module_name) - except ModuleNotFoundError: + except ModuleNotFoundError: # pragma: no cover try_energon_module = importlib.import_module("megatron.energon", package=None) if hasattr(try_energon_module, object_name): module = try_energon_module @@ -85,7 +85,7 @@ def _resolve_object( raise try: return getattr(module, object_name) - except AttributeError: + except AttributeError: # pragma: no cover raise ModuleNotFoundError(f"Object {object_name} not found in {module_name}") def raw_to_instance( @@ -188,7 +188,7 @@ def raw_to_instance( inst = cls else: # Do not assert the other cases, we fallback to the passed cls - inst = self.safe_call_function(kwargs, cls, allow_imports=True) + inst = self.safe_call_function(kwargs, cls) assert not isinstance(cls, type) or _check_instance_type(type(inst), inst_type), ( f"Expected {inst_type}, got {cls}" ) @@ -198,7 +198,6 @@ def raw_to_typed( # noqa: C901 self, raw_data: Union[dict, list, str, int, bool, float, None], inst_type: Type[TType], - allow_imports: bool = False, _path: str = "root", _stage: Tuple[int, ...] = (), ) -> TType: @@ -217,8 +216,6 @@ class MyNamedTuple(NamedTuple): Args: raw_data: The raw (e.g. json) data to be made as `inst_type` inst_type: The type to return - allow_imports: If true, parse '__module__' and '__class__/__function__' attributes to allow explicit - instantiation of types _path: (internal for recursive call) The path to the object being converted from the root _stage: (internal for recursive call) Numbers representing the position of the current object being converted from the root @@ -227,7 +224,7 @@ class MyNamedTuple(NamedTuple): The input data as `inst_type`. """ type_name = getattr(inst_type, "__name__", repr(inst_type)) - if raw_data is _missing_value: + if raw_data is _missing_value: # pragma: no cover raise JsonValueError( f"Missing value at {_path}", inst_type, @@ -239,7 +236,7 @@ class MyNamedTuple(NamedTuple): # Literal types or missing data if not isinstance(raw_data, inst_type) and not ( isinstance(raw_data, int) and inst_type is float - ): + ): # pragma: no cover raise JsonValueError( f"Type does not match, expected {type_name} at {_path}, got {raw_data!r}", inst_type, @@ -250,7 +247,7 @@ class MyNamedTuple(NamedTuple): return raw_data elif inst_type is Any: if ( - allow_imports + self.allow_imports and isinstance(raw_data, dict) and "__module__" in raw_data and ("__class__" in raw_data or "__function__" in raw_data) @@ -261,7 +258,7 @@ class MyNamedTuple(NamedTuple): elif typing.get_origin(inst_type) is Literal: # Literal[value[, ...]] values = typing.get_args(inst_type) - if raw_data not in values: + if raw_data not in values: # pragma: no cover raise JsonValueError( f"Expected {type_name} at {_path}, got {raw_data!r}", inst_type, @@ -284,7 +281,6 @@ class MyNamedTuple(NamedTuple): return self.raw_to_typed( raw_data, subtype, - allow_imports, f"{_path} -> {getattr(subtype, '__name__', repr(subtype))}", _stage + (1,), ) @@ -304,7 +300,7 @@ class MyNamedTuple(NamedTuple): except JsonValueError as e: cur_exc = e raise cur_exc - else: + else: # pragma: no cover raise JsonValueError( f"Expected {inst_type} at {_path}, got {raw_data!r}", inst_type, @@ -312,6 +308,13 @@ class MyNamedTuple(NamedTuple): _path, _stage, ) + elif ( + self.allow_imports + and isinstance(raw_data, dict) + and "__module__" in raw_data + and ("__class__" in raw_data or "__function__" in raw_data) + ): + return self.raw_to_instance(raw_data, inst_type, _path=_path, _stage=_stage) elif ( isinstance(inst_type, type) and issubclass(inst_type, tuple) @@ -333,7 +336,6 @@ class MyNamedTuple(NamedTuple): field_name: self.raw_to_typed( raw_data.get(field_name, defaults.get(field_name, _missing_value)), field_type, - allow_imports, f"{_path} -> {type_name}:{field_name}", _stage + (idx,), ) @@ -359,7 +361,7 @@ class MyNamedTuple(NamedTuple): ) elif dataclasses.is_dataclass(inst_type): # dataclass - if not isinstance(raw_data, dict): + if not isinstance(raw_data, dict): # pragma: no cover raise JsonValueError( f"Expected {type_name} at {_path}, got {raw_data!r}", inst_type, @@ -367,25 +369,24 @@ class MyNamedTuple(NamedTuple): _path, _stage, ) - kwargs = { - field.name: self.raw_to_typed( - raw_data.get( - field.name, - ( - ( - _missing_value - if field.default_factory is dataclasses.MISSING - else field.default_factory() - ) - if field.default is dataclasses.MISSING - else field.default - ), - ), + + def get_field_value(field: dataclasses.Field, idx: int) -> Any: + value = raw_data.get(field.name, _missing_value) + if value is _missing_value: + # Use the factory value directly, without going through the conversion + if field.default_factory is not dataclasses.MISSING: + return field.default_factory() + elif field.default is not dataclasses.MISSING: + return field.default + return self.raw_to_typed( + value, field.type, - allow_imports, f"{_path} -> {type_name}:{field.name}", _stage + (idx,), ) + + kwargs = { + field.name: get_field_value(field, idx) for idx, field in enumerate(dataclasses.fields(inst_type)) if field.init } @@ -401,7 +402,7 @@ class MyNamedTuple(NamedTuple): ) try: return inst_type(**kwargs) - except BaseException: + except BaseException: # pragma: no cover raise JsonValueError( f"Expected {type_name} at {_path}, got {raw_data!r}", inst_type, @@ -412,7 +413,7 @@ class MyNamedTuple(NamedTuple): elif typing.get_origin(inst_type) is list: # List[inner_type] (inner_type,) = typing.get_args(inst_type) - if not isinstance(raw_data, list): + if not isinstance(raw_data, list): # pragma: no cover raise JsonValueError( f"Expected {type_name} at {_path}, got {raw_data!r}", inst_type, @@ -421,15 +422,13 @@ class MyNamedTuple(NamedTuple): _stage, ) return [ - self.raw_to_typed( - val, inner_type, allow_imports, f"{_path} -> {idx}", _stage + (idx,) - ) + self.raw_to_typed(val, inner_type, f"{_path} -> {idx}", _stage + (idx,)) for idx, val in enumerate(raw_data) ] elif typing.get_origin(inst_type) is set: # Set[inner_type] (inner_type,) = typing.get_args(inst_type) - if not isinstance(raw_data, list): + if not isinstance(raw_data, list): # pragma: no cover raise JsonValueError( f"Expected {type_name} at {_path}, got {raw_data!r}", inst_type, @@ -438,12 +437,10 @@ class MyNamedTuple(NamedTuple): _stage, ) res = set( - self.raw_to_typed( - val, inner_type, allow_imports, f"{_path} -> {idx}", _stage + (idx,) - ) + self.raw_to_typed(val, inner_type, f"{_path} -> {idx}", _stage + (idx,)) for idx, val in enumerate(raw_data) ) - if len(res) != len(raw_data): + if len(res) != len(raw_data): # pragma: no cover raise JsonValueError( f"Duplicate element at {_path}", inst_type, @@ -455,7 +452,7 @@ class MyNamedTuple(NamedTuple): elif typing.get_origin(inst_type) is tuple: # Tuple[inner_types[0], inner_types[1], ...] or Tuple[inner_types[0], Ellipsis/...] inner_types = typing.get_args(inst_type) - if not isinstance(raw_data, list): + if not isinstance(raw_data, (list, tuple)): # pragma: no cover raise JsonValueError( f"Expected {type_name} at {_path}, got {raw_data!r}", inst_type, @@ -467,15 +464,13 @@ class MyNamedTuple(NamedTuple): # Tuple of arbitrary length, all elements same type # Tuple[inner_types[0], Ellipsis/...] return tuple( - self.raw_to_typed( - val, inner_types[0], allow_imports, f"{_path} -> {idx}", _stage + (idx,) - ) + self.raw_to_typed(val, inner_types[0], f"{_path} -> {idx}", _stage + (idx,)) for idx, val in enumerate(raw_data) ) else: # Fixed size/typed tuple # Tuple[inner_types[0], inner_types[1], ...] - if len(raw_data) != len(inner_types): + if len(raw_data) != len(inner_types): # pragma: no cover raise JsonValueError( f"Expected {type_name} at {_path}, got {raw_data!r}", inst_type, @@ -483,17 +478,15 @@ class MyNamedTuple(NamedTuple): _path, _stage, ) - return [ - self.raw_to_typed( - val, inner_type, allow_imports, f"{_path} -> {idx}", _stage + (idx,) - ) + return tuple( + self.raw_to_typed(val, inner_type, f"{_path} -> {idx}", _stage + (idx,)) for idx, (val, inner_type) in enumerate(zip(raw_data, inner_types)) - ] + ) elif typing.get_origin(inst_type) is dict: # Dict[str, value_type] key_type, value_type = typing.get_args(inst_type) assert key_type is str - if not isinstance(raw_data, dict): + if not isinstance(raw_data, dict): # pragma: no cover raise JsonValueError( f"Expected {type_name} at {_path}, got {raw_data!r}", inst_type, @@ -502,14 +495,12 @@ class MyNamedTuple(NamedTuple): _stage, ) return { - key: self.raw_to_typed( - val, value_type, allow_imports, f"{_path} -> {key!r}", _stage + (idx,) - ) + key: self.raw_to_typed(val, value_type, f"{_path} -> {key!r}", _stage + (idx,)) for idx, (key, val) in enumerate(raw_data.items()) } elif inst_type in (dict, list): # dict, list (no subtyping) - if not isinstance(raw_data, inst_type): + if not isinstance(raw_data, inst_type): # pragma: no cover raise JsonValueError( f"Expected {type_name} at {_path}, got {raw_data!r}", inst_type, @@ -521,7 +512,7 @@ class MyNamedTuple(NamedTuple): elif inst_type is EPath: if isinstance(raw_data, str): return EPath(raw_data) - elif not isinstance(raw_data, EPath): + elif not isinstance(raw_data, EPath): # pragma: no cover raise JsonValueError( f"Expected {type_name} at {_path}, got {raw_data!r}", inst_type, @@ -530,13 +521,6 @@ class MyNamedTuple(NamedTuple): _stage, ) return raw_data - elif ( - allow_imports - and isinstance(raw_data, dict) - and "__module__" in raw_data - and ("__class__" in raw_data or "__function__" in raw_data) - ): - return self.raw_to_instance(raw_data, inst_type, _path=_path, _stage=_stage) else: return raw_data @@ -544,7 +528,6 @@ def safe_call_function( self, raw_data: Union[dict, list, str, int, bool, float, None], fn: Callable[..., TType], - allow_imports: bool = False, ) -> TType: """ Converts raw data (i.e. dicts, lists and primitives) to typed call arguments. @@ -562,7 +545,6 @@ def fn(arg1: float, arg2: MyType, arg3) -> Any: raw_data: The raw (e.g. json) data to be made as `inst_type` fn: The function to call with the converted data strict: If true, don't allow additional attributes - allow_imports: If true, allow instantiating objects by specifying __module__ and __class__/__function__. Returns: The return value of `fn` @@ -587,15 +569,12 @@ def fn(arg1: float, arg2: MyType, arg3) -> Any: kwargs[key] = self.raw_to_typed( unused_args.pop(key, param.default), t, - allow_imports, _path=key, _stage=(idx,), ) elif param.kind == inspect.Parameter.VAR_KEYWORD: for arg_key, arg_val in unused_args.items(): - kwargs[arg_key] = self.raw_to_typed( - arg_val, t, allow_imports, _path=key, _stage=(idx,) - ) + kwargs[arg_key] = self.raw_to_typed(arg_val, t, _path=key, _stage=(idx,)) unused_args.clear() elif param.kind == inspect.Parameter.VAR_POSITIONAL: # No way to pass positional arguments @@ -607,7 +586,7 @@ def fn(arg1: float, arg2: MyType, arg3) -> Any: raise RuntimeError(f"Unknown parameter kind {param.kind!r}") if self.strict and len(unused_args) > 0: raise ValueError(f"Unexpected arguments: {unused_args!r}") - elif isinstance(raw_data, list): + elif isinstance(raw_data, list): # pragma: no cover unused_args = raw_data.copy() for idx, (key, param) in enumerate(parameters): t = Any if param.annotation is inspect.Parameter.empty else param.annotation @@ -616,11 +595,7 @@ def fn(arg1: float, arg2: MyType, arg3) -> Any: raise ValueError( f"Missing required positional-only argument {key!r} at index {idx}" ) - args.append( - self.raw_to_typed( - unused_args.pop(), t, allow_imports, _path=key, _stage=(idx,) - ) - ) + args.append(self.raw_to_typed(unused_args.pop(), t, _path=key, _stage=(idx,))) elif param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD: if param.default is inspect.Parameter.empty and len(unused_args) == 0: raise ValueError( @@ -630,14 +605,10 @@ def fn(arg1: float, arg2: MyType, arg3) -> Any: arg_val = param.default else: arg_val = unused_args.pop() - args.append( - self.raw_to_typed(arg_val, t, allow_imports, _path=key, _stage=(idx,)) - ) + args.append(self.raw_to_typed(arg_val, t, _path=key, _stage=(idx,))) elif param.kind == inspect.Parameter.VAR_POSITIONAL: for arg_val in unused_args: - args.append( - self.raw_to_typed(arg_val, t, allow_imports, _path=key, _stage=(idx,)) - ) + args.append(self.raw_to_typed(arg_val, t, _path=key, _stage=(idx,))) unused_args.clear() elif param.kind == inspect.Parameter.VAR_KEYWORD: # No way to pass keyword arguments @@ -648,424 +619,12 @@ def fn(arg1: float, arg2: MyType, arg3) -> Any: raise RuntimeError(f"Unknown parameter kind {param.kind!r}") if self.strict and len(unused_args) > 0: raise ValueError(f"Unexpected arguments: {unused_args!r}") - else: + else: # pragma: no cover raise ValueError( f"Cannot call function with raw data of type {type(raw_data)!r}, require list or dict" ) return fn(*args, **kwargs) - def override( # noqa: C901 - self, - value: TType, - overrides: Any, - inst_type: Optional[Type[TType]] = None, - allow_imports: bool = False, - _path: str = "root", - _stage: Tuple[int, ...] = (), - ) -> TType: - """ - Allows overriding values of a typed object using environment config. - Allows overriding single config variables, or whole objects. - - Examples:: - - class MyNamedTuple(NamedTuple): - x: int - y: str - - class MyNested(NamedTuple): - nested: MyNamedTuple - - assert override( - MyNested(nested=MyNamedTuple(x=42, y="foo")), - {'nested.x': 5}, - ) == MyNested(nested=MyNamedTuple(x=5, y="foo")) - assert override( - MyNested(nested=MyNamedTuple(x=42, y="foo")), - {'nested': '{"x": 5, "y": "bar"}'}, - ) == MyNested(nested=MyNamedTuple(x=5, y="bar")) - - Args: - value: The base value to override. - overrides: The overrides to apply - strict: If true, no additional keys are allowed - inst_type: If given, validate against this base type instead of the type of `value`. - allow_imports: If true, allow instantiating types with dicts of __module__ and __class__/__function__. - _path: Internal: The path to the current value. - _stage: Internal: The current stage of the override. - - Returns: - Same type as the input object (or `inst_type` if set), copied and updated from the - overrides. - """ - if inst_type is None: - inst_type = type(value) - type_name = getattr(inst_type, "__name__", repr(inst_type)) - if inst_type in (str, int, float, bool, None, type(None)): - # Literal types - if inst_type in (None, type(None)) and overrides == "None": - overrides = None - elif inst_type is bool and overrides in ("True", "true", "1", "False", "false", "0"): - overrides = overrides in ("True", "true", "1") - elif inst_type in (int, float) and isinstance(overrides, str): - overrides = inst_type(overrides) - if not isinstance(overrides, inst_type) and not ( - isinstance(overrides, int) and inst_type is float - ): - raise JsonValueError( - f"Type does not match, expected {type_name} at {_path}, got {overrides!r}", - inst_type, - overrides, - _path, - _stage, - ) - return overrides - elif inst_type is Any: - # Any - if isinstance(overrides, str): - if overrides.isnumeric(): - return int(overrides) - elif overrides == "True": - return True - elif overrides == "False": - return True - return overrides - if isinstance(value, (dict, list, tuple)): - # Merge with dict, list, str - return self.override(value, overrides, type(value), allow_imports, _path, _stage) - raise JsonValueError( - f"Expected {type_name} at {_path}, got {overrides!r}", - inst_type, - overrides, - _path, - _stage, - ) - elif typing.get_origin(inst_type) is Literal: - # Literal[value] - (value,) = typing.get_args(inst_type) - if value != overrides: - raise JsonValueError( - f"Expected {type_name} at {_path}, got {overrides!r}", - inst_type, - overrides, - _path, - _stage, - ) - return value - elif typing.get_origin(inst_type) is Union: - # Union[union_types[0], union_types[1], ...] - union_types = typing.get_args(inst_type) - if isinstance(overrides, str): - for subtype in union_types: - if subtype is None and overrides == "None": - return None - elif subtype is bool: - if overrides == "True": - return True - elif overrides == "False": - return False - elif subtype is int and overrides.strip().isnumeric(): - return int(overrides) - elif subtype is str: - return overrides - elif subtype is float and float_pattern.fullmatch(overrides): - return float(overrides) - if overrides.lstrip().startswith("{") or overrides.lstrip().startswith("["): - overrides = json.loads(overrides) - return self.raw_to_typed( - overrides, - inst_type, - allow_imports, - _path, - _stage, - ) - for subtype in union_types: - if _isinstance_deep(value, subtype): - return self.override( - value, - overrides, - subtype, - allow_imports, - f"{_path} -> {getattr(subtype, '__name__', repr(subtype))}", - _stage + (1,), - ) - raise JsonValueError( - f"Expected {type_name} at {_path}, existing is {value!r} which is invalid", - inst_type, - value, - _path, - _stage, - ) - elif ( - isinstance(inst_type, type) - and issubclass(inst_type, tuple) - and hasattr(inst_type, "__annotations__") - ): - # class MyClass(NamedTuple): ... - if not isinstance(overrides, (dict, str)): - raise JsonValueError( - f"Expected {type_name} at {_path}, got {overrides!r}", - inst_type, - overrides, - _path, - _stage, - ) - if isinstance(overrides, str): - return self.raw_to_typed( - json.loads(overrides), - inst_type, - allow_imports, - _path, - _stage, - ) - local_overrides = _split_dict_keys(overrides) - if getattr(inst_type, "__dash_keys__", "False"): - local_overrides = { - key.replace("-", "_"): val for key, val in local_overrides.items() - } - kwargs = { - field_name: ( - self.override( - getattr(value, field_name), - local_overrides.pop(field_name), - field_type, - allow_imports, - f"{_path} -> {type_name}:{field_name}", - _stage + (idx,), - ) - if field_name in local_overrides - else getattr(value, field_name) - ) - for idx, (field_name, field_type) in enumerate(inst_type.__annotations__.items()) - } - if self.strict and len(local_overrides) != 0: - raise JsonValueError( - f"Invalid config keys {', '.join(local_overrides.keys())} for {type_name} at " - f"{_path}", - inst_type, - overrides, - _path, - _stage, - ) - try: - return inst_type(**kwargs) - except BaseException: - raise JsonValueError( - f"Expected {type_name} at {_path}, got {overrides!r}", - inst_type, - overrides, - _path, - _stage, - ) - elif dataclasses.is_dataclass(inst_type): - # dataclass - if not isinstance(overrides, (dict, str)): - raise JsonValueError( - f"Expected {type_name} at {_path}, got {overrides!r}", - inst_type, - overrides, - _path, - _stage, - ) - if isinstance(overrides, str): - return self.raw_to_typed( - json.loads(overrides), - inst_type, - allow_imports, - _path, - _stage, - ) - local_overrides = _split_dict_keys(overrides) - if getattr(inst_type, "__dash_keys__", "False"): - local_overrides = { - key.replace("-", "_"): val for key, val in local_overrides.items() - } - kwargs = { - field.name: ( - self.override( - getattr(value, field.name), - local_overrides.pop(field.name), - field.type, - allow_imports, - f"{_path} -> {type_name}:{field.name}", - _stage + (idx,), - ) - if field.name in local_overrides - else getattr(value, field.name) - ) - for idx, field in enumerate(dataclasses.fields(inst_type)) - if field.init - } - if self.strict and len(local_overrides) != 0: - raise JsonValueError( - f"Invalid config keys {', '.join(local_overrides.keys())} for {type_name} at " - f"{_path}", - inst_type, - overrides, - _path, - _stage, - ) - try: - return inst_type(**kwargs) - except BaseException: - raise JsonValueError( - f"Expected {type_name} at {_path}, got {overrides!r}", - inst_type, - overrides, - _path, - _stage, - ) - elif ( - typing.get_origin(inst_type) is list - or typing.get_origin(inst_type) is tuple - or inst_type in (list, tuple) - ): - # List[inner_type] or Tuple[inner_type, Ellipsis] or - # Tuple[inner_type[0], inner_type[1], ...] - if inst_type is list: - inner_type = Any - inner_types = [] - cls = list - elif inst_type is tuple: - inner_type = Any - inner_types = [] - cls = tuple - elif typing.get_origin(inst_type) is list: - (inner_type,) = typing.get_args(inst_type) - inner_types = [] - cls = list - else: - inner_types = typing.get_args(inst_type) - if len(inner_types) == 2 and inner_types[1] is Ellipsis: - inner_type = inner_types[0] - else: - inner_type = None - cls = tuple - if not isinstance(overrides, (dict, str)): - raise JsonValueError( - f"Expected {type_name} at {_path}, got {overrides!r}", - inst_type, - overrides, - _path, - _stage, - ) - if isinstance(overrides, str): - return self.raw_to_typed( - json.loads(overrides), - inst_type, - allow_imports, - _path, - _stage, - ) - local_overrides = _split_dict_keys(overrides) - if not all(key.isnumeric() for key in local_overrides.keys()): - raise JsonValueError( - f"Expected {type_name} at {_path}, got {overrides!r}, expected integer keys", - inst_type, - overrides, - _path, - _stage, - ) - local_overrides_int = {int(key): value for key, value in local_overrides.items()} - new_max_idx = max(local_overrides_int.keys()) - original_max_idx = len(value) - if inner_type is None and new_max_idx >= len(inner_types): - raise JsonValueError( - f"Expected {type_name} at {_path}, got {overrides!r}, index {new_max_idx} out of " - f"bounds", - inst_type, - overrides, - _path, - _stage, - ) - for i in range(original_max_idx, new_max_idx): - if i not in local_overrides_int: - raise JsonValueError( - f"Expected {type_name} at {_path}, got {overrides!r}, missing value for index " - f"{i}", - inst_type, - overrides, - _path, - _stage, - ) - return cls( - ( - self.override( - value[idx], - local_overrides_int[idx], - inner_type, - allow_imports, - f"{_path} -> {idx}", - _stage + (idx,), - ) - if idx in local_overrides_int - else value[idx] - ) - for idx in range(max(new_max_idx + 1, original_max_idx)) - ) - elif typing.get_origin(inst_type) is dict or inst_type is dict: - # Dict[str, value_type] - if inst_type is dict: - value_type = Any - else: - key_type, value_type = typing.get_args(inst_type) - assert key_type is str - if not isinstance(overrides, (dict, str)): - raise JsonValueError( - f"Expected {type_name} at {_path}, got {overrides!r}", - inst_type, - overrides, - _path, - _stage, - ) - if isinstance(overrides, str): - return self.raw_to_typed( - json.loads(overrides), - inst_type, - allow_imports, - _path, - _stage, - ) - local_overrides = _split_dict_keys(overrides) - if getattr(inst_type, "__dash_keys__", "False"): - local_overrides = { - key.replace("-", "_"): val for key, val in local_overrides.items() - } - res = { - key: ( - self.override( - subvalue, - local_overrides.pop(key), - value_type, - allow_imports, - f"{_path} -> {type_name}:{key!r}", - _stage + (idx,), - ) - if key in local_overrides - else subvalue - ) - for idx, (key, subvalue) in value.items() - } - for key, val in local_overrides.items(): - if not isinstance(val, str): - raise JsonValueError( - f"Expected new {type_name} at {_path} -> {type_name}:{key!r}, got {val!r}", - inst_type, - overrides, - _path, - _stage, - ) - res[key] = self.raw_to_typed( - json.loads(val), - value_type, - allow_imports, - f"{_path} -> {type_name}:{key!r}", - _stage + (len(res),), - ) - return res - else: - raise RuntimeError(f"Unknown type {inst_type}") - def to_json_object(obj: Any) -> Any: """ @@ -1086,6 +645,16 @@ def to_json_object(obj: Any) -> Any: field_name: to_json_object(getattr(obj, field_name)) for field_name in obj.__annotations__.keys() } + elif isinstance(obj, type): + return { + "__module__": obj.__module__, + "__class__": obj.__name__, + } + elif isinstance(obj, Callable): + return { + "__module__": obj.__module__, + "__function__": obj.__name__, + } elif dataclasses.is_dataclass(obj): # dataclass return { @@ -1093,7 +662,7 @@ def to_json_object(obj: Any) -> Any: for field in dataclasses.fields(obj) if field.init } - elif isinstance(obj, (list, tuple)): + elif isinstance(obj, (list, tuple, set)): return [to_json_object(val) for val in obj] elif isinstance(obj, dict): return {key: to_json_object(val) for key, val in obj.items()} @@ -1101,41 +670,18 @@ def to_json_object(obj: Any) -> Any: raise RuntimeError(f"Unknown type {type(obj)}") -float_pattern = re.compile(r"[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)?") - - -def _split_dict_keys(dct: Dict[str, Any]) -> Dict[str, Any]: - """Splits the given dict keys by first '.' to subdicts.""" - res = {} - for key, value in dct.items(): - if "." in key: - outer_key, _, inner_key = key.partition(".") - if outer_key in res: - if not isinstance(res[outer_key], dict): - raise ValueError(f"Cannot combine {outer_key!r} with {res!r}") - res[outer_key][inner_key] = value - else: - res[outer_key] = {inner_key: value} - else: - if key in res: - raise ValueError(f"Cannot combine {key!r} with {res!r}") - res[key] = value - - return res - - def _isinstance_deep(val: Any, tp_chk: Type) -> bool: """Verifies if the given value is an instance of the tp_chk, allowing for typing extensions.""" if tp_chk is Any: return True elif typing.get_origin(tp_chk) is Literal: - (value,) = typing.get_args(val) - return val == value + values = typing.get_args(tp_chk) + return val in values elif typing.get_origin(tp_chk) is list: - (inner_type,) = typing.get_args(val) + (inner_type,) = typing.get_args(tp_chk) return isinstance(val, list) and all(_isinstance_deep(v, inner_type) for v in val) elif typing.get_origin(tp_chk) is tuple: - inner_types = typing.get_args(val) + inner_types = typing.get_args(tp_chk) if len(inner_types) == 2 and inner_types[1] == Ellipsis: return isinstance(val, tuple) and all(_isinstance_deep(v, inner_types[0]) for v in val) else: @@ -1145,7 +691,7 @@ def _isinstance_deep(val: Any, tp_chk: Type) -> bool: and all(_isinstance_deep(v, inner_type) for v, inner_type in zip(val, inner_types)) ) elif typing.get_origin(tp_chk) is dict: - key_type, value_type = typing.get_args(val) + key_type, value_type = typing.get_args(tp_chk) return isinstance(val, dict) and all( _isinstance_deep(k, key_type) and _isinstance_deep(v, value_type) for k, v in val.items() diff --git a/src/megatron/energon/worker.py b/src/megatron/energon/worker.py index 5759f3d3..df959d1f 100644 --- a/src/megatron/energon/worker.py +++ b/src/megatron/energon/worker.py @@ -31,11 +31,6 @@ def sample_index_stack(self) -> Optional[List[int]]: """The current sample index stack for the worker.""" return getattr(self._thread_local, "sample_index_stack", None) - @property - def override_global_rank(self) -> Optional[int]: - """The global rank override for the worker. Required for restoring samples.""" - return getattr(self._thread_local, "override_global_rank", None) - @property def cache_pool(self) -> Optional[CachePool]: """The current cache pool for the worker.""" @@ -49,10 +44,6 @@ def worker_config(self) -> "WorkerConfig | None": def sample_index_stack(self, value: List[int]): self._thread_local.sample_index_stack = value - @override_global_rank.setter - def override_global_rank(self, value: Optional[int]): - self._thread_local.override_global_rank = value - @cache_pool.setter def cache_pool(self, value: Optional[CachePool]): self._thread_local.cache_pool = value @@ -127,7 +118,6 @@ def active_worker_config(cls) -> Optional["WorkerConfig"]: def worker_activate( self, sample_index: int, - override_global_rank: Optional[int] = None, cache_pool: "Optional[CachePool]" = None, ): """Activates the worker config for the current worker and sets it as actively iterating. @@ -137,7 +127,6 @@ def worker_activate( ) WorkerConfig._active_state.sample_index_stack = [sample_index] WorkerConfig._active_state.worker_config = self - WorkerConfig._active_state.override_global_rank = override_global_rank WorkerConfig._active_state.cache_pool = cache_pool def worker_push_sample_index(self, sample_index: int): @@ -162,7 +151,6 @@ def worker_deactivate(self): ) WorkerConfig._active_state.sample_index_stack = None WorkerConfig._active_state.worker_config = None - WorkerConfig._active_state.override_global_rank = None WorkerConfig._active_state.cache_pool = None @property @@ -187,6 +175,11 @@ def active_worker_batch_index(self) -> int: + self.rank_worker_id() ) + @staticmethod + def active_worker_cache_pool() -> Optional[CachePool]: + """Returns the current cache pool for the actively iterating worker.""" + return WorkerConfig._active_state.cache_pool + @property def safe_num_workers(self) -> int: """Returns the number of workers, but at least 1.""" @@ -233,22 +226,6 @@ def default_worker_config( data_parallel_group=data_parallel_group, ) - def rank_worker_id(self) -> int: - """Returns the self worker id within the current rank.""" - if WorkerConfig._active_state.override_global_rank: - assert self.worker_id_offset == 0 - return WorkerConfig._active_state.override_global_rank % self.num_workers - worker_info = torch.utils.data.get_worker_info() - if worker_info is None: - return self.worker_id_offset - assert worker_info.num_workers == self.num_workers - # Apply the worker_id_offset as a left rotation of the logical worker ids. - # This ensures that after restoring a checkpoint the first physical - # worker (id=0) corresponds to the logical worker that should emit the - # next sample. For example, if `worker_id_offset` is 1, logical worker - # 1 becomes the first to emit a sample, shifting the ordering forward. - return (worker_info.id + self.worker_id_offset) % max(worker_info.num_workers, 1) - def assert_worker(self): """Checks if the current process is a worker (if configured so), and that the workers are properly configured.""" @@ -262,20 +239,49 @@ def assert_worker(self): f"match the configured number of workers ({self.num_workers})" ) + def rank_worker_id(self, override_global_worker_id: Optional[int] = None) -> int: + """Returns the self worker id within the current rank. + Optionally computes the worker id from a global worker id. + + Args: + override_global_worker_id: The global worker id to compute the rank worker id from. + None means the current worker, which is the default. If not set, must be called + within the worker. + """ + if override_global_worker_id is not None: + assert ( + self.rank * self.safe_num_workers + <= override_global_worker_id + < (self.rank + 1) * self.safe_num_workers + ), f"Invalid global worker id: {override_global_worker_id}" + return override_global_worker_id - self.rank * self.safe_num_workers + worker_info = torch.utils.data.get_worker_info() + if worker_info is None: + return self.worker_id_offset + assert worker_info.num_workers == self.num_workers + # Apply the worker_id_offset as a left rotation of the logical worker ids. + # This ensures that after restoring a checkpoint the first physical + # worker (id=0) corresponds to the logical worker that should emit the + # next sample. For example, if `worker_id_offset` is 1, logical worker + # 1 becomes the first to emit a sample, shifting the ordering forward. + return (worker_info.id + self.worker_id_offset) % max(worker_info.num_workers, 1) + def global_worker_id(self, override_local_worker_id: Optional[int] = None) -> int: """Returns the global worker index by multiplying the rank with the number of workers. Alternatively, you can override the local worker id. Args: - override_local_worker_id (int, optional): The local worker id to override. None means - the current worker, which is the default. + override_local_worker_id: The local worker id to override. None means + the current worker, which is the default. If not set, must be called + within the worker. """ if override_local_worker_id is not None: - return self.rank * self.num_workers + override_local_worker_id - if WorkerConfig._active_state.override_global_rank is not None: - return WorkerConfig._active_state.override_global_rank + assert 0 <= override_local_worker_id < self.safe_num_workers, ( + f"Invalid local worker id: {override_local_worker_id}" + ) + return self.rank * self.safe_num_workers + override_local_worker_id self.assert_worker() - return self.rank * self.num_workers + self.rank_worker_id() + return self.rank * self.safe_num_workers + self.rank_worker_id() def worker_seed(self, override_local_worker_id: Optional[int] = None) -> int: """Returns the seed for the current worker (or a specified worker). diff --git a/src/megatron/energon/wrappers/__init__.py b/src/megatron/energon/wrappers/__init__.py index 808ca50e..221a270c 100644 --- a/src/megatron/energon/wrappers/__init__.py +++ b/src/megatron/energon/wrappers/__init__.py @@ -5,6 +5,7 @@ from megatron.energon.wrappers.blend_dataset import BlendDataset from megatron.energon.wrappers.concat_dataset import ConcatDataset from megatron.energon.wrappers.epochize_dataset import EpochizeDataset +from megatron.energon.wrappers.file_store_init_wrapper import FileStoreInitWrapper from megatron.energon.wrappers.filter_dataset import FilterDataset from megatron.energon.wrappers.gc_dataset import GcDataset from megatron.energon.wrappers.group_batch_dataset import GroupBatchDataset @@ -24,6 +25,7 @@ from megatron.energon.wrappers.skip import SkipSample __all__ = [ + "FileStoreInitWrapper", "BatchDataset", "BlendDataset", "ConcatDataset", diff --git a/src/megatron/energon/wrappers/base.py b/src/megatron/energon/wrappers/base.py index f90bca21..55c405bf 100644 --- a/src/megatron/energon/wrappers/base.py +++ b/src/megatron/energon/wrappers/base.py @@ -111,6 +111,16 @@ def restore_sample(self, restore_key: RestoreKey) -> T_sample_out: assert len(self.datasets) == 1, "Must be implemented by subclass" return self.dataset.restore_sample(restore_key) + def worker_close(self) -> None: + for ds in self.datasets: + ds.worker_close() + super().worker_close() + + def close(self) -> None: + for ds in self.datasets: + ds.close() + super().close() + class SampleIndex(Savable): """A simple class to hold the sample index for one worker.""" diff --git a/src/megatron/energon/wrappers/file_store_init_wrapper.py b/src/megatron/energon/wrappers/file_store_init_wrapper.py new file mode 100644 index 00000000..d47aa951 --- /dev/null +++ b/src/megatron/energon/wrappers/file_store_init_wrapper.py @@ -0,0 +1,77 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause + +from typing import Any, Dict, Generic, Iterator, Sequence, TypeVar + +from megatron.energon.cache.base import FileStore +from megatron.energon.flavors.base_dataset import SavableDataset +from megatron.energon.state import FlexState +from megatron.energon.worker import WorkerConfig +from megatron.energon.wrappers.base import BaseWrapperDataset + +T_sample = TypeVar("T_sample") + + +class FileStoreInitWrapper(BaseWrapperDataset[T_sample, T_sample], Generic[T_sample]): + """This dataset wraps a primary dataset, and additional auxiliary datasets for (de)initialization.""" + + auxiliary_datasets: Sequence[FileStore] + + def __init__( + self, + dataset: SavableDataset[T_sample], + *, + auxiliary_datasets: Sequence[FileStore] = (), + worker_config: WorkerConfig, + ): + """Construct the auxiliary datasets dataset, which wraps a primary dataset and additional + auxiliary datasets for initialization. + + Args: + dataset: The input dataset to wrap + auxiliary_datasets: The additional datasets to (de)initialize + worker_config: The worker configuration + """ + super().__init__(dataset, worker_config=worker_config) + self.auxiliary_datasets = auxiliary_datasets + + def reset_state_own(self) -> None: + for ds in self.auxiliary_datasets: + ds.worker_init() + cache_pool = self.worker_config.active_worker_cache_pool() + if cache_pool is not None: + cache_pool.worker_init() + + def worker_close(self) -> None: + for ds in self.auxiliary_datasets: + ds.worker_close() + cache_pool = self.worker_config.active_worker_cache_pool() + if cache_pool is not None: + cache_pool.worker_close() + super().worker_close() + + def close(self) -> None: + for ds in self.auxiliary_datasets: + ds.close() + cache_pool = self.worker_config.active_worker_cache_pool() + if cache_pool is not None: + cache_pool.close() + super().close() + + def __iter__(self) -> Iterator[T_sample]: + yield from self.dataset + + def save_state(self) -> FlexState: + # Just delegate, make self transparent + return self.dataset.save_state() + + def restore_state(self, state: FlexState): + # Just delegate, make self transparent + return self.dataset.restore_state(state) + + def config(self) -> Dict[str, Any]: + # Transparent logger, it won't change the samples + return self.dataset.config() + + def __str__(self): + return f"FileStoreInitWrapper(auxiliary_datasets={self.auxiliary_datasets}, dataset={self.dataset})" diff --git a/tests/test_av_decoder.py b/tests/test_av_decoder.py index 7e28acb3..b43e7269 100644 --- a/tests/test_av_decoder.py +++ b/tests/test_av_decoder.py @@ -8,11 +8,11 @@ import os import sys import time -import unittest from pathlib import Path import av import numpy as np +import pytest import torch import torchvision.transforms as transforms @@ -71,152 +71,136 @@ def tensors_close(tensor1: torch.Tensor, tensor2: torch.Tensor, tolerance: float return mae <= tolerance -class TestVideoDecode(unittest.TestCase): - """Test video decoding functionality.""" - - def setUp(self): - """Set up test fixtures.""" - logging.basicConfig(stream=sys.stderr, level=logging.INFO) - self.decode_baseline_video_pyav() - self.loaders = [] # Keep track of loaders for cleanup - - def tearDown(self): - """Clean up test fixtures.""" - # Clean up any loaders - for loader in self.loaders: - if hasattr(loader, "_iterator"): - loader._iterator = None - if hasattr(loader, "_shutdown_workers"): - try: - loader._shutdown_workers() - except Exception: - pass - - def decode_baseline_video_pyav(self): - """Load the baseline video using PyAV directly.""" - self.complete_video_tensor = load_video_to_tensor("tests/data/sync_test.mp4") - - def test_decode_all_frames(self): - """Test decoding all frames from a video file.""" - av_decoder = AVDecoder(io.BytesIO(Path("tests/data/sync_test.mp4").read_bytes())) - av_data = av_decoder.get_frames() - video_tensor = av_data.video_clips[0] - - print(video_tensor.shape) - assert (video_tensor == self.complete_video_tensor).all(), ( - "Energon decoded video does not match baseline" +@pytest.fixture +def video_test_setup(): + """Set up test fixtures for video tests.""" + logging.basicConfig(stream=sys.stderr, level=logging.INFO) + complete_video_tensor = load_video_to_tensor("tests/data/sync_test.mp4") + yield complete_video_tensor + + +def test_decode_all_frames(video_test_setup): + """Test decoding all frames from a video file.""" + av_decoder = AVDecoder(io.BytesIO(Path("tests/data/sync_test.mp4").read_bytes())) + av_data = av_decoder.get_frames() + video_tensor = av_data.video_clips[0] + + print(video_tensor.shape) + assert (video_tensor == video_test_setup).all(), "Energon decoded video does not match baseline" + + +def test_decode_video_metadata(video_test_setup): + """Test decoding metadata.""" + expected_metadata = [ + AVMetadata( + video_duration=63.054, + video_num_frames=1891, + video_fps=30.0, + video_width=192, + video_height=108, + audio_duration=63.103, + audio_channels=2, + audio_sample_rate=48000, + ), + AVMetadata( + video_duration=63.03333333333333, + video_num_frames=1891, + video_fps=30.0, + video_width=192, + video_height=108, + audio_duration=63.068, + audio_channels=2, + audio_sample_rate=48000, + ), + ] + for video_file, expected_metadata in zip( + ["tests/data/sync_test.mkv", "tests/data/sync_test.mp4"], expected_metadata + ): + av_decoder = AVDecoder(io.BytesIO(Path(video_file).read_bytes())) + assert av_decoder.get_metadata() == expected_metadata, ( + f"Metadata does not match expected metadata for {video_file}" ) - def test_decode_metadata(self): - """Test decoding metadata.""" - expected_metadata = [ - AVMetadata( - video_duration=63.054, - video_num_frames=1891, - video_fps=30.0, - video_width=192, - video_height=108, - audio_duration=63.103, - audio_channels=2, - audio_sample_rate=48000, - ), - AVMetadata( - video_duration=63.03333333333333, - video_num_frames=1891, - video_fps=30.0, - video_width=192, - video_height=108, - audio_duration=63.068, - audio_channels=2, - audio_sample_rate=48000, - ), - ] - for video_file, expected_metadata in zip( - ["tests/data/sync_test.mkv", "tests/data/sync_test.mp4"], expected_metadata - ): - av_decoder = AVDecoder(io.BytesIO(Path(video_file).read_bytes())) - assert av_decoder.get_metadata() == expected_metadata, ( - f"Metadata does not match expected metadata for {video_file}" - ) - - assert av_decoder.get_video_duration(get_frame_count=False) in ( - (expected_metadata.video_duration, None), - (expected_metadata.video_duration, expected_metadata.video_num_frames), - ) - assert av_decoder.get_video_duration(get_frame_count=True) == ( - expected_metadata.video_duration, - expected_metadata.video_num_frames, - ) - - assert av_decoder.get_audio_duration() == expected_metadata.audio_duration - assert av_decoder.get_video_fps() == expected_metadata.video_fps - assert av_decoder.get_audio_samples_per_second() == expected_metadata.audio_sample_rate - - def test_decode_strided_resized(self): - """Test decoding a subset of frames with resizing.""" - for video_file in ["tests/data/sync_test.mkv", "tests/data/sync_test.mp4"]: - print(f"================= Testing {video_file} ==================") - av_decoder = AVDecoder(io.BytesIO(Path(video_file).read_bytes())) - - video_tensor = get_single_frames_uniform( - av_decoder=av_decoder, - num_frames=64, - video_out_frame_size=(224, 224), - ) - - # Get strided frames from baseline complete video tensor - strided_baseline_tensor = self.complete_video_tensor[ - np.linspace(0, self.complete_video_tensor.shape[0] - 1, 64, dtype=int).tolist() - ] - # Now resize the baseline frames - resize = transforms.Resize((224, 224)) - strided_resized_baseline_tensor = resize(strided_baseline_tensor) - - # We allow small numerical differences due to different resize implementations - assert tensors_close(video_tensor, strided_resized_baseline_tensor, tolerance=0.01), ( - "Energon decoded video does not match baseline" - ) - - def test_video_audio_sync(self): - """Test decoding video frames and audio clips together.""" - av_decoder = AVDecoder(io.BytesIO(Path("tests/data/sync_test.mp4").read_bytes())) - - # Extract a single frame every 2 seconds and an audio clip (0.05 seconds long) at the same time. - # We extract the frames from the sync video that shows the full white circle on the left, - # when the click sound occurs. - # Note that the click sound is actually off by 0.022 secs in the original video, - # I verified this in Davinci Resolve. - av_data = av_decoder.get_clips( - video_clip_ranges=[(a * 2 + 1 / 30, a * 2 + 1 / 30) for a in range(65)], - audio_clip_ranges=[(a * 2 + 1 / 30, a * 2 + 1 / 30 + 0.05) for a in range(65)], - video_unit="seconds", - audio_unit="seconds", - video_out_frame_size=None, + assert av_decoder.get_video_duration(get_frame_count=False) in ( + (expected_metadata.video_duration, None), + (expected_metadata.video_duration, expected_metadata.video_num_frames), + ) + assert av_decoder.get_video_duration(get_frame_count=True) == ( + expected_metadata.video_duration, + expected_metadata.video_num_frames, + ) + + assert av_decoder.get_audio_duration() == expected_metadata.audio_duration + assert av_decoder.get_video_fps() == expected_metadata.video_fps + assert av_decoder.get_audio_samples_per_second() == expected_metadata.audio_sample_rate + + +def test_decode_strided_resized(video_test_setup): + """Test decoding a subset of frames with resizing.""" + for video_file in ["tests/data/sync_test.mkv", "tests/data/sync_test.mp4"]: + print(f"================= Testing {video_file} ==================") + av_decoder = AVDecoder(io.BytesIO(Path(video_file).read_bytes())) + + video_tensor = get_single_frames_uniform( + av_decoder=av_decoder, + num_frames=64, + video_out_frame_size=(224, 224), ) - # We drop the first two extracted frames because the click sequence hasn't started yet - video_clips = av_data.video_clips[2:] - audio_clips = av_data.audio_clips[2:] - # Then we check that the first extracted frame is all white in the area (18, 18, 55, 55) - # Image.fromarray(video_clips[0][0, :, 18:55, 18:55].numpy().transpose(1,2,0)).save('circ.png') - assert (video_clips[0][0, :, 18:55, 18:55] > 250).all(), ( - "First extracted frame is not all white in the area (18, 18, 55, 55)" + # Get strided frames from baseline complete video tensor + strided_baseline_tensor = video_test_setup[ + np.linspace(0, video_test_setup.shape[0] - 1, 64, dtype=int).tolist() + ] + # Now resize the baseline frames + resize = transforms.Resize((224, 224)) + strided_resized_baseline_tensor = resize(strided_baseline_tensor) + + # We allow small numerical differences due to different resize implementations + assert tensors_close(video_tensor, strided_resized_baseline_tensor, tolerance=0.01), ( + "Energon decoded video does not match baseline" ) - # Check that all the video frames are the same (close value) - for video_clip in video_clips: - assert tensors_close(video_clip, video_clips[0], tolerance=0.01), ( - "All video frames are not the same" - ) - # Check that the first audio clip has the click sound - assert (audio_clips[0] > 0.5).any(), "Audio click not found" +def test_video_audio_sync(video_test_setup): + """Test decoding video frames and audio clips together.""" + av_decoder = AVDecoder(io.BytesIO(Path("tests/data/sync_test.mp4").read_bytes())) + + # Extract a single frame every 2 seconds and an audio clip (0.05 seconds long) at the same time. + # We extract the frames from the sync video that shows the full white circle on the left, + # when the click sound occurs. + # Note that the click sound is actually off by 0.022 secs in the original video, + # I verified this in Davinci Resolve. + av_data = av_decoder.get_clips( + video_clip_ranges=[(a * 2 + 1 / 30, a * 2 + 1 / 30) for a in range(65)], + audio_clip_ranges=[(a * 2 + 1 / 30, a * 2 + 1 / 30 + 0.05) for a in range(65)], + video_unit="seconds", + audio_unit="seconds", + video_out_frame_size=None, + ) + + # We drop the first two extracted frames because the click sequence hasn't started yet + video_clips = av_data.video_clips[2:] + audio_clips = av_data.audio_clips[2:] + # Then we check that the first extracted frame is all white in the area (18, 18, 55, 55) + # Image.fromarray(video_clips[0][0, :, 18:55, 18:55].numpy().transpose(1,2,0)).save('circ.png') + assert (video_clips[0][0, :, 18:55, 18:55] > 250).all(), ( + "First extracted frame is not all white in the area (18, 18, 55, 55)" + ) + + # Check that all the video frames are the same (close value) + for video_clip in video_clips: + assert tensors_close(video_clip, video_clips[0], tolerance=0.01), ( + "All video frames are not the same" + ) - # Check that all the audio clips are the same (close value) - for audio_clip in audio_clips: - assert tensors_close(audio_clip, audio_clips[0], tolerance=0.01), ( - "All audio clips are not the same" - ) + # Check that the first audio clip has the click sound + assert (audio_clips[0] > 0.5).any(), "Audio click not found" + + # Check that all the audio clips are the same (close value) + for audio_clip in audio_clips: + assert tensors_close(audio_clip, audio_clips[0], tolerance=0.01), ( + "All audio clips are not the same" + ) def load_audio_to_tensor(audio_path: str) -> torch.Tensor: @@ -238,218 +222,195 @@ def load_audio_to_tensor(audio_path: str) -> torch.Tensor: return audio_tensor -class TestAudioDecode(unittest.TestCase): - """Test audio decoding functionality.""" - - def setUp(self): - """Set up test fixtures.""" - logging.basicConfig(stream=sys.stderr, level=logging.INFO) - self.decode_baseline_audio_pyav() - self.loaders = [] # Keep track of loaders for cleanup - - def tearDown(self): - """Clean up test fixtures.""" - # Clean up any loaders - for loader in self.loaders: - if hasattr(loader, "_iterator"): - loader._iterator = None - if hasattr(loader, "_shutdown_workers"): - try: - loader._shutdown_workers() - except Exception: - pass - - def decode_baseline_audio_pyav(self): - """Load the baseline audio using PyAV directly.""" - self.complete_audio_tensor = load_audio_to_tensor("tests/data/test_audio.flac") - - def test_decode_all_samples(self): - """Test decoding all samples from an audio file.""" - with open("tests/data/test_audio.flac", "rb") as f: - raw_bytes = f.read() - stream = io.BytesIO(raw_bytes) - - av_decoder = AVDecoder(stream) - av_data = av_decoder.get_audio() - audio_tensor = av_data.audio_clips[0] - - assert (audio_tensor == self.complete_audio_tensor).all(), ( - "Energon decoded audio does not match baseline" - ) +@pytest.fixture +def audio_test_setup(): + """Set up test fixtures for audio tests.""" + logging.basicConfig(stream=sys.stderr, level=logging.INFO) + complete_audio_tensor = load_audio_to_tensor("tests/data/test_audio.flac") + yield complete_audio_tensor + + +def test_decode_all_samples(audio_test_setup): + """Test decoding all samples from an audio file.""" + with open("tests/data/test_audio.flac", "rb") as f: + raw_bytes = f.read() + stream = io.BytesIO(raw_bytes) + + av_decoder = AVDecoder(stream) + av_data = av_decoder.get_audio() + audio_tensor = av_data.audio_clips[0] + + assert (audio_tensor == audio_test_setup).all(), "Energon decoded audio does not match baseline" + - def test_decode_clips(self): - """Test decoding multiple clips from an audio file.""" - with open("tests/data/test_audio.flac", "rb") as f: - raw_bytes = f.read() - stream = io.BytesIO(raw_bytes) - - av_decoder = AVDecoder(stream) - av_data = get_clips_uniform( - av_decoder=av_decoder, num_clips=5, clip_duration_seconds=3, request_audio=True - ) +def test_decode_clips(audio_test_setup): + """Test decoding multiple clips from an audio file.""" + with open("tests/data/test_audio.flac", "rb") as f: + raw_bytes = f.read() + stream = io.BytesIO(raw_bytes) + + av_decoder = AVDecoder(stream) + av_data = get_clips_uniform( + av_decoder=av_decoder, num_clips=5, clip_duration_seconds=3, request_audio=True + ) + audio_tensor = av_data.audio_clips[0] + audio_sps = av_decoder.get_audio_samples_per_second() + + # Check audio tensor shape (5 clips, channels, 3 seconds at original sample rate) + assert len(av_data.audio_clips) == 5 + assert len(av_data.audio_timestamps) == 5 + assert audio_tensor.shape[1] >= int(3 * audio_sps) + assert audio_tensor.shape[1] <= int(4 * audio_sps) + + +def test_decode_wav(audio_test_setup): + """Test decoding a WAV file.""" + # Skip WAV test if file doesn't exist + if not os.path.exists("tests/data/test_audio.wav"): + pytest.skip("WAV test file not found") + return + + with open("tests/data/test_audio.wav", "rb") as f: + raw_bytes = f.read() + stream = io.BytesIO(raw_bytes) + + av_decoder = AVDecoder(stream) + av_data = get_clips_uniform( + av_decoder=av_decoder, num_clips=3, clip_duration_seconds=3, request_audio=True + ) + audio_sps = av_decoder.get_audio_samples_per_second() + + # Check audio tensor shape (3 clips, 2 channels, samples) + expected_samples = int(3 * audio_sps) # 3 seconds at original sample rate + assert all( + audio_tensor.shape == torch.Size([2, expected_samples]) + for audio_tensor in av_data.audio_clips + ), "Energon decoded WAV file has wrong shape." + + +def test_decode_wav_same_shape(audio_test_setup): + """Test decoding a WAV file.""" + # Skip WAV test if file doesn't exist + if not os.path.exists("tests/data/test_audio.wav"): + pytest.skip("WAV test file not found") + return + + with open("tests/data/test_audio.wav", "rb") as f: + raw_bytes = f.read() + stream = io.BytesIO(raw_bytes) + + av_decoder = AVDecoder(stream) + av_data = get_clips_uniform( + av_decoder=av_decoder, + num_clips=10, + clip_duration_seconds=0.9954783485892385, + request_audio=True, + ) + audio_sps = av_decoder.get_audio_samples_per_second() + + print(f"SPS: {audio_sps}") + for audio_tensor in av_data.audio_clips: + print(audio_tensor.shape) + + assert all( + audio_tensor.shape == av_data.audio_clips[0].shape for audio_tensor in av_data.audio_clips + ), "Audio clips have different shapes" + + +def test_wav_decode_against_soundfile(audio_test_setup): + """Test decoding a WAV file against the soundfile library.""" + + try: + import soundfile + except ImportError: + pytest.skip("soundfile library not found") + + with open("tests/data/test_audio.wav", "rb") as f: + raw_bytes = f.read() + stream = io.BytesIO(raw_bytes) + + av_decoder = AVDecoder(stream) + av_data = av_decoder.get_clips(audio_clip_ranges=[(0, float("inf"))], audio_unit="samples") + audio_tensor = av_data.audio_clips[0] + + # Load the same audio file using soundfile + + audio_data, _ = soundfile.read("tests/data/test_audio.wav", dtype="int16") + audio_tensor_soundfile = torch.from_numpy(audio_data).transpose(0, 1) + + # Check that the two tensors are close + assert tensors_close(audio_tensor, audio_tensor_soundfile, tolerance=0.01), ( + "Energon decoded audio does not match baseline" + ) + + # Now check partial extraction in the middle of the audio + av_data = av_decoder.get_clips(audio_clip_ranges=[(0.5, 1.0)], audio_unit="seconds") + audio_tensor = av_data.audio_clips[0] + audio_sps = av_decoder.get_audio_samples_per_second() + audio_tensor_soundfile = torch.from_numpy( + audio_data[int(0.5 * audio_sps) : int(1.0 * audio_sps)] + ).transpose(0, 1) + + # Check that the two tensors are close + assert tensors_close(audio_tensor, audio_tensor_soundfile, tolerance=0.01), ( + "Energon decoded audio does not match baseline" + ) + + # Now compare the speed of the two implementations by repeatedly decoding the same audio + num_trials = 100 + + start_time = time.perf_counter() + for _ in range(num_trials): + av_data = av_decoder.get_clips(audio_clip_ranges=[(0, float("inf"))], audio_unit="samples") audio_tensor = av_data.audio_clips[0] - audio_sps = av_decoder.get_audio_samples_per_second() - - # Check audio tensor shape (5 clips, channels, 3 seconds at original sample rate) - assert len(av_data.audio_clips) == 5 - assert len(av_data.audio_timestamps) == 5 - assert audio_tensor.shape[1] >= int(3 * audio_sps) - assert audio_tensor.shape[1] <= int(4 * audio_sps) - - def test_decode_wav(self): - """Test decoding a WAV file.""" - # Skip WAV test if file doesn't exist - if not os.path.exists("tests/data/test_audio.wav"): - self.skipTest("WAV test file not found") - return - - with open("tests/data/test_audio.wav", "rb") as f: - raw_bytes = f.read() - stream = io.BytesIO(raw_bytes) - - av_decoder = AVDecoder(stream) - av_data = get_clips_uniform( - av_decoder=av_decoder, num_clips=3, clip_duration_seconds=3, request_audio=True - ) - audio_sps = av_decoder.get_audio_samples_per_second() - - # Check audio tensor shape (3 clips, 2 channels, samples) - expected_samples = int(3 * audio_sps) # 3 seconds at original sample rate - assert all( - audio_tensor.shape == torch.Size([2, expected_samples]) - for audio_tensor in av_data.audio_clips - ), "Energon decoded WAV file has wrong shape." - - def test_decode_wav_same_shape(self): - """Test decoding a WAV file.""" - # Skip WAV test if file doesn't exist - if not os.path.exists("tests/data/test_audio.wav"): - self.skipTest("WAV test file not found") - return - - with open("tests/data/test_audio.wav", "rb") as f: - raw_bytes = f.read() - stream = io.BytesIO(raw_bytes) - - av_decoder = AVDecoder(stream) - av_data = get_clips_uniform( - av_decoder=av_decoder, - num_clips=10, - clip_duration_seconds=0.9954783485892385, - request_audio=True, - ) - audio_sps = av_decoder.get_audio_samples_per_second() - - print(f"SPS: {audio_sps}") - for audio_tensor in av_data.audio_clips: - print(audio_tensor.shape) + end_time = time.perf_counter() + print(f"AVDecoder time: {end_time - start_time} seconds") - assert all( - audio_tensor.shape == av_data.audio_clips[0].shape - for audio_tensor in av_data.audio_clips - ), "Audio clips have different shapes" - - def test_wav_decode_against_soundfile(self): - """Test decoding a WAV file against the soundfile library.""" - - try: - import soundfile - except ImportError: - self.skipTest("soundfile library not found") - - with open("tests/data/test_audio.wav", "rb") as f: - raw_bytes = f.read() - stream = io.BytesIO(raw_bytes) + # Now do the same with soundfile + start_time = time.perf_counter() + for _ in range(num_trials): + audio_data, _ = soundfile.read("tests/data/test_audio.wav", dtype="int16") + audio_tensor_soundfile = torch.from_numpy(audio_data).transpose(0, 1) + end_time = time.perf_counter() + print(f"Soundfile time: {end_time - start_time} seconds") - av_decoder = AVDecoder(stream) + start_time = time.perf_counter() + for _ in range(num_trials): av_data = av_decoder.get_clips(audio_clip_ranges=[(0, float("inf"))], audio_unit="samples") audio_tensor = av_data.audio_clips[0] + end_time = time.perf_counter() + print(f"AVDecoder time: {end_time - start_time} seconds") - # Load the same audio file using soundfile - + # Now do the same with soundfile + start_time = time.perf_counter() + for _ in range(num_trials): audio_data, _ = soundfile.read("tests/data/test_audio.wav", dtype="int16") audio_tensor_soundfile = torch.from_numpy(audio_data).transpose(0, 1) - - # Check that the two tensors are close - assert tensors_close(audio_tensor, audio_tensor_soundfile, tolerance=0.01), ( - "Energon decoded audio does not match baseline" - ) - - # Now check partial extraction in the middle of the audio - av_data = av_decoder.get_clips(audio_clip_ranges=[(0.5, 1.0)], audio_unit="seconds") - audio_tensor = av_data.audio_clips[0] - audio_sps = av_decoder.get_audio_samples_per_second() - audio_tensor_soundfile = torch.from_numpy( - audio_data[int(0.5 * audio_sps) : int(1.0 * audio_sps)] - ).transpose(0, 1) - - # Check that the two tensors are close - assert tensors_close(audio_tensor, audio_tensor_soundfile, tolerance=0.01), ( - "Energon decoded audio does not match baseline" + end_time = time.perf_counter() + print(f"Soundfile time: {end_time - start_time} seconds") + + +def test_decode_audio_metadata(audio_test_setup): + """Test decoding metadata.""" + expected_metadata = [ + AVMetadata( + audio_duration=10.0, + audio_channels=1, + audio_sample_rate=32000, + ), + AVMetadata( + audio_duration=12.782585034013605, + audio_channels=2, + audio_sample_rate=44100, + ), + ] + for audio_file, expected_metadata in zip( + ["tests/data/test_audio.flac", "tests/data/test_audio.wav"], expected_metadata + ): + av_decoder = AVDecoder(io.BytesIO(Path(audio_file).read_bytes())) + assert av_decoder.get_metadata() == expected_metadata, ( + f"Metadata does not match expected metadata for {audio_file}: {av_decoder.get_metadata()}" ) - # Now compare the speed of the two implementations by repeatedly decoding the same audio - num_trials = 100 - - start_time = time.perf_counter() - for _ in range(num_trials): - av_data = av_decoder.get_clips( - audio_clip_ranges=[(0, float("inf"))], audio_unit="samples" - ) - audio_tensor = av_data.audio_clips[0] - end_time = time.perf_counter() - print(f"AVDecoder time: {end_time - start_time} seconds") - - # Now do the same with soundfile - start_time = time.perf_counter() - for _ in range(num_trials): - audio_data, _ = soundfile.read("tests/data/test_audio.wav", dtype="int16") - audio_tensor_soundfile = torch.from_numpy(audio_data).transpose(0, 1) - end_time = time.perf_counter() - print(f"Soundfile time: {end_time - start_time} seconds") - - start_time = time.perf_counter() - for _ in range(num_trials): - av_data = av_decoder.get_clips( - audio_clip_ranges=[(0, float("inf"))], audio_unit="samples" - ) - audio_tensor = av_data.audio_clips[0] - end_time = time.perf_counter() - print(f"AVDecoder time: {end_time - start_time} seconds") - - # Now do the same with soundfile - start_time = time.perf_counter() - for _ in range(num_trials): - audio_data, _ = soundfile.read("tests/data/test_audio.wav", dtype="int16") - audio_tensor_soundfile = torch.from_numpy(audio_data).transpose(0, 1) - end_time = time.perf_counter() - print(f"Soundfile time: {end_time - start_time} seconds") - - def test_decode_metadata(self): - """Test decoding metadata.""" - expected_metadata = [ - AVMetadata( - audio_duration=10.0, - audio_channels=1, - audio_sample_rate=32000, - ), - AVMetadata( - audio_duration=12.782585034013605, - audio_channels=2, - audio_sample_rate=44100, - ), - ] - for audio_file, expected_metadata in zip( - ["tests/data/test_audio.flac", "tests/data/test_audio.wav"], expected_metadata - ): - av_decoder = AVDecoder(io.BytesIO(Path(audio_file).read_bytes())) - assert av_decoder.get_metadata() == expected_metadata, ( - f"Metadata does not match expected metadata for {audio_file}: {av_decoder.get_metadata()}" - ) - - assert av_decoder.get_audio_duration() == expected_metadata.audio_duration - assert av_decoder.get_audio_samples_per_second() == expected_metadata.audio_sample_rate - - -if __name__ == "__main__": - unittest.main() + assert av_decoder.get_audio_duration() == expected_metadata.audio_duration + assert av_decoder.get_audio_samples_per_second() == expected_metadata.audio_sample_rate diff --git a/tests/test_crudedataset.py b/tests/test_crudedataset.py index 2d54ecd6..9219b3d6 100644 --- a/tests/test_crudedataset.py +++ b/tests/test_crudedataset.py @@ -9,11 +9,11 @@ import re import sys import tempfile -import unittest import warnings from pathlib import Path from typing import List +import pytest import torch import webdataset as wds @@ -218,482 +218,485 @@ def batch(self, samples: List[TextSample]) -> TextBatch: ) -class TestDataset(unittest.TestCase): - # Set up the test fixture - def setUp(self): - logging.basicConfig(stream=sys.stderr, level=logging.INFO) - warnings.simplefilter("ignore", ResourceWarning) - - # Create a temporary directory - self.temp_dir = tempfile.TemporaryDirectory() - self.dataset_path = Path(self.temp_dir.name) - # self.dataset_path = Path("./test_dataset") - - self.dataset_path.mkdir(exist_ok=True, parents=True) - - (self.dataset_path / "ds1").mkdir(exist_ok=True, parents=True) - (self.dataset_path / "ds2").mkdir(exist_ok=True, parents=True) - - # Create a small dummy captioning dataset - self.create_crude_text_test_dataset(self.dataset_path / "ds1", 0) - self.create_crude_text_test_dataset(self.dataset_path / "ds2", 100) - - self.mds_path = self.dataset_path / "metadataset.yaml" - with open(self.mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: Metadataset", - "splits:", - " train:", - " datasets:", - " - weight: 1", - " path: ds1", - " subflavors:", - " source: metadataset.yaml", - " number: 43", - " mds: mds", - " crude_type: txtpkl", - " shuffle_over_epochs_multiplier: 3", - " - weight: 1", - " path: ds2", - " subflavors:", - " source: metadataset.yaml", - " number: 44", - " mds: mds", - " crude_type: otherpkl", - " val:", - " datasets:", - " - weight: 1", - " path: ds1", - " split_part: train", - " - weight: 1", - " path: ds2", - " split_part: train", - ] - ) +@pytest.fixture +def dataset_path(): + logging.basicConfig(stream=sys.stderr, level=logging.INFO) + warnings.simplefilter("ignore", ResourceWarning) + + # Create a temporary directory + temp_dir = tempfile.TemporaryDirectory() + dataset_path = Path(temp_dir.name) + # dataset_path = Path("./test_dataset") + + dataset_path.mkdir(exist_ok=True, parents=True) + + (dataset_path / "ds1").mkdir(exist_ok=True, parents=True) + (dataset_path / "ds2").mkdir(exist_ok=True, parents=True) + + # Create a small dummy captioning dataset + create_crude_text_test_dataset(dataset_path / "ds1", 0) + create_crude_text_test_dataset(dataset_path / "ds2", 100) + + mds_path = dataset_path / "metadataset.yaml" + with open(mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: Metadataset", + "splits:", + " train:", + " datasets:", + " - weight: 1", + " path: ds1", + " subflavors:", + " source: metadataset.yaml", + " number: 43", + " mds: mds", + " crude_type: txtpkl", + " shuffle_over_epochs_multiplier: 3", + " - weight: 1", + " path: ds2", + " subflavors:", + " source: metadataset.yaml", + " number: 44", + " mds: mds", + " crude_type: otherpkl", + " val:", + " datasets:", + " - weight: 1", + " path: ds1", + " split_part: train", + " - weight: 1", + " path: ds2", + " split_part: train", + ] ) + ) - self.aux_mds_path = self.dataset_path / "aux_metadataset.yaml" - with open(self.aux_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " path: ds1", - " aux:", - " pkl_source: ds2", - " fs_source: filesystem://.", - " subflavors:", - " crude_type: aux_random_access", - ] - ) + aux_mds_path = dataset_path / "aux_metadataset.yaml" + with open(aux_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " path: ds1", + " aux:", + " pkl_source: ds2", + " fs_source: filesystem://.", + " subflavors:", + " crude_type: aux_random_access", + ] ) - - print(self.dataset_path) - - def tearDown(self): - # Remove all temporary files - gc.collect() - self.temp_dir.cleanup() - - @staticmethod - def create_crude_text_test_dataset(path: Path, offset: int): - """Creates a small dummy test dataset for testing purposes.""" - - # Create num_samples unique captions - (path / "parts").mkdir(exist_ok=True, parents=True) - - # Initialize the ShardWriter - with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=10) as shard_writer: - for idx in range(55): - # Write individual files to shards - shard_writer.write( - { - "__key__": f"{idx + offset:06d}", - "txt": f"{idx + offset}".encode(), - "pkl": pickle.dumps({"idx": idx + offset}), - }, - ) - total_shards = shard_writer.shard - - from megatron.energon.flavors import BaseWebdatasetFactory - - BaseWebdatasetFactory.prepare_dataset( - path, - [f"parts/data-{{0..{total_shards - 1}}}.tar"], - split_parts_ratio=[("train", 1.0)], - shuffle_seed=None, - workers=1, ) - with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: CrudeWebdataset", - "subflavors:", - " dataset.yaml: true", - " number: 42", - ] - ) + print(dataset_path) + + yield dataset_path + + # Remove all temporary files + gc.collect() + temp_dir.cleanup() + + +def create_crude_text_test_dataset(path: Path, offset: int): + """Creates a small dummy test dataset for testing purposes.""" + + # Create num_samples unique captions + (path / "parts").mkdir(exist_ok=True, parents=True) + + # Initialize the ShardWriter + with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=10) as shard_writer: + for idx in range(55): + # Write individual files to shards + shard_writer.write( + { + "__key__": f"{idx + offset:06d}", + "txt": f"{idx + offset}".encode(), + "pkl": pickle.dumps({"idx": idx + offset}), + }, ) + total_shards = shard_writer.shard + + from megatron.energon.flavors import BaseWebdatasetFactory + + BaseWebdatasetFactory.prepare_dataset( + path, + [f"parts/data-{{0..{total_shards - 1}}}.tar"], + split_parts_ratio=[("train", 1.0)], + shuffle_seed=None, + workers=1, + ) - def test_metadataset(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, + with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: CrudeWebdataset", + "subflavors:", + " dataset.yaml: true", + " number: 42", + ] + ) ) - # Train mode dataset - torch.manual_seed(42) - train_dataset = get_train_dataset( - self.mds_path, + +def test_metadataset(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + ) + + # Train mode dataset + torch.manual_seed(42) + train_dataset = get_train_dataset( + dataset_path / "metadataset.yaml", + worker_config=worker_config, + batch_size=3, + task_encoder=CookingTaskEncoder(), + shuffle_buffer_size=None, + max_samples_per_sequence=None, + handler=reraise_exception, + ) + with get_savable_loader( + train_dataset, + ) as loader: + print(len(train_dataset)) + # assert len(train_dataset) == 11 + + for idx, data in enumerate(loader): + if idx >= len(train_dataset): + break + + assert isinstance(data, TextBatch) + + print("Batch", idx) + for txt, key in zip(data.txts, data.__key__): + key_int = int(key.split("/")[-1]) + if key_int < 100: + assert txt == f"<{key_int}>" + else: + assert txt == f"<{key_int}|{key_int}>" + + print(key, txt) + + +def test_loader(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=2, + ) + + with get_savable_loader( + get_train_dataset( + dataset_path / "metadataset.yaml", + batch_size=2, worker_config=worker_config, - batch_size=3, task_encoder=CookingTaskEncoder(), shuffle_buffer_size=None, max_samples_per_sequence=None, - handler=reraise_exception, - ) - with get_savable_loader( - train_dataset, - ) as loader: - print(len(train_dataset)) - # assert len(train_dataset) == 11 - - for idx, data in enumerate(loader): - if idx >= len(train_dataset): - break - - assert isinstance(data, TextBatch) - - print("Batch", idx) - for txt, key in zip(data.txts, data.__key__): - key_int = int(key.split("/")[-1]) - if key_int < 100: - assert txt == f"<{key_int}>" - else: - assert txt == f"<{key_int}|{key_int}>" - - print(key, txt) - - def test_loader(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=2, - ) + packing_buffer_size=2, + ), + ) as loader: + samples = [s.__key__ for idx, s in zip(range(100), loader)] - with get_savable_loader( - get_train_dataset( - self.mds_path, - batch_size=2, - worker_config=worker_config, - task_encoder=CookingTaskEncoder(), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - packing_buffer_size=2, - ), - ) as loader: - samples = [s.__key__ for idx, s in zip(range(100), loader)] + print(samples) - print(samples) + state = loader.save_state_rank() - state = loader.save_state_rank() + samples_after = [s.__key__ for idx, s in zip(range(100, 200), loader)] + print(samples_after) - samples_after = [s.__key__ for idx, s in zip(range(100, 200), loader)] - print(samples_after) + with get_savable_loader( + get_train_dataset( + dataset_path / "metadataset.yaml", + batch_size=2, + worker_config=worker_config, + task_encoder=CookingTaskEncoder(), + shuffle_buffer_size=None, + max_samples_per_sequence=None, + packing_buffer_size=2, + ), + ).with_restored_state_rank(state) as loader: + samples_restored = [s.__key__ for idx, s in zip(range(100, 200), loader)] + print(samples_restored) - with get_savable_loader( - get_train_dataset( - self.mds_path, - batch_size=2, - worker_config=worker_config, - task_encoder=CookingTaskEncoder(), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - packing_buffer_size=2, - ), - ).with_restored_state_rank(state) as loader: - samples_restored = [s.__key__ for idx, s in zip(range(100, 200), loader)] - print(samples_restored) - - assert all([a == b for a, b in zip(samples_after, samples_restored)]) - - def test_aux_random_access(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=2, - ) + assert all([a == b for a, b in zip(samples_after, samples_restored)]) - print("Initializing dataset") - - with get_savable_loader( - get_train_dataset( - self.aux_mds_path, - batch_size=2, - worker_config=worker_config, - task_encoder=CookingTaskEncoder(), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - packing_buffer_size=2, - ), - ) as loader: - print("Iterating from dataset") - samples = [s.txts for idx, s in zip(range(100), loader)] - for idx, txts in enumerate(samples): - for txt in txts: - m = re.fullmatch(r"<([0-9]*)\|aux\|([0-9]*)>", txt) - assert m, f"Invalid aux text: {txt}" - assert int(m.group(2)) == int(m.group(1)) + 100 - - print(samples) - - state = loader.save_state_rank() - - samples_after = [s.__key__ for idx, s in zip(range(100, 200), loader)] - print(samples_after) - - with get_savable_loader( - get_train_dataset( - self.aux_mds_path, - batch_size=2, - worker_config=worker_config, - task_encoder=CookingTaskEncoder(), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - packing_buffer_size=2, - ), - ).with_restored_state_rank(state) as loader: - samples_restored = [s.__key__ for idx, s in zip(range(100, 200), loader)] - print(samples_restored) - - assert all([a == b for a, b in zip(samples_after, samples_restored)]) - - def test_aux_random_access_with_cache(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=2, - ) - print("Initializing dataset") - - with get_savable_loader( - get_train_dataset( - self.aux_mds_path, - batch_size=2, - worker_config=worker_config, - task_encoder=LazyCookingTaskEncoder(), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - packing_buffer_size=2, - ), - cache_pool=FileStoreCachePool( - parent_cache_dir=self.dataset_path / "cache", - num_workers=1, - ), - ) as loader: - print("Iterating from dataset") - samples = [s.txts for idx, s in zip(range(100), loader)] - for idx, txts in enumerate(samples): - for txt in txts: - m = re.fullmatch(r"<([0-9]*)\|aux\|([0-9]*)>\|([0-9]*)", txt) - assert m, f"Invalid aux text: {txt}" - assert int(m.group(2)) == int(m.group(1)) + 100 - assert int(m.group(3)) == (int(m.group(1)) + 1) % 55 - - print(samples) - - state = loader.save_state_rank() - - samples_after = [s.__key__ for idx, s in zip(range(100, 200), loader)] - print(samples_after) - - with get_savable_loader( - get_train_dataset( - self.aux_mds_path, - batch_size=2, - worker_config=worker_config, - task_encoder=CookingTaskEncoder(), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - packing_buffer_size=2, +def test_aux_random_access(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=2, + ) + + print("Initializing dataset") + + with get_savable_loader( + get_train_dataset( + dataset_path / "aux_metadataset.yaml", + batch_size=2, + worker_config=worker_config, + task_encoder=CookingTaskEncoder(), + shuffle_buffer_size=None, + max_samples_per_sequence=None, + packing_buffer_size=2, + ), + ) as loader: + print("Iterating from dataset") + samples = [s.txts for idx, s in zip(range(100), loader)] + for idx, txts in enumerate(samples): + for txt in txts: + m = re.fullmatch(r"<([0-9]*)\|aux\|([0-9]*)>", txt) + assert m, f"Invalid aux text: {txt}" + assert int(m.group(2)) == int(m.group(1)) + 100 + + print(samples) + + state = loader.save_state_rank() + + samples_after = [s.__key__ for idx, s in zip(range(100, 200), loader)] + print(samples_after) + + with get_savable_loader( + get_train_dataset( + dataset_path / "aux_metadataset.yaml", + batch_size=2, + worker_config=worker_config, + task_encoder=CookingTaskEncoder(), + shuffle_buffer_size=None, + max_samples_per_sequence=None, + packing_buffer_size=2, + ), + ).with_restored_state_rank(state) as loader: + samples_restored = [s.__key__ for idx, s in zip(range(100, 200), loader)] + print(samples_restored) + + assert all([a == b for a, b in zip(samples_after, samples_restored)]) + + +def test_aux_random_access_with_cache(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=2, + ) + + print("Initializing dataset") + + with get_savable_loader( + get_train_dataset( + dataset_path / "aux_metadataset.yaml", + batch_size=2, + worker_config=worker_config, + task_encoder=LazyCookingTaskEncoder(), + shuffle_buffer_size=None, + max_samples_per_sequence=None, + packing_buffer_size=2, + ), + cache_pool=FileStoreCachePool( + parent_cache_dir=dataset_path / "cache", + num_workers=1, + ), + ) as loader: + print("Iterating from dataset") + samples = [s.txts for idx, s in zip(range(100), loader)] + for idx, txts in enumerate(samples): + for txt in txts: + m = re.fullmatch(r"<([0-9]*)\|aux\|([0-9]*)>\|([0-9]*)", txt) + assert m, f"Invalid aux text: {txt}" + assert int(m.group(2)) == int(m.group(1)) + 100 + assert int(m.group(3)) == (int(m.group(1)) + 1) % 55 + + print(samples) + + state = loader.save_state_rank() + + samples_after = [s.__key__ for idx, s in zip(range(100, 200), loader)] + print(samples_after) + + with get_savable_loader( + get_train_dataset( + dataset_path / "aux_metadataset.yaml", + batch_size=2, + worker_config=worker_config, + task_encoder=CookingTaskEncoder(), + shuffle_buffer_size=None, + max_samples_per_sequence=None, + packing_buffer_size=2, + ), + cache_pool=FileStoreCachePool( + parent_cache_dir=dataset_path / "cache", + num_workers=1, + ), + ).with_restored_state_rank(state) as loader: + samples_restored = [s.__key__ for idx, s in zip(range(100, 200), loader)] + print(samples_restored) + + assert all([a == b for a, b in zip(samples_after, samples_restored)]) + + +def test_aux_random_access_with_cache_and_postencode(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=2, + ) + + print("Initializing dataset") + + with get_savable_loader( + get_train_dataset( + dataset_path / "aux_metadataset.yaml", + batch_size=2, + worker_config=worker_config, + task_encoder=LazyCookingTaskEncoderWithPostencode(), + shuffle_buffer_size=None, + max_samples_per_sequence=None, + packing_buffer_size=2, + ), + cache_pool=FileStoreCachePool( + parent_cache_dir=dataset_path / "cache", + num_workers=1, + ), + ) as loader: + print("Iterating from dataset") + samples = [s.txts for idx, s in zip(range(100), loader)] + for idx, txts in enumerate(samples): + for txt in txts: + m = re.fullmatch(r"<([0-9]*)\|aux\|([0-9]*)>\|([0-9]*)", txt) + assert m, f"Invalid aux text: {txt}" + assert int(m.group(2)) == int(m.group(1)) + 100 + assert int(m.group(3)) == (int(m.group(1)) + 1) % 55 + + print(samples) + + state = loader.save_state_rank() + + samples_after = [s.__key__ for idx, s in zip(range(100, 200), loader)] + print(samples_after) + + with get_savable_loader( + get_train_dataset( + dataset_path / "aux_metadataset.yaml", + batch_size=2, + worker_config=worker_config, + task_encoder=LazyCookingTaskEncoderWithPostencode(), + shuffle_buffer_size=None, + max_samples_per_sequence=None, + packing_buffer_size=2, + ), + cache_pool=FileStoreCachePool( + parent_cache_dir=dataset_path / "cache", + num_workers=1, + ), + ).with_restored_state_rank(state) as loader: + samples_restored = [s.__key__ for idx, s in zip(range(100, 200), loader)] + print(samples_restored) + + assert all([a == b for a, b in zip(samples_after, samples_restored)]) + + # Verify that the sources are correct + sample_src_check = [s.__sources__ for idx, s in zip(range(1), loader)][0] + print(sample_src_check) + # NOTE: Auxiliary sources have string as index, not int + assert sample_src_check == ( + # Primary source for the sample, reading all source files + SourceInfo( + dataset_path=EPath(dataset_path / "ds1"), + index=2, + shard_name="parts/data-0.tar", + file_names=("000002.pkl", "000002.txt"), ), - cache_pool=FileStoreCachePool( - parent_cache_dir=self.dataset_path / "cache", - num_workers=1, + # Auxiliary source for the sample, reading from ds2 + SourceInfo( + dataset_path=EPath(dataset_path / "ds2"), + index="000102.txt", + shard_name="parts/data-0.tar", + file_names=("000102.txt",), ), - ).with_restored_state_rank(state) as loader: - samples_restored = [s.__key__ for idx, s in zip(range(100, 200), loader)] - print(samples_restored) - - assert all([a == b for a, b in zip(samples_after, samples_restored)]) - - def test_aux_random_access_with_cache_and_postencode(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=2, - ) - - print("Initializing dataset") - - with get_savable_loader( - get_train_dataset( - self.aux_mds_path, - batch_size=2, - worker_config=worker_config, - task_encoder=LazyCookingTaskEncoderWithPostencode(), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - packing_buffer_size=2, + # Auxiliary source for the sample, reading from ds1, but next sample + SourceInfo( + dataset_path=EPath(dataset_path / "ds1"), + index="000003.txt", + shard_name="parts/data-0.tar", + file_names=("000003.txt",), ), - cache_pool=FileStoreCachePool( - parent_cache_dir=self.dataset_path / "cache", - num_workers=1, + SourceInfo( + dataset_path=EPath(dataset_path / "ds1"), + index=21, + shard_name="parts/data-2.tar", + file_names=("000021.pkl", "000021.txt"), ), - ) as loader: - print("Iterating from dataset") - samples = [s.txts for idx, s in zip(range(100), loader)] - for idx, txts in enumerate(samples): - for txt in txts: - m = re.fullmatch(r"<([0-9]*)\|aux\|([0-9]*)>\|([0-9]*)", txt) - assert m, f"Invalid aux text: {txt}" - assert int(m.group(2)) == int(m.group(1)) + 100 - assert int(m.group(3)) == (int(m.group(1)) + 1) % 55 - - print(samples) - - state = loader.save_state_rank() - - samples_after = [s.__key__ for idx, s in zip(range(100, 200), loader)] - print(samples_after) - - with get_savable_loader( - get_train_dataset( - self.aux_mds_path, - batch_size=2, - worker_config=worker_config, - task_encoder=LazyCookingTaskEncoderWithPostencode(), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - packing_buffer_size=2, + SourceInfo( + dataset_path=EPath(dataset_path / "ds2"), + index="000121.txt", + shard_name="parts/data-2.tar", + file_names=("000121.txt",), ), - cache_pool=FileStoreCachePool( - parent_cache_dir=self.dataset_path / "cache", - num_workers=1, + SourceInfo( + dataset_path=EPath(dataset_path / "ds1"), + index="000022.txt", + shard_name="parts/data-2.tar", + file_names=("000022.txt",), ), - ).with_restored_state_rank(state) as loader: - samples_restored = [s.__key__ for idx, s in zip(range(100, 200), loader)] - print(samples_restored) - - assert all([a == b for a, b in zip(samples_after, samples_restored)]) - - # Verify that the sources are correct - sample_src_check = [s.__sources__ for idx, s in zip(range(1), loader)][0] - print(sample_src_check) - # NOTE: Auxiliary sources have string as index, not int - assert sample_src_check == ( - # Primary source for the sample, reading all source files - SourceInfo( - dataset_path=EPath(self.dataset_path / "ds1"), - index=2, - shard_name="parts/data-0.tar", - file_names=("000002.pkl", "000002.txt"), - ), - # Auxiliary source for the sample, reading from ds2 - SourceInfo( - dataset_path=EPath(self.dataset_path / "ds2"), - index="000102.txt", - shard_name="parts/data-0.tar", - file_names=("000102.txt",), - ), - # Auxiliary source for the sample, reading from ds1, but next sample - SourceInfo( - dataset_path=EPath(self.dataset_path / "ds1"), - index="000003.txt", - shard_name="parts/data-0.tar", - file_names=("000003.txt",), - ), - SourceInfo( - dataset_path=EPath(self.dataset_path / "ds1"), - index=21, - shard_name="parts/data-2.tar", - file_names=("000021.pkl", "000021.txt"), - ), - SourceInfo( - dataset_path=EPath(self.dataset_path / "ds2"), - index="000121.txt", - shard_name="parts/data-2.tar", - file_names=("000121.txt",), - ), - SourceInfo( - dataset_path=EPath(self.dataset_path / "ds1"), - index="000022.txt", - shard_name="parts/data-2.tar", - file_names=("000022.txt",), - ), - ) - - def test_aux_filesystem_reference(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, ) - with get_savable_loader( - get_train_dataset( - self.aux_mds_path, - batch_size=1, - worker_config=worker_config, - task_encoder=CookingTaskEncoderWithAuxFilesystemReference(), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - ) as loader: - sample = next(iter(loader)) - assert sample.txts[0].endswith("|aux|__module__: megatron.ener>") +def test_aux_filesystem_reference(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + ) + + with get_savable_loader( + get_train_dataset( + dataset_path / "aux_metadataset.yaml", + batch_size=1, + worker_config=worker_config, + task_encoder=CookingTaskEncoderWithAuxFilesystemReference(), + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ), + ) as loader: + sample = next(iter(loader)) - def test_nomds(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=2, - ) + assert sample.txts[0].endswith("|aux|__module__: megatron.ener>") - with get_savable_loader( - get_train_dataset( - self.dataset_path / "ds1", - batch_size=2, - worker_config=worker_config, - task_encoder=GenericCookingTaskEncoder(), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - ) as loader: - samples = [s.__key__ for idx, s in zip(range(100), loader)] - print(samples) - assert len(samples) == 100 +def test_nomds(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=2, + ) + with get_savable_loader( + get_train_dataset( + dataset_path / "ds1", + batch_size=2, + worker_config=worker_config, + task_encoder=GenericCookingTaskEncoder(), + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ), + ) as loader: + samples = [s.__key__ for idx, s in zip(range(100), loader)] -if __name__ == "__main__": - unittest.main() + print(samples) + assert len(samples) == 100 diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index 2bac88e1..c7c58361 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -7,12 +7,12 @@ import logging import sys import tempfile -import unittest import warnings from collections import Counter from pathlib import Path from typing import Iterable +import pytest import torch import webdataset as wds @@ -42,361 +42,349 @@ def encode_sample(self, sample): return sample -class TestDataloader(unittest.TestCase): - # Set up the test fixture - def setUp(self): - logging.basicConfig(stream=sys.stderr, level=logging.INFO) - warnings.simplefilter("ignore", ResourceWarning) - - # Create a temporary directory - self.temp_dir = tempfile.TemporaryDirectory() - self.dataset_path = Path(self.temp_dir.name) - # self.dataset_path = Path("./test_dataset") - - self.dataset_path.mkdir(exist_ok=True, parents=True) - - self.ds1_path = self.dataset_path / "ds1" - self.ds1_path.mkdir(exist_ok=True, parents=True) - - # Create a small dummy captioning dataset - self.create_text_test_dataset(self.ds1_path, range(55), range(55)) - print(self.ds1_path) - - def tearDown(self): - # Remove all temporary files - gc.collect() - self.temp_dir.cleanup() - - @staticmethod - def create_text_test_dataset(path: Path, txt_range: Iterable[int], key_range: Iterable[int]): - """Creates a small dummy test dataset for testing purposes.""" - - # Create num_samples unique captions - (path / "parts").mkdir(exist_ok=True, parents=True) - - # Initialize the ShardWriter - with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=10) as shard_writer: - for key, txt in zip(key_range, txt_range): - # Write individual files to shards - shard_writer.write( - { - "__key__": f"{key:06d}", - "txt": f"{txt}".encode(), - }, - ) - total_shards = shard_writer.shard - - from megatron.energon.flavors import BaseWebdatasetFactory - - BaseWebdatasetFactory.prepare_dataset( - path, - [f"parts/data-{{0..{total_shards - 1}}}.tar"], - split_parts_ratio=[("train", 1.0)], - shuffle_seed=None, - ) +@pytest.fixture +def temp_dir(): + temp_dir = tempfile.TemporaryDirectory() + yield temp_dir + gc.collect() + temp_dir.cleanup() - with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: - f.write( - "\n".join( - [ - "sample_type:", - " __module__: megatron.energon", - " __class__: TextSample", - "field_map:", - " text: txt", - "subflavors:", - " source: dataset.yaml", - " dataset.yaml: true", - " number: 42", - ] - ) - ) - def test_dataloader_no_workers(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, - ) +@pytest.fixture +def dataset_path(temp_dir): + dataset_path = Path(temp_dir.name) + dataset_path.mkdir(exist_ok=True, parents=True) + return dataset_path - # Train mode dataset - with DataLoader( - get_train_dataset( - self.ds1_path, - worker_config=worker_config, - batch_size=10, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - repeat=False, - task_encoder=VerifyWorkerTaskEncoder(expected_num_workers=0), - ), - ) as train_loader: - assert len(train_loader) == 6, len(train_loader) - - train_order1 = [ - text for idx, data in zip(range(55 * 10), train_loader) for text in data.text - ] - print(train_order1[:10]) - print(Counter(train_order1)) - assert len(train_order1) == 55, len(train_order1) - assert len(Counter(train_order1)) == 55, Counter(train_order1) - assert all(v == 1 for v in Counter(train_order1).values()), Counter(train_order1) - - state1 = train_loader.save_state_rank() - - train_order2 = [ - text for idx, data in zip(range(55 * 10), train_loader) for text in data.text - ] - - with DataLoader( - get_train_dataset( - self.ds1_path, - worker_config=worker_config, - batch_size=10, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - repeat=False, - task_encoder=VerifyWorkerTaskEncoder(expected_num_workers=0), - ), - ).with_restored_state_rank(state1) as train_loader: - cmp_order2 = [ - text for idx, data in zip(range(55 * 10), train_loader) for text in data.text - ] - assert train_order2 == cmp_order2, (train_order1, cmp_order2) - - def test_dataloader_fork(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=2, - seed_offset=42, - ) - # Train mode dataset - with DataLoader( - get_train_dataset( - self.ds1_path, - worker_config=worker_config, - batch_size=10, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - repeat=False, - task_encoder=VerifyWorkerTaskEncoder(expected_num_workers=2), - ), - prefetch_factor=2, - worker_type=ForkDataLoaderWorker, - gc_collect_every_n_steps=10, - gc_freeze_at_start=True, - watchdog_timeout_seconds=60, - fail_on_timeout=True, - ) as train_loader: - assert len(train_loader) == 6, len(train_loader) - - train_order1 = [ - text for idx, data in zip(range(55 * 10), train_loader) for text in data.text - ] - print(train_order1[:10]) - print(Counter(train_order1)) - assert len(train_order1) == 55, len(train_order1) - assert len(Counter(train_order1)) == 55, Counter(train_order1) - assert all(v == 1 for v in Counter(train_order1).values()), Counter(train_order1) - - state1 = train_loader.save_state_rank() - - train_order2 = [ - text for idx, data in zip(range(55 * 10), train_loader) for text in data.text - ] - - assert len(train_order1) == len(train_order2), (len(train_order1), len(train_order2)) - - with DataLoader( - get_train_dataset( - self.ds1_path, - worker_config=worker_config, - batch_size=10, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - repeat=False, - task_encoder=VerifyWorkerTaskEncoder( - expected_num_workers=worker_config.num_workers - ), - ), - prefetch_factor=2, - worker_type=ForkDataLoaderWorker, - gc_collect_every_n_steps=10, - gc_freeze_at_start=True, - watchdog_timeout_seconds=60, - fail_on_timeout=True, - ).with_restored_state_rank(state1) as train_loader: - cmp_order2 = [ - text for idx, data in zip(range(55 * 10), train_loader) for text in data.text - ] - assert train_order2 == cmp_order2, (train_order1, cmp_order2) - - def test_dataloader_fork_multi_parallel(self): - torch.manual_seed(42) - worker_config_r0 = WorkerConfig( - rank=0, - world_size=2, - num_workers=2, - seed_offset=42, - ) - worker_config_r1 = WorkerConfig( - rank=1, - world_size=2, - num_workers=2, - seed_offset=42, - ) +@pytest.fixture +def ds1_path(dataset_path): + ds1_path = dataset_path / "ds1" + ds1_path.mkdir(exist_ok=True, parents=True) + create_text_test_dataset(ds1_path, range(55), range(55)) + print(ds1_path) + return ds1_path + + +@pytest.fixture(autouse=True) +def setup_logging(): + logging.basicConfig(stream=sys.stderr, level=logging.INFO) + warnings.simplefilter("ignore", ResourceWarning) + + +def create_text_test_dataset(path: Path, txt_range: Iterable[int], key_range: Iterable[int]): + """Creates a small dummy test dataset for testing purposes.""" + + # Create num_samples unique captions + (path / "parts").mkdir(exist_ok=True, parents=True) - # Train mode dataset - train_loader_r0 = DataLoader( - get_train_dataset( - self.ds1_path, - worker_config=worker_config_r0, - batch_size=10, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - repeat=False, - task_encoder=VerifyWorkerTaskEncoder( - expected_num_workers=worker_config_r0.num_workers - ), - ), - prefetch_factor=2, - worker_type=ForkDataLoaderWorker, - gc_collect_every_n_steps=10, - gc_freeze_at_start=True, - watchdog_timeout_seconds=60, - fail_on_timeout=True, + # Initialize the ShardWriter + with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=10) as shard_writer: + for key, txt in zip(key_range, txt_range): + # Write individual files to shards + shard_writer.write( + { + "__key__": f"{key:06d}", + "txt": f"{txt}".encode(), + }, + ) + total_shards = shard_writer.shard + + from megatron.energon.flavors import BaseWebdatasetFactory + + BaseWebdatasetFactory.prepare_dataset( + path, + [f"parts/data-{{0..{total_shards - 1}}}.tar"], + split_parts_ratio=[("train", 1.0)], + shuffle_seed=None, + ) + + with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: + f.write( + "\n".join( + [ + "sample_type:", + " __module__: megatron.energon", + " __class__: TextSample", + "field_map:", + " text: txt", + "subflavors:", + " source: dataset.yaml", + " dataset.yaml: true", + " number: 42", + ] + ) ) - assert len(train_loader_r0) == 4, len(train_loader_r0) - train_order1_r0 = [ - text for idx, data in zip(range(55 * 10), train_loader_r0) for text in data.text + +def test_dataloader_no_workers(ds1_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + + # Train mode dataset + with DataLoader( + get_train_dataset( + ds1_path, + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + task_encoder=VerifyWorkerTaskEncoder(expected_num_workers=0), + ), + ) as train_loader: + assert len(train_loader) == 6, len(train_loader) + + train_order1 = [ + text for idx, data in zip(range(55 * 10), train_loader) for text in data.text ] - print(train_order1_r0[:10]) - print(Counter(train_order1_r0)) - assert len(train_order1_r0) == 28, len(train_order1_r0) - assert len(Counter(train_order1_r0)) == 28, Counter(train_order1_r0) - assert all(v == 1 for v in Counter(train_order1_r0).values()), Counter(train_order1_r0) - - train_loader_r1 = DataLoader( - get_train_dataset( - self.ds1_path, - worker_config=worker_config_r1, - batch_size=10, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - repeat=False, - task_encoder=VerifyWorkerTaskEncoder( - expected_num_workers=worker_config_r1.num_workers - ), - ), - prefetch_factor=2, - worker_type=ForkDataLoaderWorker, - gc_collect_every_n_steps=10, - gc_freeze_at_start=True, - watchdog_timeout_seconds=60, - fail_on_timeout=True, - ) - assert len(train_loader_r1) == 4, len(train_loader_r1) + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(train_order1) == 55, len(train_order1) + assert len(Counter(train_order1)) == 55, Counter(train_order1) + assert all(v == 1 for v in Counter(train_order1).values()), Counter(train_order1) + + state1 = train_loader.save_state_rank() - train_order1_r1 = [ - text for idx, data in zip(range(55 * 10), train_loader_r1) for text in data.text + train_order2 = [ + text for idx, data in zip(range(55 * 10), train_loader) for text in data.text ] - print(train_order1_r1[:10]) - print(Counter(train_order1_r1)) - assert len(train_order1_r1) == 27, len(train_order1_r1) - assert len(Counter(train_order1_r1)) == 27, Counter(train_order1_r1) - assert all(v == 1 for v in Counter(train_order1_r1).values()), Counter(train_order1_r1) - train_loader_r1.save_state_rank() + with DataLoader( + get_train_dataset( + ds1_path, + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + task_encoder=VerifyWorkerTaskEncoder(expected_num_workers=0), + ), + ).with_restored_state_rank(state1) as train_loader: + cmp_order2 = [text for idx, data in zip(range(55 * 10), train_loader) for text in data.text] + assert train_order2 == cmp_order2, (train_order1, cmp_order2) + + +def test_dataloader_fork(ds1_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=2, + seed_offset=42, + ) + + # Train mode dataset + with DataLoader( + get_train_dataset( + ds1_path, + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + task_encoder=VerifyWorkerTaskEncoder(expected_num_workers=2), + ), + prefetch_factor=2, + worker_type=ForkDataLoaderWorker, + gc_collect_every_n_steps=10, + gc_freeze_at_start=True, + watchdog_timeout_seconds=60, + fail_on_timeout=True, + ) as train_loader: + assert len(train_loader) == 6, len(train_loader) + + train_order1 = [ + text for idx, data in zip(range(55 * 10), train_loader) for text in data.text + ] + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(train_order1) == 55, len(train_order1) + assert len(Counter(train_order1)) == 55, Counter(train_order1) + assert all(v == 1 for v in Counter(train_order1).values()), Counter(train_order1) - train_loader_r0.save_state_rank() + state1 = train_loader.save_state_rank() - train_order2_r0 = [ - text for idx, data in zip(range(55 * 10), train_loader_r0) for text in data.text + train_order2 = [ + text for idx, data in zip(range(55 * 10), train_loader) for text in data.text ] - assert len(train_order2_r0) == 28 - train_order2_r1 = [ - text for idx, data in zip(range(55 * 10), train_loader_r1) for text in data.text + assert len(train_order1) == len(train_order2), (len(train_order1), len(train_order2)) + + with DataLoader( + get_train_dataset( + ds1_path, + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + task_encoder=VerifyWorkerTaskEncoder(expected_num_workers=worker_config.num_workers), + ), + prefetch_factor=2, + worker_type=ForkDataLoaderWorker, + gc_collect_every_n_steps=10, + gc_freeze_at_start=True, + watchdog_timeout_seconds=60, + fail_on_timeout=True, + ).with_restored_state_rank(state1) as train_loader: + cmp_order2 = [text for idx, data in zip(range(55 * 10), train_loader) for text in data.text] + assert train_order2 == cmp_order2, (train_order1, cmp_order2) + + +def test_dataloader_fork_multi_parallel(ds1_path): + torch.manual_seed(42) + worker_config_r0 = WorkerConfig( + rank=0, + world_size=2, + num_workers=2, + seed_offset=42, + ) + worker_config_r1 = WorkerConfig( + rank=1, + world_size=2, + num_workers=2, + seed_offset=42, + ) + + # Train mode dataset + train_loader_r0 = DataLoader( + get_train_dataset( + ds1_path, + worker_config=worker_config_r0, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + task_encoder=VerifyWorkerTaskEncoder(expected_num_workers=worker_config_r0.num_workers), + ), + prefetch_factor=2, + worker_type=ForkDataLoaderWorker, + gc_collect_every_n_steps=10, + gc_freeze_at_start=True, + watchdog_timeout_seconds=60, + fail_on_timeout=True, + ) + assert len(train_loader_r0) == 4, len(train_loader_r0) + + train_order1_r0 = [ + text for idx, data in zip(range(55 * 10), train_loader_r0) for text in data.text + ] + print(train_order1_r0[:10]) + print(Counter(train_order1_r0)) + assert len(train_order1_r0) == 28, len(train_order1_r0) + assert len(Counter(train_order1_r0)) == 28, Counter(train_order1_r0) + assert all(v == 1 for v in Counter(train_order1_r0).values()), Counter(train_order1_r0) + + train_loader_r1 = DataLoader( + get_train_dataset( + ds1_path, + worker_config=worker_config_r1, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + task_encoder=VerifyWorkerTaskEncoder(expected_num_workers=worker_config_r1.num_workers), + ), + prefetch_factor=2, + worker_type=ForkDataLoaderWorker, + gc_collect_every_n_steps=10, + gc_freeze_at_start=True, + watchdog_timeout_seconds=60, + fail_on_timeout=True, + ) + assert len(train_loader_r1) == 4, len(train_loader_r1) + + train_order1_r1 = [ + text for idx, data in zip(range(55 * 10), train_loader_r1) for text in data.text + ] + print(train_order1_r1[:10]) + print(Counter(train_order1_r1)) + assert len(train_order1_r1) == 27, len(train_order1_r1) + assert len(Counter(train_order1_r1)) == 27, Counter(train_order1_r1) + assert all(v == 1 for v in Counter(train_order1_r1).values()), Counter(train_order1_r1) + + train_loader_r1.save_state_rank() + + train_loader_r0.save_state_rank() + + train_order2_r0 = [ + text for idx, data in zip(range(55 * 10), train_loader_r0) for text in data.text + ] + assert len(train_order2_r0) == 28 + + train_order2_r1 = [ + text for idx, data in zip(range(55 * 10), train_loader_r1) for text in data.text + ] + assert len(train_order2_r1) == 27 + + train_loader_r0.shutdown() + train_loader_r1.shutdown() + + +def test_dataloader_thread(ds1_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=2, + seed_offset=42, + ) + + # Train mode dataset + with DataLoader( + get_train_dataset( + ds1_path, + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + task_encoder=VerifyWorkerTaskEncoder(expected_num_workers=worker_config.num_workers), + ), + prefetch_factor=2, + worker_type=ThreadDataLoaderWorker, + gc_collect_every_n_steps=0, + watchdog_timeout_seconds=60, + fail_on_timeout=True, + ) as train_loader: + assert len(train_loader) == 6, len(train_loader) + + train_order1 = [ + text for idx, data in zip(range(55 * 10), train_loader) for text in data.text + ] + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(train_order1) == 55, len(train_order1) + assert len(Counter(train_order1)) == 55, Counter(train_order1) + assert all(v == 1 for v in Counter(train_order1).values()), Counter(train_order1) + + state1 = train_loader.save_state_rank() + + train_order2 = [ + text for idx, data in zip(range(55 * 10), train_loader) for text in data.text ] - assert len(train_order2_r1) == 27 - - train_loader_r0.shutdown() - train_loader_r1.shutdown() - - def test_dataloader_thread(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=2, - seed_offset=42, - ) - # Train mode dataset - with DataLoader( - get_train_dataset( - self.ds1_path, - worker_config=worker_config, - batch_size=10, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - repeat=False, - task_encoder=VerifyWorkerTaskEncoder( - expected_num_workers=worker_config.num_workers - ), - ), - prefetch_factor=2, - worker_type=ThreadDataLoaderWorker, - gc_collect_every_n_steps=0, - watchdog_timeout_seconds=60, - fail_on_timeout=True, - ) as train_loader: - assert len(train_loader) == 6, len(train_loader) - - train_order1 = [ - text for idx, data in zip(range(55 * 10), train_loader) for text in data.text - ] - print(train_order1[:10]) - print(Counter(train_order1)) - assert len(train_order1) == 55, len(train_order1) - assert len(Counter(train_order1)) == 55, Counter(train_order1) - assert all(v == 1 for v in Counter(train_order1).values()), Counter(train_order1) - - state1 = train_loader.save_state_rank() - - train_order2 = [ - text for idx, data in zip(range(55 * 10), train_loader) for text in data.text - ] - - with DataLoader( - get_train_dataset( - self.ds1_path, - worker_config=worker_config, - batch_size=10, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - repeat=False, - task_encoder=VerifyWorkerTaskEncoder( - expected_num_workers=worker_config.num_workers - ), - ), - prefetch_factor=2, - worker_type=ThreadDataLoaderWorker, - gc_collect_every_n_steps=0, - watchdog_timeout_seconds=60, - fail_on_timeout=True, - ).with_restored_state_rank(state1) as train_loader: - cmp_order2 = [ - text for idx, data in zip(range(55 * 10), train_loader) for text in data.text - ] - assert train_order2 == cmp_order2, (train_order1, cmp_order2) - - -if __name__ == "__main__": - unittest.main() + with DataLoader( + get_train_dataset( + ds1_path, + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + task_encoder=VerifyWorkerTaskEncoder(expected_num_workers=worker_config.num_workers), + ), + prefetch_factor=2, + worker_type=ThreadDataLoaderWorker, + gc_collect_every_n_steps=0, + watchdog_timeout_seconds=60, + fail_on_timeout=True, + ).with_restored_state_rank(state1) as train_loader: + cmp_order2 = [text for idx, data in zip(range(55 * 10), train_loader) for text in data.text] + assert train_order2 == cmp_order2, (train_order1, cmp_order2) diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 718c448b..a172ba9b 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -11,7 +11,6 @@ import random import sys import tempfile -import unittest import warnings from collections import defaultdict from dataclasses import dataclass @@ -19,6 +18,7 @@ from typing import Hashable, List, Tuple, Type, Union import numpy as np +import pytest import torch import webdataset as wds from click.testing import CliRunner @@ -92,221 +92,224 @@ class CaptioningBatch(Batch): caption: torch.Tensor -class TestDataset(unittest.TestCase): - # Set up the test fixture - def setUp(self): - logging.basicConfig(stream=sys.stderr, level=logging.INFO) - warnings.simplefilter("ignore", ResourceWarning) - - # Create a temporary directory - self.temp_dir = tempfile.TemporaryDirectory() - self.dataset_path = Path(self.temp_dir.name) - # self.dataset_path = Path("./test_dataset") - - self.dataset_path.mkdir(exist_ok=True, parents=True) - - # Create a small dummy captioning dataset - self.samples = self.create_captioning_test_dataset(self.dataset_path, DATASET_SIZE) - print(self.dataset_path) - - def tearDown(self): - # Remove all temporary files - gc.collect() - self.temp_dir.cleanup() - - @staticmethod - def create_captioning_test_dataset(path: Union[str, Path], num_samples: int = 50): - """Creates a small dummy captioning dataset for testing purposes.""" - path = Path(path) - - animals = ( - "ant bee beetle bug bumblebee butterfly caterpillar cicada cricket dragonfly earwig " - "firefly grasshopper honeybee hornet inchworm ladybug locust mantis mayfly mosquito " - "moth sawfly silkworm termite wasp woodlouse" - ).split() - adjectives = ( - "adorable affable amazing amiable attractive beautiful calm charming cherubic classic " - "classy convivial cordial cuddly curly cute debonair elegant famous fresh friendly " - "funny gorgeous graceful gregarious grinning handsome hilarious hot interesting kind " - "laughing lovely meek mellow merciful neat nifty notorious poetic pretty refined " - "refreshing sexy smiling sociable spiffy stylish sweet tactful whimsical" - ).split() - - # Set random seeds for numpy and torch - np.random.seed(42) - torch.manual_seed(42) - - entries = [] - - assert num_samples < len(animals) * len(adjectives), ( - "Cannot generate more samples than unique captions." - ) - - # Create num_samples unique captions - captions = set() - while len(captions) < num_samples: - # Create random description by sampling from adjectives and animals - adjective = np.random.choice(adjectives) - prefix = "An" if adjective[0] in "aeiou" else "A" - description = f"{prefix} {adjective} {np.random.choice(animals)}." - captions.add(description) - - (path / "parts").mkdir(exist_ok=True, parents=True) - - # Initialize the ShardWriter - with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=30) as shard_writer: - for idx in range(num_samples): - # Create a dummy image with random noise and save to disk - img_buf = io.BytesIO() - randimg = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) - image = Image.fromarray(randimg) - image.save(img_buf, format="PNG") - img_bytes = img_buf.getvalue() - - description = captions.pop() - - entries.append({"image": randimg, "caption": description}) - - # Write individual files to shards - shard_writer.write( - { - "__key__": f"{idx:06d}", - "png": img_bytes, - "txt": description.encode("utf-8"), - "json": json.dumps({"caption": description}), - }, - ) - total_shards = shard_writer.shard - - BaseWebdatasetFactory.prepare_dataset( - path, - [f"parts/data-{{0..{total_shards - 1}}}.tar"], - split_parts_ratio=[("train", 1.0)], - ) - - with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: - f.write( - "\n".join( - [ - "sample_type:", - " __module__: megatron.energon", - " __class__: CaptioningSample", - "field_map:", - " image: png", - " caption: txt", - ] - ) +@pytest.fixture +def temp_dir(): + temp_dir = tempfile.TemporaryDirectory() + yield temp_dir + gc.collect() + temp_dir.cleanup() + + +@pytest.fixture +def dataset_path(temp_dir): + dataset_path = Path(temp_dir.name) + dataset_path.mkdir(exist_ok=True, parents=True) + return dataset_path + + +@pytest.fixture +def samples(dataset_path): + return create_captioning_test_dataset(dataset_path, DATASET_SIZE) + + +@pytest.fixture(autouse=True) +def setup_logging(): + logging.basicConfig(stream=sys.stderr, level=logging.INFO) + warnings.simplefilter("ignore", ResourceWarning) + + +def create_captioning_test_dataset(path: Union[str, Path], num_samples: int = 50): + """Creates a small dummy captioning dataset for testing purposes.""" + path = Path(path) + + animals = ( + "ant bee beetle bug bumblebee butterfly caterpillar cicada cricket dragonfly earwig " + "firefly grasshopper honeybee hornet inchworm ladybug locust mantis mayfly mosquito " + "moth sawfly silkworm termite wasp woodlouse" + ).split() + adjectives = ( + "adorable affable amazing amiable attractive beautiful calm charming cherubic classic " + "classy convivial cordial cuddly curly cute debonair elegant famous fresh friendly " + "funny gorgeous graceful gregarious grinning handsome hilarious hot interesting kind " + "laughing lovely meek mellow merciful neat nifty notorious poetic pretty refined " + "refreshing sexy smiling sociable spiffy stylish sweet tactful whimsical" + ).split() + + # Set random seeds for numpy and torch + np.random.seed(42) + torch.manual_seed(42) + + entries = [] + + assert num_samples < len(animals) * len(adjectives), ( + "Cannot generate more samples than unique captions." + ) + + # Create num_samples unique captions + captions = set() + while len(captions) < num_samples: + # Create random description by sampling from adjectives and animals + adjective = np.random.choice(adjectives) + prefix = "An" if adjective[0] in "aeiou" else "A" + description = f"{prefix} {adjective} {np.random.choice(animals)}." + captions.add(description) + + (path / "parts").mkdir(exist_ok=True, parents=True) + + # Initialize the ShardWriter + with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=30) as shard_writer: + for idx in range(num_samples): + # Create a dummy image with random noise and save to disk + img_buf = io.BytesIO() + randimg = np.random.randint(0, 255, (100, 100, 3), dtype=np.uint8) + image = Image.fromarray(randimg) + image.save(img_buf, format="PNG") + img_bytes = img_buf.getvalue() + + description = captions.pop() + + entries.append({"image": randimg, "caption": description}) + + # Write individual files to shards + shard_writer.write( + { + "__key__": f"{idx:06d}", + "png": img_bytes, + "txt": description.encode("utf-8"), + "json": json.dumps({"caption": description}), + }, ) + total_shards = shard_writer.shard + + BaseWebdatasetFactory.prepare_dataset( + path, + [f"parts/data-{{0..{total_shards - 1}}}.tar"], + split_parts_ratio=[("train", 1.0)], + ) + + with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: + f.write( + "\n".join( + [ + "sample_type:", + " __module__: megatron.energon", + " __class__: CaptioningSample", + "field_map:", + " image: png", + " caption: txt", + ] + ) + ) - with open(path / MAIN_FOLDER_NAME / "dataset_field.yaml", "w") as f: - f.write( - "\n".join( - [ - "sample_type:", - " __module__: megatron.energon", - " __class__: CaptioningSample", - "field_map:", - " image: png", - " caption: json[caption]", - ] - ) + with open(path / MAIN_FOLDER_NAME / "dataset_field.yaml", "w") as f: + f.write( + "\n".join( + [ + "sample_type:", + " __module__: megatron.energon", + " __class__: CaptioningSample", + "field_map:", + " image: png", + " caption: json[caption]", + ] ) + ) - with open(path / MAIN_FOLDER_NAME / "dataset_sample_loader.yaml", "w") as f: - f.write( - "\n".join( - [ - "sample_type:", - " __module__: megatron.energon", - " __class__: CaptioningSample", - "sample_loader: sample_loader.py:sample_loader", - "part_filter: sample_loader.py:part_filter", - ] - ) + with open(path / MAIN_FOLDER_NAME / "dataset_sample_loader.yaml", "w") as f: + f.write( + "\n".join( + [ + "sample_type:", + " __module__: megatron.energon", + " __class__: CaptioningSample", + "sample_loader: sample_loader.py:sample_loader", + "part_filter: sample_loader.py:part_filter", + ] ) + ) - with open(path / MAIN_FOLDER_NAME / "dataset_sample_loader_key.yaml", "w") as f: - f.write( - "\n".join( - [ - "sample_type:", - " __module__: megatron.energon", - " __class__: CaptioningSample", - "sample_loader: sample_loader.py:sample_loader_key", - "part_filter: sample_loader.py:part_filter", - ] - ) + with open(path / MAIN_FOLDER_NAME / "dataset_sample_loader_key.yaml", "w") as f: + f.write( + "\n".join( + [ + "sample_type:", + " __module__: megatron.energon", + " __class__: CaptioningSample", + "sample_loader: sample_loader.py:sample_loader_key", + "part_filter: sample_loader.py:part_filter", + ] ) + ) - with open(path / MAIN_FOLDER_NAME / "sample_loader.py", "w") as f: - f.write( - "\n".join( - [ - "def sample_loader(raw: dict) -> dict:", - " assert 'txt' not in raw", - " return dict(", - ' image=raw["png"],', - ' caption="" + raw["json"]["caption"],', - " )", - "", - "def sample_loader_key(raw: dict) -> dict:", - " assert 'txt' not in raw", - " return dict(", - ' __key__="" + raw["__key__"],', - ' image=raw["png"],', - ' caption="" + raw["json"]["caption"],', - " )", - "", - "def part_filter(part: str) -> bool:", - ' return part in ["json", "png"]', - "", - ] - ) + with open(path / MAIN_FOLDER_NAME / "sample_loader.py", "w") as f: + f.write( + "\n".join( + [ + "def sample_loader(raw: dict) -> dict:", + " assert 'txt' not in raw", + " return dict(", + ' image=raw["png"],', + ' caption="" + raw["json"]["caption"],', + " )", + "", + "def sample_loader_key(raw: dict) -> dict:", + " assert 'txt' not in raw", + " return dict(", + ' __key__="" + raw["__key__"],', + ' image=raw["png"],', + ' caption="" + raw["json"]["caption"],', + " )", + "", + "def part_filter(part: str) -> bool:", + ' return part in ["json", "png"]', + "", + ] ) + ) - with open(path / MAIN_FOLDER_NAME / "dataset_exclude.yaml", "w") as f: - f.write( - "\n".join( - [ - "sample_type:", - " __module__: megatron.energon", - " __class__: CaptioningSample", - "field_map:", - " image: png", - " caption: txt", - "split_config: split2.yaml", - ] - ) + with open(path / MAIN_FOLDER_NAME / "dataset_exclude.yaml", "w") as f: + f.write( + "\n".join( + [ + "sample_type:", + " __module__: megatron.energon", + " __class__: CaptioningSample", + "field_map:", + " image: png", + " caption: txt", + "split_config: split2.yaml", + ] ) + ) - with open(path / MAIN_FOLDER_NAME / "split2.yaml", "w") as f: - with open(path / MAIN_FOLDER_NAME / "split.yaml", "r") as rf: - origsplit = rf.read() - f.write( - origsplit - + "\n" - + "\n".join( - [ - "exclude:", - " - parts/data-0.tar", - " - parts/data-1.tar/00003{5..9}", - ] - ) + with open(path / MAIN_FOLDER_NAME / "split2.yaml", "w") as f: + with open(path / MAIN_FOLDER_NAME / "split.yaml", "r") as rf: + origsplit = rf.read() + f.write( + origsplit + + "\n" + + "\n".join( + [ + "exclude:", + " - parts/data-0.tar", + " - parts/data-1.tar/00003{5..9}", + ] ) + ) - return entries + return entries - def test_captioning_dataset(self): - ds = get_dataset_from_config( - self.dataset_path, - split_part="train", - worker_config=no_worker_config, - training=False, - sample_type=CaptioningSample, - ) - ds = MapDataset( - ds.build(), +def test_captioning_dataset(dataset_path, samples): + def new_ds(): + return MapDataset( + get_dataset_from_config( + dataset_path, + split_part="train", + worker_config=no_worker_config, + training=False, + sample_type=CaptioningSample, + ).build(), lambda x: CaptioningSample( __key__=x.__key__, __restore_key__=x.__restore_key__, @@ -317,1469 +320,1451 @@ def test_captioning_dataset(self): worker_config=no_worker_config, ) - # Check len operator - assert len(ds) == 50 - # Check if iterating returns the same - with get_loader(ds) as l1, get_loader(ds) as l2: - iter1 = list(l1) - iter2 = list(l2) - assert len(iter1) == 50 - assert len(iter2) == 50 - assert all(elem1.__key__ == elem2.__key__ for elem1, elem2 in zip(iter1, iter2)) - - # Check case when batch size is larger than dataset size + ds = new_ds() + # Check len operator + assert len(ds) == 50 + # Check if iterating returns the same + with get_loader(ds) as l1, get_loader(new_ds()) as l2: + iter1 = list(l1) + iter2 = list(l2) + assert len(iter1) == 50 + assert len(iter2) == 50 + assert all(elem1.__key__ == elem2.__key__ for elem1, elem2 in zip(iter1, iter2)) + + # Check case when batch size is larger than dataset size + batch_sizes = [] + with get_loader( + BatchDataset( + new_ds(), + batch_size=DATASET_SIZE * 2, + batcher=generic_batch, + worker_config=no_worker_config, + ) + ) as l: + for wrapped_sample in l: + batch_sizes.append(wrapped_sample.image.shape[0]) + assert batch_sizes == [DATASET_SIZE] + + # Check returned dimensions and batch sizes if batch size is smaller than dataset size + batch_size = 4 + assert batch_size < DATASET_SIZE + + batched_ds = BatchDataset( + new_ds(), batch_size=batch_size, batcher=generic_batch, worker_config=no_worker_config + ) + + cnt = 0 + expected_num_batches = math.ceil(DATASET_SIZE / batch_size) + with get_loader(batched_ds) as l: + for idx, wrapped_sample in enumerate(l): + # Check batch sizes + if idx < expected_num_batches - 1: + assert wrapped_sample.image.shape[0] == batch_size + assert wrapped_sample.caption.shape[0] == batch_size + else: + assert wrapped_sample.image.shape[0] == DATASET_SIZE % batch_size + assert wrapped_sample.caption.shape[0] == DATASET_SIZE % batch_size + + # Check image size + assert tuple(wrapped_sample.image.shape[1:]) == (3, 100, 100) + + cnt += 1 + + logging.info(f" Batch {idx}:") + logging.info(f" {wrapped_sample.image.shape=}") + logging.info(f" {wrapped_sample.caption.shape=}") + + assert cnt == expected_num_batches + + # Check if actual image and caption data are correct + with get_loader( + BatchDataset(new_ds(), batch_size=9, batcher=generic_batch, worker_config=no_worker_config), + ) as loader: batch_sizes = [] - with get_loader( - BatchDataset( - ds, - batch_size=DATASET_SIZE * 2, - batcher=generic_batch, - worker_config=no_worker_config, + dataset_samples = {sample["caption"]: sample["image"] for sample in samples} + for idx, sample in enumerate(loader): + batch_sizes.append(sample.image.shape[0]) + for bidx in range(sample.image.shape[0]): + refimg = dataset_samples.pop( + sample.caption[bidx].numpy().tobytes().rstrip(b"\0").decode() + ) + assert torch.allclose( + sample.image[bidx], + torch.permute(torch.tensor(refimg, dtype=torch.float32) / 255, (2, 0, 1)), + ) + assert len(dataset_samples) == 0 + assert batch_sizes == [9, 9, 9, 9, 9, 5] + + +def test_field_access(dataset_path, samples): + ds = get_dataset_from_config( + dataset_path, + dataset_config="dataset_field.yaml", + split_part="train", + worker_config=no_worker_config, + training=False, + sample_type=CaptioningSample, + ) + captions = set(sample["caption"] for sample in samples) + with get_loader(ds.build()) as loader: + for sample in loader: + captions.remove(sample.caption) + assert len(captions) == 0 + + +def test_sample_loader(dataset_path, samples): + ds = get_dataset_from_config( + dataset_path, + dataset_config="dataset_sample_loader.yaml", + split_part="train", + worker_config=no_worker_config, + training=False, + sample_type=CaptioningSample, + ) + captions = set(sample["caption"] for sample in samples) + with get_loader(ds.build()) as loader: + for sample in loader: + assert sample.caption[:4] == "" + captions.remove(sample.caption[4:]) + assert len(captions) == 0 + + +def test_sample_loader_key(dataset_path, samples): + ds = get_dataset_from_config( + dataset_path, + dataset_config="dataset_sample_loader_key.yaml", + split_part="train", + worker_config=no_worker_config, + training=False, + sample_type=CaptioningSample, + ) + captions = set(sample["caption"] for sample in samples) + keys = set(f"parts/data-{idx // 30:d}.tar/{idx:06d}" for idx in range(len(samples))) + with get_loader(ds.build()) as loader: + for sample in loader: + assert sample.caption[:4] == "" + captions.remove(sample.caption[4:]) + keys.remove(sample.__key__) + assert len(captions) == 0 + assert len(keys) == 0 + + +def test_exclusion(dataset_path, samples): + ds = get_dataset_from_config( + dataset_path, + dataset_config="dataset_exclude.yaml", + split_part="train", + worker_config=no_worker_config, + training=False, + sample_type=CaptioningSample, + ) + + with get_loader(ds.build()) as loader: + keys = [entry.__key__ for entry in loader] + assert keys == [ + f"parts/data-1.tar/{i:06d}" for i in list(range(30, 35)) + list(range(40, 50)) + ], keys + + +def test_loader(dataset_path, samples): + torch.manual_seed(42) + + class TestTaskEncoder(DefaultTaskEncoder): + def __init__(self): + super().__init__(raw_batch_type=CaptioningBatch) + + def encode_sample(self, sample: CaptioningSample) -> EncodedCaptioningSample: + return EncodedCaptioningSample.derive_from( + sample, + image=sample.image, + caption=torch.frombuffer(bytearray(sample.caption.encode()), dtype=torch.uint8), ) - ) as l: - for wrapped_sample in l: - batch_sizes.append(wrapped_sample.image.shape[0]) - assert batch_sizes == [DATASET_SIZE] - # Check returned dimensions and batch sizes if batch size is smaller than dataset size - batch_size = 4 - assert batch_size < DATASET_SIZE - - batched_ds = BatchDataset( - ds, batch_size=batch_size, batcher=generic_batch, worker_config=no_worker_config - ) - - cnt = 0 - expected_num_batches = math.ceil(DATASET_SIZE / batch_size) - with get_loader(batched_ds) as l: - for idx, wrapped_sample in enumerate(l): - # Check batch sizes - if idx < expected_num_batches - 1: - assert wrapped_sample.image.shape[0] == batch_size - assert wrapped_sample.caption.shape[0] == batch_size - else: - assert wrapped_sample.image.shape[0] == DATASET_SIZE % batch_size - assert wrapped_sample.caption.shape[0] == DATASET_SIZE % batch_size - - # Check image size - assert tuple(wrapped_sample.image.shape[1:]) == (3, 100, 100) - - cnt += 1 - - logging.info(f" Batch {idx}:") - logging.info(f" {wrapped_sample.image.shape=}") - logging.info(f" {wrapped_sample.caption.shape=}") - - assert cnt == expected_num_batches - - # Check if actual image and caption data are correct - with get_loader( - BatchDataset(ds, batch_size=9, batcher=generic_batch, worker_config=no_worker_config), - ) as loader: - batch_sizes = [] - dataset_samples = {sample["caption"]: sample["image"] for sample in self.samples} - for idx, sample in enumerate(loader): - batch_sizes.append(sample.image.shape[0]) - for bidx in range(sample.image.shape[0]): - refimg = dataset_samples.pop( - sample.caption[bidx].numpy().tobytes().rstrip(b"\0").decode() - ) - assert torch.allclose( - sample.image[bidx], - torch.permute(torch.tensor(refimg, dtype=torch.float32) / 255, (2, 0, 1)), - ) - assert len(dataset_samples) == 0 - assert batch_sizes == [9, 9, 9, 9, 9, 5] - - def test_field_access(self): - ds = get_dataset_from_config( - self.dataset_path, - dataset_config="dataset_field.yaml", - split_part="train", - worker_config=no_worker_config, - training=False, - sample_type=CaptioningSample, - ) - captions = set(sample["caption"] for sample in self.samples) - with get_loader(ds.build()) as loader: - for sample in loader: - captions.remove(sample.caption) - assert len(captions) == 0 - - def test_sample_loader(self): - ds = get_dataset_from_config( - self.dataset_path, - dataset_config="dataset_sample_loader.yaml", - split_part="train", - worker_config=no_worker_config, - training=False, - sample_type=CaptioningSample, - ) - captions = set(sample["caption"] for sample in self.samples) - with get_loader(ds.build()) as loader: - for sample in loader: - assert sample.caption[:4] == "" - captions.remove(sample.caption[4:]) - assert len(captions) == 0 - - def test_sample_loader_key(self): - ds = get_dataset_from_config( - self.dataset_path, - dataset_config="dataset_sample_loader_key.yaml", - split_part="train", + with get_loader( + get_train_dataset( + dataset_path, + batch_size=10, worker_config=no_worker_config, - training=False, - sample_type=CaptioningSample, - ) - captions = set(sample["caption"] for sample in self.samples) - keys = set( - f"parts/data-{idx // 30:d}.tar/{idx:06d}" for idx in range(len(self.samples)) + parallel_shard_iters=2, + virtual_epoch_length=2, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + task_encoder=TestTaskEncoder(), ) - with get_loader(ds.build()) as loader: - for sample in loader: - assert sample.caption[:4] == "" - captions.remove(sample.caption[4:]) - keys.remove(sample.__key__) - assert len(captions) == 0 - assert len(keys) == 0 - - def test_exclusion(self): - ds = get_dataset_from_config( - self.dataset_path, - dataset_config="dataset_exclude.yaml", + ) as loader: + assert len(loader) == 2 + + def hist(data): + """Histogram function""" + r = defaultdict(lambda: 0) + for k in data: + r[k] += 1 + return r + + print([[batch.__key__ for batch in loader] for _ in range(100)]) + keys = [key for _ in range(100) for batch in loader for key in batch.__key__] + # 100 iterations, 2 virtual epoch size, batch size 10 + print(len(keys), keys) + keyhist = hist(keys) + print(sorted(keyhist.items())) + print(sorted(keyhist.items(), key=lambda x: (x[1], x[0]))) + assert len(keys) == 100 * 2 * 10 + # Data should be approximately sampled uniformly (40+-1 samples per key) + assert len(keyhist) == 50 + assert all(v in (39, 40, 41) for v in keyhist.values()) + + with get_loader( + get_val_dataset( + dataset_path, split_part="train", + batch_size=10, worker_config=no_worker_config, - training=False, - sample_type=CaptioningSample, + task_encoder=TestTaskEncoder(), ) + ) as loader2: + assert len(loader2) == 5 + # The order in the split is shuffled this way + assert list(key for batch in loader2 for key in batch.__key__) == [ + f"parts/data-1.tar/{i:06d}" for i in range(30, 50) + ] + [f"parts/data-0.tar/{i:06d}" for i in range(30)] - with get_loader(ds.build()) as loader: - keys = [entry.__key__ for entry in loader] - assert keys == [ - f"parts/data-1.tar/{i:06d}" for i in list(range(30, 35)) + list(range(40, 50)) - ], keys - - def test_loader(self): - torch.manual_seed(42) - class TestTaskEncoder(DefaultTaskEncoder): - def __init__(self): - super().__init__(raw_batch_type=CaptioningBatch) - - def encode_sample(self, sample: CaptioningSample) -> EncodedCaptioningSample: - return EncodedCaptioningSample.derive_from( - sample, - image=sample.image, - caption=torch.frombuffer(bytearray(sample.caption.encode()), dtype=torch.uint8), - ) +def test_default_dataset(dataset_path, samples): + torch.manual_seed(42) - with get_loader( + with ( + get_loader( get_train_dataset( - self.dataset_path, + dataset_path, batch_size=10, worker_config=no_worker_config, - parallel_shard_iters=2, - virtual_epoch_length=2, shuffle_buffer_size=None, max_samples_per_sequence=None, - task_encoder=TestTaskEncoder(), ) - ) as loader: - assert len(loader) == 2 - - def hist(data): - """Histogram function""" - r = defaultdict(lambda: 0) - for k in data: - r[k] += 1 - return r - - print([[batch.__key__ for batch in loader] for _ in range(100)]) - keys = [key for _ in range(100) for batch in loader for key in batch.__key__] - # 100 iterations, 2 virtual epoch size, batch size 10 - print(len(keys), keys) - keyhist = hist(keys) - print(sorted(keyhist.items())) - print(sorted(keyhist.items(), key=lambda x: (x[1], x[0]))) - assert len(keys) == 100 * 2 * 10 - # Data should be approximately sampled uniformly (40+-1 samples per key) - assert len(keyhist) == 50 - assert all(v in (39, 40, 41) for v in keyhist.values()) - - with get_loader( + ) as train_loader, + get_loader( get_val_dataset( - self.dataset_path, + dataset_path, split_part="train", batch_size=10, worker_config=no_worker_config, - task_encoder=TestTaskEncoder(), ) - ) as loader2: - assert len(loader2) == 5 - # The order in the split is shuffled this way - assert list(key for batch in loader2 for key in batch.__key__) == [ - f"parts/data-1.tar/{i:06d}" for i in range(30, 50) - ] + [f"parts/data-0.tar/{i:06d}" for i in range(30)] - - def test_default_dataset(self): - torch.manual_seed(42) - - with ( - get_loader( - get_train_dataset( - self.dataset_path, - batch_size=10, - worker_config=no_worker_config, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ) - ) as train_loader, - get_loader( - get_val_dataset( - self.dataset_path, - split_part="train", - batch_size=10, - worker_config=no_worker_config, - ) - ) as val_loader, - ): - n_samples = 0 - for i, sample in zip(range(100), train_loader): - assert sample.image.shape == (10, 3, 100, 100) - n_samples += sample.image.shape[0] - assert n_samples == 1000 - n_samples = 0 - for sample in val_loader: - assert sample.image.shape == (10, 3, 100, 100) - n_samples += sample.image.shape[0] - assert n_samples == 50 - - def test_no_batching(self): - with get_loader( - get_train_dataset( - self.dataset_path, - batch_size=None, - worker_config=no_worker_config, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ) - ) as train_loader: - one_sample = next(iter(train_loader)) + ) as val_loader, + ): + n_samples = 0 + for i, sample in zip(range(100), train_loader): + assert sample.image.shape == (10, 3, 100, 100) + n_samples += sample.image.shape[0] + assert n_samples == 1000 + n_samples = 0 + for sample in val_loader: + assert sample.image.shape == (10, 3, 100, 100) + n_samples += sample.image.shape[0] + assert n_samples == 50 + + +def test_no_batching(dataset_path, samples): + with get_loader( + get_train_dataset( + dataset_path, + batch_size=None, + worker_config=no_worker_config, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + ) as train_loader: + one_sample = next(iter(train_loader)) - # Single sample without batching - assert isinstance(one_sample.image, torch.Tensor) - assert isinstance(one_sample.caption, str) + # Single sample without batching + assert isinstance(one_sample.image, torch.Tensor) + assert isinstance(one_sample.caption, str) - def test_dataset_len(self): - torch.manual_seed(42) - worker_config = WorkerConfig(rank=0, world_size=1, num_workers=4) +def test_dataset_len(dataset_path, samples): + torch.manual_seed(42) - train_dataset = get_train_dataset( - self.dataset_path, - batch_size=11, - worker_config=worker_config, - virtual_epoch_length=12, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ) - with get_loader(train_dataset) as train_loader: - assert len(train_dataset) == 12 - assert len(train_loader) == 12 - assert len(list(train_loader)) == 12 + worker_config = WorkerConfig(rank=0, world_size=1, num_workers=4) - val_dataset = get_val_dataset( - self.dataset_path, split_part="train", batch_size=1, worker_config=no_worker_config - ) - with get_loader(val_dataset) as val_loader: - assert len(val_loader) == 50 - assert len(list(val_loader)) == 50 + train_dataset = get_train_dataset( + dataset_path, + batch_size=11, + worker_config=worker_config, + virtual_epoch_length=12, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + with get_loader(train_dataset) as train_loader: + assert len(train_dataset) == 12 + assert len(train_loader) == 12 + assert len(list(train_loader)) == 12 val_dataset = get_val_dataset( - self.dataset_path, split_part="train", batch_size=11, worker_config=worker_config + dataset_path, split_part="train", batch_size=1, worker_config=no_worker_config ) - with get_loader(val_dataset) as val_loader: - # n samples: ceil(50 / 11) // 4 * 4 - assert len(val_dataset) == 8 - assert len(val_loader) == 8 - assert len(list(val_loader)) == 8 - assert [len(entry.__key__) for entry in val_loader] == [11, 11, 11, 11, 2, 1, 2, 1] - assert sum(len(entry.__key__) for entry in val_loader) == 50 - - def test_multirank_dataset(self): - torch.manual_seed(42) - - worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2) - worker_config_r1 = WorkerConfig(rank=1, world_size=2, num_workers=2) - - train_dataset = get_train_dataset( - self.dataset_path, - batch_size=11, - worker_config=worker_config_r0, - virtual_epoch_length=12, - shuffle_buffer_size=None, - max_samples_per_sequence=None, + with get_loader(val_dataset) as val_loader: + assert len(val_loader) == 50 + assert len(list(val_loader)) == 50 + + val_dataset = get_val_dataset( + dataset_path, split_part="train", batch_size=11, worker_config=worker_config + ) + with get_loader(val_dataset) as val_loader: + # n samples: ceil(50 / 11) // 4 * 4 + assert len(val_dataset) == 8 + assert len(val_loader) == 8 + assert len(list(val_loader)) == 8 + assert [len(entry.__key__) for entry in val_loader] == [11, 11, 11, 11, 2, 1, 2, 1] + assert sum(len(entry.__key__) for entry in val_loader) == 50 + + +def test_multirank_dataset(dataset_path, samples): + torch.manual_seed(42) + + worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2) + worker_config_r1 = WorkerConfig(rank=1, world_size=2, num_workers=2) + + train_dataset = get_train_dataset( + dataset_path, + batch_size=11, + worker_config=worker_config_r0, + virtual_epoch_length=12, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + with get_loader(train_dataset) as train_loader: + assert len(train_dataset) == 12 + assert len(train_loader) == 12 + assert len(list(train_loader)) == 12 + + val_dataset0 = get_val_dataset( + dataset_path, split_part="train", batch_size=1, worker_config=worker_config_r0 + ) + with get_loader(val_dataset0) as val_loader0: + print(len(val_loader0)) + assert len(val_loader0) == 25 + keys0 = set(key for entry in val_loader0 for key in entry.__key__) + assert len(keys0) == 25 + + val_dataset0b11 = get_val_dataset( + dataset_path, split_part="train", batch_size=11, worker_config=worker_config_r0 + ) + with get_loader(val_dataset0b11) as val_loader0b11: + assert len(val_dataset0b11) == 4 + assert len(val_loader0b11) == 4 + assert len(list(val_loader0b11)) == 4 + keys0b11 = set(key for entry in val_loader0b11 for key in entry.__key__) + print([len(entry.__key__) for entry in val_loader0b11]) + assert [len(entry.__key__) for entry in val_loader0b11] == [11, 11, 2, 1] + assert len(keys0b11) == 25 + + assert keys0b11 == keys0 + + val_dataset1 = get_val_dataset( + dataset_path, split_part="train", batch_size=1, worker_config=worker_config_r1 + ) + with get_loader(val_dataset1) as val_loader1: + print(len(val_loader1)) + assert len(val_loader1) == 25 + keys1 = set(key for entry in val_loader1 for key in entry.__key__) + assert len(keys1) == 25 + print(sorted(keys1)) + print(sorted(keys0)) + assert keys1.isdisjoint(keys0) + + val_dataset1b11 = get_val_dataset( + dataset_path, split_part="train", batch_size=11, worker_config=worker_config_r1 + ) + with get_loader(val_dataset1b11) as val_loader1b11: + assert len(val_dataset1b11) == 4 + assert len(val_loader1b11) == 4 + assert len(list(val_loader1b11)) == 4 + keys1b11 = set(key for entry in val_loader1b11 for key in entry.__key__) + print([len(entry.__key__) for entry in val_loader1b11]) + assert [len(entry.__key__) for entry in val_loader1b11] == [11, 11, 2, 1] + assert len(keys1b11) == 25 + assert keys1b11.isdisjoint(keys0b11) + + assert keys1b11 == keys1 + + +def test_weight_aug(dataset_path, samples): + class WeightAugmentTaskEncoder(AugmentTaskEncoder): + def __init__(self, task_encoder: TaskEncoder, weight: float, target_data_class: type): + super().__init__(task_encoder) + self.weight = weight + self.target_data_class = target_data_class + + def encode_sample(self, sample): + sample = super().encode_sample(sample) + return self.target_data_class.extend(sample, weight=self.weight) + + torch.manual_seed(42) + + @edataclass + class WeightedCaptioningBatch(Batch): + image: torch.Tensor + caption: List[str] + weight: float + + with get_loader( + get_val_dataset( + dataset_path, + split_part="train", + batch_size=10, + worker_config=no_worker_config, + task_encoder=WeightAugmentTaskEncoder( + DefaultTaskEncoder(), + weight=0.8, + target_data_class=WeightedCaptioningBatch, + ), ) - with get_loader(train_dataset) as train_loader: - assert len(train_dataset) == 12 - assert len(train_loader) == 12 - assert len(list(train_loader)) == 12 + ) as loader: + for data in loader: + assert data.weight == [0.8] * 10 - val_dataset0 = get_val_dataset( - self.dataset_path, split_part="train", batch_size=1, worker_config=worker_config_r0 - ) - with get_loader(val_dataset0) as val_loader0: - print(len(val_loader0)) - assert len(val_loader0) == 25 - keys0 = set(key for entry in val_loader0 for key in entry.__key__) - assert len(keys0) == 25 - - val_dataset0b11 = get_val_dataset( - self.dataset_path, split_part="train", batch_size=11, worker_config=worker_config_r0 - ) - with get_loader(val_dataset0b11) as val_loader0b11: - assert len(val_dataset0b11) == 4 - assert len(val_loader0b11) == 4 - assert len(list(val_loader0b11)) == 4 - keys0b11 = set(key for entry in val_loader0b11 for key in entry.__key__) - print([len(entry.__key__) for entry in val_loader0b11]) - assert [len(entry.__key__) for entry in val_loader0b11] == [11, 11, 2, 1] - assert len(keys0b11) == 25 - - assert keys0b11 == keys0 - - val_dataset1 = get_val_dataset( - self.dataset_path, split_part="train", batch_size=1, worker_config=worker_config_r1 - ) - with get_loader(val_dataset1) as val_loader1: - print(len(val_loader1)) - assert len(val_loader1) == 25 - keys1 = set(key for entry in val_loader1 for key in entry.__key__) - assert len(keys1) == 25 - print(sorted(keys1)) - print(sorted(keys0)) - assert keys1.isdisjoint(keys0) - - val_dataset1b11 = get_val_dataset( - self.dataset_path, split_part="train", batch_size=11, worker_config=worker_config_r1 - ) - with get_loader(val_dataset1b11) as val_loader1b11: - assert len(val_dataset1b11) == 4 - assert len(val_loader1b11) == 4 - assert len(list(val_loader1b11)) == 4 - keys1b11 = set(key for entry in val_loader1b11 for key in entry.__key__) - print([len(entry.__key__) for entry in val_loader1b11]) - assert [len(entry.__key__) for entry in val_loader1b11] == [11, 11, 2, 1] - assert len(keys1b11) == 25 - assert keys1b11.isdisjoint(keys0b11) - - assert keys1b11 == keys1 - - def test_weight_aug(self): - class WeightAugmentTaskEncoder(AugmentTaskEncoder): - def __init__(self, task_encoder: TaskEncoder, weight: float, target_data_class: type): - super().__init__(task_encoder) - self.weight = weight - self.target_data_class = target_data_class - - def encode_sample(self, sample): - sample = super().encode_sample(sample) - return self.target_data_class.extend(sample, weight=self.weight) - - torch.manual_seed(42) - - @edataclass - class WeightedCaptioningBatch(Batch): - image: torch.Tensor - caption: List[str] - weight: float - - with get_loader( - get_val_dataset( - self.dataset_path, - split_part="train", - batch_size=10, - worker_config=no_worker_config, - task_encoder=WeightAugmentTaskEncoder( - DefaultTaskEncoder(), - weight=0.8, - target_data_class=WeightedCaptioningBatch, - ), - ) - ) as loader: - for data in loader: - assert data.weight == [0.8] * 10 - - def test_blending(self): - torch.manual_seed(42) - - with get_loader( - BlendDataset( - ( - get_train_dataset( - self.dataset_path, - batch_size=10, - worker_config=no_worker_config, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - 2, + +def test_blending(dataset_path, samples): + torch.manual_seed(42) + + with get_loader( + BlendDataset( + ( + get_train_dataset( + dataset_path, + batch_size=10, + worker_config=no_worker_config, + shuffle_buffer_size=None, + max_samples_per_sequence=None, ), - ( - get_train_dataset( - self.dataset_path, - batch_size=20, - worker_config=no_worker_config, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - 8, + 2, + ), + ( + get_train_dataset( + dataset_path, + batch_size=20, + worker_config=no_worker_config, + shuffle_buffer_size=None, + max_samples_per_sequence=None, ), - worker_config=no_worker_config, - ) - ) as loader: - bs_hist = {10: 0, 20: 0} - for i, sample in zip(range(1000), loader): - bs_hist[sample.image.shape[0]] += 1 - print(bs_hist) - assert 150 <= bs_hist[10] <= 250 - assert 750 <= bs_hist[20] <= 850 - - def test_mixing_homogeneous(self): - @dataclass - class TestBatch(Batch): - image: torch.Tensor - caption: List[str] - source: int - - class TestTaskEncoder(TaskEncoder): - def __init__(self, source: int): - self.source = source - - def encode_batch(self, batch): - return TestBatch.extend(batch, source=self.source) - - with get_loader( - MixBatchDataset( - ( - get_train_dataset( - self.dataset_path, - batch_size=1, - worker_config=no_worker_config, - task_encoder=TestTaskEncoder(source=0), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - 2, + 8, + ), + worker_config=no_worker_config, + ) + ) as loader: + bs_hist = {10: 0, 20: 0} + for i, sample in zip(range(1000), loader): + bs_hist[sample.image.shape[0]] += 1 + print(bs_hist) + assert 150 <= bs_hist[10] <= 250 + assert 750 <= bs_hist[20] <= 850 + + +def test_mixing_homogeneous(dataset_path, samples): + @dataclass + class TestBatch(Batch): + image: torch.Tensor + caption: List[str] + source: int + + class TestTaskEncoder(TaskEncoder): + def __init__(self, source: int): + self.source = source + + def encode_batch(self, batch): + return TestBatch.extend(batch, source=self.source) + + with get_loader( + MixBatchDataset( + ( + get_train_dataset( + dataset_path, + batch_size=1, + worker_config=no_worker_config, + task_encoder=TestTaskEncoder(source=0), + shuffle_buffer_size=None, + max_samples_per_sequence=None, ), - ( - get_train_dataset( - self.dataset_path, - batch_size=1, - worker_config=no_worker_config, - task_encoder=TestTaskEncoder(source=1), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - 8, + 2, + ), + ( + get_train_dataset( + dataset_path, + batch_size=1, + worker_config=no_worker_config, + task_encoder=TestTaskEncoder(source=1), + shuffle_buffer_size=None, + max_samples_per_sequence=None, ), - batch_size=10, - batch_mix_fn=homogeneous_concat_mix, - worker_config=no_worker_config, - ) - ) as loader: - source_hist = {0: 0, 1: 0} - for i, sample in zip(range(1000), loader): - assert sample.image.shape == (10, 3, 100, 100) - for source in sample.source: - source_hist[source] += 1 - assert 1500 <= source_hist[0] <= 2500 - assert 7500 <= source_hist[1] <= 8500 - - def test_mixing_heterogeneous(self): - @dataclass - class TestBatch1(Batch): - image: torch.Tensor - caption: List[str] - source: int - - @dataclass - class TestBatch2(TestBatch1): - pass - - class TestTaskEncoder(TaskEncoder): - def __init__(self, source: int, batch_cls: Type[TestBatch1]): - self.source = source - self.batch_cls = batch_cls - - def encode_batch(self, batch): - return self.batch_cls.extend(batch, source=self.source) - - with get_loader( - MixBatchDataset( - ( - get_train_dataset( - self.dataset_path, - batch_size=1, - worker_config=no_worker_config, - task_encoder=TestTaskEncoder(source=0, batch_cls=TestBatch1), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - 2, + 8, + ), + batch_size=10, + batch_mix_fn=homogeneous_concat_mix, + worker_config=no_worker_config, + ) + ) as loader: + source_hist = {0: 0, 1: 0} + for i, sample in zip(range(1000), loader): + assert sample.image.shape == (10, 3, 100, 100) + for source in sample.source: + source_hist[source] += 1 + assert 1500 <= source_hist[0] <= 2500 + assert 7500 <= source_hist[1] <= 8500 + + +def test_mixing_heterogeneous(dataset_path, samples): + @dataclass + class TestBatch1(Batch): + image: torch.Tensor + caption: List[str] + source: int + + @dataclass + class TestBatch2(TestBatch1): + pass + + class TestTaskEncoder(TaskEncoder): + def __init__(self, source: int, batch_cls: Type[TestBatch1]): + self.source = source + self.batch_cls = batch_cls + + def encode_batch(self, batch): + return self.batch_cls.extend(batch, source=self.source) + + with get_loader( + MixBatchDataset( + ( + get_train_dataset( + dataset_path, + batch_size=1, + worker_config=no_worker_config, + task_encoder=TestTaskEncoder(source=0, batch_cls=TestBatch1), + shuffle_buffer_size=None, + max_samples_per_sequence=None, ), - ( - get_train_dataset( - self.dataset_path, - batch_size=1, - worker_config=no_worker_config, - task_encoder=TestTaskEncoder(source=1, batch_cls=TestBatch2), - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - 8, + 2, + ), + ( + get_train_dataset( + dataset_path, + batch_size=1, + worker_config=no_worker_config, + task_encoder=TestTaskEncoder(source=1, batch_cls=TestBatch2), + shuffle_buffer_size=None, + max_samples_per_sequence=None, ), - batch_size=10, - worker_config=no_worker_config, - ) - ) as loader: - source_hist = {0: 0, 1: 0} - for i, samples in zip(range(1000), loader): - assert len(samples) == 10 - for sample in samples: - assert sample.image.shape == (1, 3, 100, 100) - source_hist[sample.source] += 1 - assert 1500 <= source_hist[0] <= 2500 - assert 7500 <= source_hist[1] <= 8500 - - def test_val_limit(self): - torch.manual_seed(42) - - with get_loader( - get_val_dataset( - self.dataset_path, - split_part="train", - batch_size=2, - worker_config=no_worker_config, - limit=3, - ) - ) as loader: - assert len(loader) == 3 + 8, + ), + batch_size=10, + worker_config=no_worker_config, + ) + ) as loader: + source_hist = {0: 0, 1: 0} + for i, samples in zip(range(1000), loader): + assert len(samples) == 10 + for sample in samples: + assert sample.image.shape == (1, 3, 100, 100) + source_hist[sample.source] += 1 + assert 1500 <= source_hist[0] <= 2500 + assert 7500 <= source_hist[1] <= 8500 + + +def test_val_limit(dataset_path, samples): + torch.manual_seed(42) + + with get_loader( + get_val_dataset( + dataset_path, + split_part="train", + batch_size=2, + worker_config=no_worker_config, + limit=3, + ) + ) as loader: + assert len(loader) == 3 - samples = [[batch.__key__ for batch in loader] for _ in range(10)] - print(samples) - for s in samples: - print(" -", s) - assert all(samples[0] == one_ep_samples for one_ep_samples in samples) + samples = [[batch.__key__ for batch in loader] for _ in range(10)] + print(samples) + for s in samples: + print(" -", s) + assert all(samples[0] == one_ep_samples for one_ep_samples in samples) - worker_config = WorkerConfig(rank=0, world_size=1, num_workers=2) + worker_config = WorkerConfig(rank=0, world_size=1, num_workers=2) - with get_loader( - get_val_dataset( - self.dataset_path, - split_part="train", - batch_size=2, - worker_config=worker_config, - limit=3, - ) - ) as loader: - assert len(loader) == 3 - - samples_wrk2 = [[batch.__key__ for batch in loader] for _ in range(10)] - print(samples_wrk2) - for s in samples_wrk2: - print(" -", s) - assert all( - all(a == b for a, b in zip(samples_wrk2[0], one_ep_samples)) - for one_ep_samples in samples_wrk2 + with get_loader( + get_val_dataset( + dataset_path, + split_part="train", + batch_size=2, + worker_config=worker_config, + limit=3, + ) + ) as loader: + assert len(loader) == 3 + + samples_wrk2 = [[batch.__key__ for batch in loader] for _ in range(10)] + print(samples_wrk2) + for s in samples_wrk2: + print(" -", s) + assert all( + all(a == b for a, b in zip(samples_wrk2[0], one_ep_samples)) + for one_ep_samples in samples_wrk2 + ) + + +def test_current_batch_index(dataset_path, samples): + # Tests if the get_current_batch_index works properly + torch.manual_seed(42) + + class TestTaskEncoder(TaskEncoder): + @stateless(restore_seeds=True) + def encode_sample(self, sample): + # print("si stack:", WorkerConfig._sample_index_stack) + return ExtendedCaptioningSample.extend( + sample, + batch_index=self.current_batch_index, + sample_index=self.current_sample_index, + rand_num=random.randint(0, 1000), ) - def test_current_batch_index(self): - # Tests if the get_current_batch_index works properly - torch.manual_seed(42) - - class TestTaskEncoder(TaskEncoder): - @stateless(restore_seeds=True) - def encode_sample(self, sample): - # print("si stack:", WorkerConfig._sample_index_stack) - return ExtendedCaptioningSample.extend( - sample, - batch_index=self.current_batch_index, - sample_index=self.current_sample_index, - rand_num=random.randint(0, 1000), - ) + # First, test simple single main-thread loader with accessing get_current_batch_index + with get_loader( + get_train_dataset( + dataset_path, + batch_size=2, + task_encoder=TestTaskEncoder(), + worker_config=no_worker_config, + shuffle_buffer_size=20, + max_samples_per_sequence=10, + ) + ) as loader: + batches = list(zip(range(20), loader)) + print("bi", [batch.batch_index for batch_idx, batch in batches]) + assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches) + + print("si", [batch.sample_index for batch_idx, batch in batches]) + assert all( + all( + si == sample_offset + batch_idx * 2 + for sample_offset, si in enumerate(batch.sample_index) + ) + for batch_idx, batch in batches + ) - # First, test simple single main-thread loader with accessing get_current_batch_index - with get_loader( + print("pk", [batch.__key__ for batch_idx, batch in batches]) + print("rk", [batch.__restore_key__ for batch_idx, batch in batches]) + assert loader.can_restore_sample() + + # These need to be hard coded to detect breaking changes + # If a change is expected, update the values with the ones printed below + ref_batch_rand_nums = [ + [661, 762], + [206, 470], + [130, 283], + [508, 61], + [625, 661], + [296, 376], + [632, 514], + [715, 406], + [555, 27], + [760, 36], + [607, 610], + [825, 219], + [564, 832], + [876, 512], + [632, 605], + [357, 738], + [40, 378], + [609, 444], + [610, 367], + [367, 69], + ] + + batch_rand_nums = [] + for batch_idx, batch in batches: + restore_batch = loader.restore_sample(batch.__restore_key__) + assert restore_batch.__key__ == batch.__key__ + assert restore_batch.batch_index == batch.batch_index + assert restore_batch.sample_index == batch.sample_index + assert restore_batch.rand_num == batch.rand_num + + batch_rand_nums.append(restore_batch.rand_num) + assert np.allclose(restore_batch.image, batch.image) + + # For constructing the test data above: + print("batch_rand_nums: ", batch_rand_nums) + assert batch_rand_nums == ref_batch_rand_nums + + # Now, test multi-worker loader with accessing get_current_batch_index + worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2) + worker_config_r1 = WorkerConfig(rank=1, world_size=2, num_workers=2) + + with ( + get_loader( get_train_dataset( - self.dataset_path, + dataset_path, batch_size=2, task_encoder=TestTaskEncoder(), - worker_config=no_worker_config, + worker_config=worker_config_r0, shuffle_buffer_size=20, max_samples_per_sequence=10, ) - ) as loader: - batches = list(zip(range(20), loader)) - print("bi", [batch.batch_index for batch_idx, batch in batches]) - assert all( - all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches - ) - - print("si", [batch.sample_index for batch_idx, batch in batches]) - assert all( - all( - si == sample_offset + batch_idx * 2 - for sample_offset, si in enumerate(batch.sample_index) - ) - for batch_idx, batch in batches - ) - - print("pk", [batch.__key__ for batch_idx, batch in batches]) - print("rk", [batch.__restore_key__ for batch_idx, batch in batches]) - assert loader.can_restore_sample() - - # These need to be hard coded to detect breaking changes - # If a change is expected, update the values with the ones printed below - ref_batch_rand_nums = [ - [661, 762], - [206, 470], - [130, 283], - [508, 61], - [625, 661], - [296, 376], - [632, 514], - [715, 406], - [555, 27], - [760, 36], - [607, 610], - [825, 219], - [564, 832], - [876, 512], - [632, 605], - [357, 738], - [40, 378], - [609, 444], - [610, 367], - [367, 69], - ] - - batch_rand_nums = [] - for batch_idx, batch in batches: - restore_batch = loader.restore_sample(batch.__restore_key__) - assert restore_batch.__key__ == batch.__key__ - assert restore_batch.batch_index == batch.batch_index - assert restore_batch.sample_index == batch.sample_index - assert restore_batch.rand_num == batch.rand_num - - batch_rand_nums.append(restore_batch.rand_num) - assert np.allclose(restore_batch.image, batch.image) - - # For constructing the test data above: - print("batch_rand_nums: ", batch_rand_nums) - assert batch_rand_nums == ref_batch_rand_nums - - # Now, test multi-worker loader with accessing get_current_batch_index - worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2) - worker_config_r1 = WorkerConfig(rank=1, world_size=2, num_workers=2) - - with ( - get_loader( - get_train_dataset( - self.dataset_path, - batch_size=2, - task_encoder=TestTaskEncoder(), - worker_config=worker_config_r0, - shuffle_buffer_size=20, - max_samples_per_sequence=10, - ) - ) as loader, - get_loader( - get_train_dataset( - self.dataset_path, - batch_size=2, - task_encoder=TestTaskEncoder(), - worker_config=worker_config_r1, - shuffle_buffer_size=20, - max_samples_per_sequence=10, - ) - ) as loader_r1, - ): - batches = list(zip(range(20), loader)) - print("bir0", [batch.batch_index for batch_idx, batch in batches]) - assert all( - all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches - ) - - print("sir0", [batch.sample_index for batch_idx, batch in batches]) - assert all( - all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches - ) - assert all( - all( - si == 2 * sample_offset + (batch_idx * 2 - batch_idx % 2) - for sample_offset, si in enumerate(batch.sample_index) - ) - for batch_idx, batch in batches - ) - - batches_r1 = list(zip(range(20), loader_r1)) - print("bir0", [batch.batch_index for batch_idx, batch in batches_r1]) - print("sir1", [batch.sample_index for batch_idx, batch in batches_r1]) - assert all( - all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches_r1 - ) - assert all( - all( - si == 2 * sample_offset + (batch_idx * 2 - batch_idx % 2) - for sample_offset, si in enumerate(batch.sample_index) - ) - for batch_idx, batch in batches_r1 - ) - - # Now, test multi-worker loader with accessing get_current_batch_index and save/restore state - with ( - get_savable_loader( - get_train_dataset( - self.dataset_path, - batch_size=2, - task_encoder=TestTaskEncoder(), - worker_config=worker_config_r0, - shuffle_buffer_size=20, - max_samples_per_sequence=10, - ) - ) as loader, - get_savable_loader( - get_train_dataset( - self.dataset_path, - batch_size=2, - task_encoder=TestTaskEncoder(), - worker_config=worker_config_r1, - shuffle_buffer_size=20, - max_samples_per_sequence=10, - ) - ) as loader_r1, - ): - batches = list(zip(range(20), loader)) - print([batch.batch_index for batch_idx, batch in batches]) - assert all( - all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches + ) as loader, + get_loader( + get_train_dataset( + dataset_path, + batch_size=2, + task_encoder=TestTaskEncoder(), + worker_config=worker_config_r1, + shuffle_buffer_size=20, + max_samples_per_sequence=10, ) - assert all( - all( - si == 2 * sample_offset + (batch_idx * 2 - batch_idx % 2) - for sample_offset, si in enumerate(batch.sample_index) - ) - for batch_idx, batch in batches + ) as loader_r1, + ): + batches = list(zip(range(20), loader)) + print("bir0", [batch.batch_index for batch_idx, batch in batches]) + assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches) + + print("sir0", [batch.sample_index for batch_idx, batch in batches]) + assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches) + assert all( + all( + si == 2 * sample_offset + (batch_idx * 2 - batch_idx % 2) + for sample_offset, si in enumerate(batch.sample_index) ) + for batch_idx, batch in batches + ) - batches_r1 = list(zip(range(20), loader_r1)) - print([batch.batch_index for batch_idx, batch in batches_r1]) - assert all( - all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches_r1 - ) - assert all( - all( - si == 2 * sample_offset + (batch_idx * 2 - batch_idx % 2) - for sample_offset, si in enumerate(batch.sample_index) - ) - for batch_idx, batch in batches_r1 + batches_r1 = list(zip(range(20), loader_r1)) + print("bir0", [batch.batch_index for batch_idx, batch in batches_r1]) + print("sir1", [batch.sample_index for batch_idx, batch in batches_r1]) + assert all( + all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches_r1 + ) + assert all( + all( + si == 2 * sample_offset + (batch_idx * 2 - batch_idx % 2) + for sample_offset, si in enumerate(batch.sample_index) ) + for batch_idx, batch in batches_r1 + ) - # Save and restore state - state = loader.save_state_rank() - - # Restore state and check if the batch index is restored correctly - with get_savable_loader( + # Now, test multi-worker loader with accessing get_current_batch_index and save/restore state + with ( + get_savable_loader( get_train_dataset( - self.dataset_path, + dataset_path, batch_size=2, task_encoder=TestTaskEncoder(), worker_config=worker_config_r0, shuffle_buffer_size=20, max_samples_per_sequence=10, ) - ).with_restored_state_rank(state) as loader: - batches = list(zip(range(20, 40), loader)) - print([batch.batch_index for batch_idx, batch in batches]) - print([batch.sample_index for batch_idx, batch in batches]) - assert all( - all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches - ) - assert all( - all( - si == 2 * sample_offset + (batch_idx * 2 - batch_idx % 2) - for sample_offset, si in enumerate(batch.sample_index) - ) - for batch_idx, batch in batches - ) - - def test_current_batch_index_generator(self): - # Tests if the get_current_batch_index works properly - torch.manual_seed(42) - - class TestTaskEncoder(TaskEncoder): - @stateless(restore_seeds=True) - def encode_sample(self, sample): - # print("si stack:", WorkerConfig._sample_index_stack) - yield ExtendedCaptioningSample.extend( - sample, - batch_index=self.current_batch_index, - sample_index=self.current_sample_index, - rand_num=random.randint(0, 1000) + 0, - ) - - yield ExtendedCaptioningSample.extend( - sample, - batch_index=self.current_batch_index, - sample_index=self.current_sample_index, - rand_num=random.randint(0, 1000) + 1000, - ) - - # First, test simple single main-thread loader with accessing get_current_batch_index - with get_loader( + ) as loader, + get_savable_loader( get_train_dataset( - self.dataset_path, - batch_size=3, + dataset_path, + batch_size=2, task_encoder=TestTaskEncoder(), - worker_config=no_worker_config, + worker_config=worker_config_r1, shuffle_buffer_size=20, max_samples_per_sequence=10, ) - ) as loader: - batches = list(zip(range(20), loader)) - print("bi", [batch.batch_index for batch_idx, batch in batches]) - assert all( - all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches + ) as loader_r1, + ): + batches = list(zip(range(20), loader)) + print([batch.batch_index for batch_idx, batch in batches]) + assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches) + assert all( + all( + si == 2 * sample_offset + (batch_idx * 2 - batch_idx % 2) + for sample_offset, si in enumerate(batch.sample_index) ) + for batch_idx, batch in batches + ) - print("si", [batch.sample_index for batch_idx, batch in batches]) - assert all( - all( - si == (sample_offset + batch_idx * 3) // 2 - for sample_offset, si in enumerate(batch.sample_index) - ) - for batch_idx, batch in batches + batches_r1 = list(zip(range(20), loader_r1)) + print([batch.batch_index for batch_idx, batch in batches_r1]) + assert all( + all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches_r1 + ) + assert all( + all( + si == 2 * sample_offset + (batch_idx * 2 - batch_idx % 2) + for sample_offset, si in enumerate(batch.sample_index) ) + for batch_idx, batch in batches_r1 + ) - print("rk", [batch.__restore_key__ for batch_idx, batch in batches]) - assert loader.can_restore_sample() - - # These need to be hard coded to detect breaking changes - # If a change is expected, update the values with the ones printed below - ref_batch_rand_nums = [ - [661, 1747, 762], - [1171, 206, 1921], - [470, 1705, 130], - [1722, 283, 1990], - [508, 1041, 61], - [1102, 625, 1559], - [661, 1512, 296], - [1866, 376, 1345], - [632, 1176, 514], - [1652, 715, 1702], - [406, 1552, 555], - [1303, 27, 1520], - [760, 1380, 36], - [1869, 607, 1292], - [610, 1084, 825], - [1113, 219, 1102], - [564, 1695, 832], - [1612, 876, 2000], - [512, 1308, 632], - [1425, 605, 1931], - ] - - batch_rand_nums = [] - for batch_idx, batch in batches: - restore_batch = loader.restore_sample(batch.__restore_key__) - assert restore_batch.batch_index == batch.batch_index - assert restore_batch.sample_index == batch.sample_index - assert restore_batch.rand_num == batch.rand_num - - batch_rand_nums.append(restore_batch.rand_num) - assert np.allclose(restore_batch.image, batch.image) - - # For constructing the test data above: - print("batch_rand_nums: ", batch_rand_nums) - assert batch_rand_nums == ref_batch_rand_nums - - # Now, test multi-worker loader with accessing get_current_batch_index - worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2) - worker_config_r1 = WorkerConfig(rank=1, world_size=2, num_workers=2) + # Save and restore state + state = loader.save_state_rank() - with ( - get_loader( - get_train_dataset( - self.dataset_path, - batch_size=3, - task_encoder=TestTaskEncoder(), - worker_config=worker_config_r0, - shuffle_buffer_size=20, - max_samples_per_sequence=10, - ) - ) as loader, - get_loader( - get_train_dataset( - self.dataset_path, - batch_size=3, - task_encoder=TestTaskEncoder(), - worker_config=worker_config_r1, - shuffle_buffer_size=20, - max_samples_per_sequence=10, - ) - ) as loader_r1, - ): - batches = list(zip(range(20), loader)) - print("bir0", [batch.batch_index for batch_idx, batch in batches]) - assert all( - all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches + # Restore state and check if the batch index is restored correctly + with get_savable_loader( + get_train_dataset( + dataset_path, + batch_size=2, + task_encoder=TestTaskEncoder(), + worker_config=worker_config_r0, + shuffle_buffer_size=20, + max_samples_per_sequence=10, + ) + ).with_restored_state_rank(state) as loader: + batches = list(zip(range(20, 40), loader)) + print([batch.batch_index for batch_idx, batch in batches]) + print([batch.sample_index for batch_idx, batch in batches]) + assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches) + assert all( + all( + si == 2 * sample_offset + (batch_idx * 2 - batch_idx % 2) + for sample_offset, si in enumerate(batch.sample_index) ) + for batch_idx, batch in batches + ) - print("sir0", [batch.sample_index for batch_idx, batch in batches]) - # [[0, 0, 2], [1, 1, 3], [2, 4, 4], [3, 5, 5], [6, 6, 8], [7, 7, 9], [8, 10, 10], [9, 11, 11], [12, 12, 14], [13, 13, 15], [14, 16, 16], [15, 17, 17], [18, 18, 20], [19, 19, 21], [20, 22, 22], [21, 23, 23], [24, 24, 26], [25, 25, 27], [26, 28, 28], [27, 29, 29]] - assert all( - all( - si - == batch_idx - + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2 - for sample_offset, si in enumerate(batch.sample_index) - ) - for batch_idx, batch in batches - ) - batches_r1 = list(zip(range(20), loader_r1)) - print("bir0", [batch.batch_index for batch_idx, batch in batches_r1]) - print("sir1", [batch.sample_index for batch_idx, batch in batches_r1]) - assert all( - all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches_r1 - ) - assert all( - all( - si - == batch_idx - + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2 - for sample_offset, si in enumerate(batch.sample_index) - ) - for batch_idx, batch in batches_r1 - ) +def test_current_batch_index_generator(dataset_path, samples): + # Tests if the get_current_batch_index works properly + torch.manual_seed(42) - # Now, test multi-worker loader with accessing get_current_batch_index and save/restore state - with ( - get_savable_loader( - get_train_dataset( - self.dataset_path, - batch_size=3, - task_encoder=TestTaskEncoder(), - worker_config=worker_config_r0, - shuffle_buffer_size=20, - max_samples_per_sequence=10, - ), - ) as loader, - get_savable_loader( - get_train_dataset( - self.dataset_path, - batch_size=3, - task_encoder=TestTaskEncoder(), - worker_config=worker_config_r1, - shuffle_buffer_size=20, - max_samples_per_sequence=10, - ), - ) as loader_r1, - ): - batches = list(zip(range(20), loader)) - print("bi:", [batch.batch_index for batch_idx, batch in batches]) - print("si:", [batch.sample_index for batch_idx, batch in batches]) - assert all( - all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches - ) - assert all( - all( - si - == batch_idx - + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2 - for sample_offset, si in enumerate(batch.sample_index) - ) - for batch_idx, batch in batches + class TestTaskEncoder(TaskEncoder): + @stateless(restore_seeds=True) + def encode_sample(self, sample): + # print("si stack:", WorkerConfig._sample_index_stack) + yield ExtendedCaptioningSample.extend( + sample, + batch_index=self.current_batch_index, + sample_index=self.current_sample_index, + rand_num=random.randint(0, 1000) + 0, ) - batches_r1 = list(zip(range(20), loader_r1)) - print([batch.batch_index for batch_idx, batch in batches_r1]) - assert all( - all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches_r1 + yield ExtendedCaptioningSample.extend( + sample, + batch_index=self.current_batch_index, + sample_index=self.current_sample_index, + rand_num=random.randint(0, 1000) + 1000, ) - assert all( - all( - si - == batch_idx - + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2 - for sample_offset, si in enumerate(batch.sample_index) - ) - for batch_idx, batch in batches_r1 - ) - - # Save and restore state - state = loader.save_state_rank() - # Iter next 20 from the loader - cmp_batches = list(zip(range(20, 40), loader)) - print("bi:", [batch.batch_index for batch_idx, batch in cmp_batches]) - print("si:", [batch.sample_index for batch_idx, batch in cmp_batches]) - print("rnd:", [batch.rand_num for batch_idx, batch in cmp_batches]) - assert all( - all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in cmp_batches - ) - assert all( - all( - si - == batch_idx - + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2 - for sample_offset, si in enumerate(batch.sample_index) - ) - for batch_idx, batch in cmp_batches + # First, test simple single main-thread loader with accessing get_current_batch_index + with get_loader( + get_train_dataset( + dataset_path, + batch_size=3, + task_encoder=TestTaskEncoder(), + worker_config=no_worker_config, + shuffle_buffer_size=20, + max_samples_per_sequence=10, + ) + ) as loader: + batches = list(zip(range(20), loader)) + print("bi", [batch.batch_index for batch_idx, batch in batches]) + assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches) + + print("si", [batch.sample_index for batch_idx, batch in batches]) + assert all( + all( + si == (sample_offset + batch_idx * 3) // 2 + for sample_offset, si in enumerate(batch.sample_index) ) + for batch_idx, batch in batches + ) - # Restore state and check if the batch index is restored correctly - with get_savable_loader( + print("rk", [batch.__restore_key__ for batch_idx, batch in batches]) + assert loader.can_restore_sample() + + # These need to be hard coded to detect breaking changes + # If a change is expected, update the values with the ones printed below + ref_batch_rand_nums = [ + [661, 1747, 762], + [1171, 206, 1921], + [470, 1705, 130], + [1722, 283, 1990], + [508, 1041, 61], + [1102, 625, 1559], + [661, 1512, 296], + [1866, 376, 1345], + [632, 1176, 514], + [1652, 715, 1702], + [406, 1552, 555], + [1303, 27, 1520], + [760, 1380, 36], + [1869, 607, 1292], + [610, 1084, 825], + [1113, 219, 1102], + [564, 1695, 832], + [1612, 876, 2000], + [512, 1308, 632], + [1425, 605, 1931], + ] + + batch_rand_nums = [] + for batch_idx, batch in batches: + restore_batch = loader.restore_sample(batch.__restore_key__) + assert restore_batch.batch_index == batch.batch_index + assert restore_batch.sample_index == batch.sample_index + assert restore_batch.rand_num == batch.rand_num + + batch_rand_nums.append(restore_batch.rand_num) + assert np.allclose(restore_batch.image, batch.image) + + # For constructing the test data above: + print("batch_rand_nums: ", batch_rand_nums) + assert batch_rand_nums == ref_batch_rand_nums + + # Now, test multi-worker loader with accessing get_current_batch_index + worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2) + worker_config_r1 = WorkerConfig(rank=1, world_size=2, num_workers=2) + + with ( + get_loader( get_train_dataset( - self.dataset_path, + dataset_path, batch_size=3, task_encoder=TestTaskEncoder(), worker_config=worker_config_r0, shuffle_buffer_size=20, max_samples_per_sequence=10, - ), - ).with_restored_state_rank(state) as loader: - batches = list(zip(range(20, 40), loader)) - print("bi:", [batch.batch_index for batch_idx, batch in batches]) - print("si:", [batch.sample_index for batch_idx, batch in batches]) - print("rnd:", [batch.rand_num for batch_idx, batch in batches]) - assert all( - all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches ) - assert all( - all( - si - == batch_idx - + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2 - for sample_offset, si in enumerate(batch.sample_index) - ) - for batch_idx, batch in batches - ) - assert all( - all(b1s == b2s for b1s, b2s in zip(b1.rand_num, b2.rand_num)) - for (_b1idx, b1), (_b2idx, b2) in zip(batches, cmp_batches) - ) - - def test_packing(self): - torch.manual_seed(42) - - class TestTaskEncoder(DefaultTaskEncoder): - def __init__(self): - super().__init__(raw_batch_type=CaptioningBatch) - - @stateless - def encode_sample(self, sample: CaptioningSample) -> EncodedCaptioningSample: - return EncodedCaptioningSample.derive_from( - sample, - image=sample.image, - caption=torch.frombuffer(sample.caption.encode(), dtype=torch.uint8), - ) - - def select_samples_to_pack( - self, samples: List[EncodedCaptioningSample] - ) -> List[List[EncodedCaptioningSample]]: - assert len(samples) == 21 - return [samples[:1], samples[1 : 1 + 4], samples[1 + 4 : 1 + 4 + 16]] - - @stateless - def pack_selected_samples( - self, samples: List[EncodedCaptioningSample] - ) -> EncodedCaptioningSample: - return EncodedCaptioningSample( - __key__=",".join([sample.__key__ for sample in samples]), - __restore_key__=None, - image=torch.stack([sample.image for sample in samples]), - caption=torch.cat([sample.caption for sample in samples]), - ) - - with get_loader( + ) as loader, + get_loader( get_train_dataset( - self.dataset_path, - batch_size=2, - packing_buffer_size=21, - worker_config=no_worker_config, - virtual_epoch_length=6, - shuffle_buffer_size=None, - max_samples_per_sequence=None, + dataset_path, + batch_size=3, task_encoder=TestTaskEncoder(), + worker_config=worker_config_r1, + shuffle_buffer_size=20, + max_samples_per_sequence=10, ) - ) as loader: - assert len(loader) == 6 - - samples = list(loader) - - print([batch.__key__ for batch in samples]) - print([batch.__restore_key__ for batch in samples]) - print([len(batch.__key__) for batch in samples]) - print([[len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples]) - - # Each batch should have 2 samples - assert [len(batch.__key__) for batch in samples] == [ - 2, - 2, - 2, - 2, - 2, - 2, - ] - - # The packs of lengths 1, 4, 16 should be unrolled repeatedly across the batches of size 2 - assert [ - [len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples - ] == [[1, 4], [16, 1], [4, 16], [1, 4], [16, 1], [4, 16]] - - restored_sample_1 = loader.restore_sample(samples[1].__restore_key__) - assert restored_sample_1.__key__ == samples[1].__key__ - assert restored_sample_1.__restore_key__ == samples[1].__restore_key__ + ) as loader_r1, + ): + batches = list(zip(range(20), loader)) + print("bir0", [batch.batch_index for batch_idx, batch in batches]) + assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches) + + print("sir0", [batch.sample_index for batch_idx, batch in batches]) + # [[0, 0, 2], [1, 1, 3], [2, 4, 4], [3, 5, 5], [6, 6, 8], [7, 7, 9], [8, 10, 10], [9, 11, 11], [12, 12, 14], [13, 13, 15], [14, 16, 16], [15, 17, 17], [18, 18, 20], [19, 19, 21], [20, 22, 22], [21, 23, 23], [24, 24, 26], [25, 25, 27], [26, 28, 28], [27, 29, 29]] + assert all( + all( + si == batch_idx + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2 + for sample_offset, si in enumerate(batch.sample_index) + ) + for batch_idx, batch in batches + ) - worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2) + batches_r1 = list(zip(range(20), loader_r1)) + print("bir0", [batch.batch_index for batch_idx, batch in batches_r1]) + print("sir1", [batch.sample_index for batch_idx, batch in batches_r1]) + assert all( + all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches_r1 + ) + assert all( + all( + si == batch_idx + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2 + for sample_offset, si in enumerate(batch.sample_index) + ) + for batch_idx, batch in batches_r1 + ) - with get_savable_loader( + # Now, test multi-worker loader with accessing get_current_batch_index and save/restore state + with ( + get_savable_loader( get_train_dataset( - self.dataset_path, - batch_size=2, - packing_buffer_size=21, - worker_config=worker_config_r0, - virtual_epoch_length=8, - shuffle_buffer_size=None, - max_samples_per_sequence=None, + dataset_path, + batch_size=3, task_encoder=TestTaskEncoder(), + worker_config=worker_config_r0, + shuffle_buffer_size=20, + max_samples_per_sequence=10, ), - ) as loader_r0: - samples_r0 = list(loader_r0) - assert [ - [len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples_r0 - ] == [[1, 4], [1, 4], [16, 1], [16, 1], [4, 16], [4, 16], [1, 4], [1, 4]] - - restored_sample_1 = loader_r0.restore_sample(samples_r0[1].__restore_key__) - assert restored_sample_1.__key__ == samples_r0[1].__key__ - assert restored_sample_1.__restore_key__ == samples_r0[1].__restore_key__ - - rank_state_r0 = loader_r0.save_state_rank() - samples_r0_cmp = list(loader_r0) - assert [ - [len(batch_key.split(",")) for batch_key in batch.__key__] - for batch in samples_r0_cmp - ] == [[16, 1], [16, 1], [4, 16], [4, 16], [1, 4], [1, 4], [16, 1], [16, 1]] - - with get_savable_loader( + ) as loader, + get_savable_loader( get_train_dataset( - self.dataset_path, - batch_size=2, - packing_buffer_size=21, - worker_config=worker_config_r0, - virtual_epoch_length=8, - shuffle_buffer_size=None, - max_samples_per_sequence=None, + dataset_path, + batch_size=3, task_encoder=TestTaskEncoder(), + worker_config=worker_config_r1, + shuffle_buffer_size=20, + max_samples_per_sequence=10, ), - ).with_restored_state_rank(rank_state_r0) as loader_r0: - samples_r0_restored = list(loader_r0) - print("cmp", [batch.__key__ for batch in samples_r0_cmp]) - print("rst", [batch.__key__ for batch in samples_r0_restored]) - assert [ - [len(batch_key.split(",")) for batch_key in batch.__key__] - for batch in samples_r0_restored - ] == [[16, 1], [16, 1], [4, 16], [4, 16], [1, 4], [1, 4], [16, 1], [16, 1]] - - assert all( - s0.__key__ == s1.__key__ for s0, s1 in zip(samples_r0_cmp, samples_r0_restored) + ) as loader_r1, + ): + batches = list(zip(range(20), loader)) + print("bi:", [batch.batch_index for batch_idx, batch in batches]) + print("si:", [batch.sample_index for batch_idx, batch in batches]) + assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches) + assert all( + all( + si == batch_idx + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2 + for sample_offset, si in enumerate(batch.sample_index) ) + for batch_idx, batch in batches + ) - def test_packing_val(self): - torch.manual_seed(42) + batches_r1 = list(zip(range(20), loader_r1)) + print([batch.batch_index for batch_idx, batch in batches_r1]) + assert all( + all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches_r1 + ) + assert all( + all( + si == batch_idx + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2 + for sample_offset, si in enumerate(batch.sample_index) + ) + for batch_idx, batch in batches_r1 + ) - class TestTaskEncoder(DefaultTaskEncoder): - def __init__(self): - super().__init__(raw_batch_type=CaptioningBatch) + # Save and restore state + state = loader.save_state_rank() - @stateless - def encode_sample(self, sample: CaptioningSample) -> EncodedCaptioningSample: - return EncodedCaptioningSample.derive_from( - sample, - image=sample.image, - caption=torch.frombuffer(sample.caption.encode(), dtype=torch.uint8), - ) + # Iter next 20 from the loader + cmp_batches = list(zip(range(20, 40), loader)) + print("bi:", [batch.batch_index for batch_idx, batch in cmp_batches]) + print("si:", [batch.sample_index for batch_idx, batch in cmp_batches]) + print("rnd:", [batch.rand_num for batch_idx, batch in cmp_batches]) + assert all( + all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in cmp_batches + ) + assert all( + all( + si == batch_idx + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2 + for sample_offset, si in enumerate(batch.sample_index) + ) + for batch_idx, batch in cmp_batches + ) - def select_samples_to_pack( - self, samples: List[EncodedCaptioningSample] - ) -> List[List[EncodedCaptioningSample]]: - assert len(samples) in (1 + 3 + 5 + 2, 50 % 11) - if len(samples) < 11: - return [] - return [ - samples[1 + 3 + 5 : 1 + 3 + 5 + 2], - samples[1 + 3 : 1 + 3 + 5], - samples[1 : 1 + 3], - samples[:1], - ] + # Restore state and check if the batch index is restored correctly + with get_savable_loader( + get_train_dataset( + dataset_path, + batch_size=3, + task_encoder=TestTaskEncoder(), + worker_config=worker_config_r0, + shuffle_buffer_size=20, + max_samples_per_sequence=10, + ), + ).with_restored_state_rank(state) as loader: + batches = list(zip(range(20, 40), loader)) + print("bi:", [batch.batch_index for batch_idx, batch in batches]) + print("si:", [batch.sample_index for batch_idx, batch in batches]) + print("rnd:", [batch.rand_num for batch_idx, batch in batches]) + assert all(all(bi == batch_idx for bi in batch.batch_index) for batch_idx, batch in batches) + assert all( + all( + si == batch_idx + (batch_idx // 4 + ((batch_idx // 2 % 2) + sample_offset) // 2) * 2 + for sample_offset, si in enumerate(batch.sample_index) + ) + for batch_idx, batch in batches + ) + assert all( + all(b1s == b2s for b1s, b2s in zip(b1.rand_num, b2.rand_num)) + for (_b1idx, b1), (_b2idx, b2) in zip(batches, cmp_batches) + ) - @stateless - def pack_selected_samples( - self, samples: List[EncodedCaptioningSample] - ) -> EncodedCaptioningSample: - return EncodedCaptioningSample( - __key__=",".join([sample.__key__ for sample in samples]), - __restore_key__=None, - image=torch.stack([sample.image for sample in samples]), - caption=torch.cat([sample.caption for sample in samples]), - ) - with get_loader( - get_val_dataset( - self.dataset_path, - batch_size=2, - packing_buffer_size=11, - worker_config=no_worker_config, - task_encoder=TestTaskEncoder(), - split_part="train", +def test_packing(dataset_path, samples): + torch.manual_seed(42) + + class TestTaskEncoder(DefaultTaskEncoder): + def __init__(self): + super().__init__(raw_batch_type=CaptioningBatch) + + @stateless + def encode_sample(self, sample: CaptioningSample) -> EncodedCaptioningSample: + return EncodedCaptioningSample.derive_from( + sample, + image=sample.image, + caption=torch.frombuffer(sample.caption.encode(), dtype=torch.uint8), ) - ) as loader: - assert len(loader) == 25, f"len(loader) == {len(loader)}" - samples = list(loader) + def select_samples_to_pack( + self, samples: List[EncodedCaptioningSample] + ) -> List[List[EncodedCaptioningSample]]: + assert len(samples) == 21 + return [samples[:1], samples[1 : 1 + 4], samples[1 + 4 : 1 + 4 + 16]] + + @stateless + def pack_selected_samples( + self, samples: List[EncodedCaptioningSample] + ) -> EncodedCaptioningSample: + return EncodedCaptioningSample( + __key__=",".join([sample.__key__ for sample in samples]), + __restore_key__=None, + image=torch.stack([sample.image for sample in samples]), + caption=torch.cat([sample.caption for sample in samples]), + ) - print([batch.__key__ for batch in samples]) - print([batch.__restore_key__ for batch in samples]) - print([len(batch.__key__) for batch in samples]) - print([[len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples]) + with get_loader( + get_train_dataset( + dataset_path, + batch_size=2, + packing_buffer_size=21, + worker_config=no_worker_config, + virtual_epoch_length=6, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + task_encoder=TestTaskEncoder(), + ) + ) as loader: + assert len(loader) == 6 + + samples = list(loader) + + print([batch.__key__ for batch in samples]) + print([batch.__restore_key__ for batch in samples]) + print([len(batch.__key__) for batch in samples]) + print([[len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples]) + + # Each batch should have 2 samples + assert [len(batch.__key__) for batch in samples] == [ + 2, + 2, + 2, + 2, + 2, + 2, + ] + + # The packs of lengths 1, 4, 16 should be unrolled repeatedly across the batches of size 2 + assert [ + [len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples + ] == [[1, 4], [16, 1], [4, 16], [1, 4], [16, 1], [4, 16]] + + restored_sample_1 = loader.restore_sample(samples[1].__restore_key__) + assert restored_sample_1.__key__ == samples[1].__key__ + assert restored_sample_1.__restore_key__ == samples[1].__restore_key__ + + worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2) + + with get_savable_loader( + get_train_dataset( + dataset_path, + batch_size=2, + packing_buffer_size=21, + worker_config=worker_config_r0, + virtual_epoch_length=8, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + task_encoder=TestTaskEncoder(), + ), + ) as loader_r0: + samples_r0 = list(loader_r0) + assert [ + [len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples_r0 + ] == [[1, 4], [1, 4], [16, 1], [16, 1], [4, 16], [4, 16], [1, 4], [1, 4]] + + restored_sample_1 = loader_r0.restore_sample(samples_r0[1].__restore_key__) + assert restored_sample_1.__key__ == samples_r0[1].__key__ + assert restored_sample_1.__restore_key__ == samples_r0[1].__restore_key__ + + rank_state_r0 = loader_r0.save_state_rank() + samples_r0_cmp = list(loader_r0) + assert [ + [len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples_r0_cmp + ] == [[16, 1], [16, 1], [4, 16], [4, 16], [1, 4], [1, 4], [16, 1], [16, 1]] + + with get_savable_loader( + get_train_dataset( + dataset_path, + batch_size=2, + packing_buffer_size=21, + worker_config=worker_config_r0, + virtual_epoch_length=8, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + task_encoder=TestTaskEncoder(), + ), + ).with_restored_state_rank(rank_state_r0) as loader_r0: + samples_r0_restored = list(loader_r0) + print("cmp", [batch.__key__ for batch in samples_r0_cmp]) + print("rst", [batch.__key__ for batch in samples_r0_restored]) + assert [ + [len(batch_key.split(",")) for batch_key in batch.__key__] + for batch in samples_r0_restored + ] == [[16, 1], [16, 1], [4, 16], [4, 16], [1, 4], [1, 4], [16, 1], [16, 1]] + + assert all(s0.__key__ == s1.__key__ for s0, s1 in zip(samples_r0_cmp, samples_r0_restored)) + + +def test_packing_val(dataset_path, samples): + torch.manual_seed(42) + + class TestTaskEncoder(DefaultTaskEncoder): + def __init__(self): + super().__init__(raw_batch_type=CaptioningBatch) + + @stateless + def encode_sample(self, sample: CaptioningSample) -> EncodedCaptioningSample: + return EncodedCaptioningSample.derive_from( + sample, + image=sample.image, + caption=torch.frombuffer(sample.caption.encode(), dtype=torch.uint8), + ) - # Each batch should have 2 samples - assert [len(batch.__key__) for batch in samples] == [ - 2, - 2, - 2, - 2, - 2, - 2, - 2, - 2, + def select_samples_to_pack( + self, samples: List[EncodedCaptioningSample] + ) -> List[List[EncodedCaptioningSample]]: + assert len(samples) in (1 + 3 + 5 + 2, 50 % 11) + if len(samples) < 11: + return [] + return [ + samples[1 + 3 + 5 : 1 + 3 + 5 + 2], + samples[1 + 3 : 1 + 3 + 5], + samples[1 : 1 + 3], + samples[:1], ] - # The packs of lengths 1, 4, 16 should be unrolled repeatedly across the batches of size 2 - assert [ - [len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples - ] == [[2, 5], [3, 1], [2, 5], [3, 1], [2, 5], [3, 1], [2, 5], [3, 1]] - - restored_sample_1 = loader.restore_sample(samples[1].__restore_key__) - assert restored_sample_1.__key__ == samples[1].__key__ - assert restored_sample_1.__restore_key__ == samples[1].__restore_key__ - - def test_group_batch(self): - class GroupingTaskEncoder( - TaskEncoder[CaptioningSample, CaptioningSample, CaptioningSample, CaptioningSample] - ): - @stateless - def encode_sample(self, sample: CaptioningSample) -> CaptioningSample: - sample.caption = sample.__key__.split("/")[-2] - return sample - - def batch_group_criterion(self, sample: CaptioningSample) -> Tuple[Hashable, int]: - if sample.caption == "data-0.tar": - return "shard1", 4 - elif sample.caption == "data-1.tar": - return "shard2", 8 - else: - assert False - - @stateless - def encode_batch(self, batch: CaptioningSample) -> CaptioningEncodedBatch: - return CaptioningEncodedBatch.extend(batch) - - worker_config = WorkerConfig(rank=0, world_size=1, num_workers=0) - with get_savable_loader( - get_train_dataset( - self.dataset_path, - batch_size=None, - worker_config=worker_config, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - task_encoder=GroupingTaskEncoder(), - ), - ) as loader: - batches = list(zip(range(40), loader)) - print([batch.__key__ for idx, batch in batches]) - - assert all(isinstance(batch, CaptioningEncodedBatch) for idx, batch in batches) - assert all( - all(key == batch.caption[0] for key in batch.caption) for idx, batch in batches + @stateless + def pack_selected_samples( + self, samples: List[EncodedCaptioningSample] + ) -> EncodedCaptioningSample: + return EncodedCaptioningSample( + __key__=",".join([sample.__key__ for sample in samples]), + __restore_key__=None, + image=torch.stack([sample.image for sample in samples]), + caption=torch.cat([sample.caption for sample in samples]), ) - worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2) - - with get_savable_loader( - get_train_dataset( - self.dataset_path, - batch_size=None, - worker_config=worker_config_r0, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - task_encoder=GroupingTaskEncoder(), - ), - ) as loader_r0: - batches = list(zip(range(40), loader_r0)) + with get_loader( + get_val_dataset( + dataset_path, + batch_size=2, + packing_buffer_size=11, + worker_config=no_worker_config, + task_encoder=TestTaskEncoder(), + split_part="train", + ) + ) as loader: + assert len(loader) == 25, f"len(loader) == {len(loader)}" + + samples = list(loader) + + print([batch.__key__ for batch in samples]) + print([batch.__restore_key__ for batch in samples]) + print([len(batch.__key__) for batch in samples]) + print([[len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples]) + + # Each batch should have 2 samples + assert [len(batch.__key__) for batch in samples] == [ + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + ] + + # The packs of lengths 1, 4, 16 should be unrolled repeatedly across the batches of size 2 + assert [ + [len(batch_key.split(",")) for batch_key in batch.__key__] for batch in samples + ] == [[2, 5], [3, 1], [2, 5], [3, 1], [2, 5], [3, 1], [2, 5], [3, 1]] + + restored_sample_1 = loader.restore_sample(samples[1].__restore_key__) + assert restored_sample_1.__key__ == samples[1].__key__ + assert restored_sample_1.__restore_key__ == samples[1].__restore_key__ + + +def test_group_batch(dataset_path, samples): + class GroupingTaskEncoder( + TaskEncoder[CaptioningSample, CaptioningSample, CaptioningSample, CaptioningSample] + ): + @stateless + def encode_sample(self, sample: CaptioningSample) -> CaptioningSample: + sample.caption = sample.__key__.split("/")[-2] + return sample + + def batch_group_criterion(self, sample: CaptioningSample) -> Tuple[Hashable, int]: + if sample.caption == "data-0.tar": + return "shard1", 4 + elif sample.caption == "data-1.tar": + return "shard2", 8 + else: + assert False + + @stateless + def encode_batch(self, batch: CaptioningSample) -> CaptioningEncodedBatch: + return CaptioningEncodedBatch.extend(batch) + + worker_config = WorkerConfig(rank=0, world_size=1, num_workers=0) + with get_savable_loader( + get_train_dataset( + dataset_path, + batch_size=None, + worker_config=worker_config, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + task_encoder=GroupingTaskEncoder(), + ), + ) as loader: + batches = list(zip(range(40), loader)) + print([batch.__key__ for idx, batch in batches]) - print([batch.__key__ for idx, batch in batches]) + assert all(isinstance(batch, CaptioningEncodedBatch) for idx, batch in batches) + assert all(all(key == batch.caption[0] for key in batch.caption) for idx, batch in batches) - assert all(isinstance(batch, CaptioningEncodedBatch) for idx, batch in batches) - assert all( - all(key == batch.caption[0] for key in batch.caption) for idx, batch in batches - ) + worker_config_r0 = WorkerConfig(rank=0, world_size=2, num_workers=2) - state = loader_r0.save_state_rank() + with get_savable_loader( + get_train_dataset( + dataset_path, + batch_size=None, + worker_config=worker_config_r0, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + task_encoder=GroupingTaskEncoder(), + ), + ) as loader_r0: + batches = list(zip(range(40), loader_r0)) - cmp_samples = list(zip(range(40, 80), loader_r0)) - print([batch.__key__ for idx, batch in cmp_samples]) + print([batch.__key__ for idx, batch in batches]) - with get_savable_loader( - get_train_dataset( - self.dataset_path, - batch_size=None, - worker_config=worker_config_r0, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - task_encoder=GroupingTaskEncoder(), - ), - ).with_restored_state_rank(state) as loader_r0: - cmp_samples_rest = list(zip(range(40, 80), loader_r0)) - print([batch.__key__ for idx, batch in cmp_samples_rest]) - - assert len(cmp_samples) == len(cmp_samples_rest) - assert all( - len(cmp_sample.caption) == len(cmp_sample_rest.caption) - for (idx, cmp_sample), (idx, cmp_sample_rest) in zip(cmp_samples, cmp_samples_rest) - ) - assert all( - all( - cmp_cap == cmp_cap_rest - for cmp_cap, cmp_cap_rest in zip(cmp_sample.caption, cmp_sample_rest.caption) - ) - for (idx, cmp_sample), (idx, cmp_sample_rest) in zip(cmp_samples, cmp_samples_rest) - ) + assert all(isinstance(batch, CaptioningEncodedBatch) for idx, batch in batches) + assert all(all(key == batch.caption[0] for key in batch.caption) for idx, batch in batches) - def test_debug_dataset(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=2, - worker_log_level=3, - worker_debug_path=str(self.dataset_path) + "/worker_debug/{worker_id}.jsonl", - ) + state = loader_r0.save_state_rank() - # Reset this to 0 to make sure the test is deterministic - DataLoader._next_id = 0 + cmp_samples = list(zip(range(40, 80), loader_r0)) + print([batch.__key__ for idx, batch in cmp_samples]) - with get_savable_loader( - get_val_dataset( - self.dataset_path, - split_part="train", - batch_size=5, - worker_config=worker_config, - ), - ) as loader: - assert len(loader) == 10 - - samples = [[batch.__key__ for batch in loader] for _ in range(2)] - print(samples) - - debug_log_path = self.dataset_path / "worker_debug" - assert (debug_log_path / "0.jsonl").is_file() - assert (debug_log_path / "1.jsonl").is_file() - assert (debug_log_path / "2.jsonl").is_file() - - collected_keys_order = [[None] * 10 for _ in range(2)] - with (debug_log_path / "0.jsonl").open() as rf: - for line in rf: - line_data = json.loads(line) - print(line_data) - if line_data["t"] == "DataLoader.epoch_iter.yield": - for i in range(len(collected_keys_order)): - if collected_keys_order[i][line_data["epoch_sample_idx"]] is None: - collected_keys_order[i][line_data["epoch_sample_idx"]] = line_data[ - "keys" - ] - break - else: - assert False, "Too many entries for key" - - print(collected_keys_order) - assert collected_keys_order == samples - - runner = CliRunner() - result = runner.invoke( - analyze_debug_command, - [ - str(debug_log_path), - "--include-modality", - "train,val", - "--heatmap-path", - str(self.dataset_path / "heatmap.png"), - ], - catch_exceptions=False, - ) - print(result.stdout) - assert result.exit_code == 0, "Debug analysis failed, see output" - assert "Analyzing 3 logs" in result.stdout - assert "Found 50 unique sample keys, 20 steps" in result.stdout - - def test_validate_captioning_dataset(self): - runner = CliRunner() - result = runner.invoke( - lint_command, - [str(self.dataset_path), "--split-parts=train"], - catch_exceptions=False, - ) - assert result.exit_code == 0, "Validation failed, see output" - - def test_prepare_dataset(self): - runner = CliRunner() - result = runner.invoke( - prepare_command, - [str(self.dataset_path)], - catch_exceptions=False, - input="y\n1,0,0\ny\n0\nY\npng\ntxt\n", - ) - assert result.exit_code == 0, "Prepare failed, see output" - assert "Done" in result.stdout, "Prepare failed, see output" - - def test_preview_captioning_dataset(self): - runner = CliRunner() - result = runner.invoke( - preview_command, - [str(self.dataset_path), "--split-parts=train"], - input="n\n", - catch_exceptions=False, + with get_savable_loader( + get_train_dataset( + dataset_path, + batch_size=None, + worker_config=worker_config_r0, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + task_encoder=GroupingTaskEncoder(), + ), + ).with_restored_state_rank(state) as loader_r0: + cmp_samples_rest = list(zip(range(40, 80), loader_r0)) + print([batch.__key__ for idx, batch in cmp_samples_rest]) + + assert len(cmp_samples) == len(cmp_samples_rest) + assert all( + len(cmp_sample.caption) == len(cmp_sample_rest.caption) + for (idx, cmp_sample), (idx, cmp_sample_rest) in zip(cmp_samples, cmp_samples_rest) ) - # First sample! - assert "__key__ (): 'parts/data-1.tar/000030'" in result.stdout - assert result.exit_code == 0, "Preview failed, see output" - - def test_info_captioning_dataset(self): - runner = CliRunner() - result = runner.invoke( - info_command, - [str(self.dataset_path)], - catch_exceptions=False, + assert all( + all( + cmp_cap == cmp_cap_rest + for cmp_cap, cmp_cap_rest in zip(cmp_sample.caption, cmp_sample_rest.caption) + ) + for (idx, cmp_sample), (idx, cmp_sample_rest) in zip(cmp_samples, cmp_samples_rest) ) - print(result.stdout) - assert "50 samples" in result.stdout - assert "2 shards" in result.stdout - assert str(self.dataset_path) in result.stdout - assert "train" in result.stdout - assert result.exit_code == 0, "Preview failed, see output" -if __name__ == "__main__": - unittest.main() +def test_debug_dataset(dataset_path, samples): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=2, + worker_log_level=3, + worker_debug_path=str(dataset_path) + "/worker_debug/{worker_id}.jsonl", + ) + + # Reset this to 0 to make sure the test is deterministic + DataLoader._next_id = 0 + + with get_savable_loader( + get_val_dataset( + dataset_path, + split_part="train", + batch_size=5, + worker_config=worker_config, + ), + ) as loader: + assert len(loader) == 10 + + samples = [[batch.__key__ for batch in loader] for _ in range(2)] + print(samples) + + debug_log_path = dataset_path / "worker_debug" + assert (debug_log_path / "0.jsonl").is_file() + assert (debug_log_path / "1.jsonl").is_file() + assert (debug_log_path / "2.jsonl").is_file() + + collected_keys_order = [[None] * 10 for _ in range(2)] + with (debug_log_path / "0.jsonl").open() as rf: + for line in rf: + line_data = json.loads(line) + print(line_data) + if line_data["t"] == "DataLoader.epoch_iter.yield": + for i in range(len(collected_keys_order)): + if collected_keys_order[i][line_data["epoch_sample_idx"]] is None: + collected_keys_order[i][line_data["epoch_sample_idx"]] = line_data[ + "keys" + ] + break + else: + assert False, "Too many entries for key" + + print(collected_keys_order) + assert collected_keys_order == samples + + runner = CliRunner() + result = runner.invoke( + analyze_debug_command, + [ + str(debug_log_path), + "--include-modality", + "train,val", + "--heatmap-path", + str(dataset_path / "heatmap.png"), + ], + catch_exceptions=False, + ) + print(result.stdout) + assert result.exit_code == 0, "Debug analysis failed, see output" + assert "Analyzing 3 logs" in result.stdout + assert "Found 50 unique sample keys, 20 steps" in result.stdout + + +def test_validate_captioning_dataset(dataset_path, samples): + runner = CliRunner() + result = runner.invoke( + lint_command, + [str(dataset_path), "--split-parts=train"], + catch_exceptions=False, + ) + assert result.exit_code == 0, "Validation failed, see output" + + +def test_prepare_dataset(dataset_path, samples): + runner = CliRunner() + result = runner.invoke( + prepare_command, + [str(dataset_path)], + catch_exceptions=False, + input="y\n1,0,0\ny\n0\nY\npng\ntxt\n", + ) + assert result.exit_code == 0, "Prepare failed, see output" + assert "Done" in result.stdout, "Prepare failed, see output" + + +def test_preview_captioning_dataset(dataset_path, samples): + runner = CliRunner() + result = runner.invoke( + preview_command, + [str(dataset_path), "--split-parts=train"], + input="n\n", + catch_exceptions=False, + ) + # First sample! + assert "__key__ (): 'parts/data-1.tar/000030'" in result.stdout + assert result.exit_code == 0, "Preview failed, see output" + + +def test_info_captioning_dataset(dataset_path, samples): + runner = CliRunner() + result = runner.invoke( + info_command, + [str(dataset_path)], + catch_exceptions=False, + ) + print(result.stdout) + assert "50 samples" in result.stdout + assert "2 shards" in result.stdout + assert str(dataset_path) in result.stdout + assert "train" in result.stdout + assert result.exit_code == 0, "Preview failed, see output" diff --git a/tests/test_dataset_det.py b/tests/test_dataset_det.py index ab90daad..355d9eb3 100644 --- a/tests/test_dataset_det.py +++ b/tests/test_dataset_det.py @@ -8,11 +8,11 @@ import random import sys import tempfile -import unittest import warnings from collections import Counter from pathlib import Path +import pytest import torch import webdataset as wds import yaml @@ -58,378 +58,485 @@ def _norng_state(state): return state -class TestDataset(unittest.TestCase): - # Set up the test fixture - def setUp(self): - logging.basicConfig(stream=sys.stderr, level=logging.INFO) - warnings.simplefilter("ignore", ResourceWarning) +@pytest.fixture +def temp_dir(): + """Create a temporary directory for test files.""" + temp_dir = tempfile.TemporaryDirectory() + yield temp_dir + gc.collect() + temp_dir.cleanup() - # Create a temporary directory - self.temp_dir = tempfile.TemporaryDirectory() - self.dataset_path = Path(self.temp_dir.name) - # self.dataset_path = Path("./test_dataset") - self.dataset_path.mkdir(exist_ok=True, parents=True) +@pytest.fixture +def dataset_path(temp_dir): + """Create the main dataset directory with test data.""" + logging.basicConfig(stream=sys.stderr, level=logging.INFO) + warnings.simplefilter("ignore", ResourceWarning) - # Create a small dummy captioning dataset - self.create_text_test_dataset(self.dataset_path) + dataset_path = Path(temp_dir.name) + dataset_path.mkdir(exist_ok=True, parents=True) - # Create temporary directories for checkpoint files - self.checkpoint_dir = Path(self.temp_dir.name) / "checkpoints" - self.checkpoint_dir.mkdir(exist_ok=True, parents=True) + # Create a small dummy captioning dataset + create_text_test_dataset(dataset_path) - self.redist_dir = Path(self.temp_dir.name) / "redist_checkpoints" - self.redist_dir.mkdir(exist_ok=True, parents=True) + print(dataset_path) + return dataset_path - print(self.dataset_path) - def tearDown(self): - # Remove all temporary files - gc.collect() - self.temp_dir.cleanup() +@pytest.fixture +def checkpoint_dir(dataset_path): + """Create checkpoint directory for test files.""" + checkpoint_dir = dataset_path / "checkpoints" + checkpoint_dir.mkdir(exist_ok=True, parents=True) + return checkpoint_dir - @staticmethod - def create_text_test_dataset(path: Path): - """Creates a small dummy test dataset for testing purposes.""" - # Create num_samples unique captions - (path / "parts").mkdir(exist_ok=True, parents=True) +@pytest.fixture +def redist_dir(dataset_path): + """Create redistribution directory for test files.""" + redist_dir = dataset_path / "redist_checkpoints" + redist_dir.mkdir(exist_ok=True, parents=True) + return redist_dir - # Initialize the ShardWriter - with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=100) as shard_writer: - for idx in range(55): - # Write individual files to shards - shard_writer.write( - { - "__key__": f"{idx:06d}", - "txt": f"{idx}".encode(), - }, - ) - # Also create smaller shards, to verify distributions - if idx in (1, 3, 6, 10, 20, 30, 40, 50): - shard_writer.next_stream() - total_shards = shard_writer.shard - - from megatron.energon.flavors import BaseWebdatasetFactory - - BaseWebdatasetFactory.prepare_dataset( - path, - [f"parts/data-{{0..{total_shards - 1}}}.tar"], - split_parts_ratio=[("train", 1.0)], - shuffle_seed=None, - ) - with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: - f.write( - "\n".join( - [ - "sample_type:", - " __module__: megatron.energon", - " __class__: TextSample", - "field_map:", - " text: txt", - ] - ) - ) +def create_text_test_dataset(path: Path): + """Creates a small dummy test dataset for testing purposes.""" + + # Create num_samples unique captions + (path / "parts").mkdir(exist_ok=True, parents=True) - # Split with alternating train/val shards - with open(path / MAIN_FOLDER_NAME / "split2.yaml", "w") as f: - yaml.dump( + # Initialize the ShardWriter + with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=100) as shard_writer: + for idx in range(55): + # Write individual files to shards + shard_writer.write( { - "split_parts": { - "train": [ - "parts/data-4.tar", - "parts/data-0.tar", - "parts/data-2.tar", - ], - "val": [ - "parts/data-1.tar", - "parts/data-3.tar", - "parts/data-5.tar", - ], - } + "__key__": f"{idx:06d}", + "txt": f"{idx}".encode(), }, - f, ) + # Also create smaller shards, to verify distributions + if idx in (1, 3, 6, 10, 20, 30, 40, 50): + shard_writer.next_stream() + total_shards = shard_writer.shard + + from megatron.energon.flavors import BaseWebdatasetFactory + + BaseWebdatasetFactory.prepare_dataset( + path, + [f"parts/data-{{0..{total_shards - 1}}}.tar"], + split_parts_ratio=[("train", 1.0)], + shuffle_seed=None, + ) + + with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: + f.write( + "\n".join( + [ + "sample_type:", + " __module__: megatron.energon", + " __class__: TextSample", + "field_map:", + " text: txt", + ] + ) + ) - def test_split_parts(self): - with open(self.dataset_path / MAIN_FOLDER_NAME / "split.yaml", "r") as f: - print(f.read()) - with open(self.dataset_path / MAIN_FOLDER_NAME / "split2.yaml", "r") as f: - print(f.read()) - - ds = get_dataset_from_config( - self.dataset_path, - split_config="split2.yaml", - split_part="train", - worker_config=WorkerConfig(rank=0, world_size=1, num_workers=0), - training=False, - sample_type=TextSample, + # Split with alternating train/val shards + with open(path / MAIN_FOLDER_NAME / "split2.yaml", "w") as f: + yaml.dump( + { + "split_parts": { + "train": [ + "parts/data-4.tar", + "parts/data-0.tar", + "parts/data-2.tar", + ], + "val": [ + "parts/data-1.tar", + "parts/data-3.tar", + "parts/data-5.tar", + ], + } + }, + f, ) - with get_loader(ds.build()) as dl: - all_keys = [sample.__key__ for sample in dl] - assert all_keys == [ - "parts/data-4.tar/000011", # Shard 4 first - "parts/data-4.tar/000012", - "parts/data-4.tar/000013", - "parts/data-4.tar/000014", - "parts/data-4.tar/000015", - "parts/data-4.tar/000016", - "parts/data-4.tar/000017", - "parts/data-4.tar/000018", - "parts/data-4.tar/000019", - "parts/data-4.tar/000020", - "parts/data-0.tar/000000", # Shard 0 - "parts/data-0.tar/000001", - "parts/data-2.tar/000004", # Shard 2 - "parts/data-2.tar/000005", - "parts/data-2.tar/000006", - ] - def test_text_dataset(self): - worker_config = WorkerConfig(rank=0, world_size=1, num_workers=0) - ds = get_dataset_from_config( - self.dataset_path, - split_part="train", - training=False, - sample_type=TextSample, - worker_config=worker_config, - ).build() +def test_split_parts(dataset_path): + with open(dataset_path / MAIN_FOLDER_NAME / "split.yaml", "r") as f: + print(f.read()) + with open(dataset_path / MAIN_FOLDER_NAME / "split2.yaml", "r") as f: + print(f.read()) + + ds = get_dataset_from_config( + dataset_path, + split_config="split2.yaml", + split_part="train", + worker_config=WorkerConfig(rank=0, world_size=1, num_workers=0), + training=False, + sample_type=TextSample, + ) + with get_loader(ds.build()) as dl: + all_keys = [sample.__key__ for sample in dl] + assert all_keys == [ + "parts/data-4.tar/000011", # Shard 4 first + "parts/data-4.tar/000012", + "parts/data-4.tar/000013", + "parts/data-4.tar/000014", + "parts/data-4.tar/000015", + "parts/data-4.tar/000016", + "parts/data-4.tar/000017", + "parts/data-4.tar/000018", + "parts/data-4.tar/000019", + "parts/data-4.tar/000020", + "parts/data-0.tar/000000", # Shard 0 + "parts/data-0.tar/000001", + "parts/data-2.tar/000004", # Shard 2 + "parts/data-2.tar/000005", + "parts/data-2.tar/000006", + ] - # Check len operator - assert len(ds) == 55 - # Check if iterating returns the same - with get_loader(ds) as l1: - iter1 = list(l1) - with get_loader(ds) as l2: - iter2 = list(l2) - assert len(iter1) == 55 - assert len(iter2) == 55 - assert all(elem1.__key__ == elem2.__key__ for elem1, elem2 in zip(iter1, iter2)) - with get_loader(ds) as l3: - assert all(f"{idx}" == x.text for idx, x in enumerate(l3)) - - def test_epoch(self): - torch.manual_seed(42) - worker_config = WorkerConfig(rank=0, world_size=1, num_workers=5) +def test_text_dataset(dataset_path): + worker_config = WorkerConfig(rank=0, world_size=1, num_workers=0) - # Without shuffle buffer, should yield everything exactly once - ds3 = get_dataset_from_config( - self.dataset_path, + def new_ds(): + return get_dataset_from_config( + dataset_path, split_part="train", - training=True, + training=False, sample_type=TextSample, worker_config=worker_config, - ) - with get_loader(ds3.build()) as loader5: - order9 = [data.text for idx, data in zip(range(55), loader5)] - print(order9) - print(Counter(order9)) - assert all(v == 1 for v in Counter(order9).values()) + ).build() + + ds = new_ds() + + # Check len operator + assert len(ds) == 55 + # Check if iterating returns the same + with get_loader(ds) as l1: + iter1 = list(l1) + with get_loader(new_ds()) as l2: + iter2 = list(l2) + assert len(iter1) == 55 + assert len(iter2) == 55 + assert all(elem1.__key__ == elem2.__key__ for elem1, elem2 in zip(iter1, iter2)) + with get_loader(new_ds()) as l3: + assert all(f"{idx}" == x.text for idx, x in enumerate(l3)) + + +def test_epoch(dataset_path): + torch.manual_seed(42) + + worker_config = WorkerConfig(rank=0, world_size=1, num_workers=5) + + # Without shuffle buffer, should yield everything exactly once + ds3 = get_dataset_from_config( + dataset_path, + split_part="train", + training=True, + sample_type=TextSample, + worker_config=worker_config, + ) + with get_loader(ds3.build()) as loader5: + order9 = [data.text for idx, data in zip(range(55), loader5)] + print(order9) + print(Counter(order9)) + assert all(v == 1 for v in Counter(order9).values()) + + +def test_determinism(dataset_path): + worker_config2 = WorkerConfig(rank=0, world_size=1, num_workers=2) + worker_config2b = WorkerConfig(rank=0, world_size=1, num_workers=2, seed_offset=43) + worker_config4 = WorkerConfig(rank=0, world_size=1, num_workers=4) + + # This seed is used by the dataset to shuffle the data + torch.manual_seed(42) + ds1 = get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config2, + batch_size=1, + shuffle_buffer_size=42, + max_samples_per_sequence=2, + ) + ds1b = get_train_dataset( # Same but different seed + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config2b, + batch_size=1, + shuffle_buffer_size=42, + max_samples_per_sequence=2, + ) + ds2 = get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config2, + batch_size=1, + shuffle_buffer_size=42, + max_samples_per_sequence=2, + ) + ds3 = get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config4, + batch_size=1, + shuffle_buffer_size=42, + max_samples_per_sequence=2, + ) + + # Fork the dataset twice + with get_loader(ds1) as loader1, get_loader(ds2) as loader2: + order4 = [data.text[0] for idx, data in zip(range(55 * 20), loader1)] + order5 = [data.text[0] for idx, data in zip(range(55 * 20), loader1)] + order6 = [data.text[0] for idx, data in zip(range(55 * 20), loader2)] + print(order4) + print(Counter(order4)) + # +-1 is possible due to the random shuffling (actually +-2 is possible) + assert all(17 <= v <= 22 for v in Counter(order4).values()) + + assert order4 != order5 + assert order4 == order6 + + with get_loader(ds1b) as loader3: + order7 = [data.text[0] for idx, data in zip(range(55 * 20), loader3)] + assert order6 != order7 + + with get_loader(ds3) as loader4: + order8 = [data.text[0] for idx, data in zip(range(55 * 100), loader4)] + assert order6 != order8[: len(order6)] + print(Counter(order8)) + assert all(90 <= v <= 110 for v in Counter(order8).values()) + + +def test_determinism_taskencoder(dataset_path): + class TestTaskEncoder(DefaultTaskEncoder): + @stateless(restore_seeds=True) + def encode_sample(self, sample: TextSample) -> TextSample: + rand_str = f"_{torch.randint(0, 1000, (1,)).item()}_{random.randint(0, 1000)}" + return TextSample( + __key__=sample.__key__, + __restore_key__=sample.__restore_key__, + __subflavors__=sample.__subflavors__, + text=sample.text + rand_str, + ) - def test_determinism(self): - worker_config2 = WorkerConfig(rank=0, world_size=1, num_workers=2) - worker_config2b = WorkerConfig(rank=0, world_size=1, num_workers=2, seed_offset=43) - worker_config4 = WorkerConfig(rank=0, world_size=1, num_workers=4) + for num_workers in [0, 1]: + worker_config1 = WorkerConfig(rank=0, world_size=1, num_workers=num_workers) # This seed is used by the dataset to shuffle the data torch.manual_seed(42) - ds1 = get_train_dataset( - self.dataset_path, + ds1a = get_train_dataset( + dataset_path, split_part="train", sample_type=TextSample, - worker_config=worker_config2, + worker_config=worker_config1, batch_size=1, shuffle_buffer_size=42, max_samples_per_sequence=2, + task_encoder=TestTaskEncoder(), ) - ds1b = get_train_dataset( # Same but different seed - self.dataset_path, + + torch.manual_seed(44) + ds1b = get_train_dataset( + dataset_path, split_part="train", sample_type=TextSample, - worker_config=worker_config2b, + worker_config=worker_config1, batch_size=1, shuffle_buffer_size=42, max_samples_per_sequence=2, + task_encoder=TestTaskEncoder(), ) - ds2 = get_train_dataset( - self.dataset_path, + + # Fork the dataset twice + with get_loader(ds1a) as loader1a, get_loader(ds1b) as loader1b: + order1a = [data.text[0] for idx, data in zip(range(55 * 20), loader1a)] + order1b = [data.text[0] for idx, data in zip(range(55 * 20), loader1b)] + + assert order1a == order1b + assert order1a == order1b + + +def test_determinism_taskencoder_save_restore(dataset_path): + class TestTaskEncoder(DefaultTaskEncoder): + @stateless(restore_seeds=True) + def encode_sample(self, sample: TextSample) -> TextSample: + rand_str = ( + f"_{torch.randint(0, 1000, (1,)).item()}_{random.randint(0, 1000)}" + + f"_{WorkerConfig.active_worker_config.worker_seed()}" + + f"_{self.current_batch_index}_{self.current_sample_index}" + ) + print(f"For sample {sample.__restore_key__}: {sample.text}{rand_str}") + + return TextSample( + __key__=sample.__key__, + __restore_key__=sample.__restore_key__, + __subflavors__=sample.__subflavors__, + text=sample.text + rand_str, + ) + + for num_workers in [1, 0]: + worker_config1 = WorkerConfig(rank=0, world_size=1, num_workers=num_workers) + + # This seed is used by the dataset to shuffle the data + torch.manual_seed(42) + ds1a = get_train_dataset( + dataset_path, split_part="train", sample_type=TextSample, - worker_config=worker_config2, + worker_config=worker_config1, batch_size=1, shuffle_buffer_size=42, max_samples_per_sequence=2, + task_encoder=TestTaskEncoder(), ) - ds3 = get_train_dataset( - self.dataset_path, + + torch.manual_seed(44) + ds1b = get_train_dataset( + dataset_path, split_part="train", sample_type=TextSample, - worker_config=worker_config4, + worker_config=worker_config1, batch_size=1, shuffle_buffer_size=42, max_samples_per_sequence=2, + task_encoder=TestTaskEncoder(), ) # Fork the dataset twice - with get_loader(ds1) as loader1, get_loader(ds2) as loader2: - order4 = [data.text[0] for idx, data in zip(range(55 * 20), loader1)] - order5 = [data.text[0] for idx, data in zip(range(55 * 20), loader1)] - order6 = [data.text[0] for idx, data in zip(range(55 * 20), loader2)] - print(order4) - print(Counter(order4)) - # +-1 is possible due to the random shuffling (actually +-2 is possible) - assert all(17 <= v <= 22 for v in Counter(order4).values()) - - assert order4 != order5 - assert order4 == order6 - - with get_loader(ds1b) as loader3: - order7 = [data.text[0] for idx, data in zip(range(55 * 20), loader3)] - assert order6 != order7 - - with get_loader(ds3) as loader4: - order8 = [data.text[0] for idx, data in zip(range(55 * 100), loader4)] - assert order6 != order8[: len(order6)] - print(Counter(order8)) - assert all(90 <= v <= 110 for v in Counter(order8).values()) - - def test_determinism_taskencoder(self): - class TestTaskEncoder(DefaultTaskEncoder): - @stateless(restore_seeds=True) - def encode_sample(self, sample: TextSample) -> TextSample: - rand_str = f"_{torch.randint(0, 1000, (1,)).item()}_{random.randint(0, 1000)}" - return TextSample( - __key__=sample.__key__, - __restore_key__=sample.__restore_key__, - __subflavors__=sample.__subflavors__, - text=sample.text + rand_str, - ) + with get_savable_loader(ds1a) as loader1a: + # Load 7 samples + _data_pre = [data.text[0] for idx, data in zip(range(7), loader1a)] - for num_workers in [0, 1]: - worker_config1 = WorkerConfig(rank=0, world_size=1, num_workers=num_workers) + # Then save state + state = loader1a.save_state_rank() - # This seed is used by the dataset to shuffle the data - torch.manual_seed(42) - ds1a = get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config1, - batch_size=1, - shuffle_buffer_size=42, - max_samples_per_sequence=2, - task_encoder=TestTaskEncoder(), - ) + print("iterating loader1a") + # Load another 20 samples + data_post = [data.text[0] for idx, data in zip(range(20), loader1a)] - torch.manual_seed(44) - ds1b = get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config1, - batch_size=1, - shuffle_buffer_size=42, - max_samples_per_sequence=2, - task_encoder=TestTaskEncoder(), - ) + # Restore state + with get_savable_loader(ds1b).with_restored_state_rank(state) as loader1b: + print("iterating loader1b") + # Load 20 samples again + data_restored = [data.text[0] for idx, data in zip(range(20), loader1b)] - # Fork the dataset twice - with get_loader(ds1a) as loader1a, get_loader(ds1b) as loader1b: - order1a = [data.text[0] for idx, data in zip(range(55 * 20), loader1a)] - order1b = [data.text[0] for idx, data in zip(range(55 * 20), loader1b)] - - assert order1a == order1b - assert order1a == order1b - - def test_determinism_taskencoder_save_restore(self): - class TestTaskEncoder(DefaultTaskEncoder): - @stateless(restore_seeds=True) - def encode_sample(self, sample: TextSample) -> TextSample: - rand_str = ( - f"_{torch.randint(0, 1000, (1,)).item()}_{random.randint(0, 1000)}" - + f"_{WorkerConfig.active_worker_config.worker_seed()}" - + f"_{self.current_batch_index}_{self.current_sample_index}" - ) - print(f"For sample {sample.__restore_key__}: {sample.text}{rand_str}") + print("Data post:", data_post) + print("Data restored:", data_restored) - return TextSample( - __key__=sample.__key__, - __restore_key__=sample.__restore_key__, - __subflavors__=sample.__subflavors__, - text=sample.text + rand_str, - ) + assert data_post == data_restored - for num_workers in [1, 0]: - worker_config1 = WorkerConfig(rank=0, world_size=1, num_workers=num_workers) - # This seed is used by the dataset to shuffle the data - torch.manual_seed(42) - ds1a = get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config1, - batch_size=1, - shuffle_buffer_size=42, - max_samples_per_sequence=2, - task_encoder=TestTaskEncoder(), - ) +def test_restore_state(dataset_path): + worker_config = WorkerConfig(rank=0, world_size=1, num_workers=0) + + count1 = 55 * 20 + count2 = 55 * 20 + sbs = 42 + # count1 = 4 + # count2 = 2 + # sbs = None + psi = None - torch.manual_seed(44) - ds1b = get_train_dataset( - self.dataset_path, + # This seed is used by the dataset to shuffle the data + torch.manual_seed(42) + + with get_savable_loader( + get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=sbs, + max_samples_per_sequence=2, + parallel_shard_iters=psi, + ) + ) as loader: + # print("save state") + state_0 = loader.save_state_global(global_dst_rank=0) + # print("save state done") + order_1 = [data.text[0] for idx, data in zip(range(count1), loader)] + assert len(order_1) == count1 + # print("save state") + state_1 = loader.save_state_global(global_dst_rank=0) + # print("save state done") + order_2 = [data.text[0] for idx, data in zip(range(count2), loader)] + assert len(order_2) == count2 + + print("state0", state_0) + print("state1", state_1) + + torch.manual_seed(213) + with get_savable_loader( + get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=sbs, + max_samples_per_sequence=2, + parallel_shard_iters=psi, + ) + ).with_restored_state_global(state_0, src_rank=None) as loader: + order_45 = [data.text[0] for idx, data in zip(range(count1 + count2), loader)] + order_4 = order_45[:count1] + order_5 = order_45[count1:] + # print("order1", order_1) + # print("order2", order_2) + # print("order4", order_4) + assert order_1 == order_4 + # print("order5", order_5) + assert order_2 == order_5 + + torch.manual_seed(145) + with get_savable_loader( + get_train_dataset( + dataset_path, split_part="train", sample_type=TextSample, - worker_config=worker_config1, + worker_config=worker_config, batch_size=1, - shuffle_buffer_size=42, + shuffle_buffer_size=sbs, max_samples_per_sequence=2, - task_encoder=TestTaskEncoder(), + parallel_shard_iters=psi, ) + ).with_restored_state_global(state_1, src_rank=None) as loader: + order_3 = [data.text[0] for idx, data in zip(range(count2), loader)] + # print("order1", order_1) + # print("order2", order_2[:100]) + # print("order3", order_3[:100]) + assert order_2 == order_3 - # Fork the dataset twice - with get_savable_loader(ds1a) as loader1a: - # Load 7 samples - _data_pre = [data.text[0] for idx, data in zip(range(7), loader1a)] - - # Then save state - state = loader1a.save_state_rank() - - print("iterating loader1a") - # Load another 20 samples - data_post = [data.text[0] for idx, data in zip(range(20), loader1a)] - # Restore state - with get_savable_loader(ds1b).with_restored_state_rank(state) as loader1b: - print("iterating loader1b") - # Load 20 samples again - data_restored = [data.text[0] for idx, data in zip(range(20), loader1b)] +def test_restore_state_dist(dataset_path): + from multiprocessing import Manager, Process - print("Data post:", data_post) - print("Data restored:", data_restored) + import torch.distributed as dist - assert data_post == data_restored + world_size = 3 - def test_restore_state(self): - worker_config = WorkerConfig(rank=0, world_size=1, num_workers=0) + count1 = 55 * 20 + count2 = 55 * 20 + sbs = 42 + psi = None - count1 = 55 * 20 - count2 = 55 * 20 - sbs = 42 - # count1 = 4 - # count2 = 2 - # sbs = None - psi = None + def phase1(rank: int, world_size: int, shared_dict: dict): + worker_config = WorkerConfig(rank=rank, world_size=world_size, num_workers=0) # This seed is used by the dataset to shuffle the data torch.manual_seed(42) with get_savable_loader( get_train_dataset( - self.dataset_path, + dataset_path, split_part="train", sample_type=TextSample, worker_config=worker_config, @@ -439,24 +546,40 @@ def test_restore_state(self): parallel_shard_iters=psi, ) ) as loader: - # print("save state") state_0 = loader.save_state_global(global_dst_rank=0) - # print("save state done") order_1 = [data.text[0] for idx, data in zip(range(count1), loader)] assert len(order_1) == count1 - # print("save state") + + # print(f"Rank {rank}: order_1", order_1) + state_1 = loader.save_state_global(global_dst_rank=0) - # print("save state done") order_2 = [data.text[0] for idx, data in zip(range(count2), loader)] assert len(order_2) == count2 - print("state0", state_0) - print("state1", state_1) + shared_dict[(rank, "order_1")] = order_1 + shared_dict[(rank, "order_2")] = order_2 + + if rank == 0: + shared_dict["state_0"] = state_0 + shared_dict["state_1"] = state_1 + + def phase2(rank: int, world_size: int, shared_dict: dict): + order_1 = shared_dict[(rank, "order_1")] + order_2 = shared_dict[(rank, "order_2")] + + if rank == 0: + state_0 = shared_dict["state_0"] + state_1 = shared_dict["state_1"] + else: + state_0 = None + state_1 = None + + worker_config = WorkerConfig(rank=rank, world_size=world_size, num_workers=0) torch.manual_seed(213) with get_savable_loader( get_train_dataset( - self.dataset_path, + dataset_path, split_part="train", sample_type=TextSample, worker_config=worker_config, @@ -465,429 +588,413 @@ def test_restore_state(self): max_samples_per_sequence=2, parallel_shard_iters=psi, ) - ).with_restored_state_global(state_0, src_rank=None) as loader: + ).with_restored_state_global(state_0, src_rank=0) as loader: order_45 = [data.text[0] for idx, data in zip(range(count1 + count2), loader)] order_4 = order_45[:count1] order_5 = order_45[count1:] - # print("order1", order_1) - # print("order2", order_2) - # print("order4", order_4) + + # print(f"Rank {rank}: order_4", order_4) + assert order_1 == order_4 - # print("order5", order_5) assert order_2 == order_5 - torch.manual_seed(145) - with get_savable_loader( - get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=sbs, - max_samples_per_sequence=2, - parallel_shard_iters=psi, - ) - ).with_restored_state_global(state_1, src_rank=None) as loader: - order_3 = [data.text[0] for idx, data in zip(range(count2), loader)] - # print("order1", order_1) - # print("order2", order_2[:100]) - # print("order3", order_3[:100]) - assert order_2 == order_3 - - def test_restore_state_dist(self): - from multiprocessing import Manager, Process + torch.manual_seed(213) + with get_savable_loader( + get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=sbs, + max_samples_per_sequence=2, + parallel_shard_iters=psi, + ) + ).with_restored_state_global(state_1, src_rank=0) as loader: + order_3 = [data.text[0] for idx, data in zip(range(count2), loader)] + assert order_2 == order_3 + + def init_process(rank, world_size, shared_dict, fn, backend="gloo"): + """Initializes the distributed environment.""" + dist.init_process_group( + backend=backend, + init_method="tcp://127.0.0.1:12355", + world_size=world_size, + rank=rank, + ) + fn(rank, world_size, shared_dict) + dist.destroy_process_group() + + with Manager() as manager: + shared_dict = manager.dict() + + # Phase 1 (save state) + processes = [] + for rank in range(world_size): + p = Process(target=init_process, args=(rank, world_size, shared_dict, phase1)) + p.start() + processes.append(p) + + for p in processes: + p.join() + + # Phase 2 (restore state) + processes = [] + for rank in range(world_size): + p = Process(target=init_process, args=(rank, world_size, shared_dict, phase2)) + p.start() + processes.append(p) + + for p in processes: + p.join() + + +def test_restore_state_workers(dataset_path): + worker_config = WorkerConfig(rank=0, world_size=1, num_workers=2) + + psi = 2 + sbs = 42 + n1 = 18 + n2 = 109 + n3 = 28 + + # This seed is used by the dataset to shuffle the data + torch.manual_seed(42) + ds = get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=sbs, + max_samples_per_sequence=2, + parallel_shard_iters=psi, + ) + with get_savable_loader(ds) as loader: + # print("save state") + state_0 = loader.save_state_rank() + it1 = iter(loader) + # print("save state done") + order_1 = [data.text[0] for idx, data in zip(range(n1), it1)] + # print("save state") + # time.sleep(0.5) + state_1 = loader.save_state_rank() + # print("save state done") + order_2 = [data.text[0] for idx, data in zip(range(n2), it1)] + state_2 = loader.save_state_rank() + order_3 = [data.text[0] for idx, data in zip(range(n3), it1)] + + print("order_1", order_1) + print("order_2", order_2) + print("order_3", order_3) + + # print("state0", state_0) + print("state1", state_1) + print("state2", state_2) + + # Restoring the state of a new dataset should also yield the same + torch.manual_seed(42) + ds = get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=sbs, + max_samples_per_sequence=2, + parallel_shard_iters=psi, + ) + with get_savable_loader(ds).with_restored_state_rank(state_0) as loader: + order_6 = [data.text[0] for idx, data in zip(range(n1), loader)] + print("order1", order_1) + print("order6", order_6) + assert order_6 == order_1 + + # Restoring the state of a new dataset should also yield the same + torch.manual_seed(42) + ds = get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=sbs, + max_samples_per_sequence=2, + parallel_shard_iters=psi, + ) + with get_savable_loader(ds).with_restored_state_rank(state_1) as loader: + order_7 = [data.text[0] for idx, data in zip(range(n2), loader)] + print("order2", order_2[:100]) + print("order7", order_7[:100]) + assert order_7 == order_2 + + # Restoring the state of a new dataset should also yield the same + torch.manual_seed(42) + ds = get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=worker_config, + batch_size=1, + max_samples_per_sequence=2, + shuffle_buffer_size=sbs, + parallel_shard_iters=psi, + ) + with get_savable_loader(ds).with_restored_state_rank(state_2) as loader: + order_8 = [data.text[0] for idx, data in zip(range(n3), loader)] + print("order3", order_3) + print("order8", order_8) + assert order_8 == order_3 + + +def test_invariance_global_samples(dataset_path): + # We'd like to ensure that the user can keep the same global batches + # (deterministic pseudo random order) when changing the number of ranks (world size). + + # This can be achieved by obeying a few constraints: + # - Global batch size must stay the same across runs + # - Global batch size must be a multiple of (micro-batch size * world_size * num_workers) + # - Global batch size = micro-batch size * world_size * num_workers * gradient_accum_steps + # - world_size * num_workers must stay the same across runs + # Set the same torch.manual_seed(...) on each rank before constructing the dataset and the data loader + + scenarios = [ + dict( + configs=(WorkerConfig(rank=0, world_size=1, num_workers=4),), + micro_batch_size=2, + global_batch_size=8, + ), + dict( + configs=( + WorkerConfig(rank=0, world_size=2, num_workers=2), + WorkerConfig(rank=1, world_size=2, num_workers=2), + ), + micro_batch_size=2, + global_batch_size=8, + ), + dict( + configs=( + WorkerConfig(rank=0, world_size=4, num_workers=1), + WorkerConfig(rank=1, world_size=4, num_workers=1), + WorkerConfig(rank=2, world_size=4, num_workers=1), + WorkerConfig(rank=3, world_size=4, num_workers=1), + ), + micro_batch_size=2, + global_batch_size=8, + ), + dict( + configs=( + WorkerConfig(rank=0, world_size=2, num_workers=2), + WorkerConfig(rank=1, world_size=2, num_workers=2), + ), + micro_batch_size=1, # Micro-batch 1, more accum + global_batch_size=8, + ), + ] - import torch.distributed as dist + # Constraints to user: - world_size = 3 + global_batches_per_scenario = [] + for scenario in scenarios: + assert scenario["global_batch_size"] % scenario["micro_batch_size"] == 0, ( + "Global batch size must be a multiple of the micro-batch size." + ) - count1 = 55 * 20 - count2 = 55 * 20 - sbs = 42 - psi = None + world_size = len(scenario["configs"]) + gradient_accum_steps = scenario["global_batch_size"] // ( + scenario["micro_batch_size"] * world_size + ) - def phase1(rank: int, world_size: int, shared_dict: dict): - worker_config = WorkerConfig(rank=rank, world_size=world_size, num_workers=0) + batches_per_rank = [] - # This seed is used by the dataset to shuffle the data + for rank_config in scenario["configs"]: torch.manual_seed(42) - - with get_savable_loader( - get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=sbs, - max_samples_per_sequence=2, - parallel_shard_iters=psi, - ) - ) as loader: - state_0 = loader.save_state_global(global_dst_rank=0) - order_1 = [data.text[0] for idx, data in zip(range(count1), loader)] - assert len(order_1) == count1 - - # print(f"Rank {rank}: order_1", order_1) - - state_1 = loader.save_state_global(global_dst_rank=0) - order_2 = [data.text[0] for idx, data in zip(range(count2), loader)] - assert len(order_2) == count2 - - shared_dict[(rank, "order_1")] = order_1 - shared_dict[(rank, "order_2")] = order_2 - - if rank == 0: - shared_dict["state_0"] = state_0 - shared_dict["state_1"] = state_1 - - def phase2(rank: int, world_size: int, shared_dict: dict): - order_1 = shared_dict[(rank, "order_1")] - order_2 = shared_dict[(rank, "order_2")] - - if rank == 0: - state_0 = shared_dict["state_0"] - state_1 = shared_dict["state_1"] - else: - state_0 = None - state_1 = None - - worker_config = WorkerConfig(rank=rank, world_size=world_size, num_workers=0) - - torch.manual_seed(213) - with get_savable_loader( - get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=sbs, - max_samples_per_sequence=2, - parallel_shard_iters=psi, - ) - ).with_restored_state_global(state_0, src_rank=0) as loader: - order_45 = [data.text[0] for idx, data in zip(range(count1 + count2), loader)] - order_4 = order_45[:count1] - order_5 = order_45[count1:] - - # print(f"Rank {rank}: order_4", order_4) - - assert order_1 == order_4 - assert order_2 == order_5 - - torch.manual_seed(213) - with get_savable_loader( - get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=sbs, - max_samples_per_sequence=2, - parallel_shard_iters=psi, - ) - ).with_restored_state_global(state_1, src_rank=0) as loader: - order_3 = [data.text[0] for idx, data in zip(range(count2), loader)] - assert order_2 == order_3 - - def init_process(rank, world_size, shared_dict, fn, backend="gloo"): - """Initializes the distributed environment.""" - dist.init_process_group( - backend=backend, - init_method="tcp://127.0.0.1:12355", - world_size=world_size, - rank=rank, + ds = get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=rank_config, + batch_size=scenario["micro_batch_size"], + shuffle_buffer_size=42, + max_samples_per_sequence=2, ) - fn(rank, world_size, shared_dict) - dist.destroy_process_group() - - with Manager() as manager: - shared_dict = manager.dict() - - # Phase 1 (save state) - processes = [] - for rank in range(world_size): - p = Process(target=init_process, args=(rank, world_size, shared_dict, phase1)) - p.start() - processes.append(p) + with get_loader(ds) as loader: + micro_batches = [ + data.text + for idx, data in zip( + range(55 * 8 // (world_size * scenario["micro_batch_size"])), loader + ) + ] + batches_per_rank.append(micro_batches) - for p in processes: - p.join() + # Compose global batches + global_batches_cur_rank = [] + batch_index = 0 + while batch_index < len(batches_per_rank[0]): + global_batch = [] + for _ in range(gradient_accum_steps): + for rank_batches in batches_per_rank: + global_batch.extend(rank_batches[batch_index]) + batch_index += 1 + if batch_index >= len(batches_per_rank[0]): + # last global batch may be smaller + break + global_batches_cur_rank.append(sorted(global_batch)) - # Phase 2 (restore state) - processes = [] - for rank in range(world_size): - p = Process(target=init_process, args=(rank, world_size, shared_dict, phase2)) - p.start() - processes.append(p) + global_batches_per_scenario.append(global_batches_cur_rank) - for p in processes: - p.join() + # Check that the global batches are the same - def test_restore_state_workers(self): - worker_config = WorkerConfig(rank=0, world_size=1, num_workers=2) + # Assert that all scenarios produced the same number of global batches + assert all( + len(global_batches) == len(global_batches_per_scenario[0]) + for global_batches in global_batches_per_scenario + ), "Number of global batches per scenario does not match." - psi = 2 - sbs = 42 - n1 = 18 - n2 = 109 - n3 = 28 + for global_batches in global_batches_per_scenario: + print("= Global batches per scenario") + for global_batch in global_batches: + print(" Global batch: ", global_batch) - # This seed is used by the dataset to shuffle the data - torch.manual_seed(42) - ds = get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=sbs, - max_samples_per_sequence=2, - parallel_shard_iters=psi, - ) - with get_savable_loader(ds) as loader: - # print("save state") - state_0 = loader.save_state_rank() - it1 = iter(loader) - # print("save state done") - order_1 = [data.text[0] for idx, data in zip(range(n1), it1)] - # print("save state") - # time.sleep(0.5) - state_1 = loader.save_state_rank() - # print("save state done") - order_2 = [data.text[0] for idx, data in zip(range(n2), it1)] - state_2 = loader.save_state_rank() - order_3 = [data.text[0] for idx, data in zip(range(n3), it1)] - - print("order_1", order_1) - print("order_2", order_2) - print("order_3", order_3) - - # print("state0", state_0) - print("state1", state_1) - print("state2", state_2) - - # Restoring the state of a new dataset should also yield the same - torch.manual_seed(42) - ds = get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=sbs, - max_samples_per_sequence=2, - parallel_shard_iters=psi, - ) - with get_savable_loader(ds).with_restored_state_rank(state_0) as loader: - order_6 = [data.text[0] for idx, data in zip(range(n1), loader)] - print("order1", order_1) - print("order6", order_6) - assert order_6 == order_1 + # Assert that all global batches are the same + for i in range(len(global_batches_per_scenario[0])): + for scenerio_idx, global_batches in enumerate(global_batches_per_scenario): + assert global_batches[i] == global_batches_per_scenario[0][i], ( + f"Global batch {i} of scenario {scenerio_idx} does not match." + ) - # Restoring the state of a new dataset should also yield the same - torch.manual_seed(42) - ds = get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=sbs, - max_samples_per_sequence=2, - parallel_shard_iters=psi, - ) - with get_savable_loader(ds).with_restored_state_rank(state_1) as loader: - order_7 = [data.text[0] for idx, data in zip(range(n2), loader)] - print("order2", order_2[:100]) - print("order7", order_7[:100]) - assert order_7 == order_2 - # Restoring the state of a new dataset should also yield the same - torch.manual_seed(42) - ds = get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=worker_config, - batch_size=1, - max_samples_per_sequence=2, - shuffle_buffer_size=sbs, - parallel_shard_iters=psi, - ) - with get_savable_loader(ds).with_restored_state_rank(state_2) as loader: - order_8 = [data.text[0] for idx, data in zip(range(n3), loader)] - print("order3", order_3) - print("order8", order_8) - assert order_8 == order_3 - - def test_invariance_global_samples(self): - # We'd like to ensure that the user can keep the same global batches - # (deterministic pseudo random order) when changing the number of ranks (world size). - - # This can be achieved by obeying a few constraints: - # - Global batch size must stay the same across runs - # - Global batch size must be a multiple of (micro-batch size * world_size * num_workers) - # - Global batch size = micro-batch size * world_size * num_workers * gradient_accum_steps - # - world_size * num_workers must stay the same across runs - # Set the same torch.manual_seed(...) on each rank before constructing the dataset and the data loader - - scenarios = [ - dict( - configs=(WorkerConfig(rank=0, world_size=1, num_workers=4),), - micro_batch_size=2, - global_batch_size=8, +def test_redist(dataset_path, checkpoint_dir, redist_dir): + scenarios = [ + dict( + configs=( + WorkerConfig(rank=0, world_size=2, num_workers=2), + WorkerConfig(rank=1, world_size=2, num_workers=2), ), - dict( - configs=( - WorkerConfig(rank=0, world_size=2, num_workers=2), - WorkerConfig(rank=1, world_size=2, num_workers=2), - ), - micro_batch_size=2, - global_batch_size=8, + micro_batch_size=2, + global_batch_size=8, + ), + dict( + configs=(WorkerConfig(rank=0, world_size=1, num_workers=4),), + micro_batch_size=2, + global_batch_size=8, + ), + dict( + configs=( + WorkerConfig(rank=0, world_size=4, num_workers=1), + WorkerConfig(rank=1, world_size=4, num_workers=1), + WorkerConfig(rank=2, world_size=4, num_workers=1), + WorkerConfig(rank=3, world_size=4, num_workers=1), ), - dict( - configs=( - WorkerConfig(rank=0, world_size=4, num_workers=1), - WorkerConfig(rank=1, world_size=4, num_workers=1), - WorkerConfig(rank=2, world_size=4, num_workers=1), - WorkerConfig(rank=3, world_size=4, num_workers=1), - ), - micro_batch_size=2, - global_batch_size=8, + micro_batch_size=2, + global_batch_size=8, + ), + dict( + configs=( + WorkerConfig(rank=0, world_size=2, num_workers=2), + WorkerConfig(rank=1, world_size=2, num_workers=2), ), - dict( - configs=( - WorkerConfig(rank=0, world_size=2, num_workers=2), - WorkerConfig(rank=1, world_size=2, num_workers=2), - ), - micro_batch_size=1, # Micro-batch 1, more accum - global_batch_size=8, + micro_batch_size=1, # Micro-batch 1, more accum + global_batch_size=8, + ), + dict( # Same as original + configs=( + WorkerConfig(rank=0, world_size=2, num_workers=2), + WorkerConfig(rank=1, world_size=2, num_workers=2), ), - ] + micro_batch_size=2, + global_batch_size=8, + ), + ] - # Constraints to user: + # === Stage 1 first generate a saved state using scenario 0 + checkpoint_files = [] - global_batches_per_scenario = [] - for scenario in scenarios: - assert scenario["global_batch_size"] % scenario["micro_batch_size"] == 0, ( - "Global batch size must be a multiple of the micro-batch size." - ) + global_batches_per_scenario = [] + scenario = scenarios[0] - world_size = len(scenario["configs"]) - gradient_accum_steps = scenario["global_batch_size"] // ( - scenario["micro_batch_size"] * world_size - ) + world_size = len(scenario["configs"]) + gradient_accum_steps = scenario["global_batch_size"] // ( + scenario["micro_batch_size"] * world_size + ) - batches_per_rank = [] + batches_per_rank = [] - for rank_config in scenario["configs"]: - torch.manual_seed(42) - ds = get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=rank_config, - batch_size=scenario["micro_batch_size"], - shuffle_buffer_size=42, - max_samples_per_sequence=2, - ) - with get_loader(ds) as loader: - micro_batches = [ - data.text - for idx, data in zip( - range(55 * 8 // (world_size * scenario["micro_batch_size"])), loader - ) - ] - batches_per_rank.append(micro_batches) - - # Compose global batches - global_batches_cur_rank = [] - batch_index = 0 - while batch_index < len(batches_per_rank[0]): - global_batch = [] - for _ in range(gradient_accum_steps): - for rank_batches in batches_per_rank: - global_batch.extend(rank_batches[batch_index]) - batch_index += 1 - if batch_index >= len(batches_per_rank[0]): - # last global batch may be smaller - break - global_batches_cur_rank.append(sorted(global_batch)) - - global_batches_per_scenario.append(global_batches_cur_rank) - - # Check that the global batches are the same - - # Assert that all scenarios produced the same number of global batches - assert all( - len(global_batches) == len(global_batches_per_scenario[0]) - for global_batches in global_batches_per_scenario - ), "Number of global batches per scenario does not match." - - for global_batches in global_batches_per_scenario: - print("= Global batches per scenario") - for global_batch in global_batches: - print(" Global batch: ", global_batch) - - # Assert that all global batches are the same - for i in range(len(global_batches_per_scenario[0])): - for scenerio_idx, global_batches in enumerate(global_batches_per_scenario): - assert global_batches[i] == global_batches_per_scenario[0][i], ( - f"Global batch {i} of scenario {scenerio_idx} does not match." + for rank_config in scenario["configs"]: + with get_savable_loader( + get_train_dataset( + dataset_path, + split_part="train", + sample_type=TextSample, + worker_config=rank_config, + batch_size=scenario["micro_batch_size"], + shuffle_buffer_size=42, + max_samples_per_sequence=2, + ) + ) as loader: + # Throw away some samples to advance the loader state + num_pre_samples = 20 + for _ in zip(range(num_pre_samples), loader): + pass + + # Save the state to a file + checkpoint_file = checkpoint_dir / f"state_rank{rank_config.rank}.pt" + state = loader.save_state_rank() + torch.save(state, str(checkpoint_file)) + checkpoint_files.append(checkpoint_file) + + # Now capture the next micro-batches + micro_batches = [ + data.text + for idx, data in zip( + range(55 * 8 // (world_size * scenario["micro_batch_size"])), loader ) - - def test_redist(self): - scenarios = [ - dict( - configs=( - WorkerConfig(rank=0, world_size=2, num_workers=2), - WorkerConfig(rank=1, world_size=2, num_workers=2), - ), - micro_batch_size=2, - global_batch_size=8, - ), - dict( - configs=(WorkerConfig(rank=0, world_size=1, num_workers=4),), - micro_batch_size=2, - global_batch_size=8, - ), - dict( - configs=( - WorkerConfig(rank=0, world_size=4, num_workers=1), - WorkerConfig(rank=1, world_size=4, num_workers=1), - WorkerConfig(rank=2, world_size=4, num_workers=1), - WorkerConfig(rank=3, world_size=4, num_workers=1), - ), - micro_batch_size=2, - global_batch_size=8, - ), - dict( - configs=( - WorkerConfig(rank=0, world_size=2, num_workers=2), - WorkerConfig(rank=1, world_size=2, num_workers=2), - ), - micro_batch_size=1, # Micro-batch 1, more accum - global_batch_size=8, - ), - dict( # Same as original - configs=( - WorkerConfig(rank=0, world_size=2, num_workers=2), - WorkerConfig(rank=1, world_size=2, num_workers=2), - ), - micro_batch_size=2, - global_batch_size=8, - ), - ] - - # === Stage 1 first generate a saved state using scenario 0 - checkpoint_files = [] - - global_batches_per_scenario = [] - scenario = scenarios[0] + ] + batches_per_rank.append(micro_batches) + + # Compose global batches + global_batches_cur_rank = [] + batch_index = 0 + while batch_index < len(batches_per_rank[0]): + global_batch = [] + for _ in range(gradient_accum_steps): + for rank_batches in batches_per_rank: + global_batch.extend(rank_batches[batch_index]) + batch_index += 1 + if batch_index >= len(batches_per_rank[0]): + # last global batch may be smaller + break + global_batches_cur_rank.append(sorted(global_batch)) + + global_batches_per_scenario.append(global_batches_cur_rank) + + # === Stage 2: Now check that the global batches are the same after redistribution + + for scenario in scenarios[1:]: + print(f"\n\nRunning scenario {scenario}") + # Redistribute the saved state + runner = CliRunner() + result = runner.invoke( + command_redist, + [ + "--new-world-size", + str(len(scenario["configs"])), + "--new-micro-batch-size", + str(scenario["micro_batch_size"]), + *[str(cpt) for cpt in checkpoint_files], + str(redist_dir), + ], + ) + print(result.output) + if result.exception is not None: + raise result.exception + assert result.exception is None, result.exception + assert result.exit_code == 0, "Redistribution failed" + + # Load state and check that the global batches are the same + assert scenario["global_batch_size"] % scenario["micro_batch_size"] == 0, ( + "Global batch size must be a multiple of the micro-batch size." + ) world_size = len(scenario["configs"]) gradient_accum_steps = scenario["global_batch_size"] // ( @@ -897,9 +1004,13 @@ def test_redist(self): batches_per_rank = [] for rank_config in scenario["configs"]: + state = torch.load( + str(redist_dir / f"state_rank{rank_config.rank}.pt"), weights_only=False + ) + with get_savable_loader( get_train_dataset( - self.dataset_path, + dataset_path, split_part="train", sample_type=TextSample, worker_config=rank_config, @@ -907,19 +1018,7 @@ def test_redist(self): shuffle_buffer_size=42, max_samples_per_sequence=2, ) - ) as loader: - # Throw away some samples to advance the loader state - num_pre_samples = 20 - for _ in zip(range(num_pre_samples), loader): - pass - - # Save the state to a file - checkpoint_file = self.checkpoint_dir / f"state_rank{rank_config.rank}.pt" - state = loader.save_state_rank() - torch.save(state, str(checkpoint_file)) - checkpoint_files.append(checkpoint_file) - - # Now capture the next micro-batches + ).with_restored_state_rank(state) as loader: micro_batches = [ data.text for idx, data in zip( @@ -944,105 +1043,24 @@ def test_redist(self): global_batches_per_scenario.append(global_batches_cur_rank) - # === Stage 2: Now check that the global batches are the same after redistribution - - for scenario in scenarios[1:]: - print(f"\n\nRunning scenario {scenario}") - # Redistribute the saved state - runner = CliRunner() - result = runner.invoke( - command_redist, - [ - "--new-world-size", - str(len(scenario["configs"])), - "--new-micro-batch-size", - str(scenario["micro_batch_size"]), - *[str(cpt) for cpt in checkpoint_files], - str(self.redist_dir), - ], - ) - print(result.output) - if result.exception is not None: - raise result.exception - assert result.exception is None, result.exception - assert result.exit_code == 0, "Redistribution failed" - - # Load state and check that the global batches are the same - assert scenario["global_batch_size"] % scenario["micro_batch_size"] == 0, ( - "Global batch size must be a multiple of the micro-batch size." - ) - - world_size = len(scenario["configs"]) - gradient_accum_steps = scenario["global_batch_size"] // ( - scenario["micro_batch_size"] * world_size - ) - - batches_per_rank = [] + # Check that the global batches are the same - for rank_config in scenario["configs"]: - state = torch.load( - str(self.redist_dir / f"state_rank{rank_config.rank}.pt"), weights_only=False - ) + print() - with get_savable_loader( - get_train_dataset( - self.dataset_path, - split_part="train", - sample_type=TextSample, - worker_config=rank_config, - batch_size=scenario["micro_batch_size"], - shuffle_buffer_size=42, - max_samples_per_sequence=2, - ) - ).with_restored_state_rank(state) as loader: - micro_batches = [ - data.text - for idx, data in zip( - range(55 * 8 // (world_size * scenario["micro_batch_size"])), loader - ) - ] - batches_per_rank.append(micro_batches) - - # Compose global batches - global_batches_cur_rank = [] - batch_index = 0 - while batch_index < len(batches_per_rank[0]): - global_batch = [] - for _ in range(gradient_accum_steps): - for rank_batches in batches_per_rank: - global_batch.extend(rank_batches[batch_index]) - batch_index += 1 - if batch_index >= len(batches_per_rank[0]): - # last global batch may be smaller - break - global_batches_cur_rank.append(sorted(global_batch)) - - global_batches_per_scenario.append(global_batches_cur_rank) - - # Check that the global batches are the same - - print() - - # Assert that all scenarios produced the same global batches - assert all( - len(global_batches) == len(global_batches_per_scenario[0]) - for global_batches in global_batches_per_scenario - ), "Number of global batches per scenario does not match." - - for idx, (global_batches, scenario) in enumerate( - zip(global_batches_per_scenario, scenarios) - ): - print(f"= Global batches per scenario {idx} {scenario}") - for global_batch in global_batches: - print(" Global batch: ", global_batch) - - # Assert that all global batches are the same - for i in range(len(global_batches_per_scenario[0])): - for scenerio_idx, global_batches in enumerate(global_batches_per_scenario): - assert global_batches[i] == global_batches_per_scenario[0][i], ( - f"Global batch {i} of scenario {scenerio_idx} does not match." - ) + # Assert that all scenarios produced the same global batches + assert all( + len(global_batches) == len(global_batches_per_scenario[0]) + for global_batches in global_batches_per_scenario + ), "Number of global batches per scenario does not match." + for idx, (global_batches, scenario) in enumerate(zip(global_batches_per_scenario, scenarios)): + print(f"= Global batches per scenario {idx} {scenario}") + for global_batch in global_batches: + print(" Global batch: ", global_batch) -if __name__ == "__main__": - unittest.main() + # Assert that all global batches are the same + for i in range(len(global_batches_per_scenario[0])): + for scenerio_idx, global_batches in enumerate(global_batches_per_scenario): + assert global_batches[i] == global_batches_per_scenario[0][i], ( + f"Global batch {i} of scenario {scenerio_idx} does not match." + ) diff --git a/tests/test_epathlib.py b/tests/test_epathlib.py index 76e45f19..640f4f83 100644 --- a/tests/test_epathlib.py +++ b/tests/test_epathlib.py @@ -9,273 +9,280 @@ import pickle import struct import sys -import unittest + +import pytest from megatron.energon.epathlib import EPath from tests.epath_s3_emulator import setup_s3_emulator -class TestEPath(unittest.TestCase): - def setUp(self): - logging.basicConfig(stream=sys.stderr, level=logging.INFO) +@pytest.fixture(autouse=True) +def setup_logging(): + logging.basicConfig(stream=sys.stderr, level=logging.INFO) - def tearDown(self): - pass - def test_basic(self): - """Some basic functionality tests""" +def test_basic(): + """Some basic functionality tests""" - p_rel = EPath("./subdir") - p_abs = EPath("/tmp") + p_rel = EPath("./subdir") + p_abs = EPath("/tmp") - p_comb = p_abs / p_rel - # Those should not raise: - assert not p_comb.is_file() - assert not p_abs.is_file() - assert p_abs.is_dir() + p_comb = p_abs / p_rel + # Those should not raise: + assert not p_comb.is_file() + assert not p_abs.is_file() + assert p_abs.is_dir() - def test_contextman(self): - """Test the context manager""" - tmp_file_path = "/tmp/testfile.bin" - # First create a file - with open(tmp_file_path, "wb") as f: - f.write(struct.pack("H10s", 1337, b"1234567890")) +def test_contextman(): + """Test the context manager""" - # Test context manager reading - p = EPath(tmp_file_path).open("rb") - print(p) - with p: - b = p.read() - assert isinstance(b, bytes) + tmp_file_path = "/tmp/testfile.bin" + # First create a file + with open(tmp_file_path, "wb") as f: + f.write(struct.pack("H10s", 1337, b"1234567890")) - num, data = struct.unpack("H10s", b) - logging.info(f"num: {num}") - assert num == 1337 - assert data == b"1234567890" + # Test context manager reading + p = EPath(tmp_file_path).open("rb") + print(p) + with p: + b = p.read() + assert isinstance(b, bytes) - # Test context manager writing - tmp_file_path2 = "/tmp/testfile2.bin" - with EPath(tmp_file_path2).open("wb") as p: - p.write(struct.pack("H10s", 1337, b"1234567890")) + num, data = struct.unpack("H10s", b) + logging.info(f"num: {num}") + assert num == 1337 + assert data == b"1234567890" - def test_localfs(self): - """Test the local filesystem""" - p = EPath("/tmp/testfile.bin") - with p.open("wb") as f: - f.write(b"dummycontent") + # Test context manager writing + tmp_file_path2 = "/tmp/testfile2.bin" + with EPath(tmp_file_path2).open("wb") as p: + p.write(struct.pack("H10s", 1337, b"1234567890")) + + +def test_localfs(): + """Test the local filesystem""" + p = EPath("/tmp/testfile.bin") + with p.open("wb") as f: + f.write(b"dummycontent") + assert p.is_file() + assert p.size() == 12 + with p.open("rb") as f: + assert f.read() == b"dummycontent" + + # Test relative paths + revert_dir = os.getcwd() + try: + os.chdir("/tmp") + p = EPath("testfile.bin") + assert str(p) == "/tmp/testfile.bin" assert p.is_file() assert p.size() == 12 with p.open("rb") as f: assert f.read() == b"dummycontent" - # Test relative paths - revert_dir = os.getcwd() - try: - os.chdir("/tmp") - p = EPath("testfile.bin") - assert str(p) == "/tmp/testfile.bin" - assert p.is_file() - assert p.size() == 12 - with p.open("rb") as f: - assert f.read() == b"dummycontent" - - p = EPath("nonexisting/../testfile.bin") - assert str(p) == "/tmp/testfile.bin" - - p = EPath("../tmp/testfile.bin") - assert str(p) == "/tmp/testfile.bin" - finally: - os.chdir(revert_dir) + p = EPath("nonexisting/../testfile.bin") + assert str(p) == "/tmp/testfile.bin" - p.unlink() - assert p.is_file() is False + p = EPath("../tmp/testfile.bin") + assert str(p) == "/tmp/testfile.bin" + finally: + os.chdir(revert_dir) - def test_glob(self): - """Test the glob functionality""" + p.unlink() + assert p.is_file() is False - # First create some files - for i in range(10): - with open(f"/tmp/epathtestfile_{i}.bin", "wb") as f: - f.write(b"dummycontent") - # Test globbing - p = EPath("/tmp").glob("epathtestfile_*.bin") - - logging.info(f"p: {p}, type of p: {type(p)}") - elems = list(p) - assert len(elems) == 10 - for i, e in enumerate(elems): - logging.info(f"glob_result[{i}]: {e}") - assert isinstance(e, EPath) - assert e.is_file() - - # Test globbing with a pattern - p = EPath("/tmp").glob("epathtestfile_[0-3].bin") - assert len(list(p)) == 4 - - def test_s3_path_resolution(self): - """Test s3 path resolution""" - rclone_config_path = EPath("/tmp/XDG_CONFIG_HOME/.config/rclone/rclone.conf") - with rclone_config_path.open("w") as f: - f.write( - "\n".join( - [ - "[s3]", - "type = s3", - "env_auth = false", - "access_key_id = dummy", - "secret_access_key = dummy", - "region = dummy", - "endpoint = https://localhost", - ] - ) +def test_glob(): + """Test the glob functionality""" + + # First create some files + for i in range(10): + with open(f"/tmp/epathtestfile_{i}.bin", "wb") as f: + f.write(b"dummycontent") + + # Test globbing + p = EPath("/tmp").glob("epathtestfile_*.bin") + + logging.info(f"p: {p}, type of p: {type(p)}") + elems = list(p) + assert len(elems) == 10 + for i, e in enumerate(elems): + logging.info(f"glob_result[{i}]: {e}") + assert isinstance(e, EPath) + assert e.is_file() + + # Test globbing with a pattern + p = EPath("/tmp").glob("epathtestfile_[0-3].bin") + assert len(list(p)) == 4 + + +def test_s3_path_resolution(): + """Test s3 path resolution""" + rclone_config_path = EPath("/tmp/XDG_CONFIG_HOME/.config/rclone/rclone.conf") + with rclone_config_path.open("w") as f: + f.write( + "\n".join( + [ + "[s3]", + "type = s3", + "env_auth = false", + "access_key_id = dummy", + "secret_access_key = dummy", + "region = dummy", + "endpoint = https://localhost", + ] ) + ) - orig_xdg_config_home = os.environ.get("XDG_CONFIG_HOME") - os.environ["XDG_CONFIG_HOME"] = "/tmp/XDG_CONFIG_HOME/.config" - os.environ["HOME"] = "/tmp/XDG_CONFIG_HOME" - # Hack to clear the cache of the rclone config for msc to get the "s3" profile - from multistorageclient.rclone import read_rclone_config - - read_rclone_config.cache_clear() - try: - # Test globbing - p = EPath("msc://s3/tmp/path/subpath.txt") - assert str(p) == "msc://s3/tmp/path/subpath.txt", str(p) - - p2 = p / ".." / "subpath2.txt" - assert str(p2) == "msc://s3/tmp/path/subpath2.txt", str(p2) - - p3 = EPath("msc://s3/tmp/path/.././subpath.txt") - assert str(p3) == "msc://s3/tmp/subpath.txt", str(p3) - - p4 = p3.parent / "../bla/bla/bla/../../../no/../subpath2.txt" - assert str(p4) == "msc://s3/subpath2.txt", str(p4) - - # Test warning for deprecated rclone protocol - with self.assertWarns((DeprecationWarning, FutureWarning)) as warning: - # Test rclone backwards compatibility - pr = EPath("rclone://s3/tmp/path/.././subpath.txt") - assert str(pr) == "msc://s3/tmp/subpath.txt", str(pr) - assert "deprecated" in str(warning.warnings[0].message) - - # Test pickle / unpickle - p4serialized = pickle.dumps(p4) - # No secret must be serialized - assert b"dummy" not in p4serialized - finally: - if orig_xdg_config_home is not None: - os.environ["XDG_CONFIG_HOME"] = orig_xdg_config_home - else: - del os.environ["XDG_CONFIG_HOME"] - rclone_config_path.unlink() - - def test_multi_storage_client(self): - """Test the Multi-Storage Client integration""" - # Test path handling - p = EPath("msc://default/etc/resolv.conf") - assert str(p) == "/etc/resolv.conf", str(p) - assert p.is_file() + orig_xdg_config_home = os.environ.get("XDG_CONFIG_HOME") + os.environ["XDG_CONFIG_HOME"] = "/tmp/XDG_CONFIG_HOME/.config" + os.environ["HOME"] = "/tmp/XDG_CONFIG_HOME" + # Hack to clear the cache of the rclone config for msc to get the "s3" profile + from multistorageclient.rclone import read_rclone_config - p2 = p / ".." / "hosts" - assert str(p2) == "/etc/hosts", str(p2) + read_rclone_config.cache_clear() + try: + # Test globbing + p = EPath("msc://s3/tmp/path/subpath.txt") + assert str(p) == "msc://s3/tmp/path/subpath.txt", str(p) - # Test glob - p3 = EPath("msc://default/etc/") - assert p3.is_dir() - for i in p3.glob("*.conf"): - assert str(i).endswith(".conf") + p2 = p / ".." / "subpath2.txt" + assert str(p2) == "msc://s3/tmp/path/subpath2.txt", str(p2) - # Test open file - assert p.size() > 0 - with p.open("r") as fp: - assert len(fp.read()) > 0 - - # Test move and delete - p4 = EPath("msc://default/tmp/random_file_0001") - p4.unlink() - with p4.open("w") as fp: - fp.write("*****") - assert p4.is_file() - p5 = EPath("msc://default/tmp/random_file_0002") - p5.unlink() - assert p5.is_file() is False - p4.move(p5) - assert p5.is_file() - assert p4.is_file() is False - p5.unlink() - assert p5.is_file() is False + p3 = EPath("msc://s3/tmp/path/.././subpath.txt") + assert str(p3) == "msc://s3/tmp/subpath.txt", str(p3) + + p4 = p3.parent / "../bla/bla/bla/../../../no/../subpath2.txt" + assert str(p4) == "msc://s3/subpath2.txt", str(p4) + + # Test warning for deprecated rclone protocol + with pytest.warns((DeprecationWarning, FutureWarning)) as warning: + # Test rclone backwards compatibility + pr = EPath("rclone://s3/tmp/path/.././subpath.txt") + assert str(pr) == "msc://s3/tmp/subpath.txt", str(pr) + assert "deprecated" in str(warning[0].message) # Test pickle / unpickle - p5serialized = pickle.dumps(p5) - p5unserialized = pickle.loads(p5serialized) - assert p5unserialized == p5 - assert str(p5unserialized) == str(p5) - - def test_multiprocessing(self): - """Test EPath in multiprocessing context""" - p = EPath("/tmp/path/subpath.txt") - - orig_start_method = multiprocessing.get_start_method() - try: - multiprocessing.set_start_method("spawn", force=True) - - proc = multiprocessing.Process(target=_multiproc_test_func, args=(p, True)) - proc.start() - proc.join() - assert proc.exitcode == 0 - - multiprocessing.set_start_method("fork", force=True) - - proc = multiprocessing.Process(target=_multiproc_test_func, args=(p, True)) - proc.start() - proc.join() - assert proc.exitcode == 0 - finally: - multiprocessing.set_start_method(orig_start_method, force=True) - - def test_multiprocessing_msc(self): - """Test EPath in multiprocessing context""" - p = EPath("msc://default/tmp/random_file_0001") - with p.open("w") as fp: - fp.write("*****") - - orig_start_method = multiprocessing.get_start_method() - try: - multiprocessing.set_start_method("spawn", force=True) - - proc = multiprocessing.Process(target=_multiproc_test_func, args=(p, True)) - proc.start() - proc.join() - assert proc.exitcode == 0 - - multiprocessing.set_start_method("fork", force=True) - - proc = multiprocessing.Process(target=_multiproc_test_func, args=(p, True)) - proc.start() - proc.join() - assert proc.exitcode == 0 - finally: - multiprocessing.set_start_method(orig_start_method, force=True) - p.unlink() - - def test_msc_s3(self): - # Test S3 with MSC - with setup_s3_emulator(profile_name="s3test_msc"): - p = EPath("msc://s3test_msc/test/dir/file.txt") - assert not p.is_file() - p.write_text("dummy") - assert p.is_file() - assert p.size() > 0 - assert p.read_text() == "dummy" - # TODO: Fix when fixed in MSC. - # assert EPath("msc://s3test_msc/test").is_dir() - assert EPath("msc://s3test_msc/test/dir").is_dir() - p.unlink() - assert not p.is_file() - # assert not EPath("msc://s3test_msc/test").is_dir() - assert not EPath("msc://s3test_msc/test/dir").is_dir() + p4serialized = pickle.dumps(p4) + # No secret must be serialized + assert b"dummy" not in p4serialized + finally: + if orig_xdg_config_home is not None: + os.environ["XDG_CONFIG_HOME"] = orig_xdg_config_home + else: + del os.environ["XDG_CONFIG_HOME"] + rclone_config_path.unlink() + + +def test_multi_storage_client(): + """Test the Multi-Storage Client integration""" + # Test path handling + p = EPath("msc://default/etc/resolv.conf") + assert str(p) == "/etc/resolv.conf", str(p) + assert p.is_file() + + p2 = p / ".." / "hosts" + assert str(p2) == "/etc/hosts", str(p2) + + # Test glob + p3 = EPath("msc://default/etc/") + assert p3.is_dir() + for i in p3.glob("*.conf"): + assert str(i).endswith(".conf") + + # Test open file + assert p.size() > 0 + with p.open("r") as fp: + assert len(fp.read()) > 0 + + # Test move and delete + p4 = EPath("msc://default/tmp/random_file_0001") + p4.unlink() + with p4.open("w") as fp: + fp.write("*****") + assert p4.is_file() + p5 = EPath("msc://default/tmp/random_file_0002") + p5.unlink() + assert p5.is_file() is False + p4.move(p5) + assert p5.is_file() + assert p4.is_file() is False + p5.unlink() + assert p5.is_file() is False + + # Test pickle / unpickle + p5serialized = pickle.dumps(p5) + p5unserialized = pickle.loads(p5serialized) + assert p5unserialized == p5 + assert str(p5unserialized) == str(p5) + + +def test_multiprocessing(): + """Test EPath in multiprocessing context""" + p = EPath("/tmp/path/subpath.txt") + + orig_start_method = multiprocessing.get_start_method() + try: + multiprocessing.set_start_method("spawn", force=True) + + proc = multiprocessing.Process(target=_multiproc_test_func, args=(p, True)) + proc.start() + proc.join() + assert proc.exitcode == 0 + + multiprocessing.set_start_method("fork", force=True) + + proc = multiprocessing.Process(target=_multiproc_test_func, args=(p, True)) + proc.start() + proc.join() + assert proc.exitcode == 0 + finally: + multiprocessing.set_start_method(orig_start_method, force=True) + + +def test_multiprocessing_msc(): + """Test EPath in multiprocessing context""" + p = EPath("msc://default/tmp/random_file_0001") + with p.open("w") as fp: + fp.write("*****") + + orig_start_method = multiprocessing.get_start_method() + try: + multiprocessing.set_start_method("spawn", force=True) + + proc = multiprocessing.Process(target=_multiproc_test_func, args=(p, True)) + proc.start() + proc.join() + assert proc.exitcode == 0 + + multiprocessing.set_start_method("fork", force=True) + + proc = multiprocessing.Process(target=_multiproc_test_func, args=(p, True)) + proc.start() + proc.join() + assert proc.exitcode == 0 + finally: + multiprocessing.set_start_method(orig_start_method, force=True) + p.unlink() + + +def test_msc_s3(): + # Test S3 with MSC + with setup_s3_emulator(profile_name="s3test_msc"): + p = EPath("msc://s3test_msc/test/dir/file.txt") + assert not p.is_file() + p.write_text("dummy") + assert p.is_file() + assert p.size() > 0 + assert p.read_text() == "dummy" + # TODO: Fix when fixed in MSC. + # assert EPath("msc://s3test_msc/test").is_dir() + assert EPath("msc://s3test_msc/test/dir").is_dir() + p.unlink() + assert not p.is_file() + # assert not EPath("msc://s3test_msc/test").is_dir() + assert not EPath("msc://s3test_msc/test/dir").is_dir() def _multiproc_test_func(p: EPath, test_function: bool): @@ -283,7 +290,3 @@ def _multiproc_test_func(p: EPath, test_function: bool): print(f"str: {str(p)}") if test_function: print(f"is_file: {p.is_file()}") - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_file_cache_pool.py b/tests/test_file_cache_pool.py index 7dc94b77..25a18c8e 100644 --- a/tests/test_file_cache_pool.py +++ b/tests/test_file_cache_pool.py @@ -5,10 +5,11 @@ import tempfile import threading import time -import unittest from pathlib import Path from typing import Any, Dict, Optional +import pytest + from megatron.energon.cache import DecodeFileStore, FileCacheLazy, FileStore, FileStoreCachePool from megatron.energon.flavors.webdataset.sample_decoder import SampleDecoder from megatron.energon.source_info import SourceInfo @@ -40,608 +41,608 @@ def decode(self, fname: str, raw: bytes) -> Any: return f"{fname}: {raw.decode()}" -class TestFileStoreCachePool(unittest.TestCase): - """Test cases for FileStoreCachePool""" - - def setUp(self): - """Setup test environment before each test""" - # Create a temporary directory - self.temp_dir = tempfile.TemporaryDirectory() - self.temp_path = Path(self.temp_dir.name) - - def tearDown(self): - """Clean up after each test""" - self.temp_dir.cleanup() - - def test_get_method(self): - """Test the synchronous get method""" - # Create mock file stores - mock_raw_file_store = MockFileStore( - { - "file1": b"test data 1", - "file2": b"test data 2", - "file3": b"test data 3", - } - ) - - mock_decode_file_store = DecodeFileStore( - decoder=MockDecoder(), - inner_reader=mock_raw_file_store, - ) - pool = FileStoreCachePool(parent_cache_dir=self.temp_path) - try: - # get should directly read from the dataset without caching - sample_for_source_info = {"__sources__": []} - result = pool.get(mock_raw_file_store, "file1", sample_for_source_info) - assert result == b"test data 1" - assert len(sample_for_source_info["__sources__"]) == 1 - assert ( - sample_for_source_info["__sources__"][0].dataset_path - == mock_raw_file_store.get_path() - ) - assert sample_for_source_info["__sources__"][0].index is None - assert sample_for_source_info["__sources__"][0].shard_name is None - assert sample_for_source_info["__sources__"][0].file_names == ("file1",) - - # get should directly read from the dataset without caching - sample_for_source_info = {"__sources__": []} - result = pool.get(mock_decode_file_store, "file1", sample_for_source_info) - assert result == "file1: test data 1" - assert len(sample_for_source_info["__sources__"]) == 1 - assert ( - sample_for_source_info["__sources__"][0].dataset_path - == mock_decode_file_store.get_path() - ) - assert sample_for_source_info["__sources__"][0].index is None - assert sample_for_source_info["__sources__"][0].shard_name is None - assert sample_for_source_info["__sources__"][0].file_names == ("file1",) - finally: - pool.close() - - def test_get_lazy_method(self): - """Test the lazy get method for background prefetching""" - pool = FileStoreCachePool(parent_cache_dir=self.temp_path) - # Create mock file stores - mock_raw_file_store = MockFileStore( - { - "file1": b"test data 1", - } +"""Test cases for FileStoreCachePool""" + + +@pytest.fixture +def temp_dir(): + """Setup test environment before each test""" + # Create a temporary directory + temp_dir = tempfile.TemporaryDirectory() + temp_path = Path(temp_dir.name) + yield temp_path + temp_dir.cleanup() + + +def test_get_method(temp_dir): + """Test the synchronous get method""" + # Create mock file stores + mock_raw_file_store = MockFileStore( + { + "file1": b"test data 1", + "file2": b"test data 2", + "file3": b"test data 3", + } + ) + + mock_decode_file_store = DecodeFileStore( + decoder=MockDecoder(), + inner_reader=mock_raw_file_store, + ) + pool = FileStoreCachePool(parent_cache_dir=temp_dir) + try: + # get should directly read from the dataset without caching + sample_for_source_info = {"__sources__": []} + result = pool.get(mock_raw_file_store, "file1", sample_for_source_info) + assert result == b"test data 1" + assert len(sample_for_source_info["__sources__"]) == 1 + assert ( + sample_for_source_info["__sources__"][0].dataset_path == mock_raw_file_store.get_path() ) - try: - # Request lazy loading - lazy_ref = pool.get_lazy(mock_raw_file_store, "file1") - - # Verify the return type - assert isinstance(lazy_ref, FileCacheLazy) - - # Wait for the background task - lazy_ref.entry.send_to_cache_future.result() - - # Check that the file exists in the cache directory - cache_files = list(pool.cache_dir.glob("*")) - assert len(cache_files) == 1 - - # Get the data - result = lazy_ref.get() - assert result == b"test data 1" - finally: - pool.close() - - def test_shared_references(self): - """Test that multiple references share the same background task""" - pool = FileStoreCachePool(parent_cache_dir=self.temp_path) - # Create mock file stores - mock_raw_file_store = MockFileStore( - { - "file1": b"test data 1", - } + assert sample_for_source_info["__sources__"][0].index is None + assert sample_for_source_info["__sources__"][0].shard_name is None + assert sample_for_source_info["__sources__"][0].file_names == ("file1",) + + # get should directly read from the dataset without caching + sample_for_source_info = {"__sources__": []} + result = pool.get(mock_decode_file_store, "file1", sample_for_source_info) + assert result == "file1: test data 1" + assert len(sample_for_source_info["__sources__"]) == 1 + assert ( + sample_for_source_info["__sources__"][0].dataset_path + == mock_decode_file_store.get_path() ) - try: - # Check that the file exists in the cache directory - cache_files = list(pool.cache_dir.rglob("*")) - assert len(cache_files) == 0 - - # Request lazy loading for the same file twice - lazy_ref1 = pool.get_lazy(mock_raw_file_store, "file1") - lazy_ref2 = pool.get_lazy(mock_raw_file_store, "file1") - - # Check that they share the same entry - assert lazy_ref1.entry is lazy_ref2.entry - - # Check that refcount is 2 - assert lazy_ref1.entry.refcount == 2 - - # Wait for the background task - lazy_ref1.entry.send_to_cache_future.result() - - # Check that the file exists in the cache directory - cache_files = list(pool.cache_dir.rglob("*")) - assert len(cache_files) == 1, cache_files - - # Get data from both references - sample_with_source_info = {"__sources__": []} - result1 = lazy_ref1.get(sample_with_source_info) - assert lazy_ref1.entry.refcount == 1 - sample_with_source_info2 = {"__sources__": []} - result2 = lazy_ref2.get(sample_with_source_info2) - assert lazy_ref1.entry.refcount == 0 - - # Check that the file exists in the cache directory - cache_files = list(pool.cache_dir.rglob("*")) - assert len(cache_files) == 0 - - assert result1 == b"test data 1" - assert result2 == b"test data 1" - assert ( - sample_with_source_info["__sources__"][0].dataset_path - == sample_with_source_info2["__sources__"][0].dataset_path - ) - assert sample_with_source_info["__sources__"][0].index is None - assert sample_with_source_info["__sources__"][0].shard_name is None - assert ( - sample_with_source_info["__sources__"][0].file_names - == sample_with_source_info2["__sources__"][0].file_names - ) - finally: - pool.close() - - def test_cache_size_management(self): - """Test that the cache respects size limits and evicts files""" - # Create a cache pool with strict limits - pool = FileStoreCachePool( - parent_cache_dir=self.temp_path, - max_cache_size_gbytes=0.0001, # ~100KB - max_cache_count=2, - num_workers=1, + assert sample_for_source_info["__sources__"][0].index is None + assert sample_for_source_info["__sources__"][0].shard_name is None + assert sample_for_source_info["__sources__"][0].file_names == ("file1",) + finally: + pool.close() + + +def test_get_lazy_method(temp_dir): + """Test the lazy get method for background prefetching""" + pool = FileStoreCachePool(parent_cache_dir=temp_dir) + # Create mock file stores + mock_raw_file_store = MockFileStore( + { + "file1": b"test data 1", + } + ) + try: + # Request lazy loading + lazy_ref = pool.get_lazy(mock_raw_file_store, "file1") + + # Verify the return type + assert isinstance(lazy_ref, FileCacheLazy) + + # Wait for the background task + lazy_ref.entry.send_to_cache_future.result() + + # Check that the file exists in the cache directory + cache_files = list(pool.cache_dir.glob("*")) + assert len(cache_files) == 1 + + # Get the data + result = lazy_ref.get() + assert result == b"test data 1" + finally: + pool.close() + + +def test_shared_references(temp_dir): + """Test that multiple references share the same background task""" + pool = FileStoreCachePool(parent_cache_dir=temp_dir) + # Create mock file stores + mock_raw_file_store = MockFileStore( + { + "file1": b"test data 1", + } + ) + try: + # Check that the file exists in the cache directory + cache_files = list(pool.cache_dir.rglob("*")) + assert len(cache_files) == 0 + + # Request lazy loading for the same file twice + lazy_ref1 = pool.get_lazy(mock_raw_file_store, "file1") + lazy_ref2 = pool.get_lazy(mock_raw_file_store, "file1") + + # Check that they share the same entry + assert lazy_ref1.entry is lazy_ref2.entry + + # Check that refcount is 2 + assert lazy_ref1.entry.refcount == 2 + + # Wait for the background task + lazy_ref1.entry.send_to_cache_future.result() + + # Check that the file exists in the cache directory + cache_files = list(pool.cache_dir.rglob("*")) + assert len(cache_files) == 1, cache_files + + # Get data from both references + sample_with_source_info = {"__sources__": []} + result1 = lazy_ref1.get(sample_with_source_info) + assert lazy_ref1.entry.refcount == 1 + sample_with_source_info2 = {"__sources__": []} + result2 = lazy_ref2.get(sample_with_source_info2) + assert lazy_ref1.entry.refcount == 0 + + # Check that the file exists in the cache directory + cache_files = list(pool.cache_dir.rglob("*")) + assert len(cache_files) == 0 + + assert result1 == b"test data 1" + assert result2 == b"test data 1" + assert ( + sample_with_source_info["__sources__"][0].dataset_path + == sample_with_source_info2["__sources__"][0].dataset_path ) - # Set to a safe byte size - pool.max_cache_size = 75_000 - - mock_raw_file_store = MockFileStore( - { - "large_file1": b"a" * 50_000, - "large_file2": b"b" * 50_000, - "large_file3": b"c" * 50_000, - "large_file4": b"d" * 25_000, - "large_file5": b"e" * 25_000, - "large_file6": b"f" * 25_000, - } + assert sample_with_source_info["__sources__"][0].index is None + assert sample_with_source_info["__sources__"][0].shard_name is None + assert ( + sample_with_source_info["__sources__"][0].file_names + == sample_with_source_info2["__sources__"][0].file_names ) - - try: - # Enqueue all fetches - lazy1 = pool.get_lazy(mock_raw_file_store, "large_file1") - lazy2 = pool.get_lazy(mock_raw_file_store, "large_file2") - lazy3 = pool.get_lazy(mock_raw_file_store, "large_file3") - lazy4 = pool.get_lazy(mock_raw_file_store, "large_file4") - lazy2_2 = pool.get_lazy(mock_raw_file_store, "large_file2") - lazy2_3 = pool.get_lazy(mock_raw_file_store, "large_file2") - lazy3_2 = pool.get_lazy(mock_raw_file_store, "large_file3") - lazy5 = pool.get_lazy(mock_raw_file_store, "large_file5") - lazy6 = pool.get_lazy(mock_raw_file_store, "large_file6") - lazy6_2 = pool.get_lazy(mock_raw_file_store, "large_file6") - - def status(): - return [ - ( - name, - lazy.entry.refcount, - "consumed" - if lazy._data - else ("cached" if lazy.entry.send_to_cache_future.done() else "pending"), + finally: + pool.close() + + +def test_cache_size_management(temp_dir): + """Test that the cache respects size limits and evicts files""" + # Create a cache pool with strict limits + pool = FileStoreCachePool( + parent_cache_dir=temp_dir, + max_cache_size_gbytes=0.0001, # ~100KB + max_cache_count=2, + num_workers=1, + ) + # Set to a safe byte size + pool.max_cache_size = 75_000 + + mock_raw_file_store = MockFileStore( + { + "large_file1": b"a" * 50_000, + "large_file2": b"b" * 50_000, + "large_file3": b"c" * 50_000, + "large_file4": b"d" * 25_000, + "large_file5": b"e" * 25_000, + "large_file6": b"f" * 25_000, + } + ) + + try: + # Enqueue all fetches + lazy1 = pool.get_lazy(mock_raw_file_store, "large_file1") + lazy2 = pool.get_lazy(mock_raw_file_store, "large_file2") + lazy3 = pool.get_lazy(mock_raw_file_store, "large_file3") + lazy4 = pool.get_lazy(mock_raw_file_store, "large_file4") + lazy2_2 = pool.get_lazy(mock_raw_file_store, "large_file2") + lazy2_3 = pool.get_lazy(mock_raw_file_store, "large_file2") + lazy3_2 = pool.get_lazy(mock_raw_file_store, "large_file3") + lazy5 = pool.get_lazy(mock_raw_file_store, "large_file5") + lazy6 = pool.get_lazy(mock_raw_file_store, "large_file6") + lazy6_2 = pool.get_lazy(mock_raw_file_store, "large_file6") + + def status(): + return [ + ( + name, + lazy.entry.refcount, + "consumed" + if lazy._data + else ("cached" if lazy.entry.send_to_cache_future.done() else "pending"), + ) + for lazy, name in ( + [ + (lazy1, "1"), + (lazy2, "2"), + (lazy2_2, "2_2"), + (lazy2_3, "2_3"), + (lazy3, "3"), + (lazy3_2, "3_2"), + (lazy4, "4"), + (lazy5, "5"), + (lazy6, "6"), + ] + + ([(lazy6_2, "6_2")] if lazy6_2 is not None else []) + ) + ] + + def txt_status(): + out = [] + for lazy in [ + lazy1, + lazy2, + lazy2_2, + lazy2_3, + lazy3, + lazy3_2, + lazy4, + lazy5, + lazy6, + ] + ([lazy6_2] if lazy6_2 is not None else []): + if lazy._data is not None: + out.append( + f" - {lazy.fname} [{lazy.entry.data_size}b, {lazy.entry.refcount}refs] consumed" ) - for lazy, name in ( - [ - (lazy1, "1"), - (lazy2, "2"), - (lazy2_2, "2_2"), - (lazy2_3, "2_3"), - (lazy3, "3"), - (lazy3_2, "3_2"), - (lazy4, "4"), - (lazy5, "5"), - (lazy6, "6"), - ] - + ([(lazy6_2, "6_2")] if lazy6_2 is not None else []) + elif lazy.entry.send_to_cache_future.done(): + out.append( + f" - {lazy.fname} [{lazy.entry.data_size}b, {lazy.entry.refcount}refs] cached" ) - ] - - def txt_status(): - out = [] - for lazy in [ - lazy1, - lazy2, - lazy2_2, - lazy2_3, - lazy3, - lazy3_2, - lazy4, - lazy5, - lazy6, - ] + ([lazy6_2] if lazy6_2 is not None else []): - if lazy._data is not None: - out.append( - f" - {lazy.fname} [{lazy.entry.data_size}b, {lazy.entry.refcount}refs] consumed" - ) - elif lazy.entry.send_to_cache_future.done(): - out.append( - f" - {lazy.fname} [{lazy.entry.data_size}b, {lazy.entry.refcount}refs] cached" - ) - else: - out.append( - f" - {lazy.fname} [{lazy.entry.data_size}b, {lazy.entry.refcount}refs] pending" - ) - return ( - f"Cached Count: {pool.current_cache_count}, Cache size: {pool.current_cache_size}\n" - + "\n".join(out) - ) - - # lazy2_2 and lazy2_3 should share the same entry as lazy2 - assert lazy2_2.entry is lazy2.entry - assert lazy2_3.entry is lazy2.entry - - lazy1.entry.send_to_cache_future.result(timeout=1) - # Wait for the background tasks to finish - time.sleep(0.5) - - print("Checking cache status") - # They should not be able to finish, because the cache is full - # Queue state: [2<50>, 3<50>, 4<25>, 5<25>, 6<25>], cached out: [1<50>], removed: [] - assert status() == [ - ("1", 1, "cached"), - ("2", 3, "pending"), - ("2_2", 3, "pending"), - ("2_3", 3, "pending"), - ("3", 2, "pending"), - ("3_2", 2, "pending"), - ("4", 1, "pending"), - ("5", 1, "pending"), - ("6", 2, "pending"), - ("6_2", 2, "pending"), - ], txt_status() - - # Check cache count and size before second file - assert pool.current_cache_count == 1, pool.current_cache_count - assert pool.current_cache_size == 50_000, pool.current_cache_size - - print("Fetching lazy2_3") - # Now, fetching the second file should still work directly and ignore the caching - # But it will requeue fetching the second file to the background thread for the remaining lazies. - result2_3 = lazy2_3.get() - assert result2_3 == b"b" * 50_000 - - # They should not be able to finish, because the cache is full - # Queue state: [3<50>, 4<25>, 5<25>, 6<25>, 2<50>], cached out: [1<50>], removed: [] - assert status() == [ - ("1", 1, "cached"), - ("2", 2, "pending"), - ("2_2", 2, "pending"), - ("2_3", 2, "consumed"), - ("3", 2, "pending"), - ("3_2", 2, "pending"), - ("4", 1, "pending"), - ("5", 1, "pending"), - ("6", 2, "pending"), - ("6_2", 2, "pending"), - ], txt_status() - - # Fetch - result1 = lazy1.get() - assert result1 == b"a" * 50_000 - - lazy3.entry.send_to_cache_future.result(timeout=1) - - time.sleep(0.5) - - # Second file is now queued at the end. - # File 3 and 4 should now be cached. - # Queue state: [5<25>, 6<25>, 2<50>], cached out: [3<50>, 4<25>], removed: [1<50>] - assert status() == [ - ("1", 0, "consumed"), - ("2", 2, "pending"), - ("2_2", 2, "pending"), - ("2_3", 2, "consumed"), - ("3", 2, "cached"), - ("3_2", 2, "cached"), - ("4", 1, "cached"), - ("5", 1, "pending"), - ("6", 2, "pending"), - ("6_2", 2, "pending"), - ], txt_status() - assert pool.current_cache_count == 2 - assert pool.current_cache_size == 75_000 - - result3 = lazy3.get() - assert result3 == b"c" * 50_000 - - time.sleep(0.5) - - # Space by large_file3 is still occupied in cache - # Queue state: [5<25>, 6<25>, 2<50>], cached out: [3<50>, 4<25>], removed: [1<50>] - assert status() == [ - ("1", 0, "consumed"), - ("2", 2, "pending"), - ("2_2", 2, "pending"), - ("2_3", 2, "consumed"), - ("3", 1, "consumed"), - ("3_2", 1, "cached"), - ("4", 1, "cached"), - ("5", 1, "pending"), - ("6", 2, "pending"), - ("6_2", 2, "pending"), - ], txt_status() - assert pool.current_cache_count == 2 - assert pool.current_cache_size == 75_000 - - result3_2 = lazy3_2.get() - assert result3_2 == b"c" * 50_000 - - time.sleep(0.5) - - # Space by large_file3 was freed now, 4, 5, and 6 should fit now, large_file2 not yet - # Queue state: [6<25>, 2<50>], cached out: [5<25>, 4<25>], removed: [1<50>, 3<50>] - assert status() == [ - ("1", 0, "consumed"), - ("2", 2, "pending"), - ("2_2", 2, "pending"), - ("2_3", 2, "consumed"), - ("3", 0, "consumed"), - ("3_2", 0, "consumed"), - ("4", 1, "cached"), - ("5", 1, "cached"), - ("6", 2, "pending"), - ("6_2", 2, "pending"), - ], txt_status() - assert pool.current_cache_count == 2 - assert pool.current_cache_size == 50_000 - - result4 = lazy4.get() - assert result4 == b"d" * 25_000 - - time.sleep(0.5) - - # Nothing changed, no space for large_file2 still - # Queue state: [6<25>, 2<50>], cached out: [5<25>, 4<25>], removed: [1<50>, 3<50>, 4<25>] - assert status() == [ - ("1", 0, "consumed"), - ("2", 2, "pending"), - ("2_2", 2, "pending"), - ("2_3", 2, "consumed"), - ("3", 0, "consumed"), - ("3_2", 0, "consumed"), - ("4", 0, "consumed"), - ("5", 1, "cached"), - ("6", 2, "cached"), - ("6_2", 2, "cached"), - ], txt_status() - assert pool.current_cache_count == 2 - assert pool.current_cache_size == 50_000 - - result5 = lazy5.get() - assert result5 == b"e" * 25_000 - - time.sleep(0.5) - - # Now large_file2 can be cached - # Queue state: [], cached out: [6<25>, 2<50>], removed: [1<50>, 3<50>, 4<25>, 5<25>] - assert status() == [ - ("1", 0, "consumed"), - ("2", 2, "cached"), - ("2_2", 2, "cached"), - ("2_3", 2, "consumed"), - ("3", 0, "consumed"), - ("3_2", 0, "consumed"), - ("4", 0, "consumed"), - ("5", 0, "consumed"), - ("6", 2, "cached"), - ("6_2", 2, "cached"), - ], txt_status() - assert pool.current_cache_count == 2 - assert pool.current_cache_size == 75_000 - - result6 = lazy6.get() - assert result6 == b"f" * 25_000 - - # Queue state: [], cached out: [6<25>, 2<50>], removed: [1<50>, 3<50>, 4<25>, 5<25>] - assert status() == [ - ("1", 0, "consumed"), - ("2", 2, "cached"), - ("2_2", 2, "cached"), - ("2_3", 2, "consumed"), - ("3", 0, "consumed"), - ("3_2", 0, "consumed"), - ("4", 0, "consumed"), - ("5", 0, "consumed"), - ("6", 1, "consumed"), - ("6_2", 1, "cached"), - ], txt_status() - assert pool.current_cache_count == 2 - assert pool.current_cache_size == 75_000 - - result2 = lazy2.get() - assert result2 == b"b" * 50_000 - - # Queue state: [], cached out: [6<25>, 2<50>], removed: [1<50>, 3<50>, 4<25>, 5<25>] - assert status() == [ - ("1", 0, "consumed"), - ("2", 1, "consumed"), - ("2_2", 1, "cached"), - ("2_3", 1, "consumed"), - ("3", 0, "consumed"), - ("3_2", 0, "consumed"), - ("4", 0, "consumed"), - ("5", 0, "consumed"), - ("6", 1, "consumed"), - ("6_2", 1, "cached"), - ], txt_status() - assert pool.current_cache_count == 2 - assert pool.current_cache_size == 75_000 - - result2_2 = lazy2_2.get() - assert result2_2 == b"b" * 50_000 - - # Cache should only contain large_file6 now - # Queue state: [], cached out: [6<25>], removed: [1<50>, 3<50>, 4<25>, 5<25>, 2<50>] - assert status() == [ - ("1", 0, "consumed"), - ("2", 0, "consumed"), - ("2_2", 0, "consumed"), - ("2_3", 0, "consumed"), - ("3", 0, "consumed"), - ("3_2", 0, "consumed"), - ("4", 0, "consumed"), - ("5", 0, "consumed"), - ("6", 1, "consumed"), - ("6_2", 1, "cached"), - ], txt_status() - assert pool.current_cache_count == 1, txt_status() - assert pool.current_cache_size == 25_000 - - # Delete the last reference to large_file6, it should be removed from the cache - lazy6_2 = None - gc.collect() - - # Cache should be empty now - # Queue state: [], cached out: [], removed: [1<50>, 3<50>, 4<25>, 5<25>, 6<25>, 2<50>] - assert status() == [ - ("1", 0, "consumed"), - ("2", 0, "consumed"), - ("2_2", 0, "consumed"), - ("2_3", 0, "consumed"), - ("3", 0, "consumed"), - ("3_2", 0, "consumed"), - ("4", 0, "consumed"), - ("5", 0, "consumed"), - ("6", 0, "consumed"), - ], txt_status() - assert pool.current_cache_count == 0, txt_status() - assert pool.current_cache_size == 0 - # Check that the cache directory is empty - assert not list(pool.cache_dir.glob("*")) - finally: - pool.close() - - def test_raw_method(self): - """Test the 'raw' caching method with DecodeFileStore""" - pool = FileStoreCachePool(parent_cache_dir=self.temp_path, method="raw") - mock_raw_file_store = MockFileStore( - { - "file1": b"test data 1", - } - ) - mock_decode_file_store = DecodeFileStore( - decoder=MockDecoder(), - inner_reader=mock_raw_file_store, - ) - try: - # Request lazy loading - lazy_ref = pool.get_lazy(mock_decode_file_store, "file1") - - # Wait for background task - time.sleep(0.5) - - # Get the data - should be decoded - sample_with_source_info = {"__sources__": []} - result = lazy_ref.get(sample_with_source_info) - assert result == "file1: test data 1" - assert ( - sample_with_source_info["__sources__"][0].dataset_path - == mock_decode_file_store.get_path() + else: + out.append( + f" - {lazy.fname} [{lazy.entry.data_size}b, {lazy.entry.refcount}refs] pending" + ) + return ( + f"Cached Count: {pool.current_cache_count}, Cache size: {pool.current_cache_size}\n" + + "\n".join(out) ) - assert sample_with_source_info["__sources__"][0].index is None - assert sample_with_source_info["__sources__"][0].shard_name is None - assert sample_with_source_info["__sources__"][0].file_names == ("file1",) - finally: - pool.close() - - def test_pickle_method(self): - """Test the 'pickle' caching method""" - pool = FileStoreCachePool(parent_cache_dir=self.temp_path, method="pickle") - mock_raw_file_store = MockFileStore( - { - "file1": b"test data 1", - } + + # lazy2_2 and lazy2_3 should share the same entry as lazy2 + assert lazy2_2.entry is lazy2.entry + assert lazy2_3.entry is lazy2.entry + + lazy1.entry.send_to_cache_future.result(timeout=1) + # Wait for the background tasks to finish + time.sleep(0.5) + + print("Checking cache status") + # They should not be able to finish, because the cache is full + # Queue state: [2<50>, 3<50>, 4<25>, 5<25>, 6<25>], cached out: [1<50>], removed: [] + assert status() == [ + ("1", 1, "cached"), + ("2", 3, "pending"), + ("2_2", 3, "pending"), + ("2_3", 3, "pending"), + ("3", 2, "pending"), + ("3_2", 2, "pending"), + ("4", 1, "pending"), + ("5", 1, "pending"), + ("6", 2, "pending"), + ("6_2", 2, "pending"), + ], txt_status() + + # Check cache count and size before second file + assert pool.current_cache_count == 1, pool.current_cache_count + assert pool.current_cache_size == 50_000, pool.current_cache_size + + print("Fetching lazy2_3") + # Now, fetching the second file should still work directly and ignore the caching + # But it will requeue fetching the second file to the background thread for the remaining lazies. + result2_3 = lazy2_3.get() + assert result2_3 == b"b" * 50_000 + + # They should not be able to finish, because the cache is full + # Queue state: [3<50>, 4<25>, 5<25>, 6<25>, 2<50>], cached out: [1<50>], removed: [] + assert status() == [ + ("1", 1, "cached"), + ("2", 2, "pending"), + ("2_2", 2, "pending"), + ("2_3", 2, "consumed"), + ("3", 2, "pending"), + ("3_2", 2, "pending"), + ("4", 1, "pending"), + ("5", 1, "pending"), + ("6", 2, "pending"), + ("6_2", 2, "pending"), + ], txt_status() + + # Fetch + result1 = lazy1.get() + assert result1 == b"a" * 50_000 + + lazy3.entry.send_to_cache_future.result(timeout=1) + + time.sleep(0.5) + + # Second file is now queued at the end. + # File 3 and 4 should now be cached. + # Queue state: [5<25>, 6<25>, 2<50>], cached out: [3<50>, 4<25>], removed: [1<50>] + assert status() == [ + ("1", 0, "consumed"), + ("2", 2, "pending"), + ("2_2", 2, "pending"), + ("2_3", 2, "consumed"), + ("3", 2, "cached"), + ("3_2", 2, "cached"), + ("4", 1, "cached"), + ("5", 1, "pending"), + ("6", 2, "pending"), + ("6_2", 2, "pending"), + ], txt_status() + assert pool.current_cache_count == 2 + assert pool.current_cache_size == 75_000 + + result3 = lazy3.get() + assert result3 == b"c" * 50_000 + + time.sleep(0.5) + + # Space by large_file3 is still occupied in cache + # Queue state: [5<25>, 6<25>, 2<50>], cached out: [3<50>, 4<25>], removed: [1<50>] + assert status() == [ + ("1", 0, "consumed"), + ("2", 2, "pending"), + ("2_2", 2, "pending"), + ("2_3", 2, "consumed"), + ("3", 1, "consumed"), + ("3_2", 1, "cached"), + ("4", 1, "cached"), + ("5", 1, "pending"), + ("6", 2, "pending"), + ("6_2", 2, "pending"), + ], txt_status() + assert pool.current_cache_count == 2 + assert pool.current_cache_size == 75_000 + + result3_2 = lazy3_2.get() + assert result3_2 == b"c" * 50_000 + + time.sleep(0.5) + + # Space by large_file3 was freed now, 4, 5, and 6 should fit now, large_file2 not yet + # Queue state: [6<25>, 2<50>], cached out: [5<25>, 4<25>], removed: [1<50>, 3<50>] + assert status() == [ + ("1", 0, "consumed"), + ("2", 2, "pending"), + ("2_2", 2, "pending"), + ("2_3", 2, "consumed"), + ("3", 0, "consumed"), + ("3_2", 0, "consumed"), + ("4", 1, "cached"), + ("5", 1, "cached"), + ("6", 2, "pending"), + ("6_2", 2, "pending"), + ], txt_status() + assert pool.current_cache_count == 2 + assert pool.current_cache_size == 50_000 + + result4 = lazy4.get() + assert result4 == b"d" * 25_000 + + time.sleep(0.5) + + # Nothing changed, no space for large_file2 still + # Queue state: [6<25>, 2<50>], cached out: [5<25>, 4<25>], removed: [1<50>, 3<50>, 4<25>] + assert status() == [ + ("1", 0, "consumed"), + ("2", 2, "pending"), + ("2_2", 2, "pending"), + ("2_3", 2, "consumed"), + ("3", 0, "consumed"), + ("3_2", 0, "consumed"), + ("4", 0, "consumed"), + ("5", 1, "cached"), + ("6", 2, "cached"), + ("6_2", 2, "cached"), + ], txt_status() + assert pool.current_cache_count == 2 + assert pool.current_cache_size == 50_000 + + result5 = lazy5.get() + assert result5 == b"e" * 25_000 + + time.sleep(0.5) + + # Now large_file2 can be cached + # Queue state: [], cached out: [6<25>, 2<50>], removed: [1<50>, 3<50>, 4<25>, 5<25>] + assert status() == [ + ("1", 0, "consumed"), + ("2", 2, "cached"), + ("2_2", 2, "cached"), + ("2_3", 2, "consumed"), + ("3", 0, "consumed"), + ("3_2", 0, "consumed"), + ("4", 0, "consumed"), + ("5", 0, "consumed"), + ("6", 2, "cached"), + ("6_2", 2, "cached"), + ], txt_status() + assert pool.current_cache_count == 2 + assert pool.current_cache_size == 75_000 + + result6 = lazy6.get() + assert result6 == b"f" * 25_000 + + # Queue state: [], cached out: [6<25>, 2<50>], removed: [1<50>, 3<50>, 4<25>, 5<25>] + assert status() == [ + ("1", 0, "consumed"), + ("2", 2, "cached"), + ("2_2", 2, "cached"), + ("2_3", 2, "consumed"), + ("3", 0, "consumed"), + ("3_2", 0, "consumed"), + ("4", 0, "consumed"), + ("5", 0, "consumed"), + ("6", 1, "consumed"), + ("6_2", 1, "cached"), + ], txt_status() + assert pool.current_cache_count == 2 + assert pool.current_cache_size == 75_000 + + result2 = lazy2.get() + assert result2 == b"b" * 50_000 + + # Queue state: [], cached out: [6<25>, 2<50>], removed: [1<50>, 3<50>, 4<25>, 5<25>] + assert status() == [ + ("1", 0, "consumed"), + ("2", 1, "consumed"), + ("2_2", 1, "cached"), + ("2_3", 1, "consumed"), + ("3", 0, "consumed"), + ("3_2", 0, "consumed"), + ("4", 0, "consumed"), + ("5", 0, "consumed"), + ("6", 1, "consumed"), + ("6_2", 1, "cached"), + ], txt_status() + assert pool.current_cache_count == 2 + assert pool.current_cache_size == 75_000 + + result2_2 = lazy2_2.get() + assert result2_2 == b"b" * 50_000 + + # Cache should only contain large_file6 now + # Queue state: [], cached out: [6<25>], removed: [1<50>, 3<50>, 4<25>, 5<25>, 2<50>] + assert status() == [ + ("1", 0, "consumed"), + ("2", 0, "consumed"), + ("2_2", 0, "consumed"), + ("2_3", 0, "consumed"), + ("3", 0, "consumed"), + ("3_2", 0, "consumed"), + ("4", 0, "consumed"), + ("5", 0, "consumed"), + ("6", 1, "consumed"), + ("6_2", 1, "cached"), + ], txt_status() + assert pool.current_cache_count == 1, txt_status() + assert pool.current_cache_size == 25_000 + + # Delete the last reference to large_file6, it should be removed from the cache + lazy6_2 = None + gc.collect() + + # Cache should be empty now + # Queue state: [], cached out: [], removed: [1<50>, 3<50>, 4<25>, 5<25>, 6<25>, 2<50>] + assert status() == [ + ("1", 0, "consumed"), + ("2", 0, "consumed"), + ("2_2", 0, "consumed"), + ("2_3", 0, "consumed"), + ("3", 0, "consumed"), + ("3_2", 0, "consumed"), + ("4", 0, "consumed"), + ("5", 0, "consumed"), + ("6", 0, "consumed"), + ], txt_status() + assert pool.current_cache_count == 0, txt_status() + assert pool.current_cache_size == 0 + # Check that the cache directory is empty + assert not list(pool.cache_dir.glob("*")) + finally: + pool.close() + + +def test_raw_method(temp_dir): + """Test the 'raw' caching method with DecodeFileStore""" + pool = FileStoreCachePool(parent_cache_dir=temp_dir, method="raw") + mock_raw_file_store = MockFileStore( + { + "file1": b"test data 1", + } + ) + mock_decode_file_store = DecodeFileStore( + decoder=MockDecoder(), + inner_reader=mock_raw_file_store, + ) + try: + # Request lazy loading + lazy_ref = pool.get_lazy(mock_decode_file_store, "file1") + + # Wait for background task + time.sleep(0.5) + + # Get the data - should be decoded + sample_with_source_info = {"__sources__": []} + result = lazy_ref.get(sample_with_source_info) + assert result == "file1: test data 1" + assert ( + sample_with_source_info["__sources__"][0].dataset_path + == mock_decode_file_store.get_path() ) - mock_decode_file_store = DecodeFileStore( - decoder=MockDecoder(), - inner_reader=mock_raw_file_store, + assert sample_with_source_info["__sources__"][0].index is None + assert sample_with_source_info["__sources__"][0].shard_name is None + assert sample_with_source_info["__sources__"][0].file_names == ("file1",) + finally: + pool.close() + + +def test_pickle_method(temp_dir): + """Test the 'pickle' caching method""" + pool = FileStoreCachePool(parent_cache_dir=temp_dir, method="pickle") + mock_raw_file_store = MockFileStore( + { + "file1": b"test data 1", + } + ) + mock_decode_file_store = DecodeFileStore( + decoder=MockDecoder(), + inner_reader=mock_raw_file_store, + ) + try: + # Request lazy loading + lazy_ref = pool.get_lazy(mock_decode_file_store, "file1") + + # Wait for background task + lazy_ref.entry.send_to_cache_future.result() + + # Get the data - should be unpickled correctly + sample_with_source_info = {"__sources__": []} + result = lazy_ref.get(sample_with_source_info) + assert result == "file1: test data 1" + assert ( + sample_with_source_info["__sources__"][0].dataset_path + == mock_decode_file_store.get_path() ) - try: - # Request lazy loading - lazy_ref = pool.get_lazy(mock_decode_file_store, "file1") - - # Wait for background task - lazy_ref.entry.send_to_cache_future.result() - - # Get the data - should be unpickled correctly - sample_with_source_info = {"__sources__": []} - result = lazy_ref.get(sample_with_source_info) - assert result == "file1: test data 1" - assert ( - sample_with_source_info["__sources__"][0].dataset_path - == mock_decode_file_store.get_path() - ) - assert sample_with_source_info["__sources__"][0].index is None - assert sample_with_source_info["__sources__"][0].shard_name is None - assert sample_with_source_info["__sources__"][0].file_names == ("file1",) - - # Request lazy loading - lazy_ref = pool.get_lazy(mock_raw_file_store, "file1") - - # Wait for background task - lazy_ref.entry.send_to_cache_future.result() - - # Get the data - should be unpickled correctly - sample_with_source_info = {"__sources__": []} - result = lazy_ref.get(sample_with_source_info) - assert result == b"test data 1" - assert ( - sample_with_source_info["__sources__"][0].dataset_path - == mock_raw_file_store.get_path() - ) - assert sample_with_source_info["__sources__"][0].index is None - assert sample_with_source_info["__sources__"][0].shard_name is None - assert sample_with_source_info["__sources__"][0].file_names == ("file1",) - finally: - pool.close() - - def test_concurrent_access(self): - """Test concurrent access to the cache pool""" - pool = FileStoreCachePool(parent_cache_dir=self.temp_path) - mock_raw_file_store = MockFileStore( - { - "file1": b"test data 1", - } + assert sample_with_source_info["__sources__"][0].index is None + assert sample_with_source_info["__sources__"][0].shard_name is None + assert sample_with_source_info["__sources__"][0].file_names == ("file1",) + + # Request lazy loading + lazy_ref = pool.get_lazy(mock_raw_file_store, "file1") + + # Wait for background task + lazy_ref.entry.send_to_cache_future.result() + + # Get the data - should be unpickled correctly + sample_with_source_info = {"__sources__": []} + result = lazy_ref.get(sample_with_source_info) + assert result == b"test data 1" + assert ( + sample_with_source_info["__sources__"][0].dataset_path == mock_raw_file_store.get_path() ) - results = [] - - def worker(filename): - lazy_ref = pool.get_lazy(mock_raw_file_store, filename) - result, source_info = lazy_ref.get() - results.append(result) - assert source_info.dataset_path == mock_raw_file_store.get_path() - assert source_info.index is None - assert source_info.shard_name is None - assert source_info.file_names == (filename,) - - try: - # Start multiple threads accessing the same file - threads = [] - for i in range(5): - t = threading.Thread(target=worker, args=("file1",)) - threads.append(t) - t.start() - - # Wait for all threads to complete - for t in threads: - t.join() - - # All threads should get the correct result - for r in results: - assert r == b"test data 1" - finally: - pool.close() - - -if __name__ == "__main__": - unittest.main() + assert sample_with_source_info["__sources__"][0].index is None + assert sample_with_source_info["__sources__"][0].shard_name is None + assert sample_with_source_info["__sources__"][0].file_names == ("file1",) + finally: + pool.close() + + +def test_concurrent_access(temp_dir): + """Test concurrent access to the cache pool""" + pool = FileStoreCachePool(parent_cache_dir=temp_dir) + mock_raw_file_store = MockFileStore( + { + "file1": b"test data 1", + } + ) + results = [] + + def worker(filename): + lazy_ref = pool.get_lazy(mock_raw_file_store, filename) + result, source_info = lazy_ref.get() + results.append(result) + assert source_info.dataset_path == mock_raw_file_store.get_path() + assert source_info.index is None + assert source_info.shard_name is None + assert source_info.file_names == (filename,) + + try: + # Start multiple threads accessing the same file + threads = [] + for i in range(5): + t = threading.Thread(target=worker, args=("file1",)) + threads.append(t) + t.start() + + # Wait for all threads to complete + for t in threads: + t.join() + + # All threads should get the correct result + for r in results: + assert r == b"test data 1" + finally: + pool.close() diff --git a/tests/test_jsonl_dataset.py b/tests/test_jsonl_dataset.py index de32200d..d2a51270 100644 --- a/tests/test_jsonl_dataset.py +++ b/tests/test_jsonl_dataset.py @@ -9,12 +9,12 @@ import random import sys import tempfile -import unittest import warnings from collections import Counter from pathlib import Path from typing import Iterable +import pytest import torch from click.testing import CliRunner @@ -58,256 +58,256 @@ class SimpleCookingTaskEncoder(DefaultTaskEncoder): cookers = [Cooker(cook=cook_text)] -class TestJsonlDataset(unittest.TestCase): - # Set up the test fixture - def setUp(self): - random.seed(42) +@pytest.fixture +def temp_dir(): + """Create a temporary directory for test files.""" + temp_dir = tempfile.TemporaryDirectory() + yield temp_dir + gc.collect() + temp_dir.cleanup() + + +@pytest.fixture +def dataset_path(temp_dir): + """Create dataset path and setup test data.""" + random.seed(42) + logging.basicConfig(stream=sys.stderr, level=logging.INFO) + warnings.simplefilter("ignore", ResourceWarning) + + dataset_path = Path(temp_dir.name) + dataset_path.mkdir(exist_ok=True, parents=True) + + # Create a small dummy datasets + create_text_test_dataset(dataset_path / "ds1.jsonl", range(55), range(55)) + create_text_test_dataset(dataset_path / "ds2.jsonl", range(100, 155), range(100, 155)) + create_text_test_dataset(dataset_path / "ds3.jsonl", range(200, 255), range(55)) + + mds_all_path = dataset_path / "metadataset_all.yaml" + with open(mds_all_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " blend:", + " - path: ds1.jsonl", + " subflavors:", + " ds: ds1", + " - path: ds2.jsonl", + " subflavors:", + " ds: ds2", + " - path: ds3.jsonl", + " subflavors:", + " ds: ds3", + ] + ) + ) - logging.basicConfig(stream=sys.stderr, level=logging.INFO) - warnings.simplefilter("ignore", ResourceWarning) + return dataset_path - # Create a temporary directory - self.temp_dir = tempfile.TemporaryDirectory() - self.dataset_path = Path(self.temp_dir.name) - # self.dataset_path = Path("./test_dataset") - self.dataset_path.mkdir(exist_ok=True, parents=True) +def create_text_test_dataset( + path: Path, txt_range: Iterable[int], key_range: Iterable[int], prefix: str = "" +): + """Creates a small dummy test dataset for testing purposes.""" - # Create a small dummy datasets - self.create_text_test_dataset(self.dataset_path / "ds1.jsonl", range(55), range(55)) - self.create_text_test_dataset( - self.dataset_path / "ds2.jsonl", range(100, 155), range(100, 155) - ) - self.create_text_test_dataset(self.dataset_path / "ds3.jsonl", range(200, 255), range(55)) - - self.mds_all_path = self.dataset_path / "metadataset_all.yaml" - with open(self.mds_all_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " blend:", - " - path: ds1.jsonl", - " subflavors:", - " ds: ds1", - " - path: ds2.jsonl", - " subflavors:", - " ds: ds2", - " - path: ds3.jsonl", - " subflavors:", - " ds: ds3", - ] - ) - ) + # Write jsonl file + with open(path, "w") as wf: + for key, txt in zip(key_range, txt_range): + # Write JSON entries to the file, one per line. + wf.write(json.dumps({"idx": key, "txt": f"{prefix}{txt}"}) + "\n") - def tearDown(self): - # Remove all temporary files - gc.collect() - self.temp_dir.cleanup() + from megatron.energon.flavors import CrudeJsonlDatasetFactory - @staticmethod - def create_text_test_dataset( - path: Path, txt_range: Iterable[int], key_range: Iterable[int], prefix: str = "" - ): - """Creates a small dummy test dataset for testing purposes.""" + CrudeJsonlDatasetFactory.prepare_dataset(path) - # Write jsonl file - with open(path, "w") as wf: - for key, txt in zip(key_range, txt_range): - # Write JSON entries to the file, one per line. - wf.write(json.dumps({"idx": key, "txt": f"{prefix}{txt}"}) + "\n") - from megatron.energon.flavors import CrudeJsonlDatasetFactory +def test_dataset(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) - CrudeJsonlDatasetFactory.prepare_dataset(path) + # Train mode dataset + train_dataset = get_train_dataset( + dataset_path / "ds1.jsonl", + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + task_encoder=SimpleCookingTaskEncoder(), + ) + print(len(train_dataset)) + assert len(train_dataset) == 55, f"Expected 55 samples, got {len(train_dataset)}" + + with get_loader(train_dataset) as train_loader1: + train_order1 = [ + text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text + ] + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(Counter(train_order1)) == 55 + assert all(v == 10 for v in Counter(train_order1).values()) + + +def test_metadataset_all(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=2, + seed_offset=42, + ) - def test_dataset(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, - ) + # Train mode dataset + train_dataset = get_train_dataset( + dataset_path / "metadataset_all.yaml", + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + task_encoder=SimpleCookingTaskEncoder(), + ) + print(len(train_dataset)) + assert len(train_dataset) == 55 * 3, f"Expected 55 * 3 samples, got {len(train_dataset)}" - # Train mode dataset - train_dataset = get_train_dataset( - self.dataset_path / "ds1.jsonl", - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - task_encoder=SimpleCookingTaskEncoder(), - ) - print(len(train_dataset)) - assert len(train_dataset) == 55, f"Expected 55 samples, got {len(train_dataset)}" + with get_loader(train_dataset) as train_loader1: + train_order1 = [ + text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text + ] + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(Counter(train_order1)) == 55 * 3 + assert all(2 <= v <= 5 for v in Counter(train_order1).values()) - with get_loader(train_dataset) as train_loader1: - train_order1 = [ - text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text - ] - print(train_order1[:10]) - print(Counter(train_order1)) - assert len(Counter(train_order1)) == 55 - assert all(v == 10 for v in Counter(train_order1).values()) - - def test_metadataset_all(self): - torch.manual_seed(42) + +def test_metadataset_multirank(dataset_path): + torch.manual_seed(42) + + sample_counts = Counter() + expected_lens = [19, 19, 17] + + for cur_rank in range(3): worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=2, + rank=cur_rank, + world_size=3, + num_workers=5, seed_offset=42, ) # Train mode dataset train_dataset = get_train_dataset( - self.mds_all_path, + dataset_path / "ds1.jsonl", worker_config=worker_config, batch_size=1, shuffle_buffer_size=None, max_samples_per_sequence=None, task_encoder=SimpleCookingTaskEncoder(), + repeat=False, ) print(len(train_dataset)) - assert len(train_dataset) == 55 * 3, f"Expected 55 * 3 samples, got {len(train_dataset)}" + assert len(train_dataset) == expected_lens[cur_rank], ( + f"Expected {expected_lens[cur_rank]} samples, got {len(train_dataset)}" + ) with get_loader(train_dataset) as train_loader1: - train_order1 = [ - text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text - ] - print(train_order1[:10]) - print(Counter(train_order1)) - assert len(Counter(train_order1)) == 55 * 3 - assert all(2 <= v <= 5 for v in Counter(train_order1).values()) - - def test_metadataset_multirank(self): - torch.manual_seed(42) - - sample_counts = Counter() - expected_lens = [19, 19, 17] - - for cur_rank in range(3): - worker_config = WorkerConfig( - rank=cur_rank, - world_size=3, - num_workers=5, - seed_offset=42, - ) - - # Train mode dataset - train_dataset = get_train_dataset( - self.dataset_path / "ds1.jsonl", - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - task_encoder=SimpleCookingTaskEncoder(), - repeat=False, - ) - print(len(train_dataset)) - assert len(train_dataset) == expected_lens[cur_rank], ( - f"Expected {expected_lens[cur_rank]} samples, got {len(train_dataset)}" - ) + for data in train_loader1: + sample_counts[int(data.text[0])] += 1 - with get_loader(train_dataset) as train_loader1: - for data in train_loader1: - sample_counts[int(data.text[0])] += 1 + for i in range(55): + assert sample_counts[i] == 1, ( + f"Sample {i} should have been seen exactly once, but was seen {sample_counts[i]} times." + ) - for i in range(55): - assert sample_counts[i] == 1, ( - f"Sample {i} should have been seen exactly once, but was seen {sample_counts[i]} times." - ) - def test_s3(self): - # Create a joined dataset configuration - mixed_mds_path = self.dataset_path / "metadataset_mixed.yaml" - with open(mixed_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " path: msc://s3test_jsonl_dataset/test/dataset/metadataset_all.yaml", - ] - ) +def test_s3(dataset_path): + # Create a joined dataset configuration + mixed_mds_path = dataset_path / "metadataset_mixed.yaml" + with open(mixed_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " path: msc://s3test_jsonl_dataset/test/dataset/metadataset_all.yaml", + ] ) - - with setup_s3_emulator(profile_name="s3test_jsonl_dataset") as emu: - # Upload the dataset to the S3 emulator - # EPath(self.dataset_path).copy(EPath("msc://s3/test/dataset")) - emu.add_file(self.dataset_path, "test/dataset") - - with get_loader( - get_train_dataset( - mixed_mds_path, - worker_config=WorkerConfig( - rank=0, - world_size=1, - num_workers=2, - ), - batch_size=1, - shuffle_buffer_size=10, - max_samples_per_sequence=None, - virtual_epoch_length=55 * 10, - task_encoder=SimpleCookingTaskEncoder(), - ) - ) as train_dataset: - data = list(enumerate(train_dataset)) - assert len(data) == 55 * 10, len(data) - cnt = Counter(t for _, entry in data for t in entry.text) - assert len(cnt) == 55 * 3 - assert all(2 <= v <= 5 for v in cnt.values()) - - def test_prepare(self): - print("Creating new dataset") - with open(self.dataset_path / "ds_prep.jsonl", "w") as f: - for i in range(10): - f.write(json.dumps({"idx": i, "txt": f"{i}"}) + "\n\n") - - runner = CliRunner() - result = runner.invoke( - prepare_command, - [str(self.dataset_path / "ds_prep.jsonl")], - catch_exceptions=False, ) - print(result.stdout) - assert result.exit_code == 0, "Prepare failed, see output" - assert "Done" in result.stdout, "Prepare failed, see output" - assert "Found 10 samples" in result.stdout, "Prepare failed, see output" - assert (self.dataset_path / "ds_prep.jsonl.idx").exists() - torch.manual_seed(42) + with setup_s3_emulator(profile_name="s3test_jsonl_dataset") as emu: + # Upload the dataset to the S3 emulator + # EPath(dataset_path).copy(EPath("msc://s3/test/dataset")) + emu.add_file(dataset_path, "test/dataset") - # Train mode dataset with get_loader( get_train_dataset( - self.dataset_path / "ds_prep.jsonl", + mixed_mds_path, worker_config=WorkerConfig( rank=0, world_size=1, - num_workers=0, - seed_offset=42, + num_workers=2, ), batch_size=1, - shuffle_buffer_size=None, + shuffle_buffer_size=10, max_samples_per_sequence=None, + virtual_epoch_length=55 * 10, task_encoder=SimpleCookingTaskEncoder(), ) - ) as train_loader: - assert len(train_loader) == 10, f"Expected 10 samples, got {len(train_loader)}" - - train_order1 = [text for _, data in zip(range(50), train_loader) for text in data.text] - print(train_order1[:10]) - print(Counter(train_order1)) - assert len(Counter(train_order1)) == 10 - assert all(v == 5 for v in Counter(train_order1).values()) - - -if __name__ == "__main__": - unittest.main() + ) as train_dataset: + data = list(enumerate(train_dataset)) + assert len(data) == 55 * 10, len(data) + cnt = Counter(t for _, entry in data for t in entry.text) + assert len(cnt) == 55 * 3 + assert all(2 <= v <= 5 for v in cnt.values()) + + +def test_prepare(dataset_path): + print("Creating new dataset") + with open(dataset_path / "ds_prep.jsonl", "w") as f: + for i in range(10): + f.write(json.dumps({"idx": i, "txt": f"{i}"}) + "\n\n") + + runner = CliRunner() + result = runner.invoke( + prepare_command, + [str(dataset_path / "ds_prep.jsonl")], + catch_exceptions=False, + ) + print(result.stdout) + assert result.exit_code == 0, "Prepare failed, see output" + assert "Done" in result.stdout, "Prepare failed, see output" + assert "Found 10 samples" in result.stdout, "Prepare failed, see output" + assert (dataset_path / "ds_prep.jsonl.idx").exists() + + torch.manual_seed(42) + + # Train mode dataset + with get_loader( + get_train_dataset( + dataset_path / "ds_prep.jsonl", + worker_config=WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ), + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + task_encoder=SimpleCookingTaskEncoder(), + ) + ) as train_loader: + assert len(train_loader) == 10, f"Expected 10 samples, got {len(train_loader)}" + + train_order1 = [text for _, data in zip(range(50), train_loader) for text in data.text] + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(Counter(train_order1)) == 10 + assert all(v == 5 for v in Counter(train_order1).values()) diff --git a/tests/test_metadataset.py b/tests/test_metadataset.py index 03be3892..2b4fa1cc 100644 --- a/tests/test_metadataset.py +++ b/tests/test_metadataset.py @@ -9,13 +9,13 @@ import sys import tempfile import time -import unittest import warnings from collections import Counter from pathlib import Path from typing import Any, Iterable import numpy as np +import pytest import torch import webdataset as wds @@ -150,1277 +150,1268 @@ def assert_nested_equal(a: Any, b: Any, path: str = "") -> None: raise AssertionError(mismatch_details) -class TestDataset(unittest.TestCase): - # Set up the test fixture - def setUp(self): - logging.basicConfig(stream=sys.stderr, level=logging.INFO) - warnings.simplefilter("ignore", ResourceWarning) - - # Create a temporary directory - self.temp_dir = tempfile.TemporaryDirectory() - self.dataset_path = Path(self.temp_dir.name) - # self.dataset_path = Path("./test_dataset") - - self.dataset_path.mkdir(exist_ok=True, parents=True) - - (self.dataset_path / "ds1").mkdir(exist_ok=True, parents=True) - (self.dataset_path / "ds2").mkdir(exist_ok=True, parents=True) - - # Create a small dummy captioning dataset - self.create_text_test_dataset(self.dataset_path / "ds1", range(55), range(55)) - self.create_text_test_dataset(self.dataset_path / "ds2", range(100, 155), range(100, 155)) - self.create_text_test_dataset(self.dataset_path / "ds3", range(200, 255), range(0, 55)) - - self.mds_path = self.dataset_path / "metadataset.yaml" - with open(self.mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: Metadataset", - "splits:", - " train:", - " datasets:", - " - weight: 1", - " path: ds1", - " subflavor: ds1", - " subflavors:", - " source: metadataset.yaml", - " number: 43", - " mds: mds", - " shuffle_over_epochs_multiplier: 3", - " - weight: 1", - " path: ds2", - " subflavor: ds2", - " subflavors:", - " source: metadataset.yaml", - " number: 44", - " mds: mds", - " val:", - " datasets:", - " - weight: 1", - " path: ds1", - " split_part: train", - " - weight: 1", - " path: ds2", - " split_part: train", - ] - ) +@pytest.fixture +def temp_dir(): + """Create a temporary directory for test files.""" + temp_dir = tempfile.TemporaryDirectory() + yield temp_dir + gc.collect() + temp_dir.cleanup() + + +@pytest.fixture +def dataset_path(temp_dir): + """Create dataset path and setup test data.""" + logging.basicConfig(stream=sys.stderr, level=logging.INFO) + warnings.simplefilter("ignore", ResourceWarning) + + dataset_path = Path(temp_dir.name) + dataset_path.mkdir(exist_ok=True, parents=True) + + (dataset_path / "ds1").mkdir(exist_ok=True, parents=True) + (dataset_path / "ds2").mkdir(exist_ok=True, parents=True) + + # Create a small dummy captioning dataset + create_text_test_dataset(dataset_path / "ds1", range(55), range(55)) + create_text_test_dataset(dataset_path / "ds2", range(100, 155), range(100, 155)) + create_text_test_dataset(dataset_path / "ds3", range(200, 255), range(0, 55)) + + mds_path = dataset_path / "metadataset.yaml" + with open(mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: Metadataset", + "splits:", + " train:", + " datasets:", + " - weight: 1", + " path: ds1", + " subflavor: ds1", + " subflavors:", + " source: metadataset.yaml", + " number: 43", + " mds: mds", + " shuffle_over_epochs_multiplier: 3", + " - weight: 1", + " path: ds2", + " subflavor: ds2", + " subflavors:", + " source: metadataset.yaml", + " number: 44", + " mds: mds", + " val:", + " datasets:", + " - weight: 1", + " path: ds1", + " split_part: train", + " - weight: 1", + " path: ds2", + " split_part: train", + ] ) - self.nested_mds_path = self.dataset_path / "nested_metadataset.yaml" - with open(self.nested_mds_path, "w") as f: - f.write( - "\n".join( - [ - "splits:", - " train:", - " datasets:", - " - weight: 4", - " path: ./metadataset.yaml", - " split_part: train", - " subflavor: train", - " subflavors:", - " source: nested_metadataset.yaml", - " mds: nested_train", - " - path: ./metadataset.yaml", - " split_part: val", - " subflavors:", - " source: nested_metadataset.yaml", - " mds: nested_val", - ] - ) + ) + nested_mds_path = dataset_path / "nested_metadataset.yaml" + with open(nested_mds_path, "w") as f: + f.write( + "\n".join( + [ + "splits:", + " train:", + " datasets:", + " - weight: 4", + " path: ./metadataset.yaml", + " split_part: train", + " subflavor: train", + " subflavors:", + " source: nested_metadataset.yaml", + " mds: nested_train", + " - path: ./metadataset.yaml", + " split_part: val", + " subflavors:", + " source: nested_metadataset.yaml", + " mds: nested_val", + ] ) - print(self.dataset_path) - - def tearDown(self): - # Remove all temporary files - gc.collect() - self.temp_dir.cleanup() - - @staticmethod - def create_text_test_dataset(path: Path, txt_range: Iterable[int], key_range: Iterable[int]): - """Creates a small dummy test dataset for testing purposes.""" - - # Create num_samples unique captions - (path / "parts").mkdir(exist_ok=True, parents=True) - - # Initialize the ShardWriter - with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=10) as shard_writer: - for key, txt in zip(key_range, txt_range): - # Write individual files to shards - shard_writer.write( - { - "__key__": f"{key:06d}", - "txt": f"{txt}".encode(), - }, - ) - total_shards = shard_writer.shard + ) + print(dataset_path) + return dataset_path - from megatron.energon.flavors import BaseWebdatasetFactory - BaseWebdatasetFactory.prepare_dataset( - path, - [f"parts/data-{{0..{total_shards - 1}}}.tar"], - split_parts_ratio=[("train", 1.0)], - shuffle_seed=None, - ) +def create_text_test_dataset(path: Path, txt_range: Iterable[int], key_range: Iterable[int]): + """Creates a small dummy test dataset for testing purposes.""" - with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: - f.write( - "\n".join( - [ - "sample_type:", - " __module__: megatron.energon", - " __class__: TextSample", - "field_map:", - " text: txt", - "subflavors:", - " source: dataset.yaml", - " dataset.yaml: true", - " number: 42", - ] - ) - ) + # Create num_samples unique captions + (path / "parts").mkdir(exist_ok=True, parents=True) - def test_metadataset(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, + # Initialize the ShardWriter + with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=10) as shard_writer: + for key, txt in zip(key_range, txt_range): + # Write individual files to shards + shard_writer.write( + { + "__key__": f"{key:06d}", + "txt": f"{txt}".encode(), + }, + ) + total_shards = shard_writer.shard + + from megatron.energon.flavors import BaseWebdatasetFactory + + BaseWebdatasetFactory.prepare_dataset( + path, + [f"parts/data-{{0..{total_shards - 1}}}.tar"], + split_parts_ratio=[("train", 1.0)], + shuffle_seed=None, + ) + + with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: + f.write( + "\n".join( + [ + "sample_type:", + " __module__: megatron.energon", + " __class__: TextSample", + "field_map:", + " text: txt", + "subflavors:", + " source: dataset.yaml", + " dataset.yaml: true", + " number: 42", + ] + ) ) - # Train mode dataset - train_dataset = get_train_dataset( - self.mds_path, - worker_config=worker_config, - batch_size=10, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ) - print(len(train_dataset)) - assert len(train_dataset) == 11 - with get_loader(train_dataset) as train_loader1: - train_order1 = [ - text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text +def test_metadataset(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + + # Train mode dataset + train_dataset = get_train_dataset( + dataset_path / "metadataset.yaml", + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + print(len(train_dataset)) + assert len(train_dataset) == 11 + + with get_loader(train_dataset) as train_loader1: + train_order1 = [ + text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text + ] + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(Counter(train_order1)) == 110 + assert all(48 <= v <= 52 for v in Counter(train_order1).values()) + + train_subflavors = [ + subflavor["__subflavor__"] + for idx, data in zip(range(55), train_loader1) + for subflavor in data.__subflavors__ + ] + print("train_subflavors[:10]", train_subflavors[:10]) + print("Counter(train_subflavors)", Counter(train_subflavors)) + assert len(Counter(train_subflavors)) == 2 + assert all(250 <= v <= 300 for v in Counter(train_subflavors).values()) + + # Train mode dataset + train_dataset = get_train_dataset( + dataset_path / "metadataset.yaml", + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=25, + max_samples_per_sequence=25, + ) + print(len(train_dataset)) + assert len(train_dataset) == 11 + + with get_loader(train_dataset) as train_loader1: + train_order1 = [ + text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text + ] + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(Counter(train_order1)) == 110 + assert all(48 <= v <= 52 for v in Counter(train_order1).values()) + + # Val mode dataset + val_dataset = get_val_dataset( + dataset_path / "metadataset.yaml", worker_config=worker_config, batch_size=10 + ) + print(len(val_dataset)) + assert len(val_dataset) == 11 + + with get_loader(val_dataset) as val_loader1: + val_order1 = [text for data in val_loader1 for text in data.text] + assert len(val_order1) == 110 + print(Counter(val_order1)) + assert all(v == 1 for v in Counter(val_order1).values()) + + +def test_nested_metadataset(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + ) + + dataset = load_dataset(dataset_path / "nested_metadataset.yaml") + + raw_datasets = dataset.get_datasets( + training=False, split_part="train", worker_config=worker_config + ) + assert raw_datasets.blend_mode == DatasetBlendMode.DATASET_WEIGHT + assert [raw_dataset.weight for raw_dataset in raw_datasets.datasets] == [0.4, 0.4, 0.1, 0.1] + assert [raw_dataset.dataset.paths[0].name for raw_dataset in raw_datasets.datasets] == [ + "ds1", + "ds2", + "ds1", + "ds2", + ] + print([raw_dataset.dataset.subflavors for raw_dataset in raw_datasets.datasets]) + assert [raw_dataset.dataset.subflavors for raw_dataset in raw_datasets.datasets] == [ + { + "source": "nested_metadataset.yaml", + "dataset.yaml": True, + "number": 43, + "mds": "nested_train", + "__subflavor__": "train", + }, + { + "source": "nested_metadataset.yaml", + "dataset.yaml": True, + "number": 44, + "mds": "nested_train", + "__subflavor__": "train", + }, + { + "source": "nested_metadataset.yaml", + "dataset.yaml": True, + "number": 42, + "mds": "nested_val", + }, + { + "source": "nested_metadataset.yaml", + "dataset.yaml": True, + "number": 42, + "mds": "nested_val", + }, + ] + + # Train mode dataset + train_dataset = get_train_dataset( + dataset_path / "nested_metadataset.yaml", + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + print(len(train_dataset)) + assert len(train_dataset) == 22 + + with get_loader(train_dataset) as train_loader1: + train_order1 = [ + text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text + ] + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(Counter(train_order1)) == 110 + assert all(48 <= v <= 53 for v in Counter(train_order1).values()) + + train_subflavors = [ + subflavor.get("__subflavor__") + for idx, data in zip(range(55), train_loader1) + for subflavor in data.__subflavors__ + ] + cnt = Counter(train_subflavors) + print(train_subflavors[:10]) + print(cnt) + avg = 55 * 10 / 5 + assert len(Counter(train_subflavors)) == 2 + assert avg * 4 - 40 < cnt["train"] < avg * 4 + 40 + assert avg - 10 < cnt[None] < avg + 10 + + train_subflavorss = [ + tuple(subflavor.items()) + for idx, data in zip(range(55), train_loader1) + for subflavor in data.__subflavors__ + ] + cnt = Counter(train_subflavorss) + print(train_subflavorss[:10]) + print(cnt) + assert len(Counter(train_subflavorss)) == 3 + assert ( + avg * 2 - 20 + < cnt[ + ( + ("source", "nested_metadataset.yaml"), + ("dataset.yaml", True), + ("number", 43), + ("__subflavor__", "train"), + ("mds", "nested_train"), + ) ] - print(train_order1[:10]) - print(Counter(train_order1)) - assert len(Counter(train_order1)) == 110 - assert all(48 <= v <= 52 for v in Counter(train_order1).values()) - - train_subflavors = [ - subflavor["__subflavor__"] - for idx, data in zip(range(55), train_loader1) - for subflavor in data.__subflavors__ + < avg * 2 + 20 + ) + assert ( + avg * 2 - 20 + < cnt[ + ( + ("source", "nested_metadataset.yaml"), + ("dataset.yaml", True), + ("number", 44), + ("__subflavor__", "train"), + ("mds", "nested_train"), + ) ] - print("train_subflavors[:10]", train_subflavors[:10]) - print("Counter(train_subflavors)", Counter(train_subflavors)) - assert len(Counter(train_subflavors)) == 2 - assert all(250 <= v <= 300 for v in Counter(train_subflavors).values()) - - # Train mode dataset - train_dataset = get_train_dataset( - self.mds_path, - worker_config=worker_config, - batch_size=10, - shuffle_buffer_size=25, - max_samples_per_sequence=25, + < avg * 2 + 20 ) - print(len(train_dataset)) - assert len(train_dataset) == 11 - - with get_loader(train_dataset) as train_loader1: - train_order1 = [ - text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text + assert ( + avg * 1 - 20 + < cnt[ + ( + ("source", "nested_metadataset.yaml"), + ("dataset.yaml", True), + ("number", 42), + ("mds", "nested_val"), + ) ] - print(train_order1[:10]) - print(Counter(train_order1)) - assert len(Counter(train_order1)) == 110 - assert all(48 <= v <= 52 for v in Counter(train_order1).values()) - - # Val mode dataset - val_dataset = get_val_dataset(self.mds_path, worker_config=worker_config, batch_size=10) - print(len(val_dataset)) - assert len(val_dataset) == 11 - - with get_loader(val_dataset) as val_loader1: - val_order1 = [text for data in val_loader1 for text in data.text] - assert len(val_order1) == 110 - print(Counter(val_order1)) - assert all(v == 1 for v in Counter(val_order1).values()) - - def test_nested_metadataset(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, + < avg * 1 + 20 ) - dataset = load_dataset(self.nested_mds_path) - - raw_datasets = dataset.get_datasets( - training=False, split_part="train", worker_config=worker_config - ) - assert raw_datasets.blend_mode == DatasetBlendMode.DATASET_WEIGHT - assert [raw_dataset.weight for raw_dataset in raw_datasets.datasets] == [0.4, 0.4, 0.1, 0.1] - assert [raw_dataset.dataset.paths[0].name for raw_dataset in raw_datasets.datasets] == [ - "ds1", - "ds2", - "ds1", - "ds2", + # Train mode dataset + train_dataset = get_train_dataset( + dataset_path / "metadataset.yaml", + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=25, + max_samples_per_sequence=25, + ) + print(len(train_dataset)) + assert len(train_dataset) == 11 + + with get_loader(train_dataset) as train_loader1: + train_order1 = [ + text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text ] - print([raw_dataset.dataset.subflavors for raw_dataset in raw_datasets.datasets]) - assert [raw_dataset.dataset.subflavors for raw_dataset in raw_datasets.datasets] == [ - { - "source": "nested_metadataset.yaml", - "dataset.yaml": True, - "number": 43, - "mds": "nested_train", - "__subflavor__": "train", - }, - { - "source": "nested_metadataset.yaml", - "dataset.yaml": True, - "number": 44, - "mds": "nested_train", - "__subflavor__": "train", - }, - { - "source": "nested_metadataset.yaml", - "dataset.yaml": True, - "number": 42, - "mds": "nested_val", - }, - { - "source": "nested_metadataset.yaml", - "dataset.yaml": True, - "number": 42, - "mds": "nested_val", - }, - ] - - # Train mode dataset - train_dataset = get_train_dataset( - self.nested_mds_path, - worker_config=worker_config, - batch_size=10, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ) - print(len(train_dataset)) - assert len(train_dataset) == 22 - - with get_loader(train_dataset) as train_loader1: - train_order1 = [ - text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text - ] - print(train_order1[:10]) - print(Counter(train_order1)) - assert len(Counter(train_order1)) == 110 - assert all(48 <= v <= 53 for v in Counter(train_order1).values()) - - train_subflavors = [ - subflavor.get("__subflavor__") - for idx, data in zip(range(55), train_loader1) - for subflavor in data.__subflavors__ - ] - cnt = Counter(train_subflavors) - print(train_subflavors[:10]) - print(cnt) - avg = 55 * 10 / 5 - assert len(Counter(train_subflavors)) == 2 - assert avg * 4 - 40 < cnt["train"] < avg * 4 + 40 - assert avg - 10 < cnt[None] < avg + 10 - - train_subflavorss = [ - tuple(subflavor.items()) - for idx, data in zip(range(55), train_loader1) - for subflavor in data.__subflavors__ - ] - cnt = Counter(train_subflavorss) - print(train_subflavorss[:10]) - print(cnt) - assert len(Counter(train_subflavorss)) == 3 - assert ( - avg * 2 - 20 - < cnt[ - ( - ("source", "nested_metadataset.yaml"), - ("dataset.yaml", True), - ("number", 43), - ("__subflavor__", "train"), - ("mds", "nested_train"), - ) - ] - < avg * 2 + 20 - ) - assert ( - avg * 2 - 20 - < cnt[ - ( - ("source", "nested_metadataset.yaml"), - ("dataset.yaml", True), - ("number", 44), - ("__subflavor__", "train"), - ("mds", "nested_train"), - ) - ] - < avg * 2 + 20 - ) - assert ( - avg * 1 - 20 - < cnt[ - ( - ("source", "nested_metadataset.yaml"), - ("dataset.yaml", True), - ("number", 42), - ("mds", "nested_val"), - ) - ] - < avg * 1 + 20 + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(Counter(train_order1)) == 110 + assert all(48 <= v <= 52 for v in Counter(train_order1).values()) + + # Val mode dataset + val_dataset = get_val_dataset( + dataset_path / "metadataset.yaml", worker_config=worker_config, batch_size=10 + ) + print(len(val_dataset)) + assert len(val_dataset) == 11 + + with get_loader(val_dataset) as val_loader1: + val_order1 = [text for data in val_loader1 for text in data.text] + assert len(val_order1) == 110 + print(Counter(val_order1)) + assert all(v == 1 for v in Counter(val_order1).values()) + + +def test_worker_sample_balance(dataset_path): + torch.manual_seed(42) + + for num_workers in [6, 30]: + samples_per_global_worker = Counter() + + for rank in range(2): + wc = WorkerConfig( + rank=rank, + world_size=2, + num_workers=num_workers, ) - # Train mode dataset - train_dataset = get_train_dataset( - self.mds_path, - worker_config=worker_config, - batch_size=10, - shuffle_buffer_size=25, - max_samples_per_sequence=25, - ) - print(len(train_dataset)) - assert len(train_dataset) == 11 + train_dataset = get_train_dataset( + dataset_path / "nested_metadataset.yaml", + worker_config=wc, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) - with get_loader(train_dataset) as train_loader1: - train_order1 = [ - text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text - ] - print(train_order1[:10]) - print(Counter(train_order1)) - assert len(Counter(train_order1)) == 110 - assert all(48 <= v <= 52 for v in Counter(train_order1).values()) - - # Val mode dataset - val_dataset = get_val_dataset(self.mds_path, worker_config=worker_config, batch_size=10) - print(len(val_dataset)) - assert len(val_dataset) == 11 - - with get_loader(val_dataset) as val_loader1: - val_order1 = [text for data in val_loader1 for text in data.text] - assert len(val_order1) == 110 - print(Counter(val_order1)) - assert all(v == 1 for v in Counter(val_order1).values()) - - def test_worker_sample_balance(self): - torch.manual_seed(42) - - for num_workers in [6, 30]: - samples_per_global_worker = Counter() - - for rank in range(2): - wc = WorkerConfig( - rank=rank, - world_size=2, - num_workers=num_workers, - ) + blend_dataset = get_blend_dataset(train_dataset) + assert isinstance(blend_dataset, BlendDataset) - train_dataset = get_train_dataset( - self.nested_mds_path, - worker_config=wc, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ) + ds_weights = blend_dataset.dataset_weights + assert len(ds_weights) == 4 # 4 datasets - blend_dataset = get_blend_dataset(train_dataset) - assert isinstance(blend_dataset, BlendDataset) + # We are now going to count the number of samples that was assigned to each + # globally unique worker. This corresponds to the shard_ranges that energon + # prints out when the dataset is built. - ds_weights = blend_dataset.dataset_weights - assert len(ds_weights) == 4 # 4 datasets + for ds, w in ds_weights: + worker_slice_offsets = ds.dataset.dataset.workers_slice_offsets + assert len(worker_slice_offsets) == num_workers - # We are now going to count the number of samples that was assigned to each - # globally unique worker. This corresponds to the shard_ranges that energon - # prints out when the dataset is built. + for worker_idx, slice_offsets in enumerate(worker_slice_offsets): + samples_per_global_worker[(rank, worker_idx)] += ( + slice_offsets[-1] - slice_offsets[0] + ) + print(samples_per_global_worker) + + # Check the sample assignnent is balanced across all global workers + if num_workers == 6: + assert list(samples_per_global_worker.values()) == [ + 19, # rank 0 + 18, + 18, + 19, + 18, + 18, + 19, # rank 1 + 18, + 18, + 19, + 18, + 18, + ] + elif num_workers == 30: + # This should match the pattern of the first 40 items of a generalized bit + # reversal sequence of length 60. + # Given 4 * 55 = 220 samples modulo 60 workers, is 40 remaining samples + assert list(samples_per_global_worker.values()) == [ + 4, + 4, + 4, + 4, + 3, + 4, + 3, + 4, + 4, + 4, + 3, + 4, + 3, + 4, + 3, + 4, + 4, + 4, + 4, + 3, + 4, + 3, + 4, + 4, + 4, + 3, + 4, + 3, + 4, + 3, + 4, + 4, + 4, + 4, + 3, + 4, + 3, + 4, + 4, + 4, + 3, + 4, + 3, + 4, + 3, + 4, + 4, + 4, + 4, + 3, + 4, + 3, + 4, + 4, + 4, + 3, + 4, + 3, + 4, + 3, + ] - for ds, w in ds_weights: - worker_slice_offsets = ds.dataset.dataset.workers_slice_offsets - assert len(worker_slice_offsets) == num_workers - for worker_idx, slice_offsets in enumerate(worker_slice_offsets): - samples_per_global_worker[(rank, worker_idx)] += ( - slice_offsets[-1] - slice_offsets[0] - ) - print(samples_per_global_worker) - - # Check the sample assignnent is balanced across all global workers - if num_workers == 6: - assert list(samples_per_global_worker.values()) == [ - 19, # rank 0 - 18, - 18, - 19, - 18, - 18, - 19, # rank 1 - 18, - 18, - 19, - 18, - 18, - ] - elif num_workers == 30: - # This should match the pattern of the first 40 items of a generalized bit - # reversal sequence of length 60. - # Given 4 * 55 = 220 samples modulo 60 workers, is 40 remaining samples - assert list(samples_per_global_worker.values()) == [ - 4, - 4, - 4, - 4, - 3, - 4, - 3, - 4, - 4, - 4, - 3, - 4, - 3, - 4, - 3, - 4, - 4, - 4, - 4, - 3, - 4, - 3, - 4, - 4, - 4, - 3, - 4, - 3, - 4, - 3, - 4, - 4, - 4, - 4, - 3, - 4, - 3, - 4, - 4, - 4, - 3, - 4, - 3, - 4, - 3, - 4, - 4, - 4, - 4, - 3, - 4, - 3, - 4, - 4, - 4, - 3, - 4, - 3, - 4, - 3, - ] +def test_save_restore_state_train(dataset_path): + torch.manual_seed(42) - def test_save_restore_state_train(self): - torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, + def new_loader(): + return get_savable_loader( + get_train_dataset( + dataset_path / "metadataset.yaml", + worker_config=worker_config, + batch_size=10, + parallel_shard_iters=2, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + shuffle_over_epochs_multiplier=2, + ), ) - def new_loader(): - return get_savable_loader( - get_train_dataset( - self.mds_path, - worker_config=worker_config, - batch_size=10, - parallel_shard_iters=2, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - shuffle_over_epochs_multiplier=2, - ), - ) - - # Train mode dataset - with new_loader() as loader: - state_0 = loader.save_state_rank() - order_0 = [data.text for idx, data in zip(range(10), loader)] - state_1 = loader.save_state_rank() - # print("save state done") - order_1 = [data.text for idx, data in zip(range(20), loader)] - - state_2 = loader.save_state_rank() - # print("save state done") - # Iterated 30 samples, afterwards 50 samples. Checkpoint should be around that - order_2 = [data.text for idx, data in zip(range(20), loader)] - - state_3 = loader.save_state_rank() - # print("save state done") - # Iterated 50 samples, afterwards 53 samples. Checkpoint should be around that - order_3 = [data.text for idx, data in zip(range(3), loader)] - - state_4 = loader.save_state_rank() - # print("save state done") - # Dataset size is 55, want to save one sample before end of epoch - # Iterated 53 samples, afterwards 54 samples. Checkpoint should be around that - order_4 = [data.text for idx, data in zip(range(1), loader)] - - state_5 = loader.save_state_rank() - # print("save state done") - # Dataset size is 55, want to save one sample before end of epoch - # Iterated 54 samples, afterwards 55 samples. Checkpoint should be around that - order_5 = [data.text for idx, data in zip(range(1), loader)] - - state_6 = loader.save_state_rank() - # print("save state done") - # Dataset size is 55, want to save one sample before end of epoch - # Iterated 55 samples, afterwards 75 samples. Checkpoint should be around that - order_6 = [data.text for idx, data in zip(range(70), loader)] - - with new_loader().with_restored_state_rank(state_1) as loader: - print("state_1:", _norng_state(state_1)) - order_1_rest = [data.text for idx, data in zip(range(len(order_1)), loader)] - assert order_1 == order_1_rest - - with new_loader().with_restored_state_rank(state_0) as loader: - order_0_rest = [data.text for idx, data in zip(range(len(order_0)), loader)] - assert order_0 == order_0_rest - - with new_loader().with_restored_state_rank(state_2) as loader: - print("state_2:", _norng_state(state_2)) - order_2_rest = [data.text for idx, data in zip(range(len(order_2)), loader)] - print("order_2:", order_2) - print("order_2_rest:", order_2_rest) - assert order_2 == order_2_rest - - with new_loader().with_restored_state_rank(state_3) as loader: - print("state_3:", _norng_state(state_3)) - order_3_rest = [data.text for idx, data in zip(range(len(order_3)), loader)] - print("order_3:", order_3) - print("order_3_rest:", order_3_rest) - assert order_3 == order_3_rest - - with new_loader().with_restored_state_rank(state_4) as loader: - print("state_4:", _norng_state(state_4)) - order_4_rest = [data.text for idx, data in zip(range(len(order_4)), loader)] - print("order_4:", order_4) - print("order_4_rest:", order_4_rest) - assert order_4 == order_4_rest - - with new_loader().with_restored_state_rank(state_5) as loader: - print("state_5:", _norng_state(state_5)) - order_5_rest = [data.text for idx, data in zip(range(len(order_5)), loader)] - print("order_5:", order_5) - print("order_5_rest:", order_5_rest) - assert order_5 == order_5_rest - - with new_loader().with_restored_state_rank(state_6) as loader: - print("state_6:", _norng_state(state_6)) - order_6_rest = [data.text for idx, data in zip(range(len(order_6)), loader)] - print("order_6:", order_6) - print("order_6_rest:", order_6_rest) - assert order_6 == order_6_rest - - wrk_cfg = worker_config.config() - assert wrk_cfg == { - "rank": 0, - "world_size": 1, - "num_workers": 0, - "data_parallel_group": None, - } - print("loader.config():") - print(loader.config()) - print() - reference_config = { - "type": "MapDataset", + # Train mode dataset + with new_loader() as loader: + state_0 = loader.save_state_rank() + order_0 = [data.text for idx, data in zip(range(10), loader)] + state_1 = loader.save_state_rank() + # print("save state done") + order_1 = [data.text for idx, data in zip(range(20), loader)] + + state_2 = loader.save_state_rank() + # print("save state done") + # Iterated 30 samples, afterwards 50 samples. Checkpoint should be around that + order_2 = [data.text for idx, data in zip(range(20), loader)] + + state_3 = loader.save_state_rank() + # print("save state done") + # Iterated 50 samples, afterwards 53 samples. Checkpoint should be around that + order_3 = [data.text for idx, data in zip(range(3), loader)] + + state_4 = loader.save_state_rank() + # print("save state done") + # Dataset size is 55, want to save one sample before end of epoch + # Iterated 53 samples, afterwards 54 samples. Checkpoint should be around that + order_4 = [data.text for idx, data in zip(range(1), loader)] + + state_5 = loader.save_state_rank() + # print("save state done") + # Dataset size is 55, want to save one sample before end of epoch + # Iterated 54 samples, afterwards 55 samples. Checkpoint should be around that + order_5 = [data.text for idx, data in zip(range(1), loader)] + + state_6 = loader.save_state_rank() + # print("save state done") + # Dataset size is 55, want to save one sample before end of epoch + # Iterated 55 samples, afterwards 75 samples. Checkpoint should be around that + order_6 = [data.text for idx, data in zip(range(70), loader)] + + with new_loader().with_restored_state_rank(state_1) as loader: + print("state_1:", _norng_state(state_1)) + order_1_rest = [data.text for idx, data in zip(range(len(order_1)), loader)] + assert order_1 == order_1_rest + + with new_loader().with_restored_state_rank(state_0) as loader: + order_0_rest = [data.text for idx, data in zip(range(len(order_0)), loader)] + assert order_0 == order_0_rest + + with new_loader().with_restored_state_rank(state_2) as loader: + print("state_2:", _norng_state(state_2)) + order_2_rest = [data.text for idx, data in zip(range(len(order_2)), loader)] + print("order_2:", order_2) + print("order_2_rest:", order_2_rest) + assert order_2 == order_2_rest + + with new_loader().with_restored_state_rank(state_3) as loader: + print("state_3:", _norng_state(state_3)) + order_3_rest = [data.text for idx, data in zip(range(len(order_3)), loader)] + print("order_3:", order_3) + print("order_3_rest:", order_3_rest) + assert order_3 == order_3_rest + + with new_loader().with_restored_state_rank(state_4) as loader: + print("state_4:", _norng_state(state_4)) + order_4_rest = [data.text for idx, data in zip(range(len(order_4)), loader)] + print("order_4:", order_4) + print("order_4_rest:", order_4_rest) + assert order_4 == order_4_rest + + with new_loader().with_restored_state_rank(state_5) as loader: + print("state_5:", _norng_state(state_5)) + order_5_rest = [data.text for idx, data in zip(range(len(order_5)), loader)] + print("order_5:", order_5) + print("order_5_rest:", order_5_rest) + assert order_5 == order_5_rest + + with new_loader().with_restored_state_rank(state_6) as loader: + print("state_6:", _norng_state(state_6)) + order_6_rest = [data.text for idx, data in zip(range(len(order_6)), loader)] + print("order_6:", order_6) + print("order_6_rest:", order_6_rest) + assert order_6 == order_6_rest + + wrk_cfg = worker_config.config() + assert wrk_cfg == { + "rank": 0, + "world_size": 1, + "num_workers": 0, + "data_parallel_group": None, + } + print("loader.config():") + print(loader.config()) + print() + reference_config = { + "type": "MapDataset", + "dataset": { + "type": "BatchDataset", + "batch_size": 10, + "batcher": "megatron.energon.task_encoder.base.DefaultTaskEncoder.batch", + "batcher_stateless": True, + "drop_last": False, + "error_handler": "megatron.energon.wrappers._log_exception.log_exception", + "worker_config": wrk_cfg, "dataset": { - "type": "BatchDataset", - "batch_size": 10, - "batcher": "megatron.energon.task_encoder.base.DefaultTaskEncoder.batch", - "batcher_stateless": True, - "drop_last": False, - "error_handler": "megatron.energon.wrappers._log_exception.log_exception", - "worker_config": wrk_cfg, + "type": "MapDataset", "dataset": { - "type": "MapDataset", - "dataset": { - "type": "BlendDataset", - "dataset_weights": [ - ( - { - "type": "RepeatDataset", + "type": "BlendDataset", + "dataset_weights": [ + ( + { + "type": "RepeatDataset", + "dataset": { + "type": "MapDataset", "dataset": { - "type": "MapDataset", - "dataset": { - "type": "WebdatasetSampleLoaderDataset", - "joins": 1, - "len": 55, - "slice_offsets": [[0, 10, 20, 30, 40, 50, 55]], - "worker_config": wrk_cfg, - "shuffle_over_epochs": 6, - "parallel_slice_iters": 2, - }, - "map_fn": "megatron.energon.flavors.webdataset.base_webdataset.BaseWebdatasetFactory._load_sample_raw", - "map_fn_config": { - "type": "StandardWebdatasetFactory", - "training": True, - "_path": str(self.dataset_path / "ds1"), - "shards": [ - { - "name": "parts/data-0.tar", - "count": 10, - "_path": str( - self.dataset_path / "ds1/parts/data-0.tar" - ), - }, - { - "name": "parts/data-1.tar", - "count": 10, - "_path": str( - self.dataset_path / "ds1/parts/data-1.tar" - ), - }, - { - "name": "parts/data-2.tar", - "count": 10, - "_path": str( - self.dataset_path / "ds1/parts/data-2.tar" - ), - }, - { - "name": "parts/data-3.tar", - "count": 10, - "_path": str( - self.dataset_path / "ds1/parts/data-3.tar" - ), - }, - { - "name": "parts/data-4.tar", - "count": 10, - "_path": str( - self.dataset_path / "ds1/parts/data-4.tar" - ), - }, - { - "name": "parts/data-5.tar", - "count": 5, - "_path": str( - self.dataset_path / "ds1/parts/data-5.tar" - ), - }, - ], - "sample_excludes": [], - "shuffle_over_epochs": 6, - "parallel_shard_iters": 2, - "max_samples_per_sequence": None, - "subset": None, - "subflavors": { - "source": "metadataset.yaml", - "dataset.yaml": True, - "number": 43, - "mds": "mds", - "__subflavor__": "ds1", + "type": "WebdatasetSampleLoaderDataset", + "joins": 1, + "len": 55, + "slice_offsets": [[0, 10, 20, 30, 40, 50, 55]], + "worker_config": wrk_cfg, + "shuffle_over_epochs": 6, + "parallel_slice_iters": 2, + }, + "map_fn": "megatron.energon.flavors.webdataset.base_webdataset.BaseWebdatasetFactory._load_sample_raw", + "map_fn_config": { + "type": "StandardWebdatasetFactory", + "training": True, + "_path": str(dataset_path / "ds1"), + "shards": [ + { + "name": "parts/data-0.tar", + "count": 10, + "_path": str(dataset_path / "ds1/parts/data-0.tar"), }, - "sample_loader": "megatron.energon.flavors.webdataset.default_generic_webdataset.DefaultGenericWebdatasetFactory.__init__..", - "image_decode": "torchrgb", - "av_decode": "AVDecoder", - "video_decode_audio": False, - "guess_content": False, + { + "name": "parts/data-1.tar", + "count": 10, + "_path": str(dataset_path / "ds1/parts/data-1.tar"), + }, + { + "name": "parts/data-2.tar", + "count": 10, + "_path": str(dataset_path / "ds1/parts/data-2.tar"), + }, + { + "name": "parts/data-3.tar", + "count": 10, + "_path": str(dataset_path / "ds1/parts/data-3.tar"), + }, + { + "name": "parts/data-4.tar", + "count": 10, + "_path": str(dataset_path / "ds1/parts/data-4.tar"), + }, + { + "name": "parts/data-5.tar", + "count": 5, + "_path": str(dataset_path / "ds1/parts/data-5.tar"), + }, + ], + "sample_excludes": [], + "shuffle_over_epochs": 6, + "parallel_shard_iters": 2, + "max_samples_per_sequence": None, + "subset": None, + "subflavors": { + "source": "metadataset.yaml", + "dataset.yaml": True, + "number": 43, + "mds": "mds", + "__subflavor__": "ds1", }, - "map_fn_stateless": True, + "sample_loader": "megatron.energon.flavors.webdataset.default_generic_webdataset.DefaultGenericWebdatasetFactory.__init__..", + "image_decode": "torchrgb", + "av_decode": "AVDecoder", + "video_decode_audio": False, + "guess_content": False, }, - "repeats": None, - "worker_config": wrk_cfg, + "map_fn_stateless": True, }, - 0.5, - ), - ( - { - "type": "RepeatDataset", + "repeats": None, + "worker_config": wrk_cfg, + }, + 0.5, + ), + ( + { + "type": "RepeatDataset", + "dataset": { + "type": "MapDataset", "dataset": { - "type": "MapDataset", - "dataset": { - "type": "WebdatasetSampleLoaderDataset", - "joins": 1, - "len": 55, - "slice_offsets": [[0, 10, 20, 30, 40, 50, 55]], - "worker_config": wrk_cfg, - "shuffle_over_epochs": 2, - "parallel_slice_iters": 2, - }, - "map_fn": "megatron.energon.flavors.webdataset.base_webdataset.BaseWebdatasetFactory._load_sample_raw", - "map_fn_config": { - "type": "StandardWebdatasetFactory", - "training": True, - "_path": str(self.dataset_path / "ds2"), - "shards": [ - { - "name": "parts/data-0.tar", - "count": 10, - "_path": str( - self.dataset_path / "ds2/parts/data-0.tar" - ), - }, - { - "name": "parts/data-1.tar", - "count": 10, - "_path": str( - self.dataset_path / "ds2/parts/data-1.tar" - ), - }, - { - "name": "parts/data-2.tar", - "count": 10, - "_path": str( - self.dataset_path / "ds2/parts/data-2.tar" - ), - }, - { - "name": "parts/data-3.tar", - "count": 10, - "_path": str( - self.dataset_path / "ds2/parts/data-3.tar" - ), - }, - { - "name": "parts/data-4.tar", - "count": 10, - "_path": str( - self.dataset_path / "ds2/parts/data-4.tar" - ), - }, - { - "name": "parts/data-5.tar", - "count": 5, - "_path": str( - self.dataset_path / "ds2/parts/data-5.tar" - ), - }, - ], - "sample_excludes": [], - "shuffle_over_epochs": 2, - "parallel_shard_iters": 2, - "max_samples_per_sequence": None, - "subset": None, - "subflavors": { - "source": "metadataset.yaml", - "dataset.yaml": True, - "number": 44, - "mds": "mds", - "__subflavor__": "ds2", + "type": "WebdatasetSampleLoaderDataset", + "joins": 1, + "len": 55, + "slice_offsets": [[0, 10, 20, 30, 40, 50, 55]], + "worker_config": wrk_cfg, + "shuffle_over_epochs": 2, + "parallel_slice_iters": 2, + }, + "map_fn": "megatron.energon.flavors.webdataset.base_webdataset.BaseWebdatasetFactory._load_sample_raw", + "map_fn_config": { + "type": "StandardWebdatasetFactory", + "training": True, + "_path": str(dataset_path / "ds2"), + "shards": [ + { + "name": "parts/data-0.tar", + "count": 10, + "_path": str(dataset_path / "ds2/parts/data-0.tar"), + }, + { + "name": "parts/data-1.tar", + "count": 10, + "_path": str(dataset_path / "ds2/parts/data-1.tar"), + }, + { + "name": "parts/data-2.tar", + "count": 10, + "_path": str(dataset_path / "ds2/parts/data-2.tar"), + }, + { + "name": "parts/data-3.tar", + "count": 10, + "_path": str(dataset_path / "ds2/parts/data-3.tar"), + }, + { + "name": "parts/data-4.tar", + "count": 10, + "_path": str(dataset_path / "ds2/parts/data-4.tar"), + }, + { + "name": "parts/data-5.tar", + "count": 5, + "_path": str(dataset_path / "ds2/parts/data-5.tar"), }, - "sample_loader": "megatron.energon.flavors.webdataset.default_generic_webdataset.DefaultGenericWebdatasetFactory.__init__..", - "image_decode": "torchrgb", - "av_decode": "AVDecoder", - "video_decode_audio": False, - "guess_content": False, + ], + "sample_excludes": [], + "shuffle_over_epochs": 2, + "parallel_shard_iters": 2, + "max_samples_per_sequence": None, + "subset": None, + "subflavors": { + "source": "metadataset.yaml", + "dataset.yaml": True, + "number": 44, + "mds": "mds", + "__subflavor__": "ds2", }, - "map_fn_stateless": True, + "sample_loader": "megatron.energon.flavors.webdataset.default_generic_webdataset.DefaultGenericWebdatasetFactory.__init__..", + "image_decode": "torchrgb", + "av_decode": "AVDecoder", + "video_decode_audio": False, + "guess_content": False, }, - "repeats": None, - "worker_config": wrk_cfg, + "map_fn_stateless": True, }, - 0.5, - ), - ], - "worker_config": wrk_cfg, - }, - "map_fn": "megatron.energon.task_encoder.base.DefaultTaskEncoder.encode_sample", - "map_fn_stateless": True, + "repeats": None, + "worker_config": wrk_cfg, + }, + 0.5, + ), + ], + "worker_config": wrk_cfg, }, + "map_fn": "megatron.energon.task_encoder.base.DefaultTaskEncoder.encode_sample", + "map_fn_stateless": True, }, - "map_fn": "megatron.energon.task_encoder.base.DefaultTaskEncoder.encode_batch", - "map_fn_stateless": True, - } - print("Comparing dataset configs in test_save_restore_state_train.") - assert_nested_equal(loader.config(), reference_config) - - def test_save_restore_state_train_workers(self): - torch.manual_seed(42) - - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=1, - seed_offset=42, - ) - - def new_loader(): - return get_savable_loader( - get_train_dataset( - self.mds_path, - worker_config=worker_config, - batch_size=10, - parallel_shard_iters=2, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - ) - - # Train mode dataset - with new_loader() as loader: - state_0 = loader.save_state_rank() - order_0 = [data.text for idx, data in zip(range(10), loader)] - time.sleep(0.5) - state_1 = loader.save_state_rank() - # print("save state done") - order_1 = [data.text for idx, data in zip(range(20), loader)] - - # Ensure a checkpoint is created on next() - time.sleep(1.5) - - state_2 = loader.save_state_rank() - # print("save state done") - # Iterated 30 samples, afterwards 50 samples. Checkpoint should be around that - order_2 = [data.text for idx, data in zip(range(20), loader)] - - state_3 = loader.save_state_rank() - # print("save state done") - # Iterated 50 samples, afterwards 53 samples. Checkpoint should be around that - order_3 = [data.text for idx, data in zip(range(3), loader)] - - # Ensure a checkpoint is created on next() - time.sleep(1.5) - - state_4 = loader.save_state_rank() - # print("save state done") - # Dataset size is 55, want to save one sample before end of epoch - # Iterated 1 samples, afterwards 54 samples. Checkpoint should be around that - order_4 = [data.text for idx, data in zip(range(1), loader)] - - # Ensure a checkpoint is created on next() - time.sleep(1.5) - - state_5 = loader.save_state_rank() - # print("save state done") - # Dataset size is 55, want to save one sample before end of epoch - # Iterated 1 samples, afterwards 55 samples. Checkpoint should be around that - order_5 = [data.text for idx, data in zip(range(1), loader)] - - # Ensure a checkpoint is created on next() - time.sleep(1.5) - - state_6 = loader.save_state_rank() - # print("save state done") - # Dataset size is 55, want to save one sample before end of epoch - # Iterated 1 samples, afterwards 55 samples. Checkpoint should be around that - order_6 = [data.text for idx, data in zip(range(10), loader)] - - with new_loader().with_restored_state_rank(state_1) as loader: - print("state_1:", _norng_state(state_1)) - order_1_rest = [data.text for idx, data in zip(range(len(order_1)), loader)] - print("order_1:", order_1) - print("order_1_rest:", order_1_rest) - assert order_1 == order_1_rest - - with new_loader().with_restored_state_rank(state_0) as loader: - order_0_rest = [data.text for idx, data in zip(range(len(order_0)), loader)] - assert order_0 == order_0_rest - - with new_loader().with_restored_state_rank(state_2) as loader: - print("state_2:", _norng_state(state_2)) - order_2_rest = [data.text for idx, data in zip(range(len(order_2)), loader)] - print("order_2:", order_2) - print("order_2_rest:", order_2_rest) - assert order_2 == order_2_rest - - with new_loader().with_restored_state_rank(state_3) as loader: - print("state_3:", _norng_state(state_3)) - order_3_rest = [data.text for idx, data in zip(range(len(order_3)), loader)] - print("order_3:", order_3) - print("order_3_rest:", order_3_rest) - assert order_3 == order_3_rest - - with new_loader().with_restored_state_rank(state_4) as loader: - print("state_4:", _norng_state(state_4)) - order_4_rest = [data.text for idx, data in zip(range(len(order_4)), loader)] - print("order_4:", order_4) - print("order_4_rest:", order_4_rest) - assert order_4 == order_4_rest - - with new_loader().with_restored_state_rank(state_5) as loader: - print("state_5:", _norng_state(state_5)) - order_5_rest = [data.text for idx, data in zip(range(len(order_5)), loader)] - print("order_5:", order_5) - print("order_5_rest:", order_5_rest) - assert order_5 == order_5_rest - - with new_loader().with_restored_state_rank(state_6) as loader: - print("state_6:", _norng_state(state_6)) - order_6_rest = [data.text for idx, data in zip(range(len(order_6)), loader)] - print("order_6:", order_6) - print("order_6_rest:", order_6_rest) - assert order_6 == order_6_rest - - def test_save_restore_state_train_epochize_workers(self): - torch.manual_seed(42) - psi = 2 - vel = 19 - sbs = 10 - - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=2, - seed_offset=42, - ) - - # Train mode dataset - torch.manual_seed(42) - with get_savable_loader( - get_train_dataset( - self.mds_path, - worker_config=worker_config, - batch_size=1, - parallel_shard_iters=psi, - virtual_epoch_length=vel, - shuffle_buffer_size=sbs, - max_samples_per_sequence=sbs, - ), - ) as loader: - state_0 = loader.save_state_rank() - order_1 = [data.text[0] for data in loader] - state_1 = loader.save_state_rank() - order_2 = [data.text[0] for data in loader] - state_2 = loader.save_state_rank() - order_3 = [data.text[0] for idx, data in zip(range(17), loader)] - - torch.manual_seed(42) - with get_savable_loader( + }, + "map_fn": "megatron.energon.task_encoder.base.DefaultTaskEncoder.encode_batch", + "map_fn_stateless": True, + } + print("Comparing dataset configs in test_save_restore_state_train.") + assert_nested_equal(loader.config(), reference_config) + + +def test_save_restore_state_train_workers(dataset_path): + torch.manual_seed(42) + + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=1, + seed_offset=42, + ) + + def new_loader(): + return get_savable_loader( get_train_dataset( - self.mds_path, + dataset_path / "metadataset.yaml", worker_config=worker_config, - batch_size=1, - parallel_shard_iters=psi, - virtual_epoch_length=vel, - shuffle_buffer_size=sbs, - max_samples_per_sequence=sbs, - ), - ).with_restored_state_rank(state_0) as loader: - print("state_0:", _norng_state(state_0)) - order_5 = [data.text[0] for data in loader] - print("order_1:", order_1) - print("order_5:", order_5) - assert order_1 == order_5 - - torch.manual_seed(42) - with get_savable_loader( - get_train_dataset( - self.mds_path, - worker_config=worker_config, - batch_size=1, - parallel_shard_iters=psi, - virtual_epoch_length=vel, - shuffle_buffer_size=sbs, - max_samples_per_sequence=sbs, - ), - ).with_restored_state_rank(state_1) as loader: - print("state_1:", _norng_state(state_1)) - order_6 = [data.text[0] for data in loader] - print("order_2:", order_2) - print("order_6:", order_6) - assert order_2 == order_6 - - torch.manual_seed(42) - with get_savable_loader( - get_train_dataset( - self.mds_path, - worker_config=worker_config, - batch_size=1, - parallel_shard_iters=psi, - virtual_epoch_length=vel, - shuffle_buffer_size=sbs, - max_samples_per_sequence=sbs, + batch_size=10, + parallel_shard_iters=2, + shuffle_buffer_size=None, + max_samples_per_sequence=None, ), - ).with_restored_state_rank(state_2) as loader: - print("state_2:", _norng_state(state_2)) - order_7 = [data.text[0] for idx, data in zip(range(17), loader)] - print("order_3:", order_3) - print("order_7:", order_7) - assert order_3 == order_7 - - def test_save_restore_state_val(self): - torch.manual_seed(42) - - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, ) - # Train mode dataset - with get_savable_loader( - get_val_dataset(self.mds_path, worker_config=worker_config, batch_size=10), - ) as loader: - state_0 = loader.save_state_rank() - order_1 = [data.text for idx, data in zip(range(55 * 20), loader)] - state_1 = loader.save_state_rank() - # print("save state done") - order_2 = [data.text for idx, data in zip(range(55 * 20), loader)] - - with get_savable_loader( - get_val_dataset(self.mds_path, worker_config=worker_config, batch_size=10), - ).with_restored_state_rank(state_1) as loader: - order_3 = [data.text for idx, data in zip(range(55 * 20), loader)] - assert order_2 == order_3 - - with get_savable_loader( - get_val_dataset(self.mds_path, worker_config=worker_config, batch_size=10), - ).with_restored_state_rank(state_0) as loader: - order_4 = [data.text for idx, data in zip(range(55 * 20), loader)] - assert order_1 == order_4 - - def test_blending_randomness(self): - import random - - import numpy - - for num_workers in [0, 1, 2]: # Especially also check the num_workers=0 case - world_size = 4 - micro_batch_size = 1 - seed = 42 - - configs = ( - WorkerConfig(rank=0, world_size=world_size, num_workers=num_workers), - WorkerConfig(rank=1, world_size=world_size, num_workers=num_workers), - WorkerConfig(rank=2, world_size=world_size, num_workers=num_workers), + # Train mode dataset + with new_loader() as loader: + state_0 = loader.save_state_rank() + order_0 = [data.text for idx, data in zip(range(10), loader)] + time.sleep(0.5) + state_1 = loader.save_state_rank() + # print("save state done") + order_1 = [data.text for idx, data in zip(range(20), loader)] + + # Ensure a checkpoint is created on next() + time.sleep(1.5) + + state_2 = loader.save_state_rank() + # print("save state done") + # Iterated 30 samples, afterwards 50 samples. Checkpoint should be around that + order_2 = [data.text for idx, data in zip(range(20), loader)] + + state_3 = loader.save_state_rank() + # print("save state done") + # Iterated 50 samples, afterwards 53 samples. Checkpoint should be around that + order_3 = [data.text for idx, data in zip(range(3), loader)] + + # Ensure a checkpoint is created on next() + time.sleep(1.5) + + state_4 = loader.save_state_rank() + # print("save state done") + # Dataset size is 55, want to save one sample before end of epoch + # Iterated 1 samples, afterwards 54 samples. Checkpoint should be around that + order_4 = [data.text for idx, data in zip(range(1), loader)] + + # Ensure a checkpoint is created on next() + time.sleep(1.5) + + state_5 = loader.save_state_rank() + # print("save state done") + # Dataset size is 55, want to save one sample before end of epoch + # Iterated 1 samples, afterwards 55 samples. Checkpoint should be around that + order_5 = [data.text for idx, data in zip(range(1), loader)] + + # Ensure a checkpoint is created on next() + time.sleep(1.5) + + state_6 = loader.save_state_rank() + # print("save state done") + # Dataset size is 55, want to save one sample before end of epoch + # Iterated 1 samples, afterwards 55 samples. Checkpoint should be around that + order_6 = [data.text for idx, data in zip(range(10), loader)] + + with new_loader().with_restored_state_rank(state_1) as loader: + print("state_1:", _norng_state(state_1)) + order_1_rest = [data.text for idx, data in zip(range(len(order_1)), loader)] + print("order_1:", order_1) + print("order_1_rest:", order_1_rest) + assert order_1 == order_1_rest + + with new_loader().with_restored_state_rank(state_0) as loader: + order_0_rest = [data.text for idx, data in zip(range(len(order_0)), loader)] + assert order_0 == order_0_rest + + with new_loader().with_restored_state_rank(state_2) as loader: + print("state_2:", _norng_state(state_2)) + order_2_rest = [data.text for idx, data in zip(range(len(order_2)), loader)] + print("order_2:", order_2) + print("order_2_rest:", order_2_rest) + assert order_2 == order_2_rest + + with new_loader().with_restored_state_rank(state_3) as loader: + print("state_3:", _norng_state(state_3)) + order_3_rest = [data.text for idx, data in zip(range(len(order_3)), loader)] + print("order_3:", order_3) + print("order_3_rest:", order_3_rest) + assert order_3 == order_3_rest + + with new_loader().with_restored_state_rank(state_4) as loader: + print("state_4:", _norng_state(state_4)) + order_4_rest = [data.text for idx, data in zip(range(len(order_4)), loader)] + print("order_4:", order_4) + print("order_4_rest:", order_4_rest) + assert order_4 == order_4_rest + + with new_loader().with_restored_state_rank(state_5) as loader: + print("state_5:", _norng_state(state_5)) + order_5_rest = [data.text for idx, data in zip(range(len(order_5)), loader)] + print("order_5:", order_5) + print("order_5_rest:", order_5_rest) + assert order_5 == order_5_rest + + with new_loader().with_restored_state_rank(state_6) as loader: + print("state_6:", _norng_state(state_6)) + order_6_rest = [data.text for idx, data in zip(range(len(order_6)), loader)] + print("order_6:", order_6) + print("order_6_rest:", order_6_rest) + assert order_6 == order_6_rest + + +def test_save_restore_state_train_epochize_workers(dataset_path): + torch.manual_seed(42) + psi = 2 + vel = 19 + sbs = 10 + + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=2, + seed_offset=42, + ) + + # Train mode dataset + torch.manual_seed(42) + with get_savable_loader( + get_train_dataset( + dataset_path / "metadataset.yaml", + worker_config=worker_config, + batch_size=1, + parallel_shard_iters=psi, + virtual_epoch_length=vel, + shuffle_buffer_size=sbs, + max_samples_per_sequence=sbs, + ), + ) as loader: + state_0 = loader.save_state_rank() + order_1 = [data.text[0] for data in loader] + state_1 = loader.save_state_rank() + order_2 = [data.text[0] for data in loader] + state_2 = loader.save_state_rank() + order_3 = [data.text[0] for idx, data in zip(range(17), loader)] + + torch.manual_seed(42) + with get_savable_loader( + get_train_dataset( + dataset_path / "metadataset.yaml", + worker_config=worker_config, + batch_size=1, + parallel_shard_iters=psi, + virtual_epoch_length=vel, + shuffle_buffer_size=sbs, + max_samples_per_sequence=sbs, + ), + ).with_restored_state_rank(state_0) as loader: + print("state_0:", _norng_state(state_0)) + order_5 = [data.text[0] for data in loader] + print("order_1:", order_1) + print("order_5:", order_5) + assert order_1 == order_5 + + torch.manual_seed(42) + with get_savable_loader( + get_train_dataset( + dataset_path / "metadataset.yaml", + worker_config=worker_config, + batch_size=1, + parallel_shard_iters=psi, + virtual_epoch_length=vel, + shuffle_buffer_size=sbs, + max_samples_per_sequence=sbs, + ), + ).with_restored_state_rank(state_1) as loader: + print("state_1:", _norng_state(state_1)) + order_6 = [data.text[0] for data in loader] + print("order_2:", order_2) + print("order_6:", order_6) + assert order_2 == order_6 + + torch.manual_seed(42) + with get_savable_loader( + get_train_dataset( + dataset_path / "metadataset.yaml", + worker_config=worker_config, + batch_size=1, + parallel_shard_iters=psi, + virtual_epoch_length=vel, + shuffle_buffer_size=sbs, + max_samples_per_sequence=sbs, + ), + ).with_restored_state_rank(state_2) as loader: + print("state_2:", _norng_state(state_2)) + order_7 = [data.text[0] for idx, data in zip(range(17), loader)] + print("order_3:", order_3) + print("order_7:", order_7) + assert order_3 == order_7 + + +def test_save_restore_state_val(dataset_path): + torch.manual_seed(42) + + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + + # Train mode dataset + with get_savable_loader( + get_val_dataset( + dataset_path / "metadataset.yaml", worker_config=worker_config, batch_size=10 + ), + ) as loader: + state_0 = loader.save_state_rank() + order_1 = [data.text for idx, data in zip(range(55 * 20), loader)] + state_1 = loader.save_state_rank() + # print("save state done") + order_2 = [data.text for idx, data in zip(range(55 * 20), loader)] + + with get_savable_loader( + get_val_dataset( + dataset_path / "metadataset.yaml", worker_config=worker_config, batch_size=10 + ), + ).with_restored_state_rank(state_1) as loader: + order_3 = [data.text for idx, data in zip(range(55 * 20), loader)] + assert order_2 == order_3 + + with get_savable_loader( + get_val_dataset( + dataset_path / "metadataset.yaml", worker_config=worker_config, batch_size=10 + ), + ).with_restored_state_rank(state_0) as loader: + order_4 = [data.text for idx, data in zip(range(55 * 20), loader)] + assert order_1 == order_4 + + +def test_blending_randomness(dataset_path): + import random + + import numpy + + for num_workers in [0, 1, 2]: # Especially also check the num_workers=0 case + world_size = 4 + micro_batch_size = 1 + seed = 42 + + configs = ( + WorkerConfig(rank=0, world_size=world_size, num_workers=num_workers), + WorkerConfig(rank=1, world_size=world_size, num_workers=num_workers), + WorkerConfig(rank=2, world_size=world_size, num_workers=num_workers), + ) + + all_ranks_subflavors = [] + for rank_config in configs: + torch.manual_seed(seed) + numpy.random.seed(seed) + random.seed(seed) + + ds = get_train_dataset( + dataset_path / "metadataset.yaml", + split_part="train", + worker_config=rank_config, + batch_size=micro_batch_size, + shuffle_buffer_size=None, + max_samples_per_sequence=None, ) + with get_loader(ds) as loader: + subflavors = [ + data.__subflavors__[0].get("__subflavor__") + for idx, data in zip(range(25), loader) + ] - all_ranks_subflavors = [] - for rank_config in configs: - torch.manual_seed(seed) - numpy.random.seed(seed) - random.seed(seed) - - ds = get_train_dataset( - self.mds_path, - split_part="train", - worker_config=rank_config, - batch_size=micro_batch_size, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ) - with get_loader(ds) as loader: - subflavors = [ - data.__subflavors__[0].get("__subflavor__") - for idx, data in zip(range(25), loader) - ] + all_ranks_subflavors.append(subflavors) - all_ranks_subflavors.append(subflavors) + print(f"Subflavors for rank {rank_config.rank}:", subflavors) - print(f"Subflavors for rank {rank_config.rank}:", subflavors) + # Assert that all ranks got different data + for i in range(len(all_ranks_subflavors)): + for j in range(i + 1, len(all_ranks_subflavors)): + assert all_ranks_subflavors[i] != all_ranks_subflavors[j], ( + f"Rank {i} and rank {j} got the same subflavors." + ) - # Assert that all ranks got different data - for i in range(len(all_ranks_subflavors)): - for j in range(i + 1, len(all_ranks_subflavors)): - assert all_ranks_subflavors[i] != all_ranks_subflavors[j], ( - f"Rank {i} and rank {j} got the same subflavors." - ) - def test_slice_iter_shuffle_over_epochs(self): - torch.manual_seed(42) +def test_slice_iter_shuffle_over_epochs(dataset_path): + torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + + def new_loader(): + return get_savable_loader( + get_train_dataset( + dataset_path / "metadataset.yaml", + worker_config=worker_config, + batch_size=10, + parallel_shard_iters=2, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + shuffle_over_epochs_multiplier=-1, + ), ) - def new_loader(): - return get_savable_loader( - get_train_dataset( - self.mds_path, - worker_config=worker_config, - batch_size=10, - parallel_shard_iters=2, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - shuffle_over_epochs_multiplier=-1, - ), - ) + # Train mode dataset + with new_loader() as loader: + _ = [data.text for idx, data in zip(range(1000), loader)] - # Train mode dataset - with new_loader() as loader: - _ = [data.text for idx, data in zip(range(1000), loader)] - def test_save_restore_next(self): - torch.manual_seed(42) +def test_save_restore_next(dataset_path): + torch.manual_seed(42) - wc = WorkerConfig( - rank=0, - world_size=1, - num_workers=6, - ) + wc = WorkerConfig( + rank=0, + world_size=1, + num_workers=6, + ) - with get_savable_loader( - get_train_dataset( - self.nested_mds_path, - worker_config=wc, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - ) as initial_loader: - skip_initial = 9 - - previous_cp = initial_loader.save_state_rank() - print("initial_samples:") - for i, sample in zip(range(skip_initial), initial_loader): - print(f"sample[@{i}]: {sample.text}") - print("previous_cp:", previous_cp) - with get_savable_loader( - get_train_dataset( - self.nested_mds_path, - worker_config=wc, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - ).with_restored_state_rank(previous_cp) as rst_loader: - for i, rst_sample in zip(range(1), rst_loader): - print(f"rst_sample[@{i}]: {rst_sample.text}") - assert sample.text == rst_sample.text, f"{sample} != {rst_sample}" - assert sample.__key__ == rst_sample.__key__, f"{sample} != {rst_sample}" - assert sample.__restore_key__ == rst_sample.__restore_key__, ( - f"{sample} != {rst_sample}" - ) - previous_cp = initial_loader.save_state_rank() - - # Iterate 10 samples, the save state and store the next 10 samples for reference. - state_initial = initial_loader.save_state_rank() - print("state_initial:", str(state_initial)) - initial_samples = [sample for _, sample in zip(range(20), initial_loader)] - print( - "initial_samples:" - + "".join( - f"\n [@{idx}] {sample.text}" - for idx, sample in enumerate(initial_samples, start=skip_initial) + with get_savable_loader( + get_train_dataset( + dataset_path / "nested_metadataset.yaml", + worker_config=wc, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ), + ) as initial_loader: + skip_initial = 9 + + previous_cp = initial_loader.save_state_rank() + print("initial_samples:") + for i, sample in zip(range(skip_initial), initial_loader): + print(f"sample[@{i}]: {sample.text}") + print("previous_cp:", previous_cp) + with get_savable_loader( + get_train_dataset( + dataset_path / "nested_metadataset.yaml", + worker_config=wc, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ), + ).with_restored_state_rank(previous_cp) as rst_loader: + for i, rst_sample in zip(range(1), rst_loader): + print(f"rst_sample[@{i}]: {rst_sample.text}") + assert sample.text == rst_sample.text, f"{sample} != {rst_sample}" + assert sample.__key__ == rst_sample.__key__, f"{sample} != {rst_sample}" + assert sample.__restore_key__ == rst_sample.__restore_key__, ( + f"{sample} != {rst_sample}" ) + previous_cp = initial_loader.save_state_rank() + + # Iterate 10 samples, the save state and store the next 10 samples for reference. + state_initial = initial_loader.save_state_rank() + print("state_initial:", str(state_initial)) + initial_samples = [sample for _, sample in zip(range(20), initial_loader)] + print( + "initial_samples:" + + "".join( + f"\n [@{idx}] {sample.text}" + for idx, sample in enumerate(initial_samples, start=skip_initial) ) - - second_loader = get_savable_loader( - get_train_dataset( - self.nested_mds_path, - worker_config=wc, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), ) - second_loader.restore_state_rank(state_initial) + + second_loader = get_savable_loader( + get_train_dataset( + dataset_path / "nested_metadataset.yaml", + worker_config=wc, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ), + ) + second_loader.restore_state_rank(state_initial) + # Save the state again, to check that it is the same as the just restored state + same_state = second_loader.save_state_rank() + print("same_state:", same_state) + assert_nested_equal(same_state, state_initial) + assert same_state is state_initial + + # This will propagate the state to the workers. + second_loader.start() + try: # Save the state again, to check that it is the same as the just restored state same_state = second_loader.save_state_rank() print("same_state:", same_state) assert_nested_equal(same_state, state_initial) - assert same_state is state_initial - - # This will propagate the state to the workers. - second_loader.start() - try: - # Save the state again, to check that it is the same as the just restored state - same_state = second_loader.save_state_rank() - print("same_state:", same_state) - assert_nested_equal(same_state, state_initial) - for offset in range(10): + for offset in range(10): + try: + # Save state and restore in next loader + state_offset = second_loader.save_state_rank() + # Get 1 sample from the current loader + samples = [sample for _, sample in zip(range(1), second_loader)] + assert len(samples) == 1 + sample = samples[0] + + # Check that the sample is the same as the initial loader's reference sample + print(f"sample[@{offset + skip_initial}]: {sample.text}") try: - # Save state and restore in next loader - state_offset = second_loader.save_state_rank() - # Get 1 sample from the current loader - samples = [sample for _, sample in zip(range(1), second_loader)] - assert len(samples) == 1 - sample = samples[0] - - # Check that the sample is the same as the initial loader's reference sample - print(f"sample[@{offset + skip_initial}]: {sample.text}") - try: - assert sample.text == initial_samples[offset].text, ( - f"{sample} != {initial_samples[offset]}" - ) - assert sample.__key__ == initial_samples[offset].__key__, ( - f"{sample} != {initial_samples[offset]}" - ) - assert sample.__restore_key__ == initial_samples[offset].__restore_key__, ( - f"{sample} != {initial_samples[offset]}" - ) - except Exception as e: - print( - "samples:" - + f"\n [@{offset + skip_initial}] {sample.text}" - + "".join( - f"\n [@{idx}] {sample.text}" - for idx, sample in zip( - range(skip_initial + offset + 1, skip_initial + offset + 6), - second_loader, - ) + assert sample.text == initial_samples[offset].text, ( + f"{sample} != {initial_samples[offset]}" + ) + assert sample.__key__ == initial_samples[offset].__key__, ( + f"{sample} != {initial_samples[offset]}" + ) + assert sample.__restore_key__ == initial_samples[offset].__restore_key__, ( + f"{sample} != {initial_samples[offset]}" + ) + except Exception as e: + print( + "samples:" + + f"\n [@{offset + skip_initial}] {sample.text}" + + "".join( + f"\n [@{idx}] {sample.text}" + for idx, sample in zip( + range(skip_initial + offset + 1, skip_initial + offset + 6), + second_loader, ) ) - raise ValueError( - f"Failed to iterate @{offset + skip_initial} samples" - ) from e - - # Restore state in a new loader - with get_savable_loader( - get_train_dataset( - self.nested_mds_path, - worker_config=wc, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - ).with_restored_state_rank(state_offset) as ref_loader: - # Get 1 sample from the restored loader - next_loader_samples = [sample for _, sample in zip(range(6), ref_loader)] - assert len(next_loader_samples) == 6 - next_loader_sample = next_loader_samples[0] - print( - "next_loader_samples:" - + f"\n [@{offset + skip_initial}] {sample.text}" - + "".join( - f"\n [@{idx}] {sample}" - for idx, sample in zip( - range(skip_initial + offset, skip_initial + offset + 6), - next_loader_samples, - ) + ) + raise ValueError(f"Failed to iterate @{offset + skip_initial} samples") from e + + # Restore state in a new loader + with get_savable_loader( + get_train_dataset( + dataset_path / "nested_metadataset.yaml", + worker_config=wc, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ), + ).with_restored_state_rank(state_offset) as ref_loader: + # Get 1 sample from the restored loader + next_loader_samples = [sample for _, sample in zip(range(6), ref_loader)] + assert len(next_loader_samples) == 6 + next_loader_sample = next_loader_samples[0] + print( + "next_loader_samples:" + + f"\n [@{offset + skip_initial}] {sample.text}" + + "".join( + f"\n [@{idx}] {sample}" + for idx, sample in zip( + range(skip_initial + offset, skip_initial + offset + 6), + next_loader_samples, ) ) - assert next_loader_sample.text == sample.text, ( - f"{next_loader_sample} != {sample}" - ) - assert next_loader_sample.__key__ == sample.__key__, ( - f"{next_loader_sample} != {sample}" - ) - assert next_loader_sample.__restore_key__ == sample.__restore_key__, ( - f"{next_loader_sample} != {sample}" - ) - except Exception as e: - raise ValueError(f"Failed to iterate @{skip_initial}+{offset} samples") from e - finally: - second_loader.shutdown() - - -if __name__ == "__main__": - unittest.main() + ) + assert next_loader_sample.text == sample.text, ( + f"{next_loader_sample} != {sample}" + ) + assert next_loader_sample.__key__ == sample.__key__, ( + f"{next_loader_sample} != {sample}" + ) + assert next_loader_sample.__restore_key__ == sample.__restore_key__, ( + f"{next_loader_sample} != {sample}" + ) + except Exception as e: + raise ValueError(f"Failed to iterate @{skip_initial}+{offset} samples") from e + finally: + second_loader.shutdown() diff --git a/tests/test_metadataset_fewsamp.py b/tests/test_metadataset_fewsamp.py index d2cddd5d..4b8f43c1 100644 --- a/tests/test_metadataset_fewsamp.py +++ b/tests/test_metadataset_fewsamp.py @@ -7,11 +7,11 @@ import logging import sys import tempfile -import unittest import warnings from pathlib import Path from typing import Iterable +import pytest import torch import webdataset as wds @@ -61,195 +61,190 @@ def get_blend_dataset(ds: SavableDataset): raise ValueError("No blend dataset found") -class TestDataset(unittest.TestCase): - # Set up the test fixture - def setUp(self): - logging.basicConfig(stream=sys.stderr, level=logging.INFO) - warnings.simplefilter("ignore", ResourceWarning) - - # Create a temporary directory - self.temp_dir = tempfile.TemporaryDirectory() - self.dataset_path = Path(self.temp_dir.name) - # self.dataset_path = Path("./test_dataset") - - self.dataset_path.mkdir(exist_ok=True, parents=True) - - (self.dataset_path / "ds1").mkdir(exist_ok=True, parents=True) - (self.dataset_path / "ds2").mkdir(exist_ok=True, parents=True) - (self.dataset_path / "ds3").mkdir(exist_ok=True, parents=True) - - # Create a small dummy captioning dataset - self.create_text_test_dataset(self.dataset_path / "ds1", range(55), range(55)) - self.create_text_test_dataset(self.dataset_path / "ds2", range(100, 107), range(100, 107)) - self.create_text_test_dataset(self.dataset_path / "ds3", range(200, 255), range(0, 55)) - - self.mds_path = self.dataset_path / "metadataset_v2.yaml" - with open(self.mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " blend:", - " - weight: 1", - " path: ds1", - " - weight: 1", - " path: ds2", - " - weight: 1", - " path: ds3", - ] - ) +@pytest.fixture +def temp_dir(): + """Create a temporary directory for test files.""" + temp_dir = tempfile.TemporaryDirectory() + yield temp_dir + gc.collect() + temp_dir.cleanup() + + +@pytest.fixture +def dataset_path(temp_dir): + """Create dataset path and setup test data.""" + logging.basicConfig(stream=sys.stderr, level=logging.INFO) + warnings.simplefilter("ignore", ResourceWarning) + + dataset_path = Path(temp_dir.name) + dataset_path.mkdir(exist_ok=True, parents=True) + + (dataset_path / "ds1").mkdir(exist_ok=True, parents=True) + (dataset_path / "ds2").mkdir(exist_ok=True, parents=True) + (dataset_path / "ds3").mkdir(exist_ok=True, parents=True) + + # Create a small dummy captioning dataset + create_text_test_dataset(dataset_path / "ds1", range(55), range(55)) + create_text_test_dataset(dataset_path / "ds2", range(100, 107), range(100, 107)) + create_text_test_dataset(dataset_path / "ds3", range(200, 255), range(0, 55)) + + mds_path = dataset_path / "metadataset_v2.yaml" + with open(mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " blend:", + " - weight: 1", + " path: ds1", + " - weight: 1", + " path: ds2", + " - weight: 1", + " path: ds3", + ] ) - - print(self.dataset_path) - - def tearDown(self): - # Remove all temporary files - gc.collect() - self.temp_dir.cleanup() - - @staticmethod - def create_text_test_dataset(path: Path, txt_range: Iterable[int], key_range: Iterable[int]): - """Creates a small dummy test dataset for testing purposes.""" - - # Create num_samples unique captions - (path / "parts").mkdir(exist_ok=True, parents=True) - - # Initialize the ShardWriter - with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=10) as shard_writer: - for key, txt in zip(key_range, txt_range): - # Write individual files to shards - shard_writer.write( - { - "__key__": f"{key:06d}", - "txt": f"{txt}".encode(), - }, - ) - total_shards = shard_writer.shard - - from megatron.energon.flavors import BaseWebdatasetFactory - - BaseWebdatasetFactory.prepare_dataset( - path, - [f"parts/data-{{0..{total_shards - 1}}}.tar"], - split_parts_ratio=[("train", 1.0)], - shuffle_seed=None, ) - with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: TextWebdataset", - "field_map:", - " text: txt", - "subflavors:", - " source: dataset.yaml", - " dataset.yaml: true", - " number: 42", - ] - ) - ) + print(dataset_path) + return dataset_path + + +def create_text_test_dataset(path: Path, txt_range: Iterable[int], key_range: Iterable[int]): + """Creates a small dummy test dataset for testing purposes.""" + + # Create num_samples unique captions + (path / "parts").mkdir(exist_ok=True, parents=True) - def test_metadataset_few_samples_save_restore(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=32, - num_workers=1, - seed_offset=42, + # Initialize the ShardWriter + with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=10) as shard_writer: + for key, txt in zip(key_range, txt_range): + # Write individual files to shards + shard_writer.write( + { + "__key__": f"{key:06d}", + "txt": f"{txt}".encode(), + }, + ) + total_shards = shard_writer.shard + + from megatron.energon.flavors import BaseWebdatasetFactory + + BaseWebdatasetFactory.prepare_dataset( + path, + [f"parts/data-{{0..{total_shards - 1}}}.tar"], + split_parts_ratio=[("train", 1.0)], + shuffle_seed=None, + ) + + with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: TextWebdataset", + "field_map:", + " text: txt", + "subflavors:", + " source: dataset.yaml", + " dataset.yaml: true", + " number: 42", + ] + ) ) - # Train mode dataset - train_dataset = get_train_dataset( - self.mds_path, + +def test_metadataset_few_samples_save_restore(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=32, + num_workers=1, + seed_offset=42, + ) + + # Train mode dataset + train_dataset = get_train_dataset( + dataset_path / "metadataset_v2.yaml", + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=100, + max_samples_per_sequence=None, + ) + print(len(train_dataset)) + assert len(train_dataset) == 4 + + # The middle dataset should have 0 samples assigned to this rank + blend_ds = get_blend_dataset(train_dataset) + assert len(blend_ds.dataset_weights[1][0].dataset.dataset.workers_slice_offsets[0]) == 1 + assert len(blend_ds.dataset_weights[1][0].dataset.dataset) == 0 + + with get_savable_loader( + train_dataset, + ) as train_loader: + # Load 3 samples + list(zip(train_loader, range(3))) + + # Save state mid epoch + state1 = train_loader.save_state_rank() + + # Load 5 samples + data1b = list(zip(train_loader, range(5))) + + # Restore state + with get_savable_loader( + get_train_dataset( + dataset_path / "metadataset_v2.yaml", worker_config=worker_config, batch_size=1, shuffle_buffer_size=100, max_samples_per_sequence=None, - ) - print(len(train_dataset)) - assert len(train_dataset) == 4 + ), + ).with_restored_state_rank(state1) as train_loader: + # Load 5 samples + data2_restore = list(zip(train_loader, range(5))) - # The middle dataset should have 0 samples assigned to this rank - blend_ds = get_blend_dataset(train_dataset) - assert len(blend_ds.dataset_weights[1][0].dataset.dataset.workers_slice_offsets[0]) == 1 - assert len(blend_ds.dataset_weights[1][0].dataset.dataset) == 0 + # Check that the restored state is the same + order1b = [(s[0].__key__[0], int(s[0].text[0])) for s in data1b] + order2 = [(s[0].__key__[0], int(s[0].text[0])) for s in data2_restore] - with get_savable_loader( - train_dataset, - ) as train_loader: - # Load 3 samples - list(zip(train_loader, range(3))) + print("order1b") + print(order1b) + print("order2") + print(order2) - # Save state mid epoch - state1 = train_loader.save_state_rank() + assert order1b == order2, "The restored state does not match the original state." - # Load 5 samples - data1b = list(zip(train_loader, range(5))) - # Restore state +def test_too_few_samples(dataset_path): + # Will only give a single sample, as there are 117 samples in total, and 100 ranks + ws = 100 + lens = [] + for i_rank in range(ws): + worker_config = WorkerConfig(rank=i_rank, world_size=ws, num_workers=0) with get_savable_loader( get_train_dataset( - self.mds_path, - worker_config=worker_config, + dataset_path / "metadataset_v2.yaml", batch_size=1, - shuffle_buffer_size=100, + worker_config=worker_config, + shuffle_buffer_size=None, max_samples_per_sequence=None, ), - ).with_restored_state_rank(state1) as train_loader: - # Load 5 samples - data2_restore = list(zip(train_loader, range(5))) - - # Check that the restored state is the same - order1b = [(s[0].__key__[0], int(s[0].text[0])) for s in data1b] - order2 = [(s[0].__key__[0], int(s[0].text[0])) for s in data2_restore] - - print("order1b") - print(order1b) - print("order2") - print(order2) - - assert order1b == order2, "The restored state does not match the original state." - - def test_too_few_samples(self): - # Will only give a single sample, as there are 117 samples in total, and 100 ranks - ws = 100 - lens = [] - for i_rank in range(ws): - worker_config = WorkerConfig(rank=i_rank, world_size=ws, num_workers=0) - with get_savable_loader( - get_train_dataset( - self.mds_path, - batch_size=1, - worker_config=worker_config, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - ) as loader: - lens.append(len(loader)) - - txts = [] - - for i, sample in zip(range(10), loader): - txts.extend(sample.text) - - assert len(set(txts)) == len(loader), ( - f"Rank {i_rank} should have exactly {len(loader)} sample, but got {txts}" - ) - - assert lens == [ - 2 if i in [0, 3, 6, 12, 18, 25, 31, 37, 43, 50, 56, 62, 68, 75, 81, 87, 93] else 1 - for i in range(100) - ] - - -if __name__ == "__main__": - # unittest.main() - ds = TestDataset() - ds.setUp() - ds.test_metadataset_few_samples_save_restore() - ds.tearDown() + ) as loader: + lens.append(len(loader)) + + txts = [] + + for i, sample in zip(range(10), loader): + txts.extend(sample.text) + + assert len(set(txts)) == len(loader), ( + f"Rank {i_rank} should have exactly {len(loader)} sample, but got {txts}" + ) + + assert lens == [ + 2 if i in [0, 3, 6, 12, 18, 25, 31, 37, 43, 50, 56, 62, 68, 75, 81, 87, 93] else 1 + for i in range(100) + ] diff --git a/tests/test_metadataset_v2.py b/tests/test_metadataset_v2.py index 973055d3..6063cb16 100644 --- a/tests/test_metadataset_v2.py +++ b/tests/test_metadataset_v2.py @@ -8,13 +8,13 @@ import random import sys import tempfile -import unittest import warnings from collections import Counter from pathlib import Path from typing import Iterable from unittest.mock import patch +import pytest import torch import webdataset as wds @@ -64,290 +64,551 @@ def _norng_state(state): @edataclass -class TestJoinedSample(Sample): +class JoinedSample(Sample): text1: torch.Tensor text2: torch.Tensor @staticmethod - def from_joined(ds1: TextSample, ds2: TextSample) -> "TestJoinedSample": - return TestJoinedSample.derive_from( + def from_joined(ds1: TextSample, ds2: TextSample) -> "JoinedSample": + return JoinedSample.derive_from( ds1, text1=ds1.text, text2=ds2.text, ) -def test_joiner(text1: TextSample, text2: TextSample) -> TestJoinedSample: - return TestJoinedSample.derive_from(text1, text1=f"j{text1.text}", text2=f"j{text2.text}") - - -class TestDataset(unittest.TestCase): - # Set up the test fixture - def setUp(self): - random.seed(42) - - logging.basicConfig(stream=sys.stderr, level=logging.INFO) - warnings.simplefilter("ignore", ResourceWarning) - - # Create a temporary directory - self.temp_dir = tempfile.TemporaryDirectory() - self.dataset_path = Path(self.temp_dir.name) - # self.dataset_path = Path("./test_dataset") - - self.dataset_path.mkdir(exist_ok=True, parents=True) - - # Create a small dummy datasets - self.create_text_test_dataset(self.dataset_path / "ds1", range(55), range(55)) - self.create_text_test_dataset(self.dataset_path / "ds2", range(100, 155), range(100, 155)) - self.create_text_test_dataset(self.dataset_path / "ds3", range(200, 255), range(55)) - - # Create a shuffled dataset for joining with the ds1. It has overlap but includes more samples - shuffled_range_100 = list(range(100)) - random.shuffle(shuffled_range_100) - - self.create_text_test_dataset( - self.dataset_path / "ds1b", shuffled_range_100, shuffled_range_100, prefix="B" +def my_joiner(text1: TextSample, text2: TextSample) -> JoinedSample: + return JoinedSample.derive_from(text1, text1=f"j{text1.text}", text2=f"j{text2.text}") + + +@pytest.fixture +def temp_dir(): + """Create a temporary directory for test files.""" + temp_dir = tempfile.TemporaryDirectory() + yield temp_dir + gc.collect() + temp_dir.cleanup() + + +@pytest.fixture +def dataset_path(temp_dir): + """Create dataset path and setup test data.""" + random.seed(42) + logging.basicConfig(stream=sys.stderr, level=logging.INFO) + warnings.simplefilter("ignore", ResourceWarning) + + dataset_path = Path(temp_dir.name) + dataset_path.mkdir(exist_ok=True, parents=True) + + # Create a small dummy datasets + create_text_test_dataset(dataset_path / "ds1", range(55), range(55)) + create_text_test_dataset(dataset_path / "ds2", range(100, 155), range(100, 155)) + create_text_test_dataset(dataset_path / "ds3", range(200, 255), range(55)) + + # Create a shuffled dataset for joining with the ds1. It has overlap but includes more samples + shuffled_range_100 = list(range(100)) + random.shuffle(shuffled_range_100) + + create_text_test_dataset( + dataset_path / "ds1b", shuffled_range_100, shuffled_range_100, prefix="B" + ) + + shuffled_range_100 = list(range(100)) + random.shuffle(shuffled_range_100) + create_text_test_dataset( + dataset_path / "ds1c", shuffled_range_100, shuffled_range_100, prefix="C" + ) + + mds_path = dataset_path / "metadataset_v2.yaml" + with open(mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " blend:", + " - weight: 1", + " path: ds1", + " subflavors:", + " source: metadataset_v2.yaml", + " number: 43", + " mds: mds", + " shuffle_over_epochs_multiplier: 3", + " - weight: 1", + " path: ds2", + " subflavors:", + " source: metadataset_v2.yaml", + " number: 44", + " mds: mds", + " val:", + " blend:", + " - weight: 1", + " path: ds1", + " split_part: train", + " - weight: 1", + " path: ds2", + " split_part: train", + ] + ) ) - - shuffled_range_100 = list(range(100)) - random.shuffle(shuffled_range_100) - self.create_text_test_dataset( - self.dataset_path / "ds1c", shuffled_range_100, shuffled_range_100, prefix="C" + nested_mds_path = dataset_path / "nested_metadataset_v2.yaml" + with open(nested_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " blend:", + " - weight: 4", + " path: ./metadataset_v2.yaml", + " split_part: train", + " subflavors:", + " source: nested_metadataset.yaml", + " mds: nested_train", + " - path: ./metadataset_v2.yaml", + " split_part: val", + " subflavors:", + " source: nested_metadataset.yaml", + " mds: nested_val", + ] + ) ) - - self.mds_path = self.dataset_path / "metadataset_v2.yaml" - with open(self.mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " blend:", - " - weight: 1", - " path: ds1", - " subflavors:", - " source: metadataset_v2.yaml", - " number: 43", - " mds: mds", - " shuffle_over_epochs_multiplier: 3", - " - weight: 1", - " path: ds2", - " subflavors:", - " source: metadataset_v2.yaml", - " number: 44", - " mds: mds", - " val:", - " blend:", - " - weight: 1", - " path: ds1", - " split_part: train", - " - weight: 1", - " path: ds2", - " split_part: train", - ] - ) + print(dataset_path) + return dataset_path + + +def create_text_test_dataset( + path: Path, txt_range: Iterable[int], key_range: Iterable[int], prefix: str = "" +): + """Creates a small dummy test dataset for testing purposes.""" + + # Create num_samples unique captions + (path / "parts").mkdir(exist_ok=True, parents=True) + + # Initialize the ShardWriter + with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=10) as shard_writer: + for key, txt in zip(key_range, txt_range): + # Write individual files to shards + shard_writer.write( + { + "__key__": f"{key:06d}", + "txt": f"{prefix}{txt}".encode(), + }, ) - self.nested_mds_path = self.dataset_path / "nested_metadataset_v2.yaml" - with open(self.nested_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " blend:", - " - weight: 4", - " path: ./metadataset_v2.yaml", - " split_part: train", - " subflavors:", - " source: nested_metadataset.yaml", - " mds: nested_train", - " - path: ./metadataset_v2.yaml", - " split_part: val", - " subflavors:", - " source: nested_metadataset.yaml", - " mds: nested_val", - ] - ) + total_shards = shard_writer.shard + + from megatron.energon.flavors import BaseWebdatasetFactory + + BaseWebdatasetFactory.prepare_dataset( + path, + [f"parts/data-{{0..{total_shards - 1}}}.tar"], + split_parts_ratio=[("train", 1.0)], + shuffle_seed=None, + ) + + with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: + f.write( + "\n".join( + [ + "sample_type:", + " __module__: megatron.energon", + " __class__: TextSample", + "field_map:", + " text: txt", + "subflavors:", + " source: dataset.yaml", + " dataset.yaml: true", + " number: 42", + ] ) - print(self.dataset_path) - - def tearDown(self): - # Remove all temporary files - gc.collect() - self.temp_dir.cleanup() - - @staticmethod - def create_text_test_dataset( - path: Path, txt_range: Iterable[int], key_range: Iterable[int], prefix: str = "" - ): - """Creates a small dummy test dataset for testing purposes.""" - - # Create num_samples unique captions - (path / "parts").mkdir(exist_ok=True, parents=True) - - # Initialize the ShardWriter - with wds.ShardWriter(f"{path}/parts/data-%d.tar", maxcount=10) as shard_writer: - for key, txt in zip(key_range, txt_range): - # Write individual files to shards - shard_writer.write( - { - "__key__": f"{key:06d}", - "txt": f"{prefix}{txt}".encode(), - }, - ) - total_shards = shard_writer.shard - - from megatron.energon.flavors import BaseWebdatasetFactory - - BaseWebdatasetFactory.prepare_dataset( - path, - [f"parts/data-{{0..{total_shards - 1}}}.tar"], - split_parts_ratio=[("train", 1.0)], - shuffle_seed=None, ) - with open(path / MAIN_FOLDER_NAME / "dataset.yaml", "w") as f: - f.write( - "\n".join( - [ - "sample_type:", - " __module__: megatron.energon", - " __class__: TextSample", - "field_map:", - " text: txt", - "subflavors:", - " source: dataset.yaml", - " dataset.yaml: true", - " number: 42", - ] - ) - ) - def test_metadataset(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, +def test_metadataset(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + + # Train mode dataset + train_dataset = get_train_dataset( + dataset_path / "metadataset_v2.yaml", + worker_config=worker_config, + batch_size=10, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + print(len(train_dataset)) + assert len(train_dataset) == 11 + + with get_loader(train_dataset) as train_loader1: + train_order1 = [ + text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text + ] + print(train_order1[:10]) + print(Counter(train_order1)) + assert len(Counter(train_order1)) == 110 + assert all(48 <= v <= 52 for v in Counter(train_order1).values()) + + +def test_nested_metadataset(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + ) + + dataset = load_dataset(dataset_path / "nested_metadataset_v2.yaml") + + raw_datasets = dataset.get_datasets( + training=False, split_part="train", worker_config=worker_config + ) + assert raw_datasets.blend_mode == DatasetBlendMode.DATASET_WEIGHT + assert [raw_dataset.weight for raw_dataset in raw_datasets.datasets] == [ + 0.4, + 0.4, + 0.1, + 0.1, + ], [raw_dataset.weight for raw_dataset in raw_datasets.datasets] + assert [raw_dataset.dataset.paths[0].name for raw_dataset in raw_datasets.datasets] == [ + "ds1", + "ds2", + "ds1", + "ds2", + ] + print([raw_dataset.dataset.subflavors for raw_dataset in raw_datasets.datasets]) + assert [raw_dataset.dataset.subflavors for raw_dataset in raw_datasets.datasets] == [ + { + "source": "nested_metadataset.yaml", + "dataset.yaml": True, + "number": 43, + "mds": "nested_train", + }, + { + "source": "nested_metadataset.yaml", + "dataset.yaml": True, + "number": 44, + "mds": "nested_train", + }, + { + "source": "nested_metadataset.yaml", + "dataset.yaml": True, + "number": 42, + "mds": "nested_val", + }, + { + "source": "nested_metadataset.yaml", + "dataset.yaml": True, + "number": 42, + "mds": "nested_val", + }, + ] + + +def test_joined_metadataset(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + + # Create a joined dataset configuration + joined_mds_path = dataset_path / "joined_metadataset_v2.yaml" + with open(joined_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " join:", + " ds1:", + " path: ds1", + " subflavors:", + " source1: ds1", + " number: 43", + " ds2:", + " path: ds3", + " subflavors:", + " source2: ds3", + " number: 44", + " joiner:", + f" __module__: {JoinedSample.__module__}", + f" __class__: {JoinedSample.__name__}", + ] + ) ) - - # Train mode dataset - train_dataset = get_train_dataset( - self.mds_path, + prepare_metadataset(EPath(joined_mds_path)) + + # Train mode dataset + train_dataset = get_train_dataset( + joined_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + print(len(train_dataset)) + assert len(train_dataset) == 55 + + with get_savable_loader( + train_dataset, + ) as train_loader: + data = list(zip(range(2 * 55), train_loader)) + txt1_order = [data.text1[0] for idx, data in data] + txt2_order = [data.text2[0] for idx, data in data] + key_order = [data.__key__[0] for idx, data in data] + # ds1 has 55 samples, key range 0:55, txt range 0:55 + # ds3 has 28 samples, key range 0:55, txt range 200:255 + # Joining results in: 0:55 + print("txt1:", txt1_order) + # Joining results in: 200:255 + print("txt2:", txt2_order) + # Joining results in: 0:55 + print("key:", key_order) + # Check matching + assert all(int(txt1) + 200 == int(txt2) for txt1, txt2 in zip(txt1_order, txt2_order)) + # Check frequency + assert set(txt1_order) == set(str(i) for i in range(0, 55)) + assert set(txt2_order) == set(str(i) for i in range(200, 255)) + # Every item must occurr 2 times (2*55). + assert Counter(txt1_order).most_common(1)[0][1] == 2 + + state = train_loader.save_state_rank() + + # Iterate 60 more items + data = list(zip(range(60), train_loader)) + txt1_order = [data.text1 for idx, data in data] + txt2_order = [data.text2 for idx, data in data] + key_order = [data.__key__ for idx, data in data] + + # Restore state + with get_savable_loader( + get_train_dataset( + joined_mds_path, worker_config=worker_config, - batch_size=10, + batch_size=1, shuffle_buffer_size=None, max_samples_per_sequence=None, + ), + ).with_restored_state_rank(state) as train_loader: + # Iterate 360 more items + data = list(zip(range(60), train_loader)) + txt1_order_rest = [data.text1 for idx, data in data] + txt2_order_rest = [data.text2 for idx, data in data] + key_order_rest = [data.__key__ for idx, data in data] + + # Verify matching + assert txt1_order == txt1_order_rest + assert txt2_order == txt2_order_rest + assert key_order == key_order_rest + + +def test_joined_metadataset_joiner(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + + # Create a joined dataset configuration + joined_mds_path = dataset_path / "joined_metadataset_joiner.yaml" + with open(joined_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " blend:", + " - weight: 1", + " join:", + " text1:", + " path: ds1", + " subflavors:", + " source1: ds1", + " number: 43", + " text2:", + " path: ds3", + " subflavors:", + " source2: ds3", + " number: 44", + " joiner:", + f" __module__: {my_joiner.__module__}", + f" __function__: {my_joiner.__name__}", + ] + ) ) - print(len(train_dataset)) - assert len(train_dataset) == 11 - - with get_loader(train_dataset) as train_loader1: - train_order1 = [ - text for idx, data in zip(range(55 * 10), train_loader1) for text in data.text - ] - print(train_order1[:10]) - print(Counter(train_order1)) - assert len(Counter(train_order1)) == 110 - assert all(48 <= v <= 52 for v in Counter(train_order1).values()) - - def test_nested_metadataset(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - ) - - dataset = load_dataset(self.nested_mds_path) - - raw_datasets = dataset.get_datasets( - training=False, split_part="train", worker_config=worker_config + prepare_metadataset(EPath(joined_mds_path)) + + # Train mode dataset + train_dataset = get_train_dataset( + joined_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + print(len(train_dataset)) + assert len(train_dataset) == 55 + + with get_savable_loader( + train_dataset, + ) as train_loader: + data = list(zip(range(2 * 55), train_loader)) + txt1_order = [data.text1[0] for idx, data in data] + txt2_order = [data.text2[0] for idx, data in data] + key_order = [data.__key__[0] for idx, data in data] + # ds1 has 55 samples, key range 0:55, txt range 0:55 + # ds3 has 28 samples, key range 0:55, txt range 200:255 + # Joining results in: 0:55, with prefix "j" + print("txt1:", txt1_order) + # Joining results in: 200:255, with prefix "j" + print("txt2:", txt2_order) + # Joining results in: 0:55 + print("key:", key_order) + # Check matching + assert all( + int(txt1[1:]) + 200 == int(txt2[1:]) for txt1, txt2 in zip(txt1_order, txt2_order) ) - assert raw_datasets.blend_mode == DatasetBlendMode.DATASET_WEIGHT - assert [raw_dataset.weight for raw_dataset in raw_datasets.datasets] == [ - 0.4, - 0.4, - 0.1, - 0.1, - ], [raw_dataset.weight for raw_dataset in raw_datasets.datasets] - assert [raw_dataset.dataset.paths[0].name for raw_dataset in raw_datasets.datasets] == [ - "ds1", - "ds2", - "ds1", - "ds2", - ] - print([raw_dataset.dataset.subflavors for raw_dataset in raw_datasets.datasets]) - assert [raw_dataset.dataset.subflavors for raw_dataset in raw_datasets.datasets] == [ - { - "source": "nested_metadataset.yaml", - "dataset.yaml": True, - "number": 43, - "mds": "nested_train", - }, - { - "source": "nested_metadataset.yaml", - "dataset.yaml": True, - "number": 44, - "mds": "nested_train", - }, - { - "source": "nested_metadataset.yaml", - "dataset.yaml": True, - "number": 42, - "mds": "nested_val", - }, - { - "source": "nested_metadataset.yaml", - "dataset.yaml": True, - "number": 42, - "mds": "nested_val", - }, - ] - - def test_joined_metadataset(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, + # Check frequency + assert set(txt1_order) == set(f"j{i}" for i in range(0, 55)) + assert set(txt2_order) == set(f"j{i}" for i in range(200, 255)) + # Every item must occurr 2 times (2*55). + assert Counter(txt1_order).most_common(1)[0][1] == 2 + + +def test_left_join(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + + # Create a joined dataset configuration + joined_mds_path = dataset_path / "left_join.yaml" + with open(joined_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " blend:", + " - weight: 1", + " join:", + " text1:", + " path: ds1", + " subflavors:", + " source1: ds1", + " number: 43", + " text2:", + " path: ds1b", + " nonmatch: skip", + " subflavors:", + " source2: ds1b", + " number: 44", + " joiner:", + f" __module__: {my_joiner.__module__}", + f" __function__: {my_joiner.__name__}", + ] + ) ) - - # Create a joined dataset configuration - joined_mds_path = self.dataset_path / "joined_metadataset_v2.yaml" - with open(joined_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " join:", - " ds1:", - " path: ds1", - " subflavors:", - " source1: ds1", - " number: 43", - " ds2:", - " path: ds3", - " subflavors:", - " source2: ds3", - " number: 44", - " joiner:", - f" __module__: {TestJoinedSample.__module__}", - f" __class__: {TestJoinedSample.__name__}", - ] - ) + prepare_metadataset(EPath(joined_mds_path)) + + # Train mode dataset + train_dataset = get_train_dataset( + joined_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + print(len(train_dataset)) + assert len(train_dataset) == 55, len(train_dataset) + + with get_savable_loader( + train_dataset, + ) as train_loader: + data = list(zip(range(2 * 55), train_loader)) + txt1_order = [data.text1[0] for idx, data in data] + txt2_order = [data.text2[0] for idx, data in data] + key_order = [data.__key__[0] for idx, data in data] + # ds1 has 55 samples, key range 0:55, txt range 0:55 + # ds3 has 28 samples, key range 0:55, txt range 200:255 + # Joining results in: 0:55, with prefix "j" + print("txt1:", txt1_order) + # Joining results in: 200:255, with prefix "j" + print("txt2:", txt2_order) + # Joining results in: 0:55 + print("key:", key_order) + # Check matching + assert all(int(txt1[1:]) == int(txt2[2:]) for txt1, txt2 in zip(txt1_order, txt2_order)) + # Check frequency + assert set(txt1_order) == set(f"j{i}" for i in range(55)) + assert set(txt2_order) == set(f"jB{i}" for i in range(55)) + # Every item must occurr 2 times (2*55). + assert Counter(txt1_order).most_common(1)[0][1] == 2 + + # Test that changing the file works as expected + with open(joined_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " blend:", + " - weight: 1", + " join:", + " text1:", + " path: ds1c", + " subflavors:", + " source1: ds1c", + " number: 43", + " text2:", + " path: ds1b", + " nonmatch: skip", + " subflavors:", + " source2: ds1b", + " number: 44", + " joiner:", + f" __module__: {my_joiner.__module__}", + f" __function__: {my_joiner.__name__}", + " - weight: 1", + " join:", + " text1:", + " path: ds1b", + " text2:", + " path: ds1", + " nonmatch: skip", + " joiner:", + f" __module__: {my_joiner.__module__}", + f" __function__: {my_joiner.__name__}", + ] ) - prepare_metadataset(EPath(joined_mds_path)) + ) + # Expect this to fail. Preparation does not match! + with pytest.raises(Exception): # Train mode dataset train_dataset = get_train_dataset( joined_mds_path, @@ -356,581 +617,377 @@ def test_joined_metadataset(self): shuffle_buffer_size=None, max_samples_per_sequence=None, ) - print(len(train_dataset)) - assert len(train_dataset) == 55 - with get_savable_loader( - train_dataset, - ) as train_loader: - data = list(zip(range(2 * 55), train_loader)) - txt1_order = [data.text1[0] for idx, data in data] - txt2_order = [data.text2[0] for idx, data in data] - key_order = [data.__key__[0] for idx, data in data] - # ds1 has 55 samples, key range 0:55, txt range 0:55 - # ds3 has 28 samples, key range 0:55, txt range 200:255 - # Joining results in: 0:55 - print("txt1:", txt1_order) - # Joining results in: 200:255 - print("txt2:", txt2_order) - # Joining results in: 0:55 - print("key:", key_order) - # Check matching - assert all(int(txt1) + 200 == int(txt2) for txt1, txt2 in zip(txt1_order, txt2_order)) - # Check frequency - assert set(txt1_order) == set(str(i) for i in range(0, 55)) - assert set(txt2_order) == set(str(i) for i in range(200, 255)) - # Every item must occurr 2 times (2*55). - assert Counter(txt1_order).most_common(1)[0][1] == 2 - - state = train_loader.save_state_rank() - - # Iterate 60 more items - data = list(zip(range(60), train_loader)) - txt1_order = [data.text1 for idx, data in data] - txt2_order = [data.text2 for idx, data in data] - key_order = [data.__key__ for idx, data in data] - - # Restore state - with get_savable_loader( - get_train_dataset( - joined_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ), - ).with_restored_state_rank(state) as train_loader: - # Iterate 360 more items - data = list(zip(range(60), train_loader)) - txt1_order_rest = [data.text1 for idx, data in data] - txt2_order_rest = [data.text2 for idx, data in data] - key_order_rest = [data.__key__ for idx, data in data] - - # Verify matching - assert txt1_order == txt1_order_rest - assert txt2_order == txt2_order_rest - assert key_order == key_order_rest - - def test_joined_metadataset_joiner(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, + # Shall succeed after preparation + prepare_metadataset(EPath(joined_mds_path)) + train_dataset = get_train_dataset( + joined_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + # Check that there are no remainder files + cache_folder = joined_mds_path.with_name(joined_mds_path.name + ".cache") + assert sum(1 for f in cache_folder.iterdir() if f.is_file()) == 2, list(cache_folder.iterdir()) + + +def test_left_join_exclude(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + + # Create a joined dataset configuration + orig_split_path = dataset_path / "ds1" / ".nv-meta" / "split.yaml" + exclude_split_path = dataset_path / "ds1" / ".nv-meta" / "exclude_split.yaml" + with open(exclude_split_path, "w") as f: + f.write( + "\n".join( + [ + orig_split_path.read_text(), + "exclude:", + ' - "parts/data-0.tar/000000"', + ' - "parts/data-0.tar/000001"', + ' - "parts/data-0.tar/000002"', + ' - "parts/data-0.tar/000003"', + ' - "parts/data-0.tar/000004"', + ' - "parts/data-1.tar"', + ' - "parts/data-2.tar/000029"', + ] + ) ) - # Create a joined dataset configuration - joined_mds_path = self.dataset_path / "joined_metadataset_joiner.yaml" - with open(joined_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " blend:", - " - weight: 1", - " join:", - " text1:", - " path: ds1", - " subflavors:", - " source1: ds1", - " number: 43", - " text2:", - " path: ds3", - " subflavors:", - " source2: ds3", - " number: 44", - " joiner:", - f" __module__: {test_joiner.__module__}", - f" __function__: {test_joiner.__name__}", - ] - ) + # Create a joined dataset configuration + joined_mds_path = dataset_path / "left_join.yaml" + with open(joined_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " blend:", + " - weight: 1", + " join:", + " text1:", + " path: ds1", + " split_config: exclude_split.yaml", + " text2:", + " path: ds1b", + " nonmatch: skip", + " joiner:", + f" __module__: {my_joiner.__module__}", + f" __function__: {my_joiner.__name__}", + ] ) - prepare_metadataset(EPath(joined_mds_path)) - - # Train mode dataset - train_dataset = get_train_dataset( - joined_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, ) - print(len(train_dataset)) - assert len(train_dataset) == 55 - - with get_savable_loader( - train_dataset, - ) as train_loader: - data = list(zip(range(2 * 55), train_loader)) - txt1_order = [data.text1[0] for idx, data in data] - txt2_order = [data.text2[0] for idx, data in data] - key_order = [data.__key__[0] for idx, data in data] - # ds1 has 55 samples, key range 0:55, txt range 0:55 - # ds3 has 28 samples, key range 0:55, txt range 200:255 - # Joining results in: 0:55, with prefix "j" - print("txt1:", txt1_order) - # Joining results in: 200:255, with prefix "j" - print("txt2:", txt2_order) - # Joining results in: 0:55 - print("key:", key_order) - # Check matching - assert all( - int(txt1[1:]) + 200 == int(txt2[1:]) for txt1, txt2 in zip(txt1_order, txt2_order) + prepare_metadataset(EPath(joined_mds_path)) + + # Train mode dataset + train_dataset = get_train_dataset( + joined_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + print(len(train_dataset)) + assert len(train_dataset) == 55 - 16, len(train_dataset) + + with get_savable_loader( + train_dataset, + ) as train_loader: + data = list(zip(range(2 * 55), train_loader)) + txt1_order = [data.text1[0] for idx, data in data] + txt2_order = [data.text2[0] for idx, data in data] + key_order = [data.__key__[0] for idx, data in data] + # ds1 has 55 samples, key range 0:55, txt range 0:55 + # ds3 has 28 samples, key range 0:55, txt range 200:255 + # Joining results in: 0:55, with prefix "j" + print("txt1:", txt1_order) + # Joining results in: 200:255, with prefix "j" + print("txt2:", txt2_order) + # Joining results in: 0:55 + print("key:", key_order) + # Check matching + assert all(int(txt1[1:]) == int(txt2[2:]) for txt1, txt2 in zip(txt1_order, txt2_order)) + # Check frequency + set_filtered_nums = set(range(5, 10)) | set(range(20, 29)) | set(range(30, 55)) + assert set(txt1_order) == set(f"j{i}" for i in set_filtered_nums) + assert set(txt2_order) == set(f"jB{i}" for i in set_filtered_nums) + + +def test_joined_metadataset_prepare_mock(dataset_path): + torch.manual_seed(42) + + # Create a joined dataset configuration + joined_mds_path = dataset_path / "joined_metadataset_prepare_mock.yaml" + with open(joined_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " join:", + " - path: ds1", + " - path: ds3", + " joiner:", + " __module__: __main__", + " __class__: NonExistantSample", + ] ) - # Check frequency - assert set(txt1_order) == set(f"j{i}" for i in range(0, 55)) - assert set(txt2_order) == set(f"j{i}" for i in range(200, 255)) - # Every item must occurr 2 times (2*55). - assert Counter(txt1_order).most_common(1)[0][1] == 2 - - def test_left_join(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, ) - - # Create a joined dataset configuration - joined_mds_path = self.dataset_path / "left_join.yaml" - with open(joined_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " blend:", - " - weight: 1", - " join:", - " text1:", - " path: ds1", - " subflavors:", - " source1: ds1", - " number: 43", - " text2:", - " path: ds1b", - " nonmatch: skip", - " subflavors:", - " source2: ds1b", - " number: 44", - " joiner:", - f" __module__: {test_joiner.__module__}", - f" __function__: {test_joiner.__name__}", - ] - ) + prepare_metadataset(EPath(joined_mds_path)) + + # Create a joined dataset configuration + joined_mds_path = dataset_path / "joined_metadataset_prepare_mock2.yaml" + with open(joined_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " join:", + " - path: ds1", + " - path: ds3", + " joiner:", + " __module__: non_existant_module", + " __class__: MyCaptioningSample", + ] + ) + ) + prepare_metadataset(EPath(joined_mds_path)) + + +def test_metadataset_fixed_epochs(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + + # Create a joined dataset configuration + fixed_epochs_mds_path = dataset_path / "metadataset_fixed_epochs.yaml" + with open(fixed_epochs_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " blend_epochized:", + " - repetitions: 2", + " path: ds1", + " subflavors:", + " source: ds1", + " number: 43", + " - repetitions: 3", + " path: ds2", + " subflavors:", + " source: ds2", + " number: 42", + ] ) - prepare_metadataset(EPath(joined_mds_path)) + ) - # Train mode dataset - train_dataset = get_train_dataset( - joined_mds_path, + # Train mode dataset + train_dataset = get_train_dataset( + fixed_epochs_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + repeat=False, + ) + print(len(train_dataset)) + assert len(train_dataset) == 5 * 55, len(train_dataset) + + with get_savable_loader( + train_dataset, + ) as train_loader: + data = list(enumerate(train_loader)) + txt_order = [data.text[0] for idx, data in data] + key_order = [data.__subflavors__[0]["source"] + "/" + data.__key__[0] for idx, data in data] + print("txt1:", txt_order) + print("key:", key_order) + assert len(txt_order) == 5 * 55, Counter(txt_order) + ds1_keys = [key for key in key_order if key.startswith("ds1/")] + ds2_keys = [key for key in key_order if key.startswith("ds2/")] + txt_cnt = Counter(txt_order) + ds1_key_cnt = Counter(ds1_keys) + ds2_key_cnt = Counter(ds2_keys) + assert len(ds1_keys) == 2 * 55, (len(ds1_keys), ds1_key_cnt) + assert len(ds2_keys) == 3 * 55, (len(ds2_keys), ds2_key_cnt) + assert all(ds1_key_cnt[key] == 2 for key in ds1_keys) + assert all(ds2_key_cnt[key] == 3 for key in ds2_keys) + assert all(txt_cnt[key] in (2, 3) for key in txt_order) + + # Next epoch + data = list(enumerate(train_loader)) + print([data.text[0] for idx, data in data]) + assert len(data) == 5 * 55, len(data) + + # Next epoch + data1 = list(zip(range(3 * 55), train_loader)) + assert len(data1) == 3 * 55, len(data1) + # Save state mid epoch + state1 = train_loader.save_state_rank() + print(state1) + + data2 = list(enumerate(train_loader)) + assert len(data2) == 2 * 55 + txt_order = [data.text[0] for idx, data in data1 + data2] + key_order = [ + data.__subflavors__[0]["source"] + "/" + data.__key__[0] for idx, data in data1 + data2 + ] + assert len(txt_order) == 5 * 55, Counter(txt_order) + ds1_keys = [key for key in key_order if key.startswith("ds1/")] + ds2_keys = [key for key in key_order if key.startswith("ds2/")] + txt_cnt = Counter(txt_order) + ds1_key_cnt = Counter(ds1_keys) + ds2_key_cnt = Counter(ds2_keys) + assert len(ds1_keys) == 2 * 55, (len(ds1_keys), ds1_key_cnt) + assert len(ds2_keys) == 3 * 55, (len(ds2_keys), ds2_key_cnt) + assert all(ds1_key_cnt[key] == 2 for key in ds1_keys) + assert all(ds2_key_cnt[key] == 3 for key in ds2_keys) + assert all(txt_cnt[key] in (2, 3) for key in txt_order) + + # Restore state + with get_savable_loader( + get_train_dataset( + fixed_epochs_mds_path, worker_config=worker_config, batch_size=1, shuffle_buffer_size=None, max_samples_per_sequence=None, + repeat=False, + ), + ).with_restored_state_rank(state1) as train_loader: + data2_restore = list(enumerate(train_loader)) + assert len(data2_restore) == 2 * 55 + txt_order_rst = [data.text[0] for idx, data in data1 + data2_restore] + key_order_rst = [ + data.__subflavors__[0]["source"] + "/" + data.__key__[0] + for idx, data in data1 + data2_restore + ] + assert len(txt_order_rst) == 5 * 55, Counter(txt_order_rst) + assert txt_order_rst == txt_order + assert key_order_rst == key_order + ds1_keys_rst = [key for key in key_order_rst if key.startswith("ds1/")] + ds2_keys_rst = [key for key in key_order_rst if key.startswith("ds2/")] + txt_cnt_rst = Counter(txt_order_rst) + ds1_key_cnt_rst = Counter(ds1_keys_rst) + ds2_key_cnt_rst = Counter(ds2_keys_rst) + assert len(ds1_keys_rst) == 2 * 55, (len(ds1_keys_rst), ds1_key_cnt_rst) + assert len(ds2_keys_rst) == 3 * 55, (len(ds2_keys_rst), ds2_key_cnt_rst) + assert all(ds1_key_cnt_rst[key] == 2 for key in ds1_keys_rst) + assert all(ds2_key_cnt_rst[key] == 3 for key in ds2_keys_rst) + assert all(txt_cnt_rst[key] in (2, 3) for key in txt_order_rst) + + +def test_metadataset_fixed_fractional_epochs(dataset_path): + torch.manual_seed(42) + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + + # Create a joined dataset configuration + fixed_epochs_mds_path = dataset_path / "metadataset_fixed_epochs.yaml" + with open(fixed_epochs_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " blend_epochized:", + " - repetitions: 0.7", + " path: ds1", + " subflavors:", + " source: ds1", + " number: 43", + " - repetitions: 1.5", + " path: ds2", + " subflavors:", + " source: ds2", + " number: 42", + ] + ) ) - print(len(train_dataset)) - assert len(train_dataset) == 55, len(train_dataset) - with get_savable_loader( - train_dataset, - ) as train_loader: - data = list(zip(range(2 * 55), train_loader)) - txt1_order = [data.text1[0] for idx, data in data] - txt2_order = [data.text2[0] for idx, data in data] - key_order = [data.__key__[0] for idx, data in data] - # ds1 has 55 samples, key range 0:55, txt range 0:55 - # ds3 has 28 samples, key range 0:55, txt range 200:255 - # Joining results in: 0:55, with prefix "j" - print("txt1:", txt1_order) - # Joining results in: 200:255, with prefix "j" - print("txt2:", txt2_order) - # Joining results in: 0:55 - print("key:", key_order) - # Check matching - assert all(int(txt1[1:]) == int(txt2[2:]) for txt1, txt2 in zip(txt1_order, txt2_order)) - # Check frequency - assert set(txt1_order) == set(f"j{i}" for i in range(55)) - assert set(txt2_order) == set(f"jB{i}" for i in range(55)) - # Every item must occurr 2 times (2*55). - assert Counter(txt1_order).most_common(1)[0][1] == 2 - - # Test that changing the file works as expected - with open(joined_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " blend:", - " - weight: 1", - " join:", - " text1:", - " path: ds1c", - " subflavors:", - " source1: ds1c", - " number: 43", - " text2:", - " path: ds1b", - " nonmatch: skip", - " subflavors:", - " source2: ds1b", - " number: 44", - " joiner:", - f" __module__: {test_joiner.__module__}", - f" __function__: {test_joiner.__name__}", - " - weight: 1", - " join:", - " text1:", - " path: ds1b", - " text2:", - " path: ds1", - " nonmatch: skip", - " joiner:", - f" __module__: {test_joiner.__module__}", - f" __function__: {test_joiner.__name__}", - ] - ) - ) + # ===== Part 1: Verify fractions ===== - # Expect this to fail. Preparation does not match! - with self.assertRaises(Exception): - # Train mode dataset - train_dataset = get_train_dataset( - joined_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ) + # Train mode dataset + train_dataset = get_train_dataset( + fixed_epochs_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + shuffle_over_epochs_multiplier=None, + parallel_shard_iters=1, + max_samples_per_sequence=None, + repeat=False, + ) - # Shall succeed after preparation - prepare_metadataset(EPath(joined_mds_path)) - train_dataset = get_train_dataset( - joined_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ) - # Check that there are no remainder files - cache_folder = joined_mds_path.with_name(joined_mds_path.name + ".cache") - assert sum(1 for f in cache_folder.iterdir() if f.is_file()) == 2, list( - cache_folder.iterdir() - ) + with get_savable_loader( + train_dataset, + ) as train_loader: + assert len(train_loader) == 38 + 55 + 27, len(train_loader) - def test_left_join_exclude(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, - ) + data = list(enumerate(train_loader)) - # Create a joined dataset configuration - orig_split_path = self.dataset_path / "ds1" / ".nv-meta" / "split.yaml" - exclude_split_path = self.dataset_path / "ds1" / ".nv-meta" / "exclude_split.yaml" - with open(exclude_split_path, "w") as f: - f.write( - "\n".join( - [ - orig_split_path.read_text(), - "exclude:", - ' - "parts/data-0.tar/000000"', - ' - "parts/data-0.tar/000001"', - ' - "parts/data-0.tar/000002"', - ' - "parts/data-0.tar/000003"', - ' - "parts/data-0.tar/000004"', - ' - "parts/data-1.tar"', - ' - "parts/data-2.tar/000029"', - ] - ) - ) + # Check the overall number of samples + # Should be 0.7*len(ds1) + 1.5*len(ds2) = 0.7*55 + 1.5*55 = 38 + 55 + 27 (floor rounding) + assert len(data) == 38 + 55 + 27, len(data) - # Create a joined dataset configuration - joined_mds_path = self.dataset_path / "left_join.yaml" - with open(joined_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " blend:", - " - weight: 1", - " join:", - " text1:", - " path: ds1", - " split_config: exclude_split.yaml", - " text2:", - " path: ds1b", - " nonmatch: skip", - " joiner:", - f" __module__: {test_joiner.__module__}", - f" __function__: {test_joiner.__name__}", - ] - ) - ) - prepare_metadataset(EPath(joined_mds_path)) + sample_counts = Counter([int(s[1].text[0]) for s in data]) - # Train mode dataset - train_dataset = get_train_dataset( - joined_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - ) - print(len(train_dataset)) - assert len(train_dataset) == 55 - 16, len(train_dataset) + # The first 70% of samples from ds1 (0 to incl. 37) should be repeated only once + assert all(sample_counts[sample] == 1 for sample in range(38)) - with get_savable_loader( - train_dataset, - ) as train_loader: - data = list(zip(range(2 * 55), train_loader)) - txt1_order = [data.text1[0] for idx, data in data] - txt2_order = [data.text2[0] for idx, data in data] - key_order = [data.__key__[0] for idx, data in data] - # ds1 has 55 samples, key range 0:55, txt range 0:55 - # ds3 has 28 samples, key range 0:55, txt range 200:255 - # Joining results in: 0:55, with prefix "j" - print("txt1:", txt1_order) - # Joining results in: 200:255, with prefix "j" - print("txt2:", txt2_order) - # Joining results in: 0:55 - print("key:", key_order) - # Check matching - assert all(int(txt1[1:]) == int(txt2[2:]) for txt1, txt2 in zip(txt1_order, txt2_order)) - # Check frequency - set_filtered_nums = set(range(5, 10)) | set(range(20, 29)) | set(range(30, 55)) - assert set(txt1_order) == set(f"j{i}" for i in set_filtered_nums) - assert set(txt2_order) == set(f"jB{i}" for i in set_filtered_nums) - - def test_joined_metadataset_prepare_mock(self): - torch.manual_seed(42) - - # Create a joined dataset configuration - joined_mds_path = self.dataset_path / "joined_metadataset_prepare_mock.yaml" - with open(joined_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " join:", - " - path: ds1", - " - path: ds3", - " joiner:", - " __module__: __main__", - " __class__: NonExistantSample", - ] - ) - ) - prepare_metadataset(EPath(joined_mds_path)) - - # Create a joined dataset configuration - joined_mds_path = self.dataset_path / "joined_metadataset_prepare_mock2.yaml" - with open(joined_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " join:", - " - path: ds1", - " - path: ds3", - " joiner:", - " __module__: non_existant_module", - " __class__: MyCaptioningSample", - ] - ) - ) - prepare_metadataset(EPath(joined_mds_path)) + # Since ds2 is repeated 1.5 times, the first 50% of samples from ds2 (100 to incl. 126) should be repeated twice + assert all(sample_counts[sample] == 2 for sample in range(100, 127)) - def test_metadataset_fixed_epochs(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, - ) + # The remaining samples from ds2 (127 to incl. 154) should be repeated only once + assert all(sample_counts[sample] == 1 for sample in range(127, 155)) - # Create a joined dataset configuration - fixed_epochs_mds_path = self.dataset_path / "metadataset_fixed_epochs.yaml" - with open(fixed_epochs_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " blend_epochized:", - " - repetitions: 2", - " path: ds1", - " subflavors:", - " source: ds1", - " number: 43", - " - repetitions: 3", - " path: ds2", - " subflavors:", - " source: ds2", - " number: 42", - ] - ) - ) + # ===== Part 2: Save and restore state ===== - # Train mode dataset - train_dataset = get_train_dataset( + # Now let's check if the state is stored and restored correctly + + with get_savable_loader( + get_train_dataset( fixed_epochs_mds_path, worker_config=worker_config, batch_size=1, shuffle_buffer_size=None, + shuffle_over_epochs_multiplier=None, + parallel_shard_iters=1, max_samples_per_sequence=None, repeat=False, - ) - print(len(train_dataset)) - assert len(train_dataset) == 5 * 55, len(train_dataset) - - with get_savable_loader( - train_dataset, - ) as train_loader: - data = list(enumerate(train_loader)) - txt_order = [data.text[0] for idx, data in data] - key_order = [ - data.__subflavors__[0]["source"] + "/" + data.__key__[0] for idx, data in data - ] - print("txt1:", txt_order) - print("key:", key_order) - assert len(txt_order) == 5 * 55, Counter(txt_order) - ds1_keys = [key for key in key_order if key.startswith("ds1/")] - ds2_keys = [key for key in key_order if key.startswith("ds2/")] - txt_cnt = Counter(txt_order) - ds1_key_cnt = Counter(ds1_keys) - ds2_key_cnt = Counter(ds2_keys) - assert len(ds1_keys) == 2 * 55, (len(ds1_keys), ds1_key_cnt) - assert len(ds2_keys) == 3 * 55, (len(ds2_keys), ds2_key_cnt) - assert all(ds1_key_cnt[key] == 2 for key in ds1_keys) - assert all(ds2_key_cnt[key] == 3 for key in ds2_keys) - assert all(txt_cnt[key] in (2, 3) for key in txt_order) - - # Next epoch - data = list(enumerate(train_loader)) - print([data.text[0] for idx, data in data]) - assert len(data) == 5 * 55, len(data) - - # Next epoch - data1 = list(zip(range(3 * 55), train_loader)) - assert len(data1) == 3 * 55, len(data1) - # Save state mid epoch - state1 = train_loader.save_state_rank() - print(state1) - - data2 = list(enumerate(train_loader)) - assert len(data2) == 2 * 55 - txt_order = [data.text[0] for idx, data in data1 + data2] - key_order = [ - data.__subflavors__[0]["source"] + "/" + data.__key__[0] - for idx, data in data1 + data2 - ] - assert len(txt_order) == 5 * 55, Counter(txt_order) - ds1_keys = [key for key in key_order if key.startswith("ds1/")] - ds2_keys = [key for key in key_order if key.startswith("ds2/")] - txt_cnt = Counter(txt_order) - ds1_key_cnt = Counter(ds1_keys) - ds2_key_cnt = Counter(ds2_keys) - assert len(ds1_keys) == 2 * 55, (len(ds1_keys), ds1_key_cnt) - assert len(ds2_keys) == 3 * 55, (len(ds2_keys), ds2_key_cnt) - assert all(ds1_key_cnt[key] == 2 for key in ds1_keys) - assert all(ds2_key_cnt[key] == 3 for key in ds2_keys) - assert all(txt_cnt[key] in (2, 3) for key in txt_order) - - # Restore state - with get_savable_loader( - get_train_dataset( - fixed_epochs_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - max_samples_per_sequence=None, - repeat=False, - ), - ).with_restored_state_rank(state1) as train_loader: - data2_restore = list(enumerate(train_loader)) - assert len(data2_restore) == 2 * 55 - txt_order_rst = [data.text[0] for idx, data in data1 + data2_restore] - key_order_rst = [ - data.__subflavors__[0]["source"] + "/" + data.__key__[0] - for idx, data in data1 + data2_restore - ] - assert len(txt_order_rst) == 5 * 55, Counter(txt_order_rst) - assert txt_order_rst == txt_order - assert key_order_rst == key_order - ds1_keys_rst = [key for key in key_order_rst if key.startswith("ds1/")] - ds2_keys_rst = [key for key in key_order_rst if key.startswith("ds2/")] - txt_cnt_rst = Counter(txt_order_rst) - ds1_key_cnt_rst = Counter(ds1_keys_rst) - ds2_key_cnt_rst = Counter(ds2_keys_rst) - assert len(ds1_keys_rst) == 2 * 55, (len(ds1_keys_rst), ds1_key_cnt_rst) - assert len(ds2_keys_rst) == 3 * 55, (len(ds2_keys_rst), ds2_key_cnt_rst) - assert all(ds1_key_cnt_rst[key] == 2 for key in ds1_keys_rst) - assert all(ds2_key_cnt_rst[key] == 3 for key in ds2_keys_rst) - assert all(txt_cnt_rst[key] in (2, 3) for key in txt_order_rst) - - def test_metadataset_fixed_fractional_epochs(self): - torch.manual_seed(42) - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, - ) + ), + ) as train_loader: + data1 = list(zip(range(95), train_loader)) + state1 = train_loader.save_state_rank() - # Create a joined dataset configuration - fixed_epochs_mds_path = self.dataset_path / "metadataset_fixed_epochs.yaml" - with open(fixed_epochs_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " blend_epochized:", - " - repetitions: 0.7", - " path: ds1", - " subflavors:", - " source: ds1", - " number: 43", - " - repetitions: 1.5", - " path: ds2", - " subflavors:", - " source: ds2", - " number: 42", - ] - ) - ) - - # ===== Part 1: Verify fractions ===== - - # Train mode dataset - train_dataset = get_train_dataset( + with get_savable_loader( + get_train_dataset( fixed_epochs_mds_path, worker_config=worker_config, batch_size=1, @@ -939,187 +996,136 @@ def test_metadataset_fixed_fractional_epochs(self): parallel_shard_iters=1, max_samples_per_sequence=None, repeat=False, - ) - - with get_savable_loader( - train_dataset, - ) as train_loader: - assert len(train_loader) == 38 + 55 + 27, len(train_loader) - - data = list(enumerate(train_loader)) - - # Check the overall number of samples - # Should be 0.7*len(ds1) + 1.5*len(ds2) = 0.7*55 + 1.5*55 = 38 + 55 + 27 (floor rounding) - assert len(data) == 38 + 55 + 27, len(data) - - sample_counts = Counter([int(s[1].text[0]) for s in data]) - - # The first 70% of samples from ds1 (0 to incl. 37) should be repeated only once - assert all(sample_counts[sample] == 1 for sample in range(38)) - - # Since ds2 is repeated 1.5 times, the first 50% of samples from ds2 (100 to incl. 126) should be repeated twice - assert all(sample_counts[sample] == 2 for sample in range(100, 127)) - - # The remaining samples from ds2 (127 to incl. 154) should be repeated only once - assert all(sample_counts[sample] == 1 for sample in range(127, 155)) - - # ===== Part 2: Save and restore state ===== - - # Now let's check if the state is stored and restored correctly - - with get_savable_loader( - get_train_dataset( - fixed_epochs_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - shuffle_over_epochs_multiplier=None, - parallel_shard_iters=1, - max_samples_per_sequence=None, - repeat=False, - ), - ) as train_loader: - data1 = list(zip(range(95), train_loader)) - state1 = train_loader.save_state_rank() + ), + ).with_restored_state_rank(state1) as train_loader: + data2_restore = list(enumerate(train_loader)) - with get_savable_loader( - get_train_dataset( - fixed_epochs_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - shuffle_over_epochs_multiplier=None, - parallel_shard_iters=1, - max_samples_per_sequence=None, - repeat=False, - ), - ).with_restored_state_rank(state1) as train_loader: - data2_restore = list(enumerate(train_loader)) + total_samples_save_restore = len(data1) + len(data2_restore) - total_samples_save_restore = len(data1) + len(data2_restore) + assert total_samples_save_restore == len(data), ( + "Total number of samples do not match when using save/restore" + ) - assert total_samples_save_restore == len(data), ( - "Total number of samples do not match when using save/restore" - ) + sample_counts_save_restore = Counter( + [int(s[1].text[0]) for d in [data1, data2_restore] for s in d] + ) - sample_counts_save_restore = Counter( - [int(s[1].text[0]) for d in [data1, data2_restore] for s in d] - ) + assert sample_counts_save_restore == sample_counts, ( + "Sample counts do not match when using save/restore" + ) - assert sample_counts_save_restore == sample_counts, ( - "Sample counts do not match when using save/restore" - ) + # ===== Part 3: Check if the state is restored correctly when saving right at the end of a dataset ===== - # ===== Part 3: Check if the state is restored correctly when saving right at the end of a dataset ===== + torch.manual_seed(42) - torch.manual_seed(42) + with get_savable_loader( + get_train_dataset( + fixed_epochs_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + shuffle_over_epochs_multiplier=None, + parallel_shard_iters=1, + max_samples_per_sequence=None, + repeat=False, + ), + ) as train_loader: + ds1_counter = 0 + data1 = [] + for idx, sample in enumerate(train_loader): + data1.append((idx, sample)) + if sample.__subflavors__[0]["source"] == "ds1": + ds1_counter += 1 + if ds1_counter == 38: + # Stop right after the last sample from ds1 + break - with get_savable_loader( - get_train_dataset( - fixed_epochs_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - shuffle_over_epochs_multiplier=None, - parallel_shard_iters=1, - max_samples_per_sequence=None, - repeat=False, - ), - ) as train_loader: - ds1_counter = 0 - data1 = [] - for idx, sample in enumerate(train_loader): - data1.append((idx, sample)) - if sample.__subflavors__[0]["source"] == "ds1": - ds1_counter += 1 - if ds1_counter == 38: - # Stop right after the last sample from ds1 - break - - state1 = train_loader.save_state_rank() - - with get_savable_loader( - get_train_dataset( - fixed_epochs_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - shuffle_over_epochs_multiplier=None, - parallel_shard_iters=1, - max_samples_per_sequence=None, - repeat=False, - ), - ).with_restored_state_rank(state1) as train_loader: - data2_restore = list(enumerate(train_loader)) + state1 = train_loader.save_state_rank() - total_samples_save_restore = len(data1) + len(data2_restore) + with get_savable_loader( + get_train_dataset( + fixed_epochs_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + shuffle_over_epochs_multiplier=None, + parallel_shard_iters=1, + max_samples_per_sequence=None, + repeat=False, + ), + ).with_restored_state_rank(state1) as train_loader: + data2_restore = list(enumerate(train_loader)) - assert total_samples_save_restore == len(data), ( - "Total number of samples do not match when using save/restore" - ) + total_samples_save_restore = len(data1) + len(data2_restore) - sample_counts_save_restore = Counter( - [int(s[1].text[0]) for d in [data1, data2_restore] for s in d] - ) + assert total_samples_save_restore == len(data), ( + "Total number of samples do not match when using save/restore" + ) - assert sample_counts_save_restore == sample_counts, ( - "Sample counts do not match when using save/restore" - ) + sample_counts_save_restore = Counter( + [int(s[1].text[0]) for d in [data1, data2_restore] for s in d] + ) - # Try in repeat mode - # Train mode dataset - with get_savable_loader( - get_train_dataset( - fixed_epochs_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - shuffle_over_epochs_multiplier=None, - parallel_shard_iters=1, - max_samples_per_sequence=None, - ), - ) as train_loader: - data = list(zip(range(200), train_loader)) - assert len(train_loader) == 38 + 55 + 27, len(train_loader) + assert sample_counts_save_restore == sample_counts, ( + "Sample counts do not match when using save/restore" + ) - # Check the overall number of samples - # Should be 0.7*len(ds1) + 1.5*len(ds2) = 38 + 55 + 27 (floor rounding) - assert len(data) == 200, len(data) + # Try in repeat mode + # Train mode dataset + with get_savable_loader( + get_train_dataset( + fixed_epochs_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + shuffle_over_epochs_multiplier=None, + parallel_shard_iters=1, + max_samples_per_sequence=None, + ), + ) as train_loader: + data = list(zip(range(200), train_loader)) + assert len(train_loader) == 38 + 55 + 27, len(train_loader) + + # Check the overall number of samples + # Should be 0.7*len(ds1) + 1.5*len(ds2) = 38 + 55 + 27 (floor rounding) + assert len(data) == 200, len(data) + + # ===== Part 4: Test count for multiple workers ===== + + worker_config = WorkerConfig( + rank=0, + world_size=2, + num_workers=2, + seed_offset=42, + ) + + # Train mode dataset + with get_savable_loader( + get_train_dataset( + fixed_epochs_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + shuffle_over_epochs_multiplier=None, + parallel_shard_iters=1, + max_samples_per_sequence=None, + repeat=False, + ), + ) as train_loader: + # TODO: This should be exactly 60. There is a corresponding TODO in the repeat_dataset.py + assert len(train_loader) == 58, len(train_loader) - # ===== Part 4: Test count for multiple workers ===== + data = list(enumerate(train_loader)) - worker_config = WorkerConfig( - rank=0, - world_size=2, - num_workers=2, - seed_offset=42, - ) + # Check the overall number of samples + # Should be 0.7*len(ds1)55 + 1.5*len(ds2)55 = 38 + 55 + 27 (floor rounding) + # TODO: This should be exactly 60. There is a corresponding TODO in the repeat_dataset.py + assert len(data) == 58, len(data) - # Train mode dataset - with get_savable_loader( - get_train_dataset( - fixed_epochs_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - shuffle_over_epochs_multiplier=None, - parallel_shard_iters=1, - max_samples_per_sequence=None, - repeat=False, - ), - ) as train_loader: - # TODO: This should be exactly 60. There is a corresponding TODO in the repeat_dataset.py - assert len(train_loader) == 58, len(train_loader) - - data = list(enumerate(train_loader)) - # Check the overall number of samples - # Should be 0.7*len(ds1)55 + 1.5*len(ds2)55 = 38 + 55 + 27 (floor rounding) - # TODO: This should be exactly 60. There is a corresponding TODO in the repeat_dataset.py - assert len(data) == 58, len(data) +def test_watchdog_dataset(dataset_path): + with patch.object(WatchdogDataset, "_watchdog_trigger") as mock_watchdog_trigger: - @patch.object(WatchdogDataset, "_watchdog_trigger") - def test_watchdog_dataset(self, mock_watchdog_trigger): class TestTaskEncoder(DefaultTaskEncoder): def __init__(self): super().__init__() @@ -1145,7 +1151,7 @@ def encode_sample(self, sample: TextSample) -> TextSample: # Train mode dataset train_dataset = get_train_dataset( - self.mds_path, + dataset_path / "metadataset_v2.yaml", worker_config=worker_config, batch_size=1, shuffle_buffer_size=None, @@ -1165,85 +1171,40 @@ def encode_sample(self, sample: TextSample) -> TextSample: mock_watchdog_trigger.assert_called() - def test_dataset_absolute_nested_subset_fail(self): - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, - ) - ratio_mds_path = self.dataset_path / "metadataset_ratio.yaml" - with open(ratio_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - # Absolute range on outer level should fail - " subset: {range: [50, 55]}", - " blend_epochized:", - " - path: ds1", - " subflavors:", - " source: ds1", - " number: 43", - " - repetitions: 2", - " path: ds2", - " subflavors:", - " source: ds2", - " number: 42", - ] - ) - ) - try: - with get_loader( - get_train_dataset( - ratio_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - shuffle_over_epochs_multiplier=None, - parallel_shard_iters=1, - max_samples_per_sequence=None, - repeat=False, - ) - ): - assert False, "Should have failed" - except Exception as e: - assert "only allowed for a leaf dataset" in str( - e - ) or "only use absolute subset ranges for a leaf dataset" in str(e), str(e) - return - - def test_dataset_with_subset_end_keyword(self): - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, - ) - ratio_mds_path = self.dataset_path / "metadataset_ratio.yaml" - with open(ratio_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - # Absolute range: [50, end] - # I.e. corresponds to sample range: [50, 55] (end is not included, so up to 54) - " subset: {range: [50, end]}", - " path: ds1", - " subflavors:", - " source: ds1", - " number: 43", - ] - ) +def test_dataset_absolute_nested_subset_fail(dataset_path): + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + ratio_mds_path = dataset_path / "metadataset_ratio.yaml" + with open(ratio_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + # Absolute range on outer level should fail + " subset: {range: [50, 55]}", + " blend_epochized:", + " - path: ds1", + " subflavors:", + " source: ds1", + " number: 43", + " - repetitions: 2", + " path: ds2", + " subflavors:", + " source: ds2", + " number: 42", + ] ) + ) + try: with get_loader( get_train_dataset( ratio_mds_path, @@ -1255,215 +1216,260 @@ def test_dataset_with_subset_end_keyword(self): max_samples_per_sequence=None, repeat=False, ) - ) as loader: - all_numbers = [int(s.text[0]) for s in loader] - - assert all_numbers == [50, 51, 52, 53, 54], "Subset range [50, end] should be [50, 55]" + ): + assert False, "Should have failed" + except Exception as e: + assert "only allowed for a leaf dataset" in str( + e + ) or "only use absolute subset ranges for a leaf dataset" in str(e), str(e) + return + + +def test_dataset_with_subset_end_keyword(dataset_path): + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + ratio_mds_path = dataset_path / "metadataset_ratio.yaml" + with open(ratio_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + # Absolute range: [50, end] + # I.e. corresponds to sample range: [50, 55] (end is not included, so up to 54) + " subset: {range: [50, end]}", + " path: ds1", + " subflavors:", + " source: ds1", + " number: 43", + ] + ) + ) - def test_dataset_with_subset_ratio(self): - worker_config = WorkerConfig( - rank=0, - world_size=1, - num_workers=0, - seed_offset=42, + with get_loader( + get_train_dataset( + ratio_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + shuffle_over_epochs_multiplier=None, + parallel_shard_iters=1, + max_samples_per_sequence=None, + repeat=False, ) - ratio_mds_path = self.dataset_path / "metadataset_ratio.yaml" - with open(ratio_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - # 20% of the dataset will be from ds1, 80% from ds2 - # I.e. sample range: [0.2*55, 0.8*55] = [11, 44] - " subset: {range: [20%, 80%]}", - " blend_epochized:", - " - path: ds1", - " subflavors:", - " source: ds1", - " number: 43", - " - repetitions: 2", - " path: ds2", - " subflavors:", - " source: ds2", - " number: 42", - ] - ) + ) as loader: + all_numbers = [int(s.text[0]) for s in loader] + + assert all_numbers == [50, 51, 52, 53, 54], "Subset range [50, end] should be [50, 55]" + + +def test_dataset_with_subset_ratio(dataset_path): + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + seed_offset=42, + ) + ratio_mds_path = dataset_path / "metadataset_ratio.yaml" + with open(ratio_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + # 20% of the dataset will be from ds1, 80% from ds2 + # I.e. sample range: [0.2*55, 0.8*55] = [11, 44] + " subset: {range: [20%, 80%]}", + " blend_epochized:", + " - path: ds1", + " subflavors:", + " source: ds1", + " number: 43", + " - repetitions: 2", + " path: ds2", + " subflavors:", + " source: ds2", + " number: 42", + ] ) + ) - with get_loader( - get_train_dataset( - ratio_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - shuffle_over_epochs_multiplier=None, - parallel_shard_iters=1, - max_samples_per_sequence=None, - repeat=False, - ) - ) as loader: - data = list(enumerate(loader)) - assert len(data) == 33 + 33 * 2, len(data) - - sample_counts = Counter([int(s[1].text[0]) for s in data]) - assert all(sample_counts[sample] == 0 for sample in range(11)), sample_counts - assert all(sample_counts[sample] == 1 for sample in range(11, 44)), sample_counts - assert all(sample_counts[sample] == 0 for sample in range(44, 55)), sample_counts - assert all(sample_counts[sample] == 0 for sample in range(100, 111)), sample_counts - assert all(sample_counts[sample] == 2 for sample in range(111, 144)), sample_counts - assert all(sample_counts[sample] == 0 for sample in range(144, 155)), sample_counts - assert sample_counts.total() == 33 + 33 * 2, sample_counts.total() - - # Combine with subset_samples - - ratio2_mds_path = self.dataset_path / "metadataset_ratio2.yaml" - with open(ratio2_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - # take [10, 30] from ds1, [20, 40] from ds2 and then only [20%, 80%] - # I.e. sample range: [14, 26], 2 * [124, 136] - " subset: {range: [20%, 80%]}", - " blend_epochized:", - " - path: ds1", - " subset: {range: [10, 30]}", - " subflavors:", - " source: ds1", - " number: 43", - " - repetitions: 2", - " subset: {range: [20, 40]}", - " path: ds2", - " subflavors:", - " source: ds2", - " number: 42", - ] - ) + with get_loader( + get_train_dataset( + ratio_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + shuffle_over_epochs_multiplier=None, + parallel_shard_iters=1, + max_samples_per_sequence=None, + repeat=False, + ) + ) as loader: + data = list(enumerate(loader)) + assert len(data) == 33 + 33 * 2, len(data) + + sample_counts = Counter([int(s[1].text[0]) for s in data]) + assert all(sample_counts[sample] == 0 for sample in range(11)), sample_counts + assert all(sample_counts[sample] == 1 for sample in range(11, 44)), sample_counts + assert all(sample_counts[sample] == 0 for sample in range(44, 55)), sample_counts + assert all(sample_counts[sample] == 0 for sample in range(100, 111)), sample_counts + assert all(sample_counts[sample] == 2 for sample in range(111, 144)), sample_counts + assert all(sample_counts[sample] == 0 for sample in range(144, 155)), sample_counts + assert sample_counts.total() == 33 + 33 * 2, sample_counts.total() + + # Combine with subset_samples + + ratio2_mds_path = dataset_path / "metadataset_ratio2.yaml" + with open(ratio2_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + # take [10, 30] from ds1, [20, 40] from ds2 and then only [20%, 80%] + # I.e. sample range: [14, 26], 2 * [124, 136] + " subset: {range: [20%, 80%]}", + " blend_epochized:", + " - path: ds1", + " subset: {range: [10, 30]}", + " subflavors:", + " source: ds1", + " number: 43", + " - repetitions: 2", + " subset: {range: [20, 40]}", + " path: ds2", + " subflavors:", + " source: ds2", + " number: 42", + ] ) + ) - with get_loader( - get_train_dataset( - ratio2_mds_path, - worker_config=worker_config, - batch_size=1, - shuffle_buffer_size=None, - shuffle_over_epochs_multiplier=None, - parallel_shard_iters=1, - max_samples_per_sequence=None, - repeat=False, + with get_loader( + get_train_dataset( + ratio2_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + shuffle_over_epochs_multiplier=None, + parallel_shard_iters=1, + max_samples_per_sequence=None, + repeat=False, + ) + ) as loader: + data = list(enumerate(loader)) + assert len(data) == 12 + 12 * 2, len(data) + + sample_counts = Counter([int(s[1].text[0]) for s in data]) + assert all(sample_counts[sample] == 0 for sample in range(14)), sample_counts + assert all(sample_counts[sample] == 1 for sample in range(14, 26)), sample_counts + assert all(sample_counts[sample] == 0 for sample in range(26, 55)), sample_counts + assert all(sample_counts[sample] == 0 for sample in range(100, 124)), sample_counts + assert all(sample_counts[sample] == 2 for sample in range(124, 136)), sample_counts + assert all(sample_counts[sample] == 0 for sample in range(136, 155)), sample_counts + assert sample_counts.total() == 12 + 12 * 2, sample_counts.total() + + # Combine with subset_ratio and subset_samples and nested metadataset + nested_mds_path = dataset_path / "metadataset_nested_subset.yaml" + with open(nested_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " subset: {range: [0%, 50%]}", + " blend_epochized:", + " - path: ds3", + # take [30, 50] from ds3, then first 50%, resulting in samples [230, 240] + " subset: {range: [30, 50]}", + " subflavors:", + " source: ds3", + " number: 45", + " - repetitions: 2", + # Inner sample range: [14, 26], 2 * [124, 136], total=12*3=36 + # Applying subset ratio 25%-75%: [17, 23], 2*[127, 133], total=3*6=18 + # Applying outer 50%: [17, 20], 2*[127, 130], total=3*3=9 + # Applying repetition: 2*[17, 20], 4*[127, 130], total=2*9=18 + " subset: {range: [25%, 75%]}", + " path: metadataset_ratio2.yaml", + ] ) - ) as loader: - data = list(enumerate(loader)) - assert len(data) == 12 + 12 * 2, len(data) - - sample_counts = Counter([int(s[1].text[0]) for s in data]) - assert all(sample_counts[sample] == 0 for sample in range(14)), sample_counts - assert all(sample_counts[sample] == 1 for sample in range(14, 26)), sample_counts - assert all(sample_counts[sample] == 0 for sample in range(26, 55)), sample_counts - assert all(sample_counts[sample] == 0 for sample in range(100, 124)), sample_counts - assert all(sample_counts[sample] == 2 for sample in range(124, 136)), sample_counts - assert all(sample_counts[sample] == 0 for sample in range(136, 155)), sample_counts - assert sample_counts.total() == 12 + 12 * 2, sample_counts.total() - - # Combine with subset_ratio and subset_samples and nested metadataset - nested_mds_path = self.dataset_path / "metadataset_nested_subset.yaml" - with open(nested_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " subset: {range: [0%, 50%]}", - " blend_epochized:", - " - path: ds3", - # take [30, 50] from ds3, then first 50%, resulting in samples [230, 240] - " subset: {range: [30, 50]}", - " subflavors:", - " source: ds3", - " number: 45", - " - repetitions: 2", - # Inner sample range: [14, 26], 2 * [124, 136], total=12*3=36 - # Applying subset ratio 25%-75%: [17, 23], 2*[127, 133], total=3*6=18 - # Applying outer 50%: [17, 20], 2*[127, 130], total=3*3=9 - # Applying repetition: 2*[17, 20], 4*[127, 130], total=2*9=18 - " subset: {range: [25%, 75%]}", - " path: metadataset_ratio2.yaml", - ] - ) + ) + + with get_loader( + get_train_dataset( + nested_mds_path, + worker_config=worker_config, + batch_size=1, + shuffle_buffer_size=None, + shuffle_over_epochs_multiplier=None, + parallel_shard_iters=1, + max_samples_per_sequence=None, + repeat=False, + ) + ) as loader: + data = list(enumerate(loader)) + assert len(data) == 10 + 9 * 2, len(data) + sample_counts = Counter([int(s[1].text[0]) for s in data]) + assert all(sample_counts[sample] == 0 for sample in range(17)), sample_counts + assert all(sample_counts[sample] == 2 for sample in range(17, 20)), sample_counts + assert all(sample_counts[sample] == 0 for sample in range(20, 55)), sample_counts + assert all(sample_counts[sample] == 0 for sample in range(100, 127)), sample_counts + assert all(sample_counts[sample] == 4 for sample in range(127, 130)), sample_counts + assert all(sample_counts[sample] == 0 for sample in range(130, 155)), sample_counts + assert all(sample_counts[sample] == 0 for sample in range(200, 230)), sample_counts + assert all(sample_counts[sample] == 1 for sample in range(230, 240)), sample_counts + assert all(sample_counts[sample] == 0 for sample in range(240, 255)), sample_counts + assert sample_counts.total() == 10 + 9 * 2, sample_counts.total() + + +def test_s3(dataset_path): + # Create a joined dataset configuration + mixed_mds_path = dataset_path / "metadataset_mixed.yaml" + with open(mixed_mds_path, "w") as f: + f.write( + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " path: msc://s3test_metadataset/test/dataset/nested_metadataset_v2.yaml", + ] ) + ) + + with setup_s3_emulator(profile_name="s3test_metadataset") as emu: + # Upload the dataset to the S3 emulator + # EPath(dataset_path).copy(EPath("msc://s3/test/dataset")) + emu.add_file(dataset_path, "test/dataset") with get_loader( get_train_dataset( - nested_mds_path, - worker_config=worker_config, + mixed_mds_path, + worker_config=WorkerConfig( + rank=0, + world_size=1, + num_workers=2, + ), batch_size=1, - shuffle_buffer_size=None, - shuffle_over_epochs_multiplier=None, - parallel_shard_iters=1, + shuffle_buffer_size=10, max_samples_per_sequence=None, - repeat=False, + virtual_epoch_length=10, ) - ) as loader: - data = list(enumerate(loader)) - assert len(data) == 10 + 9 * 2, len(data) - sample_counts = Counter([int(s[1].text[0]) for s in data]) - assert all(sample_counts[sample] == 0 for sample in range(17)), sample_counts - assert all(sample_counts[sample] == 2 for sample in range(17, 20)), sample_counts - assert all(sample_counts[sample] == 0 for sample in range(20, 55)), sample_counts - assert all(sample_counts[sample] == 0 for sample in range(100, 127)), sample_counts - assert all(sample_counts[sample] == 4 for sample in range(127, 130)), sample_counts - assert all(sample_counts[sample] == 0 for sample in range(130, 155)), sample_counts - assert all(sample_counts[sample] == 0 for sample in range(200, 230)), sample_counts - assert all(sample_counts[sample] == 1 for sample in range(230, 240)), sample_counts - assert all(sample_counts[sample] == 0 for sample in range(240, 255)), sample_counts - assert sample_counts.total() == 10 + 9 * 2, sample_counts.total() - - def test_s3(self): - # Create a joined dataset configuration - mixed_mds_path = self.dataset_path / "metadataset_mixed.yaml" - with open(mixed_mds_path, "w") as f: - f.write( - "\n".join( - [ - "__module__: megatron.energon", - "__class__: MetadatasetV2", - "splits:", - " train:", - " path: msc://s3test_metadataset/test/dataset/nested_metadataset_v2.yaml", - ] - ) - ) - - with setup_s3_emulator(profile_name="s3test_metadataset") as emu: - # Upload the dataset to the S3 emulator - # EPath(self.dataset_path).copy(EPath("msc://s3/test/dataset")) - emu.add_file(self.dataset_path, "test/dataset") - - with get_loader( - get_train_dataset( - mixed_mds_path, - worker_config=WorkerConfig( - rank=0, - world_size=1, - num_workers=2, - ), - batch_size=1, - shuffle_buffer_size=10, - max_samples_per_sequence=None, - virtual_epoch_length=10, - ) - ) as train_dataset: - data = list(enumerate(train_dataset)) - assert len(data) == 10, len(data) - - -if __name__ == "__main__": - unittest.main() + ) as train_dataset: + data = list(enumerate(train_dataset)) + assert len(data) == 10, len(data) diff --git a/tests/test_transforms.py b/tests/test_transforms.py index c76bd087..ddca5718 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -6,10 +6,10 @@ import logging import random import sys -import unittest from typing import Any, Tuple import numpy as np +import pytest import torch from PIL import Image from torchvision import transforms as T @@ -18,293 +18,287 @@ from megatron.energon.transforms import CustomTransform, MergeTransform -class TestTransforms(unittest.TestCase): - # Set up the test fixture - def setUp(self): - logging.basicConfig(stream=sys.stderr, level=logging.INFO) +@pytest.fixture(autouse=True) +def setup_logging(): + logging.basicConfig(stream=sys.stderr, level=logging.INFO) - def tearDown(self): - pass - @staticmethod - def reset_rng(seed: int = 42): - """Creates a PIL image with random noise.""" - np.random.seed(seed) - torch.manual_seed(seed) - random.seed(seed) +def reset_rng(seed: int = 42): + """Creates a PIL image with random noise.""" + np.random.seed(seed) + torch.manual_seed(seed) + random.seed(seed) - @staticmethod - def get_test_image(width: int, height: int): - """Creates a PIL image with random noise.""" - arr = np.zeros((width, height, 3), dtype=np.uint8) +def get_test_image(width: int, height: int): + """Creates a PIL image with random noise.""" - # Some colorful borders - arr[0, :, :] = [255, 0, 0] - arr[:, 0, :] = [255, 255, 0] - arr[-1, :, :] = [255, 255, 255] - arr[:, -1, :] = [0, 255, 0] + arr = np.zeros((width, height, 3), dtype=np.uint8) - # A single white pixel - if width > 3 and height > 3: - arr[3, 3, :] = [255, 255, 255] + # Some colorful borders + arr[0, :, :] = [255, 0, 0] + arr[:, 0, :] = [255, 255, 0] + arr[-1, :, :] = [255, 255, 255] + arr[:, -1, :] = [0, 255, 0] - # And in the middle some noise - if width > 10 and height > 10: - arr[5:-5, 5:-5, :] = np.random.randint(0, 255, (width - 10, height - 10, 3)) + # A single white pixel + if width > 3 and height > 3: + arr[3, 3, :] = [255, 255, 255] - return Image.fromarray(arr) + # And in the middle some noise + if width > 10 and height > 10: + arr[5:-5, 5:-5, :] = np.random.randint(0, 255, (width - 10, height - 10, 3)) - @staticmethod - def get_test_image_soft(width: int, height: int): - """Creates a PIL image smooth content""" + return Image.fromarray(arr) - arr = np.zeros((width, height, 3), dtype=np.uint8) - # Fill red channel the image with a smooth gradient from left to right. - arr[:, :, 0] = np.arange(width)[:, None] / width * 255 - # The same for green from top to bottom: - arr[:, :, 1] = np.arange(height)[None, :] / height * 255 +def get_test_image_soft(width: int, height: int): + """Creates a PIL image smooth content""" - return Image.fromarray(arr) + arr = np.zeros((width, height, 3), dtype=np.uint8) - def _apply_and_compare( - self, testable_transform, img, atol=2, seed=42, msg=None, only_nonblack=False - ): - # Then transform using our method - merge_transform = MergeTransform([testable_transform]) + # Fill red channel the image with a smooth gradient from left to right. + arr[:, :, 0] = np.arange(width)[:, None] / width * 255 + # The same for green from top to bottom: + arr[:, :, 1] = np.arange(height)[None, :] / height * 255 - self.reset_rng(seed=seed) - test_result = merge_transform(img) + return Image.fromarray(arr) - # And also transform using torchvision directly - self.reset_rng(seed=seed) - ref_result = testable_transform(img) - # Then compare the sizes and the images contents - self.assertEqual(test_result.size, ref_result.size) +def _apply_and_compare(testable_transform, img, atol=2, seed=42, msg=None, only_nonblack=False): + # Then transform using our method + merge_transform = MergeTransform([testable_transform]) - # Check that image contents are close - np_test = np.array(test_result) - np_ref = np.array(ref_result) + reset_rng(seed=seed) + test_result = merge_transform(img) - if only_nonblack: - nonblack_mask = (np_test > 0) & (np_ref > 0) - np_test = np_test[nonblack_mask] - np_ref = np_ref[nonblack_mask] + # And also transform using torchvision directly + reset_rng(seed=seed) + ref_result = testable_transform(img) - # The maximum allowed difference between pixel values is 2 (uint8) - self.assertTrue(np.allclose(np_test, np_ref, atol=atol), msg=msg) + # Then compare the sizes and the images contents + assert test_result.size == ref_result.size - def test_resize(self): - """Tests ResizeMapper""" + # Check that image contents are close + np_test = np.array(test_result) + np_ref = np.array(ref_result) - MAX_SIZE = 150 - # These are the different setups we test. Each entry is a tuple of - # (source size, resize_kwargs) + if only_nonblack: + nonblack_mask = (np_test > 0) & (np_ref > 0) + np_test = np_test[nonblack_mask] + np_ref = np_ref[nonblack_mask] - size_list = [ # source size (w, h), resize_kwargs - [(100, 100), {"size": (100, 100)}], - [(200, 50), {"size": (100, 100)}], - [(50, 50), {"size": (100, 100)}], - [(500, 500), {"size": (10, 10)}], - [(1, 2), {"size": (1, 3)}], # Scale width by 1.5x - [(50, 100), {"size": 100, "max_size": MAX_SIZE}], # Test max_size - ] + # The maximum allowed difference between pixel values is 2 (uint8) + assert np.allclose(np_test, np_ref, atol=atol), msg - for source_size, resize_kwargs in size_list: - logging.info( - f"Testing Resize with source size {source_size} and resize_kwargs {resize_kwargs}" - ) - # Create a test image of the given source size - img = TestTransforms.get_test_image(*source_size) - transform = T.Resize(**resize_kwargs, interpolation=InterpolationMode.NEAREST) +def test_resize(): + """Tests ResizeMapper""" - self._apply_and_compare( - transform, - img, - msg=f"Resize: source_size={source_size}, resize_kwargs={resize_kwargs}", - ) + MAX_SIZE = 150 + # These are the different setups we test. Each entry is a tuple of + # (source size, resize_kwargs) - def test_random_resized_crop(self): - """Tests RandomResizedCropMapper""" + size_list = [ # source size (w, h), resize_kwargs + [(100, 100), {"size": (100, 100)}], + [(200, 50), {"size": (100, 100)}], + [(50, 50), {"size": (100, 100)}], + [(500, 500), {"size": (10, 10)}], + [(1, 2), {"size": (1, 3)}], # Scale width by 1.5x + [(50, 100), {"size": 100, "max_size": MAX_SIZE}], # Test max_size + ] - randcrop = T.RandomResizedCrop( - 90, scale=(0.3, 0.7), ratio=(0.75, 1.3), interpolation=InterpolationMode.BILINEAR + for source_size, resize_kwargs in size_list: + logging.info( + f"Testing Resize with source size {source_size} and resize_kwargs {resize_kwargs}" ) - source_size = (50, 60) - - logging.info(f"Testing RandomResizedCrop with source size {source_size}") # Create a test image of the given source size - img = TestTransforms.get_test_image_soft(*source_size) + img = get_test_image(*source_size) + transform = T.Resize(**resize_kwargs, interpolation=InterpolationMode.NEAREST) - self._apply_and_compare(randcrop, img, msg="RandomResizedCrop") + _apply_and_compare( + transform, + img, + msg=f"Resize: source_size={source_size}, resize_kwargs={resize_kwargs}", + ) - def test_random_flip(self): - source_size = (55, 33) - img = TestTransforms.get_test_image(*source_size) - logging.info("Testing RandomHorizontalFlip 5 times") - for idx in range(5): - randhflip = T.RandomHorizontalFlip(p=0.8) - self._apply_and_compare(randhflip, img, seed=idx, msg="RandomHorizontalFlip") +def test_random_resized_crop(): + """Tests RandomResizedCropMapper""" - logging.info("Testing RandomVerticalFlip 5 times") - for idx in range(5): - randvflip = T.RandomVerticalFlip(p=0.8) - self._apply_and_compare(randvflip, img, seed=idx, msg="RandomVerticalFlip") + randcrop = T.RandomResizedCrop( + 90, scale=(0.3, 0.7), ratio=(0.75, 1.3), interpolation=InterpolationMode.BILINEAR + ) + source_size = (50, 60) - def test_random_rotation(self): - source_size = (55, 33) - img = TestTransforms.get_test_image_soft(*source_size) + logging.info(f"Testing RandomResizedCrop with source size {source_size}") - logging.info("Testing RandomRotation without expand") - for idx in range(5): - randrot = T.RandomRotation((-90, 269), interpolation=InterpolationMode.BILINEAR) - self._apply_and_compare( - randrot, - img, - seed=idx, - msg="RandomRotation without expand", - ) + # Create a test image of the given source size + img = get_test_image_soft(*source_size) - logging.info("Testing RandomRotation with expand") - for idx in range(5): - randrot = T.RandomRotation( - (-180, 269), interpolation=InterpolationMode.BILINEAR, expand=True - ) - self._apply_and_compare( - randrot, - img, - seed=idx, - msg="RandomRotation with expand", - ) + _apply_and_compare(randcrop, img, msg="RandomResizedCrop") - def test_random_crop(self): - source_size = (155, 120) - img = TestTransforms.get_test_image(*source_size) - - size_list = [ # crop size (w, h) - (155, 120), # Same size - (100, 50), - 3, # Single int as size - 120, - (155, 8), # One dimension same size - ] - - logging.info("Testing RandomCrop") - for idx, size in enumerate(size_list): - randcrop = T.RandomCrop(size) - self._apply_and_compare( - randcrop, - img, - seed=idx, - msg=f"RandomCrop: crop size={size}", - ) - # Test `pad_if_needed` (Crop size larger than image size) - randcrop = T.RandomCrop((500, 500), pad_if_needed=True) - self._apply_and_compare(randcrop, img) +def test_random_flip(): + source_size = (55, 33) + img = get_test_image(*source_size) - def test_random_perspective(self): - source_size = (128, 133) - img = TestTransforms.get_test_image_soft(*source_size) + logging.info("Testing RandomHorizontalFlip 5 times") + for idx in range(5): + randhflip = T.RandomHorizontalFlip(p=0.8) + _apply_and_compare(randhflip, img, seed=idx, msg="RandomHorizontalFlip") - logging.info("Testing RandomPerspective") - for idx in range(5): - randpersp = T.RandomPerspective(interpolation=InterpolationMode.BILINEAR) - self._apply_and_compare( - randpersp, - img, - seed=idx, - msg=f"RandomPerspective: source_size={source_size}", - only_nonblack=True, # Sometimes one pixel is off - ) + logging.info("Testing RandomVerticalFlip 5 times") + for idx in range(5): + randvflip = T.RandomVerticalFlip(p=0.8) + _apply_and_compare(randvflip, img, seed=idx, msg="RandomVerticalFlip") - def test_center_crop(self): - source_size_list = [ # source size (w, h) - (155, 120), - (154, 119), - ] - crop_size_list = [ # crop size (w, h) - (155, 120), # Same size - (100, 50), - 3, # Single int as size - 120, - (200, 50), # Large than image in x direction - (50, 200), # Large than image in y direction - (200, 200), # Large than image in both directions - ] +def test_random_rotation(): + source_size = (55, 33) + img = get_test_image_soft(*source_size) - logging.info("Testing CenterCrop") + logging.info("Testing RandomRotation without expand") + for idx in range(5): + randrot = T.RandomRotation((-90, 269), interpolation=InterpolationMode.BILINEAR) + _apply_and_compare( + randrot, + img, + seed=idx, + msg="RandomRotation without expand", + ) - for source_size in source_size_list: - img = TestTransforms.get_test_image(*source_size) + logging.info("Testing RandomRotation with expand") + for idx in range(5): + randrot = T.RandomRotation( + (-180, 269), interpolation=InterpolationMode.BILINEAR, expand=True + ) + _apply_and_compare( + randrot, + img, + seed=idx, + msg="RandomRotation with expand", + ) - for idx, crop_size in enumerate(crop_size_list): - centcrop = T.CenterCrop(crop_size) - self._apply_and_compare( - centcrop, - img, - seed=idx, - msg=f"CenterCrop: source_size={source_size}, crop_size={crop_size}", - ) - def test_custom(self): - """Tests if a custom transform works""" +def test_random_crop(): + source_size = (155, 120) + img = get_test_image(*source_size) + + size_list = [ # crop size (w, h) + (155, 120), # Same size + (100, 50), + 3, # Single int as size + 120, + (155, 8), # One dimension same size + ] + + logging.info("Testing RandomCrop") + for idx, size in enumerate(size_list): + randcrop = T.RandomCrop(size) + _apply_and_compare( + randcrop, + img, + seed=idx, + msg=f"RandomCrop: crop size={size}", + ) - source_size = (128, 133) + # Test `pad_if_needed` (Crop size larger than image size) + randcrop = T.RandomCrop((500, 500), pad_if_needed=True) + _apply_and_compare(randcrop, img) - class FixedTranslate(CustomTransform): - """Translates the image by 5 pixels in both x and y direction""" - def __init__(self): - pass +def test_random_perspective(): + source_size = (128, 133) + img = get_test_image_soft(*source_size) - def apply_transform( - self, matrix: np.ndarray, dst_size: np.ndarray - ) -> Tuple[Any, Any, Any]: - matrix = self.translate(5, 5) @ matrix - return matrix, dst_size, (self.__class__.__name__, (5, 5)) + logging.info("Testing RandomPerspective") + for idx in range(5): + randpersp = T.RandomPerspective(interpolation=InterpolationMode.BILINEAR) + _apply_and_compare( + randpersp, + img, + seed=idx, + msg=f"RandomPerspective: source_size={source_size}", + only_nonblack=True, # Sometimes one pixel is off + ) - img = TestTransforms.get_test_image(*source_size) - merge_transform = MergeTransform([FixedTranslate()]) - test_result = merge_transform(img) +def test_center_crop(): + source_size_list = [ # source size (w, h) + (155, 120), + (154, 119), + ] - reference_img = Image.new(img.mode, img.size, (0, 0, 0)) - reference_img.paste(img, (5, 5)) + crop_size_list = [ # crop size (w, h) + (155, 120), # Same size + (100, 50), + 3, # Single int as size + 120, + (200, 50), # Large than image in x direction + (50, 200), # Large than image in y direction + (200, 200), # Large than image in both directions + ] - self.assertTrue( - np.allclose(np.array(test_result), np.array(reference_img), atol=1), - msg="FixedTranslate", - ) + logging.info("Testing CenterCrop") + + for source_size in source_size_list: + img = get_test_image(*source_size) - def test_merge(self): - """Tests if two merged transforms yield the same result. - Merging RandomCrop and RandomPerspective.""" + for idx, crop_size in enumerate(crop_size_list): + centcrop = T.CenterCrop(crop_size) + _apply_and_compare( + centcrop, + img, + seed=idx, + msg=f"CenterCrop: source_size={source_size}, crop_size={crop_size}", + ) - source_size = (128, 133) - img = TestTransforms.get_test_image_soft(*source_size) - randcrop = T.RandomCrop((70, 70)) - randrot = T.RandomRotation((45, 269), interpolation=InterpolationMode.BILINEAR) +def test_custom(): + """Tests if a custom transform works""" - merge_transform = MergeTransform([randrot, randcrop]) - self.reset_rng(1) - test_result = merge_transform(img) + source_size = (128, 133) - self.reset_rng(1) - ref_result = randcrop(randrot(img)) + class FixedTranslate(CustomTransform): + """Translates the image by 5 pixels in both x and y direction""" - self.assertTrue( - np.allclose(np.array(test_result), np.array(ref_result), atol=1), - msg="MergeTransform of RandomRotation and RandomCrop", - ) + def __init__(self): + pass + + def apply_transform(self, matrix: np.ndarray, dst_size: np.ndarray) -> Tuple[Any, Any, Any]: + matrix = self.translate(5, 5) @ matrix + return matrix, dst_size, (self.__class__.__name__, (5, 5)) + + img = get_test_image(*source_size) + + merge_transform = MergeTransform([FixedTranslate()]) + test_result = merge_transform(img) + + reference_img = Image.new(img.mode, img.size, (0, 0, 0)) + reference_img.paste(img, (5, 5)) + + assert np.allclose(np.array(test_result), np.array(reference_img), atol=1), "FixedTranslate" + + +def test_merge(): + """Tests if two merged transforms yield the same result. + Merging RandomCrop and RandomPerspective.""" + + source_size = (128, 133) + img = get_test_image_soft(*source_size) + + randcrop = T.RandomCrop((70, 70)) + randrot = T.RandomRotation((45, 269), interpolation=InterpolationMode.BILINEAR) + + merge_transform = MergeTransform([randrot, randcrop]) + reset_rng(1) + test_result = merge_transform(img) + reset_rng(1) + ref_result = randcrop(randrot(img)) -if __name__ == "__main__": - unittest.main() + assert np.allclose(np.array(test_result), np.array(ref_result), atol=1), ( + "MergeTransform of RandomRotation and RandomCrop" + ) diff --git a/tests/test_typed_converter.py b/tests/test_typed_converter.py new file mode 100644 index 00000000..74f2ee73 --- /dev/null +++ b/tests/test_typed_converter.py @@ -0,0 +1,465 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause +import dataclasses +from typing import Any, Callable, Dict, List, Literal, NamedTuple, Optional, Set, Tuple, Type, Union + +import pytest + +from megatron.energon import typed_converter + + +class MyNamedTuple(NamedTuple): + x: int + y: str + + +@dataclasses.dataclass +class MyDataClass: + a: int + b: str = "default" + + +class ComplexNamedTuple(NamedTuple): + id: int + name: str + active: bool + + +@dataclasses.dataclass +class NestedDataClass: + value: float + metadata: Dict[str, Any] + tags: List[str] + + +@dataclasses.dataclass +class ExtendedNestedDataClass(NestedDataClass): + extra: str + + +def my_call(a: int, b: str) -> NestedDataClass: + return NestedDataClass(value=a, metadata={"b": b}, tags=[b]) + + +def my_call_extended(a: int, b: str) -> ExtendedNestedDataClass: + return ExtendedNestedDataClass(value=a, metadata={"b": b}, tags=[b], extra=f"extra_{b}") + + +@dataclasses.dataclass +class ComprehensiveDataClass: + # Primitive types + string_field: str + int_field: int + float_field: float + bool_field: bool + + # Optional types + optional_string: Optional[str] = None + optional_int: Optional[int] = None + + # Union types + union_field: Union[str, int] = "default" + union_optional: Union[str, None] = None + + # List types + string_list: List[str] = dataclasses.field(default_factory=list) + int_list: List[int] = dataclasses.field(default_factory=list) + nested_list: List[List[str]] = dataclasses.field(default_factory=list) + + # Dict types + string_dict: Dict[str, str] = dataclasses.field(default_factory=dict) + mixed_dict: Dict[str, Any] = dataclasses.field(default_factory=dict) + nested_dict: Dict[str, Dict[str, int]] = dataclasses.field(default_factory=dict) + + # Tuple types + fixed_tuple: Tuple[str, int, bool] = ("default", 0, False) + variable_tuple: Tuple[str, ...] = ("single",) + + # Set types + set_field: Set[int] = dataclasses.field(default_factory=set) + + # Literal types + status: Literal["active", "inactive", "pending"] = "pending" + priority: Literal[1, 2, 3, 4, 5] = 3 + + # Nested dataclass + nested: Optional[NestedDataClass] = None + + # Referencing a type + type_ref: Type[NestedDataClass] = NestedDataClass + + # Referencing a function + function_ref: Callable[[int, str], NestedDataClass] = my_call + + # NamedTuple + named_tuple: Optional[ComplexNamedTuple] = None + + # Any type + any_field: Any = None + + +def test_raw_to_typed_namedtuple(): + parser = typed_converter.JsonParser() + raw = {"x": 42, "y": "foo"} + result = parser.raw_to_typed(raw, MyNamedTuple) + assert isinstance(result, MyNamedTuple) + assert result.x == 42 + assert result.y == "foo" + + +def test_raw_to_typed_dataclass(): + parser = typed_converter.JsonParser() + raw = {"a": 7, "b": "bar"} + result = parser.raw_to_typed(raw, MyDataClass) + assert isinstance(result, MyDataClass) + assert result.a == 7 + assert result.b == "bar" + + +def test_raw_to_typed_dataclass_default(): + parser = typed_converter.JsonParser() + raw = {"a": 5} + result = parser.raw_to_typed(raw, MyDataClass) + assert result.a == 5 + assert result.b == "default" + + +def test_raw_to_typed_union(): + parser = typed_converter.JsonParser() + assert parser.raw_to_typed(123, Union[int, str]) == 123 + assert parser.raw_to_typed("abc", Union[int, str]) == "abc" + with pytest.raises(typed_converter.JsonValueError): + parser.raw_to_typed(1.5, Union[int, str]) + + +def test_raw_to_typed_optional(): + parser = typed_converter.JsonParser() + assert parser.raw_to_typed(None, Optional[int]) is None + assert parser.raw_to_typed(10, Optional[int]) == 10 + + +def test_raw_to_typed_list(): + parser = typed_converter.JsonParser() + raw = [1, 2, 3] + result = parser.raw_to_typed(raw, List[int]) + assert result == [1, 2, 3] + + +def test_raw_to_typed_dict(): + parser = typed_converter.JsonParser() + raw = {"foo": 1, "bar": 2} + result = parser.raw_to_typed(raw, Dict[str, int]) + assert result == {"foo": 1, "bar": 2} + + +def test_raw_to_typed_set(): + parser = typed_converter.JsonParser() + raw = [1, 2, 3] + result = parser.raw_to_typed(raw, Set[int]) + assert result == {1, 2, 3} + + +def test_raw_to_typed_literal(): + parser = typed_converter.JsonParser() + assert parser.raw_to_typed("yes", Literal["yes", "no"]) == "yes" + with pytest.raises(typed_converter.JsonValueError): + parser.raw_to_typed("maybe", Literal["yes", "no"]) + + +def test_to_json_object_namedtuple(): + obj = MyNamedTuple(x=1, y="abc") + json_obj = typed_converter.to_json_object(obj) + assert json_obj == {"x": 1, "y": "abc"} + + +def test_to_json_object_dataclass(): + obj = MyDataClass(a=2, b="xyz") + json_obj = typed_converter.to_json_object(obj) + assert json_obj == {"a": 2, "b": "xyz"} + + +def test_to_json_object_list(): + obj = [1, 2, 3] + json_obj = typed_converter.to_json_object(obj) + assert json_obj == [1, 2, 3] + + +def test_to_json_object_dict(): + obj = {"foo": 1, "bar": 2} + json_obj = typed_converter.to_json_object(obj) + assert json_obj == {"foo": 1, "bar": 2} + + +def test_isinstance_deep(): + assert typed_converter._isinstance_deep(1, int) + assert not typed_converter._isinstance_deep(1, str) + assert not typed_converter._isinstance_deep(1, float) + assert not typed_converter._isinstance_deep("1", int) + assert not typed_converter._isinstance_deep("1", float) + assert typed_converter._isinstance_deep([1, 2], List[int]) + assert not typed_converter._isinstance_deep([1, "a"], List[int]) + assert typed_converter._isinstance_deep({"a": 1}, Dict[str, int]) + assert not typed_converter._isinstance_deep({"a": "b"}, Dict[str, int]) + + +def test_missing_value_error(): + parser = typed_converter.JsonParser() + with pytest.raises(typed_converter.JsonValueError): + parser.raw_to_typed(typed_converter._missing_value, int) + + +def test_strict_extra_keys(): + parser = typed_converter.JsonParser(strict=True) + raw = {"a": 1, "b": "foo", "extra": 123} + with pytest.raises(typed_converter.JsonValueError): + parser.raw_to_typed(raw, MyDataClass) + + +def test_non_strict_extra_keys(): + parser = typed_converter.JsonParser(strict=False) + raw = {"a": 1, "b": "foo", "extra": 123} + result = parser.raw_to_typed(raw, MyDataClass) + assert result.a == 1 + assert result.b == "foo" + + +def test_comprehensive_dataclass(): + """Test a complex dataclass with all supported types.""" + parser = typed_converter.JsonParser() + + # Create comprehensive raw data + raw_data = { + "string_field": "test_string", + "int_field": 42, + "float_field": 3.14159, + "bool_field": True, + "optional_string": "optional_value", + "optional_int": 100, + "union_field": 123, # Using int instead of string + "union_optional": "union_string", + "string_list": ["item1", "item2", "item3"], + "int_list": [1, 2, 3, 4, 5], + "nested_list": [["a", "b"], ["c", "d"]], + "string_dict": {"key1": "value1", "key2": "value2"}, + "mixed_dict": {"str_key": "string", "int_key": 42, "bool_key": True}, + "nested_dict": {"outer1": {"inner1": 1, "inner2": 2}, "outer2": {"inner3": 3}}, + "fixed_tuple": ["tuple_string", 99, True], + "variable_tuple": ["var1", "var2", "var3"], + "set_field": [1, 2, 3], + "status": "active", + "priority": 5, + "nested": { + "value": 2.71828, + "metadata": {"nested_key": "nested_value", "count": 42}, + "tags": ["tag1", "tag2"], + }, + "named_tuple": {"id": 123, "name": "test_name", "active": False}, + "any_field": {"arbitrary": "data", "number": 999}, + } + + # Convert raw data to typed object + result = parser.raw_to_typed(raw_data, ComprehensiveDataClass) + + # Verify all fields + assert result.string_field == "test_string" + assert result.int_field == 42 + assert result.float_field == 3.14159 + assert result.bool_field is True + assert result.optional_string == "optional_value" + assert result.optional_int == 100 + assert result.union_field == 123 + assert result.union_optional == "union_string" + assert result.string_list == ["item1", "item2", "item3"] + assert result.int_list == [1, 2, 3, 4, 5] + assert result.nested_list == [["a", "b"], ["c", "d"]] + assert result.string_dict == {"key1": "value1", "key2": "value2"} + assert result.mixed_dict == {"str_key": "string", "int_key": 42, "bool_key": True} + assert result.nested_dict == {"outer1": {"inner1": 1, "inner2": 2}, "outer2": {"inner3": 3}} + assert result.fixed_tuple == ("tuple_string", 99, True) + assert result.variable_tuple == ("var1", "var2", "var3") + assert result.set_field == {1, 2, 3} + assert result.status == "active" + assert result.priority == 5 + + # Verify nested dataclass + assert isinstance(result.nested, NestedDataClass) + assert result.nested.value == 2.71828 + assert result.nested.metadata == {"nested_key": "nested_value", "count": 42} + assert result.nested.tags == ["tag1", "tag2"] + + # Verify NamedTuple + assert isinstance(result.named_tuple, ComplexNamedTuple) + assert result.named_tuple.id == 123 + assert result.named_tuple.name == "test_name" + assert result.named_tuple.active is False + + # Verify Any field + assert result.any_field == {"arbitrary": "data", "number": 999} + + # Test conversion back to JSON + json_obj = typed_converter.to_json_object(result) + + # Verify JSON conversion preserves data + assert json_obj["string_field"] == "test_string" + assert json_obj["int_field"] == 42 + assert json_obj["float_field"] == 3.14159 + assert json_obj["bool_field"] is True + assert json_obj["optional_string"] == "optional_value" + assert json_obj["optional_int"] == 100 + assert json_obj["union_field"] == 123 + assert json_obj["union_optional"] == "union_string" + assert json_obj["string_list"] == ["item1", "item2", "item3"] + assert json_obj["int_list"] == [1, 2, 3, 4, 5] + assert json_obj["nested_list"] == [["a", "b"], ["c", "d"]] + assert json_obj["string_dict"] == {"key1": "value1", "key2": "value2"} + assert json_obj["mixed_dict"] == {"str_key": "string", "int_key": 42, "bool_key": True} + assert json_obj["nested_dict"] == { + "outer1": {"inner1": 1, "inner2": 2}, + "outer2": {"inner3": 3}, + } + assert json_obj["fixed_tuple"] == ["tuple_string", 99, True] + assert json_obj["variable_tuple"] == ["var1", "var2", "var3"] + assert json_obj["set_field"] == [1, 2, 3] + assert json_obj["status"] == "active" + assert json_obj["priority"] == 5 + assert json_obj["nested"]["value"] == 2.71828 + assert json_obj["nested"]["metadata"] == {"nested_key": "nested_value", "count": 42} + assert json_obj["nested"]["tags"] == ["tag1", "tag2"] + assert json_obj["named_tuple"]["id"] == 123 + assert json_obj["named_tuple"]["name"] == "test_name" + assert json_obj["named_tuple"]["active"] is False + assert json_obj["any_field"] == {"arbitrary": "data", "number": 999} + assert json_obj["function_ref"]["__module__"] == my_call.__module__ + assert json_obj["function_ref"]["__function__"] == my_call.__name__ + assert json_obj["type_ref"]["__module__"] == NestedDataClass.__module__ + assert json_obj["type_ref"]["__class__"] == NestedDataClass.__name__ + + +def test_comprehensive_dataclass_with_defaults(): + """Test comprehensive dataclass with minimal data using defaults.""" + parser = typed_converter.JsonParser() + + # Minimal raw data - only required fields + raw_data = {"string_field": "minimal", "int_field": 1, "float_field": 1.0, "bool_field": False} + + result = parser.raw_to_typed(raw_data, ComprehensiveDataClass) + + # Verify required fields + assert result.string_field == "minimal" + assert result.int_field == 1 + assert result.float_field == 1.0 + assert result.bool_field is False + + # Verify defaults + assert result.optional_string is None + assert result.optional_int is None + assert result.union_field == "default" + assert result.union_optional is None + assert result.string_list == [] + assert result.int_list == [] + assert result.nested_list == [] + assert result.string_dict == {} + assert result.mixed_dict == {} + assert result.nested_dict == {} + assert result.fixed_tuple == ("default", 0, False) + assert result.variable_tuple == ("single",) + assert result.set_field == set() + assert result.status == "pending" + assert result.priority == 3 + assert result.nested is None + assert result.named_tuple is None + assert result.any_field is None + + +def test_comprehensive_dataclass_error_cases(): + """Test error cases for comprehensive dataclass.""" + parser = typed_converter.JsonParser() + + # Test invalid literal value + raw_data = { + "string_field": "test", + "int_field": 1, + "float_field": 1.0, + "bool_field": False, + "status": "invalid_status", # Should fail + } + + with pytest.raises(typed_converter.JsonValueError): + parser.raw_to_typed(raw_data, ComprehensiveDataClass) + + # Test invalid union type + raw_data = { + "string_field": "test", + "int_field": 1, + "float_field": 1.0, + "bool_field": False, + "union_field": 1.5, # Should fail - not str or int + } + + with pytest.raises(typed_converter.JsonValueError): + parser.raw_to_typed(raw_data, ComprehensiveDataClass) + + # Test invalid list element type + raw_data = { + "string_field": "test", + "int_field": 1, + "float_field": 1.0, + "bool_field": False, + "string_list": ["valid", 123, "also_valid"], # Should fail - int in string list + } + + with pytest.raises(typed_converter.JsonValueError): + parser.raw_to_typed(raw_data, ComprehensiveDataClass) + + # Test invalid tuple length + raw_data = { + "string_field": "test", + "int_field": 1, + "float_field": 1.0, + "bool_field": False, + "fixed_tuple": ["only", "two"], # Should fail - needs 3 elements + } + + with pytest.raises(typed_converter.JsonValueError): + parser.raw_to_typed(raw_data, ComprehensiveDataClass) + + +def test_comprehensive_dataclass_strict_mode(): + """Test comprehensive dataclass in strict mode with extra keys.""" + parser = typed_converter.JsonParser(strict=True) + + raw_data = { + "string_field": "test", + "int_field": 1, + "float_field": 1.0, + "bool_field": False, + "extra_field": "should_fail", # Should fail in strict mode + "nested": { + "__module__": my_call_extended.__module__, + "__function__": my_call_extended.__name__, + "a": 42, + "b": "Hello", + }, + "function_ref": { + "__module__": my_call_extended.__module__, + "__function__": my_call_extended.__name__, + }, + "type_ref": { + "__module__": ExtendedNestedDataClass.__module__, + "__class__": ExtendedNestedDataClass.__name__, + }, + } + + with pytest.raises(typed_converter.JsonValueError): + parser.raw_to_typed(raw_data, ComprehensiveDataClass) + + # Test non-strict mode allows extra keys + parser_non_strict = typed_converter.JsonParser(strict=False) + result = parser_non_strict.raw_to_typed(raw_data, ComprehensiveDataClass) + assert result.string_field == "test" + assert result.int_field == 1 + assert result.nested == ExtendedNestedDataClass( + value=42, metadata={"b": "Hello"}, tags=["Hello"], extra="extra_Hello" + ) + assert result.function_ref == my_call_extended + assert result.type_ref == ExtendedNestedDataClass diff --git a/tests/test_typedconverter_extended.py b/tests/test_typedconverter_extended.py new file mode 100644 index 00000000..43842921 --- /dev/null +++ b/tests/test_typedconverter_extended.py @@ -0,0 +1,2 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. +# SPDX-License-Identifier: BSD-3-Clause diff --git a/tests/test_weakref.py b/tests/test_weakref.py index 7d104b86..69e4aadb 100644 --- a/tests/test_weakref.py +++ b/tests/test_weakref.py @@ -1,7 +1,6 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause import multiprocessing -import unittest from dataclasses import dataclass from megatron.energon.fork_hook import ( @@ -14,187 +13,170 @@ ) -class TestWeakref(unittest.TestCase): - def test_weakcallbacks(self): - # Just test the internal WeakCallbacks class. - a_called = 0 - fn_called = 0 +def test_weakcallbacks(): + # Just test the internal WeakCallbacks class. + a_called = 0 + fn_called = 0 - class A: - def method(self): - nonlocal a_called - a_called += 1 + class A: + def method(self): + nonlocal a_called + a_called += 1 - def fn(): - nonlocal fn_called - fn_called += 1 + def fn(): + nonlocal fn_called + fn_called += 1 - a = A() + a = A() - registry = WeakCallbacks() + registry = WeakCallbacks() - registry.add_hook(a.method) - registry.add_hook(fn) - registry.add_hook(a.method) + registry.add_hook(a.method) + registry.add_hook(fn) + registry.add_hook(a.method) - registry.run() + registry.run() - assert a_called == 1, a_called - assert fn_called == 1, fn_called + assert a_called == 1, a_called + assert fn_called == 1, fn_called - assert len(registry._hooks) == 2, len(registry._hooks) + assert len(registry._hooks) == 2, len(registry._hooks) - del a + del a - assert len(registry._hooks) == 1, len(registry._hooks) + assert len(registry._hooks) == 1, len(registry._hooks) - registry.run() + registry.run() - assert a_called == 1, a_called - assert fn_called == 2, fn_called + assert a_called == 1, a_called + assert fn_called == 2, fn_called - del fn + del fn - assert len(registry._hooks) == 0, len(registry._hooks) + assert len(registry._hooks) == 0, len(registry._hooks) - registry.run() + registry.run() - assert a_called == 1, a_called - assert fn_called == 2, fn_called + assert a_called == 1, a_called + assert fn_called == 2, fn_called - assert len(registry._hooks) == 0, len(registry._hooks) + assert len(registry._hooks) == 0, len(registry._hooks) - def test_fork_weakref(self): - # Verify that the fork hooks are called correctly, and that gc works correctly. - _a_before_fork_called = 0 - _a_after_in_child_fork_called = 0 - _a_after_in_parent_fork_called = 0 +def test_fork_weakref(): + # Verify that the fork hooks are called correctly, and that gc works correctly. - class A(ForkMixin): - def __before_fork__(self): - nonlocal _a_before_fork_called - _a_before_fork_called += 1 + _a_before_fork_called = 0 + _a_after_in_child_fork_called = 0 + _a_after_in_parent_fork_called = 0 - def __after_in_child_fork__(self): - nonlocal _a_after_in_child_fork_called - _a_after_in_child_fork_called += 1 + class A(ForkMixin): + def __before_fork__(self): + nonlocal _a_before_fork_called + _a_before_fork_called += 1 - def __after_in_parent_fork__(self): - nonlocal _a_after_in_parent_fork_called - _a_after_in_parent_fork_called += 1 + def __after_in_child_fork__(self): + nonlocal _a_after_in_child_fork_called + _a_after_in_child_fork_called += 1 - _b_before_fork_called = 0 - _b_after_in_child_fork_called = 0 - _b_after_in_parent_fork_called = 0 + def __after_in_parent_fork__(self): + nonlocal _a_after_in_parent_fork_called + _a_after_in_parent_fork_called += 1 - @dataclass - class B(DataclassForkMixin): - def __before_fork__(self): - nonlocal _b_before_fork_called - _b_before_fork_called += 1 + _b_before_fork_called = 0 + _b_after_in_child_fork_called = 0 + _b_after_in_parent_fork_called = 0 - def __after_in_child_fork__(self): - nonlocal _b_after_in_child_fork_called - _b_after_in_child_fork_called += 1 + @dataclass + class B(DataclassForkMixin): + def __before_fork__(self): + nonlocal _b_before_fork_called + _b_before_fork_called += 1 - def __after_in_parent_fork__(self): - nonlocal _b_after_in_parent_fork_called - _b_after_in_parent_fork_called += 1 + def __after_in_child_fork__(self): + nonlocal _b_after_in_child_fork_called + _b_after_in_child_fork_called += 1 - a = A() - b = B() + def __after_in_parent_fork__(self): + nonlocal _b_after_in_parent_fork_called + _b_after_in_parent_fork_called += 1 - _before_fork_called = 0 - _after_in_child_fork_called = 0 - _after_in_parent_fork_called = 0 + a = A() + b = B() - def before_fork(): - nonlocal _before_fork_called - _before_fork_called += 1 + _before_fork_called = 0 + _after_in_child_fork_called = 0 + _after_in_parent_fork_called = 0 - def after_in_child_fork(): - nonlocal _after_in_child_fork_called - _after_in_child_fork_called += 1 + def before_fork(): + nonlocal _before_fork_called + _before_fork_called += 1 - def after_in_parent_fork(): - nonlocal _after_in_parent_fork_called - _after_in_parent_fork_called += 1 + def after_in_child_fork(): + nonlocal _after_in_child_fork_called + _after_in_child_fork_called += 1 - before_fork_hook(before_fork) - after_in_child_fork_hook(after_in_child_fork) - after_in_parent_fork_hook(after_in_parent_fork) + def after_in_parent_fork(): + nonlocal _after_in_parent_fork_called + _after_in_parent_fork_called += 1 - multiprocessing.set_start_method("fork", force=True) + before_fork_hook(before_fork) + after_in_child_fork_hook(after_in_child_fork) + after_in_parent_fork_hook(after_in_parent_fork) - def process_verify_fork_hooks_1(): - # Verify in the process that the fork hooks were called - assert _before_fork_called == 1, _before_fork_called - assert _after_in_child_fork_called == 1, _after_in_child_fork_called - # This was not called in the child process - assert _after_in_parent_fork_called == 0, _after_in_parent_fork_called - - assert _a_before_fork_called == 1, _a_before_fork_called - assert _a_after_in_child_fork_called == 1, _a_after_in_child_fork_called - assert _a_after_in_parent_fork_called == 0, _a_after_in_parent_fork_called - - assert _b_before_fork_called == 1, _b_before_fork_called - assert _b_after_in_child_fork_called == 1, _b_after_in_child_fork_called - assert _b_after_in_parent_fork_called == 0, _b_after_in_parent_fork_called - - p1 = multiprocessing.Process(target=process_verify_fork_hooks_1) - p1.start() - p1.join() - assert p1.exitcode == 0, p1.exitcode + multiprocessing.set_start_method("fork", force=True) + def process_verify_fork_hooks_1(): + # Verify in the process that the fork hooks were called assert _before_fork_called == 1, _before_fork_called - assert _after_in_child_fork_called == 0, _after_in_child_fork_called - assert _after_in_parent_fork_called == 1, _after_in_parent_fork_called + assert _after_in_child_fork_called == 1, _after_in_child_fork_called + # This was not called in the child process + assert _after_in_parent_fork_called == 0, _after_in_parent_fork_called assert _a_before_fork_called == 1, _a_before_fork_called - assert _a_after_in_child_fork_called == 0, _a_after_in_child_fork_called - assert _a_after_in_parent_fork_called == 1, _a_after_in_parent_fork_called + assert _a_after_in_child_fork_called == 1, _a_after_in_child_fork_called + assert _a_after_in_parent_fork_called == 0, _a_after_in_parent_fork_called assert _b_before_fork_called == 1, _b_before_fork_called - assert _b_after_in_child_fork_called == 0, _b_after_in_child_fork_called - assert _b_after_in_parent_fork_called == 1, _b_after_in_parent_fork_called + assert _b_after_in_child_fork_called == 1, _b_after_in_child_fork_called + assert _b_after_in_parent_fork_called == 0, _b_after_in_parent_fork_called - _a_before_fork_called = 0 - _a_after_in_child_fork_called = 0 - _a_after_in_parent_fork_called = 0 + p1 = multiprocessing.Process(target=process_verify_fork_hooks_1) + p1.start() + p1.join() + assert p1.exitcode == 0, p1.exitcode - _b_before_fork_called = 0 - _b_after_in_child_fork_called = 0 - _b_after_in_parent_fork_called = 0 + assert _before_fork_called == 1, _before_fork_called + assert _after_in_child_fork_called == 0, _after_in_child_fork_called + assert _after_in_parent_fork_called == 1, _after_in_parent_fork_called - _before_fork_called = 0 - _after_in_child_fork_called = 0 - _after_in_parent_fork_called = 0 + assert _a_before_fork_called == 1, _a_before_fork_called + assert _a_after_in_child_fork_called == 0, _a_after_in_child_fork_called + assert _a_after_in_parent_fork_called == 1, _a_after_in_parent_fork_called - del a - del b - del before_fork - del after_in_child_fork - del after_in_parent_fork + assert _b_before_fork_called == 1, _b_before_fork_called + assert _b_after_in_child_fork_called == 0, _b_after_in_child_fork_called + assert _b_after_in_parent_fork_called == 1, _b_after_in_parent_fork_called - def process_verify_fork_hooks_2(): - assert _before_fork_called == 0, _before_fork_called - assert _after_in_child_fork_called == 0, _after_in_child_fork_called - assert _after_in_parent_fork_called == 0, _after_in_parent_fork_called + _a_before_fork_called = 0 + _a_after_in_child_fork_called = 0 + _a_after_in_parent_fork_called = 0 - assert _a_before_fork_called == 0, _a_before_fork_called - assert _a_after_in_child_fork_called == 0, _a_after_in_child_fork_called - assert _a_after_in_parent_fork_called == 0, _a_after_in_parent_fork_called + _b_before_fork_called = 0 + _b_after_in_child_fork_called = 0 + _b_after_in_parent_fork_called = 0 - assert _b_before_fork_called == 0, _b_before_fork_called - assert _b_after_in_child_fork_called == 0, _b_after_in_child_fork_called - assert _b_after_in_parent_fork_called == 0, _b_after_in_parent_fork_called + _before_fork_called = 0 + _after_in_child_fork_called = 0 + _after_in_parent_fork_called = 0 - p2 = multiprocessing.Process(target=process_verify_fork_hooks_2) - p2.start() - p2.join() - assert p2.exitcode == 0, p2.exitcode + del a + del b + del before_fork + del after_in_child_fork + del after_in_parent_fork + def process_verify_fork_hooks_2(): assert _before_fork_called == 0, _before_fork_called assert _after_in_child_fork_called == 0, _after_in_child_fork_called assert _after_in_parent_fork_called == 0, _after_in_parent_fork_called @@ -206,3 +188,20 @@ def process_verify_fork_hooks_2(): assert _b_before_fork_called == 0, _b_before_fork_called assert _b_after_in_child_fork_called == 0, _b_after_in_child_fork_called assert _b_after_in_parent_fork_called == 0, _b_after_in_parent_fork_called + + p2 = multiprocessing.Process(target=process_verify_fork_hooks_2) + p2.start() + p2.join() + assert p2.exitcode == 0, p2.exitcode + + assert _before_fork_called == 0, _before_fork_called + assert _after_in_child_fork_called == 0, _after_in_child_fork_called + assert _after_in_parent_fork_called == 0, _after_in_parent_fork_called + + assert _a_before_fork_called == 0, _a_before_fork_called + assert _a_after_in_child_fork_called == 0, _a_after_in_child_fork_called + assert _a_after_in_parent_fork_called == 0, _a_after_in_parent_fork_called + + assert _b_before_fork_called == 0, _b_before_fork_called + assert _b_after_in_child_fork_called == 0, _b_after_in_child_fork_called + assert _b_after_in_parent_fork_called == 0, _b_after_in_parent_fork_called diff --git a/uv.lock b/uv.lock index 8c444d92..544cb73b 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,4 @@ version = 1 -revision = 1 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.13'", @@ -720,6 +719,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/52/cf/9dfc5616f103648f483c1595f05d2ac96df2dfec915351f507f7a500a38d/ebmlite-3.3.1-py3-none-any.whl", hash = "sha256:59285c472de1a6b92a4caf758b2b634a72a1468a94f12ebdb003202a07f01edf", size = 92152 }, ] +[[package]] +name = "exceptiongroup" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0b/9f/a65090624ecf468cdca03533906e7c69ed7588582240cfe7cc9e770b50eb/exceptiongroup-1.3.0.tar.gz", hash = "sha256:b241f5885f560bc56a59ee63ca4c6a8bfa46ae4ad651af316d4e81817bb9fd88", size = 29749 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/36/f4/c6e662dade71f56cd2f3735141b265c3c79293c109549c1e6933b0651ffc/exceptiongroup-1.3.0-py3-none-any.whl", hash = "sha256:4d111e6e0c13d0644cad6ddaa7ed0261a0b36971f6d23e7ec9b4b9097da78a10", size = 16674 }, +] + [[package]] name = "filelock" version = "3.18.0" @@ -999,6 +1010,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/79/9d/0fb148dc4d6fa4a7dd1d8378168d9b4cd8d4560a6fbf6f0121c5fc34eb68/importlib_metadata-8.6.1-py3-none-any.whl", hash = "sha256:02a89390c1e15fdfdc0d7c6b25cb3e62650d0494005c97d6f148bf5b9787525e", size = 26971 }, ] +[[package]] +name = "iniconfig" +version = "2.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f2/97/ebf4da567aa6827c909642694d71c9fcf53e5b504f2d96afea02718862f3/iniconfig-2.1.0.tar.gz", hash = "sha256:3abbd2e30b36733fee78f9c7f7308f2d0050e88f0087fd25c2645f63c773e1c7", size = 4793 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/e1/e6716421ea10d38022b952c159d5161ca1193197fb744506875fbb87ea7b/iniconfig-2.1.0-py3-none-any.whl", hash = "sha256:9deba5723312380e77435581c6bf4935c94cbfab9b1ed33ef8d238ea168eb760", size = 6050 }, +] + [[package]] name = "isodate" version = "0.7.2" @@ -1182,6 +1202,7 @@ azure-storage-blob = [ dev = [ { name = "coverage" }, { name = "myst-parser" }, + { name = "pytest" }, { name = "ruff" }, { name = "soundfile" }, { name = "sphinx", version = "8.1.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, @@ -1227,6 +1248,7 @@ requires-dist = [ { name = "myst-parser", marker = "extra == 'dev'" }, { name = "numpy" }, { name = "pillow", specifier = ">=10.0.1" }, + { name = "pytest", marker = "extra == 'dev'" }, { name = "pyyaml" }, { name = "rapidyaml", specifier = "==0.9.0" }, { name = "ruff", marker = "extra == 'dev'" }, @@ -1242,7 +1264,6 @@ requires-dist = [ { name = "tqdm" }, { name = "webdataset" }, ] -provides-extras = ["aistore", "av-decode", "azure-storage-blob", "dev", "google-cloud-storage", "guess-content", "oci", "s3", "transforms"] [[package]] name = "mfusepy" @@ -1775,6 +1796,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/21/2c/5e05f58658cf49b6667762cca03d6e7d85cededde2caf2ab37b81f80e574/pillow-11.2.1-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:208653868d5c9ecc2b327f9b9ef34e0e42a4cdd172c2988fd81d62d2bc9bc044", size = 2674751 }, ] +[[package]] +name = "pluggy" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f9/e2/3e91f31a7d2b083fe6ef3fa267035b518369d9511ffab804f839851d2779/pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3", size = 69412 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538 }, +] + [[package]] name = "pockets" version = "0.9.1" @@ -2056,6 +2086,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9e/de/f8342b68fa9e981d348039954657bdf681b2ab93de27443be51865ffa310/pyOpenSSL-19.1.0-py2.py3-none-any.whl", hash = "sha256:621880965a720b8ece2f1b2f54ea2071966ab00e2970ad2ce11d596102063504", size = 53749 }, ] +[[package]] +name = "pytest" +version = "8.4.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "iniconfig" }, + { name = "packaging" }, + { name = "pluggy" }, + { name = "pygments" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/08/ba/45911d754e8eba3d5a841a5ce61a65a685ff1798421ac054f85aa8747dfb/pytest-8.4.1.tar.gz", hash = "sha256:7c67fd69174877359ed9371ec3af8a3d2b04741818c51e5e99cc1742251fa93c", size = 1517714 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/29/16/c8a903f4c4dffe7a12843191437d7cd8e32751d5de349d45d3fe69544e87/pytest-8.4.1-py3-none-any.whl", hash = "sha256:539c70ba6fcead8e78eebbf1115e8b589e7565830d7d006a8723f19ac8a0afb7", size = 365474 }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" From 7a075669ae4303dd71b33baf83dee86a801457fa Mon Sep 17 00:00:00 2001 From: Philipp Fischer Date: Mon, 8 Dec 2025 04:14:22 -0800 Subject: [PATCH 30/36] Two import fixes --- src/megatron/energon/wrappers/map_dataset.py | 10 +++++++--- tests/test_dataloader.py | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/megatron/energon/wrappers/map_dataset.py b/src/megatron/energon/wrappers/map_dataset.py index 9fb4d385..0b7a4164 100644 --- a/src/megatron/energon/wrappers/map_dataset.py +++ b/src/megatron/energon/wrappers/map_dataset.py @@ -22,12 +22,16 @@ from megatron.energon.flavors.base_dataset import ( RestoreKey, SavableDataset, - WrappedRestoreKey, set_sample_restore_key, - wrap_sample_restore_key, ) from megatron.energon.worker import WorkerConfig -from megatron.energon.wrappers.base import BaseWrapperDataset, SampleIndex, get_sample_restore_key +from megatron.energon.wrappers.base import ( + BaseWrapperDataset, + SampleIndex, + WrappedRestoreKey, + get_sample_restore_key, + wrap_sample_restore_key, +) T_sample = TypeVar("T_sample") T_sample_out = TypeVar("T_sample_out") diff --git a/tests/test_dataloader.py b/tests/test_dataloader.py index c7c58361..1e134324 100644 --- a/tests/test_dataloader.py +++ b/tests/test_dataloader.py @@ -21,7 +21,7 @@ get_train_dataset, ) from megatron.energon.dataloader import DataLoader, ForkDataLoaderWorker, ThreadDataLoaderWorker -from megatron.energon.flavors.webdataset import MAIN_FOLDER_NAME +from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME from megatron.energon.task_encoder.base import DefaultTaskEncoder, stateless From 00deed0a93d0246a756b77e3ac62283f8c6d1150 Mon Sep 17 00:00:00 2001 From: Philipp Fischer Date: Mon, 8 Dec 2025 04:15:31 -0800 Subject: [PATCH 31/36] Ruff --- src/megatron/energon/wrappers/batch_dataset.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/megatron/energon/wrappers/batch_dataset.py b/src/megatron/energon/wrappers/batch_dataset.py index 406d4586..5466aad1 100644 --- a/src/megatron/energon/wrappers/batch_dataset.py +++ b/src/megatron/energon/wrappers/batch_dataset.py @@ -12,8 +12,6 @@ Iterator, List, Optional, - Sequence, - Tuple, TypeVar, Union, ) From 7db3f1bb23dbb57ff2806ffac0625368cf9d6535 Mon Sep 17 00:00:00 2001 From: Philipp Fischer Date: Mon, 8 Dec 2025 04:55:49 -0800 Subject: [PATCH 32/36] Fix mock file pool and wrapping file stores. --- src/megatron/energon/cache/base.py | 9 +++++++++ src/megatron/energon/cache/file_store.py | 9 --------- tests/test_file_cache_pool.py | 9 +++++++++ 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/megatron/energon/cache/base.py b/src/megatron/energon/cache/base.py index 31c4b127..3de61184 100644 --- a/src/megatron/energon/cache/base.py +++ b/src/megatron/energon/cache/base.py @@ -88,6 +88,15 @@ def _decode_raw(self, data: T, **kwargs) -> T: """ return self._inner._decode_raw(data, **kwargs) + def worker_init(self) -> None: + self._inner.worker_init() + + def worker_close(self) -> None: + self._inner.worker_close() + + def close(self) -> None: + self._inner.close() + @edataclass class Lazy(Generic[T]): diff --git a/src/megatron/energon/cache/file_store.py b/src/megatron/energon/cache/file_store.py index f8027e0c..df05ab75 100644 --- a/src/megatron/energon/cache/file_store.py +++ b/src/megatron/energon/cache/file_store.py @@ -33,15 +33,6 @@ def __init__( super().__init__(inner) self.decoder = decoder - def worker_init(self) -> None: - self.inner_reader.worker_init() - - def worker_close(self) -> None: - self.inner_reader.worker_close() - - def close(self) -> None: - self.inner_reader.close() - def __getitem__(self, fname: str) -> tuple[Any, SourceInfo]: data, source_info = self._inner[fname] return self.decoder.decode(fname, data), source_info diff --git a/tests/test_file_cache_pool.py b/tests/test_file_cache_pool.py index 3c140ca0..016f6e1b 100644 --- a/tests/test_file_cache_pool.py +++ b/tests/test_file_cache_pool.py @@ -33,6 +33,15 @@ def __getitem__(self, key: str) -> tuple[Any, SourceInfo]: def get_path(self) -> str: return self._path + def worker_init(self) -> None: + pass + + def worker_close(self) -> None: + pass + + def close(self) -> None: + pass + class MockDecoder(SampleDecoder): """Mock decoder for DecodeFileStore""" From e5a30ea10f2253ba618467c48275868527f1a07d Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Wed, 1 Oct 2025 17:49:38 +0200 Subject: [PATCH 33/36] Remove empty file --- tests/test_typedconverter_extended.py | 2 -- 1 file changed, 2 deletions(-) delete mode 100644 tests/test_typedconverter_extended.py diff --git a/tests/test_typedconverter_extended.py b/tests/test_typedconverter_extended.py deleted file mode 100644 index 43842921..00000000 --- a/tests/test_typedconverter_extended.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. -# SPDX-License-Identifier: BSD-3-Clause From a3e2980b302f18fabfcee4b3e073906b89ece84e Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Tue, 9 Dec 2025 13:33:45 +0100 Subject: [PATCH 34/36] Fix file store worker init --- src/megatron/energon/cache/file_cache_pool.py | 5 +++++ src/megatron/energon/cache/file_store.py | 4 ++-- src/megatron/energon/flavors/webdataset/itar_reader.py | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/megatron/energon/cache/file_cache_pool.py b/src/megatron/energon/cache/file_cache_pool.py index a2f50b08..68b5af9e 100644 --- a/src/megatron/energon/cache/file_cache_pool.py +++ b/src/megatron/energon/cache/file_cache_pool.py @@ -157,6 +157,8 @@ class FileStoreCachePool(CachePool, ForkMixin): # Whether the pool is shutting down _shutting_down: bool = False + _workers_initialized: dict[int, bool] = {} + def __init__( self, *, @@ -278,6 +280,9 @@ def _cache_out_task(self, ds: FileStore, fname: str, entry: _PendingTask) -> boo with self._lock: if self._shutting_down: return False + if not self._workers_initialized.get(threading.get_ident(), False): + ds.worker_init() + self._workers_initialized[threading.get_ident()] = True # Perform the data read if self.method == "raw": diff --git a/src/megatron/energon/cache/file_store.py b/src/megatron/energon/cache/file_store.py index df05ab75..82b1b052 100644 --- a/src/megatron/energon/cache/file_store.py +++ b/src/megatron/energon/cache/file_store.py @@ -166,7 +166,7 @@ def get_path(self) -> str: def get_media_metadata(self, key: str) -> MediaMetadataBase: if self._media_metadata_available is None: try: - has_metadata = self.sqlite_reader.db_has_media_metadata() + has_metadata = self._sqlite_reader.db_has_media_metadata() except sqlite3.Error as exc: # pragma: no cover - defensive raise RuntimeError( "Failed to inspect media metadata table. Re-run `energon prepare --media-metadata-by-...`." @@ -181,7 +181,7 @@ def get_media_metadata(self, key: str) -> MediaMetadataBase: self._media_metadata_available = True try: - row = self.sqlite_reader.get_media_metadata(key) + row = self._sqlite_reader.get_media_metadata(key) except sqlite3.Error as exc: # pragma: no cover - defensive raise RuntimeError( "Failed to load media metadata. Re-run `energon prepare --media-metadata-by-...`." diff --git a/src/megatron/energon/flavors/webdataset/itar_reader.py b/src/megatron/energon/flavors/webdataset/itar_reader.py index a661ae38..7ca59785 100644 --- a/src/megatron/energon/flavors/webdataset/itar_reader.py +++ b/src/megatron/energon/flavors/webdataset/itar_reader.py @@ -578,7 +578,7 @@ def worker_init(self): self.thread_local._sqlite_reader = SqliteIndexReader(self.sqlite_path) def worker_close(self): - if hasattr(self.thread_local, "_sqlite_reader"): + if getattr(self.thread_local, "_sqlite_reader", None) is not None: self.thread_local._sqlite_reader.close() del self.thread_local._sqlite_reader From 087a13d2347c77ac05ec0723d07c9530b60a2408 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Tue, 9 Dec 2025 13:38:48 +0100 Subject: [PATCH 35/36] Remove print --- src/megatron/energon/rng.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/megatron/energon/rng.py b/src/megatron/energon/rng.py index cbb61324..2773af08 100644 --- a/src/megatron/energon/rng.py +++ b/src/megatron/energon/rng.py @@ -237,5 +237,5 @@ def get_seed_from_args(*args: Any) -> int: @staticmethod def seed_args(*args: Any) -> None: """Seeds the global random generators deterministically from the given arguments.""" - print(f"Seeding with args: {args}") + # print(f"Seeding with args: {args}") SystemRng.seed(SystemRng.get_seed_from_args(*args)) From f82b86abb10ad7b2131dc8f1a65d23c281ad93c8 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Tue, 9 Dec 2025 13:41:18 +0100 Subject: [PATCH 36/36] Fix SqliteIndexReader shutdown --- src/megatron/energon/flavors/webdataset/indexing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/megatron/energon/flavors/webdataset/indexing.py b/src/megatron/energon/flavors/webdataset/indexing.py index b3ef98d1..fa524183 100644 --- a/src/megatron/energon/flavors/webdataset/indexing.py +++ b/src/megatron/energon/flavors/webdataset/indexing.py @@ -507,7 +507,7 @@ def get_media_metadata(self, entry_key: str) -> Tuple[str, str] | None: def close(self): """Close the database connection.""" - if self.db is not None: + if getattr(self, "db", None) is not None: self.db.thread_close() del self.db