-
Notifications
You must be signed in to change notification settings - Fork 4k
QMoE: prepack int4/int8 expert weights in PrePack hook (symmetric with MatMulNBits) #28749
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
834eae8
66c7f47
516812c
197819f
2fcb940
5e1491c
360217b
3cbbf51
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,8 @@ | |
| #include "contrib_ops/cuda/moe/qmoe_kernels.h" | ||
| #include "contrib_ops/cuda/llm/common/env_utils.h" | ||
| #include "contrib_ops/cuda/llm/common/logger.h" | ||
| #include "contrib_ops/cuda/llm/fpA_intB_gemm_adaptor.h" | ||
| #include "contrib_ops/cuda/llm/fpA_intB_gemm_preprocessors.h" | ||
|
|
||
| #include "contrib_ops/cuda/utils/dump_cuda_tensor.h" | ||
| #include "contrib_ops/cpu/utils/debug_macros.h" | ||
|
|
@@ -60,6 +62,10 @@ QMoE::QMoE(const OpKernelInfo& op_kernel_info) : CudaKernel(op_kernel_info), MoE | |
| this->quant_type_ = op_kernel_info.GetAttrOrDefault<std::string>("quant_type", "int"); | ||
| ORT_ENFORCE(quant_type_ == "int" || quant_type_ == "fp4" || quant_type_ == "fp8" || quant_type_ == "wfp4afp8", | ||
| "quant_type must be 'int', 'fp4', 'fp8', or 'wfp4afp8', but got '", quant_type_, "'"); | ||
| // Backward-compat opt-in: default is 1 (callers ship CUTLASS-prepacked | ||
| // weights, matching all pre-existing tooling). Setting to 0 tells the | ||
| // PrePack hook to lay out raw [E, N, K/pack] quantized weights itself. | ||
| weights_prepacked_ = op_kernel_info.GetAttrOrDefault<int64_t>("weights_prepacked", 1) != 0; | ||
| #if !defined(ENABLE_FP4) || !defined(USE_FP4_QMOE) | ||
| ORT_ENFORCE(quant_type_ != "fp4", "QMoE quant_type='fp4' requires USE_FP4_QMOE with CUDA 12.8 or newer."); | ||
| ORT_ENFORCE(quant_type_ != "wfp4afp8", | ||
|
|
@@ -199,10 +205,15 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { | |
| const bool uses_global_weight_scales = is_fp4 || is_fp8 || is_wfp4afp8; | ||
| const Tensor* input = context->Input<Tensor>(0); | ||
| const Tensor* router_probs = context->Input<Tensor>(1); | ||
| const Tensor* fc1_experts_weights = context->Input<Tensor>(2); | ||
| // When PrePack consumed the int4/int8 expert-weight initializers | ||
| // (``weights_prepacked == false`` opt-in path), the original tensors | ||
| // were freed; ``context->Input<Tensor>(2)/(5)`` would return nothing. | ||
| // Mirror how ``MatMulNBits`` reads its prepacked B input. | ||
| const bool int_weights_consumed_by_prepack = is_int && !weights_prepacked_ && packed_fc1_weights_ != nullptr; | ||
| const Tensor* fc1_experts_weights = int_weights_consumed_by_prepack ? nullptr : context->Input<Tensor>(2); | ||
| const Tensor* fc1_scales = (is_int && !packed_fc1_scales_) ? context->Input<Tensor>(3) : nullptr; | ||
| const Tensor* fc1_experts_bias_optional = context->Input<Tensor>(4); | ||
| const Tensor* fc2_experts_weights = context->Input<Tensor>(5); | ||
| const Tensor* fc2_experts_weights = int_weights_consumed_by_prepack ? nullptr : context->Input<Tensor>(5); | ||
| const Tensor* fc2_scales = (is_int && !packed_fc2_scales_) ? context->Input<Tensor>(6) : nullptr; | ||
| const Tensor* fc2_experts_bias_optional = context->Input<Tensor>(7); | ||
| // The CUTLASS MoE runner has no separate FC3 GEMM — gate and up projection weights must be | ||
|
|
@@ -224,8 +235,13 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { | |
| return Status::OK(); | ||
| }; | ||
|
|
||
| ORT_RETURN_IF_ERROR(check_weight_type(fc1_experts_weights, "fc1_experts_weights", is_fp8)); | ||
| ORT_RETURN_IF_ERROR(check_weight_type(fc2_experts_weights, "fc2_experts_weights", is_fp8)); | ||
| // When PrePack consumed the int weight initializers, the dtype check | ||
| // is no longer applicable (we know they were uint8 — that's what | ||
| // PrePackIntExpertWeights validated and consumed). | ||
| if (!int_weights_consumed_by_prepack) { | ||
| ORT_RETURN_IF_ERROR(check_weight_type(fc1_experts_weights, "fc1_experts_weights", is_fp8)); | ||
| ORT_RETURN_IF_ERROR(check_weight_type(fc2_experts_weights, "fc2_experts_weights", is_fp8)); | ||
| } | ||
|
justinchuby marked this conversation as resolved.
|
||
|
|
||
| // Unified FP4 inputs: block scales in fc*_scales (3/6), global scales in 15/16. | ||
| const Tensor* fp4_fc1_block_scales = (uses_fp4_weight_scales && !packed_fp4_fc1_block_scales_) ? context->Input<Tensor>(3) : nullptr; | ||
|
|
@@ -256,10 +272,13 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { | |
| int64_t pack_size = expert_weight_bits_ == 4 ? 2 : 1; | ||
| bool is_fused_swiglu = activation_type_ == onnxruntime::llm::kernels::cutlass_kernels::ActivationType::Swiglu; | ||
| MoEParameters moe_params; | ||
| // Prefer the cached shapes when PrePack consumed the source initializer. | ||
| const TensorShape& fc1_shape = int_weights_consumed_by_prepack ? fc1_weights_shape_ : fc1_experts_weights->Shape(); | ||
| const TensorShape& fc2_shape = int_weights_consumed_by_prepack ? fc2_weights_shape_ : fc2_experts_weights->Shape(); | ||
|
justinchuby marked this conversation as resolved.
|
||
| ORT_RETURN_IF_ERROR(onnxruntime::contrib::moe_helper::CheckInputs<Tensor>( | ||
| moe_params, input, router_probs, fc1_experts_weights, | ||
| moe_params, input, router_probs, &fc1_shape, | ||
| fc1_experts_bias_optional, fc1_scales, fc1_zeros, | ||
| fc2_experts_weights, fc2_experts_bias_optional, fc2_scales, fc2_zeros, | ||
| &fc2_shape, fc2_experts_bias_optional, fc2_scales, fc2_zeros, | ||
| nullptr, nullptr, nullptr, nullptr, | ||
| pack_size, is_fused_swiglu, block_size_)); | ||
|
|
||
|
|
@@ -808,11 +827,22 @@ Status QMoE::ComputeInternal(OpKernelContext* context) const { | |
|
|
||
| Tensor* output = context->Output(0, input->Shape()); | ||
|
|
||
| const void* fc1_weight_data = fc1_experts_weights->DataRaw(); | ||
| const void* fc2_weight_data = fc2_experts_weights->DataRaw(); | ||
| const void* fc1_weight_data = fc1_experts_weights ? fc1_experts_weights->DataRaw() : nullptr; | ||
| const void* fc2_weight_data = fc2_experts_weights ? fc2_experts_weights->DataRaw() : nullptr; | ||
| if (is_wfp4afp8 && !use_wfp4afp8_dequant_fallback_) { | ||
| fc1_weight_data = packed_fp4_fc1_weights_ ? packed_fp4_fc1_weights_.get() : fc1_weight_data; | ||
| fc2_weight_data = packed_fp4_fc2_weights_ ? packed_fp4_fc2_weights_.get() : fc2_weight_data; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Correctness: this Suggest gating on the same condition, e.g.: } else if (int_weights_consumed_by_prepack) {
fc1_weight_data = packed_fc1_weights_.get();
fc2_weight_data = packed_fc2_weights_.get();
}so that when the prepack buffers are absent the code keeps the raw-initializer pointers. |
||
| } else if (int_weights_consumed_by_prepack) { | ||
| // PrePack converted the raw int4/int8 weights to the CUTLASS fpA_intB | ||
| // layout that the runner consumes and freed the source initializer | ||
| // (``is_packed = true``). Gate on ``int_weights_consumed_by_prepack`` | ||
| // (which already requires ``packed_fc1_weights_ != nullptr``) rather than | ||
| // just ``is_int && !weights_prepacked_``: when prepacking is disabled at | ||
| // the session level (``session.disable_prepacking``) PrePack never runs, | ||
| // the prepack buffers stay null, and the raw initializer pointers read | ||
| // above must be kept so the runner is not handed null weight pointers. | ||
| fc1_weight_data = packed_fc1_weights_.get(); | ||
| fc2_weight_data = packed_fc2_weights_.get(); | ||
|
justinchuby marked this conversation as resolved.
|
||
| } | ||
| IAllocatorUniquePtr<void> dequant_fc1_weights; | ||
| IAllocatorUniquePtr<void> dequant_fc2_weights; | ||
|
|
@@ -972,6 +1002,19 @@ Status QMoE::PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, | |
| } else if (input_idx == 5 && quant_type_ == "wfp4afp8" && !use_wfp4afp8_dequant_fallback_) { | ||
| PrePackRepackFP4Weights(tensor, stream, alloc, packed_fp4_fc2_weights_, is_packed); | ||
| is_packed = false; | ||
| } else if (input_idx == 2 && quant_type_ == "int" && !weights_prepacked_) { | ||
| // Caller opted in (``weights_prepacked=0`` attribute) to having ORT | ||
| // do the CUTLASS fpA_intB layout transform internally, instead of | ||
| // shipping pre-prepacked bytes. Mirrors ``MatMulNBits::PrePack_B`` | ||
| // looped over the E experts of ``[E, N, K/pack]``. We cache the | ||
| // source shape in ``fc1_weights_shape_`` so ``CheckInputs`` can be | ||
| // satisfied without holding the original initializer alive, then | ||
| // set ``is_packed = true`` to let ORT free it. | ||
| fc1_weights_shape_ = tensor.Shape(); | ||
| PrePackIntExpertWeights(tensor, stream, alloc, packed_fc1_weights_, is_packed); | ||
| } else if (input_idx == 5 && quant_type_ == "int" && !weights_prepacked_) { | ||
| fc2_weights_shape_ = tensor.Shape(); | ||
| PrePackIntExpertWeights(tensor, stream, alloc, packed_fc2_weights_, is_packed); | ||
| } else if (input_idx == 3) { // fc1_scales | ||
| DUMP_TENSOR("fc1_scales", tensor); | ||
| if (quant_type_ == "wfp4afp8" && !use_wfp4afp8_dequant_fallback_) { | ||
|
|
@@ -1078,6 +1121,104 @@ void QMoE::PrePackCopyToGpu(const Tensor& tensor, cudaStream_t stream, Allocator | |
| is_packed = true; | ||
| } | ||
|
|
||
| // --------------------------------------------------------------------------- | ||
| // PrePack helper: int4/int8 per-expert weights → CUTLASS fpA_intB layout. | ||
| // --------------------------------------------------------------------------- | ||
| // Mirrors ``MatMulNBits::PrePack_B`` but loops over the leading E (experts) | ||
| // dimension. Input ``tensor`` is the row-major 3-D ``[E, N, K/(8/bits)]`` | ||
| // quantized weight initializer; output is a GPU buffer in the | ||
| // kernel-expected ``[E, K, N/(8/bits)]`` layout. | ||
| void QMoE::PrePackIntExpertWeights(const Tensor& tensor, cudaStream_t stream, AllocatorPtr alloc, | ||
| IAllocatorUniquePtr<void>& packed_buf, bool& is_packed) { | ||
| ORT_ENFORCE(expert_weight_bits_ == 4 || expert_weight_bits_ == 8, | ||
| "PrePackIntExpertWeights: only 4 and 8 bits are supported, got ", expert_weight_bits_); | ||
| const auto& shape = tensor.Shape(); | ||
| ORT_ENFORCE(shape.NumDimensions() == 3, | ||
| "PrePackIntExpertWeights: expected 3-D weight tensor [E, N, K/pack], got ndim=", | ||
| shape.NumDimensions()); | ||
|
|
||
| const int bits = static_cast<int>(expert_weight_bits_); | ||
| const int pack_factor = 8 / bits; | ||
| const int64_t num_experts = shape[0]; | ||
| const int64_t n = shape[1]; | ||
| const int64_t k_packed = shape[2]; | ||
| const int64_t k = k_packed * pack_factor; | ||
|
|
||
| // Weight packing is architecture-aware (see | ||
| // docs/contrib_ops/cuda/moe_qmoe.md §7 "Cross-Architecture Packing | ||
| // Compatibility"). SM90 (Hopper) uses its own Permuted-Linear layout that | ||
| // skips column interleaving, so it is its own compatibility group. Every | ||
| // other supported arch — SM75/80/86/89 and SM100/120 (Blackwell) — shares | ||
| // the SM80 fpA_intB layout, so they all pack as SM80. SM70 and older lack | ||
| // INT8 LDSM and are unsupported. The compute-side runner selects the same | ||
| // layout from this clamped arch, so the two cannot drift. | ||
| ORT_ENFORCE(sm_ >= 75, | ||
| "QMoE int4/int8 weight prepack requires SM75 or newer, got sm=", sm_); | ||
| const int packing_sm = (sm_ == 90) ? 90 : 80; | ||
|
|
||
|
justinchuby marked this conversation as resolved.
|
||
| // Per-expert sizes. | ||
| const size_t per_expert_bytes = static_cast<size_t>(n) * static_cast<size_t>(k) / pack_factor; | ||
| const size_t total_bytes = per_expert_bytes * static_cast<size_t>(num_experts); | ||
|
|
||
| // Output buffer holds all E prepacked experts back-to-back in | ||
| // [E, K, N/pack_factor] layout. | ||
| packed_buf = IAllocator::MakeUniquePtr<void>(alloc, total_bytes, /*use_reserve=*/true); | ||
| int8_t* dst_all = reinterpret_cast<int8_t*>(packed_buf.get()); | ||
|
|
||
| // Two transient per-expert scratch buffers reused across experts. | ||
| IAllocatorUniquePtr<void> transposed_scratch = | ||
| this->GetTransientScratchBuffer<void>(per_expert_bytes); | ||
| int8_t* transposed_scratch_ptr = reinterpret_cast<int8_t*>(transposed_scratch.get()); | ||
|
|
||
| IAllocatorUniquePtr<void> src_gpu_scratch; | ||
| const uint8_t* src_base_gpu = nullptr; | ||
| if (tensor.Location().device.Type() == OrtDevice::CPU) { | ||
| src_gpu_scratch = this->GetTransientScratchBuffer<void>(total_bytes); | ||
| CUDA_CALL_THROW(cudaMemcpyAsync(src_gpu_scratch.get(), tensor.DataRaw(), total_bytes, | ||
| cudaMemcpyHostToDevice, stream)); | ||
| src_base_gpu = reinterpret_cast<const uint8_t*>(src_gpu_scratch.get()); | ||
| } else { | ||
| src_base_gpu = reinterpret_cast<const uint8_t*>(tensor.DataRaw()); | ||
| } | ||
|
|
||
| IAllocatorUniquePtr<int32_t> permutation_map = this->GetTransientScratchBuffer<int32_t>(32); | ||
|
|
||
| using onnxruntime::llm::kernels::weight_only::QuantType; | ||
| const QuantType quant_type = (bits == 4) ? QuantType::W4_A16 : QuantType::W8_A16; | ||
|
|
||
| for (int64_t e = 0; e < num_experts; ++e) { | ||
| const uint8_t* src_e = src_base_gpu + static_cast<size_t>(e) * per_expert_bytes; | ||
| int8_t* dst_e = dst_all + static_cast<size_t>(e) * per_expert_bytes; | ||
|
|
||
| // Step 1: transpose + (for int4) unpack/zero-point bias into the | ||
| // transposed-int8 scratch buffer. Mirrors MatMulNBits's PrePack_B. | ||
| if (bits == 4) { | ||
| onnxruntime::llm::kernels::fpA_intB_gemv::unpack_uint4_transposed_to_int8_direct_cuda( | ||
| stream, transposed_scratch_ptr, src_e, static_cast<int>(n), static_cast<int>(k)); | ||
| } else { | ||
| onnxruntime::llm::kernels::fpA_intB_gemv::transpose_uint8_matrix_and_convert_to_int8( | ||
| stream, transposed_scratch_ptr, src_e, static_cast<int>(n), static_cast<int>(k)); | ||
| } | ||
|
|
||
| // Step 2: apply the CUTLASS fpA_intB row-permutation / column-interleave / | ||
| // bias / pair-interleave transform into the per-expert output slot. | ||
| onnxruntime::llm::kernels::weight_only::preprocess_weights_for_mixed_gemm_cuda( | ||
| stream, | ||
| packing_sm, | ||
| dst_e, | ||
| transposed_scratch_ptr, | ||
| permutation_map.get(), | ||
| {static_cast<size_t>(k), static_cast<size_t>(n)}, | ||
| quant_type); | ||
| } | ||
|
|
||
| // No explicit cudaStreamSynchronize here: preprocess_weights_for_mixed_gemm_cuda | ||
| // synchronizes the stream internally at the end of every per-expert call, so | ||
| // after the final expert all transpose/pack work (and the CPU->GPU staging | ||
| // copy above) is complete and the transient scratch buffers are safe to free. | ||
| is_packed = true; | ||
| } | ||
|
|
||
| // --------------------------------------------------------------------------- | ||
| // PrePack helper: Swizzle MXFP block scales for SM120 TMA layout using GPU kernel. | ||
| // --------------------------------------------------------------------------- | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1519,6 +1519,17 @@ | |
| "fc*_scales inputs contain MXFP4 block scales, and fc*_global_scale inputs must be provided.", | ||
| AttributeProto::STRING, | ||
| std::string("int")) | ||
| .Attr("weights_prepacked", | ||
| "Only meaningful when quant_type='int'. Set to 1 (default) when the int4/int8 " | ||
| "fc1/fc2 weight initializers have already been laid out in the CUTLASS fpA_intB " | ||
| "format expected by the runner (e.g. produced offline by " | ||
| "pack_weights_for_cuda_mixed_gemm). Set to 0 when the initializers are raw, " | ||
|
justinchuby marked this conversation as resolved.
|
||
| "un-prepacked [E, N, K/pack] tensors as produced by quantize_matmul_{4,8}bits; " | ||
| "in that case the kernel runs the CUTLASS layout transform itself in PrePack(), " | ||
|
justinchuby marked this conversation as resolved.
|
||
| "matching the behaviour of MatMulNBits and removing the offline pre-pack " | ||
|
Check warning on line 1529 in onnxruntime/core/graph/contrib_ops/contrib_defs.cc
|
||
| "requirement from exporters. Default is 1 for backward compatibility.", | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's mark this attribute as optional but not set default value, and each EP shall decide its own default value for backward-compatible: cuda assumes prepacked weight, but other EPs might not. You will need update operator doc. You can download it from a link in failed "GPU Doc Gen CI" job after this change. There is a step call "Upload updated documentation" has the link. |
||
| AttributeProto::INT, | ||
| static_cast<int64_t>(1)) | ||
| .Input(0, | ||
| "input", | ||
| "2D tensor with shape (num_tokens, hidden_size), or " | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.