Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions src/datasets/fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from functools import wraps
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Union

import types
import numpy as np
import xxhash

Expand Down Expand Up @@ -208,9 +208,39 @@ def hash_bytes(cls, value: Union[bytes, list[bytes]]) -> str:
for x in value:
m.update(x)
return m.hexdigest()

@staticmethod
def _hash_callable(func):
code = func.__code__
freevars = {}
if func.__closure__:
for name, cell in zip(code.co_freevars, func.__closure__):
try:
val = cell.cell_contents
# if the free variable is an object (like self),
# only keep its __dict__ filtered to simple types
if hasattr(val, "__dict__"):
freevars[name] = {
k: v for k, v in val.__dict__.items()
if isinstance(v, (int, float, str, bool, type(None)))
}
else:
freevars[name] = val
except ValueError:
freevars[name] = "<empty cell>"
return (code.co_code, code.co_consts, code.co_varnames, freevars)

@classmethod
def hash(cls, value: Any) -> str:
if (
isinstance(value, types.FunctionType)
and value.__closure__
and any(
hasattr(cell.cell_contents, "__dict__")
for cell in value.__closure__
if hasattr(cell, "cell_contents")
)
):
value = cls._hash_callable(value)
return cls.hash_bytes(dumps(value))

def update(self, value: Any) -> None:
Expand Down
23 changes: 23 additions & 0 deletions tests/test_fingerprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,3 +573,26 @@ def test_dependency_on_dill():
# AttributeError: module 'dill._dill' has no attribute 'stack'
hasher = Hasher()
hasher.update(lambda x: x)



def test_map_fingerprint_stable_with_nondeterministic_closure():
import uuid
from datasets import Dataset

class DataModule:
def __init__(self):
self.max_length = 512
self._uid = uuid.uuid4() # non-deterministic, changes every instantiation
self.ds = Dataset.from_dict({"text": ["hello", "world"]})

def process(self):
def fn(examples):
ml = self.max_length
return {"text": [t[:ml] for t in examples["text"]]}
return self.ds.map(fn, batched=True)

fp1 = DataModule().process()._fingerprint
fp2 = DataModule().process()._fingerprint

assert fp1 == fp2, f"Expected same fingerprint but got {fp1} != {fp2}"