From 38b7727107dcd7852ca3777ae935fbe09ae51edc Mon Sep 17 00:00:00 2001 From: ParamChordiya Date: Tue, 26 May 2026 20:51:14 -0500 Subject: [PATCH] Validate split name before download when split info is already known When load_dataset() is called with an invalid split name, the error is currently raised deep inside arrow_reader after the full download has already completed. For large datasets this wastes a lot of time. Add _check_split_names() to splits.py and call it from two places: * load.py: immediately after load_dataset_builder(), before download_and_prepare(). If the builder already has split info (from Hub YAML metadata or a cached dataset_info.json) we can bail out early. * builder.as_dataset(): after the default-split expansion, before map_nested(). This guarantees a clear ValueError with the list of available splits instead of a confusing error bubbling up from arrow_reader, even in cases where the early check wasn't possible. The helper handles composite specs ("train+test"), sliced specs ("train[:1000]"), lists, and dicts transparently. Fixes #5523 --- src/datasets/builder.py | 6 ++++- src/datasets/load.py | 8 +++++- src/datasets/splits.py | 35 ++++++++++++++++++++++++++ tests/test_builder.py | 24 ++++++++++++++++++ tests/test_splits.py | 54 ++++++++++++++++++++++++++++++++++++++++- 5 files changed, 124 insertions(+), 3 deletions(-) diff --git a/src/datasets/builder.py b/src/datasets/builder.py index 5b702df44a9..135546fb62f 100644 --- a/src/datasets/builder.py +++ b/src/datasets/builder.py @@ -59,7 +59,7 @@ from .info import DatasetInfo from .iterable_dataset import ArrowExamplesIterable, ExamplesIterable, IterableDataset from .naming import INVALID_WINDOWS_CHARACTERS_IN_PATH, camelcase_to_snakecase -from .splits import Split, SplitDict, SplitGenerator, SplitInfo +from .splits import Split, SplitDict, SplitGenerator, SplitInfo, _check_split_names from .streaming import extend_dataset_builder_for_streaming from .table import CastError from .utils import logging @@ -1024,6 +1024,10 @@ def as_dataset( if split is None: split = {s: s for s in self.info.splits} + # Validate before doing any work so the error is clear rather than + # something cryptic bubbling up from inside arrow_reader. + _check_split_names(split, self.info.splits) + # Create a dataset for each of the given splits datasets = map_nested( partial( diff --git a/src/datasets/load.py b/src/datasets/load.py index 560bcad3a44..9048bb2a7c8 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -75,7 +75,7 @@ _PACKAGED_DATASETS_MODULES, ) from .packaged_modules.folder_based_builder.folder_based_builder import FolderBasedBuilder -from .splits import Split +from .splits import Split, _check_split_names from .utils import _dataset_viewer from .utils.file_utils import ( _raise_if_offline_mode_is_enabled, @@ -1700,6 +1700,12 @@ def load_dataset( **config_kwargs, ) + # If split info is already known (from Hub YAML metadata or a previously cached + # dataset_info.json) we can catch a bad split name right here, before starting + # what could be a very large download. + if split is not None and builder_instance.info.splits: + _check_split_names(split, builder_instance.info.splits) + # Return iterable dataset in case of streaming if streaming: return builder_instance.as_streaming_dataset(split=split) diff --git a/src/datasets/splits.py b/src/datasets/splits.py index 7e8ea953afd..2a158c6551a 100644 --- a/src/datasets/splits.py +++ b/src/datasets/splits.py @@ -610,6 +610,41 @@ def _from_yaml_list(cls, yaml_data: list) -> "SplitDict": return cls.from_split_dict(yaml_data) +def _check_split_names(split, known_splits): + """Raise ValueError if any requested split name isn't in the dataset's known splits. + + Handles composite specs like ``"train+test"`` and sliced specs like + ``"train[:1000]"`` by extracting the base split name before checking. + + Args: + split: The split argument passed by the user – a string, :class:`Split`, + list, or dict (as accepted by ``load_dataset``). + known_splits: Mapping of available split names, e.g. ``builder.info.splits``. + """ + if not known_splits: + return + + if isinstance(split, dict): + specs = list(split.values()) + elif isinstance(split, (list, tuple)): + specs = list(split) + else: + specs = [split] + + available = sorted(known_splits) + for spec in specs: + if spec is None: + continue + spec = str(spec).strip().strip("()") + if not spec or spec == "all": + continue + # "train+test[:50%]" → check "train" and "test" separately + for part in spec.split("+"): + name = part.strip().split("[")[0].strip() + if name and name not in known_splits: + raise ValueError(f'Unknown split "{name}". Should be one of {available}.') + + @dataclass class SplitGenerator: """Defines the split information for the generator. diff --git a/tests/test_builder.py b/tests/test_builder.py index 14d44fae7d3..75965775f83 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -552,6 +552,30 @@ def test_builder_as_dataset(split, expected_dataset_class, expected_dataset_leng dataset.column_names == ["text"] +def test_builder_as_dataset_unknown_split_raises(tmp_path): + builder = DummyBuilder(cache_dir=str(tmp_path)) + os.makedirs(builder.cache_dir) + + builder.info.splits = SplitDict() + builder.info.splits.add(SplitInfo("train", num_examples=10)) + builder.info.splits.add(SplitInfo("test", num_examples=10)) + + for split_name in builder.info.splits: + with ArrowWriter( + path=os.path.join(builder.cache_dir, f"{builder.dataset_name}-{split_name}.arrow"), + features=Features({"text": Value("string")}), + ) as writer: + writer.write_batch({"text": ["foo"] * 10}) + writer.finalize() + + with pytest.raises(ValueError, match='Unknown split "validation"'): + builder.as_dataset(split="validation") + + # compound spec – one part is valid, one isn't + with pytest.raises(ValueError, match='Unknown split "oops"'): + builder.as_dataset(split="train+oops") + + @pytest.mark.parametrize("in_memory", [False, True]) def test_generator_based_builder_as_dataset(in_memory, tmp_path): cache_dir = tmp_path / "data" diff --git a/tests/test_splits.py b/tests/test_splits.py index 364880ec686..b9894f2a456 100644 --- a/tests/test_splits.py +++ b/tests/test_splits.py @@ -2,7 +2,7 @@ import pytest -from datasets.splits import Split, SplitDict, SplitInfo +from datasets.splits import Split, SplitDict, SplitInfo, _check_split_names from datasets.utils.py_utils import asdict @@ -41,3 +41,55 @@ def test_split_dict_asdict_has_dataset_name(split_info): def test_named_split_inequality(): # Used while building the docs, when set as a default parameter value in a function signature assert Split.TRAIN != inspect.Parameter.empty + + +# --------------------------------------------------------------------------- +# _check_split_names +# --------------------------------------------------------------------------- + +_SPLITS = SplitDict( + { + "train": SplitInfo(name="train", num_examples=100), + "test": SplitInfo(name="test", num_examples=50), + } +) + + +@pytest.mark.parametrize( + "split", + [ + "train", + "test", + "train[:50%]", + "train[10:20]", + "train+test", + "train[:50%]+test", + ["train", "test"], + {"my_train": "train", "my_test": "test"}, + None, + "all", + ], +) +def test_check_split_names_valid(split): + # should not raise + _check_split_names(split, _SPLITS) + + +@pytest.mark.parametrize( + "split, bad_name", + [ + ("blabla", "blabla"), + ("train+blabla", "blabla"), + ("blabla[:50%]", "blabla"), + (["train", "blabla"], "blabla"), + ({"a": "train", "b": "blabla"}, "blabla"), + ], +) +def test_check_split_names_invalid(split, bad_name): + with pytest.raises(ValueError, match=f'Unknown split "{bad_name}"'): + _check_split_names(split, _SPLITS) + + +def test_check_split_names_empty_known_splits(): + # can't validate anything without known splits – should be a no-op + _check_split_names("whatever", SplitDict())