diff --git a/src/datasets/fingerprint.py b/src/datasets/fingerprint.py index 13b801621bb..57590d2a4a1 100644 --- a/src/datasets/fingerprint.py +++ b/src/datasets/fingerprint.py @@ -7,7 +7,7 @@ from functools import wraps from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Optional, Union - +import types import numpy as np import xxhash @@ -208,9 +208,39 @@ def hash_bytes(cls, value: Union[bytes, list[bytes]]) -> str: for x in value: m.update(x) return m.hexdigest() - + @staticmethod + def _hash_callable(func): + code = func.__code__ + freevars = {} + if func.__closure__: + for name, cell in zip(code.co_freevars, func.__closure__): + try: + val = cell.cell_contents + # if the free variable is an object (like self), + # only keep its __dict__ filtered to simple types + if hasattr(val, "__dict__"): + freevars[name] = { + k: v for k, v in val.__dict__.items() + if isinstance(v, (int, float, str, bool, type(None))) + } + else: + freevars[name] = val + except ValueError: + freevars[name] = "" + return (code.co_code, code.co_consts, code.co_varnames, freevars) + @classmethod def hash(cls, value: Any) -> str: + if ( + isinstance(value, types.FunctionType) + and value.__closure__ + and any( + hasattr(cell.cell_contents, "__dict__") + for cell in value.__closure__ + if hasattr(cell, "cell_contents") + ) + ): + value = cls._hash_callable(value) return cls.hash_bytes(dumps(value)) def update(self, value: Any) -> None: diff --git a/tests/test_fingerprint.py b/tests/test_fingerprint.py index 393d0e5137c..278b7c984ee 100644 --- a/tests/test_fingerprint.py +++ b/tests/test_fingerprint.py @@ -573,3 +573,26 @@ def test_dependency_on_dill(): # AttributeError: module 'dill._dill' has no attribute 'stack' hasher = Hasher() hasher.update(lambda x: x) + + + +def test_map_fingerprint_stable_with_nondeterministic_closure(): + import uuid + from datasets import Dataset + + class DataModule: + def __init__(self): + self.max_length = 512 + self._uid = uuid.uuid4() # non-deterministic, changes every instantiation + self.ds = Dataset.from_dict({"text": ["hello", "world"]}) + + def process(self): + def fn(examples): + ml = self.max_length + return {"text": [t[:ml] for t in examples["text"]]} + return self.ds.map(fn, batched=True) + + fp1 = DataModule().process()._fingerprint + fp2 = DataModule().process()._fingerprint + + assert fp1 == fp2, f"Expected same fingerprint but got {fp1} != {fp2}" \ No newline at end of file