QMoE: prepack int4/int8 expert weights in PrePack hook (symmetric with MatMulNBits)#28749
QMoE: prepack int4/int8 expert weights in PrePack hook (symmetric with MatMulNBits)#28749justinchuby wants to merge 8 commits into
Conversation
Fixes microsoft#28748. `MatMulNBits::PrePack_B` calls `preprocess_weights_for_mixed_gemm_cuda` at session-load time so callers can hand it the raw `[N, K/(8/bits)]` packed int4/int8 weights produced by `quantize_matmul_{4,8}bits`. The CUTLASS fpA_intB layout transform (row permutation + sub-byte transpose + column interleave + bias) happens inside ORT. `QMoE::PrePack` for `quant_type == "int"` did the opposite: input slots 2 and 5 (fc1/fc2 expert weights) were explicitly skipped with `is_packed = false`, and the compute path passed `tensor->DataRaw()` straight into the CUTLASS runner. That assumes the caller has already prepacked the weights themselves, which: - requires a CUDA-built ORT just to export a QMoE model (the `pack_weights_for_cuda_mixed_gemm` pybind binding is only exposed when ORT is built with USE_CUDA), and - is silent-failure-prone: skipping the prepack just produces garbage output, not an error. This change mirrors the MatMulNBits PrePack path: - Add `packed_fc1_weights_` / `packed_fc2_weights_` buffers. - Add `PrePackIntExpertWeights` helper that walks the E experts of the `[E, N, K/(8/bits)]` initializer, runs the existing `unpack_uint4_transposed_to_int8_direct_cuda` / `transpose_uint8_matrix_and_convert_to_int8` adapter, then the shared `preprocess_weights_for_mixed_gemm_cuda` transform, and stacks results into `[E, K, N/(8/bits)]`. - Dispatch from `PrePack` for slots 2 and 5 when `quant_type_ == "int"`. - Update `ComputeInternal` to prefer `packed_fc{1,2}_weights_` over the raw tensor data when the PrePack hook has populated them, with a fall-through to the raw initializer for sessions that disable prepacking (in that case the caller still has to provide pre-prepacked bytes themselves — same as today). Builds cleanly (verified by re-compiling `contrib_ops/cuda/moe/moe_quantization.cc.o` against the current ORT main; remaining link-time errors in the surrounding `onnxruntime_providers_cuda` target are a pre-existing CUDA 13.2 + CCCL header incompatibility in `bias_softmax_impl.cu` and unrelated to this change). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Bit-parity smoke test that constructs two single-node QMoE graphs over identical per-expert quantized weights: - **Raw path**: writes the un-prepacked `[E, N, K/2]` bytes from `quantize_matmul_4bits` straight into the initializer. Exercises the new `QMoE::PrePackIntExpertWeights` hook. - **Pre-prepacked path**: applies `pack_weights_for_cuda_mixed_gemm` per-expert before writing the initializer (matches what the existing test_qmoe_cuda.py tests do). Both feed the same QMoE runner; with the PrePack hook in place the runner sees the same prepacked bytes either way, so outputs should agree to within fp16 rounding. Two cases cover small (64/32/E=4) and medium (128/64/E=8) shapes with SwiGLU interleaved fusion. Guarded by `@unittest.skipUnless(torch.cuda.is_available())` so it no-ops on CPU-only CI. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
Match the CI ruff (0.12.12) import sort: treat onnxruntime as first-party so 'from onnxruntime.capi import _pybind_state' belongs in the local-imports block after 'import onnxruntime', not in the third-party block. Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
The first version compared the new raw-weight PrePack path against the
existing `pack_weights_for_cuda_mixed_gemm` offline-pre-pack path, but
that comparison is invalid on SM>=90: the existing test harness in this
file hardcodes `force_arch=80` when calling
`pack_weights_for_cuda_mixed_gemm`, and on H100/H200 the other QMoE
parity tests in this file fail with max-diff > 1.0 too (verified on
plain main, pre-dating this change).
Rewrite as a smoke test that:
- builds a single QMoE node with raw, un-prepacked `[E, N, K/2]` int4
weights from `quantize_matmul_4bits` (the new schema-conformant
layout that the PrePack hook unlocks),
- runs it through the CUDA QMoE kernel,
- asserts the output has the right shape, is finite, and has reasonable
magnitudes for the toy weight distribution.
Verified passing on H200 (sm_90) with the PrePack hook in place.
Also: keep `is_packed = false` after `PrePackIntExpertWeights` so the
original weight initializer stays alive for `moe_helper::CheckInputs`
to read its shape on every `Compute` call. The prepacked bytes live
in `packed_fc{1,2}_weights_` and the compute path prefers them over
`fc{1,2}_experts_weights->DataRaw()`. Same trade-off the wfp4afp8
weight branch uses.
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
|
Built locally and ran the smoke test on H200 (sm_90): A few notes from local verification:
|
tianleiwu
left a comment
There was a problem hiding this comment.
Review summary
Moving the CUTLASS fpA_intB layout transform for quant_type == "int" QMoE weights into the PrePack hook (mirroring MatMulNBits::PrePack_B) is a good direction, and the mechanics faithfully mirror the validated pack_weights_for_cuda_mixed_gemm pybind path. One blocking concern plus a couple of follow-ups.
High priority
Unconditional re-prepack is not backward compatible (no raw-vs-packed marker). The new dispatch runs PrePackIntExpertWeights for every int QMoE on slots 2/5, with no schema flag or version guard. Existing tooling — including this file's own quant_dequant_blockwise/preprocess_weights_for_mixed_gemm in test_qmoe_cuda.py — already emits weights in the CUTLASS layout and stores them under the logical [E, N, K/pack] shape that moe_helper::CheckInputs validates. Because raw and prepacked byte counts are identical (E*N*K/pack == E*K*N/pack), the declared shape, dtype, and size are indistinguishable. On the default path (prepacking enabled) those models get prepacked a second time and silently produce garbage — the same silent-failure class this PR set out to remove. The description's claim of "no behaviour change for callers that pre-prepacked their weights" only holds when session.disable_prepacking is set.
A safe fix needs an explicit signal that the weights are raw (e.g. a weights_prepacked attribute defaulting to legacy/prepacked, a com.microsoft opset bump, or a distinct marker). Until then this should be opt-in and the hard break documented. The existing prepacked-weight parity tests (test_qmoe_cuda.py ~L154-298) are not updated and would regress on an SM where they currently pass once the hook double-prepacks; the new test only covers the raw path so it won't catch that.
Suggestions
- Persistent weight memory ~2x for int QMoE. Keeping
is_packed = false(soCheckInputscan still read the original shape) while also allocating persistentpacked_fc{1,2}_weights_means the original int weights and the prepacked int weights both stay resident ~= 2x the dominant weight memory for both FC layers.MatMulNBits::PrePack_Bavoids this viais_packed = true. Consider caching just the (E, N, K) shape and releasing the source, or documenting the cost. Transient PrePack overhead per FC:E*N*K/packhost->device staging (CPU initializers only) +N*K/packper-expert transpose scratch + 128 B perm map, freed after the sync. - SM coverage. The offline packer restricts
force_archto {75,80,90} and warns arch>90 falls back to 80; the in-kernel path passessm_straight through (more correct, matches the device). Please confirmpreprocess_weights_for_mixed_gemm_cudais valid on SM100/120 if int QMoE is expected there, else add a guard/diagnostic.
Nitpick
ORT_ENFORCE(bits != 4 || k % 2 == 0, ...)is always true forbits == 4(k = k_packed*2).
Nice work on the motivation (decoupling QMoE export from a CUDA-built ORT) and the clear writeup.
…drop redundant assert Addresses tianleiwu's review on microsoft#28749: **Blocking — backward compatibility.** The previous version dispatched PrePackIntExpertWeights for every int QMoE unconditionally, which would double-prepack any model produced by existing tooling (quantize_matmul_4bits → pack_weights_for_cuda_mixed_gemm → CUTLASS layout) and silently corrupt its output. Add a new 'weights_prepacked' INT attribute on the QMoE schema, default value 1 (legacy behaviour: weights already in CUTLASS layout, kernel reads as-is). Setting it to 0 opts in to the new PrePack hook that takes raw [E, N, K/pack] quantize_matmul_{4,8}bits output and runs the layout transform inside ORT — matching MatMulNBits semantics and removing the offline pre-pack dependency from exporters. The PrePack dispatch and the compute-time weight-buffer override are both gated on '!weights_prepacked_'. Models without the attribute behave exactly as before. **SM coverage.** preprocess_weights_for_mixed_gemm_cuda only has tile / permutation tables for SM75/80/90; the offline pack_weights_for_cuda_mixed_gemm restricts force_arch to that set and falls back to 80 for newer archs. Mirror the same fallback inside PrePackIntExpertWeights so SM86/89 and SM100/120 callers get a defined Ampere-compiled layout rather than a silent path through the helper with an unknown SM. **Nit.** Drop 'ORT_ENFORCE(bits != 4 || k % 2 == 0, ...)' — k is computed as k_packed * pack_factor, so for bits=4, k % 2 == 0 is a tautology. **Memory cost documented.** is_packed stays false (so CheckInputs can read the source weight shape on every Compute call). Persistent memory cost is therefore ~2x the int4/int8 weight footprint, ~4x smaller than the original fp16 baseline. Documented inline. MatMulNBits avoids the doubling by caching shape in N_/K_ at construction; folding the same into QMoE is a follow-up. Tests still pass on H200 (sm_90) with weights_prepacked=0 set in the new test cases. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
|
Thanks for the thorough review! Addressed all 4 points in 2fcb940:
H200 re-verification with the new attribute: Test was also updated to set |
… freed Follow-up review nit from microsoft#28749: the previous version kept the original int4/int8 weight initializers resident (~2x weight memory) so `moe_helper::CheckInputs` could read their shapes per Compute. Removes the doubling by caching `fc1_weights_shape_` / `fc2_weights_shape_` in member variables during PrePack and switching the CheckInputs call to the TensorShape* overload, mirroring how `MatMulNBits` caches `N_` / `K_` in its constructor. Changes: - Add `TensorShape fc1_weights_shape_` / `fc2_weights_shape_` members on `QMoE`. Captured from `tensor.Shape()` at PrePack time when the opt-in raw-weight path is active. - `PrePackIntExpertWeights` now leaves `is_packed = true` (via the underlying helper) so ORT releases the source initializer. Net persistent weight memory is back to ~1x the int4/int8 footprint, matching the FP4 dequant-fallback path's memory profile. - `ComputeInternal`: - Guard `context->Input<Tensor>(2)/(5)` to return nullptr when the source weights were consumed by PrePack (the `int_weights_consumed_by_prepack` flag). - Use the `TensorShape*` overload of `moe_helper::CheckInputs` when no live tensor is available, feeding the cached shapes. - Skip the trivial `check_weight_type` dtype assertions for the consumed-by-prepack case (we already validated uint8 inside `PrePackIntExpertWeights`). - Compute path always reads from `packed_fc{1,2}_weights_.get()` in the consumed path; the previous `if (packed_...)` fall-through to the raw initializer was dead and confusing. Re-verified on H200: both `TestQMoEIntPrePackParity` tests pass with the opt-in attribute, and `test_swiglu_qmoe_parity_0` (legacy prepacked path, the default) still passes with max_diff ~0.001. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
|
Followed up on the 2× weight memory concern in 5e1491c:
H200 verification: |
tianleiwu
left a comment
There was a problem hiding this comment.
Review summary
Thanks for reworking this into an opt-in path — the new weights_prepacked attribute (default 1 = legacy CUTLASS-prepacked) plus releasing the source initializer via cached fc1/fc2_weights_shape_ resolves both of my earlier blockers (silent double-prepack on existing models, and the ~2x weight-memory cost). Those two threads are resolved.
One new correctness issue remains in the compute path (inline), plus one design question:
Inline (correctness)
ComputeInternalunconditionally overrides the weight pointers in theelse if (is_int && !weights_prepacked_)branch, even when the prepack buffer is null. See inline comment.
packing_sm clamp vs. runner layout (question)
In PrePackIntExpertWeights, packing_sm is clamped (>90 → 80, and 86/89 → 80) before calling preprocess_weights_for_mixed_gemm_cuda. This diverges from MatMulNBits::PrePack_B, which passes sm_ unmodified. The risk: the CUTLASS MoE runner at compute time selects its expected weight layout from the actual sm_. If sm_ is e.g. SM100/120 (Blackwell) and PrePack lays the bytes out for SM80, the packer/runner layouts can disagree and silently produce garbage rather than erroring. Either (a) confirm the runner consumes the SM80 layout on those arches, or (b) make the clamp explicit/shared with the runner-side arch selection so the two cannot drift. Worth a short comment documenting that the compute side relies on the same clamped arch.
Nit
- The trailing
cudaStreamSynchronize(stream)at the end ofPrePackIntExpertWeightsis redundant:preprocess_weights_for_mixed_gemm_cudaalready synchronizes the stream internally on every per-expert call. Harmless, but can be dropped.
The schema doc, opt-in default, and the new parity smoke test all look good. Requesting changes only for the compute-path null-override.
| const void* fc2_weight_data = fc2_experts_weights ? fc2_experts_weights->DataRaw() : nullptr; | ||
| if (is_wfp4afp8 && !use_wfp4afp8_dequant_fallback_) { | ||
| fc1_weight_data = packed_fp4_fc1_weights_ ? packed_fp4_fc1_weights_.get() : fc1_weight_data; | ||
| fc2_weight_data = packed_fp4_fc2_weights_ ? packed_fp4_fc2_weights_.get() : fc2_weight_data; |
There was a problem hiding this comment.
Correctness: this else if guard is weaker than the int_weights_consumed_by_prepack guard used above (is_int && !weights_prepacked_ && packed_fc1_weights_ != nullptr). When weights_prepacked_ == false but prepacking is disabled at the session level (session.disable_prepacking), PrePack never runs, so packed_fc{1,2}_weights_ stay null. In that case int_weights_consumed_by_prepack is false, the raw initializers are correctly read into fc1/fc2_weight_data above — but this branch still fires (is_int && !weights_prepacked_ is true) and overwrites them with packed_fc1_weights_.get() == nullptr, so the runner receives null weight pointers (crash / garbage). This also contradicts the PR description's stated "fall-through to the raw initializer ... for sessions that disable prepacking."
Suggest gating on the same condition, e.g.:
} else if (int_weights_consumed_by_prepack) {
fc1_weight_data = packed_fc1_weights_.get();
fc2_weight_data = packed_fc2_weights_.get();
}so that when the prepack buffers are absent the code keeps the raw-initializer pointers.
| // to that set and falls back to 80 for newer archs. Mirror the same | ||
| // fallback here so SM100/120 (Blackwell) consumers get a defined layout | ||
| // (compiled-for-Ampere) instead of garbage. | ||
| int packing_sm = sm_; |
There was a problem hiding this comment.
So packing_sm = (sm_ == 90) ? 90 : 80;
Skip packing when sm_ < 75 since those GPUs are not supported.
Resolve three review items on the QMoE int4/int8 PrePack path:
1. Compute-path null-pointer guard. The weight-pointer override branch
was gated on `is_int && !weights_prepacked_`, which still fired when
prepacking was disabled at the session level
(`session.disable_prepacking`) — clobbering the raw initializer
pointers with null `packed_fc{1,2}_weights_.get()`. Gate on the
existing `int_weights_consumed_by_prepack` (which requires the packed
buffers to be non-null) so disabled-prepack sessions fall through to
the raw initializer pointers instead of receiving null weights.
2. Simplify the architecture clamp in PrePackIntExpertWeights to match
the cross-architecture packing table in
docs/contrib_ops/cuda/moe_qmoe.md §7: SM90 is its own layout group,
every other supported arch shares the SM80 layout, and SM70/older are
unsupported. Replace the multi-branch clamp with an ORT_ENFORCE on
SM75+ and `packing_sm = (sm_ == 90) ? 90 : 80`.
3. Drop the redundant trailing cudaStreamSynchronize after the per-expert
pack loop. preprocess_weights_for_mixed_gemm_cuda already synchronizes
the stream internally at the end of every per-expert call, so all
transpose/pack work and the CPU->GPU staging copy are complete before
the transient scratch buffers are freed on return.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
This PR extends the CUDA com.microsoft::QMoE integer-quantized path to optionally prepack raw int4/int8 expert weights inside ORT’s PrePack() hook (mirroring MatMulNBits), controlled via a new weights_prepacked attribute. It also adds a Python smoke test to validate that raw weights can execute through the CUDA QMoE kernel without producing NaN/Inf.
Changes:
- Add a new
weights_prepackedattribute to the QMoE schema (default1) to distinguish offline-prepacked vs raw quantized expert weights. - Implement
PrePackIntExpertWeightsand wire it into QMoE’s CUDAPrePack()+ComputeInternalto use cached prepacked GPU buffers. - Add a CUDA Python test that builds a minimal QMoE graph with raw int4 weights and checks output sanity.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 9 comments.
| File | Description |
|---|---|
onnxruntime/test/python/transformers/test_qmoe_cuda.py |
Adds a CUDA test case for running QMoE with raw int4 weights and weights_prepacked=0. |
onnxruntime/core/graph/contrib_ops/contrib_defs.cc |
Adds the weights_prepacked schema attribute and documentation for QMoE. |
onnxruntime/contrib_ops/cuda/moe/moe_quantization.h |
Adds members/state and declares helper for int-weight prepacking. |
onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc |
Implements int expert-weight prepacking and uses prepacked buffers during compute. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| // When PrePack consumed the int4/int8 expert-weight initializers | ||
| // (``weights_prepacked == false`` opt-in path), the original tensors | ||
| // were freed; ``context->Input<Tensor>(2)/(5)`` would return nothing. | ||
| // Mirror how ``MatMulNBits`` reads its prepacked B input. | ||
| const bool int_weights_consumed_by_prepack = is_int && !weights_prepacked_ && packed_fc1_weights_ != nullptr; | ||
| const Tensor* fc1_experts_weights = int_weights_consumed_by_prepack ? nullptr : context->Input<Tensor>(2); |
| // When PrePack consumed the int weight initializers, the dtype check | ||
| // is no longer applicable (we know they were uint8 — that's what | ||
| // PrePackIntExpertWeights validated and consumed). | ||
| if (!int_weights_consumed_by_prepack) { | ||
| ORT_RETURN_IF_ERROR(check_weight_type(fc1_experts_weights, "fc1_experts_weights", is_fp8)); | ||
| ORT_RETURN_IF_ERROR(check_weight_type(fc2_experts_weights, "fc2_experts_weights", is_fp8)); | ||
| } |
| // Prefer the cached shapes when PrePack consumed the source initializer. | ||
| const TensorShape& fc1_shape = int_weights_consumed_by_prepack ? fc1_weights_shape_ : fc1_experts_weights->Shape(); | ||
| const TensorShape& fc2_shape = int_weights_consumed_by_prepack ? fc2_weights_shape_ : fc2_experts_weights->Shape(); |
| } else if (int_weights_consumed_by_prepack) { | ||
| // PrePack converted the raw int4/int8 weights to the CUTLASS fpA_intB | ||
| // layout that the runner consumes and freed the source initializer | ||
| // (``is_packed = true``). Gate on ``int_weights_consumed_by_prepack`` | ||
| // (which already requires ``packed_fc1_weights_ != nullptr``) rather than | ||
| // just ``is_int && !weights_prepacked_``: when prepacking is disabled at | ||
| // the session level (``session.disable_prepacking``) PrePack never runs, | ||
| // the prepack buffers stay null, and the raw initializer pointers read | ||
| // above must be kept so the runner is not handed null weight pointers. | ||
| fc1_weight_data = packed_fc1_weights_.get(); | ||
| fc2_weight_data = packed_fc2_weights_.get(); |
| ORT_ENFORCE(sm_ >= 75, | ||
| "QMoE int4/int8 weight prepack requires SM75 or newer, got sm=", sm_); | ||
| const int packing_sm = (sm_ == 90) ? 90 : 80; | ||
|
|
| .Attr("weights_prepacked", | ||
| "Only meaningful when quant_type='int'. Set to 1 (default) when the int4/int8 " | ||
| "fc1/fc2 weight initializers have already been laid out in the CUTLASS fpA_intB " | ||
| "format expected by the runner (e.g. produced offline by " | ||
| "pack_weights_for_cuda_mixed_gemm). Set to 0 when the initializers are raw, " |
| "pack_weights_for_cuda_mixed_gemm). Set to 0 when the initializers are raw, " | ||
| "row-major [E, N, K/pack] tensors as produced by quantize_matmul_{4,8}bits; " | ||
| "in that case the kernel runs the CUTLASS layout transform itself in PrePack(), " |
| class TestQMoEIntPrePackParity(unittest.TestCase): | ||
| """Smoke test for the QMoE int4 PrePack hook (issue #28748 / PR #28749). |
| # ============================================================================ | ||
| # QMoE integer-weight PrePack parity test. | ||
| # | ||
| # Validates the PrePack hook added in PR #28749: with `quant_type="int"`, the | ||
| # QMoE op should be able to consume raw quantized weights — shape | ||
| # `[E, N, K/(8/bits)]` as produced by `quantize_matmul_{4,8}bits` — | ||
| # and internally run the CUTLASS fpA_intB layout transform that callers | ||
| # previously had to do offline via `pack_weights_for_cuda_mixed_gemm`. | ||
| # | ||
| # Strategy: build two ONNX graphs that differ only in whether the weight | ||
| # initializer is pre-prepacked or raw. Both go through ORT's CUDA QMoE | ||
| # kernel. With the PrePack hook in place, the raw-weight graph's output | ||
| # should be bit-identical to the offline-prepacked graph's output. | ||
| # ============================================================================ |
Address automated review feedback on doc/test wording: - Rename TestQMoEIntPrePackParity -> TestQMoEIntPrePackSmoke and rewrite the module-level comment block. The test is intentionally a smoke test (finite + plausible-magnitude output) with no bit-parity assertion, so the old "parity" name and "bit-identical" strategy comment were misleading. - Schema doc: describe the raw weights as "un-prepacked [E, N, K/pack]" instead of "row-major", so it no longer conflicts with the QMoE schema docstring, which states weights are stored in column-major order per expert. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> Signed-off-by: Justin Chu <11205048+justinchuby@users.noreply.github.com>
|
Thanks for the automated pass. Summary of how I addressed the 9 comments (pushed in 3cbbf51): Fixed
Declined, with rationale
Build + tests: |
Summary
This PR lets the CUDA
com.microsoft::QMoEoperator prepack raw int4/int8expert weights into the CUTLASS
fpA_intBlayout inside ORT'sPrePack()hook, instead of requiring callers to run the layout transform offline via
pack_weights_for_cuda_mixed_gemm. This makes integer QMoE symmetric withMatMulNBits::PrePack_B, and lets exporters ship the schema-conformant[E, N, K/pack]quantized weights produced byquantize_matmul_{4,8}bitsdirectly, with no offline pre-pack step.
The behaviour is opt-in and backward compatible: a new
weights_prepackedattribute defaults to
1(legacy — weights are already CUTLASS-prepacked), andonly
weights_prepacked=0triggers the new in-PrePacklayout transform.What changed
weights_prepackedattribute on the QMoE schema (default1).1= the int4/int8fc1/fc2initializers are already in the CUTLASSfpA_intBlayout (today's behaviour).0= the initializers are raw[E, N, K/pack]tensors and the kernel runs the layout transform itself inPrePack().PrePackIntExpertWeights— loops over theEexperts and applies theper-expert transpose + CUTLASS
fpA_intBrow-permutation / column-interleave/ bias / pair-interleave transform on the GPU, mirroring
pack_weights_for_cuda_mixed_gemm. Architecture-aware packing perdocs/contrib_ops/cuda/moe_qmoe.md§7 (SM90 is its own layout group; allother supported arches share the SM80 layout; SM75+ required).
PrePack()dispatch for the int weight slots (2 and 5) whenquant_type == "int"andweights_prepacked == 0. The source initializersare released after their shapes are cached (
fc1/fc2_weights_shape_), so peakweight memory stays ~1×.
ComputeInternalprefers the prepacked GPU buffers when the PrePack hookpopulated them (gated on
int_weights_consumed_by_prepack), and otherwisefalls through to the raw initializer pointers (e.g. for sessions that set
session.disable_prepacking).Schema note
This does add a schema attribute (
weights_prepacked) to QMoE. It isbackward compatible because the default (
1) preserves the existingoffline-prepacked behaviour, but it is a schema surface-area change.
Diff scope
onnxruntime/core/graph/contrib_ops/contrib_defs.ccweights_prepackedschema attribute + docsonnxruntime/contrib_ops/cuda/moe/moe_quantization.honnxruntime/contrib_ops/cuda/moe/moe_quantization.ccPrePackIntExpertWeights+ PrePack dispatch + ComputeInternal hookuponnxruntime/test/python/transformers/test_qmoe_cuda.pyweights_prepacked=0pathFP4 / FP8 / WFP4AFP8 paths are untouched, and there is no behaviour change for
callers that pre-prepacked their weights.
Testing
onnxruntime_providers_cudabuilds and links cleanly (nvcc 13.2 / sm_90).TestQMoEIntPrePackSmoke(test_qmoe_cuda.py) builds a QMoE graph with rawint4 weights and
weights_prepacked=0, runs it through the CUDA kernel, andasserts the output is finite with a plausible magnitude. Verified on H200
(SM90); node placement confirmed on
CUDAExecutionProvidervia profiling.phi3/swiglu, fp16) pass on CUDA — noregression in the default
weights_prepacked=1path.