Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ classifiers = [
dependencies = [
"braceexpand",
"click",
"filetype>=1.0.0",
"mfusepy",
"multi-storage-client>=0.33.0",
"numpy",
Expand Down Expand Up @@ -78,10 +79,6 @@ av_decode = [
"ebmlite>=3.3.1",
"av>=14.4.0",
]
# If using guess_content=True for decoding
guess_content = [
"filetype>=1.0.0",
]
tar_patcher = [
"numba",
]
Expand Down
7 changes: 6 additions & 1 deletion src/megatron/energon/av/av_init.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

from bitstring.bits import BitsType
from typing import Any, TypeAlias

try:
from bitstring.bits import BitsType
except ImportError:
BitsType: TypeAlias = Any

try:
# Try importing optional dependencies
Expand Down
6 changes: 5 additions & 1 deletion src/megatron/energon/cache/file_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,11 @@ def __init__(
self,
dataset_path: EPath,
):
super().__init__(base_path=dataset_path, key_is_full_entryname=True)
super().__init__(
base_path=dataset_path,
key_is_full_entryname=True,
disable_cache=True,
)
self._media_metadata_available: Optional[bool] = None

def get_path(self) -> str:
Expand Down
133 changes: 84 additions & 49 deletions src/megatron/energon/flavors/webdataset/itar_reader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause

import contextlib
from abc import ABC, abstractmethod
from bisect import bisect_right
from typing import (
Expand Down Expand Up @@ -49,6 +50,9 @@ class ITarReader(ABC, Generic[T_index]):
part_filter: An optional filter function to select parts of the samples.
itar_cache_size: The number of tar readers to keep open at the same time.
sample_filter: An optional filter function to select samples by their key.
disable_cache: If True, disables caching of open tar files and opens a fresh tar file
for every read. This mode avoids sharing tar file handles across threads and is
intended to be thread-safe.
"""

base_path: EPath
Expand All @@ -57,6 +61,7 @@ class ITarReader(ABC, Generic[T_index]):
part_filter: Optional[Callable[[str], bool]]
itar_files_cache: Dict[int, ITarFile]
sample_filter: Optional[Callable[[str], bool]]
disable_cache: bool

def __init__(
self,
Expand All @@ -66,6 +71,7 @@ def __init__(
part_filter: Optional[Callable[[str], bool]] = None,
itar_cache_size: int = 5,
sample_filter: Optional[Callable[[str], bool]] = None,
disable_cache: bool = False,
):
assert len(tar_filenames) == len(tar_filepaths), (
f"tar_filenames length ({len(tar_filenames)}) does not match "
Expand All @@ -78,6 +84,7 @@ def __init__(
self.itar_files_cache = {}
self.itar_cache_size = itar_cache_size
self.sample_filter = sample_filter
self.disable_cache = disable_cache

@abstractmethod
def __len__(self) -> int:
Expand Down Expand Up @@ -124,6 +131,24 @@ def _get_itarfile_cached(self, tar_file_id: int) -> ITarFile:

return self.itar_files_cache[tar_file_id]

@contextlib.contextmanager
def _open_itarfile(self, tar_file_id: int) -> Generator[ITarFile, None, None]:
"""
Context manager to access an ITarFile for a shard.

- If caching is enabled, yields the cached ITarFile and does not close it.
- If caching is disabled, opens a fresh ITarFile for the duration of the context.
"""
if not self.disable_cache:
yield self._get_itarfile_cached(tar_file_id)
return

# Open a fresh tar file handle for this access. This avoids sharing file positions
# and tarfile internal state across threads.
with self.tar_filepaths[tar_file_id].open(mode="rb") as file_object:
with ITarFile.open(fileobj=file_object, mode="r:") as tar_file:
yield tar_file

def _get_part_by_raw_sample_pointer(
self,
raw_sample_pointer: ITarRawSamplePartPointer,
Expand All @@ -139,15 +164,14 @@ def _get_part_by_raw_sample_pointer(
The raw data bytes.
"""

# Open the tar file (cached)
tar_file = self._get_itarfile_cached(raw_sample_pointer.tar_file_id)
shard_name = self.tar_filenames[raw_sample_pointer.tar_file_id]

# Get the raw data from the tar file
rest = tar_file.fileobj.tell()
tar_file.fileobj.seek(raw_sample_pointer.raw_byte_offset)
raw_data = tar_file.fileobj.read(raw_sample_pointer.raw_byte_size)
tar_file.fileobj.seek(rest)
with self._open_itarfile(raw_sample_pointer.tar_file_id) as tar_file:
# Get the raw data from the tar file
rest = tar_file.fileobj.tell()
tar_file.fileobj.seek(raw_sample_pointer.raw_byte_offset)
raw_data = tar_file.fileobj.read(raw_sample_pointer.raw_byte_size)
tar_file.fileobj.seek(rest)

return raw_data, SourceInfo(
dataset_path=self.base_path,
Expand All @@ -174,57 +198,56 @@ def _get_item_by_sample_pointer(
The sample or None if the sample is not found.
"""

# Open the tar file (cached)
tar_file = self._get_itarfile_cached(sample_pointer.tar_file_id)
shard_name = self.tar_filenames[sample_pointer.tar_file_id]
sample_base_name = None
sample_name = None
group_parts: Dict[str, bytes] = {}
file_names: list[str] = []

# Position the tar file at the correct offset
tar_file.offset = sample_pointer.byte_offset
with self._open_itarfile(sample_pointer.tar_file_id) as tar_file:
# Position the tar file at the correct offset
tar_file.offset = sample_pointer.byte_offset

while tar_file.offset < sample_pointer.byte_offset + sample_pointer.byte_size:
tarinfo = tar_file.next()
if tarinfo is None:
raise ValueError(
f"Unexpected end of tar file: {self.tar_filenames[sample_pointer.tar_file_id]}"
)
fname = tarinfo.name
if not tarinfo.isfile() or fname is None:
continue
if skip_meta_re.match(fname):
continue

# Extract the base_name and extension
m = split_name_re.match(fname)
if not m:
continue
cur_base_name, cur_ext = m.groups()

if sample_base_name is None:
sample_base_name = cur_base_name
sample_name = f"{shard_name}/{cur_base_name}"
if self.sample_filter is not None and not self.sample_filter(sample_name):
return None
else:
if sample_base_name != cur_base_name:
while tar_file.offset < sample_pointer.byte_offset + sample_pointer.byte_size:
tarinfo = tar_file.next()
if tarinfo is None:
raise ValueError(
f"Inconsistent sample base name: {sample_base_name} vs {cur_base_name}"
f"Unexpected end of tar file: {self.tar_filenames[sample_pointer.tar_file_id]}"
)

if entry_match_fn is not None:
# If entry_match_fn is provided, use it to determine if we should take this entry
take_entry = entry_match_fn(fname)
else:
# If no entry_match_fn is provided, use the part_filter to determine if we should take this entry
take_entry = self.part_filter is None or self.part_filter(cur_ext)

if take_entry:
member_bytes = tar_file.extractfile(tarinfo).read()
group_parts[cur_ext] = member_bytes
file_names.append(fname)
fname = tarinfo.name
if not tarinfo.isfile() or fname is None:
continue
if skip_meta_re.match(fname):
continue

# Extract the base_name and extension
m = split_name_re.match(fname)
if not m:
continue
cur_base_name, cur_ext = m.groups()

if sample_base_name is None:
sample_base_name = cur_base_name
sample_name = f"{shard_name}/{cur_base_name}"
if self.sample_filter is not None and not self.sample_filter(sample_name):
return None
else:
if sample_base_name != cur_base_name:
raise ValueError(
f"Inconsistent sample base name: {sample_base_name} vs {cur_base_name}"
)

if entry_match_fn is not None:
# If entry_match_fn is provided, use it to determine if we should take this entry
take_entry = entry_match_fn(fname)
else:
# If no entry_match_fn is provided, use the part_filter to determine if we should take this entry
take_entry = self.part_filter is None or self.part_filter(cur_ext)

if take_entry:
member_bytes = tar_file.extractfile(tarinfo).read()
group_parts[cur_ext] = member_bytes
file_names.append(fname)
if sample_base_name is None:
raise ValueError(f"No valid files found in sample {sample_pointer}")

Expand Down Expand Up @@ -271,7 +294,12 @@ def __init__(
part_filter: Optional[Callable[[str], bool]] = None,
itar_cache_size: int = 5,
sample_filter: Optional[Callable[[str], bool]] = None,
disable_cache: bool = False,
):
if disable_cache:
raise NotImplementedError(
"disable_cache is not supported for JoinIndexFileITarReader yet"
)
self.index_file = index_file
self.column = column

Expand All @@ -288,6 +316,7 @@ def __init__(
part_filter=part_filter,
itar_cache_size=itar_cache_size,
sample_filter=sample_filter,
disable_cache=disable_cache,
)

def _get_join_index_reader_cached(self, sample_idx: int) -> JoinIndexReader:
Expand Down Expand Up @@ -366,7 +395,10 @@ def __init__(
part_filter: Optional[Callable[[str], bool]] = None,
itar_cache_size: int = 5,
sample_filter: Optional[Callable[[str], bool]] = None,
disable_cache: bool = False,
):
if disable_cache:
raise NotImplementedError("disable_cache is not supported for ShardInfosITarReader yet")
# Build the tar_filenames and tar_filepaths from shard_infos,
# constructing the samples tensor as we go.
cur_tar_files: Dict[str, Tuple[int, EPath]] = {}
Expand Down Expand Up @@ -404,6 +436,7 @@ def __init__(
part_filter=part_filter,
itar_cache_size=itar_cache_size,
sample_filter=sample_filter,
disable_cache=disable_cache,
)

def _get_itar_sample_pointer(self, idx: int) -> ITarSamplePointer:
Expand Down Expand Up @@ -460,6 +493,7 @@ def __init__(
itar_cache_size: int = 5,
sample_filter: Optional[Callable[[str], bool]] = None,
key_is_full_entryname: bool = False,
disable_cache: bool = False,
):
from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME
from megatron.energon.flavors.webdataset.indexing import SqliteIndexReader
Expand All @@ -483,6 +517,7 @@ def __init__(
part_filter=part_filter,
itar_cache_size=itar_cache_size,
sample_filter=sample_filter,
disable_cache=disable_cache,
)

def _get_itar_sample_pointer(self, sample_key: str) -> ITarSamplePointer:
Expand Down