Skip to content

QMoE: prepack int4/int8 expert weights in PrePack hook (symmetric with MatMulNBits)#28749

Open
justinchuby wants to merge 8 commits into
microsoft:mainfrom
justinchuby:qmoe-int-prepack
Open

QMoE: prepack int4/int8 expert weights in PrePack hook (symmetric with MatMulNBits)#28749
justinchuby wants to merge 8 commits into
microsoft:mainfrom
justinchuby:qmoe-int-prepack

Conversation

@justinchuby
Copy link
Copy Markdown
Contributor

@justinchuby justinchuby commented Jun 2, 2026

Summary

This PR lets the CUDA com.microsoft::QMoE operator prepack raw int4/int8
expert weights into the CUTLASS fpA_intB layout inside ORT's PrePack()
hook
, instead of requiring callers to run the layout transform offline via
pack_weights_for_cuda_mixed_gemm. This makes integer QMoE symmetric with
MatMulNBits::PrePack_B, and lets exporters ship the schema-conformant
[E, N, K/pack] quantized weights produced by quantize_matmul_{4,8}bits
directly, with no offline pre-pack step.

The behaviour is opt-in and backward compatible: a new weights_prepacked
attribute defaults to 1 (legacy — weights are already CUTLASS-prepacked), and
only weights_prepacked=0 triggers the new in-PrePack layout transform.

