From f6a12c19cdf0bcb7a35d206c8f965b4bae1ba377 Mon Sep 17 00:00:00 2001 From: Justin Chu <11205048+justinchuby@users.noreply.github.com> Date: Mon, 1 Jun 2026 23:41:50 +0000 Subject: [PATCH 1/3] =?UTF-8?q?Add=20OnnxMoEQuantization=20pass=20(com.mic?= =?UTF-8?q?rosoft::MoE=20=E2=86=92=20QMoE)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a new ONNX graph pass that rewrites every com.microsoft::MoE node into a com.microsoft::QMoE node with the per-expert FC1/FC2 weight initializers quantized to symmetric int4 (default) or int8, plus corresponding fp16 scale initializers. Motivation: mobius (and similar exporters) emit the fused com.microsoft::MoE op with the per-expert weights as 3-D fp16/bf16/fp32 initializers. The existing weight-quantization passes (OnnxKQuantQuantization, OnnxBlockWiseRtnQuantization, OnnxBnb4Quantization) only target MatMul nodes, so for MoE models the per-expert weights (~80% of total parameters) stay at the model's compute dtype, leaving just ~6% size reduction after quantization. The QMoE op is the correct target for MoE weights and is supported by the CUDA + experimental CPU kernels in ORT main (PR microsoft/onnxruntime#28467). Implementation: - Walks the graph and finds every com.microsoft::MoE node whose fc1_experts_weights and fc2_experts_weights are 3-D static initializers. - For each expert, calls ORT's pybind quantize_matmul_{4,8}bits to produce per-expert int4/int8 weights + symmetric fp16 scales, then CUTLASS-prepacks them via pack_weights_for_cuda_mixed_gemm so the QMoE kernels can consume the bytes directly. - Stacks per-expert tensors along axis 0 and registers them as new initializers (uint8 weight + fp16 scale per expert). - Replaces the MoE node with a QMoE node carrying the original activation/routing attributes plus expert_weight_bits, optional block_size, and quant_type='int'. - Orphaned fp16 weight initializers are dropped. Supports per-row scales (block_size=0, default) and block-wise scales (block_size ≥ 16, must be power of two). Nodes can be selectively excluded via nodes_to_exclude. The pass requires a CUDA-enabled ONNX Runtime build because pack_weights_for_cuda_mixed_gemm is only exposed when ORT is compiled with USE_CUDA. A descriptive RuntimeError is raised at run time when the binding is unavailable. Limitations / out-of-scope: - fc3 inputs (3-fold MoE variants) are not supported and trigger a warning-skip per node. - Only symmetric int quantization (matching the kernel's preferred layout). FP4 / FP8 / WFP4AFP8 quant_types are left for a follow-up. - Calibration-aware quantization (GPTQ / AWQ) is out of scope; this pass is pure RTN. Tests: 5 unit tests covering (a) end-to-end MoE → QMoE conversion with int4 + per-row scales, (b) block-wise int4, (c) graceful skip when weights are not static initializers, (d) bits validation, and (e) block_size validation. The CUTLASS prepack helper is patched during tests so CI without onnxruntime-gpu can still exercise the graph transform. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com> --- olive/olive_config.json | 8 + olive/passes/onnx/moe_quantization.py | 453 ++++++++++++++++++++++ test/passes/onnx/test_moe_quantization.py | 220 +++++++++++ 3 files changed, 681 insertions(+) create mode 100644 olive/passes/onnx/moe_quantization.py create mode 100644 test/passes/onnx/test_moe_quantization.py diff --git a/olive/olive_config.json b/olive/olive_config.json index b1bca7bb2..5267f6f07 100644 --- a/olive/olive_config.json +++ b/olive/olive_config.json @@ -283,6 +283,14 @@ "supported_algorithms": [ "kquant" ], "supported_quantization_encodings": [ ] }, + "OnnxMoEQuantization": { + "module_path": "olive.passes.onnx.moe_quantization.OnnxMoEQuantization", + "supported_providers": [ "CUDAExecutionProvider" ], + "supported_accelerators": [ "gpu" ], + "supported_precisions": [ "int4", "int8" ], + "supported_algorithms": [ "rtn" ], + "supported_quantization_encodings": [ ] + }, "OnnxBnb4Quantization": { "module_path": "olive.passes.onnx.bnb_quantization.OnnxBnb4Quantization", "supported_providers": [ "CPUExecutionProvider" ], diff --git a/olive/passes/onnx/moe_quantization.py b/olive/passes/onnx/moe_quantization.py new file mode 100644 index 000000000..dc912abe5 --- /dev/null +++ b/olive/passes/onnx/moe_quantization.py @@ -0,0 +1,453 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Convert ``com.microsoft::MoE`` nodes to ``com.microsoft::QMoE``. + +The ``MoE`` op carries the per-expert ``fc1_experts_weights`` and +``fc2_experts_weights`` as 3-D fp16 / bf16 / fp32 initializers. The +``QMoE`` op accepts the same logical inputs but with the weights packed +as symmetric int4 (or int8) plus per-row or block-wise scale tensors, +laid out in the CUTLASS ``fpA_intB`` mixed-precision GEMM format that +the CUDA / CPU QMoE kernels consume. + +This pass: + +1. Walks the graph and finds every ``com.microsoft::MoE`` node whose + ``fc1_experts_weights`` and ``fc2_experts_weights`` are static 3-D + initializers. +2. For each expert, symmetrically quantizes the per-expert weight slice + using ORT's ``quantize_matmul_4bits`` / ``quantize_matmul_8bits`` + pybind helper. +3. Stacks the per-expert quantized weights and scales into 3-D / + 2-D / 3-D initializers (matching the QMoE schema) and registers them + on the graph. +4. Replaces the ``MoE`` node with a ``QMoE`` node carrying the original + activation / routing attributes plus ``expert_weight_bits`` and + ``block_size``. + +The resulting model targets ORT ≥ 1.28 / nightly post #28467; the QMoE +kernel is currently CUDA-only (plus an experimental CPU fallback). +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +import numpy as np +import onnx_ir as ir + +from olive.constants import MSFT_DOMAIN +from olive.model.utils import resolve_onnx_path +from olive.passes import Pass +from olive.passes.onnx.common import get_external_data_config, ir_model_to_olive_model +from olive.passes.pass_config import BasePassConfig, PassConfigParam + +if TYPE_CHECKING: + from olive.hardware.accelerator import AcceleratorSpec + from olive.model import ONNXModelHandler + +logger = logging.getLogger(__name__) + + +_MOE_OP_TYPE = "MoE" +_QMOE_OP_TYPE = "QMoE" + +# Input slot layout for both ``com.microsoft::MoE`` and ``com.microsoft::QMoE`` +# (the QMoE op interleaves scale tensors after each weight tensor): +# +# MoE: [input, router_probs, fc1_W, fc1_b, fc2_W, fc2_b, fc3_W, fc3_b] +# QMoE: [input, router_probs, +# fc1_W, fc1_scales, fc1_zp, fc1_b, +# fc2_W, fc2_scales, fc2_zp, fc2_b, +# fc3_W, fc3_scales, fc3_zp, fc3_b] (zp optional) +_MOE_INPUT_INDEX = { + "input": 0, + "router_probs": 1, + "fc1_W": 2, + "fc1_b": 3, + "fc2_W": 4, + "fc2_b": 5, + "fc3_W": 6, + "fc3_b": 7, +} + + +class OnnxMoEQuantization(Pass): + """Convert ``com.microsoft::MoE`` ops to ``com.microsoft::QMoE``. + + Quantizes the per-expert FC1 / FC2 weight initializers to symmetric + int4 (default) or int8 and rewires each ``MoE`` node to a ``QMoE`` + node, preserving all routing / activation attributes. + """ + + @classmethod + def _default_config(cls, accelerator_spec: AcceleratorSpec) -> dict[str, PassConfigParam]: + return { + "bits": PassConfigParam( + type_=int, + default_value=4, + description=("Number of bits per quantized weight. Supported: 4 (default) or 8."), + ), + "block_size": PassConfigParam( + type_=int, + default_value=0, + description=( + "Block size along the K dimension. 0 means per-row scales (one " + "scale per output channel). When > 0, must be a power of two " + "≥ 16 and the K dimension of each expert weight must be " + "divisible by it." + ), + ), + "nodes_to_exclude": PassConfigParam( + type_=list[str] | None, + default_value=None, + description="List of MoE node names to leave unquantized.", + ), + "force_arch": PassConfigParam( + type_=int, + default_value=80, + description=( + "Target CUDA SM version for the CUTLASS weight prepacking " + "(80 = Ampere, 90 = Hopper). Most deployments are forward " + "compatible at sm_80." + ), + ), + **get_external_data_config(), + } + + def _run_for_config( + self, + model: ONNXModelHandler, + config: type[BasePassConfig], + output_model_path: str, + ) -> ONNXModelHandler: + output_model_path = resolve_onnx_path(output_model_path, "model.onnx") + + ir_model = model.load_ir_model() + ir.external_data.load_to_model(ir_model) + ir_model.graph.opset_imports[MSFT_DOMAIN] = 1 + + if config.bits not in (4, 8): + raise ValueError(f"OnnxMoEQuantization: bits must be 4 or 8, got {config.bits}.") + if config.block_size < 0: + raise ValueError(f"OnnxMoEQuantization: block_size must be ≥ 0, got {config.block_size}.") + if config.block_size > 0 and (config.block_size < 16 or config.block_size & (config.block_size - 1)): + raise ValueError( + f"OnnxMoEQuantization: block_size must be 0 or a power of two ≥ 16, got {config.block_size}." + ) + + converted = self._convert_moe_to_qmoe( + ir_model, + bits=config.bits, + block_size=config.block_size, + nodes_to_exclude=config.nodes_to_exclude or [], + force_arch=config.force_arch, + ) + logger.info("OnnxMoEQuantization: converted %d MoE node(s) to QMoE.", converted) + + # Drop initializers that are no longer referenced (the original 3-D + # fp16 weights are replaced by new uint8 weight + fp16 scale tensors). + self._drop_unused_initializers(ir_model.graph) + + return ir_model_to_olive_model(ir_model, output_model_path, config) + + @staticmethod + def _drop_unused_initializers(graph: ir.Graph) -> None: + used: set[str] = set() + for node in graph.all_nodes(): + for inp in node.inputs: + if inp is not None and inp.name: + used.add(inp.name) + for out in graph.outputs: + if out is not None and out.name: + used.add(out.name) + unused = [name for name in graph.initializers if name not in used] + for name in unused: + del graph.initializers[name] + if unused: + logger.info("OnnxMoEQuantization: removed %d orphan initializers.", len(unused)) + + def _convert_moe_to_qmoe( + self, + ir_model: ir.Model, + bits: int, + block_size: int, + nodes_to_exclude: list[str], + force_arch: int, + ) -> int: + graph = ir_model.graph + initializers: dict[str, ir.Value] = dict(graph.initializers) + excluded = set(nodes_to_exclude) + converted = 0 + + for node in list(graph.all_nodes()): + if node.op_type != _MOE_OP_TYPE or node.domain != MSFT_DOMAIN: + continue + if node.name in excluded: + logger.debug("Skipping MoE node %s (in nodes_to_exclude).", node.name) + continue + + try: + qmoe_node = self._convert_single_moe( + node, initializers, bits=bits, block_size=block_size, force_arch=force_arch + ) + except _UnsupportedMoEError as exc: + logger.warning("Skipping MoE node %s: %s", node.name or "", exc) + continue + + ir.convenience.replace_nodes_and_values(graph, node, [node], [qmoe_node], node.outputs, qmoe_node.outputs) + converted += 1 + + return converted + + def _convert_single_moe( + self, + node: ir.Node, + initializers: dict[str, ir.Value], + bits: int, + block_size: int, + force_arch: int, + ) -> ir.Node: + # Extract and validate weight initializers. + fc1_w_value = _get_input(node, _MOE_INPUT_INDEX["fc1_W"]) + fc2_w_value = _get_input(node, _MOE_INPUT_INDEX["fc2_W"]) + fc1_array = _require_initializer(fc1_w_value, "fc1_experts_weights", initializers) + fc2_array = _require_initializer(fc2_w_value, "fc2_experts_weights", initializers) + + if fc1_array.ndim != 3 or fc2_array.ndim != 3: + raise _UnsupportedMoEError( + f"Expected 3-D weights; got fc1.ndim={fc1_array.ndim}, fc2.ndim={fc2_array.ndim}." + ) + + num_experts = fc1_array.shape[0] + if fc2_array.shape[0] != num_experts: + raise _UnsupportedMoEError(f"fc1/fc2 num_experts disagree: {fc1_array.shape[0]} vs {fc2_array.shape[0]}.") + + # Quantize each expert weight independently and stack. + fc1_qweights, fc1_scales = _quantize_stacked_weights( + fc1_array, bits=bits, block_size=block_size, force_arch=force_arch + ) + fc2_qweights, fc2_scales = _quantize_stacked_weights( + fc2_array, bits=bits, block_size=block_size, force_arch=force_arch + ) + + # Build new initializers, named to keep the original tensor name as a prefix. + graph = node.graph + fc1_w_init = _make_initializer(f"{fc1_w_value.name}_q", fc1_qweights) + fc1_s_init = _make_initializer(f"{fc1_w_value.name}_scales", fc1_scales) + fc2_w_init = _make_initializer(f"{fc2_w_value.name}_q", fc2_qweights) + fc2_s_init = _make_initializer(f"{fc2_w_value.name}_scales", fc2_scales) + for init in (fc1_w_init, fc1_s_init, fc2_w_init, fc2_s_init): + graph.register_initializer(init) + + # Carry biases through unchanged. + fc1_bias = _maybe_input(node, _MOE_INPUT_INDEX["fc1_b"]) + fc2_bias = _maybe_input(node, _MOE_INPUT_INDEX["fc2_b"]) + fc3_w = _maybe_input(node, _MOE_INPUT_INDEX["fc3_W"]) + if fc3_w is not None: + raise _UnsupportedMoEError("fc3 inputs are not yet supported by this pass.") + + # QMoE inputs (zero_points left absent because we use symmetric int4/int8): + # 0: input + # 1: router_probs + # 2: fc1_experts_weights (quantized) + # 3: fc1_scales + # 4: fc1_zero_points (None — symmetric) + # 5: fc1_experts_bias + # 6: fc2_experts_weights (quantized) + # 7: fc2_scales + # 8: fc2_zero_points (None) + # 9: fc2_experts_bias + qmoe_inputs = [ + _get_input(node, _MOE_INPUT_INDEX["input"]), + _get_input(node, _MOE_INPUT_INDEX["router_probs"]), + fc1_w_init, + fc1_s_init, + None, + fc1_bias, + fc2_w_init, + fc2_s_init, + None, + fc2_bias, + ] + + # Copy all routing / activation attributes verbatim, then add the + # quantization-specific ones. + new_attrs = list(node.attributes.values()) + new_attrs.append(ir.AttrInt64("expert_weight_bits", bits)) + if block_size > 0: + new_attrs.append(ir.AttrInt64("block_size", block_size)) + # ``quant_type`` defaults to "int" in the schema; emit it explicitly + # for forward compatibility against future schema revisions that may + # introduce other defaults. + if not any(a.name == "quant_type" for a in node.attributes.values()): + new_attrs.append(ir.AttrString("quant_type", "int")) + + output_value = ir.Value(name=node.outputs[0].name) + return ir.Node( + domain=MSFT_DOMAIN, + op_type=_QMOE_OP_TYPE, + inputs=qmoe_inputs, + attributes=new_attrs, + outputs=[output_value], + name=node.name + "_QMoE" if node.name else None, + ) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +class _UnsupportedMoEError(Exception): + """Raised when a particular MoE node can't be converted by this pass.""" + + +def _get_input(node: ir.Node, idx: int) -> ir.Value: + if idx >= len(node.inputs) or node.inputs[idx] is None: + raise _UnsupportedMoEError(f"Missing required input at slot {idx}.") + return node.inputs[idx] + + +def _maybe_input(node: ir.Node, idx: int) -> ir.Value | None: + if idx >= len(node.inputs): + return None + return node.inputs[idx] + + +def _require_initializer(value: ir.Value, what: str, initializers: dict[str, ir.Value]) -> np.ndarray: + init = initializers.get(value.name) + if init is None or init.const_value is None: + raise _UnsupportedMoEError(f"{what} ({value.name!r}) is not a static initializer.") + return init.const_value.numpy() + + +def _make_initializer(name: str, array: np.ndarray) -> ir.Value: + tensor = ir.Tensor(array, name=name) + return ir.Value( + name=name, + type=ir.TensorType(tensor.dtype), + shape=ir.Shape(array.shape), + const_value=tensor, + ) + + +def _quantize_stacked_weights( + weights_3d: np.ndarray, bits: int, block_size: int, force_arch: int +) -> tuple[np.ndarray, np.ndarray]: + """Quantize each expert's [N, K] slice and stack along axis 0. + + Returns a tuple ``(packed_weights, scales)`` where: + + - ``packed_weights`` has shape ``[E, K, N // pack_size]`` (uint8), + laid out in the CUTLASS ``fpA_intB`` mixed-precision GEMM format + expected by the QMoE kernels. + - ``scales`` has shape ``[E, N]`` for per-row scales, or + ``[E, N, K // block_size]`` for block-wise scales (fp16). + """ + if bits not in (4, 8): + raise ValueError(f"bits must be 4 or 8, got {bits}") + num_experts = weights_3d.shape[0] + + packed_per_expert = [] + scales_per_expert = [] + + pack_fn = _load_cuda_pack_fn() + + for e in range(num_experts): + weight = weights_3d[e] # [N, K] + packed, scale = _quantize_one_expert( + weight, bits=bits, block_size=block_size, pack_fn=pack_fn, force_arch=force_arch + ) + packed_per_expert.append(packed) + scales_per_expert.append(scale) + + return np.stack(packed_per_expert, axis=0), np.stack(scales_per_expert, axis=0) + + +def _quantize_one_expert( + weight: np.ndarray, bits: int, block_size: int, pack_fn, force_arch: int +) -> tuple[np.ndarray, np.ndarray]: + """Quantize a single expert's ``[N, K]`` weight matrix. + + Mirrors the test harness in + ``onnxruntime/test/python/transformers/test_qmoe_cuda.py:quant_dequant_blockwise``: + + 1. Transpose to ``[K, N]`` (CUTLASS column-major convention). + 2. Call ORT's ``quantize_matmul_{bits}bits`` to produce per-block + ``q_weight`` and ``scale``. + 3. Call ORT's ``pack_weights_for_cuda_mixed_gemm`` to permute / + interleave the bytes into the fpA_intB layout the kernel reads. + """ + from onnxruntime.capi import _pybind_state as _p + + quant_fn_name = f"quantize_matmul_{bits}bits" + quantize = getattr(_p, quant_fn_name, None) + if quantize is None: + raise RuntimeError( + f"onnxruntime.capi._pybind_state.{quant_fn_name} is not available; " + "this Olive pass needs a recent build of onnxruntime." + ) + + # Promote fp16/bf16 to fp32 for the quantizer (bindings only accept + # fp16 or fp32; bf16 isn't supported as a python numpy dtype). + if weight.dtype in (np.float16, np.float32): + weight_for_quant = weight + else: + weight_for_quant = weight.astype(np.float32) + + n, k = weight_for_quant.shape # per-expert weight is [N, K] + weight_t = np.ascontiguousarray(weight_for_quant.T) # [K, N] + + effective_block = block_size if block_size > 0 else k + if k % effective_block != 0: + raise _UnsupportedMoEError(f"K ({k}) is not divisible by block_size ({effective_block}).") + block_per_k = k // effective_block + + pack_factor = 8 // bits # 2 for int4, 1 for int8 + blob_size = effective_block // pack_factor + q_weight = np.zeros((n, block_per_k, blob_size), dtype=np.uint8) + scale = np.zeros((n, block_per_k), dtype=np.float32) + zero_point = np.zeros((n, (block_per_k + pack_factor - 1) // pack_factor), dtype=np.uint8) + + # Symmetric quantization (kernel uses (q - bias) * scale internally). + quantize(q_weight, weight_t, scale, zero_point, effective_block, n, k, True) # pylint: disable=not-callable + scale = np.abs(scale) + + # CUTLASS mixed-precision GEMM expects a specific byte layout. + q_weight_flat = q_weight.reshape(n, -1) + packed = pack_fn(q_weight_flat, n, k, bits, force_arch) + packed = np.ascontiguousarray(packed.reshape(k, n // pack_factor)).view(np.uint8) + + # Squeeze trivial block dim to match the spec when block_size == 0: + # row-wise scales → [N] + # block-wise scales → [N, block_per_k] + if block_size == 0: + scale_out = scale.reshape(n).astype(np.float16) + else: + scale_out = scale.reshape(n, block_per_k).astype(np.float16) + + return packed, scale_out + + +def _load_cuda_pack_fn(): + """Locate ``pack_weights_for_cuda_mixed_gemm`` in the installed ORT. + + Raises a descriptive error if the binding isn't available; without + it the produced model can't be loaded by the CUDA or CPU QMoE + kernels because they read the CUTLASS pre-packed layout. + """ + from onnxruntime.capi import _pybind_state as _p + + pack_fn = getattr(_p, "pack_weights_for_cuda_mixed_gemm", None) + if pack_fn is None: + raise RuntimeError( + "OnnxMoEQuantization requires the CUTLASS weight-packing helper " + "`pack_weights_for_cuda_mixed_gemm`, which is only exported by " + "ONNX Runtime when built with CUDA support. Install onnxruntime-gpu " + ">= 1.28 (or a nightly built from main with USE_CUDA after PR " + "microsoft/onnxruntime#28467)." + ) + return pack_fn diff --git a/test/passes/onnx/test_moe_quantization.py b/test/passes/onnx/test_moe_quantization.py new file mode 100644 index 000000000..8bb6885b0 --- /dev/null +++ b/test/passes/onnx/test_moe_quantization.py @@ -0,0 +1,220 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Tests for OnnxMoEQuantization (com.microsoft::MoE → com.microsoft::QMoE).""" + +from __future__ import annotations + +from unittest.mock import patch + +import numpy as np +import onnx +import pytest +from onnx import TensorProto, helper, numpy_helper + +from olive.model import ONNXModelHandler +from olive.passes.olive_pass import create_pass_from_dict +from olive.passes.onnx.moe_quantization import OnnxMoEQuantization + + +def _build_moe_model(tmp_path, num_experts=4, hidden=16, inter=32, top_k=2): + """Build a tiny ONNX model containing one ``com.microsoft::MoE`` node. + + Layout matches what mobius emits for Gemma 4 MoE: + fc1_experts_weights [E, 2*inter, H] + fc2_experts_weights [E, H, inter] + """ + rng = np.random.RandomState(0) + fc1 = rng.randn(num_experts, 2 * inter, hidden).astype(np.float32) * 0.02 + fc2 = rng.randn(num_experts, hidden, inter).astype(np.float32) * 0.02 + + fc1_init = numpy_helper.from_array(fc1, name="fc1_W") + fc2_init = numpy_helper.from_array(fc2, name="fc2_W") + + input_t = helper.make_tensor_value_info("x", TensorProto.FLOAT, [None, hidden]) + router_t = helper.make_tensor_value_info("router", TensorProto.FLOAT, [None, num_experts]) + output_t = helper.make_tensor_value_info("y", TensorProto.FLOAT, [None, hidden]) + + moe = helper.make_node( + "MoE", + inputs=["x", "router", "fc1_W", "", "fc2_W"], + outputs=["y"], + name="moe_layer_0", + domain="com.microsoft", + k=top_k, + normalize_routing_weights=1, + activation_type="swiglu", + swiglu_fusion=1, + activation_alpha=1.0, + activation_beta=0.0, + swiglu_limit=float("inf"), + ) + + graph = helper.make_graph( + nodes=[moe], + name="moe_only", + inputs=[input_t, router_t], + outputs=[output_t], + initializer=[fc1_init, fc2_init], + ) + model = helper.make_model( + graph, + opset_imports=[ + helper.make_opsetid("", 20), + helper.make_opsetid("com.microsoft", 1), + ], + ) + model.ir_version = 10 + path = tmp_path / "moe.onnx" + onnx.save(model, path) + return path, fc1, fc2 + + +def _fake_pack_weights_for_cuda_mixed_gemm(q_weights, n, k, bits, force_arch): + """Pass-through replacement for the CUTLASS prepack helper. + + The real helper permutes / interleaves bytes for the fpA_intB kernel + and is only available when ORT is built with USE_CUDA. The structural + pass test doesn't depend on the byte layout; this stub keeps the + shape (``[k, n // pack_factor]`` after the caller's reshape) but + leaves the data identical so the test runs in CPU-only CI. + """ + pack_factor = 8 // bits + out = np.ascontiguousarray(q_weights.reshape(n, k // pack_factor)) + # Caller reshapes to (k, n // pack_factor) — return the same total byte + # count so that reshape doesn't fail. + return out.reshape(k, n // pack_factor) + + +def test_moe_to_qmoe_conversion(tmp_path): + """Replace one MoE node with a QMoE node, int4-quantize, carry attrs. + + End-to-end: a single MoE node is replaced by a QMoE node, weights are + quantized to int4, scales are added, and all routing/activation attrs + are carried over. + """ + model_path, fc1_fp32, fc2_fp32 = _build_moe_model(tmp_path) + + input_model = ONNXModelHandler(model_path=str(model_path)) + p = create_pass_from_dict( + OnnxMoEQuantization, + {"bits": 4, "block_size": 0}, + disable_search=True, + ) + + with patch( + "olive.passes.onnx.moe_quantization._load_cuda_pack_fn", + return_value=_fake_pack_weights_for_cuda_mixed_gemm, + ): + output_model = p.run(input_model, str(tmp_path / "out")) + + g = output_model.load_model().graph + + moe_nodes = [n for n in g.node if n.op_type == "MoE"] + qmoe_nodes = [n for n in g.node if n.op_type == "QMoE"] + assert moe_nodes == [], "Original MoE node should be replaced." + assert len(qmoe_nodes) == 1 + qmoe = qmoe_nodes[0] + assert qmoe.domain == "com.microsoft" + # QMoE input ordering: input, router, fc1_W, fc1_scales, fc1_zp, fc1_b, + # fc2_W, fc2_scales, fc2_zp, fc2_b + assert qmoe.input[0] == "x" + assert qmoe.input[1] == "router" + assert qmoe.input[2].endswith("_q"), "Quantized weight initializer expected at slot 2." + assert qmoe.input[3].endswith("_scales"), "Scale initializer expected at slot 3." + assert qmoe.input[4] == "", "Zero-point (symmetric mode) should be empty at slot 4." + assert qmoe.input[6].endswith("_q") + assert qmoe.input[7].endswith("_scales") + + # Attributes: routing/activation carried over plus expert_weight_bits. + attrs = {a.name: a for a in qmoe.attribute} + assert attrs["k"].i == 2 + assert attrs["normalize_routing_weights"].i == 1 + assert attrs["activation_type"].s.decode() == "swiglu" + assert attrs["swiglu_fusion"].i == 1 + assert attrs["expert_weight_bits"].i == 4 + assert "block_size" not in attrs, "block_size should not be emitted when 0." + assert attrs["quant_type"].s.decode() == "int" + + # Initializer dtype + shape checks. + inits = {i.name: i for i in g.initializer} + fc1_q = numpy_helper.to_array(inits[qmoe.input[2]]) + fc1_s = numpy_helper.to_array(inits[qmoe.input[3]]) + fc2_q = numpy_helper.to_array(inits[qmoe.input[6]]) + fc2_s = numpy_helper.to_array(inits[qmoe.input[7]]) + + e_dim, two_inter, hidden = fc1_fp32.shape # [E, 2*inter, H] + pack_factor = 2 # int4 + assert fc1_q.dtype == np.uint8 + assert fc1_q.shape == (e_dim, hidden, two_inter // pack_factor) + assert fc1_s.dtype == np.float16 + assert fc1_s.shape == (e_dim, two_inter) # per-row scales when block_size == 0 + + e_dim2, hidden2, inter = fc2_fp32.shape # [E, H, inter] + assert fc2_q.dtype == np.uint8 + assert fc2_q.shape == (e_dim2, inter, hidden2 // pack_factor) + assert fc2_s.shape == (e_dim2, hidden2) + + # The original fp32 weights must be gone. + assert "fc1_W" not in inits + assert "fc2_W" not in inits + + +def test_moe_to_qmoe_blockwise(tmp_path): + """Block-wise (block_size=16) emits 3-D scales and a block_size attribute.""" + model_path, fc1, _ = _build_moe_model(tmp_path, hidden=32, inter=32) + p = create_pass_from_dict(OnnxMoEQuantization, {"bits": 4, "block_size": 16}, disable_search=True) + input_model = ONNXModelHandler(model_path=str(model_path)) + with patch( + "olive.passes.onnx.moe_quantization._load_cuda_pack_fn", + return_value=_fake_pack_weights_for_cuda_mixed_gemm, + ): + output_model = p.run(input_model, str(tmp_path / "out")) + g = output_model.load_model().graph + qmoe = next(n for n in g.node if n.op_type == "QMoE") + attrs = {a.name: a for a in qmoe.attribute} + assert attrs["block_size"].i == 16 + + inits = {i.name: i for i in g.initializer} + fc1_s = numpy_helper.to_array(inits[qmoe.input[3]]) + e_dim, two_inter, hidden = fc1.shape + # Block-wise scales: [E, N, K // block_size] + assert fc1_s.shape == (e_dim, two_inter, hidden // 16) + + +def test_moe_to_qmoe_skip_when_not_initializer(tmp_path): + """Skip an MoE node whose weight is a dynamic input, leaving it unchanged.""" + model_path, _, _ = _build_moe_model(tmp_path) + # Edit the model: replace fc1_W initializer with a graph input so it + # isn't a static initializer. + m = onnx.load(model_path) + m.graph.initializer.pop(0) # remove fc1_W + m.graph.input.append(helper.make_tensor_value_info("fc1_W", TensorProto.FLOAT, [4, 64, 16])) + onnx.save(m, model_path) + + p = create_pass_from_dict(OnnxMoEQuantization, {"bits": 4}, disable_search=True) + input_model = ONNXModelHandler(model_path=str(model_path)) + with patch( + "olive.passes.onnx.moe_quantization._load_cuda_pack_fn", + return_value=_fake_pack_weights_for_cuda_mixed_gemm, + ): + output_model = p.run(input_model, str(tmp_path / "out")) + g = output_model.load_model().graph + assert [n.op_type for n in g.node] == ["MoE"], "Node should remain unchanged when weights are dynamic." + + +def test_invalid_bits_rejected(tmp_path): + """Bits other than 4 or 8 fails fast at config time.""" + model_path, _, _ = _build_moe_model(tmp_path) + p = create_pass_from_dict(OnnxMoEQuantization, {"bits": 5}, disable_search=True) + with pytest.raises(ValueError, match="bits must be 4 or 8"): + p.run(ONNXModelHandler(model_path=str(model_path)), str(tmp_path / "out")) + + +def test_invalid_block_size_rejected(tmp_path): + """Non-power-of-two block_size fails fast.""" + model_path, _, _ = _build_moe_model(tmp_path) + p = create_pass_from_dict(OnnxMoEQuantization, {"bits": 4, "block_size": 24}, disable_search=True) + with pytest.raises(ValueError, match="power of two"): + p.run(ONNXModelHandler(model_path=str(model_path)), str(tmp_path / "out")) From 289d56c02ae9f8e48aef678de19bf09e712d2a12 Mon Sep 17 00:00:00 2001 From: Justin Chu <11205048+justinchuby@users.noreply.github.com> Date: Tue, 2 Jun 2026 00:10:23 +0000 Subject: [PATCH 2/3] Address PR review feedback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - _maybe_input now treats an ir.Value with empty name as missing, not present. ONNX represents unset optional inputs as empty-string slots; while onnx_ir typically normalises those to None, defensively handle the ir.Value(name='') case too so the fc3 reject path doesn't fire on MoE nodes that include empty placeholder slots for fc1_bias / fc2_bias. - Validate N % pack_factor == 0 and block_size % pack_factor == 0 in _quantize_one_expert. These were latent failure modes where the CUTLASS prepack helper would either crash or produce a wrong layout; now we emit a clear _UnsupportedMoEError and the MoE node is skipped with a warning instead of being silently corrupted. - Add a comment in _quantize_one_expert explaining that the 2-D scale / zero_point shapes match the upstream ORT test harness (test_qmoe_cuda.py::quant_dequant_blockwise) — pybind11's buffer protocol accepts any contiguous shape as long as the element count matches, so this isn't a regression vs the 1-D layout used in rtn_quantization.py. Two new unit tests cover the changes: - test_moe_to_qmoe_handles_explicit_empty_optional_inputs: appends empty-string fc2_bias / fc3_W / fc3_bias slots to the MoE node and asserts the pass still converts it (fc3 reject path doesn't trigger). - test_n_not_divisible_by_pack_factor_skipped: builds an MoE node with N=3 (odd) and asserts the conversion is skipped with a clean warning rather than crashing. All 7 tests pass, lintrunner clean. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com> --- olive/passes/onnx/moe_quantization.py | 25 +++++++- test/passes/onnx/test_moe_quantization.py | 77 +++++++++++++++++++++++ 2 files changed, 101 insertions(+), 1 deletion(-) diff --git a/olive/passes/onnx/moe_quantization.py b/olive/passes/onnx/moe_quantization.py index dc912abe5..f56bd6a2a 100644 --- a/olive/passes/onnx/moe_quantization.py +++ b/olive/passes/onnx/moe_quantization.py @@ -312,9 +312,19 @@ def _get_input(node: ir.Node, idx: int) -> ir.Value: def _maybe_input(node: ir.Node, idx: int) -> ir.Value | None: + """Return the optional input at ``idx``, treating empty / missing slots as absent. + + ONNX represents an unset optional input either as a slot past the end of the + inputs list, or as an in-place placeholder with an empty name. ``onnx_ir`` + typically maps the latter to ``None``, but defensively handle the + ``ir.Value(name="")`` case too so callers can rely on ``None`` meaning absent. + """ if idx >= len(node.inputs): return None - return node.inputs[idx] + value = node.inputs[idx] + if value is None or not value.name: + return None + return value def _require_initializer(value: ir.Value, what: str, initializers: dict[str, ir.Value]) -> np.ndarray: @@ -407,7 +417,20 @@ def _quantize_one_expert( block_per_k = k // effective_block pack_factor = 8 // bits # 2 for int4, 1 for int8 + if n % pack_factor != 0: + raise _UnsupportedMoEError(f"N ({n}) must be divisible by pack_factor ({pack_factor}) for {bits}-bit packing.") + if effective_block % pack_factor != 0: + raise _UnsupportedMoEError( + f"block_size ({effective_block}) must be divisible by pack_factor ({pack_factor}) for {bits}-bit packing." + ) blob_size = effective_block // pack_factor + # The pybind quantize_matmul_{4,8}bits binding takes raw contiguous buffers + # for ``scale`` and ``zero_point``; pybind11's buffer-protocol overload + # accepts any shape with the same total element count. Matching the layout + # used in the upstream ORT test harness + # (onnxruntime/test/python/transformers/test_qmoe_cuda.py::quant_dequant_blockwise) + # keeps the on-disk byte order obvious — 2-D ``[N, block_per_k]`` for scale, + # 2-D ``[N, ceil(block_per_k / pack_factor)]`` for zero_point. q_weight = np.zeros((n, block_per_k, blob_size), dtype=np.uint8) scale = np.zeros((n, block_per_k), dtype=np.float32) zero_point = np.zeros((n, (block_per_k + pack_factor - 1) // pack_factor), dtype=np.uint8) diff --git a/test/passes/onnx/test_moe_quantization.py b/test/passes/onnx/test_moe_quantization.py index 8bb6885b0..4f2700661 100644 --- a/test/passes/onnx/test_moe_quantization.py +++ b/test/passes/onnx/test_moe_quantization.py @@ -218,3 +218,80 @@ def test_invalid_block_size_rejected(tmp_path): p = create_pass_from_dict(OnnxMoEQuantization, {"bits": 4, "block_size": 24}, disable_search=True) with pytest.raises(ValueError, match="power of two"): p.run(ONNXModelHandler(model_path=str(model_path)), str(tmp_path / "out")) + + +def test_moe_to_qmoe_handles_explicit_empty_optional_inputs(tmp_path): + """Convert an MoE node with explicit empty-string optional inputs. + + Optional fc1_bias, fc3_W, and fc3_bias slots are present as empty + strings rather than absent slots; the pass should treat them as + missing and still convert the node. + """ + # _build_moe_model already emits inputs=['x','router','fc1_W','','fc2_W']; here we + # extend it to also include empty fc3 slots to exercise the slot-7 boundary. + model_path, _, _ = _build_moe_model(tmp_path) + m = onnx.load(model_path) + moe = m.graph.node[0] + # Append empty fc2_bias, fc3_W, fc3_bias slots. + moe.input.extend(["", "", ""]) + onnx.save(m, model_path) + + p = create_pass_from_dict(OnnxMoEQuantization, {"bits": 4}, disable_search=True) + with patch( + "olive.passes.onnx.moe_quantization._load_cuda_pack_fn", + return_value=_fake_pack_weights_for_cuda_mixed_gemm, + ): + output_model = p.run(ONNXModelHandler(model_path=str(model_path)), str(tmp_path / "out")) + g = output_model.load_model().graph + qmoe_nodes = [n for n in g.node if n.op_type == "QMoE"] + assert len(qmoe_nodes) == 1, "MoE node with empty optional slots should still be converted." + + +def test_n_not_divisible_by_pack_factor_skipped(tmp_path): + """Skip MoE nodes whose N is incompatible with the 4-bit packing factor. + + N (== 2*inter for fc1) not divisible by pack_factor (== 2 for int4) + should be rejected with a clear error and the MoE node left + unchanged. + """ + # Construct an MoE node with an odd fc1 second dimension. + rng = np.random.RandomState(0) + fc1 = rng.randn(2, 3, 8).astype(np.float32) # E=2, N=3 (odd), K=8 + fc2 = rng.randn(2, 8, 4).astype(np.float32) + inputs = [ + helper.make_tensor_value_info("x", TensorProto.FLOAT, [None, 8]), + helper.make_tensor_value_info("r", TensorProto.FLOAT, [None, 2]), + ] + out = helper.make_tensor_value_info("y", TensorProto.FLOAT, [None, 8]) + moe = helper.make_node( + "MoE", + ["x", "r", "fc1_W", "", "fc2_W"], + ["y"], + name="m", + domain="com.microsoft", + k=1, + activation_type="silu", + ) + graph = helper.make_graph( + [moe], + "g", + inputs, + [out], + initializer=[numpy_helper.from_array(fc1, "fc1_W"), numpy_helper.from_array(fc2, "fc2_W")], + ) + model = helper.make_model( + graph, + opset_imports=[helper.make_opsetid("", 20), helper.make_opsetid("com.microsoft", 1)], + ) + model.ir_version = 10 + p_in = tmp_path / "m.onnx" + onnx.save(model, p_in) + + p = create_pass_from_dict(OnnxMoEQuantization, {"bits": 4}, disable_search=True) + with patch( + "olive.passes.onnx.moe_quantization._load_cuda_pack_fn", + return_value=_fake_pack_weights_for_cuda_mixed_gemm, + ): + output_model = p.run(ONNXModelHandler(model_path=str(p_in)), str(tmp_path / "out")) + g = output_model.load_model().graph + assert [n.op_type for n in g.node] == ["MoE"], "Odd-N MoE node should be skipped, not crash." From ea57069827b03f54ef3a3a113f3ff0c37ca077e1 Mon Sep 17 00:00:00 2001 From: Justin Chu <11205048+justinchuby@users.noreply.github.com> Date: Tue, 2 Jun 2026 00:18:33 +0000 Subject: [PATCH 3/3] Refactor: free functions + onnx_ir RemoveUnusedNodesPass for cleanup - Move `_convert_moe_to_qmoe`, `_convert_single_moe`, and the `_drop_unused_initializers` helper from the `OnnxMoEQuantization` class into module-level private functions, per Google's Python style guide preference for free functions over class methods when no class state is involved. The `OnnxMoEQuantization` class now only owns config defaulting and the `_run_for_config` entry point. - Replace the hand-rolled orphan-initializer sweep with `onnx_ir.passes.common.RemoveUnusedNodesPass`, which also handles dead-node removal and keeps the cleanup consistent with the rest of the IR pass ecosystem. No behaviour change: all 7 unit tests pass, lintrunner clean. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com> --- olive/passes/onnx/moe_quantization.py | 284 +++++++++++++------------- 1 file changed, 141 insertions(+), 143 deletions(-) diff --git a/olive/passes/onnx/moe_quantization.py b/olive/passes/onnx/moe_quantization.py index f56bd6a2a..1d4923c73 100644 --- a/olive/passes/onnx/moe_quantization.py +++ b/olive/passes/onnx/moe_quantization.py @@ -37,6 +37,7 @@ import numpy as np import onnx_ir as ir +from onnx_ir.passes.common.unused_removal import RemoveUnusedNodesPass from olive.constants import MSFT_DOMAIN from olive.model.utils import resolve_onnx_path @@ -138,7 +139,7 @@ def _run_for_config( f"OnnxMoEQuantization: block_size must be 0 or a power of two ≥ 16, got {config.block_size}." ) - converted = self._convert_moe_to_qmoe( + converted = _convert_moe_to_qmoe( ir_model, bits=config.bits, block_size=config.block_size, @@ -147,157 +148,154 @@ def _run_for_config( ) logger.info("OnnxMoEQuantization: converted %d MoE node(s) to QMoE.", converted) - # Drop initializers that are no longer referenced (the original 3-D - # fp16 weights are replaced by new uint8 weight + fp16 scale tensors). - self._drop_unused_initializers(ir_model.graph) + # Drop the original 3-D fp16 weight initializers (now replaced by + # quantized uint8 weight + fp16 scale initializers). Defer to + # onnx_ir's standard dead-code elimination pass so this stays + # consistent with other consumers of the IR. + RemoveUnusedNodesPass()(ir_model) return ir_model_to_olive_model(ir_model, output_model_path, config) - @staticmethod - def _drop_unused_initializers(graph: ir.Graph) -> None: - used: set[str] = set() - for node in graph.all_nodes(): - for inp in node.inputs: - if inp is not None and inp.name: - used.add(inp.name) - for out in graph.outputs: - if out is not None and out.name: - used.add(out.name) - unused = [name for name in graph.initializers if name not in used] - for name in unused: - del graph.initializers[name] - if unused: - logger.info("OnnxMoEQuantization: removed %d orphan initializers.", len(unused)) - - def _convert_moe_to_qmoe( - self, - ir_model: ir.Model, - bits: int, - block_size: int, - nodes_to_exclude: list[str], - force_arch: int, - ) -> int: - graph = ir_model.graph - initializers: dict[str, ir.Value] = dict(graph.initializers) - excluded = set(nodes_to_exclude) - converted = 0 - - for node in list(graph.all_nodes()): - if node.op_type != _MOE_OP_TYPE or node.domain != MSFT_DOMAIN: - continue - if node.name in excluded: - logger.debug("Skipping MoE node %s (in nodes_to_exclude).", node.name) - continue - - try: - qmoe_node = self._convert_single_moe( - node, initializers, bits=bits, block_size=block_size, force_arch=force_arch - ) - except _UnsupportedMoEError as exc: - logger.warning("Skipping MoE node %s: %s", node.name or "", exc) - continue - - ir.convenience.replace_nodes_and_values(graph, node, [node], [qmoe_node], node.outputs, qmoe_node.outputs) - converted += 1 - - return converted - - def _convert_single_moe( - self, - node: ir.Node, - initializers: dict[str, ir.Value], - bits: int, - block_size: int, - force_arch: int, - ) -> ir.Node: - # Extract and validate weight initializers. - fc1_w_value = _get_input(node, _MOE_INPUT_INDEX["fc1_W"]) - fc2_w_value = _get_input(node, _MOE_INPUT_INDEX["fc2_W"]) - fc1_array = _require_initializer(fc1_w_value, "fc1_experts_weights", initializers) - fc2_array = _require_initializer(fc2_w_value, "fc2_experts_weights", initializers) - - if fc1_array.ndim != 3 or fc2_array.ndim != 3: - raise _UnsupportedMoEError( - f"Expected 3-D weights; got fc1.ndim={fc1_array.ndim}, fc2.ndim={fc2_array.ndim}." - ) - num_experts = fc1_array.shape[0] - if fc2_array.shape[0] != num_experts: - raise _UnsupportedMoEError(f"fc1/fc2 num_experts disagree: {fc1_array.shape[0]} vs {fc2_array.shape[0]}.") +# --------------------------------------------------------------------------- +# Graph rewrite helpers (module-private, per Google-style guide preference for +# free functions over static / class methods when no shared state is involved) +# --------------------------------------------------------------------------- - # Quantize each expert weight independently and stack. - fc1_qweights, fc1_scales = _quantize_stacked_weights( - fc1_array, bits=bits, block_size=block_size, force_arch=force_arch - ) - fc2_qweights, fc2_scales = _quantize_stacked_weights( - fc2_array, bits=bits, block_size=block_size, force_arch=force_arch - ) - # Build new initializers, named to keep the original tensor name as a prefix. - graph = node.graph - fc1_w_init = _make_initializer(f"{fc1_w_value.name}_q", fc1_qweights) - fc1_s_init = _make_initializer(f"{fc1_w_value.name}_scales", fc1_scales) - fc2_w_init = _make_initializer(f"{fc2_w_value.name}_q", fc2_qweights) - fc2_s_init = _make_initializer(f"{fc2_w_value.name}_scales", fc2_scales) - for init in (fc1_w_init, fc1_s_init, fc2_w_init, fc2_s_init): - graph.register_initializer(init) - - # Carry biases through unchanged. - fc1_bias = _maybe_input(node, _MOE_INPUT_INDEX["fc1_b"]) - fc2_bias = _maybe_input(node, _MOE_INPUT_INDEX["fc2_b"]) - fc3_w = _maybe_input(node, _MOE_INPUT_INDEX["fc3_W"]) - if fc3_w is not None: - raise _UnsupportedMoEError("fc3 inputs are not yet supported by this pass.") - - # QMoE inputs (zero_points left absent because we use symmetric int4/int8): - # 0: input - # 1: router_probs - # 2: fc1_experts_weights (quantized) - # 3: fc1_scales - # 4: fc1_zero_points (None — symmetric) - # 5: fc1_experts_bias - # 6: fc2_experts_weights (quantized) - # 7: fc2_scales - # 8: fc2_zero_points (None) - # 9: fc2_experts_bias - qmoe_inputs = [ - _get_input(node, _MOE_INPUT_INDEX["input"]), - _get_input(node, _MOE_INPUT_INDEX["router_probs"]), - fc1_w_init, - fc1_s_init, - None, - fc1_bias, - fc2_w_init, - fc2_s_init, - None, - fc2_bias, - ] - - # Copy all routing / activation attributes verbatim, then add the - # quantization-specific ones. - new_attrs = list(node.attributes.values()) - new_attrs.append(ir.AttrInt64("expert_weight_bits", bits)) - if block_size > 0: - new_attrs.append(ir.AttrInt64("block_size", block_size)) - # ``quant_type`` defaults to "int" in the schema; emit it explicitly - # for forward compatibility against future schema revisions that may - # introduce other defaults. - if not any(a.name == "quant_type" for a in node.attributes.values()): - new_attrs.append(ir.AttrString("quant_type", "int")) - - output_value = ir.Value(name=node.outputs[0].name) - return ir.Node( - domain=MSFT_DOMAIN, - op_type=_QMOE_OP_TYPE, - inputs=qmoe_inputs, - attributes=new_attrs, - outputs=[output_value], - name=node.name + "_QMoE" if node.name else None, - ) +def _convert_moe_to_qmoe( + ir_model: ir.Model, + bits: int, + block_size: int, + nodes_to_exclude: list[str], + force_arch: int, +) -> int: + """Walk ``ir_model.graph`` and rewrite every MoE node to a QMoE node. + + Returns the number of nodes successfully converted. Nodes whose weights + can't be statically quantized (e.g. dynamic weight inputs, shape that + doesn't divide cleanly into pack tiles) are skipped with a logger + warning rather than aborting the whole pass. + """ + graph = ir_model.graph + initializers: dict[str, ir.Value] = dict(graph.initializers) + excluded = set(nodes_to_exclude) + converted = 0 + + for node in list(graph.all_nodes()): + if node.op_type != _MOE_OP_TYPE or node.domain != MSFT_DOMAIN: + continue + if node.name in excluded: + logger.debug("Skipping MoE node %s (in nodes_to_exclude).", node.name) + continue + + try: + qmoe_node = _convert_single_moe(node, initializers, bits=bits, block_size=block_size, force_arch=force_arch) + except _UnsupportedMoEError as exc: + logger.warning("Skipping MoE node %s: %s", node.name or "", exc) + continue + + ir.convenience.replace_nodes_and_values(graph, node, [node], [qmoe_node], node.outputs, qmoe_node.outputs) + converted += 1 + + return converted + + +def _convert_single_moe( + node: ir.Node, + initializers: dict[str, ir.Value], + bits: int, + block_size: int, + force_arch: int, +) -> ir.Node: + """Build the QMoE replacement for a single MoE node. + + Quantizes the per-expert FC1/FC2 initializers and registers the new + initializers on the same graph as ``node``. Raises + ``_UnsupportedMoEError`` if the node's shape, dtype, or input set + isn't something this pass can handle (the caller logs and skips). + """ + fc1_w_value = _get_input(node, _MOE_INPUT_INDEX["fc1_W"]) + fc2_w_value = _get_input(node, _MOE_INPUT_INDEX["fc2_W"]) + fc1_array = _require_initializer(fc1_w_value, "fc1_experts_weights", initializers) + fc2_array = _require_initializer(fc2_w_value, "fc2_experts_weights", initializers) + + if fc1_array.ndim != 3 or fc2_array.ndim != 3: + raise _UnsupportedMoEError(f"Expected 3-D weights; got fc1.ndim={fc1_array.ndim}, fc2.ndim={fc2_array.ndim}.") + + num_experts = fc1_array.shape[0] + if fc2_array.shape[0] != num_experts: + raise _UnsupportedMoEError(f"fc1/fc2 num_experts disagree: {fc1_array.shape[0]} vs {fc2_array.shape[0]}.") + + fc1_qweights, fc1_scales = _quantize_stacked_weights( + fc1_array, bits=bits, block_size=block_size, force_arch=force_arch + ) + fc2_qweights, fc2_scales = _quantize_stacked_weights( + fc2_array, bits=bits, block_size=block_size, force_arch=force_arch + ) + + graph = node.graph + fc1_w_init = _make_initializer(f"{fc1_w_value.name}_q", fc1_qweights) + fc1_s_init = _make_initializer(f"{fc1_w_value.name}_scales", fc1_scales) + fc2_w_init = _make_initializer(f"{fc2_w_value.name}_q", fc2_qweights) + fc2_s_init = _make_initializer(f"{fc2_w_value.name}_scales", fc2_scales) + for init in (fc1_w_init, fc1_s_init, fc2_w_init, fc2_s_init): + graph.register_initializer(init) + + fc1_bias = _maybe_input(node, _MOE_INPUT_INDEX["fc1_b"]) + fc2_bias = _maybe_input(node, _MOE_INPUT_INDEX["fc2_b"]) + fc3_w = _maybe_input(node, _MOE_INPUT_INDEX["fc3_W"]) + if fc3_w is not None: + raise _UnsupportedMoEError("fc3 inputs are not yet supported by this pass.") + + # QMoE input layout (zero_points stay absent because we use symmetric + # int4/int8): + # 0: input + # 1: router_probs + # 2: fc1_experts_weights (quantized) + # 3: fc1_scales + # 4: fc1_zero_points (None — symmetric) + # 5: fc1_experts_bias + # 6: fc2_experts_weights (quantized) + # 7: fc2_scales + # 8: fc2_zero_points (None) + # 9: fc2_experts_bias + qmoe_inputs = [ + _get_input(node, _MOE_INPUT_INDEX["input"]), + _get_input(node, _MOE_INPUT_INDEX["router_probs"]), + fc1_w_init, + fc1_s_init, + None, + fc1_bias, + fc2_w_init, + fc2_s_init, + None, + fc2_bias, + ] + + new_attrs = list(node.attributes.values()) + new_attrs.append(ir.AttrInt64("expert_weight_bits", bits)) + if block_size > 0: + new_attrs.append(ir.AttrInt64("block_size", block_size)) + # ``quant_type`` defaults to ``"int"`` in the schema; emit it + # explicitly so future schema revisions changing the default don't + # silently alter behaviour for our exported models. + if not any(a.name == "quant_type" for a in node.attributes.values()): + new_attrs.append(ir.AttrString("quant_type", "int")) + + output_value = ir.Value(name=node.outputs[0].name) + return ir.Node( + domain=MSFT_DOMAIN, + op_type=_QMOE_OP_TYPE, + inputs=qmoe_inputs, + attributes=new_attrs, + outputs=[output_value], + name=node.name + "_QMoE" if node.name else None, + ) # --------------------------------------------------------------------------- -# Helpers +# Small graph / value helpers # ---------------------------------------------------------------------------