Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/megatron/energon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
CrudeJsonlDatasetFactory,
CrudeSample,
CrudeWebdataset,
DatasetSampleReader,
DefaultCrudeJsonlDatasetFactory,
DefaultDecoderWebdatasetFactory,
DefaultGenericWebdatasetFactory,
Expand All @@ -47,6 +48,7 @@
MultiChoiceVQAWebdataset,
OCRSample,
OCRWebdataset,
RawSample,
Sample,
SampleDecoder,
SavableDataset,
Expand All @@ -61,7 +63,7 @@
VQASample,
VQAWebdataset,
)
from megatron.energon.loader import get_loader, get_savable_loader
from megatron.energon.loader import get_dataset_reader, get_loader, get_savable_loader
from megatron.energon.metadataset import (
DatasetLoader,
DatasetLoaderInterface,
Expand Down Expand Up @@ -136,6 +138,7 @@
"CrudeWebdataset",
"DatasetLoader",
"DatasetLoaderInterface",
"DatasetSampleReader",
"DecodeFileStore",
"DefaultCrudeJsonlDatasetFactory",
"DefaultDecoderWebdatasetFactory",
Expand All @@ -152,6 +155,7 @@
"GcDataset",
"generic_batch",
"generic_concat",
"get_dataset_reader",
"get_loader",
"get_savable_loader",
"get_train_dataset",
Expand Down Expand Up @@ -187,6 +191,7 @@
"prepare_metadataset",
"RepeatDataset",
"reraise_exception",
"RawSample",
"Sample",
"SampleDecoder",
"SavableDataLoader",
Expand Down
4 changes: 4 additions & 0 deletions src/megatron/energon/flavors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from megatron.energon.flavors.base_dataset import (
BaseCoreDatasetFactory,
DatasetSampleReader,
PinMemoryMixin,
RawSample,
Sample,
SavableDataset,
)
Expand Down Expand Up @@ -55,6 +57,7 @@
"CrudeJsonlDatasetFactory",
"CrudeSample",
"CrudeWebdataset",
"DatasetSampleReader",
"DefaultCrudeJsonlDatasetFactory",
"DefaultDecoderWebdatasetFactory",
"DefaultGenericWebdatasetFactory",
Expand All @@ -72,6 +75,7 @@
"OCRSample",
"OCRWebdataset",
"PinMemoryMixin",
"RawSample",
"Sample",
"SampleDecoder",
"SavableDataset",
Expand Down
29 changes: 29 additions & 0 deletions src/megatron/energon/flavors/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Callable,
ClassVar,
Dict,
Generator,
Generic,
List,
Optional,
Expand All @@ -37,6 +38,10 @@
T_sample = TypeVar("T_sample", covariant=True)
T = TypeVar("T", covariant=True)

# Must contain at least the fields __key__, __restore_key__, __sources__.
# Other fields contain the data.
RawSample = Dict[str, Any]


class PinMemoryMixin:
"""A mixin class providing a generic `pin_memory` function."""
Expand Down Expand Up @@ -395,6 +400,30 @@ def restore_sample(self, restore_key: Tuple[Union[str, int, tuple], ...]) -> T_s
)


class DatasetSampleReader(ABC):
"""
Interface for dataset readers for random access to samples.
"""

@abstractmethod
def __getitem__(self, idx: int) -> RawSample | None: ...

@abstractmethod
def __len__(self) -> int: ...

@abstractmethod
def __iter__(self) -> Generator[RawSample | None, None, None]: ...

@abstractmethod
def close(self) -> None: ...

def __enter__(self) -> Self:
return self

def __exit__(self, exc_type, exc_value, traceback) -> None:
self.close()


class BaseCoreDatasetFactory(Generic[T_sample], ABC):
"""Base type for an inner dataset sample loader. This factory can be used to construct a sample loader, or for
joining in a joined dataset."""
Expand Down
23 changes: 21 additions & 2 deletions src/megatron/energon/flavors/jsonl/ijsonl_reader.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

from abc import ABC
from typing import (
Callable,
Generator,
Optional,
Tuple,
TypeVar,
overload,
)

from typing_extensions import Self

