diff --git a/src/megatron/energon/epathlib/epath.py b/src/megatron/energon/epathlib/epath.py index f5c30973..6357262a 100644 --- a/src/megatron/energon/epathlib/epath.py +++ b/src/megatron/energon/epathlib/epath.py @@ -80,6 +80,7 @@ def __init__( elif protocol == "dss": # Profile corresponds to the dataset name and version assert profile is not None + self._split_dss_name_and_version(profile) assert NVDATASET_CACHE_DIR is not None, ( "Environment variable NVDATASET_CACHE_DIR is not set" ) @@ -111,6 +112,7 @@ def __setstate__(self, state: dict) -> None: assert NVDATASET_CACHE_DIR is not None, ( "Environment variable NVDATASET_CACHE_DIR is not set" ) + self._split_dss_name_and_version(self.internal_path.parts[1]) self.fs = NVDATASET_CACHE_DIR.fs else: self.fs, _ = msc.resolve_storage_client(f"msc://{self.profile}") @@ -148,6 +150,15 @@ def _split_protocol(path: str) -> Tuple[Optional[str], Optional[str], str]: inner_path = "" return m.group("protocol"), m.group("profile"), inner_path + @staticmethod + def _split_dss_name_and_version(dataset: str) -> Tuple[str, str]: + assert "@" in dataset, "DSS paths must include a dataset version separated by '@'" + dataset_name, dataset_version = dataset.rsplit("@", maxsplit=1) + assert dataset_name and dataset_version, ( + "DSS paths must include non-empty dataset name and version" + ) + return dataset_name, dataset_version + @property def _internal_str_path(self) -> str: """Return the path as used inside the file system, without the protocol and fs part. @@ -157,7 +168,16 @@ def _internal_str_path(self) -> str: "Environment variable NVDATASET_CACHE_DIR is not set" ) # The internal path is relative to the NVDATASET_CACHE_DIR (i.e. strip the leading /, then concat with /) - return NVDATASET_CACHE_DIR._internal_str_path + str(self.internal_path) + dss_dataset_name, dss_dataset_version = self._split_dss_name_and_version( + self.internal_path.parts[1] + ) + cache_path = PurePosixPath( + "/", + dss_dataset_name, + dss_dataset_version, + *self.internal_path.parts[2:], + ) + return NVDATASET_CACHE_DIR._internal_str_path + str(cache_path) else: return str(self.internal_path) diff --git a/tests/test_crudedataset.py b/tests/test_crudedataset.py index fa168e40..5cd38eab 100644 --- a/tests/test_crudedataset.py +++ b/tests/test_crudedataset.py @@ -631,6 +631,15 @@ def test_aux_random_access(self): assert all([a == b for a, b in zip(samples_after, samples_restored)]) def test_dss_path(self): + dss_dataset_name = "crude_text" + dss_dataset_version = "v1" + dss_dataset_root = self.dataset_path / dss_dataset_name + dss_dataset_root.mkdir(parents=True, exist_ok=True) + (dss_dataset_root / dss_dataset_version).symlink_to( + self.dataset_path / "ds1", + target_is_directory=True, + ) + dss_mds_path = self.dataset_path / "metadataset_dss.yaml" dss_mds_path.write_text( "\n".join( @@ -639,7 +648,7 @@ def test_dss_path(self): "__class__: MetadatasetV2", "splits:", " train:", - " path: dss://ds1", + f" path: dss://{dss_dataset_name}@{dss_dataset_version}", " subflavors:", " crude_type: txtpkl", ] diff --git a/tests/test_epathlib.py b/tests/test_epathlib.py index 2764e6f3..2e8ae432 100644 --- a/tests/test_epathlib.py +++ b/tests/test_epathlib.py @@ -286,6 +286,13 @@ def test_msc_s3(self): # assert not EPath("msc://s3test_msc/test").is_dir() assert not EPath("msc://s3test_msc/test/dir").is_dir() + def test_dss_path_requires_version(self): + with self.assertRaisesRegex( + AssertionError, + "DSS paths must include a dataset version separated by '@'", + ): + EPath("dss://charts1234") + def test_metadataset_v2_dss_path_parsing_str(self): """Parse a MetadatasetV2 config and ensure DSS URLs stringify correctly as EPath.""" @@ -319,15 +326,15 @@ def test_metadataset_v2_dss_path_parsing_str(self): # Create dummy DSS datasets in the cache dir so that `load_dataset()` can run # post-initialization without hitting missing-path errors. # - # - charts1234_zh@v0: minimal "webdataset" marker (presence of .nv-meta/.info.json) - # - charts1234@v0: folder with images (aux media source) - webdataset_root = cache_dir / "charts1234_zh@v0" + # - charts1234_zh/v0: minimal "webdataset" marker (presence of .nv-meta/.info.json) + # - charts1234/v0: folder with images (aux media source) + webdataset_root = cache_dir / "charts1234_zh" / "v0" (webdataset_root / MAIN_FOLDER_NAME).mkdir(parents=True, exist_ok=True) (webdataset_root / MAIN_FOLDER_NAME / INFO_JSON_FILENAME).write_text( "{}", encoding="utf-8" ) - media_root = cache_dir / "charts1234@v0" + media_root = cache_dir / "charts1234" / "v0" (media_root / "images").mkdir(parents=True, exist_ok=True) (media_root / "images" / "000.jpg").write_bytes(b"\xff\xd8\xff\xd9") (media_root / "images" / "001.jpg").write_bytes(b"\xff\xd8\xff\xd9") @@ -363,6 +370,8 @@ def test_metadataset_v2_dss_path_parsing_str(self): assert ds0.url == "dss://charts1234_zh@v0" assert aux0.url == "dss://charts1234@v0" + assert ds0.local_path() == cache_dir / "charts1234_zh" / "v0" + assert aux0.local_path() == cache_dir / "charts1234" / "v0" finally: if orig_env_cache_dir is None: os.environ.pop("NVDATASET_CACHE_DIR", None)