From 0b0494126e13bde77e8517173ac50746f43ea114 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Thu, 30 Apr 2026 13:30:53 +0200 Subject: [PATCH 1/5] Implement energon prepare with a remote dataset (using local temp) --- src/megatron/energon/epathlib/epath.py | 17 +- .../energon/flavors/webdataset/indexing.py | 1 - .../energon/flavors/webdataset/prepare.py | 381 ++++++++++-------- src/megatron/energon/tools/prepare.py | 12 + src/megatron/energon/tools/prepare_media.py | 14 + tests/s3_emulator/handler.py | 193 +++++++-- tests/s3_emulator/state.py | 29 ++ tests/test_dataset.py | 55 ++- tests/test_epathlib.py | 57 ++- 9 files changed, 547 insertions(+), 212 deletions(-) diff --git a/src/megatron/energon/epathlib/epath.py b/src/megatron/energon/epathlib/epath.py index 6357262a..d11f0bcc 100644 --- a/src/megatron/energon/epathlib/epath.py +++ b/src/megatron/energon/epathlib/epath.py @@ -292,12 +292,27 @@ def mkdir(self, exist_ok: bool = True, parents: bool = False): def glob(self, pattern) -> Generator["EPath", None, None]: search_path_pattern = (self / pattern)._internal_str_path + # MSC S3 glob matches keys like ``bucket/key``; a leading ``/`` breaks wcmatch (pattern + # ``/b/**`` never matches ``b/parts/x``). Returned keys may repeat the bucket prefix; strip + # it before joining with ``internal_path`` so we do not get ``/b/b/parts/...``. + if not self.is_local() and search_path_pattern.startswith("/"): + search_path_pattern = search_path_pattern.lstrip("/") + + root_prefix = str(self.internal_path).lstrip("/") for path in self.fs.glob(search_path_pattern): assert isinstance(path, str) + rel = path + if not self.is_local() and root_prefix: + pfx = root_prefix + "/" + if rel.startswith(pfx): + rel = rel[len(pfx) :] + elif rel == root_prefix: + rel = "." + new_path = EPath(self) - new_path.internal_path = self._resolve(self.internal_path / PurePosixPath(path)) + new_path.internal_path = self._resolve(self.internal_path / PurePosixPath(rel)) yield new_path diff --git a/src/megatron/energon/flavors/webdataset/indexing.py b/src/megatron/energon/flavors/webdataset/indexing.py index 4e5cebcd..78b9f4f1 100644 --- a/src/megatron/energon/flavors/webdataset/indexing.py +++ b/src/megatron/energon/flavors/webdataset/indexing.py @@ -70,7 +70,6 @@ def __init__( # Initialize SQLite connection # Only supporting local file system, because sqlite does not support remote file systems. - # TODO: Implement remote file systems. Maybe create locally in tmp then upload? path = self.sqlite_path.local_path() path.parent.mkdir(parents=True, exist_ok=True) self.db = sqlite3.connect(path) diff --git a/src/megatron/energon/flavors/webdataset/prepare.py b/src/megatron/energon/flavors/webdataset/prepare.py index d3a16223..eeae5699 100644 --- a/src/megatron/energon/flavors/webdataset/prepare.py +++ b/src/megatron/energon/flavors/webdataset/prepare.py @@ -6,8 +6,10 @@ import logging import random import re +import shutil import sys import tarfile +import tempfile import uuid from pathlib import Path from typing import ( @@ -115,6 +117,7 @@ class SqliteIndexWriterAggregator( reset_tables: bool media_metadata_written: int progress_on_media: bool + sqlite_local_build_path: Optional[Path] def __init__( self, @@ -127,6 +130,7 @@ def __init__( media_filter: Optional[MediaFilterConfig] = None, reset_tables: bool = True, progress_on_media: bool = False, + sqlite_local_build_path: Optional[Path] = None, ): self.sqlite_path = sqlite_path self.total_tasks = total_tasks @@ -140,6 +144,7 @@ def __init__( self.reset_tables = reset_tables self.media_metadata_written = 0 self.progress_on_media = progress_on_media + self.sqlite_local_build_path = sqlite_local_build_path if progress_fn is not None: self.prog_iter = progress_fn(iter(range(self.total_tasks)), self.total_tasks) @@ -147,8 +152,11 @@ def __init__( self.prog_iter = iter(range(self.total_tasks)) def on_start(self, aggregator_pool: AggregatorPool) -> None: + local_sqlite = self.sqlite_path + if self.sqlite_local_build_path is not None: + local_sqlite = EPath(self.sqlite_local_build_path) self.writer = SqliteIndexWriter( - self.sqlite_path, + local_sqlite, enable_sample_tables=self.enable_sample_tables, enable_media_metadata=self.enable_media_metadata, reset_tables=self.reset_tables, @@ -213,6 +221,9 @@ def on_finish(self, aggregator_pool: AggregatorPool) -> None: patterns=",".join(self.media_filter.patterns), ) self.writer.close() + if self.sqlite_local_build_path is not None: + EPath(self.sqlite_local_build_path).copy(self.sqlite_path) + self.sqlite_local_build_path.unlink(missing_ok=True) def get_final_result_data( self, @@ -489,6 +500,7 @@ def prepare_dataset( tar_index_only: bool = False, media_filter: Optional[MediaFilterConfig] = None, fix_duplicates: bool = False, + index_sqlite_tmp_path: Optional[Path] = None, ) -> Tuple[Set[str], List[Tuple[str, int]]]: """ Preprocess the shards and write the split config. Preprocessing is done in parallel. @@ -507,6 +519,9 @@ def prepare_dataset( tar_index_only: Only create tar-index, then exit media_filter: Media filter configuration fix_duplicates: If True, fix duplicate keys in the dataset by renaming the files in the shards. + index_sqlite_tmp_path: When ``parent_path`` is remote, directory used to build ``index.sqlite`` + locally before upload. If omitted, a new directory under ``/tmp`` is created and removed + after a successful run. Returns: The set of all parts found in the shards. But at most 50. @@ -562,179 +577,195 @@ def prepare_dataset( else: print("No duplicate keys found, continuing.") - aggregator = SqliteIndexWriterAggregator( - parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME, - total_tasks=len(paths), - progress_fn=progress_fn, - enable_media_metadata=media_filter is not None, - media_filter=media_filter, - ) - - process_tar = functools.partial( - cls._preprocess_tar, - shard_to_idx=shard_to_idx, - parent_path=parent_path, - max_parts=50, - media_filter=media_filter, - ) - - pool = AggregatorPool( - num_workers=workers, - user_produce_data=process_tar, - aggregator=aggregator, - batch_size=INDEX_BATCH_SIZE, - ) - - for path in paths: - pool.submit_task(path) + owns_remote_sqlite_tmp = False + remote_sqlite_tmp_dir: Optional[Path] = None + if not parent_path.is_local(): + if index_sqlite_tmp_path is None: + remote_sqlite_tmp_dir = Path(tempfile.mkdtemp(dir="/tmp", prefix="energon-prepare-")) + index_sqlite_tmp_path = remote_sqlite_tmp_dir / INDEX_SQLITE_FILENAME + owns_remote_sqlite_tmp = True + else: + index_sqlite_tmp_path = None try: - shards, found_parts, had_update = pool.process() - except DuplicateSampleKeyError as error: - print("The data contains duplicate keys (e.g. same filename in different shards).") - print(f'Example duplicate key: "{error.sample_key}"') - print() - print( - "Energon does not support duplicate keys anymore, but we offer a tool to fix your dataset. " - "Run `energon prepare` with `--fix-duplicates` to fix your dataset. Inside each tar, it will " - "put each file in a subfolder with the shard name like `shard_0/filename.ext`." + aggregator = SqliteIndexWriterAggregator( + parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME, + total_tasks=len(paths), + progress_fn=progress_fn, + enable_media_metadata=media_filter is not None, + media_filter=media_filter, + sqlite_local_build_path=index_sqlite_tmp_path, ) - if (parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME).is_file(): - (parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME).unlink() + process_tar = functools.partial( + cls._preprocess_tar, + shard_to_idx=shard_to_idx, + parent_path=parent_path, + max_parts=50, + media_filter=media_filter, + ) - sys.exit(1) + pool = AggregatorPool( + num_workers=workers, + user_produce_data=process_tar, + aggregator=aggregator, + batch_size=INDEX_BATCH_SIZE, + ) + + for path in paths: + pool.submit_task(path) - # Fix permissions if needed - if fix_local_permissions: try: - Path(str(parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME)).chmod(file_perms) - except OSError: - pass + shards, found_parts, had_update = pool.process() + except DuplicateSampleKeyError as error: + print("The data contains duplicate keys (e.g. same filename in different shards).") + print(f'Example duplicate key: "{error.sample_key}"') + print() + print( + "Energon does not support duplicate keys anymore, but we offer a tool to fix your dataset. " + "Run `energon prepare` with `--fix-duplicates` to fix your dataset. Inside each tar, it will " + "put each file in a subfolder with the shard name like `shard_0/filename.ext`." + ) + + if (parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME).is_file(): + (parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME).unlink() - if had_update: - logger.info("Regenerating dataset UUID...") - with (parent_path / MAIN_FOLDER_NAME / INDEX_UUID_FILENAME).open("w") as f: - f.write(str(uuid.uuid4())) + sys.exit(1) # Fix permissions if needed if fix_local_permissions: try: - (parent_path / MAIN_FOLDER_NAME / INDEX_UUID_FILENAME).local_path().chmod( - file_perms - ) + Path(str(parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME)).chmod(file_perms) except OSError: pass - json_info_config = parent_path / MAIN_FOLDER_NAME / INFO_JSON_FILENAME - yaml_info_config = parent_path / MAIN_FOLDER_NAME / INFO_YAML_FILENAME - - if tar_index_only: - if yaml_info_config.is_file() and not json_info_config.is_file(): - # Convert legacy .info.yaml to .info.json - with json_info_config.open("w") as f: - json.dump(load_yaml(yaml_info_config.read_bytes()), f, indent=2) + if had_update: + logger.info("Regenerating dataset UUID...") + with (parent_path / MAIN_FOLDER_NAME / INDEX_UUID_FILENAME).open("w") as f: + f.write(str(uuid.uuid4())) + # Fix permissions if needed if fix_local_permissions: try: - json_info_config.local_path().chmod(file_perms) + (parent_path / MAIN_FOLDER_NAME / INDEX_UUID_FILENAME).local_path().chmod( + file_perms + ) except OSError: pass - return found_parts + json_info_config = parent_path / MAIN_FOLDER_NAME / INFO_JSON_FILENAME + yaml_info_config = parent_path / MAIN_FOLDER_NAME / INFO_YAML_FILENAME - assert len(shards) == len(shard_to_idx), ( - f"Lengths of shards and shard_to_idx do not match: {len(shards)} != {len(shard_to_idx)}" - ) + if tar_index_only: + if yaml_info_config.is_file() and not json_info_config.is_file(): + # Convert legacy .info.yaml to .info.json + with json_info_config.open("w") as f: + json.dump(load_yaml(yaml_info_config.read_bytes()), f, indent=2) - # Sort the shards according to the order in the input list - shards.sort(key=lambda shard: shard_to_idx[shard.name]) + if fix_local_permissions: + try: + json_info_config.local_path().chmod(file_perms) + except OSError: + pass - # Save info - assert [shard.name for shard in shards] == list(shard_to_idx.keys()), ( - "Shards are not in the same order as in the input list." - ) + return found_parts - info = WebdatasetInfo( - energon_version=__version__, - shard_counts={shard.name: shard.count for shard in shards}, - ) - print(f"Saving info to {json_info_config}") + assert len(shards) == len(shard_to_idx), ( + f"Lengths of shards and shard_to_idx do not match: {len(shards)} != {len(shard_to_idx)}" + ) - with json_info_config.open("w") as wf: - json.dump(to_json_object(info), wf, indent=2) + # Sort the shards according to the order in the input list + shards.sort(key=lambda shard: shard_to_idx[shard.name]) - # Fix permissions if needed - if fix_local_permissions: - try: - json_info_config.local_path().chmod(file_perms) - except OSError: - pass + # Save info + assert [shard.name for shard in shards] == list(shard_to_idx.keys()), ( + "Shards are not in the same order as in the input list." + ) - if yaml_info_config.is_file(): - # If a .info.yaml existed previously, let's also update it - # to keep them in sync - with yaml_info_config.open("w") as wf: - yaml.dump(to_json_object(info), wf) - - if split_parts_ratio is not None: - # Normalize ratio - total_ratio = sum(split_ratio for _, split_ratio in split_parts_ratio) - split_parts_ratio = [ - (split_part, split_ratio / total_ratio) - for split_part, split_ratio in split_parts_ratio - ] - # Sample from shards based on the split ratio from split parts - split_shards = {} - if shuffle_seed is not None: - random.Random(shuffle_seed).shuffle(shards) - split_total = 0 - split_offset = 0 - for split_part, split_ratio in split_parts_ratio: - split_total += split_ratio - split_end = int(len(shards) * split_total) - split_shards[split_part] = [shard.name for shard in shards[split_offset:split_end]] - split_offset = split_end - else: - assert split_parts_patterns is not None, ( - "Require either split_parts_ratio or split_parts_patterns" + info = WebdatasetInfo( + energon_version=__version__, + shard_counts={shard.name: shard.count for shard in shards}, ) - # Sample from shards based on the split patterns from split parts - split_shards = {} - for split_part, split_pattern in split_parts_patterns: - patterns = [ - re.compile(pattern) for pattern in braceexpand.braceexpand(split_pattern) - ] - split_shards[split_part] = [ - shard.name - for shard in shards - if any(pattern.match(shard.name) for pattern in patterns) - ] + print(f"Saving info to {json_info_config}") - # Optimize the split parts by trying to bracecollapse the shard names - print("Collapsing split parts... ", flush=True, end="") - for split_part in split_shards: - split_shards[split_part] = collapse(split_shards[split_part], keep_order=True) - print("Done", flush=True) - - # Save split config - splits_config = WebdatasetSplits(split_parts=split_shards) - with (parent_path / MAIN_FOLDER_NAME / split_config).open("w") as wf: - if split_config.endswith(".yaml"): - yaml.dump(to_json_object(splits_config), wf, sort_keys=False) - elif split_config.endswith(".json"): - json.dump(to_json_object(splits_config), wf, indent=2) + with json_info_config.open("w") as wf: + json.dump(to_json_object(info), wf, indent=2) + + # Fix permissions if needed + if fix_local_permissions: + try: + json_info_config.local_path().chmod(file_perms) + except OSError: + pass + + if yaml_info_config.is_file(): + # If a .info.yaml existed previously, let's also update it + # to keep them in sync + with yaml_info_config.open("w") as wf: + yaml.dump(to_json_object(info), wf) + + if split_parts_ratio is not None: + # Normalize ratio + total_ratio = sum(split_ratio for _, split_ratio in split_parts_ratio) + split_parts_ratio = [ + (split_part, split_ratio / total_ratio) + for split_part, split_ratio in split_parts_ratio + ] + # Sample from shards based on the split ratio from split parts + split_shards = {} + if shuffle_seed is not None: + random.Random(shuffle_seed).shuffle(shards) + split_total = 0 + split_offset = 0 + for split_part, split_ratio in split_parts_ratio: + split_total += split_ratio + split_end = int(len(shards) * split_total) + split_shards[split_part] = [shard.name for shard in shards[split_offset:split_end]] + split_offset = split_end else: - raise ValueError(f"Invalid split config extension: {split_config}") + assert split_parts_patterns is not None, ( + "Require either split_parts_ratio or split_parts_patterns" + ) + # Sample from shards based on the split patterns from split parts + split_shards = {} + for split_part, split_pattern in split_parts_patterns: + patterns = [ + re.compile(pattern) for pattern in braceexpand.braceexpand(split_pattern) + ] + split_shards[split_part] = [ + shard.name + for shard in shards + if any(pattern.match(shard.name) for pattern in patterns) + ] + + # Optimize the split parts by trying to bracecollapse the shard names + print("Collapsing split parts... ", flush=True, end="") + for split_part in split_shards: + split_shards[split_part] = collapse(split_shards[split_part], keep_order=True) + print("Done", flush=True) + + # Save split config + splits_config = WebdatasetSplits(split_parts=split_shards) + with (parent_path / MAIN_FOLDER_NAME / split_config).open("w") as wf: + if split_config.endswith(".yaml"): + yaml.dump(to_json_object(splits_config), wf, sort_keys=False) + elif split_config.endswith(".json"): + json.dump(to_json_object(splits_config), wf, indent=2) + else: + raise ValueError(f"Invalid split config extension: {split_config}") - # Fix permissions if needed - if fix_local_permissions: - try: - (parent_path / MAIN_FOLDER_NAME / split_config).local_path().chmod(file_perms) - except OSError: - pass + # Fix permissions if needed + if fix_local_permissions: + try: + (parent_path / MAIN_FOLDER_NAME / split_config).local_path().chmod(file_perms) + except OSError: + pass + + return found_parts + finally: + if owns_remote_sqlite_tmp and remote_sqlite_tmp_dir is not None: + shutil.rmtree(remote_sqlite_tmp_dir, ignore_errors=True) - return found_parts @classmethod def add_media_metadata( @@ -744,6 +775,7 @@ def add_media_metadata( media_filter: MediaFilterConfig, workers: int = 32, progress_fn: Callable[[Iterator[Any], int], Iterator[T]] = (lambda x, y: x), + index_sqlite_tmp_path: Optional[Path] = None, ) -> int: """Add or refresh media metadata in an existing WebDataset index.""" @@ -762,34 +794,49 @@ def add_media_metadata( if path not in shard_counts: raise ValueError(f"Shard '{path}' not present in dataset metadata") - aggregator = SqliteIndexWriterAggregator( - parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME, - total_tasks=len(expanded_paths), - progress_fn=progress_fn, - enable_sample_tables=False, - enable_media_metadata=True, - media_filter=media_filter, - reset_tables=False, - progress_on_media=False, - ) + owns_remote_sqlite_tmp = False + remote_sqlite_tmp_dir: Optional[Path] = None + if not parent_path.is_local(): + if index_sqlite_tmp_path is None: + remote_sqlite_tmp_dir = Path(tempfile.mkdtemp(dir="/tmp", prefix="energon-prepare-media-")) + index_sqlite_tmp_path = remote_sqlite_tmp_dir / INDEX_SQLITE_FILENAME + owns_remote_sqlite_tmp = True + else: + index_sqlite_tmp_path = None - process_tar = functools.partial( - cls._extract_media_from_tar, - parent_path=parent_path, - media_filter=media_filter, - shard_counts=shard_counts, - ) + try: + aggregator = SqliteIndexWriterAggregator( + parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME, + total_tasks=len(expanded_paths), + progress_fn=progress_fn, + enable_sample_tables=False, + enable_media_metadata=True, + media_filter=media_filter, + reset_tables=False, + progress_on_media=False, + sqlite_local_build_path=index_sqlite_tmp_path, + ) - pool = AggregatorPool( - num_workers=min(workers, len(expanded_paths)) or 1, - user_produce_data=process_tar, - aggregator=aggregator, - batch_size=INDEX_BATCH_SIZE, - ) + process_tar = functools.partial( + cls._extract_media_from_tar, + parent_path=parent_path, + media_filter=media_filter, + shard_counts=shard_counts, + ) - for path in expanded_paths: - pool.submit_task(path) + pool = AggregatorPool( + num_workers=min(workers, len(expanded_paths)) or 1, + user_produce_data=process_tar, + aggregator=aggregator, + batch_size=INDEX_BATCH_SIZE, + ) + + for path in expanded_paths: + pool.submit_task(path) - pool.process() + pool.process() - return aggregator.media_metadata_written + return aggregator.media_metadata_written + finally: + if owns_remote_sqlite_tmp and remote_sqlite_tmp_dir is not None: + shutil.rmtree(remote_sqlite_tmp_dir, ignore_errors=True) diff --git a/src/megatron/energon/tools/prepare.py b/src/megatron/energon/tools/prepare.py index 5e6643cf..67b2cc89 100644 --- a/src/megatron/energon/tools/prepare.py +++ b/src/megatron/energon/tools/prepare.py @@ -6,6 +6,7 @@ import json import re import typing +from pathlib import Path from types import FunctionType from typing import Any, List, Optional, Tuple, Type @@ -184,6 +185,15 @@ def printify_json(data: Any) -> Any: help='Field mapping in JSON format (e.g., \'{"image": "jpg", "caption": "txt"}\'). If not set in non-interactive mode, a sample loader from template will be created. Use with --sample-type. Only applies if sample_type is not set to CrudeWebdataset.', default=None, ) +@click.option( + "--tmp-path", + type=click.Path(path_type=Path), + default=None, + help=( + "When PATH is remote (e.g. S3), directory where index.sqlite is built locally before upload. " + "Default: a new directory under /tmp (removed after success)." + ), +) def command( path: EPath, progress: bool, @@ -203,6 +213,7 @@ def command( field_map: Optional[str], skip_dataset_yaml: bool, dataset_yaml_name: str, + tmp_path: Optional[Path], ): """Prepare WebDataset for use with energon. @@ -348,6 +359,7 @@ def progress_fn(els, length=None): workers=num_workers, media_filter=media_filter_config, fix_duplicates=fix_duplicates, + index_sqlite_tmp_path=tmp_path, ) found_types = list(found_types) diff --git a/src/megatron/energon/tools/prepare_media.py b/src/megatron/energon/tools/prepare_media.py index bcc99bf6..5640be78 100644 --- a/src/megatron/energon/tools/prepare_media.py +++ b/src/megatron/energon/tools/prepare_media.py @@ -2,6 +2,9 @@ # SPDX-License-Identifier: BSD-3-Clause +from pathlib import Path +from typing import Optional + import click from megatron.energon.epathlib import EPath @@ -44,6 +47,15 @@ is_flag=True, help="Media detection by standard file extensions.", ) +@click.option( + "--tmp-path", + type=click.Path(path_type=Path), + default=None, + help=( + "When PATH is remote, directory where index.sqlite is built locally before upload. " + "Default: a new directory under /tmp (removed after success)." + ), +) def command( path: EPath, progress: bool, @@ -51,6 +63,7 @@ def command( media_metadata_by_glob: str | None, media_metadata_by_header: bool, media_metadata_by_extension: bool, + tmp_path: Optional[Path], ): """Prepare a filesystem dataset by collecting media metadata.""" @@ -83,6 +96,7 @@ def progress_fn(els, length=None): media_filter=media_filter_config, workers=num_workers, progress_fn=progress_fn, + index_sqlite_tmp_path=tmp_path, ) click.echo(f"Done. Stored metadata for {count} files.") diff --git a/tests/s3_emulator/handler.py b/tests/s3_emulator/handler.py index 3a6d28a6..49bb198d 100644 --- a/tests/s3_emulator/handler.py +++ b/tests/s3_emulator/handler.py @@ -6,7 +6,7 @@ from hashlib import md5 from http import HTTPStatus from http.server import BaseHTTPRequestHandler -from typing import Protocol +from typing import Literal, Protocol from .auth import InvalidSignature, S3Auth from .state import S3State @@ -112,6 +112,37 @@ def _handle_write(self): qs = _up.parse_qs(parsed.query, keep_blank_values=True) + # S3 CopyObject: PUT to destination key with x-amz-copy-source and empty body. + copy_src = ( + self.headers.get("x-amz-copy-source") or self.headers.get("X-Amz-Copy-Source") or "" + ).strip() + if copy_src: + if not bucket: + self._send_error(HTTPStatus.BAD_REQUEST, "Bucket must be specified") + return + if key == "": + self._send_error(HTTPStatus.BAD_REQUEST, "CopyObject requires an object key") + return + try: + src_bucket, src_key = _parse_copy_source(copy_src) + except ValueError as err: + self._send_error(HTTPStatus.BAD_REQUEST, str(err)) + return + try: + data = self.server.state.copy_object(bucket, key, src_bucket, src_key) + except FileNotFoundError: + self._send_error(HTTPStatus.NOT_FOUND, "NoSuchKey") + return + xml = ( + '' + "" + f"{_escape_xml(formatdate(usegmt=True))}" + f""{_escape_xml(_etag(data))}"" + "" + ).encode() + self._send_bytes(xml, status=HTTPStatus.OK, content_type="application/xml") + return + # Multipart: upload part if "uploadId" in qs and "partNumber" in qs: upload_id = qs["uploadId"][0] @@ -160,15 +191,15 @@ def _handle_read(self, listing: bool, only_headers: bool = False): self._send_error(HTTPStatus.BAD_REQUEST, "Bucket must be specified") return - if key == "": # List bucket contents + if key == "": # List bucket contents (ListObjects / ListObjectsV2) if not listing: - # We treat listing with GET only try: objects = self.server.state.list_objects(bucket) except KeyError: self._send_error(HTTPStatus.NOT_FOUND, "Bucket not found") return - xml_body = self._render_bucket_list(bucket, objects) + qs = _up.parse_qs(parsed.query, keep_blank_values=True) + xml_body = self._render_list_bucket_result(bucket, objects, qs) self._send_bytes(xml_body, content_type="application/xml") else: self._send_error(HTTPStatus.NOT_IMPLEMENTED, "Listing not implemented") @@ -359,44 +390,99 @@ def _send_bytes( if self.command != "HEAD": self.wfile.write(data) - @staticmethod - def _render_bucket_list(bucket: str, objects: list[str]) -> bytes: - """Generate an XML listing of objects in a bucket. + def _render_list_bucket_result( + self, + bucket: str, + all_keys: list[str], + qs: dict[str, list[str]], + ) -> bytes: + """Build ListBucketResult XML (ListObjectsV2-compatible). + + Clients (e.g. MSC) send ``delimiter=/`` and ``prefix=`` and expect + ``CommonPrefixes`` for nested keys such as ``parts/data-0.tar``, not + only flat ``Contents``. + """ + prefix = (qs.get("prefix") or [""])[0] + delimiter = (qs.get("delimiter") or [None])[0] + max_keys_s = (qs.get("max-keys") or qs.get("maxkeys") or ["1000"])[0] + try: + max_keys = max(1, min(int(max_keys_s), 1000)) + except ValueError: + max_keys = 1000 - Args: - bucket: The bucket name. - objects: List of object keys in the bucket. + continuation = (qs.get("continuation-token") or [""])[0] + start_after = (qs.get("start-after") or [""])[0] + exclusive_after = continuation or start_after - Returns: - The XML document as bytes. - """ - entries = [] now = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.000Z") - for key in objects: - try: - data = S3RequestHandler.server.state.get_object(bucket, key) # type: ignore[attr-defined] - size = len(data) - etag = _etag(data) - except Exception: # noqa: BLE001 - size = 0 - etag = '""' - entries.append( - "" - f"{_escape_xml(key)}" - f"{now}" - f"{etag}" - f"{size}" - "" - ) - obj_elems = "".join(entries) - xml = ( - '' - "" - f"{_escape_xml(bucket)}" - f"{obj_elems}" - "" - ) - return xml.encode() + state = self.server.state + + items: list[tuple[Literal["cp", "key"], str]] = [] + if not delimiter: + for k in sorted(all_keys): + if k.startswith(prefix): + items.append(("key", k)) + else: + common: set[str] = set() + contents: list[str] = [] + for k in sorted(all_keys): + if not k.startswith(prefix): + continue + relative = k[len(prefix) :] + if delimiter in relative: + idx = relative.index(delimiter) + common.add(prefix + relative[: idx + len(delimiter)]) + else: + contents.append(k) + for cp in sorted(common): + items.append(("cp", cp)) + for ck in sorted(contents): + items.append(("key", ck)) + items.sort(key=lambda x: x[1]) + + if exclusive_after: + items = [it for it in items if it[1] > exclusive_after] + + page = items[:max_keys] + truncated = len(items) > max_keys + next_token = page[-1][1] if truncated and page else "" + + fragments: list[str] = [ + '', + "", + f"{_escape_xml(bucket)}", + f"{_escape_xml(prefix)}", + f"{len(page)}", + f"{max_keys}", + f"{str(truncated).lower()}", + ] + if delimiter: + fragments.append(f"{_escape_xml(delimiter)}") + if truncated and next_token: + fragments.append(f"{_escape_xml(next_token)}") + + for kind, path in page: + if kind == "cp": + fragments.append(f"{_escape_xml(path)}") + else: + try: + data = state.get_object(bucket, path) + size = len(data) + etag = _etag(data) + except Exception: # noqa: BLE001 + size = 0 + etag = '""' + fragments.append( + "" + f"{_escape_xml(path)}" + f"{now}" + f"{etag}" + f"{size}" + "" + ) + + fragments.append("") + return "".join(fragments).encode() class S3ServerProtocol(Protocol): # noqa: D101 @@ -404,6 +490,35 @@ class S3ServerProtocol(Protocol): # noqa: D101 auth: S3Auth +def _parse_copy_source(raw: str) -> tuple[str, str]: + """Parse ``x-amz-copy-source`` into ``(bucket, key)``. + + Accepts ``/bucket/key``, ``bucket/key``, URL-encoded keys, and strips ``?versionId=``. + + Args: + raw: Raw header value. + + Returns: + Source bucket and object key. + + Raises: + ValueError: If the value cannot be parsed. + """ + s = raw.strip() + if not s: + raise ValueError("Empty x-amz-copy-source") + s = s.split("?", 1)[0] + s = _up.unquote(s) + if s.startswith("/"): + s = s[1:] + if "/" not in s: + raise ValueError("x-amz-copy-source must be /bucket/key") + src_bucket, src_key = s.split("/", 1) + if not src_bucket or not src_key: + raise ValueError("Invalid x-amz-copy-source") + return src_bucket, src_key + + def _escape_xml(text: str) -> str: # noqa: D401 """Escape special characters for XML. diff --git a/tests/s3_emulator/state.py b/tests/s3_emulator/state.py index 1493dcb1..236b7311 100644 --- a/tests/s3_emulator/state.py +++ b/tests/s3_emulator/state.py @@ -111,6 +111,35 @@ def get_object(self, bucket: str, key: str) -> bytes: except KeyError as exc: raise FileNotFoundError(f"{bucket}/{key}") from exc + def copy_object(self, dest_bucket: str, dest_key: str, src_bucket: str, src_key: str) -> bytes: + """Copy an object to another key (S3 CopyObject). + + Args: + dest_bucket: Destination bucket. + dest_key: Destination object key. + src_bucket: Source bucket. + src_key: Source object key. + + Returns: + Copied object bytes (for ETag in the CopyObject XML response). + + Raises: + FileNotFoundError: If the source object does not exist. + """ + with self._lock: + try: + payload = bytes(self._fs[src_bucket][src_key]) + except KeyError as exc: + raise FileNotFoundError(f"{src_bucket}/{src_key}") from exc + if dest_bucket not in self._fs: + self._fs[dest_bucket] = {} + self._fs[dest_bucket][dest_key] = payload + if self._root_dir is not None: + obj_path = (self._root_dir / dest_bucket / dest_key).resolve() + obj_path.parent.mkdir(parents=True, exist_ok=True) + obj_path.write_bytes(payload) + return payload + def delete_object(self, bucket: str, key: str) -> None: """Delete an object from a bucket. diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 01aa74ee..927d2efb 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -10,6 +10,7 @@ import logging import math import random +import shutil import sys import tempfile import unittest @@ -47,14 +48,16 @@ ) from megatron.energon.dataset_config import get_dataset_from_config from megatron.energon.edataclass import edataclass +from megatron.energon.epathlib import EPath from megatron.energon.flavors import BaseWebdatasetFactory -from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME +from megatron.energon.flavors.webdataset.config import INFO_JSON_FILENAME, MAIN_FOLDER_NAME from megatron.energon.task_encoder.base import stateless from megatron.energon.tools.analyze_debug import command as analyze_debug_command from megatron.energon.tools.info import command as info_command from megatron.energon.tools.lint import command as lint_command from megatron.energon.tools.prepare import command as prepare_command from megatron.energon.tools.preview import command as preview_command +from tests.epath_s3_emulator import setup_s3_emulator # Speed up tests significantly by reducing the torch status check interval for broken worker shutdown try: @@ -1803,6 +1806,56 @@ def test_prepare_dataset_noninteractive_crude(self): content = f.read() assert "CrudeWebdataset" in content + def test_prepare_dataset_s3_cmdline(self): + """Tar shards live on S3 (emulator); `energon prepare` writes `.nv-meta` to the same prefix.""" + bucket = "energon-prepare-s3-cmdline" + profile_name = "s3test_dataset_prepare_cmdline" + with tempfile.TemporaryDirectory() as staging_dir: + staging_path = Path(staging_dir) + TestDataset.create_captioning_test_dataset(staging_path, num_samples=20) + shutil.rmtree(staging_path / MAIN_FOLDER_NAME) + with setup_s3_emulator(profile_name=profile_name) as state: + state.create_bucket(bucket) + state.add_file(staging_path, dst=bucket) + s3_root = EPath(f"msc://{profile_name}/{bucket}") + assert s3_root.is_dir() + tar_paths = list(s3_root.glob("**/*.tar")) + assert len(tar_paths) >= 1, f"expected tar shards on S3, got {tar_paths!r}" + runner = CliRunner() + result = runner.invoke( + prepare_command, + [ + str(s3_root), + "--non-interactive", + "--no-progress", + "--num-workers", + "2", + "--split-ratio=1,0,0", + "--sample-type=CaptioningSample", + '--field-map={"image": "png", "caption": "txt"}', + ], + catch_exceptions=False, + ) + assert result.exit_code == 0, result.stdout + assert "Done" in result.stdout, result.stdout + info_path = s3_root / MAIN_FOLDER_NAME / INFO_JSON_FILENAME + assert info_path.is_file(), f"missing {info_path}" + + # Iterate over the new dataset + loader = get_loader( + get_train_dataset( + s3_root, + batch_size=5, + worker_config=WorkerConfig(rank=0, world_size=1, num_workers=0), + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + ) + assert len(loader) == 4, f"len(loader) == {len(loader)}" + samples = list(d for _, d in zip(range(4), loader)) + assert len(samples) == 4, f"len(samples) == {len(samples)}" + assert all(isinstance(sample, CaptioningSample) for sample in samples) + def test_preview_captioning_dataset(self): runner = CliRunner() result = runner.invoke( diff --git a/tests/test_epathlib.py b/tests/test_epathlib.py index 2e8ae432..06d283e9 100644 --- a/tests/test_epathlib.py +++ b/tests/test_epathlib.py @@ -278,14 +278,65 @@ def test_msc_s3(self): assert p.is_file() assert p.size() > 0 assert p.read_text() == "dummy" - # TODO: Fix when fixed in MSC. - # assert EPath("msc://s3test_msc/test").is_dir() + assert EPath("msc://s3test_msc/test").is_dir() assert EPath("msc://s3test_msc/test/dir").is_dir() p.unlink() assert not p.is_file() - # assert not EPath("msc://s3test_msc/test").is_dir() + assert not EPath("msc://s3test_msc/test").is_dir() assert not EPath("msc://s3test_msc/test/dir").is_dir() + def test_msc_s3_dataprep_path_operations(self): + """EPath operations used by remote webdataset ``prepare`` (glob, rb, meta, sqlite upload, move). + + Mirrors ``tools/prepare.py`` shard discovery, ``WebdatasetPreparator._preprocess_tar`` + binary reads, and ``SqliteIndexWriterAggregator`` uploading ``index.sqlite`` from disk. + """ + profile = "s3test_msc_dataprep" + with setup_s3_emulator(profile_name=profile): + root = EPath(f"msc://{profile}/dataset_root") + parts = root / "parts" + (parts / "data-0.tar").write_bytes(b"shard0-bytes") + (parts / "data-1.tar").write_bytes(b"shard1-bytes") + + found = sorted(root.glob("**/*.tar")) + assert [p.name for p in found] == ["data-0.tar", "data-1.tar"] + + with (parts / "data-0.tar").open("rb") as f: + assert f.read(6) == b"shard0" + + meta_dir = root / MAIN_FOLDER_NAME + info = meta_dir / INFO_JSON_FILENAME + info.write_text('{"shard_counts": {"parts/data-0.tar": 1}}') + assert '"shard_counts"' in info.read_text() + + probe_idx = parts / "probe.idx" + probe_idx_tmp = parts / "probe.idx.tmp" + with probe_idx_tmp.open("wb") as out: + out.write(struct.pack("QQ", 0, 512)) + assert probe_idx_tmp.size() == 16 + assert struct.unpack("QQ", probe_idx_tmp.read_bytes()) == (0, 512) + + probe_idx_tmp.move(probe_idx) + + assert probe_idx.size() == 16 + assert struct.unpack("QQ", probe_idx.read_bytes()) == (0, 512) + probe_idx.unlink() + + with tempfile.NamedTemporaryFile(delete=False) as lf: + lf.write(b"sqlite-placeholder") + local_sqlite = lf.name + try: + EPath(local_sqlite).copy(meta_dir / INDEX_SQLITE_FILENAME) + remote_db = meta_dir / INDEX_SQLITE_FILENAME + assert remote_db.read_bytes() == b"sqlite-placeholder" + remote_db.unlink() + finally: + Path(local_sqlite).unlink(missing_ok=True) + + info.unlink() + for p in found: + p.unlink() + def test_dss_path_requires_version(self): with self.assertRaisesRegex( AssertionError, From c7740b79e6a7d93196acbc9ac6959ec92cc2e7f7 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Thu, 30 Apr 2026 13:30:53 +0200 Subject: [PATCH 2/5] Fix outdated local copy. Implement prepare media for filesystem paths on remote. --- src/megatron/energon/cache/file_store.py | 15 +- src/megatron/energon/epathlib/epath.py | 3 + .../energon/flavors/webdataset/indexing.py | 13 ++ .../energon/flavors/webdataset/prepare.py | 33 +++- src/megatron/energon/local_copy.py | 11 +- src/megatron/energon/media/extractor.py | 10 +- .../energon/media/filesystem_prepare.py | 161 ++++++++++-------- tests/s3_emulator/handler.py | 8 +- tests/test_crudedataset.py | 87 +++++++++- tests/test_dataset.py | 3 +- 10 files changed, 249 insertions(+), 95 deletions(-) diff --git a/src/megatron/energon/cache/file_store.py b/src/megatron/energon/cache/file_store.py index 2e924206..4a79e788 100644 --- a/src/megatron/energon/cache/file_store.py +++ b/src/megatron/energon/cache/file_store.py @@ -161,19 +161,18 @@ def get_path(self) -> str: def get_media_metadata(self, key: str) -> MediaMetadataBase: if self._media_metadata_available is None: try: - has_metadata = self.sqlite_reader.db_has_media_metadata() + self._media_metadata_available = self.sqlite_reader.db_has_media_metadata() except sqlite3.Error as exc: # pragma: no cover - defensive + self._media_metadata_available = False raise RuntimeError( "Failed to inspect media metadata table. Re-run `energon prepare --media-metadata-by-...`." ) from exc - if not has_metadata: - raise RuntimeError( - "Media metadata is not available for this dataset. " - "Run `energon prepare --media-metadata-by-...` to generate it." - ) - - self._media_metadata_available = True + if not self._media_metadata_available: + raise RuntimeError( + "Media metadata is not available for this dataset. " + "Run `energon prepare --media-metadata-by-...` to generate it." + ) try: row = self.sqlite_reader.get_media_metadata(key) diff --git a/src/megatron/energon/epathlib/epath.py b/src/megatron/energon/epathlib/epath.py index d11f0bcc..487866e3 100644 --- a/src/megatron/energon/epathlib/epath.py +++ b/src/megatron/energon/epathlib/epath.py @@ -336,6 +336,9 @@ def relative_to(self, other: "EPath") -> str: return str(self.internal_path.relative_to(other.internal_path)) + def stat(self) -> msc.types.ObjectMetadata: + return self.fs.info(self._internal_str_path) + @property def display_name(self) -> str: if self.profile == "dss": diff --git a/src/megatron/energon/flavors/webdataset/indexing.py b/src/megatron/energon/flavors/webdataset/indexing.py index 78b9f4f1..4d09d242 100644 --- a/src/megatron/energon/flavors/webdataset/indexing.py +++ b/src/megatron/energon/flavors/webdataset/indexing.py @@ -55,6 +55,16 @@ def __init__( part_name TEXT, content_byte_offset INTEGER, content_byte_size INTEGER) + if enable_media_metadata is True, it also creates the media_metadata table: + - media_metadata(entry_key TEXT PRIMARY KEY, + metadata_type TEXT NOT NULL, + metadata_json TEXT NOT NULL) + if enable_media_metadata is True, it also creates the media_filters table: + - media_filters(filter_id INTEGER PRIMARY KEY AUTOINCREMENT, + strategy TEXT NOT NULL, + patterns TEXT, + created_at_utc TEXT DEFAULT CURRENT_TIMESTAMP, + UNIQUE(strategy, patterns)) Also creates indexes: - samples(sample_key) - samples(tar_file_id, sample_index) @@ -354,6 +364,9 @@ class SqliteIndexReader: part_name TEXT, content_byte_offset INTEGER, content_byte_size INTEGER) + - media_metadata(entry_key TEXT PRIMARY KEY, + metadata_type TEXT NOT NULL, + metadata_json TEXT NOT NULL) """ sqlite_path: EPath diff --git a/src/megatron/energon/flavors/webdataset/prepare.py b/src/megatron/energon/flavors/webdataset/prepare.py index eeae5699..82fff51b 100644 --- a/src/megatron/energon/flavors/webdataset/prepare.py +++ b/src/megatron/energon/flavors/webdataset/prepare.py @@ -155,6 +155,8 @@ def on_start(self, aggregator_pool: AggregatorPool) -> None: local_sqlite = self.sqlite_path if self.sqlite_local_build_path is not None: local_sqlite = EPath(self.sqlite_local_build_path) + if self.sqlite_path.is_file(): + self.sqlite_path.copy(local_sqlite) self.writer = SqliteIndexWriter( local_sqlite, enable_sample_tables=self.enable_sample_tables, @@ -581,7 +583,9 @@ def prepare_dataset( remote_sqlite_tmp_dir: Optional[Path] = None if not parent_path.is_local(): if index_sqlite_tmp_path is None: - remote_sqlite_tmp_dir = Path(tempfile.mkdtemp(dir="/tmp", prefix="energon-prepare-")) + remote_sqlite_tmp_dir = Path( + tempfile.mkdtemp(dir="/tmp", prefix="energon-prepare-") + ) index_sqlite_tmp_path = remote_sqlite_tmp_dir / INDEX_SQLITE_FILENAME owns_remote_sqlite_tmp = True else: @@ -635,7 +639,9 @@ def prepare_dataset( # Fix permissions if needed if fix_local_permissions: try: - Path(str(parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME)).chmod(file_perms) + Path(str(parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME)).chmod( + file_perms + ) except OSError: pass @@ -720,7 +726,9 @@ def prepare_dataset( for split_part, split_ratio in split_parts_ratio: split_total += split_ratio split_end = int(len(shards) * split_total) - split_shards[split_part] = [shard.name for shard in shards[split_offset:split_end]] + split_shards[split_part] = [ + shard.name for shard in shards[split_offset:split_end] + ] split_offset = split_end else: assert split_parts_patterns is not None, ( @@ -766,7 +774,6 @@ def prepare_dataset( if owns_remote_sqlite_tmp and remote_sqlite_tmp_dir is not None: shutil.rmtree(remote_sqlite_tmp_dir, ignore_errors=True) - @classmethod def add_media_metadata( cls, @@ -798,12 +805,16 @@ def add_media_metadata( remote_sqlite_tmp_dir: Optional[Path] = None if not parent_path.is_local(): if index_sqlite_tmp_path is None: - remote_sqlite_tmp_dir = Path(tempfile.mkdtemp(dir="/tmp", prefix="energon-prepare-media-")) + remote_sqlite_tmp_dir = Path( + tempfile.mkdtemp(dir="/tmp", prefix="energon-prepare-media-") + ) index_sqlite_tmp_path = remote_sqlite_tmp_dir / INDEX_SQLITE_FILENAME owns_remote_sqlite_tmp = True else: index_sqlite_tmp_path = None + sqlite_path = parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME + try: aggregator = SqliteIndexWriterAggregator( parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME, @@ -836,6 +847,18 @@ def add_media_metadata( pool.process() + if sqlite_path.is_local(): + try: + meta_dir = (parent_path / MAIN_FOLDER_NAME).local_path() + # Copy group permissions from the parent dir + meta_dir.chmod((parent_path.local_path().stat().st_mode | 0o700)) + # Just read/write, no execute + sqlite_path.local_path().chmod( + (parent_path.local_path().stat().st_mode | 0o600) & 0o666 + ) + except OSError: + pass + return aggregator.media_metadata_written finally: if owns_remote_sqlite_tmp and remote_sqlite_tmp_dir is not None: diff --git a/src/megatron/energon/local_copy.py b/src/megatron/energon/local_copy.py index 2708f3e1..d51e69c3 100644 --- a/src/megatron/energon/local_copy.py +++ b/src/megatron/energon/local_copy.py @@ -3,18 +3,20 @@ import hashlib +import os from pathlib import Path from filelock import FileLock from megatron.energon.epathlib import EPath -LOCAL_COPY_TMP_DIR = Path("/tmp/energon_local_copy") +LOCAL_COPY_TMP_DIR = Path(os.environ.get("ENERGON_LOCAL_COPY_TMP_DIR", "/tmp/energon_local_copy")) def ensure_local_copy(path: EPath) -> EPath: """If the path is not local, copy it to a temporary directory and return the - path to the temporary directory. + path to the temporary directory. Assuming that the local file is never modified + after it is copied. Will re-sync if the remote file is newer. Args: path: The path to the file to copy. @@ -33,10 +35,13 @@ def ensure_local_copy(path: EPath) -> EPath: final_path = LOCAL_COPY_TMP_DIR / f"{digest}.bin" lock_path = final_path.with_suffix(".lock") tmp_path = final_path.with_suffix(".part") + mod_time = path.stat().last_modified.timestamp() # Block until lock is free with FileLock(lock_path, timeout=60 * 5): - if final_path.exists(): # someone else already produced it + # someone else already produced it + if final_path.exists() and final_path.stat().st_mtime >= mod_time: + # The local file is already newer than the remote file return EPath(final_path) # We are the downloader diff --git a/src/megatron/energon/media/extractor.py b/src/megatron/energon/media/extractor.py index 1126584e..4ebeadda 100644 --- a/src/megatron/energon/media/extractor.py +++ b/src/megatron/energon/media/extractor.py @@ -15,12 +15,13 @@ from PIL import Image, UnidentifiedImageError from megatron.energon.av import AVDecoder +from megatron.energon.epathlib import EPath from megatron.energon.media.metadata import ImageMetadata, MediaMetadataBase, MediaMetadataType logger = logging.getLogger(__name__) -SourceData = Union[bytes, Path, BinaryIO] +SourceData = Union[bytes, EPath, BinaryIO] class MediaFilterStrategy(str, Enum): @@ -210,9 +211,13 @@ def _build_metadata( def _build_image_metadata(source: SourceData) -> ImageMetadata | None: + should_close = False try: if isinstance(source, (bytes, bytearray)): source = io.BytesIO(source) + elif isinstance(source, EPath): + source = source.open("rb") + should_close = True with Image.open(source) as image: image.load() @@ -225,6 +230,9 @@ def _build_image_metadata(source: SourceData) -> ImageMetadata | None: except UnidentifiedImageError: logger.debug("Failed to parse image metadata", exc_info=True) return None + finally: + if should_close: + source.close() def _build_av_metadata(source: SourceData) -> MediaMetadataBase | None: diff --git a/src/megatron/energon/media/filesystem_prepare.py b/src/megatron/energon/media/filesystem_prepare.py index afa492d2..7ef3e2a7 100644 --- a/src/megatron/energon/media/filesystem_prepare.py +++ b/src/megatron/energon/media/filesystem_prepare.py @@ -3,7 +3,8 @@ from __future__ import annotations -import os +import shutil +import tempfile from functools import partial from pathlib import Path from typing import Callable, Iterator @@ -31,6 +32,7 @@ def prepare_filesystem_dataset( *, progress: bool, workers: int = 16, + index_sqlite_tmp_path: Path | None = None, ) -> int: """Scan a filesystem dataset and materialize media metadata into SQLite. @@ -43,18 +45,21 @@ def prepare_filesystem_dataset( Number of metadata entries written to the database. """ - # Only supporting local file system, because sqlite does not support remote file systems. - # TODO: Implement remote file systems. Maybe create locally in tmp then upload? - root = root_path.local_path() - assert root.is_dir(), f"Expected directory for filesystem dataset, got {root}" - assert root.is_absolute(), f"Filesystem dataset path must be absolute: {root}" + assert not root_path.is_file(), f"Expected directory for filesystem dataset, got {root_path}" - meta_dir = root / MAIN_FOLDER_NAME - meta_dir.mkdir(exist_ok=True, parents=True) + files = _collect_media_files(root=root_path, media_filter=media_filter, progress=progress) - files = _collect_media_files(root=root, media_filter=media_filter, progress=progress) - - sqlite_path = EPath(meta_dir / INDEX_SQLITE_FILENAME) + owns_remote_sqlite_tmp = False + remote_sqlite_tmp_dir: Path | None = None + if not root_path.is_local(): + if index_sqlite_tmp_path is None: + remote_sqlite_tmp_dir = Path( + tempfile.mkdtemp(dir="/tmp", prefix="energon-prepare-media-") + ) + index_sqlite_tmp_path = remote_sqlite_tmp_dir / INDEX_SQLITE_FILENAME + owns_remote_sqlite_tmp = True + else: + index_sqlite_tmp_path = None agg_progress_fn: Callable[[Iterator[int], int], Iterator[int]] | None = None if progress: @@ -64,55 +69,66 @@ def agg_progress_fn(iterator: Iterator[int], total: int) -> Iterator[int]: with tqdm(iterator, total=total, unit="file", desc="Processing media files") as bar: yield from bar - aggregator = SqliteIndexWriterAggregator( - sqlite_path, - total_tasks=len(files), - progress_fn=agg_progress_fn, - enable_media_metadata=True, - media_filter=media_filter, - reset_tables=False, - enable_sample_tables=False, - progress_on_media=progress, - ) - - pool = AggregatorPool[ - Path, - IndexAggregatable, - tuple[list[ShardInfo], set[str], bool, list[tuple[str, int]]], - ]( - num_workers=min(workers, len(files)) or 1, - user_produce_data=partial( - _process_filesystem_entry, - root=root, - media_filter=media_filter, - ), - aggregator=aggregator, - batch_size=INDEX_BATCH_SIZE, - ) - - for file_path in files: - pool.submit_task(file_path) - - pool.process() + sqlite_path = root_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME try: - # Copy group permissions from the parent dir - meta_dir.chmod((root.stat().st_mode | 0o700)) - # Just read/write, no execute - sqlite_path.local_path().chmod((root.stat().st_mode | 0o600) & 0o666) - except OSError: - pass - - return aggregator.media_metadata_written + aggregator = SqliteIndexWriterAggregator( + sqlite_path, + total_tasks=len(files), + progress_fn=agg_progress_fn, + enable_media_metadata=True, + media_filter=media_filter, + reset_tables=False, + enable_sample_tables=False, + progress_on_media=progress, + sqlite_local_build_path=index_sqlite_tmp_path, + ) + + pool = AggregatorPool[ + EPath, + IndexAggregatable, + tuple[list[ShardInfo], set[str], bool, list[tuple[str, int]]], + ]( + num_workers=min(workers, len(files)) or 1, + user_produce_data=partial( + _process_filesystem_entry, + root=root_path, + media_filter=media_filter, + ), + aggregator=aggregator, + batch_size=INDEX_BATCH_SIZE, + ) + + for file_path in files: + pool.submit_task(file_path) + + pool.process() + + if sqlite_path.is_local(): + try: + meta_dir = (root_path / MAIN_FOLDER_NAME).local_path() + # Copy group permissions from the parent dir + meta_dir.chmod((root_path.local_path().stat().st_mode | 0o700)) + # Just read/write, no execute + sqlite_path.local_path().chmod( + (root_path.local_path().stat().st_mode | 0o600) & 0o666 + ) + except OSError: + pass + + return aggregator.media_metadata_written + finally: + if owns_remote_sqlite_tmp and remote_sqlite_tmp_dir is not None: + shutil.rmtree(remote_sqlite_tmp_dir, ignore_errors=True) def _collect_media_files( - *, root: Path, media_filter: MediaFilterConfig, progress: bool = False -) -> list[Path]: + *, root: EPath, media_filter: MediaFilterConfig, progress: bool = False +) -> list[EPath]: """Return a sorted list of files to process based on the media filter.""" consider_all = media_filter.should_consider_all() - files: list[Path] = [] + files: list[EPath] = [] progress_bar = None if progress: @@ -120,28 +136,26 @@ def _collect_media_files( progress_bar = tqdm(total=None, unit="file", desc="Collecting media files") - for dirpath, dirnames, filenames in os.walk(root, followlinks=False): - current_dir = Path(dirpath) - - if current_dir.name == MAIN_FOLDER_NAME: - dirnames[:] = [] - continue + if root.is_local(): + paths = ( + EPath(path / file) + for path, _dirs, files in root.local_path().walk(follow_symlinks=False) + for file in files + ) + else: + paths = root.glob("**/*") - dirnames[:] = [d for d in dirnames if d != MAIN_FOLDER_NAME] + for file in paths: + if progress_bar is not None: + progress_bar.update() - for filename in filenames: - if progress_bar is not None: - progress_bar.update() - - if not consider_all and not media_filter.should_consider_media(filename): - continue + if not consider_all and not media_filter.should_consider_media(file.name): + continue - candidate = current_dir / filename - if MAIN_FOLDER_NAME in candidate.parts: - continue + if ("/" + MAIN_FOLDER_NAME + "/") in file.url: + continue - if candidate.is_file(): - files.append(candidate) + files.append(file) if progress_bar is not None: progress_bar.close() @@ -151,18 +165,17 @@ def _collect_media_files( def _process_filesystem_entry( - file_path: Path | str, + file_path: EPath, *, - root: Path, + root: EPath, media_filter: MediaFilterConfig, ) -> Iterator[IndexMediaMetadata]: - file_path = Path(file_path) extracted_metadata = media_filter.extract_metadata(file_path) if extracted_metadata is None: return stored_type, metadata_json = serialize_media_metadata(extracted_metadata) - entry_key = file_path.relative_to(root).as_posix() + entry_key = file_path.relative_to(root) yield IndexMediaMetadata( entry_key=entry_key, diff --git a/tests/s3_emulator/handler.py b/tests/s3_emulator/handler.py index 49bb198d..90464fc2 100644 --- a/tests/s3_emulator/handler.py +++ b/tests/s3_emulator/handler.py @@ -459,11 +459,15 @@ def _render_list_bucket_result( if delimiter: fragments.append(f"{_escape_xml(delimiter)}") if truncated and next_token: - fragments.append(f"{_escape_xml(next_token)}") + fragments.append( + f"{_escape_xml(next_token)}" + ) for kind, path in page: if kind == "cp": - fragments.append(f"{_escape_xml(path)}") + fragments.append( + f"{_escape_xml(path)}" + ) else: try: data = state.get_object(bucket, path) diff --git a/tests/test_crudedataset.py b/tests/test_crudedataset.py index 5cd38eab..2e16df22 100644 --- a/tests/test_crudedataset.py +++ b/tests/test_crudedataset.py @@ -19,6 +19,7 @@ import torch import webdataset as wds +from click.testing import CliRunner import megatron.energon.epathlib.epath as epath_mod from megatron.energon import ( @@ -34,11 +35,12 @@ ) from megatron.energon.cache import FileStore, FileStoreCachePool, Lazy from megatron.energon.cache.base import CachePool +from megatron.energon.cli.main import prepare_media_command from megatron.energon.edataclass import edataclass from megatron.energon.epathlib.epath import EPath from megatron.energon.errors import reraise_exception from megatron.energon.flavors.base_dataset import Sample -from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME +from megatron.energon.flavors.webdataset.config import INDEX_SQLITE_FILENAME, MAIN_FOLDER_NAME from megatron.energon.flavors.webdataset.sample_decoder import SampleDecoder from megatron.energon.media.extractor import MediaFilterConfig, MediaFilterStrategy from megatron.energon.media.filesystem_prepare import prepare_filesystem_dataset @@ -1014,6 +1016,89 @@ def test_media_metadata_webdataset(self): "VIDEO-192x108@30.0fps-63.0s|VIDEO-192x108@30.0fps-63.0s", ] + def test_prepare_dataset_s3_cmdline(self): + """Tar shards live on S3 (emulator); `energon prepare` writes `.nv-meta` to the same prefix.""" + bucket = "energon-prepare-s3-cmdline-media-metadata" + profile_name = "s3test_dataset_prepare_cmdline_media_metadata" + with setup_s3_emulator(profile_name=profile_name) as state: + state.create_bucket(bucket) + state.add_file(self.multimedia_fs_path, dst=f"{bucket}/multimedia_fs/") + state.add_file(self.multimedia_wds_path, dst=f"{bucket}/multimedia_wds/") + s3_root = EPath(f"msc://{profile_name}/{bucket}") + info_path = s3_root / "multimedia_fs" / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME + # assert not info_path.is_file(), f"index.sqlite already exists: {info_path}" + print(f"S3 root: {s3_root}") + runner = CliRunner() + result = runner.invoke( + prepare_media_command, + [ + str(s3_root / "multimedia_fs"), + "--num-workers", + "2", + "--media-metadata-by-extension", + ], + catch_exceptions=False, + ) + print(result.stdout) + assert result.exit_code == 0, result.stdout + assert "Done" in result.stdout, result.stdout + assert info_path.is_file(), f"missing {info_path}" + + print("Done preparing media metadata, iterating now") + + state.put_object( + bucket, + "s3_media_metadataset.yaml", + "\n".join( + [ + "__module__: megatron.energon", + "__class__: MetadatasetV2", + "splits:", + " train:", + " path: multimedia_wds", + " aux:", + f" media: filesystem+msc://{profile_name}/{bucket}/multimedia_fs", + " subflavors:", + " crude_type: media_metadata", + ] + ).encode("utf-8"), + ) + + # Iterate over the new dataset + worker_config = WorkerConfig( + rank=0, + world_size=1, + num_workers=0, + ) + + loader = get_savable_loader( + get_train_dataset( + s3_root / "s3_media_metadataset.yaml", + batch_size=1, + worker_config=worker_config, + task_encoder=CookingTaskEncoder(), + shuffle_buffer_size=None, + max_samples_per_sequence=None, + ) + ) + + descriptions = [] + for _, batch in zip(range(4), loader): + descriptions.extend(batch.txts) + + # from pprint import pprint + # pprint(descriptions, indent=4) + + # The descriptions are like "A|B", where A is the format + # in the WebDataset and B is the format in the auxiliary dataset. + + assert descriptions == [ + "IMG-32x16-JPEG|IMG-32x16-JPEG", + "IMG-24x24-PNG|IMG-24x24-PNG", + "AUDIO-10.0s@32000Hz|AUDIO-10.0s@32000Hz", + "VIDEO-192x108@30.0fps-63.0s|VIDEO-192x108@30.0fps-63.0s", + ] + def test_nomds(self): torch.manual_seed(42) worker_config = WorkerConfig( diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 927d2efb..b05cb401 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1833,6 +1833,7 @@ def test_prepare_dataset_s3_cmdline(self): "--split-ratio=1,0,0", "--sample-type=CaptioningSample", '--field-map={"image": "png", "caption": "txt"}', + "--media-metadata-by-extension", ], catch_exceptions=False, ) @@ -1840,7 +1841,7 @@ def test_prepare_dataset_s3_cmdline(self): assert "Done" in result.stdout, result.stdout info_path = s3_root / MAIN_FOLDER_NAME / INFO_JSON_FILENAME assert info_path.is_file(), f"missing {info_path}" - + # Iterate over the new dataset loader = get_loader( get_train_dataset( From c0c153157e5e53f2e98245b8cf261514e794f6e4 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Thu, 30 Apr 2026 14:04:45 +0200 Subject: [PATCH 3/5] Fix for older python --- src/megatron/energon/media/filesystem_prepare.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/megatron/energon/media/filesystem_prepare.py b/src/megatron/energon/media/filesystem_prepare.py index 7ef3e2a7..44edc207 100644 --- a/src/megatron/energon/media/filesystem_prepare.py +++ b/src/megatron/energon/media/filesystem_prepare.py @@ -3,6 +3,7 @@ from __future__ import annotations +import os import shutil import tempfile from functools import partial @@ -138,8 +139,8 @@ def _collect_media_files( if root.is_local(): paths = ( - EPath(path / file) - for path, _dirs, files in root.local_path().walk(follow_symlinks=False) + EPath(f"{path}/{file}") + for path, _dirs, files in os.walk(root.local_path(), followlinks=False) for file in files ) else: From 598db74a246f2b2e7d12fe8da0f102ac4a7af8de Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Thu, 30 Apr 2026 14:24:31 +0200 Subject: [PATCH 4/5] Fix setting timestamp of local copies of files --- src/megatron/energon/local_copy.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/megatron/energon/local_copy.py b/src/megatron/energon/local_copy.py index d51e69c3..414aeede 100644 --- a/src/megatron/energon/local_copy.py +++ b/src/megatron/energon/local_copy.py @@ -47,6 +47,7 @@ def ensure_local_copy(path: EPath) -> EPath: # We are the downloader try: path.copy(EPath(tmp_path)) + os.utime(tmp_path, (tmp_path.stat().st_atime, mod_time)) tmp_path.rename(final_path) finally: tmp_path.unlink(missing_ok=True) From ac7dc960a96871628901ffb57bd0480167fea5a1 Mon Sep 17 00:00:00 2001 From: Lukas Voegtle <5764745+voegtlel@users.noreply.github.com> Date: Thu, 21 May 2026 15:15:58 +0200 Subject: [PATCH 5/5] Fix: Epath implements walk and glob, including tests. Fix sqlite tmp path handling. Fix S3 emulator timestamp handling --- src/megatron/energon/epathlib/epath.py | 55 ++++++--- .../energon/flavors/webdataset/prepare.py | 14 ++- .../energon/media/filesystem_prepare.py | 26 +++-- src/megatron/energon/tools/prepare.py | 7 +- src/megatron/energon/tools/prepare_media.py | 8 +- tests/s3_emulator/handler.py | 33 ++++-- tests/s3_emulator/state.py | 55 ++++++++- tests/test_epathlib.py | 109 ++++++++++++++++++ 8 files changed, 263 insertions(+), 44 deletions(-) diff --git a/src/megatron/energon/epathlib/epath.py b/src/megatron/energon/epathlib/epath.py index 487866e3..9ebbfbc8 100644 --- a/src/megatron/energon/epathlib/epath.py +++ b/src/megatron/energon/epathlib/epath.py @@ -266,17 +266,15 @@ def url(self) -> str: return f"msc://{self.profile}{int_path_str}" def is_local(self) -> bool: - if self.profile == "dss": - # For now, a DSS path is always considered local. - # Note that this does not mean it exists on the local filesystem. - return True - else: - return self.profile == DEFAULT_PROFILE_NAME + # It will return a posix path if the fs is local, otherwise None + return self.fs.get_posix_path(self._internal_str_path) is not None def local_path(self) -> PathlibPath: - if not self.is_local(): + # This resolves the path if it exists, probably ok. + posix_path = self.fs.get_posix_path(self._internal_str_path) + if posix_path is None: raise ValueError(f"Path {self} is not local") - return PathlibPath(self._internal_str_path) + return PathlibPath(posix_path) def is_dir(self) -> bool: try: @@ -290,25 +288,46 @@ def is_file(self) -> bool: def mkdir(self, exist_ok: bool = True, parents: bool = False): pass - def glob(self, pattern) -> Generator["EPath", None, None]: + def walk(self) -> Generator["EPath", None, None]: + """Returns all files within this path (no folders).""" + # Prefix to be removed from found paths to remap to relative paths + root_prefix = self._internal_str_path.lstrip("/") + + for obj in self.fs.list_recursive(self._internal_str_path): + rel = obj.key + if root_prefix: + if rel.startswith(root_prefix + "/"): + rel = rel[len(root_prefix) + 1 :] + elif rel.startswith("/" + root_prefix + "/"): + rel = rel[len(root_prefix) + 2 :] + elif rel == root_prefix or rel == "/" + root_prefix: + rel = "." + + path = EPath(self) + path.internal_path = self._resolve(self.internal_path / PurePosixPath(rel)) + yield path + + def glob(self, pattern: str) -> Generator["EPath", None, None]: + """Returns all files matching the pattern within this path (no folders).""" search_path_pattern = (self / pattern)._internal_str_path - # MSC S3 glob matches keys like ``bucket/key``; a leading ``/`` breaks wcmatch (pattern + # MSC glob matches keys like ``bucket/key``; a leading ``/`` breaks wcmatch (pattern # ``/b/**`` never matches ``b/parts/x``). Returned keys may repeat the bucket prefix; strip # it before joining with ``internal_path`` so we do not get ``/b/b/parts/...``. - if not self.is_local() and search_path_pattern.startswith("/"): - search_path_pattern = search_path_pattern.lstrip("/") + search_path_pattern = search_path_pattern.lstrip("/") - root_prefix = str(self.internal_path).lstrip("/") + # Prefix to be removed from found paths to remap to relative paths + root_prefix = self._internal_str_path.lstrip("/") for path in self.fs.glob(search_path_pattern): assert isinstance(path, str) rel = path - if not self.is_local() and root_prefix: - pfx = root_prefix + "/" - if rel.startswith(pfx): - rel = rel[len(pfx) :] - elif rel == root_prefix: + if root_prefix: + if rel.startswith(root_prefix + "/"): + rel = rel[len(root_prefix) + 1 :] + elif rel.startswith("/" + root_prefix + "/"): + rel = rel[len(root_prefix) + 2 :] + elif rel == root_prefix or rel == "/" + root_prefix: rel = "." new_path = EPath(self) diff --git a/src/megatron/energon/flavors/webdataset/prepare.py b/src/megatron/energon/flavors/webdataset/prepare.py index 82fff51b..62e6388e 100644 --- a/src/megatron/energon/flavors/webdataset/prepare.py +++ b/src/megatron/energon/flavors/webdataset/prepare.py @@ -521,7 +521,7 @@ def prepare_dataset( tar_index_only: Only create tar-index, then exit media_filter: Media filter configuration fix_duplicates: If True, fix duplicate keys in the dataset by renaming the files in the shards. - index_sqlite_tmp_path: When ``parent_path`` is remote, directory used to build ``index.sqlite`` + index_sqlite_tmp_path: When ``parent_path`` is remote, temp file path used to build ``index.sqlite`` locally before upload. If omitted, a new directory under ``/tmp`` is created and removed after a successful run. @@ -784,7 +784,17 @@ def add_media_metadata( progress_fn: Callable[[Iterator[Any], int], Iterator[T]] = (lambda x, y: x), index_sqlite_tmp_path: Optional[Path] = None, ) -> int: - """Add or refresh media metadata in an existing WebDataset index.""" + """Add or refresh media metadata in an existing WebDataset index. + + Args: + parent_path: WebDataset root path. + media_filter: Media filtering configuration. + workers: Number of parallel workers. + progress_fn: Callback for progress updates. + index_sqlite_tmp_path: When ``parent_path`` is remote, sqlite file path used to build + ``index.sqlite`` locally before upload. If omitted, a new directory under + ``/tmp`` is created and removed after a successful run. + """ parent_path = EPath(parent_path) diff --git a/src/megatron/energon/media/filesystem_prepare.py b/src/megatron/energon/media/filesystem_prepare.py index 44edc207..d9b8c3bb 100644 --- a/src/megatron/energon/media/filesystem_prepare.py +++ b/src/megatron/energon/media/filesystem_prepare.py @@ -3,7 +3,6 @@ from __future__ import annotations -import os import shutil import tempfile from functools import partial @@ -41,6 +40,9 @@ def prepare_filesystem_dataset( root_path: Dataset root directory. media_filter: Media filtering configuration. progress: Whether to display a tqdm progress bar. + index_sqlite_tmp_path: When ``root_path`` is remote, temp file path used to build + ``index.sqlite`` locally before upload. If omitted, a new directory under + ``/tmp`` is created and removed after a successful run. Returns: Number of metadata entries written to the database. @@ -137,23 +139,23 @@ def _collect_media_files( progress_bar = tqdm(total=None, unit="file", desc="Collecting media files") - if root.is_local(): - paths = ( - EPath(f"{path}/{file}") - for path, _dirs, files in os.walk(root.local_path(), followlinks=False) - for file in files - ) - else: - paths = root.glob("**/*") + # if root.is_local() and not root.profile == "dss": + # paths = ( + # EPath(path) / file + # for path, _dirs, files in os.walk(root.local_path(), followlinks=False) + # for file in files + # ) + # else: + # paths = root.glob("**/*") - for file in paths: + for file in root.walk(): if progress_bar is not None: progress_bar.update() - if not consider_all and not media_filter.should_consider_media(file.name): + if ("/" + MAIN_FOLDER_NAME + "/") in file.url: continue - if ("/" + MAIN_FOLDER_NAME + "/") in file.url: + if not consider_all and not media_filter.should_consider_media(file.name): continue files.append(file) diff --git a/src/megatron/energon/tools/prepare.py b/src/megatron/energon/tools/prepare.py index 67b2cc89..8a0a82e0 100644 --- a/src/megatron/energon/tools/prepare.py +++ b/src/megatron/energon/tools/prepare.py @@ -348,6 +348,11 @@ def progress_fn(els, length=None): def progress_fn(els, length=None): return els + if tmp_path is not None: + index_sqlite_tmp_path = tmp_path / "index.sqlite" + else: + index_sqlite_tmp_path = None + found_types = BaseWebdatasetFactory.prepare_dataset( path, all_tars, @@ -359,7 +364,7 @@ def progress_fn(els, length=None): workers=num_workers, media_filter=media_filter_config, fix_duplicates=fix_duplicates, - index_sqlite_tmp_path=tmp_path, + index_sqlite_tmp_path=index_sqlite_tmp_path, ) found_types = list(found_types) diff --git a/src/megatron/energon/tools/prepare_media.py b/src/megatron/energon/tools/prepare_media.py index 5640be78..57ce4ed5 100644 --- a/src/megatron/energon/tools/prepare_media.py +++ b/src/megatron/energon/tools/prepare_media.py @@ -71,6 +71,11 @@ def command( media_metadata_by_glob, media_metadata_by_header, media_metadata_by_extension ) + if tmp_path is not None: + index_sqlite_tmp_path = tmp_path / "index.sqlite" + else: + index_sqlite_tmp_path = None + ds_type = get_dataset_type(path) if ds_type == EnergonDatasetType.WEBDATASET: click.echo("Preparing webdataset and computing media metadata...") @@ -96,7 +101,7 @@ def progress_fn(els, length=None): media_filter=media_filter_config, workers=num_workers, progress_fn=progress_fn, - index_sqlite_tmp_path=tmp_path, + index_sqlite_tmp_path=index_sqlite_tmp_path, ) click.echo(f"Done. Stored metadata for {count} files.") @@ -112,6 +117,7 @@ def progress_fn(els, length=None): media_filter_config, progress=progress, workers=num_workers, + index_sqlite_tmp_path=index_sqlite_tmp_path, ) click.echo(f"Done. Stored metadata for {stored} files.") diff --git a/tests/s3_emulator/handler.py b/tests/s3_emulator/handler.py index 90464fc2..7329825e 100644 --- a/tests/s3_emulator/handler.py +++ b/tests/s3_emulator/handler.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import urllib.parse as _up from datetime import datetime, timezone -from email.utils import formatdate +from email.utils import format_datetime from hashlib import md5 from http import HTTPStatus from http.server import BaseHTTPRequestHandler @@ -130,13 +130,14 @@ def _handle_write(self): return try: data = self.server.state.copy_object(bucket, key, src_bucket, src_key) + last_modified = self.server.state.get_object_last_modified(bucket, key) except FileNotFoundError: self._send_error(HTTPStatus.NOT_FOUND, "NoSuchKey") return xml = ( '' "" - f"{_escape_xml(formatdate(usegmt=True))}" + f"{_escape_xml(_s3_datetime(last_modified))}" f""{_escape_xml(_etag(data))}"" "" ).encode() @@ -207,6 +208,7 @@ def _handle_read(self, listing: bool, only_headers: bool = False): try: data = self.server.state.get_object(bucket, key) + last_modified = self.server.state.get_object_last_modified(bucket, key) except FileNotFoundError: self._send_error(HTTPStatus.NOT_FOUND, "Not found") return @@ -234,10 +236,10 @@ def _handle_read(self, listing: bool, only_headers: bool = False): "Accept-Ranges": "bytes", "Content-Length": str(len(slice_data)), "ETag": _etag(data), + "Last-Modified": _http_datetime(last_modified), } if only_headers: headers.setdefault("Content-Type", "application/octet-stream") - headers.setdefault("Last-Modified", formatdate(usegmt=True)) self._send_status(HTTPStatus.PARTIAL_CONTENT, extra_headers=headers) else: self._send_bytes( @@ -254,7 +256,7 @@ def _handle_read(self, listing: bool, only_headers: bool = False): "Content-Length": str(len(data)), "Accept-Ranges": "bytes", "Content-Type": "application/octet-stream", - "Last-Modified": formatdate(usegmt=True), + "Last-Modified": _http_datetime(last_modified), "ETag": _etag(data), }, ) @@ -262,7 +264,11 @@ def _handle_read(self, listing: bool, only_headers: bool = False): self._send_bytes( data, content_type="application/octet-stream", - extra_headers={"Accept-Ranges": "bytes"}, + extra_headers={ + "Accept-Ranges": "bytes", + "Last-Modified": _http_datetime(last_modified), + "ETag": _etag(data), + }, ) def _handle_delete(self): @@ -414,7 +420,6 @@ def _render_list_bucket_result( start_after = (qs.get("start-after") or [""])[0] exclusive_after = continuation or start_after - now = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.000Z") state = self.server.state items: list[tuple[Literal["cp", "key"], str]] = [] @@ -471,15 +476,17 @@ def _render_list_bucket_result( else: try: data = state.get_object(bucket, path) + last_modified = state.get_object_last_modified(bucket, path) size = len(data) etag = _etag(data) except Exception: # noqa: BLE001 size = 0 etag = '""' + last_modified = datetime.fromtimestamp(0, tz=timezone.utc) fragments.append( "" f"{_escape_xml(path)}" - f"{now}" + f"{_s3_datetime(last_modified)}" f"{etag}" f"{size}" "" @@ -541,6 +548,18 @@ def _escape_xml(text: str) -> str: # noqa: D401 ) +def _http_datetime(value: datetime) -> str: + """Format an aware datetime for HTTP Last-Modified headers.""" + + return format_datetime(value.astimezone(timezone.utc), usegmt=True) + + +def _s3_datetime(value: datetime) -> str: + """Format an aware datetime for S3 XML LastModified fields.""" + + return value.astimezone(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.000Z") + + def _etag(data: bytes) -> str: # noqa: D401 """Generate an ETag for binary data. diff --git a/tests/s3_emulator/state.py b/tests/s3_emulator/state.py index 236b7311..c8d22cee 100644 --- a/tests/s3_emulator/state.py +++ b/tests/s3_emulator/state.py @@ -1,6 +1,8 @@ # Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause import json +import os +from datetime import datetime, timezone from pathlib import Path from threading import RLock from typing import Dict, Optional @@ -27,6 +29,7 @@ def __init__(self, root_dir: Optional[Path] = None) -> None: root_dir: Path to persist the store on disk. """ self._fs: Dict[str, Dict[str, bytes]] = {} + self._last_modified: Dict[str, Dict[str, datetime]] = {} self._uploads: Dict[str, _MultipartUpload] = {} self._lock = RLock() self._root_dir = root_dir @@ -54,6 +57,7 @@ def create_bucket(self, bucket: str) -> None: print(f"Bucket '{bucket}' already exists") return self._fs[bucket] = {} + self._last_modified[bucket] = {} if self._root_dir is not None: (self._root_dir / bucket).mkdir(parents=True, exist_ok=True) @@ -69,6 +73,7 @@ def delete_bucket(self, bucket: str) -> None: if self._fs[bucket]: raise RuntimeError("Bucket not empty") del self._fs[bucket] + del self._last_modified[bucket] if self._root_dir is not None: bucket_path = self._root_dir / bucket if bucket_path.exists(): @@ -76,24 +81,34 @@ def delete_bucket(self, bucket: str) -> None: p.unlink() bucket_path.rmdir() - def put_object(self, bucket: str, key: str, data: bytes) -> None: + def put_object( + self, bucket: str, key: str, data: bytes, *, last_modified: datetime | None = None + ) -> None: """Store an object in a bucket. Args: bucket: Name of the bucket. key: Object key. data: Object data. + last_modified: Optional timestamp for the object. Defaults to now. """ if not bucket: raise ValueError("Bucket name must be given") + if last_modified is None: + last_modified = datetime.now(timezone.utc) + else: + last_modified = last_modified.astimezone(timezone.utc) with self._lock: if bucket not in self._fs: self._fs[bucket] = {} + self._last_modified[bucket] = {} self._fs[bucket][key] = data + self._last_modified[bucket][key] = last_modified if self._root_dir is not None: obj_path = (self._root_dir / bucket / key).resolve() obj_path.parent.mkdir(parents=True, exist_ok=True) obj_path.write_bytes(data) + os.utime(obj_path, (last_modified.timestamp(), last_modified.timestamp())) def get_object(self, bucket: str, key: str) -> bytes: """Retrieve an object from a bucket. @@ -111,6 +126,15 @@ def get_object(self, bucket: str, key: str) -> bytes: except KeyError as exc: raise FileNotFoundError(f"{bucket}/{key}") from exc + def get_object_last_modified(self, bucket: str, key: str) -> datetime: + """Return the stored Last-Modified timestamp for an object.""" + + with self._lock: + try: + return self._last_modified[bucket][key] + except KeyError as exc: + raise FileNotFoundError(f"{bucket}/{key}") from exc + def copy_object(self, dest_bucket: str, dest_key: str, src_bucket: str, src_key: str) -> bytes: """Copy an object to another key (S3 CopyObject). @@ -133,11 +157,15 @@ def copy_object(self, dest_bucket: str, dest_key: str, src_bucket: str, src_key: raise FileNotFoundError(f"{src_bucket}/{src_key}") from exc if dest_bucket not in self._fs: self._fs[dest_bucket] = {} + self._last_modified[dest_bucket] = {} + last_modified = datetime.now(timezone.utc) self._fs[dest_bucket][dest_key] = payload + self._last_modified[dest_bucket][dest_key] = last_modified if self._root_dir is not None: obj_path = (self._root_dir / dest_bucket / dest_key).resolve() obj_path.parent.mkdir(parents=True, exist_ok=True) obj_path.write_bytes(payload) + os.utime(obj_path, (last_modified.timestamp(), last_modified.timestamp())) return payload def delete_object(self, bucket: str, key: str) -> None: @@ -150,6 +178,7 @@ def delete_object(self, bucket: str, key: str) -> None: with self._lock: try: del self._fs[bucket][key] + del self._last_modified[bucket][key] except KeyError as exc: raise FileNotFoundError(f"{bucket}/{key}") from exc if self._root_dir is not None: @@ -191,7 +220,21 @@ def _load_from_disk(self) -> None: print(f"Failed to read persisted state: {err}") return with self._lock: - self._fs = {bucket: {key: b"" for key in keys} for bucket, keys in mapping.items()} + self._fs = {} + self._last_modified = {} + for bucket, keys in mapping.items(): + bucket_path = self._root_dir / bucket + self._fs[bucket] = {} + self._last_modified[bucket] = {} + for key in keys: + object_path = bucket_path / key + self._fs[bucket][key] = ( + object_path.read_bytes() if object_path.is_file() else b"" + ) + self._last_modified[bucket][key] = datetime.fromtimestamp( + object_path.stat().st_mtime if object_path.exists() else 0, + tz=timezone.utc, + ) def flush(self) -> None: """Persist only the structure of the store to disk.""" @@ -215,6 +258,7 @@ def initiate_multipart(self, bucket: str, key: str) -> str: self._uploads[upload_id] = _MultipartUpload(bucket, key) if bucket not in self._fs: self._fs[bucket] = {} + self._last_modified[bucket] = {} return upload_id def upload_part(self, upload_id: str, part_number: int, data: bytes) -> None: @@ -244,11 +288,15 @@ def complete_multipart(self, upload_id: str) -> None: data = mp.assemble() if mp.bucket not in self._fs: self._fs[mp.bucket] = {} + self._last_modified[mp.bucket] = {} + last_modified = datetime.now(timezone.utc) self._fs[mp.bucket][mp.key] = data + self._last_modified[mp.bucket][mp.key] = last_modified if self._root_dir is not None: obj_path = (self._root_dir / mp.bucket / mp.key).resolve() obj_path.parent.mkdir(parents=True, exist_ok=True) obj_path.write_bytes(data) + os.utime(obj_path, (last_modified.timestamp(), last_modified.timestamp())) def abort_multipart(self, upload_id: str) -> None: """Abort a multipart upload. @@ -272,7 +320,8 @@ def add_file(self, src: Path, dst: str): self.add_file(file, dst=f"{dst}/{file.name}") elif src.is_file(): bucket, key = dst.removeprefix("/").split("/", 1) - self.put_object(bucket, key, src.read_bytes()) + last_modified = datetime.fromtimestamp(src.stat().st_mtime, tz=timezone.utc) + self.put_object(bucket, key, src.read_bytes(), last_modified=last_modified) else: raise ValueError(f"Invalid file: {src}") diff --git a/tests/test_epathlib.py b/tests/test_epathlib.py index 06d283e9..7aad01ca 100644 --- a/tests/test_epathlib.py +++ b/tests/test_epathlib.py @@ -11,6 +11,7 @@ import sys import tempfile import unittest +from datetime import datetime, timezone from pathlib import Path from megatron.energon.epathlib import DEFAULT_PROFILE_NAME, EPath @@ -121,6 +122,94 @@ def test_glob(self): p = EPath("/tmp").glob("epathtestfile_[0-3].bin") assert len(list(p)) == 4 + def test_dss_glob_walk(self): + import megatron.energon.epathlib.epath as epath_mod + + orig_env_cache_dir = os.environ.get("NVDATASET_CACHE_DIR") + orig_mod_cache_dir = epath_mod.NVDATASET_CACHE_DIR + + with tempfile.TemporaryDirectory() as td: + cache_dir = Path(td) / "nvds_cache" + media_dir = cache_dir / "charts1234" / "v0" / "images" + media_dir.mkdir(parents=True) + (media_dir / "000.jpg").write_bytes(b"\xff\xd8\xff\xd9") + (media_dir / "001.jpg").write_bytes(b"\xff\xd8\xff\xd9") + (media_dir / "002.txt").write_bytes(b"dummy") + + try: + os.environ["NVDATASET_CACHE_DIR"] = str(cache_dir) + epath_mod.NVDATASET_CACHE_DIR = EPath(cache_dir) + + root = EPath("dss://charts1234@v0") + found = sorted(root.glob("**/*.jpg")) + + print(found) + + assert [str(path) for path in found] == [ + "dss://charts1234@v0/images/000.jpg", + "dss://charts1234@v0/images/001.jpg", + ] + assert [path.relative_to(root) for path in found] == [ + "images/000.jpg", + "images/001.jpg", + ] + + found = sorted(root.walk()) + print(found) + assert [str(p) for p in found] == [ + "dss://charts1234@v0/images/000.jpg", + "dss://charts1234@v0/images/001.jpg", + "dss://charts1234@v0/images/002.txt", + ] + finally: + if orig_env_cache_dir is None: + os.environ.pop("NVDATASET_CACHE_DIR", None) + else: + os.environ["NVDATASET_CACHE_DIR"] = orig_env_cache_dir + epath_mod.NVDATASET_CACHE_DIR = orig_mod_cache_dir + + def test_s3_glob_walk(self): + with setup_s3_emulator(profile_name="s3test_dss_walk") as s3_emulator: + s3_emulator.put_object("test", "dir/file.txt", b"dummy") + s3_emulator.put_object("test", "dir/subdir/file2.txt", b"dummy") + s3_emulator.put_object("test", "dir/subdir/file3.blob", b"dummy") + root = EPath("msc://s3test_dss_walk/test/dir") + found = sorted(root.walk()) + print(found) + assert [str(p) for p in found] == [ + "msc://s3test_dss_walk/test/dir/file.txt", + "msc://s3test_dss_walk/test/dir/subdir/file2.txt", + "msc://s3test_dss_walk/test/dir/subdir/file3.blob", + ] + + found = sorted(root.glob("**/*.txt")) + print(found) + assert [str(p) for p in found] == [ + "msc://s3test_dss_walk/test/dir/file.txt", + "msc://s3test_dss_walk/test/dir/subdir/file2.txt", + ] + + def test_local_glob_walk(self): + with tempfile.TemporaryDirectory() as td: + td_path = EPath(td) + (td_path / "file.txt").write_text("dummy") + (td_path / "subdir" / "file2.txt").write_text("dummy") + (td_path / "subdir" / "file3.blob").write_text("dummy") + root = EPath(td_path) + found = sorted(root.walk()) + print(found) + assert [str(p) for p in found] == [ + str(td_path / "file.txt"), + str(td_path / "subdir" / "file2.txt"), + str(td_path / "subdir" / "file3.blob"), + ] + found = sorted(root.glob("**/*.txt")) + print(found) + assert [str(p) for p in found] == [ + str(td_path / "file.txt"), + str(td_path / "subdir" / "file2.txt"), + ] + def test_s3_path_resolution(self): """Test s3 path resolution""" rclone_config_path = EPath("/tmp/XDG_CONFIG_HOME/.config/rclone/rclone.conf") @@ -337,6 +426,26 @@ def test_msc_s3_dataprep_path_operations(self): for p in found: p.unlink() + def test_msc_s3_stat_uses_stored_last_modified(self): + profile = "s3test_msc_timestamps" + with setup_s3_emulator(profile_name=profile) as state: + remote_path = EPath(f"msc://{profile}/bucket/data.bin") + original_mtime = datetime(2024, 1, 2, 3, 4, 5, tzinfo=timezone.utc) + updated_mtime = datetime(2024, 1, 3, 3, 4, 5, tzinfo=timezone.utc) + + state.put_object("bucket", "data.bin", b"first", last_modified=original_mtime) + + first_stat = remote_path.stat() + second_stat = remote_path.stat() + assert int(first_stat.last_modified.timestamp()) == int(original_mtime.timestamp()) + assert second_stat.last_modified == first_stat.last_modified + + state.put_object("bucket", "data.bin", b"second", last_modified=updated_mtime) + + updated_stat = remote_path.stat() + assert int(updated_stat.last_modified.timestamp()) == int(updated_mtime.timestamp()) + assert updated_stat.content_length == len(b"second") + def test_dss_path_requires_version(self): with self.assertRaisesRegex( AssertionError,