Add UnstackGatherSqueezeWeight surgeon for static-unroll MoE quantization#2490
Open
justinchuby wants to merge 3 commits into
Open
Add UnstackGatherSqueezeWeight surgeon for static-unroll MoE quantization#2490justinchuby wants to merge 3 commits into
justinchuby wants to merge 3 commits into
Conversation
Contributor
There was a problem hiding this comment.
Pull request overview
Adds a new UnstackGatherSqueezeWeight surgeon to GraphSurgeries that materialises per-slice 2-D initializers from Gather(W_3d, [const], axis=0) → Squeeze([0]) chains (with optional downstream Transpose folding). This unblocks weight quantization for static-unroll MoE models such as Gemma 4 26B-A4B where per-expert MatMul B inputs are not graph-level initializers, so passes like OnnxKQuantQuantization previously skipped them.
Changes:
- New
UnstackGatherSqueezeWeightsurgeon inolive/passes/onnx/graph_surgeries.pythat detects the Gather→Squeeze chain over a 3-D initializer, slices it at the constant index, optionally folds a downstreamTranspose, rewires consumers, and cleans up orphaned nodes/initializers. - Helpers
_fold_transpose_of_initializerand_scalar_int_listfor static-int resolution from both initializers and inlineConstantnodes. - Unit test
test_unstack_gather_squeeze_weightbuilding a tiny two-chain model and asserting node/initializer rewrites plus numerical parity viaInferenceSession.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| olive/passes/onnx/graph_surgeries.py | Implements the new UnstackGatherSqueezeWeight surgeon and helpers. |
| test/passes/onnx/test_graph_surgeries.py | Adds a unit test validating slice/transpose materialisation and numerical parity. |
Comment on lines
+3274
to
+3284
| and rewrites each match by: | ||
|
|
||
| 1. Slicing ``W_3d`` along axis 0 at the constant index ``k``. | ||
| 2. Materialising the resulting ``[K, N]`` slice as a brand-new | ||
| initializer named ``<W_3d_name>__expert_<k>``. | ||
| 3. Replacing the ``Squeeze`` output's downstream uses with that | ||
| initializer directly. | ||
| 4. Cleaning up the orphaned ``Gather`` and ``Squeeze`` nodes; the | ||
| 3-D weight itself is dropped only if all its uses have been | ||
| rewritten. | ||
|
|
Comment on lines
+3375
to
+3378
| # Rewire each Squeeze output to the new initializer and remove | ||
| # the now-unused Gather + Squeeze nodes. Indices initializer is | ||
| # left in place: it may be tiny, and downstream cleanup passes | ||
| # (e.g. RemoveUnusedOpsetsPass) handle orphan removal. |
Comment on lines
+3391
to
+3402
| # Drop any 3-D source initializers that no longer have any | ||
| # consumers in the graph after rewrite. | ||
| used_init_names: set[str] = set() | ||
| for node in graph.all_nodes(): | ||
| for inp in node.inputs: | ||
| if inp is not None and inp.name: | ||
| used_init_names.add(inp.name) | ||
| for out in graph.outputs: | ||
| if out is not None and out.name: | ||
| used_init_names.add(out.name) | ||
| for name in [n for n in graph.initializers if n not in used_init_names]: | ||
| del graph.initializers[name] |
…tion
Mobius emits static-unroll Mixture-of-Experts (MoE) blocks as:
fc1 = Squeeze(Gather(W_3d, [const_idx], axis=0), [0])
proj = MatMul(x, Transpose(fc1))
because the fused 'com.microsoft::MoE' op only supports GPT-OSS-style
SwiGLU (interleaved, alpha=1.702, limit=7.0), which is incompatible with
the standard concatenated SwiGLU used by models like Gemma 4 26B-A4B.
See microsoft/onnxruntime#28738.
When the per-expert MatMul.B is the result of a Gather/Squeeze chain
(plus optional Transpose), it is not a graph-level initializer, so
weight-quantization passes such as OnnxKQuantQuantization skip it. This
leaves the per-expert weights at the model's compute dtype (e.g. fp16),
yielding only a ~6% size reduction for Gemma 4 26B-A4B instead of the
expected 4-bit ratio.
Add an UnstackGatherSqueezeWeight surgeon that:
1. Detects 'Gather(W_3d_init, [const_int], axis=0) -> Squeeze([0])'
subgraphs.
2. Materialises the slice 'W_3d[const_int]' as a new 2-D initializer
named '<W_3d_name>__slice_<idx>'.
3. Folds any subsequent Transpose into a transposed initializer
('<W_3d_name>__slice_<idx>__T').
4. Rewires downstream consumers and drops the now-orphaned 3-D
initializer and Gather/Squeeze/Transpose nodes.
After this surgery runs, OnnxKQuantQuantization picks up every per-expert
MatMul without any changes to the quantization pass itself. Numerical
parity with the original graph is verified by the accompanying unit test.
Fixes: microsoft#2489
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
- Remove unused index_value local (PYLINT W0612 / RUFF F841).
- Replace isinstance(attr, ir.AttrTensor/AttrInt64s/AttrInt64) with
attr.type == ir.AttributeType.{TENSOR,INTS,INT}; the ir.Attr*
callables are factory functions, not types (PYLINT W1116).
- Drop the redundant first orphan-initializer sweep; the post-Transpose
sweep already covers every initializer that becomes unused at any
point during the rewrite, and runs once at the end.
- Update the misleading 'left in place' comment about index initializers
— they are in fact dropped by the orphan sweep, not by a downstream
pass.
- Sync the docstring with the implementation: name slices
'<W_3d_name>__slice_<k>' (not '__expert_<k>') and document the
Transpose-folding step explicitly.
- Replace ir.AttrInt64(...) default-construction trick with a None-check
on attributes.get('axis'), removing a needless object allocation per
Gather inspection.
- Fix RUFF D205 docstring summary/description spacing in three places.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
40a2047 to
97b176c
Compare
- Move `_fold_transpose_of_initializer` and `_scalar_int_list` from `@staticmethod`s on `UnstackGatherSqueezeWeight` to module-level private functions, per Google's Python style guide preference for free functions over static methods when no class state is involved. - Extract the Gather→Squeeze match logic into a new `_try_match_gather_squeeze` free function so `call_ir` becomes a thin orchestrator that scans the graph and dispatches to the helpers. - Replace the open-coded orphan-initializer sweep with `onnx_ir.passes.common.RemoveUnusedNodesPass`. That pass also removes dead nodes for free, and using the upstream IR pass keeps this surgeon consistent with other IR consumers. No behaviour change: `test_unstack_gather_squeeze_weight` still passes and the full 83-test `test_graph_surgeries` suite is green. Lintrunner clean. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #2489.
Problem
When mobius exports models with statically-unrolled MoE blocks (e.g. Gemma 4 26B-A4B), each per-expert MatMul looks like:
The MatMul's B input is not a graph-level initializer, so weight-quantization passes (
OnnxKQuantQuantization,OnnxBlockWiseRtnQuantization) skip it via thenode.inputs[1].is_initializer()check. The model's expert weights stay at compute dtype (fp16), giving only ~6% size reduction instead of the expected 4-bit ratio.Mobius emits this fallback because the fused
com.microsoft::MoEop's SwiGLU mode is hardcoded for GPT-OSS semantics and is incompatible with the standard SwiGLU used by models like Gemma 4 26B — see microsoft/onnxruntime#28738.Solution
Add a new
GraphSurgeriessurgeonUnstackGatherSqueezeWeightthat:Gather(W_3d_init, [const_int], axis=0) → Squeeze([0])chains where the data is a 3-D initializer and the index is a constant scalar.W_3d[const_int]as a new 2-D initializer (<W_3d_name>__slice_<idx>).Transposeinto a transposed initializer (__Tsuffix).After running this surgery, the existing
OnnxKQuantQuantizationandOnnxBlockWiseRtnQuantizationpasses pick up every per-expert MatMul with no changes to their own logic.Usage
{ "input_model": {"type": "OnnxModel", "model_path": "decoder/model.onnx"}, "passes": { "unstack": { "type": "GraphSurgeries", "surgeries": [{"surgeon": "UnstackGatherSqueezeWeight"}] }, "kquant": {"type": "OnnxKQuantQuantization", "bits": 4, "block_size": 32} } }Test
New unit test
test_unstack_gather_squeeze_weightbuilds a tiny model with two Gather → Squeeze → Transpose → MatMul chains over a single 3-D initializer and verifies:InferenceSessionon CPU EP.All 83
test_graph_surgeries.pytests pass.