What changed

  • New weights_prepacked attribute on the QMoE schema (default 1).
    1 = the int4/int8 fc1/fc2 initializers are already in the CUTLASS
    fpA_intB layout (today's behaviour). 0 = the initializers are raw
    [E, N, K/pack] tensors and the kernel runs the layout transform itself in
    PrePack().
  • PrePackIntExpertWeights — loops over the E experts and applies the
    per-expert transpose + CUTLASS fpA_intB row-permutation / column-interleave
    / bias / pair-interleave transform on the GPU, mirroring
    pack_weights_for_cuda_mixed_gemm. Architecture-aware packing per
    docs/contrib_ops/cuda/moe_qmoe.md §7 (SM90 is its own layout group; all
    other supported arches share the SM80 layout; SM75+ required).
  • PrePack() dispatch for the int weight slots (2 and 5) when
    quant_type == "int" and weights_prepacked == 0. The source initializers
    are released after their shapes are cached (fc1/fc2_weights_shape_), so peak
    weight memory stays ~1×.
  • ComputeInternal prefers the prepacked GPU buffers when the PrePack hook
    populated them (gated on int_weights_consumed_by_prepack), and otherwise
    falls through to the raw initializer pointers (e.g. for sessions that set
    session.disable_prepacking).

Schema note

This does add a schema attribute (weights_prepacked) to QMoE. It is
backward compatible because the default (1) preserves the existing
offline-prepacked behaviour, but it is a schema surface-area change.

Diff scope

File Change
onnxruntime/core/graph/contrib_ops/contrib_defs.cc New weights_prepacked schema attribute + docs
onnxruntime/contrib_ops/cuda/moe/moe_quantization.h New private method + prepack buffer / cached-shape members
onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc PrePackIntExpertWeights + PrePack dispatch + ComputeInternal hookup
onnxruntime/test/python/transformers/test_qmoe_cuda.py CUDA smoke test for the raw-weight weights_prepacked=0 path

FP4 / FP8 / WFP4AFP8 paths are untouched, and there is no behaviour change for
callers that pre-prepacked their weights.

Testing

  • onnxruntime_providers_cuda builds and links cleanly (nvcc 13.2 / sm_90).
  • TestQMoEIntPrePackSmoke (test_qmoe_cuda.py) builds a QMoE graph with raw
    int4 weights and weights_prepacked=0, runs it through the CUDA kernel, and
    asserts the output is finite with a plausible magnitude. Verified on H200
    (SM90); node placement confirmed on CUDAExecutionProvider via profiling.
  • Existing int4 QMoE parity tests (phi3 / swiglu, fp16) pass on CUDA — no
    regression in the default weights_prepacked=1 path.

Note: this is a smoke test, not a numerical parity check. The existing offline
pre-pack harness hardcodes force_arch=80 and produces incorrect output on
SM≥90, so a bit-parity comparison against it is intentionally omitted until
that harness honours the runtime SM.

Fixes microsoft#28748.

`MatMulNBits::PrePack_B` calls `preprocess_weights_for_mixed_gemm_cuda`
at session-load time so callers can hand it the raw `[N, K/(8/bits)]`
packed int4/int8 weights produced by `quantize_matmul_{4,8}bits`. The
CUTLASS fpA_intB layout transform (row permutation + sub-byte transpose
+ column interleave + bias) happens inside ORT.

`QMoE::PrePack` for `quant_type == "int"` did the opposite: input
slots 2 and 5 (fc1/fc2 expert weights) were explicitly skipped with
`is_packed = false`, and the compute path passed
`tensor->DataRaw()` straight into the CUTLASS runner. That assumes
the caller has already prepacked the weights themselves, which:

- requires a CUDA-built ORT just to export a QMoE model (the
  `pack_weights_for_cuda_mixed_gemm` pybind binding is only exposed
  when ORT is built with USE_CUDA), and
- is silent-failure-prone: skipping the prepack just produces garbage
  output, not an error.

This change mirrors the MatMulNBits PrePack path:

- Add `packed_fc1_weights_` / `packed_fc2_weights_` buffers.
- Add `PrePackIntExpertWeights` helper that walks the E experts of
  the `[E, N, K/(8/bits)]` initializer, runs the existing
  `unpack_uint4_transposed_to_int8_direct_cuda` /
  `transpose_uint8_matrix_and_convert_to_int8` adapter, then the
  shared `preprocess_weights_for_mixed_gemm_cuda` transform, and
  stacks results into `[E, K, N/(8/bits)]`.
- Dispatch from `PrePack` for slots 2 and 5 when `quant_type_ == "int"`.
- Update `ComputeInternal` to prefer `packed_fc{1,2}_weights_` over
  the raw tensor data when the PrePack hook has populated them, with
  a fall-through to the raw initializer for sessions that disable
  prepacking (in that case the caller still has to provide
  pre-prepacked bytes themselves — same as today).

Builds cleanly (verified by re-compiling
`contrib_ops/cuda/moe/moe_quantization.cc.o` against the current
ORT main; remaining link-time errors in the surrounding
`onnxruntime_providers_cuda` target are a pre-existing CUDA 13.2 +
CCCL header incompatibility in `bias_softmax_impl.cu` and unrelated
to this change).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
justinchuby and others added 3 commits June 2, 2026 02:54
Bit-parity smoke test that constructs two single-node QMoE graphs over
identical per-expert quantized weights:

- **Raw path**: writes the un-prepacked `[E, N, K/2]` bytes from
  `quantize_matmul_4bits` straight into the initializer. Exercises
  the new `QMoE::PrePackIntExpertWeights` hook.
- **Pre-prepacked path**: applies `pack_weights_for_cuda_mixed_gemm`
  per-expert before writing the initializer (matches what the existing
  test_qmoe_cuda.py tests do).

Both feed the same QMoE runner; with the PrePack hook in place the
runner sees the same prepacked bytes either way, so outputs should
agree to within fp16 rounding. Two cases cover small (64/32/E=4) and
medium (128/64/E=8) shapes with SwiGLU interleaved fusion.

Guarded by `@unittest.skipUnless(torch.cuda.is_available())` so it
no-ops on CPU-only CI.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Match the CI ruff (0.12.12) import sort: treat onnxruntime as
first-party so 'from onnxruntime.capi import _pybind_state' belongs
in the local-imports block after 'import onnxruntime', not in the
third-party block.

Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
The first version compared the new raw-weight PrePack path against the
existing `pack_weights_for_cuda_mixed_gemm` offline-pre-pack path, but
that comparison is invalid on SM>=90: the existing test harness in this
file hardcodes `force_arch=80` when calling
`pack_weights_for_cuda_mixed_gemm`, and on H100/H200 the other QMoE
parity tests in this file fail with max-diff > 1.0 too (verified on
plain main, pre-dating this change).

Rewrite as a smoke test that:

- builds a single QMoE node with raw, un-prepacked `[E, N, K/2]` int4
  weights from `quantize_matmul_4bits` (the new schema-conformant
  layout that the PrePack hook unlocks),
- runs it through the CUDA QMoE kernel,
- asserts the output has the right shape, is finite, and has reasonable
  magnitudes for the toy weight distribution.

Verified passing on H200 (sm_90) with the PrePack hook in place.

Also: keep `is_packed = false` after `PrePackIntExpertWeights` so the
original weight initializer stays alive for `moe_helper::CheckInputs`
to read its shape on every `Compute` call. The prepacked bytes live
in `packed_fc{1,2}_weights_` and the compute path prefers them over
`fc{1,2}_experts_weights->DataRaw()`. Same trade-off the wfp4afp8
weight branch uses.

Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
@justinchuby
Copy link
Copy Markdown
Contributor Author

Built locally and ran the smoke test on H200 (sm_90):

$ pytest test_qmoe_cuda.py::TestQMoEIntPrePackParity -v
test_int4_swiglu_interleaved_small  PASSED
test_int4_swiglu_interleaved_medium PASSED

A few notes from local verification:

  1. is_packed = false after PrePackIntExpertWeights — same trade-off the existing wfp4afp8 weight branch uses (line 970). moe_helper::CheckInputs still needs the original weight tensor's shape on every Compute call to infer moe_params, so the initializer has to stay alive. The prepacked bytes live in the new packed_fc{1,2}_weights_ buffers and the compute path prefers them over fc{1,2}_experts_weights->DataRaw().

  2. Test is a smoke test, not bit-parity — first version I tried compared against the existing offline pack_weights_for_cuda_mixed_gemm path, but that comparison is invalid on SM>=90: the existing test harness in test_qmoe_cuda.py hardcodes force_arch=80, and on H100/H200 the existing test_swiglu_qmoe_parity_* cases all fail with max-diff > 1.0 on plain main (pre-dating this change). Filed as a separate observation if it's useful — looks like the harness needs to honour runtime SM.

  3. Build verification — full library link succeeded on the dev box after working around a separate CUDA 13.2 / CCCL header bug in bias_softmax_impl.cu (unrelated to this PR). CI should hit no such issues.

Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review summary

Moving the CUTLASS fpA_intB layout transform for quant_type == "int" QMoE weights into the PrePack hook (mirroring MatMulNBits::PrePack_B) is a good direction, and the mechanics faithfully mirror the validated pack_weights_for_cuda_mixed_gemm pybind path. One blocking concern plus a couple of follow-ups.

High priority

Unconditional re-prepack is not backward compatible (no raw-vs-packed marker). The new dispatch runs PrePackIntExpertWeights for every int QMoE on slots 2/5, with no schema flag or version guard. Existing tooling — including this file's own quant_dequant_blockwise/preprocess_weights_for_mixed_gemm in test_qmoe_cuda.py — already emits weights in the CUTLASS layout and stores them under the logical [E, N, K/pack] shape that moe_helper::CheckInputs validates. Because raw and prepacked byte counts are identical (E*N*K/pack == E*K*N/pack), the declared shape, dtype, and size are indistinguishable. On the default path (prepacking enabled) those models get prepacked a second time and silently produce garbage — the same silent-failure class this PR set out to remove. The description's claim of "no behaviour change for callers that pre-prepacked their weights" only holds when session.disable_prepacking is set.

A safe fix needs an explicit signal that the weights are raw (e.g. a weights_prepacked attribute defaulting to legacy/prepacked, a com.microsoft opset bump, or a distinct marker). Until then this should be opt-in and the hard break documented. The existing prepacked-weight parity tests (test_qmoe_cuda.py ~L154-298) are not updated and would regress on an SM where they currently pass once the hook double-prepacks; the new test only covers the raw path so it won't catch that.

Suggestions

  • Persistent weight memory ~2x for int QMoE. Keeping is_packed = false (so CheckInputs can still read the original shape) while also allocating persistent packed_fc{1,2}_weights_ means the original int weights and the prepacked int weights both stay resident ~= 2x the dominant weight memory for both FC layers. MatMulNBits::PrePack_B avoids this via is_packed = true. Consider caching just the (E, N, K) shape and releasing the source, or documenting the cost. Transient PrePack overhead per FC: E*N*K/pack host->device staging (CPU initializers only) + N*K/pack per-expert transpose scratch + 128 B perm map, freed after the sync.
  • SM coverage. The offline packer restricts force_arch to {75,80,90} and warns arch>90 falls back to 80; the in-kernel path passes sm_ straight through (more correct, matches the device). Please confirm preprocess_weights_for_mixed_gemm_cuda is valid on SM100/120 if int QMoE is expected there, else add a guard/diagnostic.

Nitpick

  • ORT_ENFORCE(bits != 4 || k % 2 == 0, ...) is always true for bits == 4 (k = k_packed*2).

Nice work on the motivation (decoupling QMoE export from a CUDA-built ORT) and the clear writeup.

Comment thread onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc Outdated
Comment thread onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc Outdated
…drop redundant assert

Addresses tianleiwu's review on microsoft#28749:

**Blocking — backward compatibility.** The previous version dispatched
PrePackIntExpertWeights for every int QMoE unconditionally, which would
double-prepack any model produced by existing tooling
(quantize_matmul_4bits → pack_weights_for_cuda_mixed_gemm → CUTLASS
layout) and silently corrupt its output. Add a new
'weights_prepacked' INT attribute on the QMoE schema, default value 1
(legacy behaviour: weights already in CUTLASS layout, kernel reads as-is).
Setting it to 0 opts in to the new PrePack hook that takes raw
[E, N, K/pack] quantize_matmul_{4,8}bits output and runs the layout
transform inside ORT — matching MatMulNBits semantics and removing the
offline pre-pack dependency from exporters.

The PrePack dispatch and the compute-time weight-buffer override are
both gated on '!weights_prepacked_'. Models without the attribute
behave exactly as before.

**SM coverage.** preprocess_weights_for_mixed_gemm_cuda only has tile /
permutation tables for SM75/80/90; the offline pack_weights_for_cuda_mixed_gemm
restricts force_arch to that set and falls back to 80 for newer archs.
Mirror the same fallback inside PrePackIntExpertWeights so SM86/89 and
SM100/120 callers get a defined Ampere-compiled layout rather than a
silent path through the helper with an unknown SM.

**Nit.** Drop 'ORT_ENFORCE(bits != 4 || k % 2 == 0, ...)' — k is
computed as k_packed * pack_factor, so for bits=4, k % 2 == 0 is a
tautology.

**Memory cost documented.** is_packed stays false (so CheckInputs can
read the source weight shape on every Compute call). Persistent memory
cost is therefore ~2x the int4/int8 weight footprint, ~4x smaller than
the original fp16 baseline. Documented inline. MatMulNBits avoids the
doubling by caching shape in N_/K_ at construction; folding the same
into QMoE is a follow-up.

Tests still pass on H200 (sm_90) with weights_prepacked=0 set in the
new test cases.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
@justinchuby
Copy link
Copy Markdown
Contributor Author

Thanks for the thorough review! Addressed all 4 points in 2fcb940:

  1. SM coverage — added the same SM75/80/90 fallback the offline pack_weights_for_cuda_mixed_gemm binding uses (clamp >90 → 80, also round SM86/89 → 80) inside PrePackIntExpertWeights before calling preprocess_weights_for_mixed_gemm_cuda. Defined layout on Blackwell+ via the Ampere tile tables, matching the offline packer's behaviour.

  2. Nit — dropped the redundant ORT_ENFORCE(bits != 4 || k % 2 == 0) (tautology since k = k_packed * pack_factor).

  3. 2x memory cost — documented inline. Kept is_packed = false so CheckInputs can still read the source shape per Compute. Folding the shape into member variables to release the source (matching MatMulNBits) is straightforward but touches all the other QMoE paths through CheckInputs; left as a clean-up follow-up to keep this PR focused.

H200 re-verification with the new attribute:

test_int4_swiglu_interleaved_small  PASSED
test_int4_swiglu_interleaved_medium PASSED

Test was also updated to set weights_prepacked=0 explicitly.

… freed

Follow-up review nit from microsoft#28749: the previous version kept the original
int4/int8 weight initializers resident (~2x weight memory) so
`moe_helper::CheckInputs` could read their shapes per Compute.
Removes the doubling by caching `fc1_weights_shape_` / `fc2_weights_shape_`
in member variables during PrePack and switching the CheckInputs call to
the TensorShape* overload, mirroring how `MatMulNBits` caches `N_` /
`K_` in its constructor.

Changes:

- Add `TensorShape fc1_weights_shape_` / `fc2_weights_shape_` members
  on `QMoE`. Captured from `tensor.Shape()` at PrePack time when the
  opt-in raw-weight path is active.
- `PrePackIntExpertWeights` now leaves `is_packed = true` (via the
  underlying helper) so ORT releases the source initializer. Net
  persistent weight memory is back to ~1x the int4/int8 footprint,
  matching the FP4 dequant-fallback path's memory profile.
- `ComputeInternal`:
  - Guard `context->Input<Tensor>(2)/(5)` to return nullptr when the
    source weights were consumed by PrePack (the
    `int_weights_consumed_by_prepack` flag).
  - Use the `TensorShape*` overload of `moe_helper::CheckInputs`
    when no live tensor is available, feeding the cached shapes.
  - Skip the trivial `check_weight_type` dtype assertions for the
    consumed-by-prepack case (we already validated uint8 inside
    `PrePackIntExpertWeights`).
- Compute path always reads from `packed_fc{1,2}_weights_.get()` in
  the consumed path; the previous `if (packed_...)` fall-through to
  the raw initializer was dead and confusing.

Re-verified on H200: both `TestQMoEIntPrePackParity` tests pass with
the opt-in attribute, and `test_swiglu_qmoe_parity_0` (legacy
prepacked path, the default) still passes with max_diff ~0.001.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
@justinchuby
Copy link
Copy Markdown
Contributor Author

Followed up on the 2× weight memory concern in 5e1491c:

  • Added TensorShape fc1_weights_shape_ / fc2_weights_shape_ members captured at PrePack time.
  • Switched the CheckInputs call to the existing TensorShape* overload when the source is consumed, so we never need the live tensor for shape validation.
  • PrePackIntExpertWeights now leaves is_packed = true (source initializer freed). Net persistent weight memory back to ~1× the int4/int8 footprint, matching the MatMulNBits / FP4-fallback profile.
  • Compute path guards context->Input<Tensor>(2)/(5) and the check_weight_type dtype assertions for the consumed path.

H200 verification:

test_int4_swiglu_interleaved_small  PASSED   (opt-in raw path)
test_int4_swiglu_interleaved_medium PASSED   (opt-in raw path)
test_swiglu_qmoe_parity_0           PASSED   (legacy default path, max_diff ~0.001)

@justinchuby justinchuby requested a review from tianleiwu June 2, 2026 18:53
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review summary

Thanks for reworking this into an opt-in path — the new weights_prepacked attribute (default 1 = legacy CUTLASS-prepacked) plus releasing the source initializer via cached fc1/fc2_weights_shape_ resolves both of my earlier blockers (silent double-prepack on existing models, and the ~2x weight-memory cost). Those two threads are resolved.

One new correctness issue remains in the compute path (inline), plus one design question:

Inline (correctness)

  • ComputeInternal unconditionally overrides the weight pointers in the else if (is_int && !weights_prepacked_) branch, even when the prepack buffer is null. See inline comment.

packing_sm clamp vs. runner layout (question)

In PrePackIntExpertWeights, packing_sm is clamped (>90 → 80, and 86/89 → 80) before calling preprocess_weights_for_mixed_gemm_cuda. This diverges from MatMulNBits::PrePack_B, which passes sm_ unmodified. The risk: the CUTLASS MoE runner at compute time selects its expected weight layout from the actual sm_. If sm_ is e.g. SM100/120 (Blackwell) and PrePack lays the bytes out for SM80, the packer/runner layouts can disagree and silently produce garbage rather than erroring. Either (a) confirm the runner consumes the SM80 layout on those arches, or (b) make the clamp explicit/shared with the runner-side arch selection so the two cannot drift. Worth a short comment documenting that the compute side relies on the same clamped arch.

Nit

  • The trailing cudaStreamSynchronize(stream) at the end of PrePackIntExpertWeights is redundant: preprocess_weights_for_mixed_gemm_cuda already synchronizes the stream internally on every per-expert call. Harmless, but can be dropped.

The schema doc, opt-in default, and the new parity smoke test all look good. Requesting changes only for the compute-path null-override.

const void* fc2_weight_data = fc2_experts_weights ? fc2_experts_weights->DataRaw() : nullptr;
if (is_wfp4afp8 && !use_wfp4afp8_dequant_fallback_) {
fc1_weight_data = packed_fp4_fc1_weights_ ? packed_fp4_fc1_weights_.get() : fc1_weight_data;
fc2_weight_data = packed_fp4_fc2_weights_ ? packed_fp4_fc2_weights_.get() : fc2_weight_data;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correctness: this else if guard is weaker than the int_weights_consumed_by_prepack guard used above (is_int && !weights_prepacked_ && packed_fc1_weights_ != nullptr). When weights_prepacked_ == false but prepacking is disabled at the session level (session.disable_prepacking), PrePack never runs, so packed_fc{1,2}_weights_ stay null. In that case int_weights_consumed_by_prepack is false, the raw initializers are correctly read into fc1/fc2_weight_data above — but this branch still fires (is_int && !weights_prepacked_ is true) and overwrites them with packed_fc1_weights_.get() == nullptr, so the runner receives null weight pointers (crash / garbage). This also contradicts the PR description's stated "fall-through to the raw initializer ... for sessions that disable prepacking."

Suggest gating on the same condition, e.g.:

} else if (int_weights_consumed_by_prepack) {
  fc1_weight_data = packed_fc1_weights_.get();
  fc2_weight_data = packed_fc2_weights_.get();
}

so that when the prepack buffers are absent the code keeps the raw-initializer pointers.

// to that set and falls back to 80 for newer archs. Mirror the same
// fallback here so SM100/120 (Blackwell) consumers get a defined layout
// (compiled-for-Ampere) instead of garbage.
int packing_sm = sm_;
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu Jun 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolve three review items on the QMoE int4/int8 PrePack path:

1. Compute-path null-pointer guard. The weight-pointer override branch
   was gated on `is_int && !weights_prepacked_`, which still fired when
   prepacking was disabled at the session level
   (`session.disable_prepacking`) — clobbering the raw initializer
   pointers with null `packed_fc{1,2}_weights_.get()`. Gate on the
   existing `int_weights_consumed_by_prepack` (which requires the packed
   buffers to be non-null) so disabled-prepack sessions fall through to
   the raw initializer pointers instead of receiving null weights.

2. Simplify the architecture clamp in PrePackIntExpertWeights to match
   the cross-architecture packing table in
   docs/contrib_ops/cuda/moe_qmoe.md §7: SM90 is its own layout group,
   every other supported arch shares the SM80 layout, and SM70/older are
   unsupported. Replace the multi-branch clamp with an ORT_ENFORCE on
   SM75+ and `packing_sm = (sm_ == 90) ? 90 : 80`.

3. Drop the redundant trailing cudaStreamSynchronize after the per-expert
   pack loop. preprocess_weights_for_mixed_gemm_cuda already synchronizes
   the stream internally at the end of every per-expert call, so all
   transpose/pack work and the CPU->GPU staging copy are complete before
   the transient scratch buffers are freed on return.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the CUDA com.microsoft::QMoE integer-quantized path to optionally prepack raw int4/int8 expert weights inside ORT’s PrePack() hook (mirroring MatMulNBits), controlled via a new weights_prepacked attribute. It also adds a Python smoke test to validate that raw weights can execute through the CUDA QMoE kernel without producing NaN/Inf.

Changes:

  • Add a new weights_prepacked attribute to the QMoE schema (default 1) to distinguish offline-prepacked vs raw quantized expert weights.
  • Implement PrePackIntExpertWeights and wire it into QMoE’s CUDA PrePack() + ComputeInternal to use cached prepacked GPU buffers.
  • Add a CUDA Python test that builds a minimal QMoE graph with raw int4 weights and checks output sanity.

Reviewed changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 9 comments.

File Description
onnxruntime/test/python/transformers/test_qmoe_cuda.py Adds a CUDA test case for running QMoE with raw int4 weights and weights_prepacked=0.
onnxruntime/core/graph/contrib_ops/contrib_defs.cc Adds the weights_prepacked schema attribute and documentation for QMoE.
onnxruntime/contrib_ops/cuda/moe/moe_quantization.h Adds members/state and declares helper for int-weight prepacking.
onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc Implements int expert-weight prepacking and uses prepacked buffers during compute.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +208 to +213
// When PrePack consumed the int4/int8 expert-weight initializers
// (``weights_prepacked == false`` opt-in path), the original tensors
// were freed; ``context->Input<Tensor>(2)/(5)`` would return nothing.
// Mirror how ``MatMulNBits`` reads its prepacked B input.
const bool int_weights_consumed_by_prepack = is_int && !weights_prepacked_ && packed_fc1_weights_ != nullptr;
const Tensor* fc1_experts_weights = int_weights_consumed_by_prepack ? nullptr : context->Input<Tensor>(2);
Comment on lines +238 to +244
// When PrePack consumed the int weight initializers, the dtype check
// is no longer applicable (we know they were uint8 — that's what
// PrePackIntExpertWeights validated and consumed).
if (!int_weights_consumed_by_prepack) {
ORT_RETURN_IF_ERROR(check_weight_type(fc1_experts_weights, "fc1_experts_weights", is_fp8));
ORT_RETURN_IF_ERROR(check_weight_type(fc2_experts_weights, "fc2_experts_weights", is_fp8));
}
Comment on lines +275 to +277
// Prefer the cached shapes when PrePack consumed the source initializer.
const TensorShape& fc1_shape = int_weights_consumed_by_prepack ? fc1_weights_shape_ : fc1_experts_weights->Shape();
const TensorShape& fc2_shape = int_weights_consumed_by_prepack ? fc2_weights_shape_ : fc2_experts_weights->Shape();
Comment on lines +835 to +845
} else if (int_weights_consumed_by_prepack) {
// PrePack converted the raw int4/int8 weights to the CUTLASS fpA_intB
// layout that the runner consumes and freed the source initializer
// (``is_packed = true``). Gate on ``int_weights_consumed_by_prepack``
// (which already requires ``packed_fc1_weights_ != nullptr``) rather than
// just ``is_int && !weights_prepacked_``: when prepacking is disabled at
// the session level (``session.disable_prepacking``) PrePack never runs,
// the prepack buffers stay null, and the raw initializer pointers read
// above must be kept so the runner is not handed null weight pointers.
fc1_weight_data = packed_fc1_weights_.get();
fc2_weight_data = packed_fc2_weights_.get();
Comment on lines +1155 to +1158
ORT_ENFORCE(sm_ >= 75,
"QMoE int4/int8 weight prepack requires SM75 or newer, got sm=", sm_);
const int packing_sm = (sm_ == 90) ? 90 : 80;

Comment on lines +1522 to +1526
.Attr("weights_prepacked",
"Only meaningful when quant_type='int'. Set to 1 (default) when the int4/int8 "
"fc1/fc2 weight initializers have already been laid out in the CUTLASS fpA_intB "
"format expected by the runner (e.g. produced offline by "
"pack_weights_for_cuda_mixed_gemm). Set to 0 when the initializers are raw, "
Comment on lines +1526 to +1528
"pack_weights_for_cuda_mixed_gemm). Set to 0 when the initializers are raw, "
"row-major [E, N, K/pack] tensors as produced by quantize_matmul_{4,8}bits; "
"in that case the kernel runs the CUTLASS layout transform itself in PrePack(), "
Comment on lines +2089 to +2090
class TestQMoEIntPrePackParity(unittest.TestCase):
"""Smoke test for the QMoE int4 PrePack hook (issue #28748 / PR #28749).
Comment on lines +2072 to +2085
# ============================================================================
# QMoE integer-weight PrePack parity test.
#
# Validates the PrePack hook added in PR #28749: with `quant_type="int"`, the
# QMoE op should be able to consume raw quantized weights — shape
# `[E, N, K/(8/bits)]` as produced by `quantize_matmul_{4,8}bits` —
# and internally run the CUTLASS fpA_intB layout transform that callers
# previously had to do offline via `pack_weights_for_cuda_mixed_gemm`.
#
# Strategy: build two ONNX graphs that differ only in whether the weight
# initializer is pre-prepacked or raw. Both go through ORT's CUDA QMoE
# kernel. With the PrePack hook in place, the raw-weight graph's output
# should be bit-identical to the offline-prepacked graph's output.
# ============================================================================
Address automated review feedback on doc/test wording:

- Rename TestQMoEIntPrePackParity -> TestQMoEIntPrePackSmoke and rewrite the
  module-level comment block. The test is intentionally a smoke test (finite +
  plausible-magnitude output) with no bit-parity assertion, so the old
  "parity" name and "bit-identical" strategy comment were misleading.
- Schema doc: describe the raw weights as "un-prepacked [E, N, K/pack]" instead
  of "row-major", so it no longer conflicts with the QMoE schema docstring,
  which states weights are stored in column-major order per expert.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
@justinchuby
Copy link
Copy Markdown
Contributor Author

Thanks for the automated pass. Summary of how I addressed the 9 comments (pushed in 3cbbf51):

Fixed

  • PR description / schema note (Add doxygen generated website for the project #6): Rewrote the description; it now explicitly states weights_prepacked is a (backward-compatible) schema attribute addition and lists all changed files.
  • Schema doc wording (Set up CI with Azure Pipelines #7): Changed "raw, row-major [E, N, K/pack]" → "raw, un-prepacked [E, N, K/pack]" so it no longer conflicts with the column-major schema docstring.
  • Test name + comment (Fix build #8, Enable Mac pipeline #9): Renamed TestQMoEIntPrePackParityTestQMoEIntPrePackSmoke and rewrote the module comment to describe the actual smoke test (single graph, finite + plausible-magnitude assertions, no bit-parity).

Declined, with rationale

  • SM75 layout (Add doxygen web for the project #5): This is incorrect. getLayoutDetailsForTransform has a separate arch<80 branch, but the underlying cutlass::gemm::kernel::LayoutDetailsB for uint4b_t/uint8_t/half/bf16 is a single specialization gated on Arch::kMinComputeCapability >= 75 (mixed_gemm_B_layout.h:49–105), so SM75 and SM80 produce an identical packed layout. This matches docs/contrib_ops/cuda/moe_qmoe.md §7 (SM75 is in the universal Group A). packing_sm = (sm_ == 90) ? 90 : 80 is correct.
  • Per-weight prepack tracking (Set up CI with Azure Pipelines #1Incremental updates. #4): For a valid int QMoE, slots 2 (fc1) and 5 (fc2) are both constant initializers and are always prepacked together in the same PrePack pass; the "only one prepacked" case is unreachable for valid models (a non-initializer weight would already be rejected). int_weights_consumed_by_prepack requiring packed_fc1_weights_ != nullptr is sufficient, and the single-flag design was the approach agreed in the previous review round. Splitting into per-weight flags adds branching for a state that cannot occur.

Build + tests: onnxruntime_providers_cuda builds/links cleanly (nvcc 13.2, sm_90); TestQMoEIntPrePackSmoke and the existing int4 QMoE parity tests pass on H200, with the QMoE node confirmed on CUDAExecutionProvider via profiling.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants