From f5b37e8d51d736975b3ed6cd50f588ff00e8918d Mon Sep 17 00:00:00 2001 From: Winston Kuo Date: Wed, 10 Jun 2026 13:35:43 +0800 Subject: [PATCH] Qualcomm AI Engine Direct - Fix HF Qwen --- .../oss_scripts/llm_utils/decoder_model_wrapper.py | 13 +++++++++---- examples/qualcomm/oss_scripts/qwen2_5/qwen2_5.py | 5 ++--- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/examples/qualcomm/oss_scripts/llm_utils/decoder_model_wrapper.py b/examples/qualcomm/oss_scripts/llm_utils/decoder_model_wrapper.py index f59dc548c44..8dc334baf28 100644 --- a/examples/qualcomm/oss_scripts/llm_utils/decoder_model_wrapper.py +++ b/examples/qualcomm/oss_scripts/llm_utils/decoder_model_wrapper.py @@ -38,7 +38,10 @@ def save_config_to_constant_methods( # Check for cache_config and its attributes cache_config = getattr(generation_config, "cache_config", None) if cache_config is not None: - max_seq_len = getattr(cache_config, "max_cache_len", None) + if isinstance(cache_config, dict): + max_seq_len = cache_config.get("max_cache_len", None) + else: + max_seq_len = getattr(cache_config, "max_cache_len", None) if max_seq_len is not None: metadata["get_max_seq_len"] = max_seq_len @@ -115,7 +118,7 @@ def _qnn_attention_mask( # Simplest and most efficient way to obtain a causal mask causal_mask = kv_arange <= reshaped_cache_position - atten_mask = torch.full((causal_mask.shape[0], kv_length), torch.tensor(-65504.0)) + atten_mask = torch.full((causal_mask.shape[0], kv_length), -65504.0) atten_mask = atten_mask.masked_fill(causal_mask, 0) atten_mask = atten_mask[None, None, :, :].expand(batch_size, -1, -1, -1) @@ -133,7 +136,7 @@ def __init__(self, model): logging.info(f"Metadata to be recorded in PTE: {self._metadata}") self.exportable_module = TorchExportableModuleForDecoderOnlyLM( self.model, - max_batch_size=1, + batch_size=1, max_cache_len=self._metadata.get("get_max_seq_len"), ) self._register_attention_mask_for_4_53(self.exportable_module) @@ -154,7 +157,9 @@ def get_example_inputs(self): return (example_input_ids, example_cache_position) def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor): - return self.exportable_module(input_ids, cache_position) + return self.exportable_module( + input_ids=input_ids, cache_position=cache_position + ) def get_metadata(self): return self._metadata diff --git a/examples/qualcomm/oss_scripts/qwen2_5/qwen2_5.py b/examples/qualcomm/oss_scripts/qwen2_5/qwen2_5.py index 70641af8fb7..7876a5b54b3 100644 --- a/examples/qualcomm/oss_scripts/qwen2_5/qwen2_5.py +++ b/examples/qualcomm/oss_scripts/qwen2_5/qwen2_5.py @@ -14,7 +14,6 @@ import torch from executorch.backends.qualcomm.export_utils import ( - get_backend_type, QnnConfig, setup_common_args_and_variables, SimpleADB, @@ -75,7 +74,7 @@ def compile(args: argparse.Namespace, qnn_config: QnnConfig): # noqa: C901 args.calibration_limit, args.prompt, tokenizer_json_path, - get_backend_type(qnn_config.backend), + qnn_config.backend, qnn_config.soc_model, ) @@ -158,7 +157,7 @@ def post_process(): runner="examples/models/llama/llama_main", ) # No pregen inputs, input_list is not required - adb.push(inputs=[], input_list="", files=[tokenizer_json_path]) + adb.push(inputs=[], files=[tokenizer_json_path]) adb.execute(custom_runner_cmd=runner_cmd) adb.pull(host_output_path=args.artifact, callback=post_process)