Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 255 additions & 1 deletion olive/passes/onnx/graph_surgeries.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
import onnxscript
from onnx import ModelProto, TensorProto
from onnx.helper import make_tensor
from onnx_ir.passes.common import DeduplicateHashedInitializersPass, InlinePass, RemoveUnusedOpsetsPass
from onnx_ir.passes.common import (
DeduplicateHashedInitializersPass,
InlinePass,
RemoveUnusedNodesPass,
RemoveUnusedOpsetsPass,
)
from onnxscript import ir, rewriter
from onnxscript.rewriter import pattern

Expand Down Expand Up @@ -3259,6 +3264,255 @@
return model


class UnstackGatherSqueezeWeight(Surgeon):
"""Materialise per-expert (or per-slice) 2-D weights from stacked 3-D initializers.

Check warning on line 3268 in olive/passes/onnx/graph_surgeries.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "Materialise" is a misspelling of "Materialize" Raw Output: ./olive/passes/onnx/graph_surgeries.py:3268:7: "Materialise" is a misspelling of "Materialize"

Detects the pattern emitted by static-unroll Mixture-of-Experts (MoE)
blocks::

W_3d : ir.Tensor = [E, K, N] (initializer)
idx = [k] (1-element int64 initializer)
gathered = Gather(W_3d, idx, axis=0) # [1, K, N]
unstacked = Squeeze(gathered, [0]) # [K, N]
... = MatMul(x, unstacked) # B input is not an initializer

and rewrites each match by:

1. Slicing ``W_3d`` along axis 0 at the constant index ``k``.
2. Materialising the resulting ``[K, N]`` slice as a brand-new

Check warning on line 3282 in olive/passes/onnx/graph_surgeries.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "Materialising" is a misspelling of "Materializing" Raw Output: ./olive/passes/onnx/graph_surgeries.py:3282:7: "Materialising" is a misspelling of "Materializing"
initializer named ``<W_3d_name>__slice_<k>``.
3. Replacing the ``Squeeze`` output's downstream uses with that
initializer directly.
4. Folding any downstream ``Transpose(slice_initializer)`` into a
pre-transposed initializer named ``<W_3d_name>__slice_<k>__T``.
This is required for the common mobius emission pattern
``MatMul(x, Transpose(Squeeze(Gather(...))))``.
5. Removing the orphaned ``Gather``, ``Squeeze``, and ``Transpose``
nodes plus any initializer (3-D weight, per-Gather index, shared
squeeze-axes constant, intermediate slice) that no longer has any
node consumer.

Comment on lines +3279 to +3294
After this pass, the downstream ``MatMul`` nodes have a 2-D static
initializer as their B input, which means weight-quantization passes
such as :class:`OnnxKQuantQuantization` and
:class:`OnnxBlockWiseRtnQuantization` pick them up automatically.

Use case:
Run before any weight-quantization pass on a model that emits
unrolled per-expert MoE dispatch (e.g. Gemma 4 26B-A4B from
mobius, where the fused ``com.microsoft::MoE`` op is bypassed
due to a SwiGLU-kernel incompatibility). Without this pass the
per-expert MatMuls remain at the model's compute dtype because
the quantizer requires a 2-D static initializer as the B input.

Example surgery list entry::

{"surgeon": "UnstackGatherSqueezeWeight"}
"""

def call_ir(self, model: ir.Model) -> ir.Model:
# Defer to module-level free functions (Google style preference
# for free functions over static methods when no class state is
# involved).
graph = model.graph
initializers: dict[str, ir.Value] = dict(graph.initializers)

new_initializers: dict[str, ir.Value] = {}
# (gather_node, squeeze_node, replacement_initializer_value)
rewrites: list[tuple[ir.Node, ir.Node, ir.Value]] = []

for node in list(graph.all_nodes()):
rewrite = _try_match_gather_squeeze(node, initializers, new_initializers)
if rewrite is not None:
rewrites.append(rewrite)

if not rewrites:
return model

