From e2517b0ac30e88fdaec6102612e071e3ea62cdd7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ragnar=20Dahl=C3=A9n?= Date: Mon, 29 Dec 2025 20:22:11 +0100 Subject: [PATCH] fix: Validate SetStatisticsUpdate correctly (fixes #2865) Previously the pydantic @model_validator would fail because it assumed statistics was a model instance. In a "before"" validator that is not necessarily the case. Check type explicitly with isinstance instead, and handle `dict` case too. --- pyiceberg/table/update/__init__.py | 12 ++++-- tests/table/test_init.py | 66 +++++++++++++++++++++++++++++- 2 files changed, 74 insertions(+), 4 deletions(-) diff --git a/pyiceberg/table/update/__init__.py b/pyiceberg/table/update/__init__.py index a79e2cb468..68e3d9f753 100644 --- a/pyiceberg/table/update/__init__.py +++ b/pyiceberg/table/update/__init__.py @@ -181,9 +181,15 @@ class SetStatisticsUpdate(IcebergBaseModel): @model_validator(mode="before") def validate_snapshot_id(cls, data: dict[str, Any]) -> dict[str, Any]: - stats = cast(StatisticsFile, data["statistics"]) - - data["snapshot_id"] = stats.snapshot_id + stats = data["statistics"] + if isinstance(stats, StatisticsFile): + snapshot_id = stats.snapshot_id + elif isinstance(stats, dict): + snapshot_id = cast(int, stats.get("snapshot-id")) + else: + snapshot_id = None + + data["snapshot_id"] = snapshot_id return data diff --git a/tests/table/test_init.py b/tests/table/test_init.py index 37d7f46e38..e40513fe86 100644 --- a/tests/table/test_init.py +++ b/tests/table/test_init.py @@ -21,7 +21,7 @@ from typing import Any import pytest -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError from sortedcontainers import SortedList from pyiceberg.catalog.noop import NoopCatalog @@ -1391,6 +1391,8 @@ def test_set_statistics_update(table_v2_with_statistics: Table) -> None: statistics=statistics_file, ) + assert model_roundtrips(update) + new_metadata = update_table_metadata( table_v2_with_statistics.metadata, (update,), @@ -1425,6 +1427,57 @@ def test_set_statistics_update(table_v2_with_statistics: Table) -> None: assert json.loads(updated_statistics[0].model_dump_json()) == json.loads(expected) +def test_set_statistics_update_handles_deprecated_snapshot_id(table_v2_with_statistics: Table) -> None: + snapshot_id = table_v2_with_statistics.metadata.current_snapshot_id + + blob_metadata = BlobMetadata( + type="apache-datasketches-theta-v1", + snapshot_id=snapshot_id, + sequence_number=2, + fields=[1], + properties={"prop-key": "prop-value"}, + ) + + statistics_file = StatisticsFile( + snapshot_id=snapshot_id, + statistics_path="s3://bucket/warehouse/stats.puffin", + file_size_in_bytes=124, + file_footer_size_in_bytes=27, + blob_metadata=[blob_metadata], + ) + update_with_model = SetStatisticsUpdate(statistics=statistics_file) + assert model_roundtrips(update_with_model) + assert update_with_model.snapshot_id == snapshot_id + + update_with_dict = SetStatisticsUpdate.model_validate({"statistics": statistics_file.model_dump()}) + assert model_roundtrips(update_with_dict) + assert update_with_dict.snapshot_id == snapshot_id + + update_json = """ + { + "statistics": + { + "snapshot-id": 3055729675574597004, + "statistics-path": "s3://a/b/stats.puffin", + "file-size-in-bytes": 413, + "file-footer-size-in-bytes": 42, + "blob-metadata": [ + { + "type": "apache-datasketches-theta-v1", + "snapshot-id": 3055729675574597004, + "sequence-number": 1, + "fields": [1] + } + ] + } + } + """ + + update_with_json = SetStatisticsUpdate.model_validate_json(update_json) + assert model_roundtrips(update_with_json) + assert update_with_json.snapshot_id == snapshot_id + + def test_remove_statistics_update(table_v2_with_statistics: Table) -> None: update = RemoveStatisticsUpdate( snapshot_id=3055729675574597004, @@ -1575,3 +1628,14 @@ def test_add_snapshot_update_updates_next_row_id(table_v3: Table) -> None: new_metadata = update_table_metadata(table_v3.metadata, (AddSnapshotUpdate(snapshot=new_snapshot),)) assert new_metadata.next_row_id == 11 + + +def model_roundtrips(model: BaseModel) -> bool: + """Helper assertion that tests if a pydantic model roundtrips + successfully. + """ + __tracebackhide__ = True + model_data = model.model_dump() + if model != type(model).model_validate(model_data): + pytest.fail(f"model {type(model)} did not roundtrip successfully") + return True