Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions pyiceberg/expressions/visitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import copy
import math
import threading
from abc import ABC, abstractmethod
from collections.abc import Callable
from collections.abc import Callable, Hashable
from functools import singledispatch
from typing import (
Any,
Expand All @@ -25,6 +27,9 @@
TypeVar,
)

from cachetools import LRUCache, cached
from cachetools.keys import hashkey

from pyiceberg.conversions import from_bytes
from pyiceberg.expressions import (
AlwaysFalse,
Expand Down Expand Up @@ -1970,11 +1975,37 @@ def residual_for(self, partition_data: Record) -> BooleanExpression:
return self.expr


def residual_evaluator_of(
_DEFAULT_RESIDUAL_EVALUATOR_CACHE_SIZE = 128
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 128? I think this is pretty high, and would probably go a bit lower (32?)



def _residual_evaluator_cache_key(
spec: PartitionSpec, expr: BooleanExpression, case_sensitive: bool, schema: Schema
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not pass in spec_id and schema_id here?

) -> tuple[Hashable, ...]:
return hashkey(spec.spec_id, repr(expr), case_sensitive, schema.schema_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Building the repr of the expr is super expensive. I think it would make more sense to implement __hash__ on the Expression?



@cached(
cache=LRUCache(maxsize=_DEFAULT_RESIDUAL_EVALUATOR_CACHE_SIZE),
key=_residual_evaluator_cache_key,
lock=threading.RLock(),
)
def _cached_residual_evaluator_template(
spec: PartitionSpec, expr: BooleanExpression, case_sensitive: bool, schema: Schema
) -> ResidualEvaluator:
return (
UnpartitionedResidualEvaluator(schema=schema, expr=expr)
if spec.is_unpartitioned()
else ResidualEvaluator(spec=spec, expr=expr, schema=schema, case_sensitive=case_sensitive)
)


def residual_evaluator_of(
spec: PartitionSpec, expr: BooleanExpression, case_sensitive: bool, schema: Schema
) -> ResidualEvaluator:
"""Create a residual evaluator.

Always returns a fresh evaluator instance because evaluators are stateful
(they set `self.struct` during evaluation) and may be used from multiple
threads.
"""
return copy.copy(_cached_residual_evaluator_template(spec=spec, expr=expr, case_sensitive=case_sensitive, schema=schema))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need the copy here?

13 changes: 13 additions & 0 deletions tests/expressions/test_residual_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,19 @@ def test_identity_transform_residual() -> None:
assert residual == AlwaysFalse()


def test_residual_evaluator_of_returns_fresh_instance() -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the test 👍

schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType()))
spec = PartitionSpec(PartitionField(50, 1050, IdentityTransform(), "dateint_part"))
predicate = LessThan("dateint", 20170815)

res_eval_1 = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema)
res_eval_2 = residual_evaluator_of(spec=spec, expr=predicate, case_sensitive=True, schema=schema)

assert res_eval_1 is not res_eval_2
assert res_eval_1.residual_for(Record(20170814)) == AlwaysTrue()
assert res_eval_2.residual_for(Record(20170816)) == AlwaysFalse()


def test_case_insensitive_identity_transform_residuals() -> None:
schema = Schema(NestedField(50, "dateint", IntegerType()), NestedField(51, "hour", IntegerType()))

Expand Down