diff --git a/pyiceberg/expressions/visitors.py b/pyiceberg/expressions/visitors.py index e4ab3befa3..1e3fd64840 100644 --- a/pyiceberg/expressions/visitors.py +++ b/pyiceberg/expressions/visitors.py @@ -14,9 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import copy import math +import threading from abc import ABC, abstractmethod -from collections.abc import Callable +from collections.abc import Callable, Hashable from functools import singledispatch from typing import ( Any, @@ -25,6 +27,9 @@ TypeVar, ) +from cachetools import LRUCache, cached +from cachetools.keys import hashkey + from pyiceberg.conversions import from_bytes from pyiceberg.expressions import ( AlwaysFalse, @@ -1970,7 +1975,21 @@ def residual_for(self, partition_data: Record) -> BooleanExpression: return self.expr -def residual_evaluator_of( +_DEFAULT_RESIDUAL_EVALUATOR_CACHE_SIZE = 128 + + +def _residual_evaluator_cache_key( + spec: PartitionSpec, expr: BooleanExpression, case_sensitive: bool, schema: Schema +) -> tuple[Hashable, ...]: + return hashkey(spec.spec_id, repr(expr), case_sensitive, schema.schema_id) + + +@cached( + cache=LRUCache(maxsize=_DEFAULT_RESIDUAL_EVALUATOR_CACHE_SIZE), + key=_residual_evaluator_cache_key, + lock=threading.RLock(), +) +def _cached_residual_evaluator_template( spec: PartitionSpec, expr: BooleanExpression, case_sensitive: bool, schema: Schema ) -> ResidualEvaluator: return ( @@ -1978,3 +1997,15 @@ def residual_evaluator_of( if spec.is_unpartitioned() else ResidualEvaluator(spec=spec, expr=expr, schema=schema, case_sensitive=case_sensitive) ) + + +def residual_evaluator_of( + spec: PartitionSpec, expr: BooleanExpression, case_sensitive: bool, schema: Schema +) -> ResidualEvaluator: + """Create a residual evaluator. + + Always returns a fresh evaluator instance because evaluators are stateful + (they set `self.struct` during evaluation) and may be used from multiple + threads. + """ + return copy.copy(_cached_residual_evaluator_template(spec=spec, expr=expr, case_sensitive=case_sensitive, schema=schema)) diff --git a/tests/expressions/test_residual_evaluator.py b/tests/expressions/test_residual_evaluator.py index ba0a0da2e5..f175e49f66 100644 --- a/tests/expressions/test_residual_evaluator.py +++ b/tests/expressions/test_residual_evaluator.py @@ -88,6 +88,19 @@ def test_identity_transform_residual() -> None: assert residual == AlwaysFalse() +def test_residual_evaluator_of_returns_fresh_instance() -> None: + schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType())) + spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "dateint_part")) + predicate = LessThan("dateint", 20170815) + + res_eval_1 = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) + res_eval_2 = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema) + + assert res_eval_1 is not res_eval_2 + assert res_eval_1.residual_for(Record(20170814)) == AlwaysTrue() + assert res_eval_2.residual_for(Record(20170816)) == AlwaysFalse() + + def test_case_insensitive_identity_transform_residuals() -> None: schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType()))