diff --git a/src/megatron/energon/cache/file_store.py b/src/megatron/energon/cache/file_store.py
index 2e924206..4a79e788 100644
--- a/src/megatron/energon/cache/file_store.py
+++ b/src/megatron/energon/cache/file_store.py
@@ -161,19 +161,18 @@ def get_path(self) -> str:
def get_media_metadata(self, key: str) -> MediaMetadataBase:
if self._media_metadata_available is None:
try:
- has_metadata = self.sqlite_reader.db_has_media_metadata()
+ self._media_metadata_available = self.sqlite_reader.db_has_media_metadata()
except sqlite3.Error as exc: # pragma: no cover - defensive
+ self._media_metadata_available = False
raise RuntimeError(
"Failed to inspect media metadata table. Re-run `energon prepare --media-metadata-by-...`."
) from exc
- if not has_metadata:
- raise RuntimeError(
- "Media metadata is not available for this dataset. "
- "Run `energon prepare --media-metadata-by-...` to generate it."
- )
-
- self._media_metadata_available = True
+ if not self._media_metadata_available:
+ raise RuntimeError(
+ "Media metadata is not available for this dataset. "
+ "Run `energon prepare --media-metadata-by-...` to generate it."
+ )
try:
row = self.sqlite_reader.get_media_metadata(key)
diff --git a/src/megatron/energon/epathlib/epath.py b/src/megatron/energon/epathlib/epath.py
index 6357262a..9ebbfbc8 100644
--- a/src/megatron/energon/epathlib/epath.py
+++ b/src/megatron/energon/epathlib/epath.py
@@ -266,17 +266,15 @@ def url(self) -> str:
return f"msc://{self.profile}{int_path_str}"
def is_local(self) -> bool:
- if self.profile == "dss":
- # For now, a DSS path is always considered local.
- # Note that this does not mean it exists on the local filesystem.
- return True
- else:
- return self.profile == DEFAULT_PROFILE_NAME
+ # It will return a posix path if the fs is local, otherwise None
+ return self.fs.get_posix_path(self._internal_str_path) is not None
def local_path(self) -> PathlibPath:
- if not self.is_local():
+ # This resolves the path if it exists, probably ok.
+ posix_path = self.fs.get_posix_path(self._internal_str_path)
+ if posix_path is None:
raise ValueError(f"Path {self} is not local")
- return PathlibPath(self._internal_str_path)
+ return PathlibPath(posix_path)
def is_dir(self) -> bool:
try:
@@ -290,14 +288,50 @@ def is_file(self) -> bool:
def mkdir(self, exist_ok: bool = True, parents: bool = False):
pass
- def glob(self, pattern) -> Generator["EPath", None, None]:
+ def walk(self) -> Generator["EPath", None, None]:
+ """Returns all files within this path (no folders)."""
+ # Prefix to be removed from found paths to remap to relative paths
+ root_prefix = self._internal_str_path.lstrip("/")
+
+ for obj in self.fs.list_recursive(self._internal_str_path):
+ rel = obj.key
+ if root_prefix:
+ if rel.startswith(root_prefix + "/"):
+ rel = rel[len(root_prefix) + 1 :]
+ elif rel.startswith("/" + root_prefix + "/"):
+ rel = rel[len(root_prefix) + 2 :]
+ elif rel == root_prefix or rel == "/" + root_prefix:
+ rel = "."
+
+ path = EPath(self)
+ path.internal_path = self._resolve(self.internal_path / PurePosixPath(rel))
+ yield path
+
+ def glob(self, pattern: str) -> Generator["EPath", None, None]:
+ """Returns all files matching the pattern within this path (no folders)."""
search_path_pattern = (self / pattern)._internal_str_path
+ # MSC glob matches keys like ``bucket/key``; a leading ``/`` breaks wcmatch (pattern
+ # ``/b/**`` never matches ``b/parts/x``). Returned keys may repeat the bucket prefix; strip
+ # it before joining with ``internal_path`` so we do not get ``/b/b/parts/...``.
+ search_path_pattern = search_path_pattern.lstrip("/")
+
+ # Prefix to be removed from found paths to remap to relative paths
+ root_prefix = self._internal_str_path.lstrip("/")
for path in self.fs.glob(search_path_pattern):
assert isinstance(path, str)
+ rel = path
+ if root_prefix:
+ if rel.startswith(root_prefix + "/"):
+ rel = rel[len(root_prefix) + 1 :]
+ elif rel.startswith("/" + root_prefix + "/"):
+ rel = rel[len(root_prefix) + 2 :]
+ elif rel == root_prefix or rel == "/" + root_prefix:
+ rel = "."
+
new_path = EPath(self)
- new_path.internal_path = self._resolve(self.internal_path / PurePosixPath(path))
+ new_path.internal_path = self._resolve(self.internal_path / PurePosixPath(rel))
yield new_path
@@ -321,6 +355,9 @@ def relative_to(self, other: "EPath") -> str:
return str(self.internal_path.relative_to(other.internal_path))
+ def stat(self) -> msc.types.ObjectMetadata:
+ return self.fs.info(self._internal_str_path)
+
@property
def display_name(self) -> str:
if self.profile == "dss":
diff --git a/src/megatron/energon/flavors/webdataset/indexing.py b/src/megatron/energon/flavors/webdataset/indexing.py
index 4e5cebcd..4d09d242 100644
--- a/src/megatron/energon/flavors/webdataset/indexing.py
+++ b/src/megatron/energon/flavors/webdataset/indexing.py
@@ -55,6 +55,16 @@ def __init__(
part_name TEXT,
content_byte_offset INTEGER,
content_byte_size INTEGER)
+ if enable_media_metadata is True, it also creates the media_metadata table:
+ - media_metadata(entry_key TEXT PRIMARY KEY,
+ metadata_type TEXT NOT NULL,
+ metadata_json TEXT NOT NULL)
+ if enable_media_metadata is True, it also creates the media_filters table:
+ - media_filters(filter_id INTEGER PRIMARY KEY AUTOINCREMENT,
+ strategy TEXT NOT NULL,
+ patterns TEXT,
+ created_at_utc TEXT DEFAULT CURRENT_TIMESTAMP,
+ UNIQUE(strategy, patterns))
Also creates indexes:
- samples(sample_key)
- samples(tar_file_id, sample_index)
@@ -70,7 +80,6 @@ def __init__(
# Initialize SQLite connection
# Only supporting local file system, because sqlite does not support remote file systems.
- # TODO: Implement remote file systems. Maybe create locally in tmp then upload?
path = self.sqlite_path.local_path()
path.parent.mkdir(parents=True, exist_ok=True)
self.db = sqlite3.connect(path)
@@ -355,6 +364,9 @@ class SqliteIndexReader:
part_name TEXT,
content_byte_offset INTEGER,
content_byte_size INTEGER)
+ - media_metadata(entry_key TEXT PRIMARY KEY,
+ metadata_type TEXT NOT NULL,
+ metadata_json TEXT NOT NULL)
"""
sqlite_path: EPath
diff --git a/src/megatron/energon/flavors/webdataset/prepare.py b/src/megatron/energon/flavors/webdataset/prepare.py
index d3a16223..62e6388e 100644
--- a/src/megatron/energon/flavors/webdataset/prepare.py
+++ b/src/megatron/energon/flavors/webdataset/prepare.py
@@ -6,8 +6,10 @@
import logging
import random
import re
+import shutil
import sys
import tarfile
+import tempfile
import uuid
from pathlib import Path
from typing import (
@@ -115,6 +117,7 @@ class SqliteIndexWriterAggregator(
reset_tables: bool
media_metadata_written: int
progress_on_media: bool
+ sqlite_local_build_path: Optional[Path]
def __init__(
self,
@@ -127,6 +130,7 @@ def __init__(
media_filter: Optional[MediaFilterConfig] = None,
reset_tables: bool = True,
progress_on_media: bool = False,
+ sqlite_local_build_path: Optional[Path] = None,
):
self.sqlite_path = sqlite_path
self.total_tasks = total_tasks
@@ -140,6 +144,7 @@ def __init__(
self.reset_tables = reset_tables
self.media_metadata_written = 0
self.progress_on_media = progress_on_media
+ self.sqlite_local_build_path = sqlite_local_build_path
if progress_fn is not None:
self.prog_iter = progress_fn(iter(range(self.total_tasks)), self.total_tasks)
@@ -147,8 +152,13 @@ def __init__(
self.prog_iter = iter(range(self.total_tasks))
def on_start(self, aggregator_pool: AggregatorPool) -> None:
+ local_sqlite = self.sqlite_path
+ if self.sqlite_local_build_path is not None:
+ local_sqlite = EPath(self.sqlite_local_build_path)
+ if self.sqlite_path.is_file():
+ self.sqlite_path.copy(local_sqlite)
self.writer = SqliteIndexWriter(
- self.sqlite_path,
+ local_sqlite,
enable_sample_tables=self.enable_sample_tables,
enable_media_metadata=self.enable_media_metadata,
reset_tables=self.reset_tables,
@@ -213,6 +223,9 @@ def on_finish(self, aggregator_pool: AggregatorPool) -> None:
patterns=",".join(self.media_filter.patterns),
)
self.writer.close()
+ if self.sqlite_local_build_path is not None:
+ EPath(self.sqlite_local_build_path).copy(self.sqlite_path)
+ self.sqlite_local_build_path.unlink(missing_ok=True)
def get_final_result_data(
self,
@@ -489,6 +502,7 @@ def prepare_dataset(
tar_index_only: bool = False,
media_filter: Optional[MediaFilterConfig] = None,
fix_duplicates: bool = False,
+ index_sqlite_tmp_path: Optional[Path] = None,
) -> Tuple[Set[str], List[Tuple[str, int]]]:
"""
Preprocess the shards and write the split config. Preprocessing is done in parallel.
@@ -507,6 +521,9 @@ def prepare_dataset(
tar_index_only: Only create tar-index, then exit
media_filter: Media filter configuration
fix_duplicates: If True, fix duplicate keys in the dataset by renaming the files in the shards.
+ index_sqlite_tmp_path: When ``parent_path`` is remote, temp file path used to build ``index.sqlite``
+ locally before upload. If omitted, a new directory under ``/tmp`` is created and removed
+ after a successful run.
Returns:
The set of all parts found in the shards. But at most 50.
@@ -562,179 +579,200 @@ def prepare_dataset(
else:
print("No duplicate keys found, continuing.")
- aggregator = SqliteIndexWriterAggregator(
- parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME,
- total_tasks=len(paths),
- progress_fn=progress_fn,
- enable_media_metadata=media_filter is not None,
- media_filter=media_filter,
- )
-
- process_tar = functools.partial(
- cls._preprocess_tar,
- shard_to_idx=shard_to_idx,
- parent_path=parent_path,
- max_parts=50,
- media_filter=media_filter,
- )
-
- pool = AggregatorPool(
- num_workers=workers,
- user_produce_data=process_tar,
- aggregator=aggregator,
- batch_size=INDEX_BATCH_SIZE,
- )
-
- for path in paths:
- pool.submit_task(path)
+ owns_remote_sqlite_tmp = False
+ remote_sqlite_tmp_dir: Optional[Path] = None
+ if not parent_path.is_local():
+ if index_sqlite_tmp_path is None:
+ remote_sqlite_tmp_dir = Path(
+ tempfile.mkdtemp(dir="/tmp", prefix="energon-prepare-")
+ )
+ index_sqlite_tmp_path = remote_sqlite_tmp_dir / INDEX_SQLITE_FILENAME
+ owns_remote_sqlite_tmp = True
+ else:
+ index_sqlite_tmp_path = None
try:
- shards, found_parts, had_update = pool.process()
- except DuplicateSampleKeyError as error:
- print("The data contains duplicate keys (e.g. same filename in different shards).")
- print(f'Example duplicate key: "{error.sample_key}"')
- print()
- print(
- "Energon does not support duplicate keys anymore, but we offer a tool to fix your dataset. "
- "Run `energon prepare` with `--fix-duplicates` to fix your dataset. Inside each tar, it will "
- "put each file in a subfolder with the shard name like `shard_0/filename.ext`."
+ aggregator = SqliteIndexWriterAggregator(
+ parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME,
+ total_tasks=len(paths),
+ progress_fn=progress_fn,
+ enable_media_metadata=media_filter is not None,
+ media_filter=media_filter,
+ sqlite_local_build_path=index_sqlite_tmp_path,
)
- if (parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME).is_file():
- (parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME).unlink()
+ process_tar = functools.partial(
+ cls._preprocess_tar,
+ shard_to_idx=shard_to_idx,
+ parent_path=parent_path,
+ max_parts=50,
+ media_filter=media_filter,
+ )
- sys.exit(1)
+ pool = AggregatorPool(
+ num_workers=workers,
+ user_produce_data=process_tar,
+ aggregator=aggregator,
+ batch_size=INDEX_BATCH_SIZE,
+ )
+
+ for path in paths:
+ pool.submit_task(path)
- # Fix permissions if needed
- if fix_local_permissions:
try:
- Path(str(parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME)).chmod(file_perms)
- except OSError:
- pass
+ shards, found_parts, had_update = pool.process()
+ except DuplicateSampleKeyError as error:
+ print("The data contains duplicate keys (e.g. same filename in different shards).")
+ print(f'Example duplicate key: "{error.sample_key}"')
+ print()
+ print(
+ "Energon does not support duplicate keys anymore, but we offer a tool to fix your dataset. "
+ "Run `energon prepare` with `--fix-duplicates` to fix your dataset. Inside each tar, it will "
+ "put each file in a subfolder with the shard name like `shard_0/filename.ext`."
+ )
+
+ if (parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME).is_file():
+ (parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME).unlink()
- if had_update:
- logger.info("Regenerating dataset UUID...")
- with (parent_path / MAIN_FOLDER_NAME / INDEX_UUID_FILENAME).open("w") as f:
- f.write(str(uuid.uuid4()))
+ sys.exit(1)
# Fix permissions if needed
if fix_local_permissions:
try:
- (parent_path / MAIN_FOLDER_NAME / INDEX_UUID_FILENAME).local_path().chmod(
+ Path(str(parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME)).chmod(
file_perms
)
except OSError:
pass
- json_info_config = parent_path / MAIN_FOLDER_NAME / INFO_JSON_FILENAME
- yaml_info_config = parent_path / MAIN_FOLDER_NAME / INFO_YAML_FILENAME
-
- if tar_index_only:
- if yaml_info_config.is_file() and not json_info_config.is_file():
- # Convert legacy .info.yaml to .info.json
- with json_info_config.open("w") as f:
- json.dump(load_yaml(yaml_info_config.read_bytes()), f, indent=2)
+ if had_update:
+ logger.info("Regenerating dataset UUID...")
+ with (parent_path / MAIN_FOLDER_NAME / INDEX_UUID_FILENAME).open("w") as f:
+ f.write(str(uuid.uuid4()))
+ # Fix permissions if needed
if fix_local_permissions:
try:
- json_info_config.local_path().chmod(file_perms)
+ (parent_path / MAIN_FOLDER_NAME / INDEX_UUID_FILENAME).local_path().chmod(
+ file_perms
+ )
except OSError:
pass
- return found_parts
+ json_info_config = parent_path / MAIN_FOLDER_NAME / INFO_JSON_FILENAME
+ yaml_info_config = parent_path / MAIN_FOLDER_NAME / INFO_YAML_FILENAME
- assert len(shards) == len(shard_to_idx), (
- f"Lengths of shards and shard_to_idx do not match: {len(shards)} != {len(shard_to_idx)}"
- )
+ if tar_index_only:
+ if yaml_info_config.is_file() and not json_info_config.is_file():
+ # Convert legacy .info.yaml to .info.json
+ with json_info_config.open("w") as f:
+ json.dump(load_yaml(yaml_info_config.read_bytes()), f, indent=2)
- # Sort the shards according to the order in the input list
- shards.sort(key=lambda shard: shard_to_idx[shard.name])
+ if fix_local_permissions:
+ try:
+ json_info_config.local_path().chmod(file_perms)
+ except OSError:
+ pass
- # Save info
- assert [shard.name for shard in shards] == list(shard_to_idx.keys()), (
- "Shards are not in the same order as in the input list."
- )
+ return found_parts
- info = WebdatasetInfo(
- energon_version=__version__,
- shard_counts={shard.name: shard.count for shard in shards},
- )
- print(f"Saving info to {json_info_config}")
+ assert len(shards) == len(shard_to_idx), (
+ f"Lengths of shards and shard_to_idx do not match: {len(shards)} != {len(shard_to_idx)}"
+ )
- with json_info_config.open("w") as wf:
- json.dump(to_json_object(info), wf, indent=2)
+ # Sort the shards according to the order in the input list
+ shards.sort(key=lambda shard: shard_to_idx[shard.name])
- # Fix permissions if needed
- if fix_local_permissions:
- try:
- json_info_config.local_path().chmod(file_perms)
- except OSError:
- pass
+ # Save info
+ assert [shard.name for shard in shards] == list(shard_to_idx.keys()), (
+ "Shards are not in the same order as in the input list."
+ )
- if yaml_info_config.is_file():
- # If a .info.yaml existed previously, let's also update it
- # to keep them in sync
- with yaml_info_config.open("w") as wf:
- yaml.dump(to_json_object(info), wf)
-
- if split_parts_ratio is not None:
- # Normalize ratio
- total_ratio = sum(split_ratio for _, split_ratio in split_parts_ratio)
- split_parts_ratio = [
- (split_part, split_ratio / total_ratio)
- for split_part, split_ratio in split_parts_ratio
- ]
- # Sample from shards based on the split ratio from split parts
- split_shards = {}
- if shuffle_seed is not None:
- random.Random(shuffle_seed).shuffle(shards)
- split_total = 0
- split_offset = 0
- for split_part, split_ratio in split_parts_ratio:
- split_total += split_ratio
- split_end = int(len(shards) * split_total)
- split_shards[split_part] = [shard.name for shard in shards[split_offset:split_end]]
- split_offset = split_end
- else:
- assert split_parts_patterns is not None, (
- "Require either split_parts_ratio or split_parts_patterns"
+ info = WebdatasetInfo(
+ energon_version=__version__,
+ shard_counts={shard.name: shard.count for shard in shards},
)
- # Sample from shards based on the split patterns from split parts
- split_shards = {}
- for split_part, split_pattern in split_parts_patterns:
- patterns = [
- re.compile(pattern) for pattern in braceexpand.braceexpand(split_pattern)
- ]
- split_shards[split_part] = [
- shard.name
- for shard in shards
- if any(pattern.match(shard.name) for pattern in patterns)
- ]
+ print(f"Saving info to {json_info_config}")
+
+ with json_info_config.open("w") as wf:
+ json.dump(to_json_object(info), wf, indent=2)
- # Optimize the split parts by trying to bracecollapse the shard names
- print("Collapsing split parts... ", flush=True, end="")
- for split_part in split_shards:
- split_shards[split_part] = collapse(split_shards[split_part], keep_order=True)
- print("Done", flush=True)
-
- # Save split config
- splits_config = WebdatasetSplits(split_parts=split_shards)
- with (parent_path / MAIN_FOLDER_NAME / split_config).open("w") as wf:
- if split_config.endswith(".yaml"):
- yaml.dump(to_json_object(splits_config), wf, sort_keys=False)
- elif split_config.endswith(".json"):
- json.dump(to_json_object(splits_config), wf, indent=2)
+ # Fix permissions if needed
+ if fix_local_permissions:
+ try:
+ json_info_config.local_path().chmod(file_perms)
+ except OSError:
+ pass
+
+ if yaml_info_config.is_file():
+ # If a .info.yaml existed previously, let's also update it
+ # to keep them in sync
+ with yaml_info_config.open("w") as wf:
+ yaml.dump(to_json_object(info), wf)
+
+ if split_parts_ratio is not None:
+ # Normalize ratio
+ total_ratio = sum(split_ratio for _, split_ratio in split_parts_ratio)
+ split_parts_ratio = [
+ (split_part, split_ratio / total_ratio)
+ for split_part, split_ratio in split_parts_ratio
+ ]
+ # Sample from shards based on the split ratio from split parts
+ split_shards = {}
+ if shuffle_seed is not None:
+ random.Random(shuffle_seed).shuffle(shards)
+ split_total = 0
+ split_offset = 0
+ for split_part, split_ratio in split_parts_ratio:
+ split_total += split_ratio
+ split_end = int(len(shards) * split_total)
+ split_shards[split_part] = [
+ shard.name for shard in shards[split_offset:split_end]
+ ]
+ split_offset = split_end
else:
- raise ValueError(f"Invalid split config extension: {split_config}")
+ assert split_parts_patterns is not None, (
+ "Require either split_parts_ratio or split_parts_patterns"
+ )
+ # Sample from shards based on the split patterns from split parts
+ split_shards = {}
+ for split_part, split_pattern in split_parts_patterns:
+ patterns = [
+ re.compile(pattern) for pattern in braceexpand.braceexpand(split_pattern)
+ ]
+ split_shards[split_part] = [
+ shard.name
+ for shard in shards
+ if any(pattern.match(shard.name) for pattern in patterns)
+ ]
+
+ # Optimize the split parts by trying to bracecollapse the shard names
+ print("Collapsing split parts... ", flush=True, end="")
+ for split_part in split_shards:
+ split_shards[split_part] = collapse(split_shards[split_part], keep_order=True)
+ print("Done", flush=True)
+
+ # Save split config
+ splits_config = WebdatasetSplits(split_parts=split_shards)
+ with (parent_path / MAIN_FOLDER_NAME / split_config).open("w") as wf:
+ if split_config.endswith(".yaml"):
+ yaml.dump(to_json_object(splits_config), wf, sort_keys=False)
+ elif split_config.endswith(".json"):
+ json.dump(to_json_object(splits_config), wf, indent=2)
+ else:
+ raise ValueError(f"Invalid split config extension: {split_config}")
- # Fix permissions if needed
- if fix_local_permissions:
- try:
- (parent_path / MAIN_FOLDER_NAME / split_config).local_path().chmod(file_perms)
- except OSError:
- pass
+ # Fix permissions if needed
+ if fix_local_permissions:
+ try:
+ (parent_path / MAIN_FOLDER_NAME / split_config).local_path().chmod(file_perms)
+ except OSError:
+ pass
- return found_parts
+ return found_parts
+ finally:
+ if owns_remote_sqlite_tmp and remote_sqlite_tmp_dir is not None:
+ shutil.rmtree(remote_sqlite_tmp_dir, ignore_errors=True)
@classmethod
def add_media_metadata(
@@ -744,8 +782,19 @@ def add_media_metadata(
media_filter: MediaFilterConfig,
workers: int = 32,
progress_fn: Callable[[Iterator[Any], int], Iterator[T]] = (lambda x, y: x),
+ index_sqlite_tmp_path: Optional[Path] = None,
) -> int:
- """Add or refresh media metadata in an existing WebDataset index."""
+ """Add or refresh media metadata in an existing WebDataset index.
+
+ Args:
+ parent_path: WebDataset root path.
+ media_filter: Media filtering configuration.
+ workers: Number of parallel workers.
+ progress_fn: Callback for progress updates.
+ index_sqlite_tmp_path: When ``parent_path`` is remote, sqlite file path used to build
+ ``index.sqlite`` locally before upload. If omitted, a new directory under
+ ``/tmp`` is created and removed after a successful run.
+ """
parent_path = EPath(parent_path)
@@ -762,34 +811,65 @@ def add_media_metadata(
if path not in shard_counts:
raise ValueError(f"Shard '{path}' not present in dataset metadata")
- aggregator = SqliteIndexWriterAggregator(
- parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME,
- total_tasks=len(expanded_paths),
- progress_fn=progress_fn,
- enable_sample_tables=False,
- enable_media_metadata=True,
- media_filter=media_filter,
- reset_tables=False,
- progress_on_media=False,
- )
+ owns_remote_sqlite_tmp = False
+ remote_sqlite_tmp_dir: Optional[Path] = None
+ if not parent_path.is_local():
+ if index_sqlite_tmp_path is None:
+ remote_sqlite_tmp_dir = Path(
+ tempfile.mkdtemp(dir="/tmp", prefix="energon-prepare-media-")
+ )
+ index_sqlite_tmp_path = remote_sqlite_tmp_dir / INDEX_SQLITE_FILENAME
+ owns_remote_sqlite_tmp = True
+ else:
+ index_sqlite_tmp_path = None
- process_tar = functools.partial(
- cls._extract_media_from_tar,
- parent_path=parent_path,
- media_filter=media_filter,
- shard_counts=shard_counts,
- )
+ sqlite_path = parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME
- pool = AggregatorPool(
- num_workers=min(workers, len(expanded_paths)) or 1,
- user_produce_data=process_tar,
- aggregator=aggregator,
- batch_size=INDEX_BATCH_SIZE,
- )
+ try:
+ aggregator = SqliteIndexWriterAggregator(
+ parent_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME,
+ total_tasks=len(expanded_paths),
+ progress_fn=progress_fn,
+ enable_sample_tables=False,
+ enable_media_metadata=True,
+ media_filter=media_filter,
+ reset_tables=False,
+ progress_on_media=False,
+ sqlite_local_build_path=index_sqlite_tmp_path,
+ )
- for path in expanded_paths:
- pool.submit_task(path)
+ process_tar = functools.partial(
+ cls._extract_media_from_tar,
+ parent_path=parent_path,
+ media_filter=media_filter,
+ shard_counts=shard_counts,
+ )
- pool.process()
+ pool = AggregatorPool(
+ num_workers=min(workers, len(expanded_paths)) or 1,
+ user_produce_data=process_tar,
+ aggregator=aggregator,
+ batch_size=INDEX_BATCH_SIZE,
+ )
+
+ for path in expanded_paths:
+ pool.submit_task(path)
+
+ pool.process()
+
+ if sqlite_path.is_local():
+ try:
+ meta_dir = (parent_path / MAIN_FOLDER_NAME).local_path()
+ # Copy group permissions from the parent dir
+ meta_dir.chmod((parent_path.local_path().stat().st_mode | 0o700))
+ # Just read/write, no execute
+ sqlite_path.local_path().chmod(
+ (parent_path.local_path().stat().st_mode | 0o600) & 0o666
+ )
+ except OSError:
+ pass
- return aggregator.media_metadata_written
+ return aggregator.media_metadata_written
+ finally:
+ if owns_remote_sqlite_tmp and remote_sqlite_tmp_dir is not None:
+ shutil.rmtree(remote_sqlite_tmp_dir, ignore_errors=True)
diff --git a/src/megatron/energon/local_copy.py b/src/megatron/energon/local_copy.py
index 2708f3e1..414aeede 100644
--- a/src/megatron/energon/local_copy.py
+++ b/src/megatron/energon/local_copy.py
@@ -3,18 +3,20 @@
import hashlib
+import os
from pathlib import Path
from filelock import FileLock
from megatron.energon.epathlib import EPath
-LOCAL_COPY_TMP_DIR = Path("/tmp/energon_local_copy")
+LOCAL_COPY_TMP_DIR = Path(os.environ.get("ENERGON_LOCAL_COPY_TMP_DIR", "/tmp/energon_local_copy"))
def ensure_local_copy(path: EPath) -> EPath:
"""If the path is not local, copy it to a temporary directory and return the
- path to the temporary directory.
+ path to the temporary directory. Assuming that the local file is never modified
+ after it is copied. Will re-sync if the remote file is newer.
Args:
path: The path to the file to copy.
@@ -33,15 +35,19 @@ def ensure_local_copy(path: EPath) -> EPath:
final_path = LOCAL_COPY_TMP_DIR / f"{digest}.bin"
lock_path = final_path.with_suffix(".lock")
tmp_path = final_path.with_suffix(".part")
+ mod_time = path.stat().last_modified.timestamp()
# Block until lock is free
with FileLock(lock_path, timeout=60 * 5):
- if final_path.exists(): # someone else already produced it
+ # someone else already produced it
+ if final_path.exists() and final_path.stat().st_mtime >= mod_time:
+ # The local file is already newer than the remote file
return EPath(final_path)
# We are the downloader
try:
path.copy(EPath(tmp_path))
+ os.utime(tmp_path, (tmp_path.stat().st_atime, mod_time))
tmp_path.rename(final_path)
finally:
tmp_path.unlink(missing_ok=True)
diff --git a/src/megatron/energon/media/extractor.py b/src/megatron/energon/media/extractor.py
index 1126584e..4ebeadda 100644
--- a/src/megatron/energon/media/extractor.py
+++ b/src/megatron/energon/media/extractor.py
@@ -15,12 +15,13 @@
from PIL import Image, UnidentifiedImageError
from megatron.energon.av import AVDecoder
+from megatron.energon.epathlib import EPath
from megatron.energon.media.metadata import ImageMetadata, MediaMetadataBase, MediaMetadataType
logger = logging.getLogger(__name__)
-SourceData = Union[bytes, Path, BinaryIO]
+SourceData = Union[bytes, EPath, BinaryIO]
class MediaFilterStrategy(str, Enum):
@@ -210,9 +211,13 @@ def _build_metadata(
def _build_image_metadata(source: SourceData) -> ImageMetadata | None:
+ should_close = False
try:
if isinstance(source, (bytes, bytearray)):
source = io.BytesIO(source)
+ elif isinstance(source, EPath):
+ source = source.open("rb")
+ should_close = True
with Image.open(source) as image:
image.load()
@@ -225,6 +230,9 @@ def _build_image_metadata(source: SourceData) -> ImageMetadata | None:
except UnidentifiedImageError:
logger.debug("Failed to parse image metadata", exc_info=True)
return None
+ finally:
+ if should_close:
+ source.close()
def _build_av_metadata(source: SourceData) -> MediaMetadataBase | None:
diff --git a/src/megatron/energon/media/filesystem_prepare.py b/src/megatron/energon/media/filesystem_prepare.py
index afa492d2..d9b8c3bb 100644
--- a/src/megatron/energon/media/filesystem_prepare.py
+++ b/src/megatron/energon/media/filesystem_prepare.py
@@ -3,7 +3,8 @@
from __future__ import annotations
-import os
+import shutil
+import tempfile
from functools import partial
from pathlib import Path
from typing import Callable, Iterator
@@ -31,6 +32,7 @@ def prepare_filesystem_dataset(
*,
progress: bool,
workers: int = 16,
+ index_sqlite_tmp_path: Path | None = None,
) -> int:
"""Scan a filesystem dataset and materialize media metadata into SQLite.
@@ -38,23 +40,29 @@ def prepare_filesystem_dataset(
root_path: Dataset root directory.
media_filter: Media filtering configuration.
progress: Whether to display a tqdm progress bar.
+ index_sqlite_tmp_path: When ``root_path`` is remote, temp file path used to build
+ ``index.sqlite`` locally before upload. If omitted, a new directory under
+ ``/tmp`` is created and removed after a successful run.
Returns:
Number of metadata entries written to the database.
"""
- # Only supporting local file system, because sqlite does not support remote file systems.
- # TODO: Implement remote file systems. Maybe create locally in tmp then upload?
- root = root_path.local_path()
- assert root.is_dir(), f"Expected directory for filesystem dataset, got {root}"
- assert root.is_absolute(), f"Filesystem dataset path must be absolute: {root}"
+ assert not root_path.is_file(), f"Expected directory for filesystem dataset, got {root_path}"
- meta_dir = root / MAIN_FOLDER_NAME
- meta_dir.mkdir(exist_ok=True, parents=True)
+ files = _collect_media_files(root=root_path, media_filter=media_filter, progress=progress)
- files = _collect_media_files(root=root, media_filter=media_filter, progress=progress)
-
- sqlite_path = EPath(meta_dir / INDEX_SQLITE_FILENAME)
+ owns_remote_sqlite_tmp = False
+ remote_sqlite_tmp_dir: Path | None = None
+ if not root_path.is_local():
+ if index_sqlite_tmp_path is None:
+ remote_sqlite_tmp_dir = Path(
+ tempfile.mkdtemp(dir="/tmp", prefix="energon-prepare-media-")
+ )
+ index_sqlite_tmp_path = remote_sqlite_tmp_dir / INDEX_SQLITE_FILENAME
+ owns_remote_sqlite_tmp = True
+ else:
+ index_sqlite_tmp_path = None
agg_progress_fn: Callable[[Iterator[int], int], Iterator[int]] | None = None
if progress:
@@ -64,55 +72,66 @@ def agg_progress_fn(iterator: Iterator[int], total: int) -> Iterator[int]:
with tqdm(iterator, total=total, unit="file", desc="Processing media files") as bar:
yield from bar
- aggregator = SqliteIndexWriterAggregator(
- sqlite_path,
- total_tasks=len(files),
- progress_fn=agg_progress_fn,
- enable_media_metadata=True,
- media_filter=media_filter,
- reset_tables=False,
- enable_sample_tables=False,
- progress_on_media=progress,
- )
-
- pool = AggregatorPool[
- Path,
- IndexAggregatable,
- tuple[list[ShardInfo], set[str], bool, list[tuple[str, int]]],
- ](
- num_workers=min(workers, len(files)) or 1,
- user_produce_data=partial(
- _process_filesystem_entry,
- root=root,
- media_filter=media_filter,
- ),
- aggregator=aggregator,
- batch_size=INDEX_BATCH_SIZE,
- )
-
- for file_path in files:
- pool.submit_task(file_path)
-
- pool.process()
+ sqlite_path = root_path / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME
try:
- # Copy group permissions from the parent dir
- meta_dir.chmod((root.stat().st_mode | 0o700))
- # Just read/write, no execute
- sqlite_path.local_path().chmod((root.stat().st_mode | 0o600) & 0o666)
- except OSError:
- pass
-
- return aggregator.media_metadata_written
+ aggregator = SqliteIndexWriterAggregator(
+ sqlite_path,
+ total_tasks=len(files),
+ progress_fn=agg_progress_fn,
+ enable_media_metadata=True,
+ media_filter=media_filter,
+ reset_tables=False,
+ enable_sample_tables=False,
+ progress_on_media=progress,
+ sqlite_local_build_path=index_sqlite_tmp_path,
+ )
+
+ pool = AggregatorPool[
+ EPath,
+ IndexAggregatable,
+ tuple[list[ShardInfo], set[str], bool, list[tuple[str, int]]],
+ ](
+ num_workers=min(workers, len(files)) or 1,
+ user_produce_data=partial(
+ _process_filesystem_entry,
+ root=root_path,
+ media_filter=media_filter,
+ ),
+ aggregator=aggregator,
+ batch_size=INDEX_BATCH_SIZE,
+ )
+
+ for file_path in files:
+ pool.submit_task(file_path)
+
+ pool.process()
+
+ if sqlite_path.is_local():
+ try:
+ meta_dir = (root_path / MAIN_FOLDER_NAME).local_path()
+ # Copy group permissions from the parent dir
+ meta_dir.chmod((root_path.local_path().stat().st_mode | 0o700))
+ # Just read/write, no execute
+ sqlite_path.local_path().chmod(
+ (root_path.local_path().stat().st_mode | 0o600) & 0o666
+ )
+ except OSError:
+ pass
+
+ return aggregator.media_metadata_written
+ finally:
+ if owns_remote_sqlite_tmp and remote_sqlite_tmp_dir is not None:
+ shutil.rmtree(remote_sqlite_tmp_dir, ignore_errors=True)
def _collect_media_files(
- *, root: Path, media_filter: MediaFilterConfig, progress: bool = False
-) -> list[Path]:
+ *, root: EPath, media_filter: MediaFilterConfig, progress: bool = False
+) -> list[EPath]:
"""Return a sorted list of files to process based on the media filter."""
consider_all = media_filter.should_consider_all()
- files: list[Path] = []
+ files: list[EPath] = []
progress_bar = None
if progress:
@@ -120,28 +139,26 @@ def _collect_media_files(
progress_bar = tqdm(total=None, unit="file", desc="Collecting media files")
- for dirpath, dirnames, filenames in os.walk(root, followlinks=False):
- current_dir = Path(dirpath)
-
- if current_dir.name == MAIN_FOLDER_NAME:
- dirnames[:] = []
- continue
+ # if root.is_local() and not root.profile == "dss":
+ # paths = (
+ # EPath(path) / file
+ # for path, _dirs, files in os.walk(root.local_path(), followlinks=False)
+ # for file in files
+ # )
+ # else:
+ # paths = root.glob("**/*")
- dirnames[:] = [d for d in dirnames if d != MAIN_FOLDER_NAME]
+ for file in root.walk():
+ if progress_bar is not None:
+ progress_bar.update()
- for filename in filenames:
- if progress_bar is not None:
- progress_bar.update()
-
- if not consider_all and not media_filter.should_consider_media(filename):
- continue
+ if ("/" + MAIN_FOLDER_NAME + "/") in file.url:
+ continue
- candidate = current_dir / filename
- if MAIN_FOLDER_NAME in candidate.parts:
- continue
+ if not consider_all and not media_filter.should_consider_media(file.name):
+ continue
- if candidate.is_file():
- files.append(candidate)
+ files.append(file)
if progress_bar is not None:
progress_bar.close()
@@ -151,18 +168,17 @@ def _collect_media_files(
def _process_filesystem_entry(
- file_path: Path | str,
+ file_path: EPath,
*,
- root: Path,
+ root: EPath,
media_filter: MediaFilterConfig,
) -> Iterator[IndexMediaMetadata]:
- file_path = Path(file_path)
extracted_metadata = media_filter.extract_metadata(file_path)
if extracted_metadata is None:
return
stored_type, metadata_json = serialize_media_metadata(extracted_metadata)
- entry_key = file_path.relative_to(root).as_posix()
+ entry_key = file_path.relative_to(root)
yield IndexMediaMetadata(
entry_key=entry_key,
diff --git a/src/megatron/energon/tools/prepare.py b/src/megatron/energon/tools/prepare.py
index 5e6643cf..8a0a82e0 100644
--- a/src/megatron/energon/tools/prepare.py
+++ b/src/megatron/energon/tools/prepare.py
@@ -6,6 +6,7 @@
import json
import re
import typing
+from pathlib import Path
from types import FunctionType
from typing import Any, List, Optional, Tuple, Type
@@ -184,6 +185,15 @@ def printify_json(data: Any) -> Any:
help='Field mapping in JSON format (e.g., \'{"image": "jpg", "caption": "txt"}\'). If not set in non-interactive mode, a sample loader from template will be created. Use with --sample-type. Only applies if sample_type is not set to CrudeWebdataset.',
default=None,
)
+@click.option(
+ "--tmp-path",
+ type=click.Path(path_type=Path),
+ default=None,
+ help=(
+ "When PATH is remote (e.g. S3), directory where index.sqlite is built locally before upload. "
+ "Default: a new directory under /tmp (removed after success)."
+ ),
+)
def command(
path: EPath,
progress: bool,
@@ -203,6 +213,7 @@ def command(
field_map: Optional[str],
skip_dataset_yaml: bool,
dataset_yaml_name: str,
+ tmp_path: Optional[Path],
):
"""Prepare WebDataset for use with energon.
@@ -337,6 +348,11 @@ def progress_fn(els, length=None):
def progress_fn(els, length=None):
return els
+ if tmp_path is not None:
+ index_sqlite_tmp_path = tmp_path / "index.sqlite"
+ else:
+ index_sqlite_tmp_path = None
+
found_types = BaseWebdatasetFactory.prepare_dataset(
path,
all_tars,
@@ -348,6 +364,7 @@ def progress_fn(els, length=None):
workers=num_workers,
media_filter=media_filter_config,
fix_duplicates=fix_duplicates,
+ index_sqlite_tmp_path=index_sqlite_tmp_path,
)
found_types = list(found_types)
diff --git a/src/megatron/energon/tools/prepare_media.py b/src/megatron/energon/tools/prepare_media.py
index bcc99bf6..57ce4ed5 100644
--- a/src/megatron/energon/tools/prepare_media.py
+++ b/src/megatron/energon/tools/prepare_media.py
@@ -2,6 +2,9 @@
# SPDX-License-Identifier: BSD-3-Clause
+from pathlib import Path
+from typing import Optional
+
import click
from megatron.energon.epathlib import EPath
@@ -44,6 +47,15 @@
is_flag=True,
help="Media detection by standard file extensions.",
)
+@click.option(
+ "--tmp-path",
+ type=click.Path(path_type=Path),
+ default=None,
+ help=(
+ "When PATH is remote, directory where index.sqlite is built locally before upload. "
+ "Default: a new directory under /tmp (removed after success)."
+ ),
+)
def command(
path: EPath,
progress: bool,
@@ -51,6 +63,7 @@ def command(
media_metadata_by_glob: str | None,
media_metadata_by_header: bool,
media_metadata_by_extension: bool,
+ tmp_path: Optional[Path],
):
"""Prepare a filesystem dataset by collecting media metadata."""
@@ -58,6 +71,11 @@ def command(
media_metadata_by_glob, media_metadata_by_header, media_metadata_by_extension
)
+ if tmp_path is not None:
+ index_sqlite_tmp_path = tmp_path / "index.sqlite"
+ else:
+ index_sqlite_tmp_path = None
+
ds_type = get_dataset_type(path)
if ds_type == EnergonDatasetType.WEBDATASET:
click.echo("Preparing webdataset and computing media metadata...")
@@ -83,6 +101,7 @@ def progress_fn(els, length=None):
media_filter=media_filter_config,
workers=num_workers,
progress_fn=progress_fn,
+ index_sqlite_tmp_path=index_sqlite_tmp_path,
)
click.echo(f"Done. Stored metadata for {count} files.")
@@ -98,6 +117,7 @@ def progress_fn(els, length=None):
media_filter_config,
progress=progress,
workers=num_workers,
+ index_sqlite_tmp_path=index_sqlite_tmp_path,
)
click.echo(f"Done. Stored metadata for {stored} files.")
diff --git a/tests/s3_emulator/handler.py b/tests/s3_emulator/handler.py
index 3a6d28a6..7329825e 100644
--- a/tests/s3_emulator/handler.py
+++ b/tests/s3_emulator/handler.py
@@ -2,11 +2,11 @@
# SPDX-License-Identifier: BSD-3-Clause
import urllib.parse as _up
from datetime import datetime, timezone
-from email.utils import formatdate
+from email.utils import format_datetime
from hashlib import md5
from http import HTTPStatus
from http.server import BaseHTTPRequestHandler
-from typing import Protocol
+from typing import Literal, Protocol
from .auth import InvalidSignature, S3Auth
from .state import S3State
@@ -112,6 +112,38 @@ def _handle_write(self):
qs = _up.parse_qs(parsed.query, keep_blank_values=True)
+ # S3 CopyObject: PUT to destination key with x-amz-copy-source and empty body.
+ copy_src = (
+ self.headers.get("x-amz-copy-source") or self.headers.get("X-Amz-Copy-Source") or ""
+ ).strip()
+ if copy_src:
+ if not bucket:
+ self._send_error(HTTPStatus.BAD_REQUEST, "Bucket must be specified")
+ return
+ if key == "":
+ self._send_error(HTTPStatus.BAD_REQUEST, "CopyObject requires an object key")
+ return
+ try:
+ src_bucket, src_key = _parse_copy_source(copy_src)
+ except ValueError as err:
+ self._send_error(HTTPStatus.BAD_REQUEST, str(err))
+ return
+ try:
+ data = self.server.state.copy_object(bucket, key, src_bucket, src_key)
+ last_modified = self.server.state.get_object_last_modified(bucket, key)
+ except FileNotFoundError:
+ self._send_error(HTTPStatus.NOT_FOUND, "NoSuchKey")
+ return
+ xml = (
+ ''
+ ""
+ f"{_escape_xml(_s3_datetime(last_modified))}"
+ f""{_escape_xml(_etag(data))}""
+ ""
+ ).encode()
+ self._send_bytes(xml, status=HTTPStatus.OK, content_type="application/xml")
+ return
+
# Multipart: upload part
if "uploadId" in qs and "partNumber" in qs:
upload_id = qs["uploadId"][0]
@@ -160,15 +192,15 @@ def _handle_read(self, listing: bool, only_headers: bool = False):
self._send_error(HTTPStatus.BAD_REQUEST, "Bucket must be specified")
return
- if key == "": # List bucket contents
+ if key == "": # List bucket contents (ListObjects / ListObjectsV2)
if not listing:
- # We treat listing with GET only
try:
objects = self.server.state.list_objects(bucket)
except KeyError:
self._send_error(HTTPStatus.NOT_FOUND, "Bucket not found")
return
- xml_body = self._render_bucket_list(bucket, objects)
+ qs = _up.parse_qs(parsed.query, keep_blank_values=True)
+ xml_body = self._render_list_bucket_result(bucket, objects, qs)
self._send_bytes(xml_body, content_type="application/xml")
else:
self._send_error(HTTPStatus.NOT_IMPLEMENTED, "Listing not implemented")
@@ -176,6 +208,7 @@ def _handle_read(self, listing: bool, only_headers: bool = False):
try:
data = self.server.state.get_object(bucket, key)
+ last_modified = self.server.state.get_object_last_modified(bucket, key)
except FileNotFoundError:
self._send_error(HTTPStatus.NOT_FOUND, "Not found")
return
@@ -203,10 +236,10 @@ def _handle_read(self, listing: bool, only_headers: bool = False):
"Accept-Ranges": "bytes",
"Content-Length": str(len(slice_data)),
"ETag": _etag(data),
+ "Last-Modified": _http_datetime(last_modified),
}
if only_headers:
headers.setdefault("Content-Type", "application/octet-stream")
- headers.setdefault("Last-Modified", formatdate(usegmt=True))
self._send_status(HTTPStatus.PARTIAL_CONTENT, extra_headers=headers)
else:
self._send_bytes(
@@ -223,7 +256,7 @@ def _handle_read(self, listing: bool, only_headers: bool = False):
"Content-Length": str(len(data)),
"Accept-Ranges": "bytes",
"Content-Type": "application/octet-stream",
- "Last-Modified": formatdate(usegmt=True),
+ "Last-Modified": _http_datetime(last_modified),
"ETag": _etag(data),
},
)
@@ -231,7 +264,11 @@ def _handle_read(self, listing: bool, only_headers: bool = False):
self._send_bytes(
data,
content_type="application/octet-stream",
- extra_headers={"Accept-Ranges": "bytes"},
+ extra_headers={
+ "Accept-Ranges": "bytes",
+ "Last-Modified": _http_datetime(last_modified),
+ "ETag": _etag(data),
+ },
)
def _handle_delete(self):
@@ -359,44 +396,104 @@ def _send_bytes(
if self.command != "HEAD":
self.wfile.write(data)
- @staticmethod
- def _render_bucket_list(bucket: str, objects: list[str]) -> bytes:
- """Generate an XML listing of objects in a bucket.
+ def _render_list_bucket_result(
+ self,
+ bucket: str,
+ all_keys: list[str],
+ qs: dict[str, list[str]],
+ ) -> bytes:
+ """Build ListBucketResult XML (ListObjectsV2-compatible).
+
+ Clients (e.g. MSC) send ``delimiter=/`` and ``prefix=`` and expect
+ ``CommonPrefixes`` for nested keys such as ``parts/data-0.tar``, not
+ only flat ``Contents``.
+ """
+ prefix = (qs.get("prefix") or [""])[0]
+ delimiter = (qs.get("delimiter") or [None])[0]
+ max_keys_s = (qs.get("max-keys") or qs.get("maxkeys") or ["1000"])[0]
+ try:
+ max_keys = max(1, min(int(max_keys_s), 1000))
+ except ValueError:
+ max_keys = 1000
- Args:
- bucket: The bucket name.
- objects: List of object keys in the bucket.
+ continuation = (qs.get("continuation-token") or [""])[0]
+ start_after = (qs.get("start-after") or [""])[0]
+ exclusive_after = continuation or start_after
- Returns:
- The XML document as bytes.
- """
- entries = []
- now = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.000Z")
- for key in objects:
- try:
- data = S3RequestHandler.server.state.get_object(bucket, key) # type: ignore[attr-defined]
- size = len(data)
- etag = _etag(data)
- except Exception: # noqa: BLE001
- size = 0
- etag = '""'
- entries.append(
- ""
- f"{_escape_xml(key)}"
- f"{now}"
- f"{etag}"
- f"{size}"
- ""
+ state = self.server.state
+
+ items: list[tuple[Literal["cp", "key"], str]] = []
+ if not delimiter:
+ for k in sorted(all_keys):
+ if k.startswith(prefix):
+ items.append(("key", k))
+ else:
+ common: set[str] = set()
+ contents: list[str] = []
+ for k in sorted(all_keys):
+ if not k.startswith(prefix):
+ continue
+ relative = k[len(prefix) :]
+ if delimiter in relative:
+ idx = relative.index(delimiter)
+ common.add(prefix + relative[: idx + len(delimiter)])
+ else:
+ contents.append(k)
+ for cp in sorted(common):
+ items.append(("cp", cp))
+ for ck in sorted(contents):
+ items.append(("key", ck))
+ items.sort(key=lambda x: x[1])
+
+ if exclusive_after:
+ items = [it for it in items if it[1] > exclusive_after]
+
+ page = items[:max_keys]
+ truncated = len(items) > max_keys
+ next_token = page[-1][1] if truncated and page else ""
+
+ fragments: list[str] = [
+ '',
+ "",
+ f"{_escape_xml(bucket)}",
+ f"{_escape_xml(prefix)}",
+ f"{len(page)}",
+ f"{max_keys}",
+ f"{str(truncated).lower()}",
+ ]
+ if delimiter:
+ fragments.append(f"{_escape_xml(delimiter)}")
+ if truncated and next_token:
+ fragments.append(
+ f"{_escape_xml(next_token)}"
)
- obj_elems = "".join(entries)
- xml = (
- ''
- ""
- f"{_escape_xml(bucket)}"
- f"{obj_elems}"
- ""
- )
- return xml.encode()
+
+ for kind, path in page:
+ if kind == "cp":
+ fragments.append(
+ f"{_escape_xml(path)}"
+ )
+ else:
+ try:
+ data = state.get_object(bucket, path)
+ last_modified = state.get_object_last_modified(bucket, path)
+ size = len(data)
+ etag = _etag(data)
+ except Exception: # noqa: BLE001
+ size = 0
+ etag = '""'
+ last_modified = datetime.fromtimestamp(0, tz=timezone.utc)
+ fragments.append(
+ ""
+ f"{_escape_xml(path)}"
+ f"{_s3_datetime(last_modified)}"
+ f"{etag}"
+ f"{size}"
+ ""
+ )
+
+ fragments.append("")
+ return "".join(fragments).encode()
class S3ServerProtocol(Protocol): # noqa: D101
@@ -404,6 +501,35 @@ class S3ServerProtocol(Protocol): # noqa: D101
auth: S3Auth
+def _parse_copy_source(raw: str) -> tuple[str, str]:
+ """Parse ``x-amz-copy-source`` into ``(bucket, key)``.
+
+ Accepts ``/bucket/key``, ``bucket/key``, URL-encoded keys, and strips ``?versionId=``.
+
+ Args:
+ raw: Raw header value.
+
+ Returns:
+ Source bucket and object key.
+
+ Raises:
+ ValueError: If the value cannot be parsed.
+ """
+ s = raw.strip()
+ if not s:
+ raise ValueError("Empty x-amz-copy-source")
+ s = s.split("?", 1)[0]
+ s = _up.unquote(s)
+ if s.startswith("/"):
+ s = s[1:]
+ if "/" not in s:
+ raise ValueError("x-amz-copy-source must be /bucket/key")
+ src_bucket, src_key = s.split("/", 1)
+ if not src_bucket or not src_key:
+ raise ValueError("Invalid x-amz-copy-source")
+ return src_bucket, src_key
+
+
def _escape_xml(text: str) -> str: # noqa: D401
"""Escape special characters for XML.
@@ -422,6 +548,18 @@ def _escape_xml(text: str) -> str: # noqa: D401
)
+def _http_datetime(value: datetime) -> str:
+ """Format an aware datetime for HTTP Last-Modified headers."""
+
+ return format_datetime(value.astimezone(timezone.utc), usegmt=True)
+
+
+def _s3_datetime(value: datetime) -> str:
+ """Format an aware datetime for S3 XML LastModified fields."""
+
+ return value.astimezone(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.000Z")
+
+
def _etag(data: bytes) -> str: # noqa: D401
"""Generate an ETag for binary data.
diff --git a/tests/s3_emulator/state.py b/tests/s3_emulator/state.py
index 1493dcb1..c8d22cee 100644
--- a/tests/s3_emulator/state.py
+++ b/tests/s3_emulator/state.py
@@ -1,6 +1,8 @@
# Copyright (c) 2025, NVIDIA CORPORATION.
# SPDX-License-Identifier: BSD-3-Clause
import json
+import os
+from datetime import datetime, timezone
from pathlib import Path
from threading import RLock
from typing import Dict, Optional
@@ -27,6 +29,7 @@ def __init__(self, root_dir: Optional[Path] = None) -> None:
root_dir: Path to persist the store on disk.
"""
self._fs: Dict[str, Dict[str, bytes]] = {}
+ self._last_modified: Dict[str, Dict[str, datetime]] = {}
self._uploads: Dict[str, _MultipartUpload] = {}
self._lock = RLock()
self._root_dir = root_dir
@@ -54,6 +57,7 @@ def create_bucket(self, bucket: str) -> None:
print(f"Bucket '{bucket}' already exists")
return
self._fs[bucket] = {}
+ self._last_modified[bucket] = {}
if self._root_dir is not None:
(self._root_dir / bucket).mkdir(parents=True, exist_ok=True)
@@ -69,6 +73,7 @@ def delete_bucket(self, bucket: str) -> None:
if self._fs[bucket]:
raise RuntimeError("Bucket not empty")
del self._fs[bucket]
+ del self._last_modified[bucket]
if self._root_dir is not None:
bucket_path = self._root_dir / bucket
if bucket_path.exists():
@@ -76,24 +81,34 @@ def delete_bucket(self, bucket: str) -> None:
p.unlink()
bucket_path.rmdir()
- def put_object(self, bucket: str, key: str, data: bytes) -> None:
+ def put_object(
+ self, bucket: str, key: str, data: bytes, *, last_modified: datetime | None = None
+ ) -> None:
"""Store an object in a bucket.
Args:
bucket: Name of the bucket.
key: Object key.
data: Object data.
+ last_modified: Optional timestamp for the object. Defaults to now.
"""
if not bucket:
raise ValueError("Bucket name must be given")
+ if last_modified is None:
+ last_modified = datetime.now(timezone.utc)
+ else:
+ last_modified = last_modified.astimezone(timezone.utc)
with self._lock:
if bucket not in self._fs:
self._fs[bucket] = {}
+ self._last_modified[bucket] = {}
self._fs[bucket][key] = data
+ self._last_modified[bucket][key] = last_modified
if self._root_dir is not None:
obj_path = (self._root_dir / bucket / key).resolve()
obj_path.parent.mkdir(parents=True, exist_ok=True)
obj_path.write_bytes(data)
+ os.utime(obj_path, (last_modified.timestamp(), last_modified.timestamp()))
def get_object(self, bucket: str, key: str) -> bytes:
"""Retrieve an object from a bucket.
@@ -111,6 +126,48 @@ def get_object(self, bucket: str, key: str) -> bytes:
except KeyError as exc:
raise FileNotFoundError(f"{bucket}/{key}") from exc
+ def get_object_last_modified(self, bucket: str, key: str) -> datetime:
+ """Return the stored Last-Modified timestamp for an object."""
+
+ with self._lock:
+ try:
+ return self._last_modified[bucket][key]
+ except KeyError as exc:
+ raise FileNotFoundError(f"{bucket}/{key}") from exc
+
+ def copy_object(self, dest_bucket: str, dest_key: str, src_bucket: str, src_key: str) -> bytes:
+ """Copy an object to another key (S3 CopyObject).
+
+ Args:
+ dest_bucket: Destination bucket.
+ dest_key: Destination object key.
+ src_bucket: Source bucket.
+ src_key: Source object key.
+
+ Returns:
+ Copied object bytes (for ETag in the CopyObject XML response).
+
+ Raises:
+ FileNotFoundError: If the source object does not exist.
+ """
+ with self._lock:
+ try:
+ payload = bytes(self._fs[src_bucket][src_key])
+ except KeyError as exc:
+ raise FileNotFoundError(f"{src_bucket}/{src_key}") from exc
+ if dest_bucket not in self._fs:
+ self._fs[dest_bucket] = {}
+ self._last_modified[dest_bucket] = {}
+ last_modified = datetime.now(timezone.utc)
+ self._fs[dest_bucket][dest_key] = payload
+ self._last_modified[dest_bucket][dest_key] = last_modified
+ if self._root_dir is not None:
+ obj_path = (self._root_dir / dest_bucket / dest_key).resolve()
+ obj_path.parent.mkdir(parents=True, exist_ok=True)
+ obj_path.write_bytes(payload)
+ os.utime(obj_path, (last_modified.timestamp(), last_modified.timestamp()))
+ return payload
+
def delete_object(self, bucket: str, key: str) -> None:
"""Delete an object from a bucket.
@@ -121,6 +178,7 @@ def delete_object(self, bucket: str, key: str) -> None:
with self._lock:
try:
del self._fs[bucket][key]
+ del self._last_modified[bucket][key]
except KeyError as exc:
raise FileNotFoundError(f"{bucket}/{key}") from exc
if self._root_dir is not None:
@@ -162,7 +220,21 @@ def _load_from_disk(self) -> None:
print(f"Failed to read persisted state: {err}")
return
with self._lock:
- self._fs = {bucket: {key: b"" for key in keys} for bucket, keys in mapping.items()}
+ self._fs = {}
+ self._last_modified = {}
+ for bucket, keys in mapping.items():
+ bucket_path = self._root_dir / bucket
+ self._fs[bucket] = {}
+ self._last_modified[bucket] = {}
+ for key in keys:
+ object_path = bucket_path / key
+ self._fs[bucket][key] = (
+ object_path.read_bytes() if object_path.is_file() else b""
+ )
+ self._last_modified[bucket][key] = datetime.fromtimestamp(
+ object_path.stat().st_mtime if object_path.exists() else 0,
+ tz=timezone.utc,
+ )
def flush(self) -> None:
"""Persist only the structure of the store to disk."""
@@ -186,6 +258,7 @@ def initiate_multipart(self, bucket: str, key: str) -> str:
self._uploads[upload_id] = _MultipartUpload(bucket, key)
if bucket not in self._fs:
self._fs[bucket] = {}
+ self._last_modified[bucket] = {}
return upload_id
def upload_part(self, upload_id: str, part_number: int, data: bytes) -> None:
@@ -215,11 +288,15 @@ def complete_multipart(self, upload_id: str) -> None:
data = mp.assemble()
if mp.bucket not in self._fs:
self._fs[mp.bucket] = {}
+ self._last_modified[mp.bucket] = {}
+ last_modified = datetime.now(timezone.utc)
self._fs[mp.bucket][mp.key] = data
+ self._last_modified[mp.bucket][mp.key] = last_modified
if self._root_dir is not None:
obj_path = (self._root_dir / mp.bucket / mp.key).resolve()
obj_path.parent.mkdir(parents=True, exist_ok=True)
obj_path.write_bytes(data)
+ os.utime(obj_path, (last_modified.timestamp(), last_modified.timestamp()))
def abort_multipart(self, upload_id: str) -> None:
"""Abort a multipart upload.
@@ -243,7 +320,8 @@ def add_file(self, src: Path, dst: str):
self.add_file(file, dst=f"{dst}/{file.name}")
elif src.is_file():
bucket, key = dst.removeprefix("/").split("/", 1)
- self.put_object(bucket, key, src.read_bytes())
+ last_modified = datetime.fromtimestamp(src.stat().st_mtime, tz=timezone.utc)
+ self.put_object(bucket, key, src.read_bytes(), last_modified=last_modified)
else:
raise ValueError(f"Invalid file: {src}")
diff --git a/tests/test_crudedataset.py b/tests/test_crudedataset.py
index 5cd38eab..2e16df22 100644
--- a/tests/test_crudedataset.py
+++ b/tests/test_crudedataset.py
@@ -19,6 +19,7 @@
import torch
import webdataset as wds
+from click.testing import CliRunner
import megatron.energon.epathlib.epath as epath_mod
from megatron.energon import (
@@ -34,11 +35,12 @@
)
from megatron.energon.cache import FileStore, FileStoreCachePool, Lazy
from megatron.energon.cache.base import CachePool
+from megatron.energon.cli.main import prepare_media_command
from megatron.energon.edataclass import edataclass
from megatron.energon.epathlib.epath import EPath
from megatron.energon.errors import reraise_exception
from megatron.energon.flavors.base_dataset import Sample
-from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME
+from megatron.energon.flavors.webdataset.config import INDEX_SQLITE_FILENAME, MAIN_FOLDER_NAME
from megatron.energon.flavors.webdataset.sample_decoder import SampleDecoder
from megatron.energon.media.extractor import MediaFilterConfig, MediaFilterStrategy
from megatron.energon.media.filesystem_prepare import prepare_filesystem_dataset
@@ -1014,6 +1016,89 @@ def test_media_metadata_webdataset(self):
"VIDEO-192x108@30.0fps-63.0s|VIDEO-192x108@30.0fps-63.0s",
]
+ def test_prepare_dataset_s3_cmdline(self):
+ """Tar shards live on S3 (emulator); `energon prepare` writes `.nv-meta` to the same prefix."""
+ bucket = "energon-prepare-s3-cmdline-media-metadata"
+ profile_name = "s3test_dataset_prepare_cmdline_media_metadata"
+ with setup_s3_emulator(profile_name=profile_name) as state:
+ state.create_bucket(bucket)
+ state.add_file(self.multimedia_fs_path, dst=f"{bucket}/multimedia_fs/")
+ state.add_file(self.multimedia_wds_path, dst=f"{bucket}/multimedia_wds/")
+ s3_root = EPath(f"msc://{profile_name}/{bucket}")
+ info_path = s3_root / "multimedia_fs" / MAIN_FOLDER_NAME / INDEX_SQLITE_FILENAME
+ # assert not info_path.is_file(), f"index.sqlite already exists: {info_path}"
+ print(f"S3 root: {s3_root}")
+ runner = CliRunner()
+ result = runner.invoke(
+ prepare_media_command,
+ [
+ str(s3_root / "multimedia_fs"),
+ "--num-workers",
+ "2",
+ "--media-metadata-by-extension",
+ ],
+ catch_exceptions=False,
+ )
+ print(result.stdout)
+ assert result.exit_code == 0, result.stdout
+ assert "Done" in result.stdout, result.stdout
+ assert info_path.is_file(), f"missing {info_path}"
+
+ print("Done preparing media metadata, iterating now")
+
+ state.put_object(
+ bucket,
+ "s3_media_metadataset.yaml",
+ "\n".join(
+ [
+ "__module__: megatron.energon",
+ "__class__: MetadatasetV2",
+ "splits:",
+ " train:",
+ " path: multimedia_wds",
+ " aux:",
+ f" media: filesystem+msc://{profile_name}/{bucket}/multimedia_fs",
+ " subflavors:",
+ " crude_type: media_metadata",
+ ]
+ ).encode("utf-8"),
+ )
+
+ # Iterate over the new dataset
+ worker_config = WorkerConfig(
+ rank=0,
+ world_size=1,
+ num_workers=0,
+ )
+
+ loader = get_savable_loader(
+ get_train_dataset(
+ s3_root / "s3_media_metadataset.yaml",
+ batch_size=1,
+ worker_config=worker_config,
+ task_encoder=CookingTaskEncoder(),
+ shuffle_buffer_size=None,
+ max_samples_per_sequence=None,
+ )
+ )
+
+ descriptions = []
+ for _, batch in zip(range(4), loader):
+ descriptions.extend(batch.txts)
+
+ # from pprint import pprint
+ # pprint(descriptions, indent=4)
+
+ # The descriptions are like "A|B", where A is the format
+ # in the WebDataset and B is the format in the auxiliary dataset.
+
+ assert descriptions == [
+ "IMG-32x16-JPEG|IMG-32x16-JPEG",
+ "IMG-24x24-PNG|IMG-24x24-PNG",
+ "AUDIO-10.0s@32000Hz|AUDIO-10.0s@32000Hz",
+ "VIDEO-192x108@30.0fps-63.0s|VIDEO-192x108@30.0fps-63.0s",
+ ]
+
def test_nomds(self):
torch.manual_seed(42)
worker_config = WorkerConfig(
diff --git a/tests/test_dataset.py b/tests/test_dataset.py
index 01aa74ee..b05cb401 100644
--- a/tests/test_dataset.py
+++ b/tests/test_dataset.py
@@ -10,6 +10,7 @@
import logging
import math
import random
+import shutil
import sys
import tempfile
import unittest
@@ -47,14 +48,16 @@
)
from megatron.energon.dataset_config import get_dataset_from_config
from megatron.energon.edataclass import edataclass
+from megatron.energon.epathlib import EPath
from megatron.energon.flavors import BaseWebdatasetFactory
-from megatron.energon.flavors.webdataset.config import MAIN_FOLDER_NAME
+from megatron.energon.flavors.webdataset.config import INFO_JSON_FILENAME, MAIN_FOLDER_NAME
from megatron.energon.task_encoder.base import stateless
from megatron.energon.tools.analyze_debug import command as analyze_debug_command
from megatron.energon.tools.info import command as info_command
from megatron.energon.tools.lint import command as lint_command
from megatron.energon.tools.prepare import command as prepare_command
from megatron.energon.tools.preview import command as preview_command
+from tests.epath_s3_emulator import setup_s3_emulator
# Speed up tests significantly by reducing the torch status check interval for broken worker shutdown
try:
@@ -1803,6 +1806,57 @@ def test_prepare_dataset_noninteractive_crude(self):
content = f.read()
assert "CrudeWebdataset" in content
+ def test_prepare_dataset_s3_cmdline(self):
+ """Tar shards live on S3 (emulator); `energon prepare` writes `.nv-meta` to the same prefix."""
+ bucket = "energon-prepare-s3-cmdline"
+ profile_name = "s3test_dataset_prepare_cmdline"
+ with tempfile.TemporaryDirectory() as staging_dir:
+ staging_path = Path(staging_dir)
+ TestDataset.create_captioning_test_dataset(staging_path, num_samples=20)
+ shutil.rmtree(staging_path / MAIN_FOLDER_NAME)
+ with setup_s3_emulator(profile_name=profile_name) as state:
+ state.create_bucket(bucket)
+ state.add_file(staging_path, dst=bucket)
+ s3_root = EPath(f"msc://{profile_name}/{bucket}")
+ assert s3_root.is_dir()
+ tar_paths = list(s3_root.glob("**/*.tar"))
+ assert len(tar_paths) >= 1, f"expected tar shards on S3, got {tar_paths!r}"
+ runner = CliRunner()
+ result = runner.invoke(
+ prepare_command,
+ [
+ str(s3_root),
+ "--non-interactive",
+ "--no-progress",
+ "--num-workers",
+ "2",
+ "--split-ratio=1,0,0",
+ "--sample-type=CaptioningSample",
+ '--field-map={"image": "png", "caption": "txt"}',
+ "--media-metadata-by-extension",
+ ],
+ catch_exceptions=False,
+ )
+ assert result.exit_code == 0, result.stdout
+ assert "Done" in result.stdout, result.stdout
+ info_path = s3_root / MAIN_FOLDER_NAME / INFO_JSON_FILENAME
+ assert info_path.is_file(), f"missing {info_path}"
+
+ # Iterate over the new dataset
+ loader = get_loader(
+ get_train_dataset(
+ s3_root,
+ batch_size=5,
+ worker_config=WorkerConfig(rank=0, world_size=1, num_workers=0),
+ shuffle_buffer_size=None,
+ max_samples_per_sequence=None,
+ )
+ )
+ assert len(loader) == 4, f"len(loader) == {len(loader)}"
+ samples = list(d for _, d in zip(range(4), loader))
+ assert len(samples) == 4, f"len(samples) == {len(samples)}"
+ assert all(isinstance(sample, CaptioningSample) for sample in samples)
+
def test_preview_captioning_dataset(self):
runner = CliRunner()
result = runner.invoke(
diff --git a/tests/test_epathlib.py b/tests/test_epathlib.py
index 2e8ae432..7aad01ca 100644
--- a/tests/test_epathlib.py
+++ b/tests/test_epathlib.py
@@ -11,6 +11,7 @@
import sys
import tempfile
import unittest
+from datetime import datetime, timezone
from pathlib import Path
from megatron.energon.epathlib import DEFAULT_PROFILE_NAME, EPath
@@ -121,6 +122,94 @@ def test_glob(self):
p = EPath("/tmp").glob("epathtestfile_[0-3].bin")
assert len(list(p)) == 4
+ def test_dss_glob_walk(self):
+ import megatron.energon.epathlib.epath as epath_mod
+
+ orig_env_cache_dir = os.environ.get("NVDATASET_CACHE_DIR")
+ orig_mod_cache_dir = epath_mod.NVDATASET_CACHE_DIR
+
+ with tempfile.TemporaryDirectory() as td:
+ cache_dir = Path(td) / "nvds_cache"
+ media_dir = cache_dir / "charts1234" / "v0" / "images"
+ media_dir.mkdir(parents=True)
+ (media_dir / "000.jpg").write_bytes(b"\xff\xd8\xff\xd9")
+ (media_dir / "001.jpg").write_bytes(b"\xff\xd8\xff\xd9")
+ (media_dir / "002.txt").write_bytes(b"dummy")
+
+ try:
+ os.environ["NVDATASET_CACHE_DIR"] = str(cache_dir)
+ epath_mod.NVDATASET_CACHE_DIR = EPath(cache_dir)
+
+ root = EPath("dss://charts1234@v0")
+ found = sorted(root.glob("**/*.jpg"))
+
+ print(found)
+
+ assert [str(path) for path in found] == [
+ "dss://charts1234@v0/images/000.jpg",
+ "dss://charts1234@v0/images/001.jpg",
+ ]
+ assert [path.relative_to(root) for path in found] == [
+ "images/000.jpg",
+ "images/001.jpg",
+ ]
+
+ found = sorted(root.walk())
+ print(found)
+ assert [str(p) for p in found] == [
+ "dss://charts1234@v0/images/000.jpg",
+ "dss://charts1234@v0/images/001.jpg",
+ "dss://charts1234@v0/images/002.txt",
+ ]
+ finally:
+ if orig_env_cache_dir is None:
+ os.environ.pop("NVDATASET_CACHE_DIR", None)
+ else:
+ os.environ["NVDATASET_CACHE_DIR"] = orig_env_cache_dir
+ epath_mod.NVDATASET_CACHE_DIR = orig_mod_cache_dir
+
+ def test_s3_glob_walk(self):
+ with setup_s3_emulator(profile_name="s3test_dss_walk") as s3_emulator:
+ s3_emulator.put_object("test", "dir/file.txt", b"dummy")
+ s3_emulator.put_object("test", "dir/subdir/file2.txt", b"dummy")
+ s3_emulator.put_object("test", "dir/subdir/file3.blob", b"dummy")
+ root = EPath("msc://s3test_dss_walk/test/dir")
+ found = sorted(root.walk())
+ print(found)
+ assert [str(p) for p in found] == [
+ "msc://s3test_dss_walk/test/dir/file.txt",
+ "msc://s3test_dss_walk/test/dir/subdir/file2.txt",
+ "msc://s3test_dss_walk/test/dir/subdir/file3.blob",
+ ]
+
+ found = sorted(root.glob("**/*.txt"))
+ print(found)
+ assert [str(p) for p in found] == [
+ "msc://s3test_dss_walk/test/dir/file.txt",
+ "msc://s3test_dss_walk/test/dir/subdir/file2.txt",
+ ]
+
+ def test_local_glob_walk(self):
+ with tempfile.TemporaryDirectory() as td:
+ td_path = EPath(td)
+ (td_path / "file.txt").write_text("dummy")
+ (td_path / "subdir" / "file2.txt").write_text("dummy")
+ (td_path / "subdir" / "file3.blob").write_text("dummy")
+ root = EPath(td_path)
+ found = sorted(root.walk())
+ print(found)
+ assert [str(p) for p in found] == [
+ str(td_path / "file.txt"),
+ str(td_path / "subdir" / "file2.txt"),
+ str(td_path / "subdir" / "file3.blob"),
+ ]
+ found = sorted(root.glob("**/*.txt"))
+ print(found)
+ assert [str(p) for p in found] == [
+ str(td_path / "file.txt"),
+ str(td_path / "subdir" / "file2.txt"),
+ ]
+
def test_s3_path_resolution(self):
"""Test s3 path resolution"""
rclone_config_path = EPath("/tmp/XDG_CONFIG_HOME/.config/rclone/rclone.conf")
@@ -278,14 +367,85 @@ def test_msc_s3(self):
assert p.is_file()
assert p.size() > 0
assert p.read_text() == "dummy"
- # TODO: Fix when fixed in MSC.
- # assert EPath("msc://s3test_msc/test").is_dir()
+ assert EPath("msc://s3test_msc/test").is_dir()
assert EPath("msc://s3test_msc/test/dir").is_dir()
p.unlink()
assert not p.is_file()
- # assert not EPath("msc://s3test_msc/test").is_dir()
+ assert not EPath("msc://s3test_msc/test").is_dir()
assert not EPath("msc://s3test_msc/test/dir").is_dir()
+ def test_msc_s3_dataprep_path_operations(self):
+ """EPath operations used by remote webdataset ``prepare`` (glob, rb, meta, sqlite upload, move).
+
+ Mirrors ``tools/prepare.py`` shard discovery, ``WebdatasetPreparator._preprocess_tar``
+ binary reads, and ``SqliteIndexWriterAggregator`` uploading ``index.sqlite`` from disk.
+ """
+ profile = "s3test_msc_dataprep"
+ with setup_s3_emulator(profile_name=profile):
+ root = EPath(f"msc://{profile}/dataset_root")
+ parts = root / "parts"
+ (parts / "data-0.tar").write_bytes(b"shard0-bytes")
+ (parts / "data-1.tar").write_bytes(b"shard1-bytes")
+
+ found = sorted(root.glob("**/*.tar"))
+ assert [p.name for p in found] == ["data-0.tar", "data-1.tar"]
+
+ with (parts / "data-0.tar").open("rb") as f:
+ assert f.read(6) == b"shard0"
+
+ meta_dir = root / MAIN_FOLDER_NAME
+ info = meta_dir / INFO_JSON_FILENAME
+ info.write_text('{"shard_counts": {"parts/data-0.tar": 1}}')
+ assert '"shard_counts"' in info.read_text()
+
+ probe_idx = parts / "probe.idx"
+ probe_idx_tmp = parts / "probe.idx.tmp"
+ with probe_idx_tmp.open("wb") as out:
+ out.write(struct.pack("QQ", 0, 512))
+ assert probe_idx_tmp.size() == 16
+ assert struct.unpack("QQ", probe_idx_tmp.read_bytes()) == (0, 512)
+
+ probe_idx_tmp.move(probe_idx)
+
+ assert probe_idx.size() == 16
+ assert struct.unpack("QQ", probe_idx.read_bytes()) == (0, 512)
+ probe_idx.unlink()
+
+ with tempfile.NamedTemporaryFile(delete=False) as lf:
+ lf.write(b"sqlite-placeholder")
+ local_sqlite = lf.name
+ try:
+ EPath(local_sqlite).copy(meta_dir / INDEX_SQLITE_FILENAME)
+ remote_db = meta_dir / INDEX_SQLITE_FILENAME
+ assert remote_db.read_bytes() == b"sqlite-placeholder"
+ remote_db.unlink()
+ finally:
+ Path(local_sqlite).unlink(missing_ok=True)
+
+ info.unlink()
+ for p in found:
+ p.unlink()
+
+ def test_msc_s3_stat_uses_stored_last_modified(self):
+ profile = "s3test_msc_timestamps"
+ with setup_s3_emulator(profile_name=profile) as state:
+ remote_path = EPath(f"msc://{profile}/bucket/data.bin")
+ original_mtime = datetime(2024, 1, 2, 3, 4, 5, tzinfo=timezone.utc)
+ updated_mtime = datetime(2024, 1, 3, 3, 4, 5, tzinfo=timezone.utc)
+
+ state.put_object("bucket", "data.bin", b"first", last_modified=original_mtime)
+
+ first_stat = remote_path.stat()
+ second_stat = remote_path.stat()
+ assert int(first_stat.last_modified.timestamp()) == int(original_mtime.timestamp())
+ assert second_stat.last_modified == first_stat.last_modified
+
+ state.put_object("bucket", "data.bin", b"second", last_modified=updated_mtime)
+
+ updated_stat = remote_path.stat()
+ assert int(updated_stat.last_modified.timestamp()) == int(updated_mtime.timestamp())
+ assert updated_stat.content_length == len(b"second")
+
def test_dss_path_requires_version(self):
with self.assertRaisesRegex(
AssertionError,