diff --git a/pyproject.toml b/pyproject.toml index 793fb161..60803102 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ classifiers = [ dependencies = [ "braceexpand", "click", + "filetype>=1.0.0", "mfusepy", "multi-storage-client>=0.33.0", "numpy", @@ -78,10 +79,6 @@ av_decode = [ "ebmlite>=3.3.1", "av>=14.4.0", ] -# If using guess_content=True for decoding -guess_content = [ - "filetype>=1.0.0", -] tar_patcher = [ "numba", ] diff --git a/src/megatron/energon/av/av_init.py b/src/megatron/energon/av/av_init.py index 28c8a943..893d8de0 100644 --- a/src/megatron/energon/av/av_init.py +++ b/src/megatron/energon/av/av_init.py @@ -1,7 +1,12 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause -from bitstring.bits import BitsType +from typing import Any, TypeAlias + +try: + from bitstring.bits import BitsType +except ImportError: + BitsType: TypeAlias = Any try: # Try importing optional dependencies diff --git a/src/megatron/energon/cache/file_store.py b/src/megatron/energon/cache/file_store.py index 556b0bdd..2e924206 100644 --- a/src/megatron/energon/cache/file_store.py +++ b/src/megatron/energon/cache/file_store.py @@ -148,7 +148,11 @@ def __init__( self, dataset_path: EPath, ): - super().__init__(base_path=dataset_path, key_is_full_entryname=True) + super().__init__( + base_path=dataset_path, + key_is_full_entryname=True, + disable_cache=True, + ) self._media_metadata_available: Optional[bool] = None def get_path(self) -> str: diff --git a/src/megatron/energon/flavors/webdataset/itar_reader.py b/src/megatron/energon/flavors/webdataset/itar_reader.py index 0134f338..b6c837aa 100644 --- a/src/megatron/energon/flavors/webdataset/itar_reader.py +++ b/src/megatron/energon/flavors/webdataset/itar_reader.py @@ -1,6 +1,7 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause +import contextlib from abc import ABC, abstractmethod from bisect import bisect_right from typing import ( @@ -49,6 +50,9 @@ class ITarReader(ABC, Generic[T_index]): part_filter: An optional filter function to select parts of the samples. itar_cache_size: The number of tar readers to keep open at the same time. sample_filter: An optional filter function to select samples by their key. + disable_cache: If True, disables caching of open tar files and opens a fresh tar file + for every read. This mode avoids sharing tar file handles across threads and is + intended to be thread-safe. """ base_path: EPath @@ -57,6 +61,7 @@ class ITarReader(ABC, Generic[T_index]): part_filter: Optional[Callable[[str], bool]] itar_files_cache: Dict[int, ITarFile] sample_filter: Optional[Callable[[str], bool]] + disable_cache: bool def __init__( self, @@ -66,6 +71,7 @@ def __init__( part_filter: Optional[Callable[[str], bool]] = None, itar_cache_size: int = 5, sample_filter: Optional[Callable[[str], bool]] = None, + disable_cache: bool = False, ): assert len(tar_filenames) == len(tar_filepaths), ( f"tar_filenames length ({len(tar_filenames)}) does not match " @@ -78,6 +84,7 @@ def __init__( self.itar_files_cache = {} self.itar_cache_size = itar_cache_size self.sample_filter = sample_filter + self.disable_cache = disable_cache @abstractmethod def __len__(self) -> int: @@ -124,6 +131,24 @@ def _get_itarfile_cached(self, tar_file_id: int) -> ITarFile: return self.itar_files_cache[tar_file_id] + @contextlib.contextmanager + def _open_itarfile(self, tar_file_id: int) -> Generator[ITarFile, None, None]: + """ + Context manager to access an ITarFile for a shard. + + - If caching is enabled, yields the cached ITarFile and does not close it. + - If caching is disabled, opens a fresh ITarFile for the duration of the context. + """ + if not self.disable_cache: + yield self._get_itarfile_cached(tar_file_id) + return + + # Open a fresh tar file handle for this access. This avoids sharing file positions + # and tarfile internal state across threads. + with self.tar_filepaths[tar_file_id].open(mode="rb") as file_object: + with ITarFile.open(fileobj=file_object, mode="r:") as tar_file: + yield tar_file + def _get_part_by_raw_sample_pointer( self, raw_sample_pointer: ITarRawSamplePartPointer, @@ -139,15 +164,14 @@ def _get_part_by_raw_sample_pointer( The raw data bytes. """ - # Open the tar file (cached) - tar_file = self._get_itarfile_cached(raw_sample_pointer.tar_file_id) shard_name = self.tar_filenames[raw_sample_pointer.tar_file_id] - # Get the raw data from the tar file - rest = tar_file.fileobj.tell() - tar_file.fileobj.seek(raw_sample_pointer.raw_byte_offset) - raw_data = tar_file.fileobj.read(raw_sample_pointer.raw_byte_size) - tar_file.fileobj.seek(rest) + with self._open_itarfile(raw_sample_pointer.tar_file_id) as tar_file: + # Get the raw data from the tar file + rest = tar_file.fileobj.tell() + tar_file.fileobj.seek(raw_sample_pointer.raw_byte_offset) + raw_data = tar_file.fileobj.read(raw_sample_pointer.raw_byte_size) + tar_file.fileobj.seek(rest) return raw_data, SourceInfo( dataset_path=self.base_path, @@ -174,57 +198,56 @@ def _get_item_by_sample_pointer( The sample or None if the sample is not found. """ - # Open the tar file (cached) - tar_file = self._get_itarfile_cached(sample_pointer.tar_file_id) shard_name = self.tar_filenames[sample_pointer.tar_file_id] sample_base_name = None sample_name = None group_parts: Dict[str, bytes] = {} file_names: list[str] = [] - # Position the tar file at the correct offset - tar_file.offset = sample_pointer.byte_offset + with self._open_itarfile(sample_pointer.tar_file_id) as tar_file: + # Position the tar file at the correct offset + tar_file.offset = sample_pointer.byte_offset - while tar_file.offset < sample_pointer.byte_offset + sample_pointer.byte_size: - tarinfo = tar_file.next() - if tarinfo is None: - raise ValueError( - f"Unexpected end of tar file: {self.tar_filenames[sample_pointer.tar_file_id]}" - ) - fname = tarinfo.name - if not tarinfo.isfile() or fname is None: - continue - if skip_meta_re.match(fname): - continue - - # Extract the base_name and extension - m = split_name_re.match(fname) - if not m: - continue - cur_base_name, cur_ext = m.groups() - - if sample_base_name is None: - sample_base_name = cur_base_name - sample_name = f"{shard_name}/{cur_base_name}" - if self.sample_filter is not None and not self.sample_filter(sample_name): - return None - else: - if sample_base_name != cur_base_name: + while tar_file.offset < sample_pointer.byte_offset + sample_pointer.byte_size: + tarinfo = tar_file.next() + if tarinfo is None: raise ValueError( - f"Inconsistent sample base name: {sample_base_name} vs {cur_base_name}" + f"Unexpected end of tar file: {self.tar_filenames[sample_pointer.tar_file_id]}" ) - - if entry_match_fn is not None: - # If entry_match_fn is provided, use it to determine if we should take this entry - take_entry = entry_match_fn(fname) - else: - # If no entry_match_fn is provided, use the part_filter to determine if we should take this entry - take_entry = self.part_filter is None or self.part_filter(cur_ext) - - if take_entry: - member_bytes = tar_file.extractfile(tarinfo).read() - group_parts[cur_ext] = member_bytes - file_names.append(fname) + fname = tarinfo.name + if not tarinfo.isfile() or fname is None: + continue + if skip_meta_re.match(fname): + continue + + # Extract the base_name and extension + m = split_name_re.match(fname) + if not m: + continue + cur_base_name, cur_ext = m.groups() + + if sample_base_name is None: + sample_base_name = cur_base_name + sample_name = f"{shard_name}/{cur_base_name}" + if self.sample_filter is not None and not self.sample_filter(sample_name): + return None + else: + if sample_base_name != cur_base_name: + raise ValueError( + f"Inconsistent sample base name: {sample_base_name} vs {cur_base_name}" + ) + + if entry_match_fn is not None: + # If entry_match_fn is provided, use it to determine if we should take this entry + take_entry = entry_match_fn(fname) + else: + # If no entry_match_fn is provided, use the part_filter to determine if we should take this entry + take_entry = self.part_filter is None or self.part_filter(cur_ext) + + if take_entry: + member_bytes = tar_file.extractfile(tarinfo).read() + group_parts[cur_ext] = member_bytes + file_names.append(fname) if sample_base_name is None: raise ValueError(f"No valid files found in sample {sample_pointer}") @@ -271,7 +294,12 @@ def __init__( part_filter: Optional[Callable[[str], bool]] = None, itar_cache_size: int = 5, sample_filter: Optional[Callable[[str], bool]] = None, + disable_cache: bool = False, ): + if disable_cache: + raise NotImplementedError( + "disable_cache is not supported for JoinIndexFileITarReader yet" + ) self.index_file = index_file self.column = column @@ -288,6 +316,7 @@ def __init__( part_filter=part_filter, itar_cache_size=itar_cache_size, sample_filter=sample_filter, + disable_cache=disable_cache, ) def _get_join_index_reader_cached(self, sample_idx: int) -> JoinIndexReader: @@ -366,7 +395,10 @@ def __init__( part_filter: Optional[Callable[[str], bool]] = None, itar_cache_size: int = 5, sample_filter: Optional[Callable[[str], bool]] = None, + disable_cache: bool = False, ): + if disable_cache: + raise NotImplementedError("disable_cache is not supported for ShardInfosITarReader yet") # Build the tar_filenames and tar_filepaths from shard_infos, # constructing the samples tensor as we go. cur_tar_files: Dict[str, Tuple[int, EPath]] = {} @@ -404,6 +436,7 @@ def __init__( part_filter=part_filter, itar_cache_size=itar_cache_size, sample_filter=sample_filter, + disable_cache=disable_cache, ) def _get_itar_sample_pointer(self, idx: int) -> ITarSamplePointer: @@ -460,6 +493,7 @@ def __init__( itar_cache_size: int = 5, sample_filter: Optional[Callable[[str], bool]] = None, key_is_full_entryname: bool = False, + disable_cache: bool = False, ): from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME from megatron.energon.flavors.webdataset.indexing import SqliteIndexReader @@ -483,6 +517,7 @@ def __init__( part_filter=part_filter, itar_cache_size=itar_cache_size, sample_filter=sample_filter, + disable_cache=disable_cache, ) def _get_itar_sample_pointer(self, sample_key: str) -> ITarSamplePointer: