diff --git a/src/megatron/energon/__init__.py b/src/megatron/energon/__init__.py index e5204230..26810e82 100644 --- a/src/megatron/energon/__init__.py +++ b/src/megatron/energon/__init__.py @@ -33,6 +33,7 @@ CrudeJsonlDatasetFactory, CrudeSample, CrudeWebdataset, + DatasetSampleReader, DefaultCrudeJsonlDatasetFactory, DefaultDecoderWebdatasetFactory, DefaultGenericWebdatasetFactory, @@ -47,6 +48,7 @@ MultiChoiceVQAWebdataset, OCRSample, OCRWebdataset, + RawSample, Sample, SampleDecoder, SavableDataset, @@ -61,7 +63,7 @@ VQASample, VQAWebdataset, ) -from megatron.energon.loader import get_loader, get_savable_loader +from megatron.energon.loader import get_dataset_reader, get_loader, get_savable_loader from megatron.energon.metadataset import ( DatasetLoader, DatasetLoaderInterface, @@ -136,6 +138,7 @@ "CrudeWebdataset", "DatasetLoader", "DatasetLoaderInterface", + "DatasetSampleReader", "DecodeFileStore", "DefaultCrudeJsonlDatasetFactory", "DefaultDecoderWebdatasetFactory", @@ -152,6 +155,7 @@ "GcDataset", "generic_batch", "generic_concat", + "get_dataset_reader", "get_loader", "get_savable_loader", "get_train_dataset", @@ -187,6 +191,7 @@ "prepare_metadataset", "RepeatDataset", "reraise_exception", + "RawSample", "Sample", "SampleDecoder", "SavableDataLoader", diff --git a/src/megatron/energon/flavors/__init__.py b/src/megatron/energon/flavors/__init__.py index 17ffb4dc..4c2fabe9 100644 --- a/src/megatron/energon/flavors/__init__.py +++ b/src/megatron/energon/flavors/__init__.py @@ -3,7 +3,9 @@ from megatron.energon.flavors.base_dataset import ( BaseCoreDatasetFactory, + DatasetSampleReader, PinMemoryMixin, + RawSample, Sample, SavableDataset, ) @@ -55,6 +57,7 @@ "CrudeJsonlDatasetFactory", "CrudeSample", "CrudeWebdataset", + "DatasetSampleReader", "DefaultCrudeJsonlDatasetFactory", "DefaultDecoderWebdatasetFactory", "DefaultGenericWebdatasetFactory", @@ -72,6 +75,7 @@ "OCRSample", "OCRWebdataset", "PinMemoryMixin", + "RawSample", "Sample", "SampleDecoder", "SavableDataset", diff --git a/src/megatron/energon/flavors/base_dataset.py b/src/megatron/energon/flavors/base_dataset.py index f51aef01..07f73dce 100644 --- a/src/megatron/energon/flavors/base_dataset.py +++ b/src/megatron/energon/flavors/base_dataset.py @@ -12,6 +12,7 @@ Callable, ClassVar, Dict, + Generator, Generic, List, Optional, @@ -37,6 +38,10 @@ T_sample = TypeVar("T_sample", covariant=True) T = TypeVar("T", covariant=True) +# Must contain at least the fields __key__, __restore_key__, __sources__. +# Other fields contain the data. +RawSample = Dict[str, Any] + class PinMemoryMixin: """A mixin class providing a generic `pin_memory` function.""" @@ -395,6 +400,30 @@ def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_s ) +class DatasetSampleReader(ABC): + """ + Interface for dataset readers for random access to samples. + """ + + @abstractmethod + def __getitem__(self, idx: int) -> RawSample | None: ... + + @abstractmethod + def __len__(self) -> int: ... + + @abstractmethod + def __iter__(self) -> Generator[RawSample | None, None, None]: ... + + @abstractmethod + def close(self) -> None: ... + + def __enter__(self) -> Self: + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + self.close() + + class BaseCoreDatasetFactory(Generic[T_sample], ABC): """Base type for an inner dataset sample loader. This factory can be used to construct a sample loader, or for joining in a joined dataset.""" diff --git a/src/megatron/energon/flavors/jsonl/ijsonl_reader.py b/src/megatron/energon/flavors/jsonl/ijsonl_reader.py index 202f1926..6761bc17 100644 --- a/src/megatron/energon/flavors/jsonl/ijsonl_reader.py +++ b/src/megatron/energon/flavors/jsonl/ijsonl_reader.py @@ -1,16 +1,19 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -from abc import ABC from typing import ( Callable, Generator, Optional, Tuple, TypeVar, + overload, ) +from typing_extensions import Self + from megatron.energon.epathlib import EPath +from megatron.energon.flavors.base_dataset import DatasetSampleReader, RawSample from megatron.energon.flavors.jsonl.ijsonl import ( CachedIJsonlOffsetReader, IJsonlFile, @@ -23,7 +26,7 @@ T_index = TypeVar("T_index", covariant=False) -class IJsonlReader(ABC): +class IJsonlReader(DatasetSampleReader): """ Class for reading indexed jsonl files containing json samples. @@ -102,6 +105,12 @@ def _get_item_by_sample_pointer( json=json_data, ) + @overload + def __getitem__(self, idx: int) -> FilteredSample | None: ... + + @overload + def __getitem__(self, idx: str) -> tuple[bytes, SourceInfo] | None: ... + def __getitem__(self, idx: int | str) -> FilteredSample | tuple[bytes, SourceInfo] | None: """ Get a sample from the dataset. @@ -136,6 +145,10 @@ def __getitem__(self, idx: int | str) -> FilteredSample | tuple[bytes, SourceInf else: return sample + def __iter__(self) -> Generator[RawSample | None, None, None]: + for idx in range(len(self)): + yield self[idx] + def list_all_samples(self) -> Generator[Tuple[str, int, int], None, None]: """List all samples in the jsonl file. @@ -184,6 +197,12 @@ def list_sample_parts(self, sample_key: str) -> Generator[Tuple[str, int, int], def get_total_size(self) -> int: return self.cached_offset_reader.get_total_size() + def __enter__(self) -> Self: + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + self.close() + def close(self): if self.ijsonl_file is not None: self.ijsonl_file.close() diff --git a/src/megatron/energon/flavors/webdataset/itar_reader.py b/src/megatron/energon/flavors/webdataset/itar_reader.py index b6c837aa..8a11a7e2 100644 --- a/src/megatron/energon/flavors/webdataset/itar_reader.py +++ b/src/megatron/energon/flavors/webdataset/itar_reader.py @@ -14,10 +14,12 @@ Tuple, TypeVar, Union, - overload, ) +from typing_extensions import Self + from megatron.energon.epathlib import EPath +from megatron.energon.flavors.base_dataset import DatasetSampleReader, RawSample from megatron.energon.flavors.webdataset.config import ( INDEX_SQLITE_FILENAME, skip_meta_re, @@ -275,7 +277,7 @@ def __getitem__(self, idx: T_index) -> FilteredSample | None: return self._get_item_by_sample_pointer(sample_pointer, idx) -class JoinIndexFileITarReader(ITarReader[int]): +class JoinIndexFileITarReader(ITarReader[int], DatasetSampleReader): """ A concrete ITarReader that reads samples from a join index file (via JoinIndexReader). """ @@ -369,6 +371,16 @@ def __len__(self) -> int: return len(index_reader) + def __iter__(self) -> Generator[RawSample | None, None, None]: + for idx in range(len(self)): + yield self[idx] + + def __enter__(self) -> Self: + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + self.close() + def __str__(self) -> str: return ( f"JoinIndexFileITarReader(" @@ -378,7 +390,7 @@ def __str__(self) -> str: ) -class ShardInfosITarReader(ITarReader[int]): +class ShardInfosITarReader(ITarReader[int], DatasetSampleReader): """ A concrete ITarReader that constructs its internal sample list from a list of ShardInfos. """ @@ -469,6 +481,16 @@ def _get_itar_sample_pointer(self, idx: int) -> ITarSamplePointer: def __len__(self) -> int: return self.shard_count_cumsum[-1] + def __iter__(self) -> Generator[RawSample | None, None, None]: + for idx in range(len(self)): + yield self[idx] + + def __enter__(self) -> Self: + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: + self.close() + def __str__(self) -> str: return ( f"ShardInfosITarReader(" @@ -524,7 +546,6 @@ def _get_itar_sample_pointer(self, sample_key: str) -> ITarSamplePointer: """ Get the ITarSample object for the given index. """ - return self.sqlite_reader.get_sample_pointer_by_key(sample_key) def list_all_samples(self) -> Generator[Tuple[str, int, int], None, None]: @@ -577,25 +598,12 @@ def list_sample_parts( def get_total_size(self) -> int: return self.sqlite_reader.get_total_size() - @overload - def __getitem__(self, key: str) -> Union[FilteredSample, tuple[bytes, SourceInfo]]: ... - - @overload - def __getitem__(self, key: slice) -> "ITarReader": ... - - def __getitem__( - self, key: Union[slice, str] - ) -> Union[FilteredSample, tuple[bytes, SourceInfo], ITarReader]: + def __getitem__(self, key: str) -> Union[FilteredSample, tuple[bytes, SourceInfo]]: """ Either get a sample from the dataset by the sample key including all its entries, or get the bytes of a specific entry by the full filename of the entry inside the tar. """ - if isinstance(key, slice): - # Return a new reader with a sliced samples tensor - raise NotImplementedError("Slicing is not yet implemented") - assert isinstance(key, str), "Invalid argument type for __getitem__" - if self.key_is_full_entryname: m = split_name_re.match(key) if not m: diff --git a/src/megatron/energon/flavors/webdataset/sample_loader.py b/src/megatron/energon/flavors/webdataset/sample_loader.py index df4ad103..da75e3ca 100644 --- a/src/megatron/energon/flavors/webdataset/sample_loader.py +++ b/src/megatron/energon/flavors/webdataset/sample_loader.py @@ -6,8 +6,7 @@ import torch from megatron.energon.edataclass import edataclass -from megatron.energon.flavors.base_dataset import FlexState, SavableDataset -from megatron.energon.flavors.webdataset.itar_reader import ITarReader +from megatron.energon.flavors.base_dataset import DatasetSampleReader, FlexState, SavableDataset from megatron.energon.flavors.webdataset.structs import FilteredSample from megatron.energon.rng import WorkerRng from megatron.energon.worker import WorkerConfig @@ -34,10 +33,10 @@ class SliceState: class WebdatasetSampleLoaderDataset(SavableDataset[RawSampleData]): - """Internal class for loading samples from webdataset slices""" + """Internal class for sampling from random access datasets efficiently (the "core sampler").""" #: The readers for each joined dataset - join_readers: Sequence[ITarReader] + join_readers: Sequence[DatasetSampleReader] #: The offsets of the slice slices to iterate over for the current worker slice_offsets: Optional[Sequence[int]] @@ -83,7 +82,7 @@ class WebdatasetSampleLoaderDataset(SavableDataset[RawSampleData]): def __init__( self, - join_readers: Sequence[ITarReader], + join_readers: Sequence[DatasetSampleReader], workers_sample_slice_offsets: Sequence[Sequence[int]], *, worker_config: WorkerConfig, diff --git a/src/megatron/energon/loader.py b/src/megatron/energon/loader.py index b973b672..02bfbf83 100644 --- a/src/megatron/energon/loader.py +++ b/src/megatron/energon/loader.py @@ -1,11 +1,18 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -from typing import Optional, TypeVar +import re +from typing import Callable, Optional, TypeVar -from megatron.energon.cache import CachePool +from megatron.energon.cache import CachePool, FileStore, SystemFileStore +from megatron.energon.cache.file_store import JsonlFileStore, WebdatasetFileStore from megatron.energon.deprecation import warn_deprecated -from megatron.energon.flavors import SavableDataset +from megatron.energon.epathlib import EPath +from megatron.energon.flavors import SavableDataset, WebdatasetMeta +from megatron.energon.flavors.base_dataset import DatasetSampleReader +from megatron.energon.flavors.jsonl.ijsonl_reader import IJsonlReader +from megatron.energon.flavors.webdataset.itar_reader import ShardInfosITarReader +from megatron.energon.flavors.webdataset.metadata import EnergonDatasetType, get_dataset_type from megatron.energon.savable_loader import BasicDataLoader, SavableDataLoader from megatron.energon.worker import WorkerConfig from megatron.energon.wrappers.gc_dataset import GC_DEFAULT_EVERY_N_ITER @@ -117,3 +124,79 @@ def get_loader( watchdog_initial_timeout_seconds=watchdog_initial_timeout_seconds, fail_on_timeout=fail_on_timeout, ) + + +# Regex for any URL-like string (any protocol) +_url_regex = re.compile(r"^(?P[a-z][a-z0-9+.-]*)://(?P.*)", re.IGNORECASE) + + +def get_file_store( + path: str | EPath, +) -> FileStore[bytes]: + """ + Get a file store for the given path. + + Args: + path: The path to the file store. + + Returns: + The instantiated :class:`megatron.energon.FileStore`. + """ + if isinstance(path, str) and (m := _url_regex.match(path)): + prot = m.group("protocol") + if prot.count("+") == 1: + # filesystem+fs_prot:// + fs_type, fs_prot = prot.split("+") + assert fs_type == "filesystem" + return SystemFileStore(f"{fs_prot}://{m.group('path')}") + elif prot == "filesystem": + # filesystem:// (may be relative or absolute) + fs_type = "filesystem" + return SystemFileStore(m.group("path")) + path = EPath(path) + ds_type = get_dataset_type(path) + if ds_type == EnergonDatasetType.WEBDATASET: + return WebdatasetFileStore(path) + if ds_type == EnergonDatasetType.JSONL: + return JsonlFileStore(path) + if ds_type == EnergonDatasetType.FILESYSTEM: + return SystemFileStore(path) + raise NotImplementedError(f"Unsupported dataset type: {ds_type}") + + +def get_dataset_reader( + path: str | EPath, + *, + split_part: str = "train", + split_config: str | None = None, + part_filter: Callable[[str], bool] | None = None, +) -> DatasetSampleReader: + """ + Get a dataset reader for the given (core) dataset path. + + Args: + path: The path to the (core) dataset. + + Returns: + The instantiated :class:`megatron.energon.DatasetSampleReader`. + """ + path = EPath(path) + ds_type = get_dataset_type(path) + if ds_type == EnergonDatasetType.WEBDATASET: + wds_meta = WebdatasetMeta.from_config( + path=path, + split_part=split_part, + split_config=split_config, + ) + return ShardInfosITarReader( + path, + shard_infos=wds_meta.shards, + part_filter=part_filter, + itar_cache_size=1, + ) + if ds_type == EnergonDatasetType.JSONL: + return IJsonlReader( + path, + index_cache_size=1, + ) + raise NotImplementedError(f"Unsupported dataset type for indexed access: {ds_type}")