Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/megatron/energon/epathlib/epath.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def __init__(
elif protocol == "dss":
# Profile corresponds to the dataset name and version
assert profile is not None
self._split_dss_name_and_version(profile)
assert NVDATASET_CACHE_DIR is not None, (
"Environment variable NVDATASET_CACHE_DIR is not set"
)
Expand Down Expand Up @@ -111,6 +112,7 @@ def __setstate__(self, state: dict) -> None:
assert NVDATASET_CACHE_DIR is not None, (
"Environment variable NVDATASET_CACHE_DIR is not set"
)
self._split_dss_name_and_version(self.internal_path.parts[1])
self.fs = NVDATASET_CACHE_DIR.fs
else:
self.fs, _ = msc.resolve_storage_client(f"msc://{self.profile}")
Expand Down Expand Up @@ -148,6 +150,15 @@ def _split_protocol(path: str) -> Tuple[Optional[str], Optional[str], str]:
inner_path = ""
return m.group("protocol"), m.group("profile"), inner_path

@staticmethod
def _split_dss_name_and_version(dataset: str) -> Tuple[str, str]:
assert "@" in dataset, "DSS paths must include a dataset version separated by '@'"
dataset_name, dataset_version = dataset.rsplit("@", maxsplit=1)
assert dataset_name and dataset_version, (
"DSS paths must include non-empty dataset name and version"
)
return dataset_name, dataset_version

@property
def _internal_str_path(self) -> str:
"""Return the path as used inside the file system, without the protocol and fs part.
Expand All @@ -157,7 +168,16 @@ def _internal_str_path(self) -> str:
"Environment variable NVDATASET_CACHE_DIR is not set"
)
# The internal path is relative to the NVDATASET_CACHE_DIR (i.e. strip the leading /, then concat with /)
return NVDATASET_CACHE_DIR._internal_str_path + str(self.internal_path)
dss_dataset_name, dss_dataset_version = self._split_dss_name_and_version(
self.internal_path.parts[1]
)
cache_path = PurePosixPath(
"/",
dss_dataset_name,
dss_dataset_version,
*self.internal_path.parts[2:],
)
return NVDATASET_CACHE_DIR._internal_str_path + str(cache_path)
else:
return str(self.internal_path)

Expand Down
11 changes: 10 additions & 1 deletion tests/test_crudedataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,15 @@ def test_aux_random_access(self):
assert all([a == b for a, b in zip(samples_after, samples_restored)])

def test_dss_path(self):
dss_dataset_name = "crude_text"
dss_dataset_version = "v1"
dss_dataset_root = self.dataset_path / dss_dataset_name
dss_dataset_root.mkdir(parents=True, exist_ok=True)
(dss_dataset_root / dss_dataset_version).symlink_to(
self.dataset_path / "ds1",
target_is_directory=True,
)

dss_mds_path = self.dataset_path / "metadataset_dss.yaml"
dss_mds_path.write_text(
"\n".join(
Expand All @@ -639,7 +648,7 @@ def test_dss_path(self):
"__class__: MetadatasetV2",
"splits:",
" train:",
" path: dss://ds1",
f" path: dss://{dss_dataset_name}@{dss_dataset_version}",
" subflavors:",
" crude_type: txtpkl",
]
Expand Down
17 changes: 13 additions & 4 deletions tests/test_epathlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,13 @@ def test_msc_s3(self):
# assert not EPath("msc://s3test_msc/test").is_dir()
assert not EPath("msc://s3test_msc/test/dir").is_dir()

def test_dss_path_requires_version(self):
with self.assertRaisesRegex(
AssertionError,
"DSS paths must include a dataset version separated by '@'",
):
EPath("dss://charts1234")

def test_metadataset_v2_dss_path_parsing_str(self):
"""Parse a MetadatasetV2 config and ensure DSS URLs stringify correctly as EPath."""

Expand Down Expand Up @@ -319,15 +326,15 @@ def test_metadataset_v2_dss_path_parsing_str(self):
# Create dummy DSS datasets in the cache dir so that `load_dataset()` can run
# post-initialization without hitting missing-path errors.
#
# - charts1234_zh@v0: minimal "webdataset" marker (presence of .nv-meta/.info.json)
# - charts1234@v0: folder with images (aux media source)
webdataset_root = cache_dir / "charts1234_zh@v0"
# - charts1234_zh/v0: minimal "webdataset" marker (presence of .nv-meta/.info.json)
# - charts1234/v0: folder with images (aux media source)
webdataset_root = cache_dir / "charts1234_zh" / "v0"
(webdataset_root / MAIN_FOLDER_NAME).mkdir(parents=True, exist_ok=True)
(webdataset_root / MAIN_FOLDER_NAME / INFO_JSON_FILENAME).write_text(
"{}", encoding="utf-8"
)

media_root = cache_dir / "charts1234@v0"
media_root = cache_dir / "charts1234" / "v0"
(media_root / "images").mkdir(parents=True, exist_ok=True)
(media_root / "images" / "000.jpg").write_bytes(b"\xff\xd8\xff\xd9")
(media_root / "images" / "001.jpg").write_bytes(b"\xff\xd8\xff\xd9")
Expand Down Expand Up @@ -363,6 +370,8 @@ def test_metadataset_v2_dss_path_parsing_str(self):

assert ds0.url == "dss://charts1234_zh@v0"
assert aux0.url == "dss://charts1234@v0"
assert ds0.local_path() == cache_dir / "charts1234_zh" / "v0"
assert aux0.local_path() == cache_dir / "charts1234" / "v0"
finally:
if orig_env_cache_dir is None:
os.environ.pop("NVDATASET_CACHE_DIR", None)
Expand Down
Loading