From 0651412bf46eb4d5b6f1118dd80159cb4f0f79ca Mon Sep 17 00:00:00 2001 From: frankzfli Date: Sat, 2 May 2026 22:52:53 +0800 Subject: [PATCH 1/2] Add Apache Iceberg format support Co-Authored-By: Claude Opus 4.6 --- setup.py | 4 + src/datasets/packaged_modules/__init__.py | 2 + .../packaged_modules/iceberg/__init__.py | 0 .../packaged_modules/iceberg/iceberg.py | 173 ++++++++++++++++++ tests/packaged_modules/test_iceberg.py | 157 ++++++++++++++++ 5 files changed, 336 insertions(+) create mode 100644 src/datasets/packaged_modules/iceberg/__init__.py create mode 100644 src/datasets/packaged_modules/iceberg/iceberg.py create mode 100644 tests/packaged_modules/test_iceberg.py diff --git a/setup.py b/setup.py index 2910e4ea930..e72d6d6a128 100644 --- a/setup.py +++ b/setup.py @@ -168,6 +168,7 @@ "faiss-cpu>=1.8.0.post1", # Pins numpy < 2 "h5py", "pylance", + "pyiceberg[sql-sqlite,pyarrow]", "jax>=0.3.14; sys_platform != 'win32'", "jaxlib>=0.3.14; sys_platform != 'win32'", "lz4; python_version < '3.14'", # python 3.14 gives ImportError: cannot import name '_compression' from partially initialized module 'lz4.frame @@ -211,6 +212,8 @@ NIBABEL_REQUIRE = ["nibabel>=5.3.2", "ipyniivue==2.4.2"] +ICEBERG_REQUIRE = ["pyiceberg>=0.7.0"] + EXTRAS_REQUIRE = { "audio": AUDIO_REQUIRE, "vision": VISION_REQUIRE, @@ -229,6 +232,7 @@ "docs": DOCS_REQUIRE, "pdfs": PDFS_REQUIRE, "nibabel": NIBABEL_REQUIRE, + "iceberg": ICEBERG_REQUIRE, } setup( diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py index f0ebcb79693..efcfcf4b643 100644 --- a/src/datasets/packaged_modules/__init__.py +++ b/src/datasets/packaged_modules/__init__.py @@ -10,6 +10,7 @@ from .csv import csv from .eval import eval from .hdf5 import hdf5 +from .iceberg import iceberg from .imagefolder import imagefolder from .json import json from .lance import lance @@ -55,6 +56,7 @@ def _hash_python_lines(lines: list[str]) -> str: "hdf5": (hdf5.__name__, _hash_python_lines(inspect.getsource(hdf5).splitlines())), "eval": (eval.__name__, _hash_python_lines(inspect.getsource(eval).splitlines())), "lance": (lance.__name__, _hash_python_lines(inspect.getsource(lance).splitlines())), + "iceberg": (iceberg.__name__, _hash_python_lines(inspect.getsource(iceberg).splitlines())), } # get importable module names and hash for caching diff --git a/src/datasets/packaged_modules/iceberg/__init__.py b/src/datasets/packaged_modules/iceberg/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/datasets/packaged_modules/iceberg/iceberg.py b/src/datasets/packaged_modules/iceberg/iceberg.py new file mode 100644 index 00000000000..447ed10fceb --- /dev/null +++ b/src/datasets/packaged_modules/iceberg/iceberg.py @@ -0,0 +1,173 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Union + +import pyarrow as pa + +import datasets +from datasets.builder import Key +from datasets.features import Features +from datasets.table import table_cast + + +if TYPE_CHECKING: + from pyiceberg.catalog import Catalog + from pyiceberg.expressions import BooleanExpression + from pyiceberg.table import FileScanTask + +logger = datasets.utils.logging.get_logger(__name__) + + +@dataclass +class IcebergConfig(datasets.BuilderConfig): + """BuilderConfig for Apache Iceberg format. + + Args: + catalog (`pyiceberg.catalog.Catalog`): + A pre-configured pyiceberg Catalog object. + table (`str` or `Dict[str, str]`): + Iceberg table identifier, e.g. ``"db.my_table"``. + Pass a dict to map split names to table identifiers, + e.g. ``{"train": "db.train", "test": "db.test"}``. + features (`Features`, *optional*): + Cast the data to these features. + columns (`List[str]`, *optional*): + List of columns to load; others are ignored. + filters (`str` or `BooleanExpression`, *optional*): + Row filter with predicate pushdown. Accepts a SQL-style string + (``"col > 1 AND col2 == 'foo'"``), or a pyiceberg + ``BooleanExpression`` object. Parsed by pyiceberg internally. + batch_size (`int`, defaults to ``131072``): + Number of rows per RecordBatch when reading. + snapshot_id (`int`, *optional*): + Load a specific snapshot for time-travel queries. + """ + + catalog: Optional["Catalog"] = None + table: Optional[Union[str, Dict[str, str]]] = None + features: Optional[datasets.Features] = None + columns: Optional[List[str]] = None + filters: Optional[Union[str, "BooleanExpression"]] = None + batch_size: int = 131072 + snapshot_id: Optional[int] = None + + def __post_init__(self): + super().__post_init__() + if self.catalog is None: + raise ValueError("`catalog` must be a pyiceberg Catalog object, but got None.") + if self.table is None: + raise ValueError("`table` must be specified, e.g. table='db.my_table'") + # Normalize table to Dict[split_name, table_identifier] + if isinstance(self.table, str): + self.table = {"train": self.table} + # Generate a stable config name for caching + if self.name == "default": + catalog_id = f"{self.catalog.__class__.__name__}_{self.catalog.name}" + table_id = "_".join(sorted(self.table.values())) + self.name = f"{catalog_id}_{table_id}" + + def create_config_id( + self, + config_kwargs: dict, + custom_features: Optional[Features] = None, + ) -> str: + # The catalog object is not picklable (contains SQLAlchemy engines, etc.), + # so we replace it with a hashable string representation before the + # parent class hashes config_kwargs via dill. + config_kwargs = config_kwargs.copy() + catalog = config_kwargs.pop("catalog", None) + if catalog is not None: + config_kwargs["_catalog_id"] = f"{catalog.__class__.__name__}_{catalog.name}" + # filters may contain pyiceberg Expression objects that are not picklable + filters = config_kwargs.pop("filters", None) + if filters is not None: + config_kwargs["_filters_repr"] = repr(filters) + return super().create_config_id(config_kwargs, custom_features=custom_features) + + +class Iceberg(datasets.ArrowBasedBuilder, datasets.builder._CountableBuilderMixin): + BUILDER_CONFIG_CLASS = IcebergConfig + + def _info(self): + return datasets.DatasetInfo(features=self.config.features) + + def _split_generators(self, dl_manager): + splits = [] + for split_name, table_id in self.config.table.items(): + iceberg_table = self.config.catalog.load_table(table_id) + + scan_kwargs = {} + if self.config.filters is not None: + scan_kwargs["row_filter"] = self.config.filters + if self.config.columns: + scan_kwargs["selected_fields"] = tuple(self.config.columns) + if self.config.snapshot_id is not None: + scan_kwargs["snapshot_id"] = self.config.snapshot_id + + scan = iceberg_table.scan(**scan_kwargs) + + # Infer features from Arrow schema if not user-provided + if self.info.features is None: + arrow_schema = scan.projection().as_arrow() + self.info.features = datasets.Features.from_arrow_schema(arrow_schema) + + # Plan files for parallel processing: passing a list in gen_kwargs + # enables _split_gen_kwargs to distribute tasks across num_proc workers. + tasks = list(scan.plan_files()) + + # Extract picklable scan context for multiprocessing compatibility. + # The scan object itself is not picklable (holds catalog connections), + # but these components are individually serializable. + scan_context = ( + scan.table_metadata, + scan.io, + scan.projection(), + scan.row_filter, + scan.case_sensitive, + scan.limit, + ) + + splits.append( + datasets.SplitGenerator( + name=split_name, + gen_kwargs={"tasks": tasks, "scan_context": scan_context}, + ) + ) + + # Drop the catalog reference so the builder becomes picklable for num_proc > 1. + # All data needed for reading has been extracted into scan_context above. + self.config.catalog = None + self.config_kwargs.pop("catalog", None) + + return splits + + def _cast_table(self, pa_table: pa.Table) -> pa.Table: + if self.info.features is not None: + # More expensive cast to support nested features with keys in a different order + # allows str <-> int/float or str to Audio for example + pa_table = table_cast(pa_table, self.info.features.arrow_schema) + return pa_table + + def _generate_shards(self, tasks: List["FileScanTask"], scan_context): + for task in tasks: + yield task.file.file_path + + def _generate_num_examples(self, tasks: List["FileScanTask"], scan_context): + for task in tasks: + yield task.file.record_count + + def _generate_tables(self, tasks: List["FileScanTask"], scan_context): + from pyiceberg.io.pyarrow import ArrowScan + + table_metadata, io, projected_schema, row_filter, case_sensitive, limit = scan_context + arrow_scan = ArrowScan( + table_metadata, + io, + projected_schema, + row_filter, + case_sensitive=case_sensitive, + limit=limit, + ) + for task_idx, task in enumerate(tasks): + for batch_idx, batch in enumerate(arrow_scan.to_record_batches([task])): + pa_table = pa.Table.from_batches([batch]) + yield Key(task_idx, batch_idx), self._cast_table(pa_table) diff --git a/tests/packaged_modules/test_iceberg.py b/tests/packaged_modules/test_iceberg.py new file mode 100644 index 00000000000..c98df97c358 --- /dev/null +++ b/tests/packaged_modules/test_iceberg.py @@ -0,0 +1,157 @@ +import numpy as np +import pyarrow as pa +import pytest +from pyiceberg.catalog.sql import SqlCatalog +from pyiceberg.schema import Schema +from pyiceberg.types import DoubleType, FloatType, ListType, LongType, NestedField, StringType + +from datasets import IterableDataset, load_dataset + + +@pytest.fixture +def catalog(tmp_path): + cat = SqlCatalog( + "test_catalog", + **{ + "uri": f"sqlite:///{tmp_path}/catalog.db", + "warehouse": str(tmp_path / "warehouse"), + }, + ) + cat.create_namespace("test_db") + return cat + + +@pytest.fixture +def sample_table(catalog): + schema = Schema( + NestedField(1, "id", LongType()), + NestedField(2, "name", StringType()), + NestedField(3, "value", DoubleType()), + NestedField(4, "vector", ListType(element_id=5, element_type=FloatType(), element_required=False)), + ) + table = catalog.create_table("test_db.sample", schema=schema) + table.append( + pa.table( + { + "id": pa.array([1, 2, 3], type=pa.int64()), + "name": pa.array(["alice", "bob", "carol"], type=pa.large_string()), + "value": pa.array([1.1, 2.2, 3.3], type=pa.float64()), + "vector": pa.FixedSizeListArray.from_arrays(pa.array([0.1] * 12, pa.float32()), list_size=4), + } + ) + ) + return table + + +def test_load_iceberg_basic(catalog, sample_table): + ds = load_dataset("iceberg", catalog=catalog, table="test_db.sample") + assert "train" in ds + dataset = ds["train"] + assert dataset.num_rows == 3 + assert "id" in dataset.column_names + assert "name" in dataset.column_names + assert "value" in dataset.column_names + assert "vector" in dataset.column_names + assert list(dataset["id"]) == [1, 2, 3] + assert list(dataset["name"]) == ["alice", "bob", "carol"] + + +def test_load_vectors(catalog, sample_table): + ds = load_dataset("iceberg", catalog=catalog, table="test_db.sample", columns=["vector"]) + dataset = ds["train"] + assert "vector" in dataset.column_names + vectors = dataset.data["vector"].combine_chunks().values.to_numpy(zero_copy_only=False) + assert np.allclose(vectors, np.full(12, 0.1), atol=1e-6) + + +def test_load_iceberg_columns(catalog, sample_table): + ds = load_dataset("iceberg", catalog=catalog, table="test_db.sample", columns=["id", "name"]) + dataset = ds["train"] + assert "id" in dataset.column_names + assert "name" in dataset.column_names + assert "value" not in dataset.column_names + + +def test_load_iceberg_filters(catalog, sample_table): + ds = load_dataset("iceberg", catalog=catalog, table="test_db.sample", filters="value > 2.0") + dataset = ds["train"] + assert dataset.num_rows == 2 + assert list(dataset["name"]) == ["bob", "carol"] + + +def test_load_iceberg_multi_split(catalog): + schema = Schema( + NestedField(1, "x", LongType()), + ) + train_table = catalog.create_table("test_db.train_split", schema=schema) + train_table.append(pa.table({"x": pa.array([1, 2, 3], type=pa.int64())})) + + test_table = catalog.create_table("test_db.test_split", schema=schema) + test_table.append(pa.table({"x": pa.array([10, 20], type=pa.int64())})) + + ds = load_dataset( + "iceberg", + catalog=catalog, + table={"train": "test_db.train_split", "test": "test_db.test_split"}, + ) + assert "train" in ds + assert "test" in ds + assert ds["train"].num_rows == 3 + assert ds["test"].num_rows == 2 + + +@pytest.mark.parametrize("streaming", [False, True]) +def test_load_iceberg_streaming(catalog, sample_table, streaming): + ds = load_dataset("iceberg", catalog=catalog, table="test_db.sample", split="train", streaming=streaming) + if streaming: + assert isinstance(ds, IterableDataset) + items = list(ds) + assert len(items) == 3 + assert all("id" in item for item in items) + + +def test_load_iceberg_snapshot(catalog): + schema = Schema( + NestedField(1, "id", LongType()), + ) + table = catalog.create_table("test_db.versioned", schema=schema) + table.append(pa.table({"id": pa.array([1, 2], type=pa.int64())})) + + # Capture snapshot after first append + first_snapshot_id = table.current_snapshot().snapshot_id + + # Append more data + table.append(pa.table({"id": pa.array([3, 4, 5], type=pa.int64())})) + + # Load at latest: should have 5 rows + ds_latest = load_dataset("iceberg", catalog=catalog, table="test_db.versioned") + assert ds_latest["train"].num_rows == 5 + + # Load at first snapshot: should have 2 rows + ds_old = load_dataset("iceberg", catalog=catalog, table="test_db.versioned", snapshot_id=first_snapshot_id) + assert ds_old["train"].num_rows == 2 + + +def test_load_iceberg_num_proc(catalog): + """Test that num_proc > 1 works for parallel processing.""" + schema = Schema( + NestedField(1, "id", LongType()), + ) + table = catalog.create_table("test_db.parallel", schema=schema) + table.append(pa.table({"id": pa.array([1, 2, 3], type=pa.int64())})) + table.append(pa.table({"id": pa.array([4, 5, 6], type=pa.int64())})) + + ds = load_dataset("iceberg", catalog=catalog, table="test_db.parallel", num_proc=2) + dataset = ds["train"] + assert dataset.num_rows == 6 + assert sorted(dataset["id"]) == [1, 2, 3, 4, 5, 6] + + +def test_load_iceberg_missing_catalog_raises(): + with pytest.raises(ValueError, match="catalog"): + load_dataset("iceberg", catalog=None, table="db.table") + + +def test_load_iceberg_missing_table_raises(catalog): + with pytest.raises(ValueError, match="table"): + load_dataset("iceberg", catalog=catalog, table=None) From e9811ff26451d8054397d30f3e3174be7d49cae7 Mon Sep 17 00:00:00 2001 From: frankzfli Date: Wed, 27 May 2026 12:22:02 +0800 Subject: [PATCH 2/2] fix comment --- tests/packaged_modules/test_iceberg.py | 39 ++++++++++++++++++++++++-- tests/utils.py | 14 +++++++++ 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/tests/packaged_modules/test_iceberg.py b/tests/packaged_modules/test_iceberg.py index c98df97c358..4ca12c5a53f 100644 --- a/tests/packaged_modules/test_iceberg.py +++ b/tests/packaged_modules/test_iceberg.py @@ -1,15 +1,16 @@ import numpy as np import pyarrow as pa import pytest -from pyiceberg.catalog.sql import SqlCatalog -from pyiceberg.schema import Schema -from pyiceberg.types import DoubleType, FloatType, ListType, LongType, NestedField, StringType from datasets import IterableDataset, load_dataset +from ..utils import require_not_windows, require_pyiceberg + @pytest.fixture def catalog(tmp_path): + from pyiceberg.catalog.sql import SqlCatalog + cat = SqlCatalog( "test_catalog", **{ @@ -23,6 +24,9 @@ def catalog(tmp_path): @pytest.fixture def sample_table(catalog): + from pyiceberg.schema import Schema + from pyiceberg.types import DoubleType, FloatType, ListType, LongType, NestedField, StringType + schema = Schema( NestedField(1, "id", LongType()), NestedField(2, "name", StringType()), @@ -43,6 +47,8 @@ def sample_table(catalog): return table +@require_not_windows +@require_pyiceberg def test_load_iceberg_basic(catalog, sample_table): ds = load_dataset("iceberg", catalog=catalog, table="test_db.sample") assert "train" in ds @@ -56,6 +62,8 @@ def test_load_iceberg_basic(catalog, sample_table): assert list(dataset["name"]) == ["alice", "bob", "carol"] +@require_not_windows +@require_pyiceberg def test_load_vectors(catalog, sample_table): ds = load_dataset("iceberg", catalog=catalog, table="test_db.sample", columns=["vector"]) dataset = ds["train"] @@ -64,6 +72,8 @@ def test_load_vectors(catalog, sample_table): assert np.allclose(vectors, np.full(12, 0.1), atol=1e-6) +@require_not_windows +@require_pyiceberg def test_load_iceberg_columns(catalog, sample_table): ds = load_dataset("iceberg", catalog=catalog, table="test_db.sample", columns=["id", "name"]) dataset = ds["train"] @@ -72,6 +82,8 @@ def test_load_iceberg_columns(catalog, sample_table): assert "value" not in dataset.column_names +@require_not_windows +@require_pyiceberg def test_load_iceberg_filters(catalog, sample_table): ds = load_dataset("iceberg", catalog=catalog, table="test_db.sample", filters="value > 2.0") dataset = ds["train"] @@ -79,7 +91,12 @@ def test_load_iceberg_filters(catalog, sample_table): assert list(dataset["name"]) == ["bob", "carol"] +@require_not_windows +@require_pyiceberg def test_load_iceberg_multi_split(catalog): + from pyiceberg.schema import Schema + from pyiceberg.types import LongType, NestedField + schema = Schema( NestedField(1, "x", LongType()), ) @@ -100,6 +117,8 @@ def test_load_iceberg_multi_split(catalog): assert ds["test"].num_rows == 2 +@require_not_windows +@require_pyiceberg @pytest.mark.parametrize("streaming", [False, True]) def test_load_iceberg_streaming(catalog, sample_table, streaming): ds = load_dataset("iceberg", catalog=catalog, table="test_db.sample", split="train", streaming=streaming) @@ -110,7 +129,12 @@ def test_load_iceberg_streaming(catalog, sample_table, streaming): assert all("id" in item for item in items) +@require_not_windows +@require_pyiceberg def test_load_iceberg_snapshot(catalog): + from pyiceberg.schema import Schema + from pyiceberg.types import LongType, NestedField + schema = Schema( NestedField(1, "id", LongType()), ) @@ -132,8 +156,13 @@ def test_load_iceberg_snapshot(catalog): assert ds_old["train"].num_rows == 2 +@require_not_windows +@require_pyiceberg def test_load_iceberg_num_proc(catalog): """Test that num_proc > 1 works for parallel processing.""" + from pyiceberg.schema import Schema + from pyiceberg.types import LongType, NestedField + schema = Schema( NestedField(1, "id", LongType()), ) @@ -147,11 +176,15 @@ def test_load_iceberg_num_proc(catalog): assert sorted(dataset["id"]) == [1, 2, 3, 4, 5, 6] +@require_not_windows +@require_pyiceberg def test_load_iceberg_missing_catalog_raises(): with pytest.raises(ValueError, match="catalog"): load_dataset("iceberg", catalog=None, table="db.table") +@require_not_windows +@require_pyiceberg def test_load_iceberg_missing_table_raises(catalog): with pytest.raises(ValueError, match="table"): load_dataset("iceberg", catalog=catalog, table=None) diff --git a/tests/utils.py b/tests/utils.py index 88bff466297..e9cceb2d730 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -127,6 +127,20 @@ def require_sqlalchemy(test_case): return test_case +def require_pyiceberg(test_case): + """ + Decorator marking a test that requires PyIceberg. + + These tests are skipped when PyIceberg isn't installed. + + """ + try: + import pyiceberg # noqa F401 + except ImportError: + test_case = unittest.skip("test requires pyiceberg")(test_case) + return test_case + + def require_torch(test_case): """ Decorator marking a test that requires PyTorch.