From 05c321a5890ea9131b4d18d3589eb7e717d9bea3 Mon Sep 17 00:00:00 2001 From: Begum Cig Date: Mon, 22 Jun 2026 09:50:58 +0000 Subject: [PATCH] test: fix the model dtype for non cpu supported type --- tests/algorithms/testers/torch_dynamic.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/algorithms/testers/torch_dynamic.py b/tests/algorithms/testers/torch_dynamic.py index a0c9c4da..3e376b1f 100644 --- a/tests/algorithms/testers/torch_dynamic.py +++ b/tests/algorithms/testers/torch_dynamic.py @@ -1,4 +1,7 @@ +from typing import Any + from pruna.algorithms.torch_dynamic import TorchDynamic +from pruna.engine.utils import get_device_type from .base_tester import AlgorithmTesterBase @@ -11,3 +14,8 @@ class TestTorchDynamic(AlgorithmTesterBase): allow_pickle_files = False algorithm_class = TorchDynamic metrics = ["latency"] + + def pre_smash_hook(self, model: Any) -> None: + """Ensure CPU dynamic quantized linear layers receive float32 activations.""" + if get_device_type(model) == "cpu": + model.float()