Skip to content

Add UnstackGatherSqueezeWeight surgeon for static-unroll MoE quantization#2490

Open
justinchuby wants to merge 3 commits into
microsoft:mainfrom
justinchuby:unstack-moe-expert-weights
Open

Add UnstackGatherSqueezeWeight surgeon for static-unroll MoE quantization#2490
justinchuby wants to merge 3 commits into
microsoft:mainfrom
justinchuby:unstack-moe-expert-weights

Conversation

@justinchuby
Copy link
Copy Markdown
Contributor

Fixes #2489.

Problem

When mobius exports models with statically-unrolled MoE blocks (e.g. Gemma 4 26B-A4B), each per-expert MatMul looks like:

fc1 = Squeeze(Gather(W_3d, [const_idx], axis=0), [0])
y   = MatMul(x, Transpose(fc1))

The MatMul's B input is not a graph-level initializer, so weight-quantization passes (OnnxKQuantQuantization, OnnxBlockWiseRtnQuantization) skip it via the node.inputs[1].is_initializer() check. The model's expert weights stay at compute dtype (fp16), giving only ~6% size reduction instead of the expected 4-bit ratio.

Mobius emits this fallback because the fused com.microsoft::MoE op's SwiGLU mode is hardcoded for GPT-OSS semantics and is incompatible with the standard SwiGLU used by models like Gemma 4 26B — see microsoft/onnxruntime#28738.

Solution

Add a new GraphSurgeries surgeon UnstackGatherSqueezeWeight that:

  1. Detects Gather(W_3d_init, [const_int], axis=0) → Squeeze([0]) chains where the data is a 3-D initializer and the index is a constant scalar.
  2. Materialises W_3d[const_int] as a new 2-D initializer (<W_3d_name>__slice_<idx>).
  3. Folds any downstream Transpose into a transposed initializer (__T suffix).
  4. Rewires consumers and removes orphaned nodes / 3-D initializers.

After running this surgery, the existing OnnxKQuantQuantization and OnnxBlockWiseRtnQuantization passes pick up every per-expert MatMul with no changes to their own logic.

Usage

{
  "input_model": {"type": "OnnxModel", "model_path": "decoder/model.onnx"},
  "passes": {
    "unstack": {
      "type": "GraphSurgeries",
      "surgeries": [{"surgeon": "UnstackGatherSqueezeWeight"}]
    },
    "kquant": {"type": "OnnxKQuantQuantization", "bits": 4, "block_size": 32}
  }
}

Test

New unit test test_unstack_gather_squeeze_weight builds a tiny model with two Gather → Squeeze → Transpose → MatMul chains over a single 3-D initializer and verifies:

  • All Gather/Squeeze/Transpose ops are removed.
  • New per-slice transposed initializers are created with the correct values.
  • The original 3-D initializer is dropped.
  • Numerical parity with the unrewritten model holds via InferenceSession on CPU EP.

All 83 test_graph_surgeries.py tests pass.

Copilot AI review requested due to automatic review settings June 1, 2026 19:03
Comment thread olive/passes/onnx/graph_surgeries.py Fixed
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds a new UnstackGatherSqueezeWeight surgeon to GraphSurgeries that materialises per-slice 2-D initializers from Gather(W_3d, [const], axis=0) → Squeeze([0]) chains (with optional downstream Transpose folding). This unblocks weight quantization for static-unroll MoE models such as Gemma 4 26B-A4B where per-expert MatMul B inputs are not graph-level initializers, so passes like OnnxKQuantQuantization previously skipped them.

Changes:

  • New UnstackGatherSqueezeWeight surgeon in olive/passes/onnx/graph_surgeries.py that detects the Gather→Squeeze chain over a 3-D initializer, slices it at the constant index, optionally folds a downstream Transpose, rewires consumers, and cleans up orphaned nodes/initializers.
  • Helpers _fold_transpose_of_initializer and _scalar_int_list for static-int resolution from both initializers and inline Constant nodes.
  • Unit test test_unstack_gather_squeeze_weight building a tiny two-chain model and asserting node/initializer rewrites plus numerical parity via InferenceSession.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
olive/passes/onnx/graph_surgeries.py Implements the new UnstackGatherSqueezeWeight surgeon and helpers.
test/passes/onnx/test_graph_surgeries.py Adds a unit test validating slice/transpose materialisation and numerical parity.

Comment on lines +3274 to +3284
and rewrites each match by:

1. Slicing ``W_3d`` along axis 0 at the constant index ``k``.
2. Materialising the resulting ``[K, N]`` slice as a brand-new
initializer named ``<W_3d_name>__expert_<k>``.
3. Replacing the ``Squeeze`` output's downstream uses with that
initializer directly.
4. Cleaning up the orphaned ``Gather`` and ``Squeeze`` nodes; the
3-D weight itself is dropped only if all its uses have been
rewritten.

Comment thread olive/passes/onnx/graph_surgeries.py Outdated
Comment on lines +3375 to +3378
# Rewire each Squeeze output to the new initializer and remove
# the now-unused Gather + Squeeze nodes. Indices initializer is
# left in place: it may be tiny, and downstream cleanup passes
# (e.g. RemoveUnusedOpsetsPass) handle orphan removal.
Comment thread olive/passes/onnx/graph_surgeries.py Outdated
Comment on lines +3391 to +3402
# Drop any 3-D source initializers that no longer have any
# consumers in the graph after rewrite.
used_init_names: set[str] = set()
for node in graph.all_nodes():
for inp in node.inputs:
if inp is not None and inp.name:
used_init_names.add(inp.name)
for out in graph.outputs:
if out is not None and out.name:
used_init_names.add(out.name)
for name in [n for n in graph.initializers if n not in used_init_names]:
del graph.initializers[name]
Comment thread olive/passes/onnx/graph_surgeries.py Fixed
Comment thread olive/passes/onnx/graph_surgeries.py Fixed
Comment thread olive/passes/onnx/graph_surgeries.py Fixed
Comment thread olive/passes/onnx/graph_surgeries.py Fixed
Comment thread olive/passes/onnx/graph_surgeries.py Fixed
Comment thread olive/passes/onnx/graph_surgeries.py Fixed
Comment thread olive/passes/onnx/graph_surgeries.py Fixed
Comment thread test/passes/onnx/test_graph_surgeries.py Fixed
justinchuby and others added 2 commits June 2, 2026 00:10
…tion

Mobius emits static-unroll Mixture-of-Experts (MoE) blocks as:

    fc1 = Squeeze(Gather(W_3d, [const_idx], axis=0), [0])
    proj = MatMul(x, Transpose(fc1))

because the fused 'com.microsoft::MoE' op only supports GPT-OSS-style
SwiGLU (interleaved, alpha=1.702, limit=7.0), which is incompatible with
the standard concatenated SwiGLU used by models like Gemma 4 26B-A4B.
See microsoft/onnxruntime#28738.

When the per-expert MatMul.B is the result of a Gather/Squeeze chain
(plus optional Transpose), it is not a graph-level initializer, so
weight-quantization passes such as OnnxKQuantQuantization skip it. This
leaves the per-expert weights at the model's compute dtype (e.g. fp16),
yielding only a ~6% size reduction for Gemma 4 26B-A4B instead of the
expected 4-bit ratio.

Add an UnstackGatherSqueezeWeight surgeon that:

  1. Detects 'Gather(W_3d_init, [const_int], axis=0) -> Squeeze([0])'
     subgraphs.
  2. Materialises the slice 'W_3d[const_int]' as a new 2-D initializer
     named '<W_3d_name>__slice_<idx>'.
  3. Folds any subsequent Transpose into a transposed initializer
     ('<W_3d_name>__slice_<idx>__T').
  4. Rewires downstream consumers and drops the now-orphaned 3-D
     initializer and Gather/Squeeze/Transpose nodes.

After this surgery runs, OnnxKQuantQuantization picks up every per-expert
MatMul without any changes to the quantization pass itself. Numerical
parity with the original graph is verified by the accompanying unit test.

Fixes: microsoft#2489

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
- Remove unused index_value local (PYLINT W0612 / RUFF F841).
- Replace isinstance(attr, ir.AttrTensor/AttrInt64s/AttrInt64) with
  attr.type == ir.AttributeType.{TENSOR,INTS,INT}; the ir.Attr*
  callables are factory functions, not types (PYLINT W1116).
- Drop the redundant first orphan-initializer sweep; the post-Transpose
  sweep already covers every initializer that becomes unused at any
  point during the rewrite, and runs once at the end.
- Update the misleading 'left in place' comment about index initializers
  — they are in fact dropped by the orphan sweep, not by a downstream
  pass.
- Sync the docstring with the implementation: name slices
  '<W_3d_name>__slice_<k>' (not '__expert_<k>') and document the
  Transpose-folding step explicitly.
- Replace ir.AttrInt64(...) default-construction trick with a None-check
  on attributes.get('axis'), removing a needless object allocation per
  Gather inspection.
- Fix RUFF D205 docstring summary/description spacing in three places.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
@justinchuby justinchuby force-pushed the unstack-moe-expert-weights branch from 40a2047 to 97b176c Compare June 2, 2026 00:11
- Move `_fold_transpose_of_initializer` and `_scalar_int_list` from
  `@staticmethod`s on `UnstackGatherSqueezeWeight` to module-level
  private functions, per Google's Python style guide preference for
  free functions over static methods when no class state is involved.
- Extract the Gather→Squeeze match logic into a new
  `_try_match_gather_squeeze` free function so `call_ir` becomes a
  thin orchestrator that scans the graph and dispatches to the helpers.
- Replace the open-coded orphan-initializer sweep with
  `onnx_ir.passes.common.RemoveUnusedNodesPass`. That pass also
  removes dead nodes for free, and using the upstream IR pass keeps
  this surgeon consistent with other IR consumers.

No behaviour change: `test_unstack_gather_squeeze_weight` still passes
and the full 83-test `test_graph_surgeries` suite is green. Lintrunner
clean.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

OnnxKQuantQuantization skips per-expert MatMuls in static-unroll MoE blocks

3 participants