source_init_names: set[str] = {g.inputs[0].name for g, _, _ in rewrites if g.inputs[0] is not None}

# Register all new initializers first.
for value in new_initializers.values():
graph.register_initializer(value)

# Rewire each Squeeze output to the new initializer and remove
# the now-unused Gather + Squeeze nodes.
for gather_node, squeeze_node, replacement in rewrites:
ir.convenience.replace_all_uses_with(squeeze_node.outputs[0], replacement)
graph.remove(squeeze_node, safe=True)
# A single Gather output may feed multiple Squeeze chains; only
# remove the Gather node once all its consumers are gone.
gather_uses = [u for u in graph.all_nodes() if any(i is gather_node.outputs[0] for i in u.inputs)]
if not gather_uses:
graph.remove(gather_node, safe=True)

# Fold any Transpose(initializer) chain that consumes a newly
# materialised slice. This keeps the downstream MatMul.B as a

Check warning on line 3350 in olive/passes/onnx/graph_surgeries.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "materialised" is a misspelling of "materialized" Raw Output: ./olive/passes/onnx/graph_surgeries.py:3350:10: "materialised" is a misspelling of "materialized"
# static 2-D initializer so weight-quantization passes such as
# OnnxKQuantQuantization (which require ``inputs[1].is_initializer()``)
# pick it up. The most common producer is the mobius MoE
# fallback emitting ``MatMul(x, Transpose(Squeeze(Gather(...))))``.
folded_count = _fold_transpose_of_initializer(graph, new_initializers)

# Defer orphan cleanup to onnx_ir's standard dead-code-elimination
# pass instead of re-implementing the sweep here. This drops the
# original 3-D source weights, the per-Gather index initializers,
# the shared squeeze-axes constant, and any intermediate per-slice
# initializer that was consumed only by a folded Transpose.
RemoveUnusedNodesPass()(model)

logger.info(
"UnstackGatherSqueezeWeight: materialised %d per-slice initializers from "

Check warning on line 3365 in olive/passes/onnx/graph_surgeries.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "materialised" is a misspelling of "materialized" Raw Output: ./olive/passes/onnx/graph_surgeries.py:3365:41: "materialised" is a misspelling of "materialized"
"%d stacked source(s); rewrote %d Gather→Squeeze chains; folded %d Transpose ops.",
len(new_initializers),
len(source_init_names),
len(rewrites),
folded_count,
)
return model


# Free functions used by ``UnstackGatherSqueezeWeight``. Kept at module
# scope (not as class methods) per Google's Python style guide: prefer
# module-level helpers when no class state is involved.


def _try_match_gather_squeeze(
node: ir.Node,
initializers: dict[str, ir.Value],
new_initializers: dict[str, ir.Value],
) -> tuple[ir.Node, ir.Node, ir.Value] | None:
"""Match a ``Gather(W_3d, [const], axis=0) → Squeeze([0])`` chain at ``node``.

Returns ``(gather_node, squeeze_node, replacement_initializer)`` on a
successful match, or ``None`` otherwise. On match, the
``replacement_initializer`` is also recorded in
``new_initializers`` so callers can register it later.
"""
if node.op_type != "Squeeze":
return None
if len(node.inputs) < 1 or node.inputs[0] is None:
return None
squeeze_axes = _scalar_int_list(node, idx=1, initializers=initializers)
if squeeze_axes != [0]:
return None

gather_value = node.inputs[0]
gather_node = gather_value.producer()
if gather_node is None or gather_node.op_type != "Gather":
return None
axis_attr = gather_node.attributes.get("axis")
axis_val = int(axis_attr.value) if axis_attr is not None else 0
if axis_val != 0:
return None
if len(gather_node.inputs) < 2 or any(i is None for i in gather_node.inputs[:2]):
return None

data_value = gather_node.inputs[0]
if data_value.name not in initializers:
return None
data_init = initializers[data_value.name]
if data_init.const_value is None:
return None
data_array = data_init.const_value.numpy()
if data_array.ndim != 3:
return None

