Skip to content

OnnxMoEQuantization: port pack_weights_for_cuda_mixed_gemm to numpy / cupy so the pass runs on CPU-only hosts #2492

@justinchuby

Description

@justinchuby

Context

#2491 adds OnnxMoEQuantization which rewrites com.microsoft::MoE nodes to com.microsoft::QMoE. Per-expert quantization uses ORT's quantize_matmul_4bits / quantize_matmul_8bits pybind helpers (CPU-side, always available), but the CUTLASS layout prepack uses pack_weights_for_cuda_mixed_gemm which is only exported when ORT is built with USE_CUDA (source).

Practical consequence: today, to run OnnxMoEQuantization, you have to install onnxruntime-gpu (or a nightly built from main with CUDA). CPU-only onnxruntime install gets a clean RuntimeError with the install instructions, but it's still a friction.

Why the prepack is needed

The QMoE CUDA kernel reads fc1_experts_weights->DataRaw() straight from the initializer and hands it to the CUTLASS fpA_intB GEMM runner without any prepacking (moe_quantization.cc — slot 2/5 are skipped in PrePack for quant_type='int'). The CUTLASS kernel needs weights in a specific row-permuted, column-interleaved, byte-pair-interleaved layout that matches its ldsm + tensor-core MMA tile shape, or it produces garbage output silently.

Filed an upstream issue to fix this asymmetry (microsoft/onnxruntime#28748QMoE should prepack in PrePack() like MatMulNBits already does), but that's not going to land + ship for a while.

What we could do in Olive in the meantime

Port the prepack to numpy with optional cupy acceleration. The transform is pure data movement:

  1. Permute rows in 32-row tiles using a small static table (kPerm_W4_A16 for int4 on sm75/80; sm90 skips this step). Tables live in fpA_intB_gemm_preprocessors_impl.h.
  2. Sub-byte transpose to column-major. Need to unpack pairs of int4 nibbles, transpose, repack.
  3. Column interleave by columns_interleaved (per arch; 4 for sm80 W4, skipped for sm90).
  4. Bias + byte-pair interleave: add +8 to each nibble (signed→unsigned shift) and interleave the four bytes in each 4-byte group.

Skeleton:

def _prepack_int4_cutlass(packed_qweight_nk: np.ndarray, n: int, k: int, sm: int) -> np.ndarray:
    """Pure-numpy port of pack_weights_for_cuda_mixed_gemm for W4_A16."""
    xp = _array_lib()                         # numpy by default, cupy if available
    w = xp.asarray(packed_qweight_nk)
    w = _permute_rows_w4(w, sm)               # static-table permutation
    w = _subbyte_transpose_w4(w, n, k)
    if sm != 90:
        w = _interleave_columns_w4(w, sm)
    w = _add_bias_and_pair_interleave_w4(w)
    return xp.asnumpy(w) if xp is not np else w

Validation strategy

Without a CUDA-built ORT in CI, we can't directly diff against the reference. Options:

  • Golden vectors: capture small (E=2, N=16, K=32) outputs from a CUDA-built ORT once and check them into the repo as a fixture.
  • Round-trip: dequantize the prepacked tensor (in numpy, reverse-engineering the layout) back to [N, K] int values; check they equal the input. Plus a small pytest.mark.cuda_runtime integration test that loads the produced model on actual ORT-GPU and confirms numerical parity vs the fp16 baseline.

Open questions

  • Worth doing now, or wait for #28748 to land in ORT? Upside of doing it now: pass becomes installable from CPU CI / dev laptops. Downside: maintenance burden of keeping the layout port in sync with any future CUTLASS-side tweaks.
  • If yes: numpy-only first (correctness-focused), and only add cupy if profiling shows the numpy version is unacceptably slow on real models.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions