From 4c9fa14cf59ab004a6499d267094bae72ba822c1 Mon Sep 17 00:00:00 2001 From: Songhao Jia Date: Wed, 17 Jun 2026 01:26:06 -0700 Subject: [PATCH] Skip LpaiPartitionFallbackSupport when compiler_specs is None (forward fix for D108707519) (#20324) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Forward fix for D108707519 (Qualcomm AI Engine Direct - LPAI Support Partition, #19835). The new LpaiPartitionFallbackSupport pass crashes when run without QNN compiler specs: lpai_partition_fallback_support.py:153 get_unsupported_nodes op_validator = QnnOperatorSupport(compiler_specs=self.compiler_specs, ...) TypeError: 'NoneType' object is not iterable The edge-transform pass framework injects compiler_specs into passes but defaults it to None (QnnPassManager.get_to_edge_transform_passes(compiler_specs= None)). Lowering paths that run the edge transform without partitioner-supplied specs — e.g. the modai LPAI offline-compile path (modai/test:test_qualcomm_lpai_recipes::test_export_and_lower_8a8w_writes_pte) — reach this pass with compiler_specs=None, and QnnOperatorSupport then iterates the None specs and aborts the whole to_edge_transform_and_lower. The pass fundamentally needs compiler specs to decide which nodes are LPAI- unsupported, so with no specs there is nothing to validate against. Skip the pass (return PassResult(.., modified=False)) in that case, restoring the behavior from before this pass existed for spec-less paths. On-device paths that do supply compiler_specs are unaffected. Differential Revision: D108853837 --- backends/qualcomm/_passes/lpai_partition_fallback_support.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/backends/qualcomm/_passes/lpai_partition_fallback_support.py b/backends/qualcomm/_passes/lpai_partition_fallback_support.py index 50270e167a7..02c17f92c20 100644 --- a/backends/qualcomm/_passes/lpai_partition_fallback_support.py +++ b/backends/qualcomm/_passes/lpai_partition_fallback_support.py @@ -314,6 +314,8 @@ def handle_back_to_back_nodes(self, graph_module: torch.fx.GraphModule): graph_module.graph.eliminate_dead_code() def call(self, graph_module: torch.fx.GraphModule) -> PassResult: + if self.compiler_specs is None: + return PassResult(graph_module, False) self.preserve_io_qdq(graph_module) unsupported_nodes = self.get_unsupported_nodes(graph_module) for node in unsupported_nodes: