diff --git a/src/bindings/python/src/openvino/frontend/pytorch/gptq.py b/src/bindings/python/src/openvino/frontend/pytorch/gptq.py index f185109b9945c2..329a64e891c51f 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/gptq.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/gptq.py @@ -94,7 +94,7 @@ def patched_forward_sym(self, *args, **kwargs): # All the following AutoGPTQ/GPTQModel quant types are supposed to have the same weights packing schema -supported_quant_types = ["triton", "exllama", "exllamav2", "cuda-old", "hf_kernel"] +supported_quant_types = ["triton", "exllama", "exllamav2", "cuda-old", "hf_kernel", "torch_fused"] def patch_model(model): diff --git a/tests/layer_tests/py_frontend_tests/test_torch_frontend.py b/tests/layer_tests/py_frontend_tests/test_torch_frontend.py index 47c0296d09c079..cda59c2a7bdf34 100644 --- a/tests/layer_tests/py_frontend_tests/test_torch_frontend.py +++ b/tests/layer_tests/py_frontend_tests/test_torch_frontend.py @@ -2317,6 +2317,104 @@ def forward(self, x): assert not hasattr(m, "_openvino_quantized_patch_orig_forward") +def _make_torch_fused_gptq_model(in_features=32, out_features=64, group_size=32): + """Build a minimal GPTQ model whose linear layer mimics gptqmodel's + ``TorchFusedQuantLinear`` backend (``QUANT_TYPE == "torch_fused"``), using the + standard 4-bit/int32 weight packing the OpenVINO GPTQ patcher expects. The + layer's own ``forward`` is a placeholder — OpenVINO replaces it with its + decompression forward before tracing/export, so only the packed buffers and + attributes need to be realistic. + """ + bits = 4 + pack_num = 32 // bits # 8 nibbles per int32 + + class FakeQuantConfig: + quant_method = "gptq" + sym = True + + class FakeConfig: + quantization_config = FakeQuantConfig() + + class TorchFusedLinear(torch.nn.Module): + QUANT_TYPE = "torch_fused" + + def __init__(self): + super().__init__() + self.bits = bits + self.group_size = group_size + # Real GPTQ backends register the packed tensors as buffers (not + # parameters); the OpenVINO patcher re-assigns plain tensors to them. + self.register_buffer("qweight", torch.randint( + 0, 2 ** 31, (in_features // pack_num, out_features), + dtype=torch.int32)) + self.register_buffer("qzeros", torch.randint( + 0, 2 ** 31, (in_features // group_size, out_features // pack_num), + dtype=torch.int32)) + self.register_buffer("scales", torch.randn( + in_features // group_size, out_features, dtype=torch.float16)) + self.bias = None + + def forward(self, x): + return torch.zeros(*x.shape[:-1], out_features, dtype=x.dtype, device=x.device) + + class GPTQModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.config = FakeConfig() + self.linear = TorchFusedLinear() + + def forward(self, x): + return self.linear(x) + + return GPTQModel(), torch.randn(2, in_features) + + +def test_gptq_torch_fused_convert_keeps_u4(): + """A GPTQ model whose layers report ``QUANT_TYPE == "torch_fused"`` must convert + via the TorchScript path and keep its 4-bit weight packing: the resulting + ov::Model must contain a 4-bit (i4/u4) Constant and no live ``BitwiseRightShift`` + weight-unpacking op.""" + from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder + + model, x = _make_torch_fused_gptq_model() + model.eval() + + # Convert through the frontend directly: TorchScriptPythonDecoder traces the + # model and auto-applies the GPTQ patch, and FrontEnd.convert keeps the u4 + # weight constant produced by the u4_compression_stack fold. The full + # openvino.convert_model MOC pipeline would constant-fold the all-constant + # dequant subgraph of this tiny fixture, hiding the packing under test. + decoder = TorchScriptPythonDecoder(model, example_input=(x,)) + fe = FrontEndManager().load_by_framework("pytorch") + ov_model = fe.convert(fe.load(decoder)) + assert ov_model + + ops = ov_model.get_ops() + type_names = [o.get_type_name() for o in ops] + # The GPTQ unpacking must have been folded away (no runtime bit-shift unpacking). + assert "BitwiseRightShift" not in type_names + # ...and the weights must be stored as a packed 4-bit constant. + four_bit_consts = [o for o in ops + if o.get_type_name() == "Constant" + and o.get_output_element_type(0) in (Type.i4, Type.u4)] + assert four_bit_consts, "expected a packed 4-bit (i4/u4) weight constant" + + +def test_gptq_torch_fused_export_supported(): + """``patch_quantized_for_export`` must accept ``QUANT_TYPE == "torch_fused"`` + rather than raising ``ValueError`` for the unsupported quant type.""" + from openvino.frontend.pytorch.quantized import ( + patch_quantized_for_export, unpatch_quantized_for_export) + + model, _ = _make_torch_fused_gptq_model() + + patch_quantized_for_export(model) # must not raise + try: + assert hasattr(model.linear, "_openvino_quantized_patch_orig_forward") + finally: + unpatch_quantized_for_export(model) + + # ────────────────────────────────────────────────────────────────────── # Tests for dynamo=True auto-patching of quantized models # ────────────────────────────────────────────────────────────────────── diff --git a/tests/model_hub_tests/pytorch/envs/llm.txt b/tests/model_hub_tests/pytorch/envs/llm.txt index d83136f8536d56..aa1de5617bbc3c 100644 --- a/tests/model_hub_tests/pytorch/envs/llm.txt +++ b/tests/model_hub_tests/pytorch/envs/llm.txt @@ -1,16 +1,26 @@ # Extra dependencies for test_llm.py (LLM quantized models) # These are NOT needed by test_hf_transformers.py +# +# Versions below are hard-pinned (==) intentionally. transformers/gptqmodel are bumped +# deliberately, not automatically: a silent upgrade changes the GPTQ backend selection +# (e.g. gptqmodel auto-selecting TorchFusedQuantLinear -> QUANT_TYPE "torch_fused") and the +# generated graph, which previously broke conversion. Bump these together and re-validate +# the opt_gptq entry in test_llm.py before raising them. transformers==5.5.3 huggingface-hub==1.10.1 +# kernels (and kernels-data) are pulled transitively by transformers and gptqmodel, neither of +# which caps the version. kernels>=0.15 made LayerRepository require a version/revision, which +# transformers 5.5.3's hub_kernels.py constructs without -> ImportError at "import transformers". +# Keep at the validated 0.14.1. +kernels==0.14.1 +kernels-data==0.14.1 + # quantized model deps autoawq==0.2.9; platform_system == "Linux" and platform_machine == "x86_64" triton==3.6.0; platform_system == "Linux" and platform_machine == "x86_64" gptqmodel==6.0.3; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.12" -# `gptqmodel` depends on `kernels`, but doesn't have upper boundary for version, which caused test failures after -# `kernels` was updated to 0.15.1 -kernels==0.14.1; platform_machine == "x86_64" and python_version < "3.12" peft==0.18.1; platform_system == "Linux" and platform_machine == "x86_64" and python_version < "3.12" diff --git a/tests/model_hub_tests/pytorch/test_llm.py b/tests/model_hub_tests/pytorch/test_llm.py index c3277efe5e98c5..f401824fc5a5cf 100644 --- a/tests/model_hub_tests/pytorch/test_llm.py +++ b/tests/model_hub_tests/pytorch/test_llm.py @@ -596,7 +596,7 @@ def get_supported_precommit_models(): ] if platform.machine() not in ['arm', 'armv7l', 'aarch64', 'arm64', 'ARM64']: models.extend([ - #("opt_gptq", "katuni4ka/opt-125m-gptq"), + ("opt_gptq", "katuni4ka/opt-125m-gptq"), ("llama", "TinyLlama/TinyLlama-1.1B-Chat-v1.0"), ("llama_awq", "casperhansen/tinyllama-1b-awq"), ]) @@ -658,7 +658,7 @@ def get_supported_export_precommit_models(): return [] return [ ("llama_awq", "casperhansen/tinyllama-1b-awq"), - #("opt_gptq", "katuni4ka/opt-125m-gptq"), + ("opt_gptq", "katuni4ka/opt-125m-gptq"), ] @pytest.mark.parametrize("type,name", get_supported_export_precommit_models())