Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
from .info import DatasetInfo
from .iterable_dataset import ArrowExamplesIterable, ExamplesIterable, IterableDataset
from .naming import INVALID_WINDOWS_CHARACTERS_IN_PATH, camelcase_to_snakecase
from .splits import Split, SplitDict, SplitGenerator, SplitInfo
from .splits import Split, SplitDict, SplitGenerator, SplitInfo, _check_split_names
from .streaming import extend_dataset_builder_for_streaming
from .table import CastError
from .utils import logging
Expand Down Expand Up @@ -1024,6 +1024,10 @@ def as_dataset(
if split is None:
split = {s: s for s in self.info.splits}

# Validate before doing any work so the error is clear rather than
# something cryptic bubbling up from inside arrow_reader.
_check_split_names(split, self.info.splits)

# Create a dataset for each of the given splits
datasets = map_nested(
partial(
Expand Down
8 changes: 7 additions & 1 deletion src/datasets/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
_PACKAGED_DATASETS_MODULES,
)
from .packaged_modules.folder_based_builder.folder_based_builder import FolderBasedBuilder
from .splits import Split
from .splits import Split, _check_split_names
from .utils import _dataset_viewer
from .utils.file_utils import (
_raise_if_offline_mode_is_enabled,
Expand Down Expand Up @@ -1700,6 +1700,12 @@ def load_dataset(
**config_kwargs,
)

# If split info is already known (from Hub YAML metadata or a previously cached
# dataset_info.json) we can catch a bad split name right here, before starting
# what could be a very large download.
if split is not None and builder_instance.info.splits:
_check_split_names(split, builder_instance.info.splits)

# Return iterable dataset in case of streaming
if streaming:
return builder_instance.as_streaming_dataset(split=split)
Expand Down
35 changes: 35 additions & 0 deletions src/datasets/splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,41 @@ def _from_yaml_list(cls, yaml_data: list) -> "SplitDict":
return cls.from_split_dict(yaml_data)


def _check_split_names(split, known_splits):
"""Raise ValueError if any requested split name isn't in the dataset's known splits.

Handles composite specs like ``"train+test"`` and sliced specs like
``"train[:1000]"`` by extracting the base split name before checking.

Args:
split: The split argument passed by the user – a string, :class:`Split`,
list, or dict (as accepted by ``load_dataset``).
known_splits: Mapping of available split names, e.g. ``builder.info.splits``.
"""
if not known_splits:
return

if isinstance(split, dict):
specs = list(split.values())
elif isinstance(split, (list, tuple)):
specs = list(split)
else:
specs = [split]

available = sorted(known_splits)
for spec in specs:
if spec is None:
continue
spec = str(spec).strip().strip("()")
if not spec or spec == "all":
continue
# "train+test[:50%]" → check "train" and "test" separately
for part in spec.split("+"):
name = part.strip().split("[")[0].strip()
if name and name not in known_splits:
raise ValueError(f'Unknown split "{name}". Should be one of {available}.')


@dataclass
class SplitGenerator:
"""Defines the split information for the generator.
Expand Down
24 changes: 24 additions & 0 deletions tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,30 @@ def test_builder_as_dataset(split, expected_dataset_class, expected_dataset_leng
dataset.column_names == ["text"]


def test_builder_as_dataset_unknown_split_raises(tmp_path):
builder = DummyBuilder(cache_dir=str(tmp_path))
os.makedirs(builder.cache_dir)

builder.info.splits = SplitDict()
builder.info.splits.add(SplitInfo("train", num_examples=10))
builder.info.splits.add(SplitInfo("test", num_examples=10))

for split_name in builder.info.splits:
with ArrowWriter(
path=os.path.join(builder.cache_dir, f"{builder.dataset_name}-{split_name}.arrow"),
features=Features({"text": Value("string")}),
) as writer:
writer.write_batch({"text": ["foo"] * 10})
writer.finalize()

with pytest.raises(ValueError, match='Unknown split "validation"'):
builder.as_dataset(split="validation")

# compound spec – one part is valid, one isn't
with pytest.raises(ValueError, match='Unknown split "oops"'):
builder.as_dataset(split="train+oops")


@pytest.mark.parametrize("in_memory", [False, True])
def test_generator_based_builder_as_dataset(in_memory, tmp_path):
cache_dir = tmp_path / "data"
Expand Down
54 changes: 53 additions & 1 deletion tests/test_splits.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from datasets.splits import Split, SplitDict, SplitInfo
from datasets.splits import Split, SplitDict, SplitInfo, _check_split_names
from datasets.utils.py_utils import asdict


Expand Down Expand Up @@ -41,3 +41,55 @@ def test_split_dict_asdict_has_dataset_name(split_info):
def test_named_split_inequality():
# Used while building the docs, when set as a default parameter value in a function signature
assert Split.TRAIN != inspect.Parameter.empty


# ---------------------------------------------------------------------------
# _check_split_names
# ---------------------------------------------------------------------------

_SPLITS = SplitDict(
{
"train": SplitInfo(name="train", num_examples=100),
"test": SplitInfo(name="test", num_examples=50),
}
)


@pytest.mark.parametrize(
"split",
[
"train",
"test",
"train[:50%]",
"train[10:20]",
"train+test",
"train[:50%]+test",
["train", "test"],
{"my_train": "train", "my_test": "test"},
None,
"all",
],
)
def test_check_split_names_valid(split):
# should not raise
_check_split_names(split, _SPLITS)


@pytest.mark.parametrize(
"split, bad_name",
[
("blabla", "blabla"),
("train+blabla", "blabla"),
("blabla[:50%]", "blabla"),
(["train", "blabla"], "blabla"),
({"a": "train", "b": "blabla"}, "blabla"),
],
)
def test_check_split_names_invalid(split, bad_name):
with pytest.raises(ValueError, match=f'Unknown split "{bad_name}"'):
_check_split_names(split, _SPLITS)


def test_check_split_names_empty_known_splits():
# can't validate anything without known splits – should be a no-op
_check_split_names("whatever", SplitDict())