diff --git a/olive/passes/onnx/graph_surgeries.py b/olive/passes/onnx/graph_surgeries.py index 61f80a47f..4edfd3db5 100644 --- a/olive/passes/onnx/graph_surgeries.py +++ b/olive/passes/onnx/graph_surgeries.py @@ -18,7 +18,12 @@ import onnxscript from onnx import ModelProto, TensorProto from onnx.helper import make_tensor -from onnx_ir.passes.common import DeduplicateHashedInitializersPass, InlinePass, RemoveUnusedOpsetsPass +from onnx_ir.passes.common import ( + DeduplicateHashedInitializersPass, + InlinePass, + RemoveUnusedNodesPass, + RemoveUnusedOpsetsPass, +) from onnxscript import ir, rewriter from onnxscript.rewriter import pattern @@ -3259,6 +3264,255 @@ def call_ir(self, model: ir.Model) -> ir.Model: return model +class UnstackGatherSqueezeWeight(Surgeon): + """Materialise per-expert (or per-slice) 2-D weights from stacked 3-D initializers. + + Detects the pattern emitted by static-unroll Mixture-of-Experts (MoE) + blocks:: + + W_3d : ir.Tensor = [E, K, N] (initializer) + idx = [k] (1-element int64 initializer) + gathered = Gather(W_3d, idx, axis=0) # [1, K, N] + unstacked = Squeeze(gathered, [0]) # [K, N] + ... = MatMul(x, unstacked) # B input is not an initializer + + and rewrites each match by: + + 1. Slicing ``W_3d`` along axis 0 at the constant index ``k``. + 2. Materialising the resulting ``[K, N]`` slice as a brand-new + initializer named ``__slice_``. + 3. Replacing the ``Squeeze`` output's downstream uses with that + initializer directly. + 4. Folding any downstream ``Transpose(slice_initializer)`` into a + pre-transposed initializer named ``__slice___T``. + This is required for the common mobius emission pattern + ``MatMul(x, Transpose(Squeeze(Gather(...))))``. + 5. Removing the orphaned ``Gather``, ``Squeeze``, and ``Transpose`` + nodes plus any initializer (3-D weight, per-Gather index, shared + squeeze-axes constant, intermediate slice) that no longer has any + node consumer. + + After this pass, the downstream ``MatMul`` nodes have a 2-D static + initializer as their B input, which means weight-quantization passes + such as :class:`OnnxKQuantQuantization` and + :class:`OnnxBlockWiseRtnQuantization` pick them up automatically. + + Use case: + Run before any weight-quantization pass on a model that emits + unrolled per-expert MoE dispatch (e.g. Gemma 4 26B-A4B from + mobius, where the fused ``com.microsoft::MoE`` op is bypassed + due to a SwiGLU-kernel incompatibility). Without this pass the + per-expert MatMuls remain at the model's compute dtype because + the quantizer requires a 2-D static initializer as the B input. + + Example surgery list entry:: + + {"surgeon": "UnstackGatherSqueezeWeight"} + """ + + def call_ir(self, model: ir.Model) -> ir.Model: + # Defer to module-level free functions (Google style preference + # for free functions over static methods when no class state is + # involved). + graph = model.graph + initializers: dict[str, ir.Value] = dict(graph.initializers) + + new_initializers: dict[str, ir.Value] = {} + # (gather_node, squeeze_node, replacement_initializer_value) + rewrites: list[tuple[ir.Node, ir.Node, ir.Value]] = [] + + for node in list(graph.all_nodes()): + rewrite = _try_match_gather_squeeze(node, initializers, new_initializers) + if rewrite is not None: + rewrites.append(rewrite) + + if not rewrites: + return model + + source_init_names: set[str] = {g.inputs[0].name for g, _, _ in rewrites if g.inputs[0] is not None} + + # Register all new initializers first. + for value in new_initializers.values(): + graph.register_initializer(value) + + # Rewire each Squeeze output to the new initializer and remove + # the now-unused Gather + Squeeze nodes. + for gather_node, squeeze_node, replacement in rewrites: + ir.convenience.replace_all_uses_with(squeeze_node.outputs[0], replacement) + graph.remove(squeeze_node, safe=True) + # A single Gather output may feed multiple Squeeze chains; only + # remove the Gather node once all its consumers are gone. + gather_uses = [u for u in graph.all_nodes() if any(i is gather_node.outputs[0] for i in u.inputs)] + if not gather_uses: + graph.remove(gather_node, safe=True) + + # Fold any Transpose(initializer) chain that consumes a newly + # materialised slice. This keeps the downstream MatMul.B as a + # static 2-D initializer so weight-quantization passes such as + # OnnxKQuantQuantization (which require ``inputs[1].is_initializer()``) + # pick it up. The most common producer is the mobius MoE + # fallback emitting ``MatMul(x, Transpose(Squeeze(Gather(...))))``. + folded_count = _fold_transpose_of_initializer(graph, new_initializers) + + # Defer orphan cleanup to onnx_ir's standard dead-code-elimination + # pass instead of re-implementing the sweep here. This drops the + # original 3-D source weights, the per-Gather index initializers, + # the shared squeeze-axes constant, and any intermediate per-slice + # initializer that was consumed only by a folded Transpose. + RemoveUnusedNodesPass()(model) + + logger.info( + "UnstackGatherSqueezeWeight: materialised %d per-slice initializers from " + "%d stacked source(s); rewrote %d Gather→Squeeze chains; folded %d Transpose ops.", + len(new_initializers), + len(source_init_names), + len(rewrites), + folded_count, + ) + return model + + +# Free functions used by ``UnstackGatherSqueezeWeight``. Kept at module +# scope (not as class methods) per Google's Python style guide: prefer +# module-level helpers when no class state is involved. + + +def _try_match_gather_squeeze( + node: ir.Node, + initializers: dict[str, ir.Value], + new_initializers: dict[str, ir.Value], +) -> tuple[ir.Node, ir.Node, ir.Value] | None: + """Match a ``Gather(W_3d, [const], axis=0) → Squeeze([0])`` chain at ``node``. + + Returns ``(gather_node, squeeze_node, replacement_initializer)`` on a + successful match, or ``None`` otherwise. On match, the + ``replacement_initializer`` is also recorded in + ``new_initializers`` so callers can register it later. + """ + if node.op_type != "Squeeze": + return None + if len(node.inputs) < 1 or node.inputs[0] is None: + return None + squeeze_axes = _scalar_int_list(node, idx=1, initializers=initializers) + if squeeze_axes != [0]: + return None + + gather_value = node.inputs[0] + gather_node = gather_value.producer() + if gather_node is None or gather_node.op_type != "Gather": + return None + axis_attr = gather_node.attributes.get("axis") + axis_val = int(axis_attr.value) if axis_attr is not None else 0 + if axis_val != 0: + return None + if len(gather_node.inputs) < 2 or any(i is None for i in gather_node.inputs[:2]): + return None + + data_value = gather_node.inputs[0] + if data_value.name not in initializers: + return None + data_init = initializers[data_value.name] + if data_init.const_value is None: + return None + data_array = data_init.const_value.numpy() + if data_array.ndim != 3: + return None + + idx_list = _scalar_int_list(gather_node, idx=1, initializers=initializers) + if idx_list is None or len(idx_list) != 1: + return None + slice_idx = idx_list[0] + if not 0 <= slice_idx < data_array.shape[0]: + return None + + slice_array = data_array[slice_idx] + slice_name = f"{data_value.name}__slice_{slice_idx}" + if slice_name in new_initializers: + replacement = new_initializers[slice_name] + else: + slice_tensor = ir.Tensor(slice_array, name=slice_name) + replacement = ir.Value( + name=slice_name, + type=ir.TensorType(slice_tensor.dtype), + shape=ir.Shape(slice_array.shape), + const_value=slice_tensor, + ) + new_initializers[slice_name] = replacement + + return gather_node, node, replacement + + +def _fold_transpose_of_initializer(graph: ir.Graph, candidate_initializers: dict[str, ir.Value]) -> int: + """Fold ``Transpose(initializer) → matmul.B`` chains into transposed initializers. + + Replaces any ``Transpose`` node whose input is in + ``candidate_initializers`` with a pre-transposed initializer and + rewires the downstream consumers to use it directly. + + Returns the number of Transpose nodes folded. + """ + folded = 0 + for node in list(graph.all_nodes()): + if node.op_type != "Transpose": + continue + if len(node.inputs) != 1 or node.inputs[0] is None: + continue + src = node.inputs[0] + if src.name not in candidate_initializers: + continue + if src.const_value is None: + continue + perm_attr = node.attributes.get("perm") + arr = src.const_value.numpy() + perm = list(perm_attr.value) if perm_attr is not None else list(range(arr.ndim))[::-1] + transposed = np.transpose(arr, perm) + new_name = f"{src.name}__T" + if new_name in graph.initializers: + new_value = graph.initializers[new_name] + else: + new_tensor = ir.Tensor(transposed, name=new_name) + new_value = ir.Value( + name=new_name, + type=ir.TensorType(new_tensor.dtype), + shape=ir.Shape(transposed.shape), + const_value=new_tensor, + ) + graph.register_initializer(new_value) + ir.convenience.replace_all_uses_with(node.outputs[0], new_value) + graph.remove(node, safe=True) + folded += 1 + return folded + + +def _scalar_int_list(node: ir.Node, idx: int, initializers: dict[str, ir.Value]) -> list[int] | None: + """Resolve ``node.inputs[idx]`` to a ``list[int]`` if it is a static 1-D int source. + + Returns the resolved list when ``node.inputs[idx]`` is either a + static integer initializer or the output of an inline + ``Constant`` node carrying ``value`` / ``value_ints`` / + ``value_int``; otherwise returns ``None``. + """ + if idx >= len(node.inputs) or node.inputs[idx] is None: + return None + value = node.inputs[idx] + # Initializer path + if value.name in initializers and initializers[value.name].const_value is not None: + arr = initializers[value.name].const_value.numpy() + return arr.reshape(-1).astype(np.int64).tolist() + # Inline Constant node path + producer = value.producer() + if producer is not None and producer.op_type == "Constant": + for attr in producer.attributes.values(): + if attr.name == "value" and attr.type == ir.AttributeType.TENSOR: + arr = attr.value.numpy() + return arr.reshape(-1).astype(np.int64).tolist() + if attr.name == "value_ints" and attr.type == ir.AttributeType.INTS: + return list(attr.value) + if attr.name == "value_int" and attr.type == ir.AttributeType.INT: + return [int(attr.value)] + return None + + class GraphSurgeries(Pass): """ONNX graph surgeries collections. diff --git a/test/passes/onnx/test_graph_surgeries.py b/test/passes/onnx/test_graph_surgeries.py index ad44db040..b5c4ce432 100644 --- a/test/passes/onnx/test_graph_surgeries.py +++ b/test/passes/onnx/test_graph_surgeries.py @@ -3050,3 +3050,76 @@ def test_remove_memcpy_chained(tmp_path): x = np.random.randn(4, 8).astype(np.float32) result = sess.run(["output"], {"input": x})[0] np.testing.assert_allclose(result, np.maximum(x, 0), atol=1e-6) + + +def test_unstack_gather_squeeze_weight(tmp_path): + """Verify per-slice 2-D weight materialisation and Transpose folding. + + ``UnstackGatherSqueezeWeight`` should rewrite each + ``Gather(W_3d, [const], axis=0) → Squeeze([0])`` chain into a 2-D + initializer and fold any downstream ``Transpose`` into a + pre-transposed initializer. + """ + input_tensor = helper.make_tensor_value_info("x", TensorProto.FLOAT, [4, 6]) + out0 = helper.make_tensor_value_info("y0", TensorProto.FLOAT, [4, 8]) + out2 = helper.make_tensor_value_info("y2", TensorProto.FLOAT, [4, 8]) + + w3d = np.random.RandomState(0).randn(3, 8, 6).astype(np.float32) + w3d_init = numpy_helper.from_array(w3d, name="W3d") + idx0_init = numpy_helper.from_array(np.array([0], dtype=np.int64), name="idx0") + idx2_init = numpy_helper.from_array(np.array([2], dtype=np.int64), name="idx2") + ax_init = numpy_helper.from_array(np.array([0], dtype=np.int64), name="ax0") + + def chain(suffix, idx_name, y_name): + return [ + helper.make_node("Gather", ["W3d", idx_name], [f"g_{suffix}"], axis=0), + helper.make_node("Squeeze", [f"g_{suffix}", "ax0"], [f"s_{suffix}"]), + helper.make_node("Transpose", [f"s_{suffix}"], [f"t_{suffix}"], perm=[1, 0]), + helper.make_node("MatMul", ["x", f"t_{suffix}"], [y_name]), + ] + + nodes = chain("a", "idx0", "y0") + chain("b", "idx2", "y2") + graph = helper.make_graph( + nodes=nodes, + name="UnstackTest", + inputs=[input_tensor], + outputs=[out0, out2], + initializer=[w3d_init, idx0_init, idx2_init, ax_init], + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 20)]) + model.ir_version = 10 + model_path = tmp_path / "model.onnx" + onnx.save(model, model_path) + + input_model = ONNXModelHandler(model_path=str(model_path)) + output_folder = str(tmp_path / "out") + p = create_pass_from_dict( + GraphSurgeries, + {"surgeries": [{"surgeon": "UnstackGatherSqueezeWeight"}]}, + disable_search=True, + ) + output_model = p.run(input_model, output_folder) + g = output_model.load_model().graph + + op_types = [n.op_type for n in g.node] + assert op_types == ["MatMul", "MatMul"], op_types + init_names = {i.name for i in g.initializer} + assert "W3d" not in init_names + assert "W3d__slice_0__T" in init_names + assert "W3d__slice_2__T" in init_names + + # Verify the materialised initializers match np.transpose of the + # original slices. + inits = {i.name: numpy_helper.to_array(i) for i in g.initializer} + np.testing.assert_array_equal(inits["W3d__slice_0__T"], w3d[0].T) + np.testing.assert_array_equal(inits["W3d__slice_2__T"], w3d[2].T) + + # Numerical parity: rewritten model should produce the same output + # as the original. + sess_orig = InferenceSession(str(model_path), providers=["CPUExecutionProvider"]) + sess_new = InferenceSession(output_model.model_path, providers=["CPUExecutionProvider"]) + x = np.random.RandomState(1).randn(4, 6).astype(np.float32) + y0_orig, y2_orig = sess_orig.run(None, {"x": x}) + y0_new, y2_new = sess_new.run(None, {"x": x}) + np.testing.assert_allclose(y0_new, y0_orig, atol=1e-6) + np.testing.assert_allclose(y2_new, y2_orig, atol=1e-6)