diff --git a/tests/algorithms/testers/torch_dynamic.py b/tests/algorithms/testers/torch_dynamic.py index a0c9c4da..3e376b1f 100644 --- a/tests/algorithms/testers/torch_dynamic.py +++ b/tests/algorithms/testers/torch_dynamic.py @@ -1,4 +1,7 @@ +from typing import Any + from pruna.algorithms.torch_dynamic import TorchDynamic +from pruna.engine.utils import get_device_type from .base_tester import AlgorithmTesterBase @@ -11,3 +14,8 @@ class TestTorchDynamic(AlgorithmTesterBase): allow_pickle_files = False algorithm_class = TorchDynamic metrics = ["latency"] + + def pre_smash_hook(self, model: Any) -> None: + """Ensure CPU dynamic quantized linear layers receive float32 activations.""" + if get_device_type(model) == "cpu": + model.float()