Skip to content
157 changes: 149 additions & 8 deletions onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
#include "contrib_ops/cuda/moe/qmoe_kernels.h"
#include "contrib_ops/cuda/llm/common/env_utils.h"
#include "contrib_ops/cuda/llm/common/logger.h"
#include "contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.h"
#include "contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors.h"

#include "contrib_ops/cuda/utils/dump_cuda_tensor.h"
#include "contrib_ops/cpu/utils/debug_macros.h"
Expand Down Expand Up @@ -60,6 +62,10 @@ QMoE::QMoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoE
this->quant_type_ = op_kernel_info.GetAttrOrDefault<std::string>("quant_type", "int");
ORT_ENFORCE(quant_type_ == "int" || quant_type_ == "fp4" || quant_type_ == "fp8" || quant_type_ == "wfp4afp8",
"quant_type must be 'int', 'fp4', 'fp8', or 'wfp4afp8', but got '", quant_type_, "'");
// Backward-compat opt-in: default is 1 (callers ship CUTLASS-prepacked
// weights, matching all pre-existing tooling). Setting to 0 tells the
// PrePack hook to lay out raw [E, N, K/pack] quantized weights itself.
weights_prepacked_ = op_kernel_info.GetAttrOrDefault<int64_t>("weights_prepacked", 1) != 0;
#if !defined(ENABLE_FP4) || !defined(USE_FP4_QMOE)
ORT_ENFORCE(quant_type_ != "fp4", "QMoE quant_type='fp4' requires USE_FP4_QMOE with CUDA 12.8 or newer.");
ORT_ENFORCE(quant_type_ != "wfp4afp8",
Expand Down Expand Up @@ -199,10 +205,15 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const {
const bool uses_global_weight_scales = is_fp4 || is_fp8 || is_wfp4afp8;
const Tensor* input = context->Input<Tensor>(0);
const Tensor* router_probs = context->Input<Tensor>(1);
const Tensor* fc1_experts_weights = context->Input<Tensor>(2);
// When PrePack consumed the int4/int8 expert-weight initializers
// (``weights_prepacked == false`` opt-in path), the original tensors
// were freed; ``context->Input<Tensor>(2)/(5)`` would return nothing.
// Mirror how ``MatMulNBits`` reads its prepacked B input.
const bool int_weights_consumed_by_prepack = is_int && !weights_prepacked_ && packed_fc1_weights_ != nullptr;
const Tensor* fc1_experts_weights = int_weights_consumed_by_prepack ? nullptr : context->Input<Tensor>(2);
Comment thread
justinchuby marked this conversation as resolved.
const Tensor* fc1_scales = (is_int && !packed_fc1_scales_) ? context->Input<Tensor>(3) : nullptr;
const Tensor* fc1_experts_bias_optional = context->Input<Tensor>(4);
const Tensor* fc2_experts_weights = context->Input<Tensor>(5);
const Tensor* fc2_experts_weights = int_weights_consumed_by_prepack ? nullptr : context->Input<Tensor>(5);
const Tensor* fc2_scales = (is_int && !packed_fc2_scales_) ? context->Input<Tensor>(6) : nullptr;
const Tensor* fc2_experts_bias_optional = context->Input<Tensor>(7);
// The CUTLASS MoE runner has no separate FC3 GEMM — gate and up projection weights must be
Expand All @@ -224,8 +235,13 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const {
return Status::OK();
};

ORT_RETURN_IF_ERROR(check_weight_type(fc1_experts_weights, "fc1_experts_weights", is_fp8));
ORT_RETURN_IF_ERROR(check_weight_type(fc2_experts_weights, "fc2_experts_weights", is_fp8));
// When PrePack consumed the int weight initializers, the dtype check
// is no longer applicable (we know they were uint8 — that's what
// PrePackIntExpertWeights validated and consumed).
if (!int_weights_consumed_by_prepack) {
ORT_RETURN_IF_ERROR(check_weight_type(fc1_experts_weights, "fc1_experts_weights", is_fp8));
ORT_RETURN_IF_ERROR(check_weight_type(fc2_experts_weights, "fc2_experts_weights", is_fp8));
}
Comment thread
justinchuby marked this conversation as resolved.

// Unified FP4 inputs: block scales in fc*_scales (3/6), global scales in 15/16.
const Tensor* fp4_fc1_block_scales = (uses_fp4_weight_scales && !packed_fp4_fc1_block_scales_) ? context->Input<Tensor>(3) : nullptr;
Expand Down Expand Up @@ -256,10 +272,13 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const {
int64_t pack_size = expert_weight_bits_ == 4 ? 2 : 1;
bool is_fused_swiglu = activation_type_ == onnxruntime::llm::kernels::cutlass_kernels::ActivationType::Swiglu;
MoEParameters moe_params;
// Prefer the cached shapes when PrePack consumed the source initializer.
const TensorShape& fc1_shape = int_weights_consumed_by_prepack ? fc1_weights_shape_ : fc1_experts_weights->Shape();
const TensorShape& fc2_shape = int_weights_consumed_by_prepack ? fc2_weights_shape_ : fc2_experts_weights->Shape();
Comment thread
justinchuby marked this conversation as resolved.
ORT_RETURN_IF_ERROR(onnxruntime::contrib::moe_helper::CheckInputs<Tensor>(
moe_params, input, router_probs, fc1_experts_weights,
moe_params, input, router_probs, &fc1_shape,
fc1_experts_bias_optional, fc1_scales, fc1_zeros,
fc2_experts_weights, fc2_experts_bias_optional, fc2_scales, fc2_zeros,
&fc2_shape, fc2_experts_bias_optional, fc2_scales, fc2_zeros,
nullptr, nullptr, nullptr, nullptr,
pack_size, is_fused_swiglu, block_size_));

Expand Down Expand Up @@ -808,11 +827,22 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const {

Tensor* output = context->Output(0, input->Shape());

const void* fc1_weight_data = fc1_experts_weights->DataRaw();
const void* fc2_weight_data = fc2_experts_weights->DataRaw();
const void* fc1_weight_data = fc1_experts_weights ? fc1_experts_weights->DataRaw() : nullptr;
const void* fc2_weight_data = fc2_experts_weights ? fc2_experts_weights->DataRaw() : nullptr;
if (is_wfp4afp8 && !use_wfp4afp8_dequant_fallback_) {
fc1_weight_data = packed_fp4_fc1_weights_ ? packed_fp4_fc1_weights_.get() : fc1_weight_data;
fc2_weight_data = packed_fp4_fc2_weights_ ? packed_fp4_fc2_weights_.get() : fc2_weight_data;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correctness: this else if guard is weaker than the int_weights_consumed_by_prepack guard used above (is_int && !weights_prepacked_ && packed_fc1_weights_ != nullptr). When weights_prepacked_ == false but prepacking is disabled at the session level (session.disable_prepacking), PrePack never runs, so packed_fc{1,2}_weights_ stay null. In that case int_weights_consumed_by_prepack is false, the raw initializers are correctly read into fc1/fc2_weight_data above — but this branch still fires (is_int && !weights_prepacked_ is true) and overwrites them with packed_fc1_weights_.get() == nullptr, so the runner receives null weight pointers (crash / garbage). This also contradicts the PR description's stated "fall-through to the raw initializer ... for sessions that disable prepacking."

Suggest gating on the same condition, e.g.:

} else if (int_weights_consumed_by_prepack) {
  fc1_weight_data = packed_fc1_weights_.get();
  fc2_weight_data = packed_fc2_weights_.get();
}

so that when the prepack buffers are absent the code keeps the raw-initializer pointers.

} else if (int_weights_consumed_by_prepack) {
// PrePack converted the raw int4/int8 weights to the CUTLASS fpA_intB
// layout that the runner consumes and freed the source initializer
// (``is_packed = true``). Gate on ``int_weights_consumed_by_prepack``
// (which already requires ``packed_fc1_weights_ != nullptr``) rather than
// just ``is_int && !weights_prepacked_``: when prepacking is disabled at
// the session level (``session.disable_prepacking``) PrePack never runs,
// the prepack buffers stay null, and the raw initializer pointers read
// above must be kept so the runner is not handed null weight pointers.
fc1_weight_data = packed_fc1_weights_.get();
fc2_weight_data = packed_fc2_weights_.get();
Comment thread
justinchuby marked this conversation as resolved.
}
IAllocatorUniquePtr<void> dequant_fc1_weights;
IAllocatorUniquePtr<void> dequant_fc2_weights;
Expand Down Expand Up @@ -972,6 +1002,19 @@ Status QMoE::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
} else if (input_idx == 5 && quant_type_ == "wfp4afp8" && !use_wfp4afp8_dequant_fallback_) {
PrePackRepackFP4Weights(tensor, stream, alloc, packed_fp4_fc2_weights_, is_packed);
is_packed = false;
} else if (input_idx == 2 && quant_type_ == "int" && !weights_prepacked_) {
// Caller opted in (``weights_prepacked=0`` attribute) to having ORT
// do the CUTLASS fpA_intB layout transform internally, instead of
// shipping pre-prepacked bytes. Mirrors ``MatMulNBits::PrePack_B``
// looped over the E experts of ``[E, N, K/pack]``. We cache the
// source shape in ``fc1_weights_shape_`` so ``CheckInputs`` can be
// satisfied without holding the original initializer alive, then
// set ``is_packed = true`` to let ORT free it.
fc1_weights_shape_ = tensor.Shape();
PrePackIntExpertWeights(tensor, stream, alloc, packed_fc1_weights_, is_packed);
} else if (input_idx == 5 && quant_type_ == "int" && !weights_prepacked_) {
fc2_weights_shape_ = tensor.Shape();
PrePackIntExpertWeights(tensor, stream, alloc, packed_fc2_weights_, is_packed);
} else if (input_idx == 3) { // fc1_scales
DUMP_TENSOR("fc1_scales", tensor);
if (quant_type_ == "wfp4afp8" && !use_wfp4afp8_dequant_fallback_) {
Expand Down Expand Up @@ -1078,6 +1121,104 @@ void QMoE::PrePackCopyToGpu(const Tensor& tensor, cudaStream_t stream, Allocator
is_packed = true;
}

// ---------------------------------------------------------------------------
// PrePack helper: int4/int8 per-expert weights → CUTLASS fpA_intB layout.
// ---------------------------------------------------------------------------
// Mirrors ``MatMulNBits::PrePack_B`` but loops over the leading E (experts)
// dimension. Input ``tensor`` is the row-major 3-D ``[E, N, K/(8/bits)]``
// quantized weight initializer; output is a GPU buffer in the
// kernel-expected ``[E, K, N/(8/bits)]`` layout.
void QMoE::PrePackIntExpertWeights(const Tensor& tensor, cudaStream_t stream, AllocatorPtr alloc,
IAllocatorUniquePtr<void>& packed_buf, bool& is_packed) {
ORT_ENFORCE(expert_weight_bits_ == 4 || expert_weight_bits_ == 8,
"PrePackIntExpertWeights: only 4 and 8 bits are supported, got ", expert_weight_bits_);
const auto& shape = tensor.Shape();
ORT_ENFORCE(shape.NumDimensions() == 3,
"PrePackIntExpertWeights: expected 3-D weight tensor [E, N, K/pack], got ndim=",
shape.NumDimensions());

const int bits = static_cast<int>(expert_weight_bits_);
const int pack_factor = 8 / bits;
const int64_t num_experts = shape[0];
const int64_t n = shape[1];
const int64_t k_packed = shape[2];
const int64_t k = k_packed * pack_factor;

// Weight packing is architecture-aware (see
// docs/contrib_ops/cuda/moe_qmoe.md §7 "Cross-Architecture Packing
// Compatibility"). SM90 (Hopper) uses its own Permuted-Linear layout that
// skips column interleaving, so it is its own compatibility group. Every
// other supported arch — SM75/80/86/89 and SM100/120 (Blackwell) — shares
// the SM80 fpA_intB layout, so they all pack as SM80. SM70 and older lack
// INT8 LDSM and are unsupported. The compute-side runner selects the same
// layout from this clamped arch, so the two cannot drift.
ORT_ENFORCE(sm_ >= 75,
"QMoE int4/int8 weight prepack requires SM75 or newer, got sm=", sm_);
const int packing_sm = (sm_ == 90) ? 90 : 80;

Comment thread
justinchuby marked this conversation as resolved.
// Per-expert sizes.
const size_t per_expert_bytes = static_cast<size_t>(n) * static_cast<size_t>(k) / pack_factor;
const size_t total_bytes = per_expert_bytes * static_cast<size_t>(num_experts);

// Output buffer holds all E prepacked experts back-to-back in
// [E, K, N/pack_factor] layout.
packed_buf = IAllocator::MakeUniquePtr<void>(alloc, total_bytes, /*use_reserve=*/true);
int8_t* dst_all = reinterpret_cast<int8_t*>(packed_buf.get());

// Two transient per-expert scratch buffers reused across experts.
IAllocatorUniquePtr<void> transposed_scratch =
this->GetTransientScratchBuffer<void>(per_expert_bytes);
int8_t* transposed_scratch_ptr = reinterpret_cast<int8_t*>(transposed_scratch.get());

IAllocatorUniquePtr<void> src_gpu_scratch;
const uint8_t* src_base_gpu = nullptr;
if (tensor.Location().device.Type() == OrtDevice::CPU) {
src_gpu_scratch = this->GetTransientScratchBuffer<void>(total_bytes);
CUDA_CALL_THROW(cudaMemcpyAsync(src_gpu_scratch.get(), tensor.DataRaw(), total_bytes,
cudaMemcpyHostToDevice, stream));
src_base_gpu = reinterpret_cast<const uint8_t*>(src_gpu_scratch.get());
} else {
src_base_gpu = reinterpret_cast<const uint8_t*>(tensor.DataRaw());
}

IAllocatorUniquePtr<int32_t> permutation_map = this->GetTransientScratchBuffer<int32_t>(32);

using onnxruntime::llm::kernels::weight_only::QuantType;
const QuantType quant_type = (bits == 4) ? QuantType::W4_A16 : QuantType::W8_A16;

for (int64_t e = 0; e < num_experts; ++e) {
const uint8_t* src_e = src_base_gpu + static_cast<size_t>(e) * per_expert_bytes;
int8_t* dst_e = dst_all + static_cast<size_t>(e) * per_expert_bytes;

// Step 1: transpose + (for int4) unpack/zero-point bias into the
// transposed-int8 scratch buffer. Mirrors MatMulNBits's PrePack_B.
if (bits == 4) {
onnxruntime::llm::kernels::fpA_intB_gemv::unpack_uint4_transposed_to_int8_direct_cuda(
stream, transposed_scratch_ptr, src_e, static_cast<int>(n), static_cast<int>(k));
} else {
onnxruntime::llm::kernels::fpA_intB_gemv::transpose_uint8_matrix_and_convert_to_int8(
stream, transposed_scratch_ptr, src_e, static_cast<int>(n), static_cast<int>(k));
}

// Step 2: apply the CUTLASS fpA_intB row-permutation / column-interleave /
// bias / pair-interleave transform into the per-expert output slot.
onnxruntime::llm::kernels::weight_only::preprocess_weights_for_mixed_gemm_cuda(
stream,
packing_sm,
dst_e,
transposed_scratch_ptr,
permutation_map.get(),
{static_cast<size_t>(k), static_cast<size_t>(n)},
quant_type);
}

// No explicit cudaStreamSynchronize here: preprocess_weights_for_mixed_gemm_cuda
// synchronizes the stream internally at the end of every per-expert call, so
// after the final expert all transpose/pack work (and the CPU->GPU staging
// copy above) is complete and the transient scratch buffers are safe to free.
is_packed = true;
}

// ---------------------------------------------------------------------------
// PrePack helper: Swizzle MXFP block scales for SM120 TMA layout using GPU kernel.
// ---------------------------------------------------------------------------
Expand Down
33 changes: 33 additions & 0 deletions onnxruntime/contrib_ops/cuda/moe/moe_quantization.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,33 @@ class QMoE final : public CudaKernel, public MoEBase {
IAllocatorUniquePtr<void>& packed_buf, bool& is_packed);
void PrePackRepackFP4Weights(const Tensor& tensor, cudaStream_t stream, AllocatorPtr alloc,
IAllocatorUniquePtr<void>& packed_buf, bool& is_packed);
// Prepacks int4/int8 expert weights into the CUTLASS fpA_intB layout so the
// QMoE runner can consume them directly. Mirrors what MatMulNBits.PrePack
// does, looped over the E expert dimension. ``tensor`` is the 3-D
// ``[E, N, K / (8 / bits)]`` weight initializer; ``packed_buf`` receives a
// GPU buffer in the kernel-expected ``[E, K, N / (8 / bits)]`` layout.
void PrePackIntExpertWeights(const Tensor& tensor, cudaStream_t stream, AllocatorPtr alloc,
IAllocatorUniquePtr<void>& packed_buf, bool& is_packed);
int64_t expert_weight_bits_;
bool is_fp16_;
// When true (the schema default), the int4/int8 fc1/fc2 weight
// initializers are already in the CUTLASS fpA_intB layout — produced
// offline e.g. via ``pack_weights_for_cuda_mixed_gemm`` — and the
// compute path reads them as-is. When false, the raw schema-conformant
// ``[E, N, K/pack]`` layout (as produced by
// ``quantize_matmul_{4,8}bits``) is rewritten inside the PrePack hook
// via ``PrePackIntExpertWeights``, removing the offline prepack
// dependency. Only meaningful when ``quant_type_ == "int"``.
bool weights_prepacked_ = true;
// Cached source weight shapes captured at PrePack time. When the
// PrePack hook consumed and released the original int4/int8 weight
// initializers (``is_packed = true``), ``context->Input<Tensor>(2)``
// and ``(5)`` return nothing, so ``moe_helper::CheckInputs`` can no
// longer read the shapes from the live tensors. We feed it these
// cached shapes instead via the ``TensorShape*`` overload, matching
// how ``MatMulNBits`` caches ``N_`` / ``K_`` in its constructor.
TensorShape fc1_weights_shape_;
TensorShape fc2_weights_shape_;
bool use_fp4_dequant_fallback_ = false;
// Dequantizes FP8 weights to FP16/BF16 scratch buffers before invoking the A16 MoE runner.
bool use_fp8_dequant_fallback_ = false;
Expand All @@ -54,6 +79,14 @@ class QMoE final : public CudaKernel, public MoEBase {
// PrePack logic:
// - Copies scales to GPU buffer (if in CPU) or just keeps them. For simplicity, we allocate and copy.
// - Computes Bias from ZP and Scale using PrePack kernel.
// - For ``quant_type == "int"``, also prepacks the per-expert int4/int8
// weight tensors into the CUTLASS fpA_intB layout, mirroring
// ``MatMulNBits.PrePack_B``. Without this, callers would have to
// pre-prepack the weights offline using ``pack_weights_for_cuda_mixed_gemm``,
// which is asymmetric with how ``MatMulNBits`` is consumed and forces
// a CUDA-enabled ORT build for any offline quantization tooling.
IAllocatorUniquePtr<void> packed_fc1_weights_;
IAllocatorUniquePtr<void> packed_fc2_weights_;
IAllocatorUniquePtr<void> packed_fc1_scales_;
IAllocatorUniquePtr<void> packed_fc1_bias_;
IAllocatorUniquePtr<void> packed_fc2_scales_;
Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1519,6 +1519,17 @@
"fc*_scales inputs contain MXFP4 block scales, and fc*_global_scale inputs must be provided.",
AttributeProto::STRING,
std::string("int"))
.Attr("weights_prepacked",
"Only meaningful when quant_type='int'. Set to 1 (default) when the int4/int8 "
"fc1/fc2 weight initializers have already been laid out in the CUTLASS fpA_intB "
"format expected by the runner (e.g. produced offline by "
"pack_weights_for_cuda_mixed_gemm). Set to 0 when the initializers are raw, "
Comment thread
justinchuby marked this conversation as resolved.
"un-prepacked [E, N, K/pack] tensors as produced by quantize_matmul_{4,8}bits; "
"in that case the kernel runs the CUTLASS layout transform itself in PrePack(), "
Comment thread
justinchuby marked this conversation as resolved.
"matching the behaviour of MatMulNBits and removing the offline pre-pack "

Check warning on line 1529 in onnxruntime/core/graph/contrib_ops/contrib_defs.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "behaviour" is a misspelling of "behavior" Raw Output: ./onnxruntime/core/graph/contrib_ops/contrib_defs.cc:1529:28: "behaviour" is a misspelling of "behavior"
"requirement from exporters. Default is 1 for backward compatibility.",
Copy link
Copy Markdown
Contributor

@tianleiwu tianleiwu Jun 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's mark this attribute as optional but not set default value, and each EP shall decide its own default value for backward-compatible: cuda assumes prepacked weight, but other EPs might not.

You will need update operator doc. You can download it from a link in failed "GPU Doc Gen CI" job after this change. There is a step call "Upload updated documentation" has the link.

AttributeProto::INT,
static_cast<int64_t>(1))
.Input(0,
"input",
"2D tensor with shape (num_tokens, hidden_size), or "
Expand Down
Loading
Loading