idx_list = _scalar_int_list(gather_node, idx=1, initializers=initializers)
if idx_list is None or len(idx_list) != 1:
return None
slice_idx = idx_list[0]
if not 0 <= slice_idx < data_array.shape[0]:
return None

slice_array = data_array[slice_idx]
slice_name = f"{data_value.name}__slice_{slice_idx}"
if slice_name in new_initializers:
replacement = new_initializers[slice_name]
else:
slice_tensor = ir.Tensor(slice_array, name=slice_name)
replacement = ir.Value(
name=slice_name,
type=ir.TensorType(slice_tensor.dtype),
shape=ir.Shape(slice_array.shape),
const_value=slice_tensor,
)
new_initializers[slice_name] = replacement

return gather_node, node, replacement


def _fold_transpose_of_initializer(graph: ir.Graph, candidate_initializers: dict[str, ir.Value]) -> int:
"""Fold ``Transpose(initializer) → matmul.B`` chains into transposed initializers.

Replaces any ``Transpose`` node whose input is in
``candidate_initializers`` with a pre-transposed initializer and
rewires the downstream consumers to use it directly.

Returns the number of Transpose nodes folded.
"""
folded = 0
for node in list(graph.all_nodes()):
if node.op_type != "Transpose":
continue
if len(node.inputs) != 1 or node.inputs[0] is None:
continue
src = node.inputs[0]
if src.name not in candidate_initializers:
continue
if src.const_value is None:
continue
perm_attr = node.attributes.get("perm")
arr = src.const_value.numpy()
perm = list(perm_attr.value) if perm_attr is not None else list(range(arr.ndim))[::-1]
transposed = np.transpose(arr, perm)
new_name = f"{src.name}__T"
if new_name in graph.initializers:
new_value = graph.initializers[new_name]
else:
new_tensor = ir.Tensor(transposed, name=new_name)
new_value = ir.Value(
name=new_name,
type=ir.TensorType(new_tensor.dtype),
shape=ir.Shape(transposed.shape),
const_value=new_tensor,
)
graph.register_initializer(new_value)
ir.convenience.replace_all_uses_with(node.outputs[0], new_value)
graph.remove(node, safe=True)
folded += 1
return folded


def _scalar_int_list(node: ir.Node, idx: int, initializers: dict[str, ir.Value]) -> list[int] | None:
"""Resolve ``node.inputs[idx]`` to a ``list[int]`` if it is a static 1-D int source.

Returns the resolved list when ``node.inputs[idx]`` is either a
static integer initializer or the output of an inline
``Constant`` node carrying ``value`` / ``value_ints`` /
``value_int``; otherwise returns ``None``.
"""
if idx >= len(node.inputs) or node.inputs[idx] is None:
return None
value = node.inputs[idx]
# Initializer path
if value.name in initializers and initializers[value.name].const_value is not None:
arr = initializers[value.name].const_value.numpy()
return arr.reshape(-1).astype(np.int64).tolist()
# Inline Constant node path
producer = value.producer()
if producer is not None and producer.op_type == "Constant":
for attr in producer.attributes.values():
if attr.name == "value" and attr.type == ir.AttributeType.TENSOR:
arr = attr.value.numpy()
return arr.reshape(-1).astype(np.int64).tolist()
if attr.name == "value_ints" and attr.type == ir.AttributeType.INTS:
return list(attr.value)
if attr.name == "value_int" and attr.type == ir.AttributeType.INT:
return [int(attr.value)]
return None


class GraphSurgeries(Pass):
"""ONNX graph surgeries collections.

Expand Down
73 changes: 73 additions & 0 deletions test/passes/onnx/test_graph_surgeries.py
Original file line number Diff line number Diff line change
Expand Up @@ -3050,3 +3050,76 @@
x = np.random.randn(4, 8).astype(np.float32)
result = sess.run(["output"], {"input": x})[0]
np.testing.assert_allclose(result, np.maximum(x, 0), atol=1e-6)


def test_unstack_gather_squeeze_weight(tmp_path):
"""Verify per-slice 2-D weight materialisation and Transpose folding.

Check warning on line 3056 in test/passes/onnx/test_graph_surgeries.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "materialisation" is a misspelling of "materialization" Raw Output: ./test/passes/onnx/test_graph_surgeries.py:3056:35: "materialisation" is a misspelling of "materialization"

``UnstackGatherSqueezeWeight`` should rewrite each
``Gather(W_3d, [const], axis=0) → Squeeze([0])`` chain into a 2-D
initializer and fold any downstream ``Transpose`` into a
pre-transposed initializer.
"""
input_tensor = helper.make_tensor_value_info("x", TensorProto.FLOAT, [4, 6])
out0 = helper.make_tensor_value_info("y0", TensorProto.FLOAT, [4, 8])
out2 = helper.make_tensor_value_info("y2", TensorProto.FLOAT, [4, 8])

w3d = np.random.RandomState(0).randn(3, 8, 6).astype(np.float32)
w3d_init = numpy_helper.from_array(w3d, name="W3d")
idx0_init = numpy_helper.from_array(np.array([0], dtype=np.int64), name="idx0")
idx2_init = numpy_helper.from_array(np.array([2], dtype=np.int64), name="idx2")
ax_init = numpy_helper.from_array(np.array([0], dtype=np.int64), name="ax0")

def chain(suffix, idx_name, y_name):
return [
helper.make_node("Gather", ["W3d", idx_name], [f"g_{suffix}"], axis=0),
helper.make_node("Squeeze", [f"g_{suffix}", "ax0"], [f"s_{suffix}"]),
helper.make_node("Transpose", [f"s_{suffix}"], [f"t_{suffix}"], perm=[1, 0]),
helper.make_node("MatMul", ["x", f"t_{suffix}"], [y_name]),
]

nodes = chain("a", "idx0", "y0") + chain("b", "idx2", "y2")
graph = helper.make_graph(
nodes=nodes,
name="UnstackTest",
inputs=[input_tensor],
outputs=[out0, out2],
initializer=[w3d_init, idx0_init, idx2_init, ax_init],
)
model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 20)])
model.ir_version = 10
model_path = tmp_path / "model.onnx"
onnx.save(model, model_path)

input_model = ONNXModelHandler(model_path=str(model_path))
output_folder = str(tmp_path / "out")
p = create_pass_from_dict(
GraphSurgeries,
{"surgeries": [{"surgeon": "UnstackGatherSqueezeWeight"}]},
disable_search=True,
)
output_model = p.run(input_model, output_folder)
g = output_model.load_model().graph

op_types = [n.op_type for n in g.node]
assert op_types == ["MatMul", "MatMul"], op_types
init_names = {i.name for i in g.initializer}
assert "W3d" not in init_names
assert "W3d__slice_0__T" in init_names
assert "W3d__slice_2__T" in init_names

# Verify the materialised initializers match np.transpose of the

Check warning on line 3111 in test/passes/onnx/test_graph_surgeries.py

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "materialised" is a misspelling of "materialized" Raw Output: ./test/passes/onnx/test_graph_surgeries.py:3111:17: "materialised" is a misspelling of "materialized"
# original slices.
inits = {i.name: numpy_helper.to_array(i) for i in g.initializer}
np.testing.assert_array_equal(inits["W3d__slice_0__T"], w3d[0].T)
np.testing.assert_array_equal(inits["W3d__slice_2__T"], w3d[2].T)

# Numerical parity: rewritten model should produce the same output
# as the original.
sess_orig = InferenceSession(str(model_path), providers=["CPUExecutionProvider"])
sess_new = InferenceSession(output_model.model_path, providers=["CPUExecutionProvider"])
x = np.random.RandomState(1).randn(4, 6).astype(np.float32)
y0_orig, y2_orig = sess_orig.run(None, {"x": x})
y0_new, y2_new = sess_new.run(None, {"x": x})
np.testing.assert_allclose(y0_new, y0_orig, atol=1e-6)
np.testing.assert_allclose(y2_new, y2_orig, atol=1e-6)
Loading