from megatron.energon.epathlib import EPath
from megatron.energon.flavors.base_dataset import DatasetSampleReader, RawSample
from megatron.energon.flavors.jsonl.ijsonl import (
CachedIJsonlOffsetReader,
IJsonlFile,
Expand All @@ -23,7 +26,7 @@
T_index = TypeVar("T_index", covariant=False)


class IJsonlReader(ABC):
class IJsonlReader(DatasetSampleReader):
"""
Class for reading indexed jsonl files containing json samples.

Expand Down Expand Up @@ -102,6 +105,12 @@ def _get_item_by_sample_pointer(
json=json_data,
)

@overload
def __getitem__(self, idx: int) -> FilteredSample | None: ...

@overload
def __getitem__(self, idx: str) -> tuple[bytes, SourceInfo] | None: ...

def __getitem__(self, idx: int | str) -> FilteredSample | tuple[bytes, SourceInfo] | None:
"""
Get a sample from the dataset.
Expand Down Expand Up @@ -136,6 +145,10 @@ def __getitem__(self, idx: int | str) -> FilteredSample | tuple[bytes, SourceInf
else:
return sample

def __iter__(self) -> Generator[RawSample | None, None, None]:
for idx in range(len(self)):
yield self[idx]

def list_all_samples(self) -> Generator[Tuple[str, int, int], None, None]:
"""List all samples in the jsonl file.

Expand Down Expand Up @@ -184,6 +197,12 @@ def list_sample_parts(self, sample_key: str) -> Generator[Tuple[str, int, int],
def get_total_size(self) -> int:
return self.cached_offset_reader.get_total_size()

def __enter__(self) -> Self:
return self

def __exit__(self, exc_type, exc_value, traceback) -> None:
self.close()

def close(self):
if self.ijsonl_file is not None:
self.ijsonl_file.close()
Expand Down
44 changes: 26 additions & 18 deletions src/megatron/energon/flavors/webdataset/itar_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
Tuple,
TypeVar,
Union,
overload,
)

from typing_extensions import Self

from megatron.energon.epathlib import EPath
from megatron.energon.flavors.base_dataset import DatasetSampleReader, RawSample
from megatron.energon.flavors.webdataset.config import (
INDEX_SQLITE_FILENAME,
skip_meta_re,
Expand Down Expand Up @@ -275,7 +277,7 @@ def __getitem__(self, idx: T_index) -> FilteredSample | None:
return self._get_item_by_sample_pointer(sample_pointer, idx)


class JoinIndexFileITarReader(ITarReader[int]):
class JoinIndexFileITarReader(ITarReader[int], DatasetSampleReader):
"""
A concrete ITarReader that reads samples from a join index file (via JoinIndexReader).
"""
Expand Down Expand Up @@ -369,6 +371,16 @@ def __len__(self) -> int:

return len(index_reader)

def __iter__(self) -> Generator[RawSample | None, None, None]:
for idx in range(len(self)):
yield self[idx]

def __enter__(self) -> Self:
return self

def __exit__(self, exc_type, exc_value, traceback) -> None:
self.close()

def __str__(self) -> str:
return (
f"JoinIndexFileITarReader("
Expand All @@ -378,7 +390,7 @@ def __str__(self) -> str:
)


class ShardInfosITarReader(ITarReader[int]):
class ShardInfosITarReader(ITarReader[int], DatasetSampleReader):
"""
A concrete ITarReader that constructs its internal sample list from a list of ShardInfos.
"""
Expand Down Expand Up @@ -469,6 +481,16 @@ def _get_itar_sample_pointer(self, idx: int) -> ITarSamplePointer:
def __len__(self) -> int:
return self.shard_count_cumsum[-1]

def __iter__(self) -> Generator[RawSample | None, None, None]:
for idx in range(len(self)):
yield self[idx]

def __enter__(self) -> Self:
return self

def __exit__(self, exc_type, exc_value, traceback) -> None:
self.close()

def __str__(self) -> str:
return (
f"ShardInfosITarReader("
Expand Down Expand Up @@ -524,7 +546,6 @@ def _get_itar_sample_pointer(self, sample_key: str) -> ITarSamplePointer:
"""
Get the ITarSample object for the given index.
"""

return self.sqlite_reader.get_sample_pointer_by_key(sample_key)

def list_all_samples(self) -> Generator[Tuple[str, int, int], None, None]:
Expand Down Expand Up @@ -577,25 +598,12 @@ def list_sample_parts(
def get_total_size(self) -> int:
return self.sqlite_reader.get_total_size()

@overload
def __getitem__(self, key: str) -> Union[FilteredSample, tuple[bytes, SourceInfo]]: ...

@overload
def __getitem__(self, key: slice) -> "ITarReader": ...

def __getitem__(
self, key: Union[slice, str]
) -> Union[FilteredSample, tuple[bytes, SourceInfo], ITarReader]:
def __getitem__(self, key: str) -> Union[FilteredSample, tuple[bytes, SourceInfo]]:
"""
Either get a sample from the dataset by the sample key including all its entries,
or get the bytes of a specific entry by the full filename of the entry inside the tar.
"""

if isinstance(key, slice):
# Return a new reader with a sliced samples tensor
raise NotImplementedError("Slicing is not yet implemented")
assert isinstance(key, str), "Invalid argument type for __getitem__"

if self.key_is_full_entryname:
m = split_name_re.match(key)
if not m:
Expand Down
9 changes: 4 additions & 5 deletions src/megatron/energon/flavors/webdataset/sample_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
import torch

from megatron.energon.edataclass import edataclass
from megatron.energon.flavors.base_dataset import FlexState, SavableDataset
from megatron.energon.flavors.webdataset.itar_reader import ITarReader
from megatron.energon.flavors.base_dataset import DatasetSampleReader, FlexState, SavableDataset
from megatron.energon.flavors.webdataset.structs import FilteredSample
from megatron.energon.rng import WorkerRng
from megatron.energon.worker import WorkerConfig
Expand All @@ -34,10 +33,10 @@ class SliceState:


class WebdatasetSampleLoaderDataset(SavableDataset[RawSampleData]):
"""Internal class for loading samples from webdataset slices"""
"""Internal class for sampling from random access datasets efficiently (the "core sampler")."""

#: The readers for each joined dataset
join_readers: Sequence[ITarReader]
join_readers: Sequence[DatasetSampleReader]

#: The offsets of the slice slices to iterate over for the current worker
slice_offsets: Optional[Sequence[int]]
Expand Down Expand Up @@ -83,7 +82,7 @@ class WebdatasetSampleLoaderDataset(SavableDataset[RawSampleData]):

def __init__(
self,
join_readers: Sequence[ITarReader],
join_readers: Sequence[DatasetSampleReader],
workers_sample_slice_offsets: Sequence[Sequence[int]],
*,
worker_config: WorkerConfig,
Expand Down
Loading
Loading