From b80ff63299fa5d8c66f7e6bb267c80354ba69f59 Mon Sep 17 00:00:00 2001 From: Justin Chu <11205048+justinchuby@users.noreply.github.com> Date: Mon, 1 Jun 2026 19:02:53 +0000 Subject: [PATCH 1/3] Add UnstackGatherSqueezeWeight surgeon for static-unroll MoE quantization Mobius emits static-unroll Mixture-of-Experts (MoE) blocks as: fc1 = Squeeze(Gather(W_3d, [const_idx], axis=0), [0]) proj = MatMul(x, Transpose(fc1)) because the fused 'com.microsoft::MoE' op only supports GPT-OSS-style SwiGLU (interleaved, alpha=1.702, limit=7.0), which is incompatible with the standard concatenated SwiGLU used by models like Gemma 4 26B-A4B. See microsoft/onnxruntime#28738. When the per-expert MatMul.B is the result of a Gather/Squeeze chain (plus optional Transpose), it is not a graph-level initializer, so weight-quantization passes such as OnnxKQuantQuantization skip it. This leaves the per-expert weights at the model's compute dtype (e.g. fp16), yielding only a ~6% size reduction for Gemma 4 26B-A4B instead of the expected 4-bit ratio. Add an UnstackGatherSqueezeWeight surgeon that: 1. Detects 'Gather(W_3d_init, [const_int], axis=0) -> Squeeze([0])' subgraphs. 2. Materialises the slice 'W_3d[const_int]' as a new 2-D initializer named '__slice_'. 3. Folds any subsequent Transpose into a transposed initializer ('__slice___T'). 4. Rewires downstream consumers and drops the now-orphaned 3-D initializer and Gather/Squeeze/Transpose nodes. After this surgery runs, OnnxKQuantQuantization picks up every per-expert MatMul without any changes to the quantization pass itself. Numerical parity with the original graph is verified by the accompanying unit test. Fixes: microsoft/Olive#2489 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com> --- olive/passes/onnx/graph_surgeries.py | 246 +++++++++++++++++++++++ test/passes/onnx/test_graph_surgeries.py | 70 +++++++ 2 files changed, 316 insertions(+) diff --git a/olive/passes/onnx/graph_surgeries.py b/olive/passes/onnx/graph_surgeries.py index 61f80a47f..e63f152cc 100644 --- a/olive/passes/onnx/graph_surgeries.py +++ b/olive/passes/onnx/graph_surgeries.py @@ -3259,6 +3259,252 @@ def call_ir(self, model: ir.Model) -> ir.Model: return model +class UnstackGatherSqueezeWeight(Surgeon): + """Materialise per-expert (or per-slice) 2-D weights from stacked 3-D initializers. + + Detects the pattern emitted by static-unroll Mixture-of-Experts (MoE) + blocks:: + + W_3d : ir.Tensor = [E, K, N] (initializer) + idx = [k] (1-element int64 initializer) + gathered = Gather(W_3d, idx, axis=0) # [1, K, N] + unstacked = Squeeze(gathered, [0]) # [K, N] + ... = MatMul(x, unstacked) # B input is not an initializer + + and rewrites each match by: + + 1. Slicing ``W_3d`` along axis 0 at the constant index ``k``. + 2. Materialising the resulting ``[K, N]`` slice as a brand-new + initializer named ``__expert_``. + 3. Replacing the ``Squeeze`` output's downstream uses with that + initializer directly. + 4. Cleaning up the orphaned ``Gather`` and ``Squeeze`` nodes; the + 3-D weight itself is dropped only if all its uses have been + rewritten. + + After this pass, the downstream ``MatMul`` nodes have a 2-D static + initializer as their B input, which means weight-quantization passes + such as :class:`OnnxKQuantQuantization` and + :class:`OnnxBlockWiseRtnQuantization` pick them up automatically. + + Use case: + Run before any weight-quantization pass on a model that emits + unrolled per-expert MoE dispatch (e.g. Gemma 4 26B-A4B from + mobius, where the fused ``com.microsoft::MoE`` op is bypassed + due to a SwiGLU-kernel incompatibility). Without this pass the + per-expert MatMuls remain at the model's compute dtype because + the quantizer requires a 2-D static initializer as the B input. + + Example surgery list entry:: + + {"surgeon": "UnstackGatherSqueezeWeight"} + """ + + def call_ir(self, model: ir.Model) -> ir.Model: + graph = model.graph + # Build name → ir.Value initializer map for quick lookup. + initializers: dict[str, ir.Value] = dict(graph.initializers) + + new_initializers: dict[str, ir.Value] = {} + rewrites: list[tuple[ir.Node, ir.Node, ir.Value]] = [] + # (gather_node, squeeze_node, replacement_initializer_value) + + for node in list(graph.all_nodes()): + if node.op_type != "Squeeze": + continue + if len(node.inputs) < 1 or node.inputs[0] is None: + continue + squeeze_axes = self._scalar_int_list(node, idx=1, initializers=initializers) + if squeeze_axes != [0]: + continue + + gather_value = node.inputs[0] + gather_node = gather_value.producer() + if gather_node is None or gather_node.op_type != "Gather": + continue + if int(gather_node.attributes.get("axis", ir.AttrInt64("axis", 0)).value) != 0: + continue + if len(gather_node.inputs) < 2 or any(i is None for i in gather_node.inputs[:2]): + continue + + data_value = gather_node.inputs[0] + index_value = gather_node.inputs[1] + if data_value.name not in initializers: + continue + data_init = initializers[data_value.name] + if data_init.const_value is None: + continue + data_array = data_init.const_value.numpy() + if data_array.ndim != 3: + continue + + idx_list = self._scalar_int_list(gather_node, idx=1, initializers=initializers) + if idx_list is None or len(idx_list) != 1: + continue + slice_idx = idx_list[0] + if not 0 <= slice_idx < data_array.shape[0]: + continue + + slice_array = data_array[slice_idx] + slice_name = f"{data_value.name}__slice_{slice_idx}" + if slice_name in new_initializers: + replacement = new_initializers[slice_name] + else: + slice_tensor = ir.Tensor(slice_array, name=slice_name) + replacement = ir.Value( + name=slice_name, + type=ir.TensorType(slice_tensor.dtype), + shape=ir.Shape(slice_array.shape), + const_value=slice_tensor, + ) + new_initializers[slice_name] = replacement + + rewrites.append((gather_node, node, replacement)) + + if not rewrites: + return model + + source_init_names: set[str] = { + g.inputs[0].name for g, _, _ in rewrites if g.inputs[0] is not None + } + + # Register all new initializers first. + for value in new_initializers.values(): + graph.register_initializer(value) + + # Rewire each Squeeze output to the new initializer and remove + # the now-unused Gather + Squeeze nodes. Indices initializer is + # left in place: it may be tiny, and downstream cleanup passes + # (e.g. RemoveUnusedOpsetsPass) handle orphan removal. + for gather_node, squeeze_node, replacement in rewrites: + ir.convenience.replace_all_uses_with(squeeze_node.outputs[0], replacement) + graph.remove(squeeze_node, safe=True) + # Gather may still be referenced if another Squeeze hadn't been + # rewritten yet (different downstream); only remove if it's now + # orphaned. + gather_uses = [ + u for u in graph.all_nodes() if any(i is gather_node.outputs[0] for i in u.inputs) + ] + if not gather_uses: + graph.remove(gather_node, safe=True) + + # Drop any 3-D source initializers that no longer have any + # consumers in the graph after rewrite. + used_init_names: set[str] = set() + for node in graph.all_nodes(): + for inp in node.inputs: + if inp is not None and inp.name: + used_init_names.add(inp.name) + for out in graph.outputs: + if out is not None and out.name: + used_init_names.add(out.name) + for name in [n for n in graph.initializers if n not in used_init_names]: + del graph.initializers[name] + + # Fold any Transpose(initializer) chain that consumes a newly + # materialised slice. This keeps the downstream MatMul.B as a + # static 2-D initializer so weight-quantization passes such as + # OnnxKQuantQuantization (which require ``inputs[1].is_initializer()``) + # pick it up. The most common producer is the mobius MoE + # fallback emitting ``MatMul(x, Transpose(Squeeze(Gather(...))))``. + folded_count = self._fold_transpose_of_initializer(graph, new_initializers) + + # Final orphan-initializer cleanup: slices that are now only + # referenced through a folded Transpose's pre-image become + # unused and should be dropped. + used_init_names = set() + for node in graph.all_nodes(): + for inp in node.inputs: + if inp is not None and inp.name: + used_init_names.add(inp.name) + for out in graph.outputs: + if out is not None and out.name: + used_init_names.add(out.name) + for name in [n for n in graph.initializers if n not in used_init_names]: + del graph.initializers[name] + + logger.info( + "UnstackGatherSqueezeWeight: materialised %d per-slice initializers from " + "%d stacked source(s); rewrote %d Gather→Squeeze chains; folded %d Transpose ops.", + len(new_initializers), + len(source_init_names), + len(rewrites), + folded_count, + ) + return model + + @staticmethod + def _fold_transpose_of_initializer( + graph: ir.Graph, candidate_initializers: dict[str, ir.Value] + ) -> int: + """Fold ``Transpose(initializer) → matmul.B`` into a new transposed + initializer for any Transpose whose input is in ``candidate_initializers``. + + Returns the number of Transpose nodes folded. + """ + folded = 0 + for node in list(graph.all_nodes()): + if node.op_type != "Transpose": + continue + if len(node.inputs) != 1 or node.inputs[0] is None: + continue + src = node.inputs[0] + if src.name not in candidate_initializers: + continue + if src.const_value is None: + continue + perm_attr = node.attributes.get("perm") + arr = src.const_value.numpy() + perm = list(perm_attr.value) if perm_attr is not None else list(range(arr.ndim))[::-1] + transposed = np.transpose(arr, perm) + new_name = f"{src.name}__T" + if new_name in graph.initializers: + new_value = graph.initializers[new_name] + else: + new_tensor = ir.Tensor(transposed, name=new_name) + new_value = ir.Value( + name=new_name, + type=ir.TensorType(new_tensor.dtype), + shape=ir.Shape(transposed.shape), + const_value=new_tensor, + ) + graph.register_initializer(new_value) + ir.convenience.replace_all_uses_with(node.outputs[0], new_value) + graph.remove(node, safe=True) + folded += 1 + return folded + + @staticmethod + def _scalar_int_list( + node: ir.Node, idx: int, initializers: dict[str, ir.Value] + ) -> list[int] | None: + """Resolve ``node.inputs[idx]`` to a list[int] if it's a static + 1-D int initializer; return None otherwise. + + Handles both standalone initializers and ``Constant`` nodes + feeding the input. + """ + if idx >= len(node.inputs) or node.inputs[idx] is None: + return None + value = node.inputs[idx] + # Initializer path + if value.name in initializers and initializers[value.name].const_value is not None: + arr = initializers[value.name].const_value.numpy() + return arr.reshape(-1).astype(np.int64).tolist() + # Inline Constant node path + producer = value.producer() + if producer is not None and producer.op_type == "Constant": + for attr in producer.attributes.values(): + if attr.name == "value" and isinstance(attr, ir.AttrTensor): + arr = attr.value.numpy() + return arr.reshape(-1).astype(np.int64).tolist() + if attr.name == "value_ints" and isinstance(attr, ir.AttrInt64s): + return list(attr.value) + if attr.name == "value_int" and isinstance(attr, ir.AttrInt64): + return [int(attr.value)] + return None + + class GraphSurgeries(Pass): """ONNX graph surgeries collections. diff --git a/test/passes/onnx/test_graph_surgeries.py b/test/passes/onnx/test_graph_surgeries.py index ad44db040..d36487f3a 100644 --- a/test/passes/onnx/test_graph_surgeries.py +++ b/test/passes/onnx/test_graph_surgeries.py @@ -3050,3 +3050,73 @@ def test_remove_memcpy_chained(tmp_path): x = np.random.randn(4, 8).astype(np.float32) result = sess.run(["output"], {"input": x})[0] np.testing.assert_allclose(result, np.maximum(x, 0), atol=1e-6) + + +def test_unstack_gather_squeeze_weight(tmp_path): + """UnstackGatherSqueezeWeight should materialise per-slice 2-D + initializers from ``Gather(W_3d, [const], axis=0) → Squeeze([0])`` chains + and fold an optional downstream ``Transpose``. + """ + input_tensor = helper.make_tensor_value_info("x", TensorProto.FLOAT, [4, 6]) + out0 = helper.make_tensor_value_info("y0", TensorProto.FLOAT, [4, 8]) + out2 = helper.make_tensor_value_info("y2", TensorProto.FLOAT, [4, 8]) + + w3d = np.random.RandomState(0).randn(3, 8, 6).astype(np.float32) + w3d_init = numpy_helper.from_array(w3d, name="W3d") + idx0_init = numpy_helper.from_array(np.array([0], dtype=np.int64), name="idx0") + idx2_init = numpy_helper.from_array(np.array([2], dtype=np.int64), name="idx2") + ax_init = numpy_helper.from_array(np.array([0], dtype=np.int64), name="ax0") + + def chain(suffix, idx_name, y_name): + return [ + helper.make_node("Gather", ["W3d", idx_name], [f"g_{suffix}"], axis=0), + helper.make_node("Squeeze", [f"g_{suffix}", "ax0"], [f"s_{suffix}"]), + helper.make_node("Transpose", [f"s_{suffix}"], [f"t_{suffix}"], perm=[1, 0]), + helper.make_node("MatMul", ["x", f"t_{suffix}"], [y_name]), + ] + + nodes = chain("a", "idx0", "y0") + chain("b", "idx2", "y2") + graph = helper.make_graph( + nodes=nodes, + name="UnstackTest", + inputs=[input_tensor], + outputs=[out0, out2], + initializer=[w3d_init, idx0_init, idx2_init, ax_init], + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 20)]) + model.ir_version = 10 + model_path = tmp_path / "model.onnx" + onnx.save(model, model_path) + + input_model = ONNXModelHandler(model_path=str(model_path)) + output_folder = str(tmp_path / "out") + p = create_pass_from_dict( + GraphSurgeries, + {"surgeries": [{"surgeon": "UnstackGatherSqueezeWeight"}]}, + disable_search=True, + ) + output_model = p.run(input_model, output_folder) + g = output_model.load_model().graph + + op_types = [n.op_type for n in g.node] + assert op_types == ["MatMul", "MatMul"], op_types + init_names = {i.name for i in g.initializer} + assert "W3d" not in init_names + assert "W3d__slice_0__T" in init_names + assert "W3d__slice_2__T" in init_names + + # Verify the materialised initializers match np.transpose of the + # original slices. + inits = {i.name: numpy_helper.to_array(i) for i in g.initializer} + np.testing.assert_array_equal(inits["W3d__slice_0__T"], w3d[0].T) + np.testing.assert_array_equal(inits["W3d__slice_2__T"], w3d[2].T) + + # Numerical parity: rewritten model should produce the same output + # as the original. + sess_orig = InferenceSession(str(model_path), providers=["CPUExecutionProvider"]) + sess_new = InferenceSession(output_model.model_path, providers=["CPUExecutionProvider"]) + x = np.random.RandomState(1).randn(4, 6).astype(np.float32) + y0_orig, y2_orig = sess_orig.run(None, {"x": x}) + y0_new, y2_new = sess_new.run(None, {"x": x}) + np.testing.assert_allclose(y0_new, y0_orig, atol=1e-6) + np.testing.assert_allclose(y2_new, y2_orig, atol=1e-6) From 97b176c6e2f3fd1be4bb39ea5bee7b5d8669d0a2 Mon Sep 17 00:00:00 2001 From: Justin Chu <11205048+justinchuby@users.noreply.github.com> Date: Mon, 1 Jun 2026 19:25:33 +0000 Subject: [PATCH 2/3] Address PR review feedback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Remove unused index_value local (PYLINT W0612 / RUFF F841). - Replace isinstance(attr, ir.AttrTensor/AttrInt64s/AttrInt64) with attr.type == ir.AttributeType.{TENSOR,INTS,INT}; the ir.Attr* callables are factory functions, not types (PYLINT W1116). - Drop the redundant first orphan-initializer sweep; the post-Transpose sweep already covers every initializer that becomes unused at any point during the rewrite, and runs once at the end. - Update the misleading 'left in place' comment about index initializers — they are in fact dropped by the orphan sweep, not by a downstream pass. - Sync the docstring with the implementation: name slices '__slice_' (not '__expert_') and document the Transpose-folding step explicitly. - Replace ir.AttrInt64(...) default-construction trick with a None-check on attributes.get('axis'), removing a needless object allocation per Gather inspection. - Fix RUFF D205 docstring summary/description spacing in three places. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com> --- olive/passes/onnx/graph_surgeries.py | 90 ++++++++++-------------- test/passes/onnx/test_graph_surgeries.py | 9 ++- 2 files changed, 45 insertions(+), 54 deletions(-) diff --git a/olive/passes/onnx/graph_surgeries.py b/olive/passes/onnx/graph_surgeries.py index e63f152cc..82614a1bf 100644 --- a/olive/passes/onnx/graph_surgeries.py +++ b/olive/passes/onnx/graph_surgeries.py @@ -3275,12 +3275,17 @@ class UnstackGatherSqueezeWeight(Surgeon): 1. Slicing ``W_3d`` along axis 0 at the constant index ``k``. 2. Materialising the resulting ``[K, N]`` slice as a brand-new - initializer named ``__expert_``. + initializer named ``__slice_``. 3. Replacing the ``Squeeze`` output's downstream uses with that initializer directly. - 4. Cleaning up the orphaned ``Gather`` and ``Squeeze`` nodes; the - 3-D weight itself is dropped only if all its uses have been - rewritten. + 4. Folding any downstream ``Transpose(slice_initializer)`` into a + pre-transposed initializer named ``__slice___T``. + This is required for the common mobius emission pattern + ``MatMul(x, Transpose(Squeeze(Gather(...))))``. + 5. Removing the orphaned ``Gather``, ``Squeeze``, and ``Transpose`` + nodes plus any initializer (3-D weight, per-Gather index, shared + squeeze-axes constant, intermediate slice) that no longer has any + node consumer. After this pass, the downstream ``MatMul`` nodes have a 2-D static initializer as their B input, which means weight-quantization passes @@ -3322,13 +3327,14 @@ def call_ir(self, model: ir.Model) -> ir.Model: gather_node = gather_value.producer() if gather_node is None or gather_node.op_type != "Gather": continue - if int(gather_node.attributes.get("axis", ir.AttrInt64("axis", 0)).value) != 0: + axis_attr = gather_node.attributes.get("axis") + axis_val = int(axis_attr.value) if axis_attr is not None else 0 + if axis_val != 0: continue if len(gather_node.inputs) < 2 or any(i is None for i in gather_node.inputs[:2]): continue data_value = gather_node.inputs[0] - index_value = gather_node.inputs[1] if data_value.name not in initializers: continue data_init = initializers[data_value.name] @@ -3364,43 +3370,23 @@ def call_ir(self, model: ir.Model) -> ir.Model: if not rewrites: return model - source_init_names: set[str] = { - g.inputs[0].name for g, _, _ in rewrites if g.inputs[0] is not None - } + source_init_names: set[str] = {g.inputs[0].name for g, _, _ in rewrites if g.inputs[0] is not None} # Register all new initializers first. for value in new_initializers.values(): graph.register_initializer(value) # Rewire each Squeeze output to the new initializer and remove - # the now-unused Gather + Squeeze nodes. Indices initializer is - # left in place: it may be tiny, and downstream cleanup passes - # (e.g. RemoveUnusedOpsetsPass) handle orphan removal. + # the now-unused Gather + Squeeze nodes. for gather_node, squeeze_node, replacement in rewrites: ir.convenience.replace_all_uses_with(squeeze_node.outputs[0], replacement) graph.remove(squeeze_node, safe=True) - # Gather may still be referenced if another Squeeze hadn't been - # rewritten yet (different downstream); only remove if it's now - # orphaned. - gather_uses = [ - u for u in graph.all_nodes() if any(i is gather_node.outputs[0] for i in u.inputs) - ] + # A single Gather output may feed multiple Squeeze chains; only + # remove the Gather node once all its consumers are gone. + gather_uses = [u for u in graph.all_nodes() if any(i is gather_node.outputs[0] for i in u.inputs)] if not gather_uses: graph.remove(gather_node, safe=True) - # Drop any 3-D source initializers that no longer have any - # consumers in the graph after rewrite. - used_init_names: set[str] = set() - for node in graph.all_nodes(): - for inp in node.inputs: - if inp is not None and inp.name: - used_init_names.add(inp.name) - for out in graph.outputs: - if out is not None and out.name: - used_init_names.add(out.name) - for name in [n for n in graph.initializers if n not in used_init_names]: - del graph.initializers[name] - # Fold any Transpose(initializer) chain that consumes a newly # materialised slice. This keeps the downstream MatMul.B as a # static 2-D initializer so weight-quantization passes such as @@ -3409,10 +3395,12 @@ def call_ir(self, model: ir.Model) -> ir.Model: # fallback emitting ``MatMul(x, Transpose(Squeeze(Gather(...))))``. folded_count = self._fold_transpose_of_initializer(graph, new_initializers) - # Final orphan-initializer cleanup: slices that are now only - # referenced through a folded Transpose's pre-image become - # unused and should be dropped. - used_init_names = set() + # Orphan-initializer sweep: drop the original 3-D source weights, + # the per-Gather index initializers, the shared squeeze-axes + # constant, and any intermediate per-slice initializer that was + # consumed only by a folded Transpose. Anything still referenced + # by a remaining node or graph output stays. + used_init_names: set[str] = set() for node in graph.all_nodes(): for inp in node.inputs: if inp is not None and inp.name: @@ -3434,11 +3422,12 @@ def call_ir(self, model: ir.Model) -> ir.Model: return model @staticmethod - def _fold_transpose_of_initializer( - graph: ir.Graph, candidate_initializers: dict[str, ir.Value] - ) -> int: - """Fold ``Transpose(initializer) → matmul.B`` into a new transposed - initializer for any Transpose whose input is in ``candidate_initializers``. + def _fold_transpose_of_initializer(graph: ir.Graph, candidate_initializers: dict[str, ir.Value]) -> int: + """Fold ``Transpose(initializer) → matmul.B`` chains into transposed initializers. + + Replaces any ``Transpose`` node whose input is in + ``candidate_initializers`` with a pre-transposed initializer and + rewires the downstream consumers to use it directly. Returns the number of Transpose nodes folded. """ @@ -3475,14 +3464,13 @@ def _fold_transpose_of_initializer( return folded @staticmethod - def _scalar_int_list( - node: ir.Node, idx: int, initializers: dict[str, ir.Value] - ) -> list[int] | None: - """Resolve ``node.inputs[idx]`` to a list[int] if it's a static - 1-D int initializer; return None otherwise. - - Handles both standalone initializers and ``Constant`` nodes - feeding the input. + def _scalar_int_list(node: ir.Node, idx: int, initializers: dict[str, ir.Value]) -> list[int] | None: + """Resolve ``node.inputs[idx]`` to a ``list[int]`` if it is a static 1-D int source. + + Returns the resolved list when ``node.inputs[idx]`` is either a + static integer initializer or the output of an inline + ``Constant`` node carrying ``value`` / ``value_ints`` / + ``value_int``; otherwise returns ``None``. """ if idx >= len(node.inputs) or node.inputs[idx] is None: return None @@ -3495,12 +3483,12 @@ def _scalar_int_list( producer = value.producer() if producer is not None and producer.op_type == "Constant": for attr in producer.attributes.values(): - if attr.name == "value" and isinstance(attr, ir.AttrTensor): + if attr.name == "value" and attr.type == ir.AttributeType.TENSOR: arr = attr.value.numpy() return arr.reshape(-1).astype(np.int64).tolist() - if attr.name == "value_ints" and isinstance(attr, ir.AttrInt64s): + if attr.name == "value_ints" and attr.type == ir.AttributeType.INTS: return list(attr.value) - if attr.name == "value_int" and isinstance(attr, ir.AttrInt64): + if attr.name == "value_int" and attr.type == ir.AttributeType.INT: return [int(attr.value)] return None diff --git a/test/passes/onnx/test_graph_surgeries.py b/test/passes/onnx/test_graph_surgeries.py index d36487f3a..b5c4ce432 100644 --- a/test/passes/onnx/test_graph_surgeries.py +++ b/test/passes/onnx/test_graph_surgeries.py @@ -3053,9 +3053,12 @@ def test_remove_memcpy_chained(tmp_path): def test_unstack_gather_squeeze_weight(tmp_path): - """UnstackGatherSqueezeWeight should materialise per-slice 2-D - initializers from ``Gather(W_3d, [const], axis=0) → Squeeze([0])`` chains - and fold an optional downstream ``Transpose``. + """Verify per-slice 2-D weight materialisation and Transpose folding. + + ``UnstackGatherSqueezeWeight`` should rewrite each + ``Gather(W_3d, [const], axis=0) → Squeeze([0])`` chain into a 2-D + initializer and fold any downstream ``Transpose`` into a + pre-transposed initializer. """ input_tensor = helper.make_tensor_value_info("x", TensorProto.FLOAT, [4, 6]) out0 = helper.make_tensor_value_info("y0", TensorProto.FLOAT, [4, 8]) From 0c073ff456a64016693a1d675738979b85cbc977 Mon Sep 17 00:00:00 2001 From: Justin Chu <11205048+justinchuby@users.noreply.github.com> Date: Tue, 2 Jun 2026 00:18:05 +0000 Subject: [PATCH 3/3] Refactor UnstackGatherSqueezeWeight: free functions + onnx_ir DCE MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Move `_fold_transpose_of_initializer` and `_scalar_int_list` from `@staticmethod`s on `UnstackGatherSqueezeWeight` to module-level private functions, per Google's Python style guide preference for free functions over static methods when no class state is involved. - Extract the Gather→Squeeze match logic into a new `_try_match_gather_squeeze` free function so `call_ir` becomes a thin orchestrator that scans the graph and dispatches to the helpers. - Replace the open-coded orphan-initializer sweep with `onnx_ir.passes.common.RemoveUnusedNodesPass`. That pass also removes dead nodes for free, and using the upstream IR pass keeps this surgeon consistent with other IR consumers. No behaviour change: `test_unstack_gather_squeeze_weight` still passes and the full 83-test `test_graph_surgeries` suite is green. Lintrunner clean. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com> --- olive/passes/onnx/graph_surgeries.py | 292 ++++++++++++++------------- 1 file changed, 156 insertions(+), 136 deletions(-) diff --git a/olive/passes/onnx/graph_surgeries.py b/olive/passes/onnx/graph_surgeries.py index 82614a1bf..4edfd3db5 100644 --- a/olive/passes/onnx/graph_surgeries.py +++ b/olive/passes/onnx/graph_surgeries.py @@ -18,7 +18,12 @@ import onnxscript from onnx import ModelProto, TensorProto from onnx.helper import make_tensor -from onnx_ir.passes.common import DeduplicateHashedInitializersPass, InlinePass, RemoveUnusedOpsetsPass +from onnx_ir.passes.common import ( + DeduplicateHashedInitializersPass, + InlinePass, + RemoveUnusedNodesPass, + RemoveUnusedOpsetsPass, +) from onnxscript import ir, rewriter from onnxscript.rewriter import pattern @@ -3306,66 +3311,20 @@ class UnstackGatherSqueezeWeight(Surgeon): """ def call_ir(self, model: ir.Model) -> ir.Model: + # Defer to module-level free functions (Google style preference + # for free functions over static methods when no class state is + # involved). graph = model.graph - # Build name → ir.Value initializer map for quick lookup. initializers: dict[str, ir.Value] = dict(graph.initializers) new_initializers: dict[str, ir.Value] = {} - rewrites: list[tuple[ir.Node, ir.Node, ir.Value]] = [] # (gather_node, squeeze_node, replacement_initializer_value) + rewrites: list[tuple[ir.Node, ir.Node, ir.Value]] = [] for node in list(graph.all_nodes()): - if node.op_type != "Squeeze": - continue - if len(node.inputs) < 1 or node.inputs[0] is None: - continue - squeeze_axes = self._scalar_int_list(node, idx=1, initializers=initializers) - if squeeze_axes != [0]: - continue - - gather_value = node.inputs[0] - gather_node = gather_value.producer() - if gather_node is None or gather_node.op_type != "Gather": - continue - axis_attr = gather_node.attributes.get("axis") - axis_val = int(axis_attr.value) if axis_attr is not None else 0 - if axis_val != 0: - continue - if len(gather_node.inputs) < 2 or any(i is None for i in gather_node.inputs[:2]): - continue - - data_value = gather_node.inputs[0] - if data_value.name not in initializers: - continue - data_init = initializers[data_value.name] - if data_init.const_value is None: - continue - data_array = data_init.const_value.numpy() - if data_array.ndim != 3: - continue - - idx_list = self._scalar_int_list(gather_node, idx=1, initializers=initializers) - if idx_list is None or len(idx_list) != 1: - continue - slice_idx = idx_list[0] - if not 0 <= slice_idx < data_array.shape[0]: - continue - - slice_array = data_array[slice_idx] - slice_name = f"{data_value.name}__slice_{slice_idx}" - if slice_name in new_initializers: - replacement = new_initializers[slice_name] - else: - slice_tensor = ir.Tensor(slice_array, name=slice_name) - replacement = ir.Value( - name=slice_name, - type=ir.TensorType(slice_tensor.dtype), - shape=ir.Shape(slice_array.shape), - const_value=slice_tensor, - ) - new_initializers[slice_name] = replacement - - rewrites.append((gather_node, node, replacement)) + rewrite = _try_match_gather_squeeze(node, initializers, new_initializers) + if rewrite is not None: + rewrites.append(rewrite) if not rewrites: return model @@ -3393,23 +3352,14 @@ def call_ir(self, model: ir.Model) -> ir.Model: # OnnxKQuantQuantization (which require ``inputs[1].is_initializer()``) # pick it up. The most common producer is the mobius MoE # fallback emitting ``MatMul(x, Transpose(Squeeze(Gather(...))))``. - folded_count = self._fold_transpose_of_initializer(graph, new_initializers) - - # Orphan-initializer sweep: drop the original 3-D source weights, - # the per-Gather index initializers, the shared squeeze-axes - # constant, and any intermediate per-slice initializer that was - # consumed only by a folded Transpose. Anything still referenced - # by a remaining node or graph output stays. - used_init_names: set[str] = set() - for node in graph.all_nodes(): - for inp in node.inputs: - if inp is not None and inp.name: - used_init_names.add(inp.name) - for out in graph.outputs: - if out is not None and out.name: - used_init_names.add(out.name) - for name in [n for n in graph.initializers if n not in used_init_names]: - del graph.initializers[name] + folded_count = _fold_transpose_of_initializer(graph, new_initializers) + + # Defer orphan cleanup to onnx_ir's standard dead-code-elimination + # pass instead of re-implementing the sweep here. This drops the + # original 3-D source weights, the per-Gather index initializers, + # the shared squeeze-axes constant, and any intermediate per-slice + # initializer that was consumed only by a folded Transpose. + RemoveUnusedNodesPass()(model) logger.info( "UnstackGatherSqueezeWeight: materialised %d per-slice initializers from " @@ -3421,76 +3371,146 @@ def call_ir(self, model: ir.Model) -> ir.Model: ) return model - @staticmethod - def _fold_transpose_of_initializer(graph: ir.Graph, candidate_initializers: dict[str, ir.Value]) -> int: - """Fold ``Transpose(initializer) → matmul.B`` chains into transposed initializers. - Replaces any ``Transpose`` node whose input is in - ``candidate_initializers`` with a pre-transposed initializer and - rewires the downstream consumers to use it directly. +# Free functions used by ``UnstackGatherSqueezeWeight``. Kept at module +# scope (not as class methods) per Google's Python style guide: prefer +# module-level helpers when no class state is involved. - Returns the number of Transpose nodes folded. - """ - folded = 0 - for node in list(graph.all_nodes()): - if node.op_type != "Transpose": - continue - if len(node.inputs) != 1 or node.inputs[0] is None: - continue - src = node.inputs[0] - if src.name not in candidate_initializers: - continue - if src.const_value is None: - continue - perm_attr = node.attributes.get("perm") - arr = src.const_value.numpy() - perm = list(perm_attr.value) if perm_attr is not None else list(range(arr.ndim))[::-1] - transposed = np.transpose(arr, perm) - new_name = f"{src.name}__T" - if new_name in graph.initializers: - new_value = graph.initializers[new_name] - else: - new_tensor = ir.Tensor(transposed, name=new_name) - new_value = ir.Value( - name=new_name, - type=ir.TensorType(new_tensor.dtype), - shape=ir.Shape(transposed.shape), - const_value=new_tensor, - ) - graph.register_initializer(new_value) - ir.convenience.replace_all_uses_with(node.outputs[0], new_value) - graph.remove(node, safe=True) - folded += 1 - return folded - @staticmethod - def _scalar_int_list(node: ir.Node, idx: int, initializers: dict[str, ir.Value]) -> list[int] | None: - """Resolve ``node.inputs[idx]`` to a ``list[int]`` if it is a static 1-D int source. +def _try_match_gather_squeeze( + node: ir.Node, + initializers: dict[str, ir.Value], + new_initializers: dict[str, ir.Value], +) -> tuple[ir.Node, ir.Node, ir.Value] | None: + """Match a ``Gather(W_3d, [const], axis=0) → Squeeze([0])`` chain at ``node``. - Returns the resolved list when ``node.inputs[idx]`` is either a - static integer initializer or the output of an inline - ``Constant`` node carrying ``value`` / ``value_ints`` / - ``value_int``; otherwise returns ``None``. - """ - if idx >= len(node.inputs) or node.inputs[idx] is None: - return None - value = node.inputs[idx] - # Initializer path - if value.name in initializers and initializers[value.name].const_value is not None: - arr = initializers[value.name].const_value.numpy() - return arr.reshape(-1).astype(np.int64).tolist() - # Inline Constant node path - producer = value.producer() - if producer is not None and producer.op_type == "Constant": - for attr in producer.attributes.values(): - if attr.name == "value" and attr.type == ir.AttributeType.TENSOR: - arr = attr.value.numpy() - return arr.reshape(-1).astype(np.int64).tolist() - if attr.name == "value_ints" and attr.type == ir.AttributeType.INTS: - return list(attr.value) - if attr.name == "value_int" and attr.type == ir.AttributeType.INT: - return [int(attr.value)] + Returns ``(gather_node, squeeze_node, replacement_initializer)`` on a + successful match, or ``None`` otherwise. On match, the + ``replacement_initializer`` is also recorded in + ``new_initializers`` so callers can register it later. + """ + if node.op_type != "Squeeze": + return None + if len(node.inputs) < 1 or node.inputs[0] is None: + return None + squeeze_axes = _scalar_int_list(node, idx=1, initializers=initializers) + if squeeze_axes != [0]: + return None + + gather_value = node.inputs[0] + gather_node = gather_value.producer() + if gather_node is None or gather_node.op_type != "Gather": + return None + axis_attr = gather_node.attributes.get("axis") + axis_val = int(axis_attr.value) if axis_attr is not None else 0 + if axis_val != 0: + return None + if len(gather_node.inputs) < 2 or any(i is None for i in gather_node.inputs[:2]): + return None + + data_value = gather_node.inputs[0] + if data_value.name not in initializers: + return None + data_init = initializers[data_value.name] + if data_init.const_value is None: + return None + data_array = data_init.const_value.numpy() + if data_array.ndim != 3: + return None + + idx_list = _scalar_int_list(gather_node, idx=1, initializers=initializers) + if idx_list is None or len(idx_list) != 1: + return None + slice_idx = idx_list[0] + if not 0 <= slice_idx < data_array.shape[0]: + return None + + slice_array = data_array[slice_idx] + slice_name = f"{data_value.name}__slice_{slice_idx}" + if slice_name in new_initializers: + replacement = new_initializers[slice_name] + else: + slice_tensor = ir.Tensor(slice_array, name=slice_name) + replacement = ir.Value( + name=slice_name, + type=ir.TensorType(slice_tensor.dtype), + shape=ir.Shape(slice_array.shape), + const_value=slice_tensor, + ) + new_initializers[slice_name] = replacement + + return gather_node, node, replacement + + +def _fold_transpose_of_initializer(graph: ir.Graph, candidate_initializers: dict[str, ir.Value]) -> int: + """Fold ``Transpose(initializer) → matmul.B`` chains into transposed initializers. + + Replaces any ``Transpose`` node whose input is in + ``candidate_initializers`` with a pre-transposed initializer and + rewires the downstream consumers to use it directly. + + Returns the number of Transpose nodes folded. + """ + folded = 0 + for node in list(graph.all_nodes()): + if node.op_type != "Transpose": + continue + if len(node.inputs) != 1 or node.inputs[0] is None: + continue + src = node.inputs[0] + if src.name not in candidate_initializers: + continue + if src.const_value is None: + continue + perm_attr = node.attributes.get("perm") + arr = src.const_value.numpy() + perm = list(perm_attr.value) if perm_attr is not None else list(range(arr.ndim))[::-1] + transposed = np.transpose(arr, perm) + new_name = f"{src.name}__T" + if new_name in graph.initializers: + new_value = graph.initializers[new_name] + else: + new_tensor = ir.Tensor(transposed, name=new_name) + new_value = ir.Value( + name=new_name, + type=ir.TensorType(new_tensor.dtype), + shape=ir.Shape(transposed.shape), + const_value=new_tensor, + ) + graph.register_initializer(new_value) + ir.convenience.replace_all_uses_with(node.outputs[0], new_value) + graph.remove(node, safe=True) + folded += 1 + return folded + + +def _scalar_int_list(node: ir.Node, idx: int, initializers: dict[str, ir.Value]) -> list[int] | None: + """Resolve ``node.inputs[idx]`` to a ``list[int]`` if it is a static 1-D int source. + + Returns the resolved list when ``node.inputs[idx]`` is either a + static integer initializer or the output of an inline + ``Constant`` node carrying ``value`` / ``value_ints`` / + ``value_int``; otherwise returns ``None``. + """ + if idx >= len(node.inputs) or node.inputs[idx] is None: return None + value = node.inputs[idx] + # Initializer path + if value.name in initializers and initializers[value.name].const_value is not None: + arr = initializers[value.name].const_value.numpy() + return arr.reshape(-1).astype(np.int64).tolist() + # Inline Constant node path + producer = value.producer() + if producer is not None and producer.op_type == "Constant": + for attr in producer.attributes.values(): + if attr.name == "value" and attr.type == ir.AttributeType.TENSOR: + arr = attr.value.numpy() + return arr.reshape(-1).astype(np.int64).tolist() + if attr.name == "value_ints" and attr.type == ir.AttributeType.INTS: + return list(attr.value) + if attr.name == "value_int" and attr.type == ir.AttributeType.INT: + return [int(attr.value)] + return None class GraphSurgeries(Pass):