From dc9af4abc6bad7b81d01ece364c01db9ff8a0e65 Mon Sep 17 00:00:00 2001 From: Ziang Li Date: Fri, 22 May 2026 15:58:11 -0700 Subject: [PATCH 1/2] Implement 4over6 NVFP4 recipe (#2972) * Initial implementation Signed-off-by: Ziang Li * Make 4over6 compile time for dequant Signed-off-by: Ziang Li * Expand 1d fwd+bwd test Signed-off-by: Ziang Li * Refactor Signed-off-by: Ziang Li * Clean up Signed-off-by: Ziang Li * Clean up Signed-off-by: Ziang Li * Add gemm test Signed-off-by: Ziang Li * Add more tests and fix offload Signed-off-by: Ziang Li * Fix offload Signed-off-by: Ziang Li * Clean up arg Signed-off-by: Ziang Li * Add more test Signed-off-by: Ziang Li * Add more tests Signed-off-by: Ziang Li * Clean up test Signed-off-by: Ziang Li * Refactor cuh kernel impl Signed-off-by: Ziang Li * Further extract Signed-off-by: Ziang Li * Clean up Signed-off-by: Ziang Li * Add recipe_id Signed-off-by: Ziang Li * Fix failing unit tests Signed-off-by: Ziang Li * Clean up test Signed-off-by: Ziang Li * Clean up Signed-off-by: Ziang Li * Refactor ref Signed-off-by: Ziang Li * Update comments and docs Signed-off-by: Ziang Li * Drop unnecessary test_sanity workaround The following tests passed: `NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 python3 -m pytest --tb=auto tests/pytorch/test_sanity.py ` `NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=tests/pytorch/debug/test_configs/dummy_feature.yaml NVTE_TEST_NVINSPECT_FEATURE_DIRS=transformer_engine/debug/features PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto tests/pytorch/test_sanity.py ` Signed-off-by: Ziang Li * Refactor `QuantizerRole` Signed-off-by: Ziang Li * Allow separate recipe 4over6 config Signed-off-by: Ziang Li * Support 2d Signed-off-by: Ziang Li * Refactor 2d Signed-off-by: Ziang Li * Clean up anti pattern Signed-off-by: Ziang Li * Enforce 4over6 consistency Signed-off-by: Ziang Li * Update comments Signed-off-by: Ziang Li * Update docs Signed-off-by: Ziang Li * Fix test Signed-off-by: Ziang Li * Drop test_fusible_ops Signed-off-by: Ziang Li * Revert "Drop test_fusible_ops" This reverts commit 69f9ccc36a9c459f50c2f00b6cd6a62c5e1bdf13. Signed-off-by: Ziang Li * Refactor test_fusible_ops Signed-off-by: Ziang Li * Refactor ref and extend cpp test Signed-off-by: Ziang Li * Clean up cpp test Signed-off-by: Ziang Li * Minor comment Signed-off-by: Ziang Li * Drop doc Signed-off-by: Ziang Li * Explicit handle conditional smem buffer Signed-off-by: Ziang Li * Further clean up Signed-off-by: Ziang Li * More templates Signed-off-by: Ziang Li * Simplify cpp Signed-off-by: Ziang Li * Drop write back lifting Signed-off-by: Ziang Li * Add MAE and dedicated fast math env var Signed-off-by: Ziang Li * Harden cpp test Signed-off-by: Ziang Li * Add warning and err fast math coverage Signed-off-by: Ziang Li * Fold test case and clean up cpp test Signed-off-by: Ziang Li * Initial 448 vs 256 implementation Signed-off-by: Ziang Li * Use e4m3 max instead of boolean, more template Signed-off-by: Ziang Li * Add benchmark script and minor optimization Signed-off-by: Ziang Li * Use standalone kernels Signed-off-by: Ziang Li * Use cp async Signed-off-by: Ziang Li * Add benchmark script Signed-off-by: Ziang Li * Minor fix after rebase Signed-off-by: Ziang Li * Naming consistency Signed-off-by: Ziang Li * Remove 4over6 benchmark Signed-off-by: Ziang Li * Refactor modes Signed-off-by: Ziang Li * Relax tol for `test_layernorm_mlp` for `nvfp4_4over6` Signed-off-by: Ziang Li * Minor fix recipe naming Signed-off-by: Ziang Li * Remove gradient 4over6 quantization and partially allow SR/RHT Signed-off-by: Ziang Li * Allow RHT in pytorch ref Signed-off-by: Ziang Li * Update transformer_engine/pytorch/csrc/quantizer.cpp Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Minor fix TODO lint Signed-off-by: Ziang Li * Use standard nvfp4 for grad ref in test_fusible_ops.py since 4over6 is not applied to gradient quantizers Signed-off-by: Ziang Li * Minor fix test-fusible_ops 4over6 helper Signed-off-by: Ziang Li * Default to 256 for 4over6 Signed-off-by: Ziang Li * Reset RNG state for each TE ops test Adding tests affected RNG in unrelated tests. Signed-off-by: Tim Moon * Remove loosened NVFP4 tols in layernorm MLP test. Make sure tensors are representable in quantized format. Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Ziang Li Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Tim Moon Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- docs/envvars.rst | 24 + .../cpp/operator/test_cast_nvfp4_transpose.cu | 616 +++++++++++++--- tests/cpp/operator/test_dequantize_nvfp4.cu | 68 +- tests/cpp/test_common.cu | 12 + tests/cpp/test_common.h | 3 + tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py | 52 ++ .../nvfp4/test_nvfp4_quantize_exact.py | 65 +- tests/pytorch/test_backward_override.py | 21 +- tests/pytorch/test_cpu_offloading.py | 91 ++- tests/pytorch/test_cuda_graphs.py | 21 +- tests/pytorch/test_fusible_ops.py | 100 ++- tests/pytorch/test_numerics.py | 72 +- tests/pytorch/test_quantized_tensor.py | 21 +- tests/pytorch/test_recipe.py | 134 +++- tests/pytorch/test_sanity.py | 64 +- tests/pytorch/test_torch_compile.py | 43 +- tests/pytorch/utils.py | 39 +- .../common/cast/dispatch/quantize.cuh | 41 +- .../common/cast/nvfp4/core_nvfp4.cuh | 8 +- .../common/cast/nvfp4/dequantize_nvfp4.cuh | 40 +- .../cast/nvfp4/quantize_4over6_nvfp4.cuh | 668 ++++++++++++++++++ .../comm_gemm_overlap/comm_gemm_overlap.cpp | 4 + transformer_engine/common/common.h | 16 +- .../transformer_engine/transformer_engine.h | 55 ++ transformer_engine/common/recipe/__init__.py | 30 + transformer_engine/common/recipe/nvfp4.cu | 13 +- .../common/transformer_engine.cpp | 28 + transformer_engine/pytorch/csrc/common.h | 4 + .../pytorch/csrc/extensions/cast.cpp | 70 +- transformer_engine/pytorch/csrc/quantizer.cpp | 49 +- .../pytorch/csrc/type_converters.cpp | 2 + .../custom_recipes/quantization_ref_nvfp4.py | 190 ++++- transformer_engine/pytorch/quantization.py | 39 +- .../pytorch/tensor/grouped_tensor.py | 6 + .../pytorch/tensor/nvfp4_tensor.py | 72 +- .../tensor/storage/grouped_tensor_storage.py | 49 ++ .../tensor/storage/nvfp4_tensor_storage.py | 16 + 37 files changed, 2595 insertions(+), 251 deletions(-) create mode 100644 transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh diff --git a/docs/envvars.rst b/docs/envvars.rst index ffbad409d4..bd62ccac46 100644 --- a/docs/envvars.rst +++ b/docs/envvars.rst @@ -287,6 +287,30 @@ Kernel Configuration :Default: ``0`` :Description: Enable row-scaled NVFP4 tensors for forward activation quantizers in the ``NVFP4BlockScaling`` recipe. When set to ``1`` (or when ``NVFP4BlockScaling(row_scaled_activation=True)`` is used), rowwise ``amax`` metadata is stored as one FP32 value per tensor row instead of a single scalar. +.. envvar:: NVTE_NVFP4_4OVER6 + + :Type: ``str`` (``none``, ``weights``, ``activations``, or ``all``) + :Default: ``none`` + :Description: Enable 4over6 adaptive NVFP4 block scaling for weights, activations, or both in the ``NVFP4BlockScaling`` recipe. For each selected FP4 block, quantization compares map-to-4 and map-to-6 candidates and stores the candidate with lower configured error. ``none`` keeps standard NVFP4. Current 4over6 support targets RL and post-training scenarios; pre-training paths that combine 4over6 with RHT are not yet implemented. + +.. envvar:: NVTE_NVFP4_4OVER6_E4M3_USE_256 + + :Type: ``str`` (``none``, ``weights``, ``activations``, or ``all``) + :Default: ``all`` + :Description: Select NVFP4 4over6 quantizers that use 256 instead of 448 as the global E4M3 scale bound. By default, all 4over6 quantizers use 256. Set the env var to ``none`` (or set ``NVFP4BlockScaling(nvfp4_4over6_e4m3_use_256="none")``) to use the standard NVFP4 448 bound for all 4over6 quantizers. This option is only meaningful for tensor roles that also enable :envvar:`NVTE_NVFP4_4OVER6`. + +.. envvar:: NVTE_NVFP4_4OVER6_ERR_MODE + + :Type: ``str`` (``MAE`` or ``MSE``) + :Default: ``MAE`` + :Description: Select the input-domain error metric used by NVFP4 4over6 map-to-4 versus map-to-6 candidate selection in the ``NVFP4BlockScaling`` recipe. + +.. envvar:: NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH + + :Type: ``int`` (0 or 1) + :Default: ``0`` + :Description: Allow the NVFP4 4over6 candidate error computation to use faster non-strict floating-point expressions. By default, 4over6 error comparison uses strict expressions; ``NVTE_USE_FAST_MATH`` does not control this error-comparison path. + Torch Compilation and Fusion ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index a8f58f8598..d6ab4b6740 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -62,12 +62,14 @@ std::vector create_transpose(const InputType* const input, const size } // Compute the global encode scale factor for a given global amax -float compute_global_encode_scaling_factor_FP4(const float global_amax, const bool use_fast_math) { - constexpr float fp8_max = 448.0f; // 448.0f; +float compute_global_encode_scaling_factor_FP4(const float global_amax, const bool use_fast_math, + const int e4m3_max = 448) { + NVTE_CHECK(e4m3_max == 448 || e4m3_max == 256, "Unsupported NVFP4 E4M3 max."); + const float fp8_max = static_cast(e4m3_max); constexpr float fp4_max = 6.0f; // 6.0f; float global_encode_scale = fp8_max * fp4_max / global_amax; // If scale is infinity, return the max normalized value - const float max_norm_clamp = use_fast_math + const float max_norm_clamp = (use_fast_math && e4m3_max == 448) ? Numeric_Traits::maxNorm : Numeric_Traits::maxNorm; @@ -79,6 +81,103 @@ float compute_global_encode_scaling_factor_FP4(const float global_amax, const bo return global_encode_scale; } +struct NVFP4FourOverSixQuantization { + fp8e4m3 scale_map4; + fp8e4m3 scale_map6; + float reciprocal_map4; + float reciprocal_map6; + fp4e2m1x2 quantized_map4; + fp4e2m1x2 quantized_map6; +}; + +enum class NVFP4FourOverSixCandidate { + Map4, + Map6, +}; + +enum class NVFP4ScalingMode { + Block1D, + RowScaled1D, + Block2D, +}; + +struct NVFP4FourOverSixTestConfig { + NVTENVFP44Over6Mode mode = kNVTENVFP44Over6Disabled; + int e4m3_max = 448; + bool err_use_fast_math = false; +}; + +bool use_2d_quantization(const NVFP4ScalingMode scaling_mode) { + return scaling_mode == NVFP4ScalingMode::Block2D; +} + +NVFP4FourOverSixQuantization compute_4over6_quantization_scales( + const float block_amax, const float global_encode_scale) { + constexpr float fp4_max = 6.0f; + constexpr float fp8_max = 448.0f; + constexpr float scale_expansion_factor = 1.5f; + const float base_sf_high_precision = block_amax / fp4_max * global_encode_scale; + const float sf_high_precision_map4 = + fminf(base_sf_high_precision * scale_expansion_factor, fp8_max); + const float sf_high_precision_map6 = fminf(base_sf_high_precision, fp8_max); + const fp8e4m3 scale_map4 = static_cast(sf_high_precision_map4); + const fp8e4m3 scale_map6 = static_cast(sf_high_precision_map6); + + const float global_decode_scale = 1.0f / global_encode_scale; + const float scale_map4_fp32 = static_cast(scale_map4); + const float reciprocal_map4 = + fminf(1.0f / (scale_map4_fp32 * global_decode_scale), Numeric_Traits::maxNorm); + const float scale_map6_fp32 = static_cast(scale_map6); + const float reciprocal_map6 = + fminf(1.0f / (scale_map6_fp32 * global_decode_scale), Numeric_Traits::maxNorm); + + const float2 zero = {0.0f, 0.0f}; + return { + scale_map4, + scale_map6, + reciprocal_map4, + reciprocal_map6, + fp4e2m1x2(zero), + fp4e2m1x2(zero), + }; +} + +fp8e4m3 select_4over6_scale(const NVFP4FourOverSixQuantization& quantization, + const NVFP4FourOverSixCandidate candidate) { + if (candidate == NVFP4FourOverSixCandidate::Map4) { + return quantization.scale_map4; + } + return quantization.scale_map6; +} + +fp4e2m1x2 select_4over6_quantized_pair(const NVFP4FourOverSixQuantization& quantization, + const NVFP4FourOverSixCandidate candidate) { + if (candidate == NVFP4FourOverSixCandidate::Map4) { + return quantization.quantized_map4; + } + return quantization.quantized_map6; +} + +NVFP4FourOverSixQuantization quantize_4over6_pair( + const float x, const float y, const NVFP4FourOverSixQuantization& quantization) { + const float2 scaled_map4 = {x * quantization.reciprocal_map4, + y * quantization.reciprocal_map4}; + const fp4e2m1x2 quantized_map4(scaled_map4); + + const float2 scaled_map6 = {x * quantization.reciprocal_map6, + y * quantization.reciprocal_map6}; + const fp4e2m1x2 quantized_map6(scaled_map6); + + return { + quantization.scale_map4, + quantization.scale_map6, + quantization.reciprocal_map4, + quantization.reciprocal_map6, + quantized_map4, + quantized_map6, + }; +} + // 1D Scaling: Original implementation with 1x16 blocks template void quantize_nvfp4_1d(float (*OP)(const float), @@ -89,10 +188,15 @@ void quantize_nvfp4_1d(float (*OP)(const float), const size_t cols, const size_t scales_stride, const float global_amax, - const bool use_fast_math) { + const bool use_fast_math, + const bool use_4over6 = false, + const int e4m3_max = 448, + const NVFP4FourOverSixCandidate four_over_six_candidate = + NVFP4FourOverSixCandidate::Map6) { // Compute a global encoding/decoding scaling factor for all S_dec_b - const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math); + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math, + e4m3_max); constexpr size_t block_size_X = 16; const size_t blocks_X = divide_round_up(cols, block_size_X); @@ -122,6 +226,27 @@ void quantize_nvfp4_1d(float (*OP)(const float), block_amax = std::max(block_amax, std::abs(elt)); } + const size_t scale_idx = i * scales_stride + block_X; + + if (use_4over6) { + const NVFP4FourOverSixQuantization quantization = + compute_4over6_quantization_scales(block_amax, S_enc); + scales[scale_idx] = select_4over6_scale(quantization, four_over_six_candidate); + + for (size_t j = j_min; j < j_max; j += 2) { + const int idx_pair = (i * cols + j) / 2; + const int cache_idx_x = j - j_min; + const int cache_idx_y = cache_idx_x + 1; + const float cached_x = cache_buffer[cache_idx_x]; + const float cached_y = cache_buffer[cache_idx_y]; + const NVFP4FourOverSixQuantization pair_quantization = + quantize_4over6_pair(cached_x, cached_y, quantization); + output[idx_pair] = + select_4over6_quantized_pair(pair_quantization, four_over_six_candidate); + } + continue; + } + // Compute and store the per-block FP8 decode scale const float S_dec_b = block_amax * (S_enc * (1.0f / 6.0f)); const fp8e4m3 S_dec_b_fp8 = static_cast(fminf(S_dec_b, Numeric_Traits::maxNorm)); @@ -131,7 +256,6 @@ void quantize_nvfp4_1d(float (*OP)(const float), const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : fminf(1.0f / (S_dec_b_fp32 * (1.0f / S_enc)), Numeric_Traits::maxNorm); - const size_t scale_idx = i * scales_stride + block_X; scales[scale_idx] = S_dec_b_fp8; float scale_reciprocal = S_enc_b_fp8; @@ -167,9 +291,14 @@ void compute_2d_mathematical_scales(float (*OP)(const float), const size_t cols, const float global_amax, std::vector>& math_scales, - const bool use_fast_math) { - - const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math); + const bool use_fast_math, + const bool use_4over6 = false, + const int e4m3_max = 448, + const NVFP4FourOverSixCandidate four_over_six_candidate = + NVFP4FourOverSixCandidate::Map6) { + + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math, + e4m3_max); constexpr size_t block_size_Y = 16; constexpr size_t block_size_X = 16; const size_t blocks_Y = divide_round_up(rows, block_size_Y); @@ -197,9 +326,16 @@ void compute_2d_mathematical_scales(float (*OP)(const float), } // Compute E4M3 scaling factor for this 16x16 block - const float S_dec_b = block_amax / 6.0f; - const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); - math_scales[block_Y][block_X] = S_dec_b_fp8; + if (use_4over6) { + const NVFP4FourOverSixQuantization quantization = + compute_4over6_quantization_scales(block_amax, S_enc); + math_scales[block_Y][block_X] = + select_4over6_scale(quantization, four_over_six_candidate); + } else { + const float S_dec_b = block_amax / 6.0f * S_enc; + const fp8e4m3 S_dec_b_fp8_map6 = static_cast(S_dec_b); + math_scales[block_Y][block_X] = S_dec_b_fp8_map6; + } } } } @@ -214,13 +350,19 @@ void quantize_nvfp4_2d(float (*OP)(const float), const size_t cols, const size_t scales_stride, const float global_amax, - const bool use_fast_math) { + const bool use_fast_math, + const bool use_4over6 = false, + const int e4m3_max = 448, + const NVFP4FourOverSixCandidate four_over_six_candidate = + NVFP4FourOverSixCandidate::Map6) { // Step 1: Compute mathematical 8x8 scaling factors std::vector> math_scales; - compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math); + compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math, + use_4over6, e4m3_max, four_over_six_candidate); - const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math); + const float S_enc = compute_global_encode_scaling_factor_FP4(global_amax, use_fast_math, + e4m3_max); constexpr size_t block_size_Y = 16; constexpr size_t block_size_X = 16; const size_t blocks_Y = divide_round_up(rows, block_size_Y); @@ -250,7 +392,7 @@ void quantize_nvfp4_2d(float (*OP)(const float), // Get the scaling factor for this block const float S_dec_b_fp8 = static_cast(math_scales[block_Y][block_X]); - const float S_enc_b_fp8 = S_dec_b_fp8 == 0 ? 0.f : S_enc / S_dec_b_fp8; + const float S_enc_b_fp8 = S_dec_b_fp8 == 0.0f ? 0.0f : S_enc / S_dec_b_fp8; const float scale_reciprocal = S_enc_b_fp8; // Process and cache data for this 16x16 block @@ -302,11 +444,17 @@ void quantize_nvfp4(float (*OP)(const float), const size_t scales_stride, const float global_amax, const bool use_fast_math, - const bool use_2d_quantization = false) { + const bool use_2d_quantization = false, + const bool use_4over6 = false, + const int e4m3_max = 448, + const NVFP4FourOverSixCandidate four_over_six_candidate = + NVFP4FourOverSixCandidate::Map6) { if (use_2d_quantization) { - quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_fast_math); + quantize_nvfp4_2d(OP, input, output, scales, rows, cols, scales_stride, global_amax, + use_fast_math, use_4over6, e4m3_max, four_over_six_candidate); } else { - quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax, use_fast_math); + quantize_nvfp4_1d(OP, input, output, scales, rows, cols, scales_stride, global_amax, + use_fast_math, use_4over6, e4m3_max, four_over_six_candidate); } } @@ -324,7 +472,11 @@ void compute_ref(float (*OP)(const float), const size_t scales_stride_t, const bool use_fast_math, const bool use_2d_quantization = false, - const bool row_scaled_nvfp4 = false) + const bool row_scaled_nvfp4 = false, + const bool use_4over6 = false, + const int e4m3_max = 448, + const NVFP4FourOverSixCandidate four_over_six_candidate = + NVFP4FourOverSixCandidate::Map6) { std::vector input_t = create_transpose(input, rows, cols); NVTE_CHECK(!(use_2d_quantization && row_scaled_nvfp4), @@ -334,7 +486,8 @@ void compute_ref(float (*OP)(const float), if (use_2d_quantization) { // Step 1: Compute mathematical 8×8 scaling factors std::vector> math_scales; - compute_2d_mathematical_scales(OP, input, rows, cols, *amax, math_scales, use_fast_math); + compute_2d_mathematical_scales(OP, input, rows, cols, *amax, math_scales, use_fast_math, + use_4over6, e4m3_max, four_over_six_candidate); constexpr size_t block_size_Y = 16; constexpr size_t block_size_X = 16; @@ -362,9 +515,11 @@ void compute_ref(float (*OP)(const float), // Step 4: Process quantized outputs using the same algorithm as quantize_nvfp4_2d // (This part processes the actual FP4 data using the mathematical scaling factors) quantize_nvfp4_2d(OP, input, output, nullptr, rows, cols, scales_stride, *amax, - use_fast_math); // scales already filled + use_fast_math, use_4over6, e4m3_max, + four_over_six_candidate); // scales already filled quantize_nvfp4_2d(OP, input_t.data(), output_t, nullptr, cols, rows, scales_stride_t, *amax, - use_fast_math); // scales_t already filled + use_fast_math, use_4over6, e4m3_max, + four_over_six_candidate); // scales_t already filled return; } @@ -381,16 +536,21 @@ void compute_ref(float (*OP)(const float), scales_stride, amax[row], use_fast_math, - use_2d_quantization); + use_2d_quantization, + use_4over6, + e4m3_max, + four_over_six_candidate); } return; } // Ref impl for basic NVFP4 quantize_nvfp4(OP, input, output, scales, rows, cols, scales_stride, *amax, - use_fast_math, use_2d_quantization); + use_fast_math, use_2d_quantization, use_4over6, e4m3_max, + four_over_six_candidate); quantize_nvfp4(OP, input_t.data(), output_t, scales_t, cols, rows, scales_stride_t, *amax, - use_fast_math, use_2d_quantization); + use_fast_math, use_2d_quantization, use_4over6, e4m3_max, + four_over_six_candidate); } void compare_nvfp4_tensors(const std::string& name, @@ -515,6 +675,92 @@ void compareResults_nvfp4(Tensor &test, } } +template +bool bitwise_equal(const T& x, const T& y) { + const auto *x_bytes = reinterpret_cast(&x); + const auto *y_bytes = reinterpret_cast(&y); + for (size_t i = 0; i < sizeof(T); ++i) { + if (x_bytes[i] != y_bytes[i]) { + return false; + } + } + return true; +} + +bool nvfp4_output_block_matches(const fp4e2m1x2* const test_data, + const fp4e2m1x2* const ref_data, + const size_t row, + const size_t cols, + const size_t block_x) { + constexpr size_t block_size_X = 16; + const size_t j_min = block_x * block_size_X; + const size_t j_max = std::min(j_min + block_size_X, cols); + for (size_t j = j_min; j < j_max; j += 2) { + const size_t idx_pair = (row * cols + j) / 2; + if (!bitwise_equal(test_data[idx_pair], ref_data[idx_pair])) { + return false; + } + } + return true; +} + +void compare_nvfp4_4over6_candidates(const std::string& name, + const fp4e2m1* const test_data, + const fp8e4m3* const test_scales, + const fp4e2m1x2* const ref_data_map4, + const fp8e4m3* const ref_scales_map4, + const fp4e2m1x2* const ref_data_map6, + const fp8e4m3* const ref_scales_map6, + const size_t rows, + const size_t cols, + const size_t blocks_X, + const size_t scales_stride) { + constexpr int max_mismatches_to_print = 3; + const auto* const test_data_pairs = reinterpret_cast(test_data); + size_t total_mismatches = 0; + + for (size_t row = 0; row < rows; ++row) { + for (size_t block_x = 0; block_x < blocks_X; ++block_x) { + const size_t scale_idx = row * scales_stride + block_x; + const bool scale_matches_map4 = + bitwise_equal(test_scales[scale_idx], ref_scales_map4[scale_idx]); + const bool data_matches_map4 = + nvfp4_output_block_matches(test_data_pairs, ref_data_map4, row, cols, block_x); + const bool scale_matches_map6 = + bitwise_equal(test_scales[scale_idx], ref_scales_map6[scale_idx]); + const bool data_matches_map6 = + nvfp4_output_block_matches(test_data_pairs, ref_data_map6, row, cols, block_x); + + if ((scale_matches_map4 && data_matches_map4) || + (scale_matches_map6 && data_matches_map6)) { + continue; + } + + ++total_mismatches; + if (total_mismatches <= max_mismatches_to_print) { + std::cout << "Error in tensor " << name << ": 4over6 block mismatch at row " + << row << ", block_x " << block_x + << ". The output did not match either map-to-4 or map-to-6 exactly." + << std::endl; + } + } + } + + std::cout << "=== SUMMARY for tensor " << name << " ===" << std::endl; + std::cout << "Total 4over6 blocks checked: " << (rows * blocks_X) << std::endl; + if (total_mismatches > 0) { + std::cout << "STATUS: FAILED for output" << std::endl; + std::cout << "Total mismatched 4over6 blocks found: " << total_mismatches << std::endl; + std::cout << "============================" << std::endl; + GTEST_FAIL() << "Found " << total_mismatches << " 4over6 block mismatches in tensor " + << name; + } + + std::cout << "STATUS: PASSED for output" << std::endl; + std::cout << "Each 4over6 block matched either map-to-4 or map-to-6 exactly" << std::endl; + std::cout << "============================" << std::endl; +} + void compare_rowwise_amax(Tensor &output, const std::vector &ref_amax) { ASSERT_EQ(output.rowwise_amax_size(), ref_amax.size()); const auto *amax_ptr = output.cpu_rowwise_amax_ptr(); @@ -529,12 +775,25 @@ template void performTest(float (*OP)(const float), const std::vector& shape, const bool use_fast_math, - const bool row_scaled_nvfp4 = false) { + const NVFP4ScalingMode scaling_mode = NVFP4ScalingMode::Block1D, + const NVTENVFP44Over6Mode mode = kNVTENVFP44Over6Disabled, + const int e4m3_max = 448, + const bool use_4over6_err_use_fast_math = false) { using namespace test; + const bool use_4over6 = mode != kNVTENVFP44Over6Disabled; + + if (use_4over6 && use_fast_math) { + std::cout << "WARNING: Plain NVFP4 fast math is ignored for 4over6. " + "Use use_4over6_err_use_fast_math to test the 4over6 candidate " + "error fast-math path." + << std::endl; + } DType itype = TypeInfo::dtype; DType otype = DType::kFloat4E2M1; + const bool is_2d_quantization = use_2d_quantization(scaling_mode); + const bool row_scaled_nvfp4 = scaling_mode == NVFP4ScalingMode::RowScaled1D; const bool rowwise = true; const bool columnwise = !row_scaled_nvfp4; @@ -560,14 +819,52 @@ void performTest(float (*OP)(const float), Tensor input("input", shape, itype); Tensor output("output", shape, otype, rowwise, columnwise, NVTE_NVFP4_1D_SCALING); + output.set_nvfp4_e4m3_max(e4m3_max); std::unique_ptr ref_output = std::make_unique(rows * (cols / 2)); std::unique_ptr ref_output_t = std::make_unique(cols * (rows / 2)); std::unique_ptr ref_scales = std::make_unique(blocks_Y * blocks_X); std::unique_ptr ref_scales_t = std::make_unique(blocks_Y_t * blocks_X_t); + std::unique_ptr ref_output_map6; + std::unique_ptr ref_output_t_map6; + std::unique_ptr ref_scales_map6; + std::unique_ptr ref_scales_t_map6; fillCase(&input, InputsFillCase::uniform); + if (use_4over6 && row_scaled_nvfp4) { + const float target_row_amax = static_cast(e4m3_max) * 6.0f * 8.0f; + auto *input_vals = input.rowwise_cpu_dptr(); + for (size_t row = 0; row < rows; ++row) { + float row_amax = 0.0f; + size_t max_col = 0; + for (size_t col = 0; col < cols; ++col) { + const float val = static_cast(input_vals[row * cols + col]); + const float abs_val = fabsf(val); + if (abs_val > row_amax) { + row_amax = abs_val; + max_col = col; + } + } + + if (row_amax == 0.0f) { + continue; + } + + const float row_scale = target_row_amax / row_amax; + for (size_t col = 0; col < cols; ++col) { + float scaled = static_cast(input_vals[row * cols + col]) * row_scale; + scaled = fminf(fmaxf(scaled, -target_row_amax), target_row_amax); + input_vals[row * cols + col] = static_cast(scaled); + } + + const float max_val = static_cast(input_vals[row * cols + max_col]); + input_vals[row * cols + max_col] = + static_cast(max_val < 0.0f ? -target_row_amax : target_row_amax); + } + input.from_cpu(); + } + // Compute 2nd stage NVFP4 scaling factor std::vector ref_amax; if (row_scaled_nvfp4) { @@ -587,7 +884,11 @@ void performTest(float (*OP)(const float), output.set_row_scaled_nvfp4(row_scaled_nvfp4); } else { // Golden value of amax chosen to make the 2nd-stage scaling mantissa zero and avoid rounding issues - ref_amax.assign(1, 448.0f * 6.0f * 8.0f); + if (use_4over6) { + ref_amax.assign(1, static_cast(e4m3_max) * 6.0f * 8.0f); + } else { + ref_amax.assign(1, 448.0f * 6.0f * 8.0f); + } // Update tensor if (rowwise) { @@ -599,22 +900,63 @@ void performTest(float (*OP)(const float), output.from_cpu(); } - // Reference implementation - bool use_2d_quantization = false; - compute_ref(OP, - input.rowwise_cpu_dptr(), - ref_output.get(), - ref_output_t.get(), - ref_scales.get(), - ref_scales_t.get(), - ref_amax.data(), - rows, - cols, - scales_stride, - scales_stride_t, - use_fast_math, - use_2d_quantization, - row_scaled_nvfp4); + if (use_4over6) { + ref_output_map6 = std::make_unique(rows * (cols / 2)); + ref_output_t_map6 = std::make_unique(cols * (rows / 2)); + ref_scales_map6 = std::make_unique(blocks_Y * blocks_X); + ref_scales_t_map6 = std::make_unique(blocks_Y_t * blocks_X_t); + + compute_ref(OP, + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output_t.get(), + ref_scales.get(), + ref_scales_t.get(), + ref_amax.data(), + rows, + cols, + scales_stride, + scales_stride_t, + use_fast_math, + is_2d_quantization, + row_scaled_nvfp4, + use_4over6, + e4m3_max, + NVFP4FourOverSixCandidate::Map4); + compute_ref(OP, + input.rowwise_cpu_dptr(), + ref_output_map6.get(), + ref_output_t_map6.get(), + ref_scales_map6.get(), + ref_scales_t_map6.get(), + ref_amax.data(), + rows, + cols, + scales_stride, + scales_stride_t, + use_fast_math, + is_2d_quantization, + row_scaled_nvfp4, + use_4over6, + e4m3_max, + NVFP4FourOverSixCandidate::Map6); + } else { + compute_ref(OP, + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output_t.get(), + ref_scales.get(), + ref_scales_t.get(), + ref_amax.data(), + rows, + cols, + scales_stride, + scales_stride_t, + use_fast_math, + is_2d_quantization, + row_scaled_nvfp4, + use_4over6); + } // Initialize stochastic rounding Tensor rng_state("rng_state", std::vector{2}, DType::kInt64); @@ -624,10 +966,12 @@ void performTest(float (*OP)(const float), // Quantization options QuantizationConfigWrapper quant_config; - quant_config.set_use_fast_math(use_fast_math); + quant_config.set_use_fast_math(use_fast_math && !use_4over6); quant_config.set_stochastic_rounding(false); quant_config.set_rng_state(rng_state.data()); - quant_config.set_nvfp4_2d_quantization(use_2d_quantization); + quant_config.set_nvfp4_2d_quantization(is_2d_quantization); + quant_config.set_nvfp4_4over6_mode(mode); + quant_config.set_nvfp4_4over6_err_use_fast_math(use_4over6 && use_4over6_err_use_fast_math); // Call appropriate function based on operation type // Activation functions take 3 parameters (input, output, stream) @@ -656,21 +1000,50 @@ void performTest(float (*OP)(const float), const double atol = 1.0E-6; const double rtol = 1.0E-6; - // Set dump_data=true to enable dumping tensor data to files for analysis - compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, - false, !row_scaled_nvfp4); - - size_t scale_mismatches_num = 0; - compare_scaling_factors("scales", output.rowwise_cpu_scale_inv_ptr(), - ref_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride, - scale_mismatches_num); - - if (!row_scaled_nvfp4) { - compare_scaling_factors("scales_t", output.columnwise_cpu_scale_inv_ptr(), - ref_scales_t.get(), - unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t, + if (use_4over6) { + output.to_cpu(); + compare_nvfp4_4over6_candidates("output", + output.rowwise_cpu_dptr(), + output.rowwise_cpu_scale_inv_ptr(), + ref_output.get(), + ref_scales.get(), + ref_output_map6.get(), + ref_scales_map6.get(), + rows, + cols, + unpadded_blocks_X, + scales_stride); + if (!row_scaled_nvfp4) { + compare_nvfp4_4over6_candidates("output_t", + output.columnwise_cpu_dptr(), + output.columnwise_cpu_scale_inv_ptr(), + ref_output_t.get(), + ref_scales_t.get(), + ref_output_t_map6.get(), + ref_scales_t_map6.get(), + cols, + rows, + unpadded_blocks_X_t, + scales_stride_t); + } + } else { + // Set dump_data=true to enable dumping tensor data to files for analysis + compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, + true, false, !row_scaled_nvfp4); + + size_t scale_mismatches_num = 0; + compare_scaling_factors("scales", output.rowwise_cpu_scale_inv_ptr(), + ref_scales.get(), + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, scale_mismatches_num); + + if (!row_scaled_nvfp4) { + compare_scaling_factors("scales_t", + output.columnwise_cpu_scale_inv_ptr(), + ref_scales_t.get(), + unpadded_blocks_Y_t, unpadded_blocks_X_t, + scales_stride_t, scale_mismatches_num); + } } compare_rowwise_amax(output, ref_amax); @@ -707,7 +1080,8 @@ class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam std::vector, transformer_engine::DType, bool, - bool>> {}; + NVFP4ScalingMode, + NVFP4FourOverSixTestConfig>> {}; TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { // Skip tests for pre-Blackwell architectures @@ -722,7 +1096,8 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { const auto tensor_dims = std::get<1>(GetParam()); const DType input_type = std::get<2>(GetParam()); const bool use_fast_math = std::get<3>(GetParam()); - const bool row_scaled_nvfp4 = std::get<4>(GetParam()); + const NVFP4ScalingMode scaling_mode = std::get<4>(GetParam()); + const NVFP4FourOverSixTestConfig config = std::get<5>(GetParam()); // Skip tests if the input tensor is 1D if (tensor_dims.size() < 2) { @@ -740,7 +1115,9 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { } TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, - performTest(OP, tensor_dims, use_fast_math, row_scaled_nvfp4); + performTest(OP, tensor_dims, use_fast_math, scaling_mode, config.mode, + config.e4m3_max, + config.err_use_fast_math); ); } @@ -756,49 +1133,96 @@ std::string to_string(const ActivationType Act_type) { } } +std::string to_string(const NVFP4ScalingMode scaling_mode) { + switch (scaling_mode) { + case NVFP4ScalingMode::Block1D: return ""; + case NVFP4ScalingMode::RowScaled1D: return "XROW_SCALED"; + case NVFP4ScalingMode::Block2D: return "X2D"; + default: return ""; + } +} + +std::string test_name(const FusedCastTransposeNVFP4TestSuite::ParamType& param) { + std::string name = to_string(std::get<0>(param)); + const auto& shape = std::get<1>(param); + for (const auto& s: shape) { + name += "X" + std::to_string(s); + } + name += "X" + test::typeName(std::get<2>(param)); + if (std::get<3>(param)) { + name += "X_FAST_SCALING"; + } + name += to_string(std::get<4>(param)); + const NVFP4FourOverSixTestConfig& config = std::get<5>(param); + if (config.mode != kNVTENVFP44Over6Disabled) { + name += "X4OVER6"; + if (config.e4m3_max == 448) { + name += "XE4M3_MAX_448"; + } else { + name += "XE4M3_MAX_256"; + } + if (config.mode == kNVTENVFP44Over6MinMSE) { + name += "XMSE"; + } else if (config.mode == kNVTENVFP44Over6MinMAE) { + name += "XMAE"; + } else { + name += "XINVALID_MODE"; + } + if (config.err_use_fast_math) { + name += "XERR_USE_FAST_MATH"; + } + } + return name; +} + INSTANTIATE_TEST_SUITE_P( OperatorTest, FusedCastTransposeNVFP4TestSuite, ::testing::Combine( - ::testing::ValuesIn(Activation_types), - ::testing::ValuesIn(tensor_dims), - ::testing::Values(DType::kBFloat16), - ::testing::Values(false), - ::testing::Values(false)), + ::testing::ValuesIn(Activation_types), // activation_type + ::testing::ValuesIn(tensor_dims), // tensor_dims + ::testing::Values(DType::kBFloat16), // input_type + ::testing::Values(false), // use_fast_math + ::testing::Values(NVFP4ScalingMode::Block1D), // scaling_mode + ::testing::Values(NVFP4FourOverSixTestConfig{})), // four_over_six_config [](const testing::TestParamInfo& info) { - std::string name = to_string(std::get<0>(info.param)); - const auto& shape = std::get<1>(info.param); - for ( const auto& s: shape) { - name += "X" + std::to_string(s); - } - name += "X" + test::typeName(std::get<2>(info.param)); - if (std::get<3>(info.param)) { - name += "X_FAST_SCALING"; - } - return name; + return test_name(info.param); }); INSTANTIATE_TEST_SUITE_P( OperatorTestRowScaled, FusedCastTransposeNVFP4TestSuite, ::testing::Combine( - ::testing::Values(ActivationType::Identity), - ::testing::Values(tensor_dims[4], tensor_dims[9], tensor_dims[12]), - ::testing::Values(DType::kBFloat16, DType::kFloat32), - ::testing::Values(false), - ::testing::Values(true)), + ::testing::ValuesIn(Activation_types), // activation_type + ::testing::ValuesIn(tensor_dims), // tensor_dims + ::testing::Values(DType::kBFloat16, DType::kFloat32), // input_type + ::testing::Values(false), // use_fast_math + ::testing::Values(NVFP4ScalingMode::RowScaled1D), // scaling_mode + ::testing::Values(NVFP4FourOverSixTestConfig{})), // four_over_six_config [](const testing::TestParamInfo& info) { - std::string name = to_string(std::get<0>(info.param)); - const auto& shape = std::get<1>(info.param); - for (const auto& s: shape) { - name += "X" + std::to_string(s); - } - name += "X" + test::typeName(std::get<2>(info.param)); - if (std::get<3>(info.param)) { - name += "X_FAST_SCALING"; - } - if (std::get<4>(info.param)) { - name += "XROW_SCALED"; - } - return name; + return test_name(info.param); + }); + +INSTANTIATE_TEST_SUITE_P( + OperatorTest4Over6, + FusedCastTransposeNVFP4TestSuite, + ::testing::Combine( + ::testing::ValuesIn(Activation_types), // activation_type + ::testing::ValuesIn(tensor_dims), // tensor_dims + ::testing::Values(DType::kBFloat16, DType::kFloat32), // input_type + ::testing::Values(false), // use_fast_math + ::testing::Values(NVFP4ScalingMode::Block1D, + NVFP4ScalingMode::RowScaled1D, + NVFP4ScalingMode::Block2D), // scaling_mode + ::testing::Values( + NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMAE, 448, false}, + NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMAE, 448, true}, + NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMSE, 448, false}, + NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMSE, 448, true}, + NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMAE, 256, false}, + NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMAE, 256, true}, + NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMSE, 256, false}, + NVFP4FourOverSixTestConfig{kNVTENVFP44Over6MinMSE, 256, true})), // four_over_six_config + [](const testing::TestParamInfo& info) { + return test_name(info.param); }); diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index eb9e8bce23..40c1fbd235 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -46,8 +46,9 @@ void compute_ref_dequantize_nvfp4(const uint8_t *packed_data, OType *output, size_t rows, size_t cols, - size_t scale_stride) { - constexpr float factor_inv = 1.0f / (6.0f * 448.0f); + size_t scale_stride, + int e4m3_max) { + const float factor_inv = 1.0f / (6.0f * static_cast(e4m3_max)); constexpr size_t BLOCK_SIZE = 16; const size_t Mread = cols / BLOCK_SIZE; const size_t bytes_per_block = BLOCK_SIZE / 2; @@ -86,11 +87,18 @@ float compute_amax(test::Tensor &t, size_t rows, size_t cols) { return amax; } +struct NVFP4DequantizeTestConfig { + NVTENVFP44Over6Mode mode = kNVTENVFP44Over6Disabled; + int e4m3_max = 448; +}; + // Quantize a high-precision input to NVFP4, then dequantize and compare // against a CPU reference computed from the quantized data. template void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, - const bool row_scaled_nvfp4) { + const bool row_scaled_nvfp4, + const NVTENVFP44Over6Mode mode, + const int e4m3_max) { using namespace test; DType otype = TypeInfo::dtype; @@ -105,6 +113,8 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, // Configure quantized tensor amax size_t amax_size = 1; + quantized.set_nvfp4_e4m3_max(e4m3_max); + ASSERT_EQ(quantized.nvfp4_e4m3_max(), e4m3_max); if (row_scaled_nvfp4) { quantized.set_row_scaled_nvfp4(true); amax_size = rows; @@ -116,7 +126,9 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, // Quantize if (rows > 0 && cols > 0) { - nvte_quantize(input.data(), quantized.data(), 0); + QuantizationConfigWrapper quant_config; + quant_config.set_nvfp4_4over6_mode(mode); + nvte_quantize_v2(input.data(), quantized.data(), quant_config, 0); cudaDeviceSynchronize(); auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); @@ -146,7 +158,7 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, std::make_unique(rows * cols); compute_ref_dequantize_nvfp4( fp4_data, scales, amax_vals, ref_output.get(), - rows, cols, scale_stride); + rows, cols, scale_stride, e4m3_max); // Compare results from TE and reference impls auto [atol, rtol] = getTolerances(otype); @@ -156,7 +168,9 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, // Dequantize NVFP4 with GEMM-swizzled scales and compare against compact path. template void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, - const bool row_scaled_nvfp4) { + const bool row_scaled_nvfp4, + const NVTENVFP44Over6Mode mode, + const int e4m3_max) { using namespace test; DType otype = TypeInfo::dtype; @@ -165,6 +179,8 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, Tensor quantized_compact("quantized_compact", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); + quantized_compact.set_nvfp4_e4m3_max(e4m3_max); + ASSERT_EQ(quantized_compact.nvfp4_e4m3_max(), e4m3_max); if (row_scaled_nvfp4) { quantized_compact.set_row_scaled_nvfp4(true); } else if (rows > 0 && cols > 0) { @@ -174,7 +190,9 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, } if (rows > 0 && cols > 0) { - nvte_quantize(input.data(), quantized_compact.data(), 0); + QuantizationConfigWrapper quant_config; + quant_config.set_nvfp4_4over6_mode(mode); + nvte_quantize_v2(input.data(), quantized_compact.data(), quant_config, 0); cudaDeviceSynchronize(); } @@ -186,6 +204,8 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, // Create tensor with same FP4 data but swizzled scales Tensor quantized_swizzled("quantized_swizzled", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); + quantized_swizzled.set_nvfp4_e4m3_max(e4m3_max); + ASSERT_EQ(quantized_swizzled.nvfp4_e4m3_max(), e4m3_max); if (row_scaled_nvfp4) { quantized_swizzled.set_row_scaled_nvfp4(true); } else { @@ -260,7 +280,8 @@ std::vector> nvfp4_tensor_dims = { class DequantizeNVFP4TestSuite : public ::testing::TestWithParam , transformer_engine::DType, - bool>> {}; + bool, + NVFP4DequantizeTestConfig>> {}; TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) { @@ -271,10 +292,12 @@ TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) const auto tensor_size = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); const bool row_scaled_nvfp4 = std::get<2>(GetParam()); + const NVFP4DequantizeTestConfig config = std::get<3>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, performTest_dequantize_nvfp4( - tensor_size.first, tensor_size.second, row_scaled_nvfp4); + tensor_size.first, tensor_size.second, row_scaled_nvfp4, config.mode, + config.e4m3_max); ); } @@ -284,13 +307,20 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Combine( ::testing::ValuesIn(nvfp4_tensor_dims), ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), - ::testing::Bool()), + ::testing::Bool(), + ::testing::Values(NVFP4DequantizeTestConfig{}, + NVFP4DequantizeTestConfig{kNVTENVFP44Over6MinMAE, 448}, + NVFP4DequantizeTestConfig{kNVTENVFP44Over6MinMAE, 256})), [](const testing::TestParamInfo& info) { + const NVFP4DequantizeTestConfig config = std::get<3>(info.param); + const bool use_4over6 = config.mode != kNVTENVFP44Over6Disabled; std::string name = std::to_string(std::get<0>(info.param).first) + "X" + std::to_string(std::get<0>(info.param).second) + "X" + test::typeName(std::get<1>(info.param)) + "X" + - (std::get<2>(info.param) ? "RowScaled" : "PerTensor"); + (std::get<2>(info.param) ? "RowScaled" : "PerTensor") + "X" + + (use_4over6 ? "FourOverSix" : "Default") + "X" + + (config.e4m3_max == 256 ? "E4M3Max256" : "E4M3Max448"); return name; } ); @@ -298,7 +328,8 @@ INSTANTIATE_TEST_SUITE_P( class DequantizeNVFP4SwizzledTestSuite : public ::testing::TestWithParam , transformer_engine::DType, - bool>> {}; + bool, + NVFP4DequantizeTestConfig>> {}; TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled) { @@ -309,10 +340,12 @@ TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled) const auto tensor_size = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); const bool row_scaled_nvfp4 = std::get<2>(GetParam()); + const NVFP4DequantizeTestConfig config = std::get<3>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, performTest_dequantize_nvfp4_swizzled( - tensor_size.first, tensor_size.second, row_scaled_nvfp4); + tensor_size.first, tensor_size.second, row_scaled_nvfp4, config.mode, + config.e4m3_max); ); } @@ -322,13 +355,20 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Combine( ::testing::ValuesIn(nvfp4_tensor_dims), ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), - ::testing::Bool()), + ::testing::Bool(), + ::testing::Values(NVFP4DequantizeTestConfig{}, + NVFP4DequantizeTestConfig{kNVTENVFP44Over6MinMAE, 448}, + NVFP4DequantizeTestConfig{kNVTENVFP44Over6MinMAE, 256})), [](const testing::TestParamInfo& info) { + const NVFP4DequantizeTestConfig config = std::get<3>(info.param); + const bool use_4over6 = config.mode != kNVTENVFP44Over6Disabled; std::string name = std::to_string(std::get<0>(info.param).first) + "X" + std::to_string(std::get<0>(info.param).second) + "X" + test::typeName(std::get<1>(info.param)) + "X" + (std::get<2>(info.param) ? "RowScaled" : "PerTensor") + "X" + + (use_4over6 ? "FourOverSix" : "Default") + "X" + + (config.e4m3_max == 256 ? "E4M3Max256" : "E4M3Max448") + "X" + "Swizzled"; return name; } diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 4fd75bb927..e35f5e029d 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -440,6 +440,18 @@ void Tensor::set_row_scaled_nvfp4(bool row_scaled_nvfp4) { } } +void Tensor::set_nvfp4_e4m3_max(int nvfp4_e4m3_max) { + NVTE_CHECK(tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING, + "NVFP4 E4M3 max is only supported for NVFP4 tensors."); + tensor_.set_nvfp4_e4m3_max(nvfp4_e4m3_max); +} + +int Tensor::nvfp4_e4m3_max() const { + NVTE_CHECK(tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING, + "NVFP4 E4M3 max is only supported for NVFP4 tensors."); + return tensor_.get_nvfp4_e4m3_max(); +} + void Tensor::to_cpu() { if (data_rowwise_) { data_rowwise_->to_cpu(); } if (data_columnwise_) { data_columnwise_->to_cpu(); } diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 17f36a99dd..fd03d283d7 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -293,10 +293,13 @@ class Tensor { return columnwise_; } + int nvfp4_e4m3_max() const; + void set_tensor_amax_nullptr(); void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales); void set_row_scaled_nvfp4(bool row_scaled_nvfp4); + void set_nvfp4_e4m3_max(int nvfp4_e4m3_max); void to_cpu(); void from_cpu(); diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index a7ea4f089f..bd4d029729 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -28,7 +28,12 @@ def check_nvfp4_gemm_versus_reference( x_columnwise: bool = False, w_columnwise: bool = False, row_scaled_nvfp4: bool = False, + use_4over6: bool = False, + nvfp4_e4m3_max: int = 448, + nvfp4_4over6_err_mode: str = "MAE", ): + if nvfp4_e4m3_max != 448 and not use_4over6: + pytest.skip("E4M3 max 256 is only meaningful for 4over6") te_dtype = tex.DType.kFloat4E2M1 # Setup device and random seed @@ -59,6 +64,9 @@ def check_nvfp4_gemm_versus_reference( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) w_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -68,6 +76,9 @@ def check_nvfp4_gemm_versus_reference( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + nvfp4_use_4over6=use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) # Quantize x and w @@ -123,6 +134,9 @@ def check_nvfp4_gemm_versus_reference( eps=0.0, quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) w_ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, @@ -131,6 +145,9 @@ def check_nvfp4_gemm_versus_reference( pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), + nvfp4_use_4over6=use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) # Create reference quantized tensors needed by reference GEMM @@ -232,6 +249,8 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( *, use_bias: bool, single_output: bool, + use_4over6: bool = False, + nvfp4_4over6_err_mode: str = "MAE", ): te_dtype = tex.DType.kFloat4E2M1 device = "cuda" @@ -249,6 +268,8 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=True, + nvfp4_use_4over6=use_4over6, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) w_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -258,6 +279,8 @@ def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + nvfp4_use_4over6=use_4over6, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) x_nvfp4 = [] @@ -321,6 +344,8 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( M: int, K: int, N: int, + use_4over6: bool = False, + nvfp4_4over6_err_mode: str = "MAE", ): te_dtype = tex.DType.kFloat4E2M1 device = "cuda" @@ -339,6 +364,8 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=True, + nvfp4_use_4over6=use_4over6, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) x_tensorwise_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -348,6 +375,8 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + nvfp4_use_4over6=use_4over6, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) w_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -357,6 +386,8 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + nvfp4_use_4over6=use_4over6, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) x_row_scaled = x_row_scaled_quantizer.update_quantized( @@ -417,6 +448,9 @@ def check_nvfp4_row_scaled_gemm_matches_emulated( ids=["rowxrow", "colxrow", "colxcol"], ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) +@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) +@pytest.mark.parametrize("nvfp4_e4m3_max", [448, 256], ids=["e4m3_448", "e4m3_256"]) +@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_nvfp4_gemm_versus_reference( M: int, K: int, @@ -428,6 +462,9 @@ def test_nvfp4_gemm_versus_reference( is_x_columnwise: bool, is_w_columnwise: bool, row_scaled_nvfp4: bool, + use_4over6: bool, + nvfp4_e4m3_max: int, + nvfp4_4over6_err_mode: str, ): if row_scaled_nvfp4: if accumulate: @@ -446,6 +483,9 @@ def test_nvfp4_gemm_versus_reference( x_columnwise=is_x_columnwise, w_columnwise=is_w_columnwise, row_scaled_nvfp4=row_scaled_nvfp4, + use_4over6=use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) @@ -471,6 +511,8 @@ def test_nvfp4_gemm_versus_reference( @pytest.mark.parametrize("out_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("use_bias", [False, True], ids=["no_bias", "bias"]) @pytest.mark.parametrize("single_output", [False, True], ids=["list_output", "single_output"]) +@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) +@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( m_splits: list[int], k: int, @@ -480,6 +522,8 @@ def test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( out_dtype: torch.dtype, use_bias: bool, single_output: bool, + use_4over6: bool, + nvfp4_4over6_err_mode: str, ): check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( x_dtype=x_dtype, @@ -490,6 +534,8 @@ def test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( n=n, use_bias=use_bias, single_output=single_output, + use_4over6=use_4over6, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) @@ -513,6 +559,8 @@ def test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("w_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) +@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_nvfp4_row_scaled_gemm_matches_emulated( M: int, K: int, @@ -520,6 +568,8 @@ def test_nvfp4_row_scaled_gemm_matches_emulated( x_dtype: torch.dtype, w_dtype: torch.dtype, out_dtype: torch.dtype, + use_4over6: bool, + nvfp4_4over6_err_mode: str, ): check_nvfp4_row_scaled_gemm_matches_emulated( x_dtype=x_dtype, @@ -528,4 +578,6 @@ def test_nvfp4_row_scaled_gemm_matches_emulated( M=M, K=K, N=N, + use_4over6=use_4over6, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 53569d90d9..5bb92f70dc 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -20,7 +20,14 @@ def maybe_skip_row_scaled_unsupported_quantization( row_scaled_nvfp4: bool, return_transpose: bool, with_2d_quantization: bool = False, + use_4over6: bool = False, + x_dtype: torch.dtype | None = None, + M: int | None = None, + N: int | None = None, ) -> None: + if use_4over6 and with_2d_quantization: + if x_dtype != torch.bfloat16 or M is None or N is None or M % 32 != 0 or N % 32 != 0: + pytest.skip("NVFP4 2D 4over6 exact tests require the optimized BF16 kernel path") if not row_scaled_nvfp4: return if return_transpose: @@ -45,9 +52,14 @@ def check_quantization_nvfp4_versus_reference( use_cpp_allocator: bool, with_2d_quantization: bool, row_scaled_nvfp4: bool = False, + use_4over6: bool = False, + nvfp4_e4m3_max: int = 448, + nvfp4_4over6_err_mode: str = "MAE", ) -> None: + if nvfp4_e4m3_max != 448 and not use_4over6: + pytest.skip("E4M3 max 256 is only meaningful for 4over6") maybe_skip_row_scaled_unsupported_quantization( - row_scaled_nvfp4, return_transpose, with_2d_quantization + row_scaled_nvfp4, return_transpose, with_2d_quantization, use_4over6, x_dtype, M, N ) te_dtype = tex.DType.kFloat4E2M1 @@ -71,6 +83,9 @@ def check_quantization_nvfp4_versus_reference( with_post_rht_amax=False, with_2d_quantization=with_2d_quantization, row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) if use_cpp_allocator: x_nvfp4_sut = nvfp4_quantizer(x) @@ -104,6 +119,9 @@ def check_quantization_nvfp4_versus_reference( eps=0.0, quant_tile_shape=quant_tile_shape, row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -179,6 +197,9 @@ def check_quantization_nvfp4_versus_reference( "with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"] ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) +@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) +@pytest.mark.parametrize("nvfp4_e4m3_max", [448, 256], ids=["e4m3_448", "e4m3_256"]) +@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, @@ -188,6 +209,9 @@ def test_quantization_block_tiling_versus_reference( use_cpp_allocator: bool, with_2d_quantization: bool, row_scaled_nvfp4: bool, + use_4over6: bool, + nvfp4_e4m3_max: int, + nvfp4_4over6_err_mode: str, ) -> None: check_quantization_nvfp4_versus_reference( x_dtype=x_dtype, @@ -198,6 +222,9 @@ def test_quantization_block_tiling_versus_reference( use_cpp_allocator=use_cpp_allocator, with_2d_quantization=with_2d_quantization, row_scaled_nvfp4=row_scaled_nvfp4, + use_4over6=use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) @@ -215,6 +242,8 @@ def test_quantization_block_tiling_versus_reference( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) +@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) +@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_nvfp4_quantization_extrema_versus_reference( x_dtype: torch.dtype, M: int, @@ -223,8 +252,12 @@ def test_nvfp4_quantization_extrema_versus_reference( return_transpose: bool, use_cpp_allocator: bool, row_scaled_nvfp4: bool, + use_4over6: bool, + nvfp4_4over6_err_mode: str, ): - maybe_skip_row_scaled_unsupported_quantization(row_scaled_nvfp4, return_transpose) + maybe_skip_row_scaled_unsupported_quantization( + row_scaled_nvfp4, return_transpose, use_4over6=use_4over6 + ) te_dtype = tex.DType.kFloat4E2M1 @@ -247,6 +280,8 @@ def test_nvfp4_quantization_extrema_versus_reference( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=use_4over6, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) if use_cpp_allocator: @@ -278,6 +313,8 @@ def test_nvfp4_quantization_extrema_versus_reference( eps=0.0, quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=use_4over6, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -322,6 +359,8 @@ def test_nvfp4_quantization_extrema_versus_reference( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) +@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) +@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_nvfp4_quantization_boundary_values( x_dtype: torch.dtype, M: int, @@ -329,13 +368,17 @@ def test_nvfp4_quantization_boundary_values( return_transpose: bool, use_cpp_allocator: bool, row_scaled_nvfp4: bool, + use_4over6: bool, + nvfp4_4over6_err_mode: str, ): """ Stress rounding/threshold behavior by placing values just below/above many potential bin edges within each 16-element microblock. Validates native vs reference byte-for-byte and scale parity. """ - maybe_skip_row_scaled_unsupported_quantization(row_scaled_nvfp4, return_transpose) + maybe_skip_row_scaled_unsupported_quantization( + row_scaled_nvfp4, return_transpose, use_4over6=use_4over6 + ) te_dtype = tex.DType.kFloat4E2M1 @@ -367,6 +410,8 @@ def test_nvfp4_quantization_boundary_values( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=use_4over6, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) if use_cpp_allocator: @@ -398,6 +443,8 @@ def test_nvfp4_quantization_boundary_values( eps=0.0, quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=use_4over6, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -442,6 +489,8 @@ def test_nvfp4_quantization_boundary_values( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) +@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) +@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) def test_nvfp4_quantization_noncontiguous_inputs( x_dtype: torch.dtype, M: int, @@ -449,8 +498,12 @@ def test_nvfp4_quantization_noncontiguous_inputs( return_transpose: bool, use_cpp_allocator: bool, row_scaled_nvfp4: bool, + use_4over6: bool, + nvfp4_4over6_err_mode: str, ): - maybe_skip_row_scaled_unsupported_quantization(row_scaled_nvfp4, return_transpose) + maybe_skip_row_scaled_unsupported_quantization( + row_scaled_nvfp4, return_transpose, use_4over6=use_4over6 + ) te_dtype = tex.DType.kFloat4E2M1 @@ -473,6 +526,8 @@ def test_nvfp4_quantization_noncontiguous_inputs( with_rht=False, with_post_rht_amax=False, row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=use_4over6, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) if use_cpp_allocator: @@ -504,6 +559,8 @@ def test_nvfp4_quantization_noncontiguous_inputs( eps=0.0, quant_tile_shape=(1, 16), row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=use_4over6, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, ) x_nvfp4_ref = ref_quantizer.quantize(x_nc) diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index 43e9587d95..5e6f36e8b4 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -83,6 +83,11 @@ marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), id="NVFP4RowScaledBlockScaling", ), + pytest.param( + "nvfp4_4over6", + marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), + id="NVFP44Over6BlockScaling", + ), ] @@ -170,7 +175,7 @@ def _maybe_skip_recipe_dtype( ) -> None: if dtype == torch.bfloat16 and not bf16_available: pytest.skip(reason_for_no_bf16) - if recipe_name in ("nvfp4", "nvfp4_row_scaled"): + if recipe_name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): if module_type in ("linear", "layernorm_linear") and dtype not in ( torch.bfloat16, torch.float32, @@ -185,6 +190,8 @@ def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: s pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") if module_type == "ops_linear" and recipe_name == "nvfp4_row_scaled": pytest.skip("Row-scaled NVFP4 currently does not support fused te_ops paths.") + if module_type == "grouped_linear" and recipe_name == "nvfp4_4over6": + pytest.skip("NVFP4 4over6 currently does not support grouped quantization.") def _make_quantized_forward_reference_recipe(recipe_name: str) -> recipe.Recipe: @@ -208,7 +215,7 @@ def _maybe_skip_unsupported_recipe_shape( " by 32." ) return - if recipe_name in ("nvfp4", "nvfp4_row_scaled") and ( + if recipe_name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") and ( flat_first_dim % 16 != 0 or last_dim % 16 != 0 ): pytest.skip( @@ -235,7 +242,7 @@ def _maybe_skip_unsupported_recipe_shape( pytest.skip( "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." ) - if recipe_name in ("nvfp4", "nvfp4_row_scaled") and ( + if recipe_name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") and ( flat_first_dim % 16 != 0 or last_dim % 16 != 0 ): pytest.skip( @@ -256,9 +263,13 @@ def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int] ) if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") - if recipe_name in ("nvfp4", "nvfp4_row_scaled") and any(m % 16 != 0 for m in non_empty_splits): + if recipe_name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") and any( + m % 16 != 0 for m in non_empty_splits + ): pytest.skip("GroupedLinear + NVFP4 requires each non-empty m_split divisible by 16.") - if recipe_name in ("nvfp4", "nvfp4_row_scaled") and any(m % 64 != 0 for m in non_empty_splits): + if recipe_name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") and any( + m % 64 != 0 for m in non_empty_splits + ): pytest.skip( "GroupedLinear + NVFP4 grouped split_quantize currently requires each non-empty " "m_split divisible by 64 due to grouped amax kernel constraints." diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 50196782f2..35cc98a976 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -19,7 +19,7 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch as te from transformer_engine.common import recipe -from utils import ModelConfig, skip_unsupported_backward_override +from utils import ModelConfig, recipe_id, skip_unsupported_backward_override import transformer_engine_torch as tex # Check supported quantization schemes @@ -28,6 +28,33 @@ mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available() nvfp4_available, _ = FP8GlobalStateManager.is_nvfp4_available() + +def nvfp4_row_scaled(): + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + row_scaled_activation=True, + backward_override="dequantized", + ) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + +def nvfp4_4over6(): + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + nvfp4_4over6="all", + ) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(fp4_2d_quantization=True) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + quantization_recipes: List[Optional[recipe.Recipe]] = [None] if fp8_available: quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling())) @@ -37,6 +64,8 @@ quantization_recipes.append(recipe.MXFP8BlockScaling()) if nvfp4_available: quantization_recipes.append(recipe.NVFP4BlockScaling()) + quantization_recipes.append(nvfp4_4over6()) + quantization_recipes.append(nvfp4_row_scaled()) model_config = { @@ -176,7 +205,20 @@ def create_tensor(recipe: Optional[recipe.Recipe], requires_grad: bool = False) quantizer = te.tensor.mxfp8_tensor.MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) return quantizer(tensor) elif recipe.nvfp4(): - quantizer = te.tensor.nvfp4_tensor.NVFP4Quantizer() + qparams = recipe.fp4_quant_fwd_inp + use_4over6 = False + if recipe.nvfp4_4over6 in ("activations", "all"): + use_4over6 = True + quantizer = te.tensor.nvfp4_tensor.NVFP4Quantizer( + rowwise=True, + columnwise=not recipe.row_scaled_activation, + with_rht=qparams.random_hadamard_transform, + with_post_rht_amax=qparams.random_hadamard_transform, + with_2d_quantization=qparams.fp4_2d_quantization, + stochastic_rounding=qparams.stochastic_rounding, + row_scaled_nvfp4=recipe.row_scaled_activation, + nvfp4_use_4over6=use_4over6, + ) return quantizer(tensor) @staticmethod @@ -191,10 +233,24 @@ def get_tensor_size_mb(tensor): if tensor is None: return 0 if isinstance(tensor, te.quantized_tensor.QuantizedTensorStorage): - return sum(Utils.get_tensor_size_mb(t) for t in tensor.get_data_tensors()) + tensors = [ + value for value in tensor.get_metadata().values() if isinstance(value, torch.Tensor) + ] + return sum(Utils.get_tensor_size_mb(t) for t in tensors) else: return tensor.numel() * tensor.element_size() / (1024**2) + @staticmethod + def get_saved_tensor_gpu_size_mb(tensor): + if tensor is None or isinstance(tensor, int): + return 0 + if isinstance(tensor, tuple): + push_results, _ = tensor + return Utils.get_saved_tensor_gpu_size_mb(push_results) + if isinstance(tensor, list): + return sum(Utils.get_saved_tensor_gpu_size_mb(t) for t in tensor) + return Utils.get_tensor_size_mb(tensor) + @staticmethod def memory_leak_check(): # Should be called before each test. @@ -212,7 +268,7 @@ def memory_leak_check(): class TestsOffloadableLayerState: @pytest.mark.parametrize("random_num_tensors", [True, False]) - @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("recipe", quantization_recipes, ids=recipe_id) def test_general(self, random_num_tensors, recipe): """ Test general functionality of DefaultOffloadSynchronizer - offload NUM_LAYERS-1 out of NUM_LAYERS layers, @@ -289,7 +345,7 @@ def test_offload_base_tensor(self): class TestsDefaultOffloadSynchronizer: @pytest.mark.parametrize("random_num_tensors", [True, False]) - @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("recipe", quantization_recipes, ids=recipe_id) def test_general(self, random_num_tensors, recipe): """ Test general functionality of DefaultOffloadSynchronizer - offload NUM_LAYERS-1 out of NUM_LAYERS layers, @@ -335,7 +391,7 @@ def test_general(self, random_num_tensors, recipe): offload_synchronizer.finish_part_of_bwd() torch.cuda.synchronize() - @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("recipe", quantization_recipes, ids=recipe_id) def test_memory(self, recipe): torch.cuda.synchronize() Utils.memory_leak_check() @@ -363,11 +419,16 @@ def test_memory(self, recipe): del tensor, tensor_id torch.cuda.synchronize() + resident_gpu_size = sum( + Utils.get_saved_tensor_gpu_size_mb(tensor_id) for tensor_id in tensor_ids + ) if recipe is None: assert Utils.get_max_cuda_memory_mb() == pytest.approx( - init_cuda_memory + tensor_size, 0.1 + init_cuda_memory + resident_gpu_size, 0.1 ) - assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory + tensor_size, 0.1) + assert Utils.get_cuda_memory_mb() == pytest.approx( + init_cuda_memory + resident_gpu_size, 0.1 + ) for i in range(NUM_LAYERS - 1, -1, -1): offload_synchronizer.bwd_step(i) @@ -385,7 +446,7 @@ def test_memory(self, recipe): ) assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1) - @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("recipe", quantization_recipes, ids=recipe_id) def test_multiple_tensor_offload(self, recipe): Utils.memory_leak_check() init_cpu_memory = Utils.get_cpu_memory_mb() @@ -416,7 +477,7 @@ def test_multiple_tensor_offload(self, recipe): class TestTELayers: @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) - @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("recipe", quantization_recipes, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) def test_sanity(self, layer_type, recipe, backward_override): Utils.memory_leak_check() @@ -463,7 +524,7 @@ def test_sanity(self, layer_type, recipe, backward_override): del out, inp, layers @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) - @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("recipe", quantization_recipes, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) def test_memory(self, layer_type, recipe, backward_override): Utils.memory_leak_check() @@ -536,7 +597,9 @@ def test_memory(self, layer_type, recipe, backward_override): out = out + 1 out = sync_function(out) del inp - if backward_override is None: + if recipe is not None and recipe.nvfp4() and recipe.row_scaled_activation: + assert Utils.get_cuda_memory_mb() <= cuda_memory_no_offload + elif backward_override is None: assert Utils.get_cuda_memory_mb() == pytest.approx(init_cuda_memory, 0.1) else: assert ( @@ -554,7 +617,7 @@ def test_memory(self, layer_type, recipe, backward_override): out.sum().backward() @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) - @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("recipe", quantization_recipes, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) def test_manual_synchronization(self, recipe, layer_type, backward_override): Utils.memory_leak_check() @@ -623,7 +686,7 @@ def test_manual_synchronization(self, recipe, layer_type, backward_override): out_1.sum().backward() out_2.sum().backward() - @pytest.mark.parametrize("recipe", quantization_recipes) + @pytest.mark.parametrize("recipe", quantization_recipes, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("layer_type", Utils.get_layer_names()) @pytest.mark.parametrize("use_cuda_graphs", [True, False]) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 33ba65e0d9..bb4a4e3857 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -65,13 +65,31 @@ def nvfp4_rht_and_2d_quantization(): def nvfp4_row_scaled(): - nvfp4_recipe = recipe.NVFP4BlockScaling(row_scaled_activation=True) + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + row_scaled_activation=True, + backward_override="dequantized", + ) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() return nvfp4_recipe +def nvfp4_4over6(): + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + nvfp4_4over6="all", + ) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(fp4_2d_quantization=True) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + def check_rht_usage(recipe: recipe.Recipe) -> bool: # if using RHT, we can only support bf16 # check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad @@ -101,6 +119,7 @@ def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> if nvfp4_available: fp8_recipes.append(nvfp4_rht_and_2d_quantization()) fp8_recipes.append(nvfp4_row_scaled()) + fp8_recipes.append(nvfp4_4over6()) if fp8_block_scaling_available: fp8_recipes.append(recipe.Float8BlockScaling()) if fp8_available: diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 3a3aa8be91..8e63caa987 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -40,6 +40,7 @@ Float8Quantizer, MXFP8Quantizer, NVFP4Quantizer, + QuantizerRole, is_bf16_available, ) from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor @@ -78,9 +79,10 @@ _quantization_list.append("mxfp8") if nvfp4_available: _quantization_list.append("nvfp4") + _quantization_list.append("nvfp4_4over6") -@pytest.fixture(autouse=True, scope="class") +@pytest.fixture(autouse=True, scope="function") def _reset_rng_states_per_test(): """Restore torch, CUDA, and Python ``random`` before each test in this module.""" reset_rng_states() @@ -107,7 +109,7 @@ def maybe_skip_quantization( pytest.skip(reason_for_no_fp8) if quantization == "mxfp8" and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if quantization == "nvfp4" and not nvfp4_available: + if quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") and not nvfp4_available: pytest.skip(reason_for_no_nvfp4) # Check dims @@ -120,13 +122,16 @@ def maybe_skip_quantization( elif quantization == "mxfp8": if math.prod(dims[:-1]) % 32 != 0 or dims[-1] % 32 != 0: pytest.skip("MXFP8 GEMMs require dims that are divisible by 32") - elif quantization == "nvfp4": + elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): if math.prod(dims[:-1]) % 16 != 0 or dims[-1] % 16 != 0: pytest.skip("NVFP4 GEMMs require dims that are divisible by 16") # Check dtype if dtype is not None: - if quantization == "nvfp4" and dtype != torch.bfloat16: + if ( + quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") + and dtype != torch.bfloat16 + ): pytest.skip("NVFP4 quantization is only supported with BF16 data") @@ -142,6 +147,7 @@ def make_reference_and_test_tensors( test_dtype: torch.dtype = torch.float32, test_device: torch.device = "cuda", test_is_quantized: bool = False, + quantizer_role: Optional[QuantizerRole] = None, requires_grad: bool = True, ) -> tuple[torch.Tensor, torch.Tensor]: """Construct tensors with the same values @@ -181,7 +187,7 @@ def make_reference_and_test_tensors( test = quantizer(test) elif quantization == "mxfp8": test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) - elif quantization == "nvfp4": + elif quantization in ("nvfp4", "nvfp4_row_scaled"): test = NVFP4Quantizer( with_rht=False, with_post_rht_amax=False, @@ -189,6 +195,29 @@ def make_reference_and_test_tensors( stochastic_rounding=False, with_random_sign_mask=False, )(test) + elif quantization == "nvfp4_4over6": + tensor_type = "input" + if quantizer_role is not None: + tensor_type = quantizer_role.tensor_type + + nvfp4_use_4over6 = False + with_2d_quantization = False + nvfp4_e4m3_max = 448 + if tensor_type not in ("grad_output", "grad_input"): + nvfp4_use_4over6 = True + nvfp4_e4m3_max = 256 + if tensor_type == "weight": + with_2d_quantization = True + + test = NVFP4Quantizer( + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=with_2d_quantization, + stochastic_rounding=False, + with_random_sign_mask=False, + nvfp4_use_4over6=nvfp4_use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, + )(test) else: raise ValueError(f"Unsupported quantization scheme ({quantization})") if isinstance(test, QuantizedTensor) and not test_is_quantized: @@ -504,6 +533,7 @@ def test_dtype_cast( quantization=quantization, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), ) # Construct operation @@ -818,9 +848,13 @@ def test_quantize( test_device=device, requires_grad=True, ) + grad_quantization = quantization + if quantization == "nvfp4_4over6" and cast_backward: + # 4over6 is not applied to gradient quantizers. + grad_quantization = "nvfp4" dy_ref, dy_test = make_reference_and_test_tensors( in_shape, - quantization=quantization, + quantization=grad_quantization, test_dtype=dtype, test_device=device, requires_grad=False, @@ -911,6 +945,7 @@ def _test_basic_linear( quantization=quantization, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, @@ -1083,6 +1118,7 @@ def test_linear( quantization=quantization, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), ) b_ref, b_test = None, None if bias: @@ -1513,7 +1549,7 @@ def test_add_extra_input( if in_place: if quantization in ("fp8_delayed_scaling", "fp8_current_scaling", "mxfp8"): tols = dtype_tols(x1_test._fp8_dtype) - elif quantization == "nvfp4": + elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): tols = dtype_tols(x1_test._fp4_dtype) y_test = y_test.to(dtype=torch.float64, device="cpu") dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") @@ -1884,7 +1920,7 @@ def test_clamped_swiglu( # Expected numerical error tols = dtype_tols(dtype) - if quantized_compute and quantization == "nvfp4": + if quantized_compute and quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): tols = dtype_tols(tex.DType.kFloat4E2M1) elif quantized_compute: tols = dtype_tols(tex.DType.kFloat8E4M3) @@ -2077,6 +2113,8 @@ def test_grouped_linear( pytest.skip("Quantization scheme is not used") if quantization is not None and dtype not in (torch.bfloat16, torch.float16): pytest.skip("Quantized group GEMM is only supported with BF16/FP16") + if quantization == "nvfp4_4over6": + pytest.skip("NVFP4 4over6 grouped quantization is not supported") if single_grouped_bias and not bias: pytest.skip("single_grouped_bias requires bias=True") @@ -2113,6 +2151,7 @@ def test_grouped_linear( quantization=quantization, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), requires_grad=weight_requires_grad, ) b_ref, b_test = None, None @@ -2685,6 +2724,7 @@ def test_forward_linear_bias_activation( quantization=quantization, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), ) b_ref, b_test = None, None if bias: @@ -2790,6 +2830,7 @@ def test_forward_linear_bias_add( quantization=quantization, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), ) b_ref, b_test = None, None if bias: @@ -2903,6 +2944,7 @@ def test_forward_linear_scale_add( quantization=quantization, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), ) x2_ref, x2_test = make_reference_and_test_tensors( out_shape, @@ -3185,6 +3227,7 @@ def test_backward_linear_add( quantization=quantization, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), ) dy1_ref, dy1_test = make_reference_and_test_tensors( out_shape, @@ -3288,6 +3331,7 @@ def test_backward_linear_scale( quantization=quantization, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), ) dy_ref, dy_test = make_reference_and_test_tensors( out_shape, @@ -3504,55 +3548,62 @@ def test_layernorm_mlp( ) norm_w_ref, norm_w_test = make_reference_and_test_tensors( hidden_size, + min=-0.5, + max=0.5, test_dtype=dtype, test_device=device, ) norm_b_ref, norm_b_test = make_reference_and_test_tensors( hidden_size, + min=-0.5, + max=0.5, test_dtype=dtype, test_device=device, ) w1_ref, w1_test = make_reference_and_test_tensors( (ffn_hidden_size, hidden_size), quantization=quantization, + min=0, + max=1 / 64, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), ) w2_ref, w2_test = make_reference_and_test_tensors( (hidden_size, ffn_hidden_size // 2), + min=0, + max=1 / 64, quantization=quantization, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), ) b1_ref, b1_test, b2_ref, b2_test = None, None, None, None if bias: b1_ref, b1_test = make_reference_and_test_tensors( ffn_hidden_size, + min=-0.5, + max=0.5, test_dtype=dtype, test_device=device, ) b2_ref, b2_test = make_reference_and_test_tensors( hidden_size, + min=-0.5, + max=0.5, test_dtype=dtype, test_device=device, ) dy_ref, dy_test = make_reference_and_test_tensors( in_shape, + min=-0.5, + max=0.5, quantization=quantization, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="grad_output"), requires_grad=False, ) - with torch.no_grad(): - for t in (norm_w_ref, norm_w_test, norm_b_ref, norm_b_test): - t -= 0.5 - for t in (w1_ref, w1_test, w2_ref, w2_test): - t *= 1 / 64 - if bias: - for t in (b1_ref, b1_test, b2_ref, b2_test): - t -= 0.5 - for t in (dy_ref, dy_test): - t -= 0.5 # Reference implementation x = x_ref @@ -3686,7 +3737,14 @@ def test_grouped_mlp( pytest.skip("Scaled unary grouped MLP fusion is only supported with MXFP8") if not activation_is_glu and glu_interleave_size is not None: pytest.skip("Unary activations do not use GLU interleaving") - if quantization == "nvfp4" and activation == "scaled_clamped_qgeglu" and bias: + if quantization == "nvfp4_4over6": + pytest.skip("NVFP4 4over6 grouped quantization is not supported") + if ( + with_quantization + and quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6") + and activation == "scaled_clamped_qgeglu" + and bias + ): # TODO: ksivaman: Need to debug numerics for this case. pytest.skip("Bias/dbias not yet supported in NVFP4 fused grouped MLP with GeGLU") fc1_out_features = 2 * hidden_size if activation_is_glu else hidden_size @@ -3726,6 +3784,7 @@ def test_grouped_mlp( quantization=quantization, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), ) fc2_w_ref, fc2_w_test = make_reference_and_test_tensors( (hidden_size, hidden_size), @@ -3734,6 +3793,7 @@ def test_grouped_mlp( quantization=quantization, test_dtype=dtype, test_device=device, + quantizer_role=QuantizerRole(tensor_type="weight"), ) fc1_b_ref, fc1_b_test = None, None fc2_b_ref, fc2_b_test = None, None @@ -3918,7 +3978,7 @@ def test_grouped_mlp( # Loose tols for sanity checking tols = {"rtol": 0.125, "atol": 0.25} - if quantization == "nvfp4": + if quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): tols = {"rtol": 0.25, "atol": 0.5} # Check values diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index a718ea2a8a..5f82bfcba2 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -54,7 +54,7 @@ from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor from transformer_engine.common import recipe import transformer_engine_torch as tex -from utils import ModelConfig, reset_rng_states +from utils import ModelConfig, recipe_id, reset_rng_states, skip_unsupported_backward_override # Only run FP8 tests on supported devices. @@ -138,6 +138,32 @@ def nvfp4_rht_and_2d_quantization(): return nvfp4_recipe +def nvfp4_row_scaled(): + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + row_scaled_activation=True, + backward_override="high_precision", + ) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + +def nvfp4_4over6(): + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + nvfp4_4over6="all", + ) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(fp4_2d_quantization=True) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + def check_rht_usage(recipe: recipe.Recipe) -> bool: # if using RHT, we can only support bf16 # check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad @@ -171,6 +197,8 @@ def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> fp8_recipes.append(recipe.DelayedScaling()) if nvfp4_available: fp8_recipes.append(nvfp4_rht_and_2d_quantization()) + fp8_recipes.append(nvfp4_4over6()) + fp8_recipes.append(nvfp4_row_scaled()) use_cutlass_grouped_gemm = [False] # Only enable cutlass grouped gemm on Hopper @@ -627,11 +655,15 @@ def _test_e2e_selective_recompute( @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", all_boolean) -@pytest.mark.parametrize("recipe", fp8_recipes) +@pytest.mark.parametrize("recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("fp8_model_params", all_boolean) def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params): if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") + if fp8 or fp8_model_params: + skip_unsupported_backward_override( + "transformer_layer", recipe, getattr(recipe, "backward_override", None) + ) if fp8 and recipe.nvfp4(): if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): pytest.skip( @@ -739,7 +771,7 @@ def _test_e2e_full_recompute( @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", all_boolean) -@pytest.mark.parametrize("recipe", fp8_recipes) +@pytest.mark.parametrize("recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_reentrant", all_boolean) def test_gpt_full_activation_recompute( @@ -747,6 +779,10 @@ def test_gpt_full_activation_recompute( ): if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") + if fp8 or fp8_model_params: + skip_unsupported_backward_override( + "transformer_layer", recipe, getattr(recipe, "backward_override", None) + ) if fp8 and recipe.nvfp4(): if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): pytest.skip( @@ -1324,7 +1360,7 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_ @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("model", ["small"]) -@pytest.mark.parametrize("recipe", fp8_recipes + [None]) +@pytest.mark.parametrize("recipe", fp8_recipes + [None], ids=recipe_id) def test_linear_accuracy_save_original_input(dtype, model, recipe): bs = 1 fuse_wgrad_accumulation = True @@ -1333,6 +1369,7 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe): if fp8 and recipe.delayed(): pytest.skip("DelayedScaling recipe is not supported with save_original_input") + skip_unsupported_backward_override("linear", recipe, getattr(recipe, "backward_override", None)) config = model_configs[model] if config.max_seqlen_q % 16 != 0 and fp8: @@ -1894,7 +1931,7 @@ def _test_grouped_linear_accuracy( @pytest.mark.parametrize("num_gemms", [3, 6]) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) -@pytest.mark.parametrize("recipe", fp8_recipes + [None]) +@pytest.mark.parametrize("recipe", fp8_recipes + [None], ids=recipe_id) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) @pytest.mark.parametrize("bias", all_boolean) @@ -1917,6 +1954,9 @@ def test_grouped_linear_accuracy( pytest.skip("FP8 parameters are not supported in debug mode.") if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute: pytest.skip("Delayed wgrad compute is not supported in debug mode.") + skip_unsupported_backward_override( + "grouped_linear", recipe, getattr(recipe, "backward_override", None) + ) config = model_configs[model] if config.max_seqlen_q % 16 != 0 and fp8: @@ -2037,7 +2077,7 @@ def test_grouped_linear_accuracy_cutlass( @pytest.mark.parametrize("num_gemms", [3]) @pytest.mark.parametrize("bs", [1]) @pytest.mark.parametrize("model", ["126m"]) -@pytest.mark.parametrize("recipe", fp8_recipes + [None]) +@pytest.mark.parametrize("recipe", fp8_recipes + [None], ids=recipe_id) @pytest.mark.parametrize("fp8_model_params", [False]) @pytest.mark.parametrize("fuse_wgrad_accumulation", [True]) @pytest.mark.parametrize("bias", [False]) @@ -2061,6 +2101,9 @@ def test_grouped_linear_accuracy_save_original_input( pytest.skip("DelayedScaling recipe is not supported with save_original_input") if NVTE_TEST_NVINSPECT_ENABLED and delay_wgrad_compute: pytest.skip("Delayed wgrad compute is not supported in debug mode.") + skip_unsupported_backward_override( + "grouped_linear", recipe, getattr(recipe, "backward_override", None) + ) config = model_configs[model] if config.max_seqlen_q % 16 != 0 and fp8: @@ -2139,7 +2182,7 @@ def test_grouped_linear_accuracy_save_original_input( torch.testing.assert_close(o, o_ref, rtol=0, atol=0) -@pytest.mark.parametrize("recipe", fp8_recipes + [None]) +@pytest.mark.parametrize("recipe", fp8_recipes + [None], ids=recipe_id) def test_grouped_linear_accuracy_single_gemm(recipe): """Split the tests to save CI time""" test_grouped_linear_accuracy( @@ -2253,7 +2296,7 @@ def _generate_random_numbers(n, total_sum): @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", [True]) -@pytest.mark.parametrize("recipe", fp8_recipes) +@pytest.mark.parametrize("recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("fp8_model_params", all_boolean) def test_padding_grouped_linear_accuracy( dtype, @@ -2267,6 +2310,9 @@ def test_padding_grouped_linear_accuracy( ): if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") + skip_unsupported_backward_override( + "grouped_linear", recipe, getattr(recipe, "backward_override", None) + ) config = model_configs[model] if config.max_seqlen_q % 16 != 0 and fp8: @@ -2328,7 +2374,7 @@ def test_padding_grouped_linear_accuracy( @pytest.mark.parametrize("bs", [1]) @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("fp8", [True]) -@pytest.mark.parametrize("recipe", fp8_recipes) +@pytest.mark.parametrize("recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("fp8_model_params", [False]) def test_padding_grouped_linear_accuracy_save_original_input( dtype, @@ -2344,6 +2390,9 @@ def test_padding_grouped_linear_accuracy_save_original_input( pytest.skip("FP8 parameters are not supported in debug mode.") if fp8 and recipe.delayed(): pytest.skip("DelayedScaling recipe is not supported with save_original_input") + skip_unsupported_backward_override( + "grouped_linear", recipe, getattr(recipe, "backward_override", None) + ) config = model_configs[model] if config.max_seqlen_q % 16 != 0 and fp8: @@ -2559,10 +2608,13 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes) @pytest.mark.parametrize("model", ["126m"]) -@pytest.mark.parametrize("recipe", fp8_recipes) +@pytest.mark.parametrize("recipe", fp8_recipes, ids=recipe_id) def test_gpt_fp8_parameters(dtype, bs, model, recipe): if NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") + skip_unsupported_backward_override( + "transformer_layer", recipe, getattr(recipe, "backward_override", None) + ) if recipe.nvfp4(): if dtype not in get_nvfp4_inp_supported_dtypes(recipe, dtype): diff --git a/tests/pytorch/test_quantized_tensor.py b/tests/pytorch/test_quantized_tensor.py index 119914fbc3..c5161349ef 100644 --- a/tests/pytorch/test_quantized_tensor.py +++ b/tests/pytorch/test_quantized_tensor.py @@ -28,7 +28,7 @@ import transformer_engine_torch as tex from references.ref_per_tensor_cs import ref_per_tensor_cs_cast -from utils import assert_close, quantization_tols +from utils import assert_close # PyTorch tensor dtypes _dtypes: List[torch.dtype] = [torch.float32, torch.float16, torch.bfloat16] @@ -69,6 +69,8 @@ def _to_list(x: Union[Iterable, Any]) -> List: _quantization_list.append("mxfp8") if nvfp4_available: _quantization_list.append("nvfp4") + _quantization_list.append("nvfp4_row_scaled") + _quantization_list.append("nvfp4_4over6") # delayed scaling @@ -163,13 +165,17 @@ def make_reference_and_test_tensors( test = quantizer(test) elif quantization == "mxfp8": test = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(test) - elif quantization == "nvfp4": + elif quantization in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): + row_scaled_nvfp4 = quantization == "nvfp4_row_scaled" test = NVFP4Quantizer( + columnwise=not row_scaled_nvfp4, with_rht=False, with_post_rht_amax=False, with_2d_quantization=False, stochastic_rounding=False, + row_scaled_nvfp4=row_scaled_nvfp4, with_random_sign_mask=False, + nvfp4_use_4over6=(quantization == "nvfp4_4over6"), )(test) else: raise ValueError(f"Unsupported quantization scheme ({quantization})") @@ -785,13 +791,16 @@ def test_update_nd_tensor( ) elif quantization == "mxfp8": quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) - elif quantization in ("nvfp4", "nvfp4_2d"): + elif quantization in ("nvfp4", "nvfp4_2d", "nvfp4_row_scaled", "nvfp4_4over6"): + row_scaled_nvfp4 = quantization == "nvfp4_row_scaled" quantizer = NVFP4Quantizer( rowwise=True, - columnwise=True, + columnwise=not row_scaled_nvfp4, with_rht=False, with_post_rht_amax=False, with_2d_quantization=(quantization == "nvfp4_2d"), + row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=(quantization == "nvfp4_4over6"), ) quantization = "nvfp4" else: @@ -806,9 +815,9 @@ def test_update_nd_tensor( q_x.copy_(x_new) # Check results + q_ref = quantizer(x_new) assert q_x.shape == torch.Size(shape) - tols = quantization_tols(quantization) - assert_close(q_x, x_new, **tols) + assert_close(q_x, q_ref, rtol=0, atol=0) @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 5f5221af76..9a14cee7fd 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -26,6 +26,7 @@ from transformer_engine.pytorch.quantization import ( FP8GlobalStateManager, NVFP4BlockScalingRecipeState, + QuantizerRole, _amax_and_scale_update, ) import transformer_engine.pytorch.ops as te_ops @@ -514,8 +515,52 @@ def test_quantizer_update(self, module_class): @pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) -def test_nvfp4_row_scaled_quantizer_roles(): - recipe = NVFP4BlockScaling(row_scaled_activation=True) +@pytest.mark.parametrize( + "nvfp4_4over6", + ["none", "weights", "activations", "all"], + ids=["disabled", "weights", "activations", "all"], +) +@pytest.mark.parametrize( + "nvfp4_4over6_e4m3_use_256", + ["none", "weights", "activations", "all"], + ids=["e4m3_448", "e4m3_256_weights", "e4m3_256_activations", "e4m3_256_all"], +) +@pytest.mark.parametrize("nvfp4_4over6_err_mode", ["MAE", "MSE"], ids=["mae_err", "mse_err"]) +def test_nvfp4_row_scaled_quantizer_roles( + nvfp4_4over6, nvfp4_4over6_e4m3_use_256, nvfp4_4over6_err_mode +): + recipe = NVFP4BlockScaling( + disable_rht=True, + disable_2d_quantization=True, + nvfp4_4over6=nvfp4_4over6, + nvfp4_4over6_e4m3_use_256=nvfp4_4over6_e4m3_use_256, + nvfp4_4over6_err_mode=nvfp4_4over6_err_mode, + row_scaled_activation=True, + ) + + def expected_use_4over6(tensor_type): + if tensor_type in ("grad_output", "grad_input"): + return False + if nvfp4_4over6 == "all": + return True + if nvfp4_4over6 == "weights": + return tensor_type == "weight" + if nvfp4_4over6 == "activations": + return tensor_type != "weight" + return False + + def expected_e4m3_max(tensor_type): + if not expected_use_4over6(tensor_type): + return 448 + if nvfp4_4over6_e4m3_use_256 == "all": + return 256 + if nvfp4_4over6_e4m3_use_256 == "weights": + if tensor_type == "weight": + return 256 + if nvfp4_4over6_e4m3_use_256 == "activations": + if tensor_type != "weight": + return 256 + return 448 forward_quantizers = NVFP4BlockScalingRecipeState( recipe, @@ -523,20 +568,85 @@ def test_nvfp4_row_scaled_quantizer_roles(): num_quantizers=3, ).make_quantizers() assert [q.row_scaled_nvfp4 for q in forward_quantizers] == [True, False, True] + assert [q.stochastic_rounding for q in forward_quantizers] == [False, False, False] + assert [q.with_rht for q in forward_quantizers] == [False, False, False] + assert [q.nvfp4_use_4over6 for q in forward_quantizers] == [ + expected_use_4over6(tensor_type) for tensor_type in ("input", "weight", "output") + ] + assert [q.nvfp4_e4m3_max for q in forward_quantizers] == [ + expected_e4m3_max(tensor_type) for tensor_type in ("input", "weight", "output") + ] + assert [q.nvfp4_4over6_err_mode for q in forward_quantizers] == [nvfp4_4over6_err_mode] * 3 assert not forward_quantizers[0].is_quantizable(torch.empty(16, 16)) assert forward_quantizers[1].is_quantizable(torch.empty(16, 16)) + role_quantizers = NVFP4BlockScalingRecipeState( + recipe, + mode="forward", + num_quantizers=4, + roles=[ + QuantizerRole(module_type="linear", tensor_type="weight"), + QuantizerRole(module_type="linear", tensor_type="input"), + QuantizerRole(module_type="linear", tensor_type="output"), + None, + ], + ).make_quantizers() + assert [q.row_scaled_nvfp4 for q in role_quantizers] == [False, True, True, True] + assert [q.nvfp4_use_4over6 for q in role_quantizers] == [ + expected_use_4over6(tensor_type) for tensor_type in ("weight", "input", "output", "input") + ] + assert [q.nvfp4_e4m3_max for q in role_quantizers] == [ + expected_e4m3_max(tensor_type) for tensor_type in ("weight", "input", "output", "input") + ] + assert [q.nvfp4_4over6_err_mode for q in role_quantizers] == [nvfp4_4over6_err_mode] * 4 + backward_quantizers = NVFP4BlockScalingRecipeState( recipe, mode="backward", num_quantizers=2, + roles=[ + QuantizerRole(module_type="linear", tensor_type="grad_output"), + QuantizerRole(module_type="linear", tensor_type="grad_input"), + ], ).make_quantizers() assert [q.row_scaled_nvfp4 for q in backward_quantizers] == [False, False] + assert [q.nvfp4_use_4over6 for q in backward_quantizers] == [False, False] + assert [q.nvfp4_e4m3_max for q in backward_quantizers] == [448, 448] + assert [q.nvfp4_4over6_err_mode for q in backward_quantizers] == [nvfp4_4over6_err_mode] * 2 + assert [q.stochastic_rounding for q in backward_quantizers] == [True, True] + assert [q.with_rht for q in backward_quantizers] == [False, False] + + backward_operand_quantizers = NVFP4BlockScalingRecipeState( + recipe, + mode="backward", + num_quantizers=4, + roles=[ + QuantizerRole(module_type="linear", tensor_type="input"), + QuantizerRole(module_type="linear", tensor_type="weight"), + QuantizerRole(module_type="linear", tensor_type="grad_output"), + QuantizerRole(module_type="linear", tensor_type="grad_input"), + ], + ).make_quantizers() + assert [q.nvfp4_use_4over6 for q in backward_operand_quantizers] == [ + expected_use_4over6(tensor_type) + for tensor_type in ("input", "weight", "grad_output", "grad_input") + ] + assert [q.nvfp4_e4m3_max for q in backward_operand_quantizers] == [ + expected_e4m3_max(tensor_type) + for tensor_type in ("input", "weight", "grad_output", "grad_input") + ] + assert [q.stochastic_rounding for q in backward_operand_quantizers] == [ + False, + False, + True, + True, + ] @pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) +@pytest.mark.parametrize("use_4over6", [False, True], ids=["default", "4over6"]) @pytest.mark.parametrize( "M, N", [ @@ -552,24 +662,30 @@ def test_nvfp4_row_scaled_quantizer_roles(): (8192, 8192), ], ) -def test_fp4_dequantize(dtype, row_scaled_nvfp4, M, N): +def test_fp4_dequantize(dtype, row_scaled_nvfp4, use_4over6, M, N): q = NVFP4Quantizer( columnwise=not row_scaled_nvfp4, row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=use_4over6, ) a = torch.rand((M, N)).cuda().to(dtype=dtype) starting_tensor = q(a) assert starting_tensor._row_scaled_nvfp4 == row_scaled_nvfp4 + assert starting_tensor._nvfp4_use_4over6 == use_4over6 assert starting_tensor._amax_rowwise.numel() == (M if row_scaled_nvfp4 else 1) dequantized_tensor = starting_tensor.dequantize() new_tensor = q(dequantized_tensor) assert new_tensor._row_scaled_nvfp4 == row_scaled_nvfp4 + assert new_tensor._nvfp4_use_4over6 == use_4over6 assert new_tensor._amax_rowwise.numel() == (M if row_scaled_nvfp4 else 1) - torch.testing.assert_close( - new_tensor._rowwise_data, - starting_tensor._rowwise_data, - rtol=0, - atol=0, - ) + # 4over6 can re-encode a dequantized block with the alternate 4/6 scale + # choice while preserving the dequantized values. + if not use_4over6: + torch.testing.assert_close( + new_tensor._rowwise_data, + starting_tensor._rowwise_data, + rtol=0, + atol=0, + ) new_dequantized_tensor = new_tensor.dequantize() torch.testing.assert_close(dequantized_tensor, new_dequantized_tensor) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index c811342df5..27eafbecdc 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -95,27 +95,43 @@ def nvfp4_vanilla(): def nvfp4_row_scaled(): - nvfp4_recipe = recipe.NVFP4BlockScaling(row_scaled_activation=True) + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + row_scaled_activation=True, + backward_override="dequantized", + ) nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() return nvfp4_recipe +def nvfp4_4over6(): + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + nvfp4_4over6="all", + ) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(fp4_2d_quantization=True) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + fp8_recipes = [] if mxfp8_available: fp8_recipes.append(recipe.MXFP8BlockScaling()) if nvfp4_available: fp8_recipes.append(nvfp4_vanilla()) # TODO: fix check for this + fp8_recipes.append(nvfp4_4over6()) if fp8_block_scaling_available: fp8_recipes.append(recipe.Float8BlockScaling()) if fp8_available: fp8_recipes.append(recipe.Float8CurrentScaling()) fp8_recipes.append(recipe.DelayedScaling()) fp8_recipes.append(None) -fp8_recipes_with_row_scaled = fp8_recipes.copy() -if nvfp4_available: - fp8_recipes_with_row_scaled.insert(-1, nvfp4_row_scaled()) param_types = [torch.float32, torch.float16] if is_bf16_available(): # bf16 requires sm_80 or higher @@ -415,7 +431,11 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_row_scaled, ids=recipe_id) +@pytest.mark.parametrize( + "fp8_recipe", + fp8_recipes + ([nvfp4_row_scaled()] if nvfp4_available else []), + ids=recipe_id, +) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @@ -463,7 +483,11 @@ def test_sanity_layernorm_linear( @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_row_scaled, ids=recipe_id) +@pytest.mark.parametrize( + "fp8_recipe", + fp8_recipes + ([nvfp4_row_scaled()] if nvfp4_available else []), + ids=recipe_id, +) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @@ -501,7 +525,11 @@ def test_sanity_linear( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_row_scaled, ids=recipe_id) +@pytest.mark.parametrize( + "fp8_recipe", + fp8_recipes + ([nvfp4_row_scaled()] if nvfp4_available else []), + ids=recipe_id, +) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) @@ -542,7 +570,11 @@ def test_sanity_linear_with_zero_tokens( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_row_scaled, ids=recipe_id) +@pytest.mark.parametrize( + "fp8_recipe", + fp8_recipes + ([nvfp4_row_scaled()] if nvfp4_available else []), + ids=recipe_id, +) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) @@ -621,7 +653,7 @@ def test_sanity_grouped_linear( @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("zero_centered_gamma", all_boolean) @@ -671,7 +703,7 @@ def test_sanity_layernorm_mlp( @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("bias", all_boolean) @@ -744,7 +776,7 @@ def test_sanity_gpt_126m(): @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("normalization", all_normalizations) @@ -800,7 +832,7 @@ def test_sanity_bert_126m(): @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @pytest.mark.parametrize("normalization", all_normalizations) @@ -856,7 +888,7 @@ def test_sanity_T5_126m(): @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): @@ -889,7 +921,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("model", ["small"]) def test_sanity_drop_path(dtype, fp8_recipe, model): config = model_configs[model] @@ -924,7 +956,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model): @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): @@ -960,7 +992,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad): diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index 51f72b1e56..137e5f5a77 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -39,6 +39,33 @@ fp8_block_scaling_available = is_fp8_block_scaling_available() nvfp4_available = is_nvfp4_available() + +def nvfp4_row_scaled(): + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + row_scaled_activation=True, + backward_override="dequantized", + ) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + +def nvfp4_4over6(): + nvfp4_recipe = recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + nvfp4_4over6="all", + ) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams(fp4_2d_quantization=True) + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + _all_recipes: list = [] if fp8_available: _all_recipes.append(recipe.Float8CurrentScaling()) @@ -48,7 +75,8 @@ _all_recipes.append(recipe.MXFP8BlockScaling()) if nvfp4_available: _all_recipes.append(recipe.NVFP4BlockScaling()) - _all_recipes.append(recipe.NVFP4BlockScaling(row_scaled_activation=True)) + _all_recipes.append(nvfp4_4over6()) + _all_recipes.append(nvfp4_row_scaled()) # --------------------------------------------------------------------------- @@ -97,8 +125,19 @@ def __fx_repr__(self): def _make_qfactory(tag: str): """Return a qfactory that produces ToyQuantizer instances tagged with *tag*.""" + quantizers = { + role: ToyQuantizer(tag=f"{tag}:{role}") + for role in ( + "linear_input", + "linear_weight", + "linear_output", + "linear_grad_output", + "linear_grad_input", + ) + } + def qfactory(role: str): - return ToyQuantizer(tag=f"{tag}:{role}") + return quantizers[role] return qfactory diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 2ee18aaf57..19cc118a90 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -118,7 +118,7 @@ def quantization_tols(name: str) -> dict[str, float]: "mxfp8_block_scaling", ): return dtype_tols(tex.DType.kFloat8E4M3) - if name in ("nvfp4", "nvfp4_row_scaled"): + if name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): return dtype_tols(tex.DType.kFloat4E2M1) raise ValueError(f"Unsupported quantization scheme ({name})") @@ -145,21 +145,17 @@ def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]: ) if name == "fp8_block_scaling": return transformer_engine.common.recipe.Float8BlockScaling(**recipe_kwargs) - if name == "nvfp4": - return transformer_engine.common.recipe.NVFP4BlockScaling( - disable_rht=True, - disable_stochastic_rounding=True, - disable_2d_quantization=True, - **recipe_kwargs, - ) - if name == "nvfp4_row_scaled": - return transformer_engine.common.recipe.NVFP4BlockScaling( - disable_rht=True, - disable_stochastic_rounding=True, - disable_2d_quantization=True, - row_scaled_activation=True, - **recipe_kwargs, - ) + if name in ("nvfp4", "nvfp4_row_scaled", "nvfp4_4over6"): + use_4over6 = name == "nvfp4_4over6" + kwargs = { + "disable_rht": True, + "disable_stochastic_rounding": True, + "disable_2d_quantization": not use_4over6, + "row_scaled_activation": name == "nvfp4_row_scaled", + "nvfp4_4over6": "all" if use_4over6 else "none", + } + kwargs.update(recipe_kwargs) + return transformer_engine.common.recipe.NVFP4BlockScaling(**kwargs) raise ValueError(f"Unsupported quantization scheme ({name})") @@ -167,6 +163,10 @@ def recipe_id(recipe: Optional[Recipe]) -> str: """Readable pytest id for a quantization recipe.""" if not isinstance(recipe, Recipe): return "None" + if recipe.nvfp4() and recipe.row_scaled_activation and recipe.nvfp4_4over6 != "none": + return "NVFP4RowScaled4Over6BlockScaling" + if recipe.nvfp4() and recipe.nvfp4_4over6 != "none": + return "NVFP44Over6BlockScaling" if recipe.nvfp4() and recipe.row_scaled_activation: return "NVFP4RowScaledBlockScaling" return type(recipe).__name__ @@ -185,6 +185,13 @@ def skip_unsupported_backward_override( and backward_override is None ): pytest.skip("Row-scaled NVFP4 does not support default quantized backward.") + if ( + quant_recipe is not None + and quant_recipe.nvfp4() + and quant_recipe.nvfp4_4over6 != "none" + and layer_type == "grouped_linear" + ): + pytest.skip("NVFP4 4over6 currently does not support grouped quantization.") if backward_override is None: return if quant_recipe is None and backward_override is not None: diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 123362ce10..316243c975 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -21,6 +21,7 @@ #include "../mxfp8/group_quantize_mxfp8.cuh" #include "../mxfp8/quantize_mxfp8.cuh" #include "../nvfp4/group_quantize_transpose_nvfp4.cuh" +#include "../nvfp4/quantize_4over6_nvfp4.cuh" #include "../nvfp4/quantize_transpose_nvfp4.cuh" namespace transformer_engine { @@ -101,6 +102,11 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, int32_t cols = input_tensor->flat_last_dim(); auto dtype = input_tensor->dtype(); const bool row_scaled_nvfp4 = output_tensor->row_scaled_nvfp4; + const bool nvfp4_use_4over6 = quant_config_cpp.nvfp4_4over6_mode != kNVTENVFP44Over6Disabled; + NVTE_CHECK(nvfp4_use_4over6 || output_tensor->nvfp4_e4m3_max == 448, + "Non-4over6 NVFP4 quantization requires E4M3 max 448."); + NVTE_CHECK(!nvfp4_use_4over6 || !quant_config_cpp.stochastic_rounding, + "NVFP4 4over6 quantization does not support stochastic rounding."); if (row_scaled_nvfp4) { NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, "Row-scaled NVFP4 quantization does not support 2D quantization."); @@ -112,7 +118,15 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, (cols % 32 == 0) && output_tensor->has_data(); // Launch NVFP4 quantize kernel - if (use_optimized_kernel) { + if (nvfp4_use_4over6) { + if (quant_config_cpp.nvfp4_2d_quantization) { + nvfp4::quantize_4over6( + *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } else { + nvfp4::quantize_4over6( + *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } + } else if (use_optimized_kernel) { if (quant_config_cpp.nvfp4_2d_quantization) { nvfp4::quantize_transpose( *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); @@ -249,13 +263,26 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens int32_t rows = grad_tensor->flat_first_dim(); int32_t cols = grad_tensor->flat_last_dim(); auto dtype = grad_tensor->dtype(); + const bool nvfp4_use_4over6 = quant_config_cpp.nvfp4_4over6_mode != kNVTENVFP44Over6Disabled; + NVTE_CHECK(nvfp4_use_4over6 || output_tensor->nvfp4_e4m3_max == 448, + "Non-4over6 NVFP4 quantization requires E4M3 max 448."); + NVTE_CHECK(!nvfp4_use_4over6 || !quant_config_cpp.stochastic_rounding, + "NVFP4 4over6 quantization does not support stochastic rounding."); NVTE_CHECK(!output_tensor->row_scaled_nvfp4, "Backward NVFP4 quantization does not support row-scaled outputs."); bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0) && output_tensor->has_data(); // Launch NVFP4 quantize kernel - if (use_optimized_kernel) { + if (nvfp4_use_4over6) { + if (quant_config_cpp.nvfp4_2d_quantization) { + nvfp4::quantize_4over6( + *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } else { + nvfp4::quantize_4over6( + *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); + } + } else if (use_optimized_kernel) { if (quant_config_cpp.nvfp4_2d_quantization) { nvfp4::quantize_transpose( *grad_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); @@ -277,7 +304,8 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, /*rng_state=*/quant_config_cpp.rng_state, /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - /*row_scaled_nvfp4=*/false, /*noop_tensor=*/noop_tensor->data, + /*row_scaled_nvfp4=*/false, + /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); } break; @@ -372,8 +400,15 @@ void group_quantize_fwd_host_aware_helper(const NVTETensor input, NVTETensor *ou int32_t cols = input_tensor->flat_last_dim(); auto dtype = input_tensor->dtype(); + const bool nvfp4_use_4over6 = quant_config_cpp.nvfp4_4over6_mode != kNVTENVFP44Over6Disabled; + for (const auto *output_tensor : output_tensors) { + NVTE_CHECK(nvfp4_use_4over6 || output_tensor->nvfp4_e4m3_max == 448, + "Non-4over6 NVFP4 quantization requires E4M3 max 448."); + } NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, "2D quantization is not supported for group quantize."); + NVTE_CHECK(!nvfp4_use_4over6, + "NVFP4 4over6 quantization is not supported for group quantize."); // Launch NVFP4 group quantize kernel nvfp4::group_quantize_transpose( diff --git a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh index 792b068cbc..3820430d5b 100644 --- a/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/core_nvfp4.cuh @@ -75,10 +75,14 @@ namespace core { #if FP4_TYPE_SUPPORTED using namespace ptx; -// Compute the global encode scale factor for a given global amax +// Compute the global encode scale factor for a given global amax. +// NVFP4 uses the full E4M3 range by default. Some 4over6 tensors dispatch +// E4M3_MAX=256 to leave room for map-to-4 scale expansion. +template __device__ __forceinline__ float compute_global_encode_scaling_factor_FP4(const float global_amax) { using namespace detail; - constexpr float fp8_max = TypeExtrema::max; // 448.0f; + static_assert(E4M3_MAX == 448 || E4M3_MAX == 256, "Unsupported NVFP4 E4M3 max."); + constexpr float fp8_max = static_cast(E4M3_MAX); constexpr float fp4_max = TypeExtrema::max; // 6.0f; float global_encode_scale = fp8_max * fp4_max / global_amax; // If scale is infinity, return max value of float32 diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index d549a050ee..faf3c58adf 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -31,12 +31,11 @@ namespace dispatch { namespace nvfp4 { namespace dequantize_kernel { #if FP4_TYPE_SUPPORTED -template +template __global__ void __launch_bounds__(512) dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales, - const float *const tensor_amax, const bool row_scaled_nvfp4, - const size_t N, const size_t M, const size_t scale_stride, - const size_t num_scale_tiles_X) { + const float *const tensor_amax, const size_t N, const size_t M, + const size_t scale_stride, const size_t num_scale_tiles_X) { const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t x = thread_idx % M; const size_t y = thread_idx / M; @@ -64,8 +63,9 @@ __global__ void __launch_bounds__(512) fp4vec value; value.vec = input_vectorized[my_index]; fp8e4m3 scale = scales[my_scale_index]; - float amax = row_scaled_nvfp4 ? tensor_amax[y] : tensor_amax[0]; - constexpr float factor_inv = 1.0 / (6.0 * 448.0); + float amax = ROW_SCALED_NVFP4 ? tensor_amax[y] : tensor_amax[0]; + static_assert(E4M3_MAX == 448 || E4M3_MAX == 256, "Unsupported NVFP4 E4M3 max."); + constexpr float factor_inv = 1.0f / (6.0f * static_cast(E4M3_MAX)); float final_scale = static_cast(scale) * amax * factor_inv; #pragma unroll for (int i = 0; i < 4; i++) { @@ -92,6 +92,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) const bool with_gemm_swizzled_scales = input.with_gemm_swizzled_scales; const bool row_scaled_nvfp4 = input.row_scaled_nvfp4; + const int e4m3_max = input.nvfp4_e4m3_max; constexpr int FP4_BLOCK_SIZE = 16; const size_t N = input.flat_first_dim(); @@ -112,14 +113,25 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) output->data.dtype, OType, TRANSFORMER_ENGINE_SWITCH_CONDITION( with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, - - dequantize_fp4_kernel<<>>( - input.data.dptr, reinterpret_cast(output->data.dptr), - reinterpret_cast(input.scale_inv.dptr), - reinterpret_cast(input.amax.dptr), row_scaled_nvfp4, N, Mread, - input.scale_inv.shape.back(), - num_scale_tiles_X);); // NOLINT(*) - ); // NOLINT(*) + TRANSFORMER_ENGINE_SWITCH_CONDITION( + row_scaled_nvfp4, ROW_SCALED_NVFP4, + if (e4m3_max == 256) { + dequantize_fp4_kernel + <<>>( + input.data.dptr, reinterpret_cast(output->data.dptr), + reinterpret_cast(input.scale_inv.dptr), + reinterpret_cast(input.amax.dptr), N, Mread, + input.scale_inv.shape.back(), num_scale_tiles_X); + } else { + NVTE_CHECK(e4m3_max == 448, "Unsupported NVFP4 E4M3 max (got ", e4m3_max, ")"); + dequantize_fp4_kernel + <<>>( + input.data.dptr, reinterpret_cast(output->data.dptr), + reinterpret_cast(input.scale_inv.dptr), + reinterpret_cast(input.amax.dptr), N, Mread, + input.scale_inv.shape.back(), num_scale_tiles_X); + });); // NOLINT(*) + ); // NOLINT(*) NVTE_CHECK_CUDA(cudaGetLastError()); #else NVTE_ERROR("CUDA 12.8 or higher is needed for FP4 calculation!"); diff --git a/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh new file mode 100644 index 0000000000..b6057370dc --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/quantize_4over6_nvfp4.cuh @@ -0,0 +1,668 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_4over6_nvfp4.cuh + * \brief Dedicated kernels for NVFP4 4over6 quantization. + * + * Four Over Six evaluates two TE-style NVFP4 encodings for every 1x16 + * quantization group. The map-to-6 candidate uses the normal scale. The + * map-to-4 candidate expands the E4M3 block scale by 1.5x so FP4 value 4 + * reaches the same range that FP4 value 6 reaches in the normal encoding. + * The selected candidate is the one with lower configured dequantization + * error; ties select map-to-6. The quantized candidates, dequantized values, + * and errors are kept in registers, matching the structure of the official + * Four Over Six implementation. + */ + +#ifndef TRANSFORMER_ENGINE_QUANTIZE_4OVER6_NVFP4_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_4OVER6_NVFP4_CUH_ + +#include +#include +#include +#include +#include + +#include + +#include "../../common.h" +#include "../../util/math.h" +#include "../../utils.cuh" +#include "core_nvfp4.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { + +#if FP4_TYPE_SUPPORTED + +#define TRANSFORMER_ENGINE_NVFP4_4OVER6_MODE_SWITCH(MODE, MODE_CONST, ...) \ + switch (MODE) { \ + case kNVTENVFP44Over6MinMAE: { \ + constexpr NVTENVFP44Over6Mode MODE_CONST = kNVTENVFP44Over6MinMAE; \ + { __VA_ARGS__ } \ + } break; \ + case kNVTENVFP44Over6MinMSE: { \ + constexpr NVTENVFP44Over6Mode MODE_CONST = kNVTENVFP44Over6MinMSE; \ + { __VA_ARGS__ } \ + } break; \ + default: { \ + NVTE_ERROR("Unsupported NVFP4 4over6 mode."); \ + } \ + } + +#define TRANSFORMER_ENGINE_NVFP4_4OVER6_E4M3_MAX_SWITCH(E4M3_MAX_VALUE, E4M3_MAX_CONST, ...) \ + if ((E4M3_MAX_VALUE) == 256) { \ + constexpr int E4M3_MAX_CONST = 256; \ + { __VA_ARGS__ } \ + } else { \ + NVTE_CHECK((E4M3_MAX_VALUE) == 448, "Unsupported NVFP4 E4M3 max."); \ + constexpr int E4M3_MAX_CONST = 448; \ + { __VA_ARGS__ } \ + } + +namespace quantize_4over6_kernel { + +constexpr int kThreads = 128; +constexpr int kWarpThreads = 32; +constexpr int kGroupSize = 16; +constexpr int kTileRows = 128; +constexpr int kTileCols = 64; +constexpr int kTileColGroups = kTileCols / kGroupSize; +constexpr int kTileRowGroups = kTileRows / kGroupSize; +constexpr int kPipelineStages = 2; +constexpr int kStageRows = kTileRows / kPipelineStages; +constexpr int kStageRowGroups = kStageRows / kGroupSize; +constexpr int kElementsPerHalfGroup = 8; +constexpr int kPackedWordsPerGroup = 2; +static_assert(kTileRows == kPipelineStages * kStageRows); +static_assert(kStageRows % kGroupSize == 0); + +template +struct Config { + static constexpr NVTENVFP44Over6Mode mode = kMode; + static constexpr bool err_use_fast_math = kErrUseFastMath; +}; + +struct Candidate { + uint32_t packed[kPackedWordsPerGroup]; + float err; +}; + +struct CandidatePair { + Candidate map4; + Candidate map6; +}; + +struct ScalePair { + nvfp4_scale_t map4; + nvfp4_scale_t map6; + float inv_map4; + float inv_map6; +}; + +template +__device__ __forceinline__ float compute_error_rn(const float diff) { + if constexpr (kMode == kNVTENVFP44Over6MinMSE) { + return __fmul_rn(diff, diff); + } else if constexpr (kMode == kNVTENVFP44Over6MinMAE) { + return fabsf(diff); + } else { + NVTE_DEVICE_ERROR("Unsupported NVFP4 4over6 mode."); + return fabsf(diff); + } +} + +template +__device__ __forceinline__ float compute_error(const float diff) { + if constexpr (kMode == kNVTENVFP44Over6MinMSE) { + return diff * diff; + } else if constexpr (kMode == kNVTENVFP44Over6MinMAE) { + return fabsf(diff); + } else { + NVTE_DEVICE_ERROR("Unsupported NVFP4 4over6 mode."); + return fabsf(diff); + } +} + +template +__device__ __forceinline__ ScalePair compute_scale_pair(const float block_amax, + const float global_amax) { + static_assert(E4M3_MAX == 448 || E4M3_MAX == 256, "Unsupported NVFP4 E4M3 max."); + constexpr float fp4_max = detail::TypeExtrema::max; // 6.0f + constexpr float fp8_max = detail::TypeExtrema::max; // 448.0f + constexpr float expand_to_map4 = 1.5f; + const float S_enc = core::compute_global_encode_scaling_factor_FP4(global_amax); + const float base = block_amax / fp4_max * S_enc; + + ScalePair scales; + scales.map4 = static_cast(fminf(base * expand_to_map4, fp8_max)); + scales.map6 = static_cast(fminf(base, fp8_max)); + + const float S_dec = 1.0f / S_enc; + scales.inv_map4 = + fminf(1.0f / (static_cast(scales.map4) * S_dec), detail::TypeExtrema::max); + scales.inv_map6 = + fminf(1.0f / (static_cast(scales.map6) * S_dec), detail::TypeExtrema::max); + return scales; +} + +template +__device__ __forceinline__ float load_input(const IType *ptr, const size_t idx) { + return static_cast(ptr[idx]); +} + +template +__device__ __forceinline__ void load_row_group(const IType *tile, const int row, + const int col_start, float (&x0)[8], float (&x1)[8], + float *amax) { + Vec x0_vec; + Vec x1_vec; + x0_vec.load_from(&tile[row * kTileCols + col_start]); + x1_vec.load_from(&tile[row * kTileCols + col_start + kElementsPerHalfGroup]); + + *amax = 0.0f; +#pragma unroll + for (int i = 0; i < kElementsPerHalfGroup; ++i) { + const float v0 = static_cast(x0_vec.data.elt[i]); + const float v1 = static_cast(x1_vec.data.elt[i]); + x0[i] = v0; + x1[i] = v1; + *amax = fmaxf(*amax, fabsf(v0)); + *amax = fmaxf(*amax, fabsf(v1)); + } +} + +template +__device__ __forceinline__ void load_col_group(const IType *tile, const int row_start, + const int col, float (&x0)[8], float (&x1)[8], + float *amax) { + *amax = 0.0f; +#pragma unroll + for (int i = 0; i < kElementsPerHalfGroup; ++i) { + const float v0 = load_input(tile, (row_start + i) * kTileCols + col); + const float v1 = load_input(tile, (row_start + i + kElementsPerHalfGroup) * kTileCols + col); + x0[i] = v0; + x1[i] = v1; + *amax = fmaxf(*amax, fabsf(v0)); + *amax = fmaxf(*amax, fabsf(v1)); + } +} + +template +__device__ __forceinline__ void accumulate_dequant_error(const uint32_t dequant_bits, const float x, + const float sf, const float global_amax, + float *err) { + constexpr float fp4_max = detail::TypeExtrema::max; // 6.0f + constexpr float fp8_max = static_cast(E4M3_MAX); + constexpr float err_denom = fp4_max * fp8_max; + const uint16_t half_bits = (dequant_bits >> SHIFT) & 0xFFFF; + + if constexpr (Cfg::err_use_fast_math) { + const float dequant = __half2float(__ushort_as_half(half_bits)); + const float val = dequant * sf * global_amax / err_denom; + const float diff = val - x; + *err += compute_error(diff); + } else { + const float dequant = __half2float(__ushort_as_half(half_bits)); + const float val = __fdiv_rn(__fmul_rn(__fmul_rn(dequant, sf), global_amax), err_denom); + const float diff = __fsub_rn(val, x); + *err = __fadd_rn(*err, compute_error_rn(diff)); + } +} + +template +__device__ __forceinline__ uint32_t cvt_fp32_to_fp4_8x_with_error(const float (&x)[8], + const float block_scale_inverse, + const nvfp4_scale_t sf, + const float global_amax, + float *err) { + uint32_t out = 0; + uint32_t out_dequant_1 = 0; + uint32_t out_dequant_2 = 0; + uint32_t out_dequant_3 = 0; + uint32_t out_dequant_4 = 0; + + constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY; + if constexpr (is_blackwell) { + asm volatile( + "{\n" + ".reg .b8 byte0, byte1, byte2, byte3;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte0, %6, %5;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte1, %8, %7;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte2, %10, %9;\n" + "cvt.rn.satfinite.e2m1x2.f32 byte3, %12, %11;\n" + "mov.b32 %0, {byte0, byte1, byte2, byte3};\n" + "cvt.rn.f16x2.e2m1x2 %1, byte0;\n" + "cvt.rn.f16x2.e2m1x2 %2, byte1;\n" + "cvt.rn.f16x2.e2m1x2 %3, byte2;\n" + "cvt.rn.f16x2.e2m1x2 %4, byte3;\n" + "}" + : "=r"(out), "=r"(out_dequant_1), "=r"(out_dequant_2), "=r"(out_dequant_3), + "=r"(out_dequant_4) + : "f"(__fmul_rn(x[0], block_scale_inverse)), "f"(__fmul_rn(x[1], block_scale_inverse)), + "f"(__fmul_rn(x[2], block_scale_inverse)), "f"(__fmul_rn(x[3], block_scale_inverse)), + "f"(__fmul_rn(x[4], block_scale_inverse)), "f"(__fmul_rn(x[5], block_scale_inverse)), + "f"(__fmul_rn(x[6], block_scale_inverse)), "f"(__fmul_rn(x[7], block_scale_inverse))); + } else { + NVTE_DEVICE_ERROR( + "FP4 cvt PTX instructions are architecture-specific. " + "Try recompiling with sm_XXXa instead of sm_XXX."); + } + + const float sf_float = static_cast(sf); + accumulate_dequant_error(out_dequant_1, x[0], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_1, x[1], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_2, x[2], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_2, x[3], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_3, x[4], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_3, x[5], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_4, x[6], sf_float, global_amax, err); + accumulate_dequant_error(out_dequant_4, x[7], sf_float, global_amax, err); + return out; +} + +template +__device__ __forceinline__ CandidatePair make_candidates(const float (&x0)[8], const float (&x1)[8], + const ScalePair &scales, + const float global_amax) { + CandidatePair candidates; + candidates.map4.err = 0.0f; + candidates.map6.err = 0.0f; + candidates.map4.packed[0] = cvt_fp32_to_fp4_8x_with_error( + x0, scales.inv_map4, scales.map4, global_amax, &candidates.map4.err); + candidates.map6.packed[0] = cvt_fp32_to_fp4_8x_with_error( + x0, scales.inv_map6, scales.map6, global_amax, &candidates.map6.err); + candidates.map4.packed[1] = cvt_fp32_to_fp4_8x_with_error( + x1, scales.inv_map4, scales.map4, global_amax, &candidates.map4.err); + candidates.map6.packed[1] = cvt_fp32_to_fp4_8x_with_error( + x1, scales.inv_map6, scales.map6, global_amax, &candidates.map6.err); + return candidates; +} + +__device__ __forceinline__ float reduce_group_sum_16(float value) { + const int lane = threadIdx.x & (kWarpThreads - 1); + const int group_base = lane & ~(kGroupSize - 1); + const unsigned mask = 0xffffu << group_base; +#pragma unroll + for (int offset = kGroupSize / 2; offset > 0; offset /= 2) { + value += __shfl_down_sync(mask, value, offset, kGroupSize); + } + return __shfl_sync(mask, value, group_base, kWarpThreads); +} + +__device__ __forceinline__ float reduce_group_max_16(float value) { + const int lane = threadIdx.x & (kWarpThreads - 1); + const int group_base = lane & ~(kGroupSize - 1); + const unsigned mask = 0xffffu << group_base; +#pragma unroll + for (int offset = kGroupSize / 2; offset > 0; offset /= 2) { + value = fmaxf(value, __shfl_down_sync(mask, value, offset, kGroupSize)); + } + return __shfl_sync(mask, value, group_base, kWarpThreads); +} + +__device__ __forceinline__ void store_packed_group(const uint32_t *packed, fp4e2m1x2 *dst) { + const uint64_t packed64 = + static_cast(packed[0]) | (static_cast(packed[1]) << 32); + *reinterpret_cast(dst) = packed64; +} + +__device__ __forceinline__ const uint32_t *select_packed(const CandidatePair &candidates, + const bool pick_map4) { + if (pick_map4) { + return candidates.map4.packed; + } + return candidates.map6.packed; +} + +__device__ __forceinline__ nvfp4_scale_t select_scale(const ScalePair &scales, + const bool pick_map4) { + if (pick_map4) { + return scales.map4; + } + return scales.map6; +} + +__device__ __forceinline__ void cp_async_cg_16(void *dst, const void *src) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + const uint32_t dst_smem_ptr = __cvta_generic_to_shared(dst); + const uint64_t src_gmem_ptr = reinterpret_cast(src); + asm volatile("cp.async.cg.shared.global [%0], [%1], 16;\n" ::"r"(dst_smem_ptr), + "l"(src_gmem_ptr)); +#else + NVTE_DEVICE_ERROR("cp.async is only supported on SM 8.0+."); +#endif +} + +__device__ __forceinline__ void cp_async_commit_group() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + asm volatile("cp.async.commit_group;\n" ::); +#else + NVTE_DEVICE_ERROR("cp.async is only supported on SM 8.0+."); +#endif +} + +template +__device__ __forceinline__ void cp_async_wait_group() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); +#else + NVTE_DEVICE_ERROR("cp.async is only supported on SM 8.0+."); +#endif +} + +template +__device__ void load_stage_to_shared_async(const IType *input, IType *tile, const size_t rows, + const size_t cols, const size_t stage_row, + const size_t tile_col) { + constexpr int vec_elems = 16 / sizeof(IType); + constexpr int vecs_per_row = kTileCols / vec_elems; + constexpr int vecs = kStageRows * vecs_per_row; + using TileVec = Vec; + + for (int idx = threadIdx.x; idx < vecs; idx += blockDim.x) { + const int local_row = idx / vecs_per_row; + const int local_vec_col = idx - local_row * vecs_per_row; + const int local_col = local_vec_col * vec_elems; + const size_t global_row = stage_row + local_row; + const size_t global_col = tile_col + local_col; + IType *stage_ptr = &tile[local_row * kTileCols + local_col]; + + if (global_row < rows && global_col + vec_elems <= cols) { + cp_async_cg_16(stage_ptr, &input[global_row * cols + global_col]); + } else { + TileVec vec; + vec.clear(); +#pragma unroll + for (int i = 0; i < vec_elems; ++i) { + if (global_row < rows && global_col + i < cols) { + vec.data.elt[i] = input[global_row * cols + global_col + i]; + } + } + vec.store_to(stage_ptr); + } + } +} + +template +__device__ void quantize_stage_rowwise(const IType *tile, fp4e2m1x2 *output, nvfp4_scale_t *scales, + const float *amax, const size_t rows, const size_t cols, + const size_t stage_row, const size_t tile_col, + const size_t scale_stride) { + constexpr int groups = kStageRows * kTileColGroups; + for (int group = threadIdx.x; group < groups; group += blockDim.x) { + const int local_row = group % kStageRows; + const int local_col_group = group / kStageRows; + const int local_col = local_col_group * kGroupSize; + const size_t global_row = stage_row + local_row; + const size_t global_col = tile_col + local_col; + if (global_row >= rows || global_col >= cols) { + continue; + } + + float x0[8]; + float x1[8]; + float group_amax = 0.0f; + load_row_group(tile, local_row, local_col, x0, x1, &group_amax); + + float block_amax = group_amax; + if constexpr (USE_2D_QUANTIZATION) { + block_amax = reduce_group_max_16(group_amax); + } + + float global_amax = amax[0]; + if constexpr (ROW_SCALED_NVFP4) { + global_amax = amax[global_row]; + } + + const ScalePair scale_pair = compute_scale_pair(block_amax, global_amax); + CandidatePair candidates = make_candidates(x0, x1, scale_pair, global_amax); + + float err_map4 = candidates.map4.err; + float err_map6 = candidates.map6.err; + if constexpr (USE_2D_QUANTIZATION) { + err_map4 = reduce_group_sum_16(err_map4); + err_map6 = reduce_group_sum_16(err_map6); + } + + const bool pick_map4 = err_map4 < err_map6; + const nvfp4_scale_t selected_scale = select_scale(scale_pair, pick_map4); + const uint32_t *selected = select_packed(candidates, pick_map4); + + const size_t global_col_group = global_col / kGroupSize; + scales[global_row * scale_stride + global_col_group] = selected_scale; + store_packed_group(selected, &output[(global_row * cols + global_col) / 2]); + } +} + +template +__device__ void quantize_stage_colwise(const IType *tile, fp4e2m1x2 *output_t, + nvfp4_scale_t *scales_t, const float *amax, + const size_t rows, const size_t cols, const size_t stage_row, + const size_t tile_col, const size_t scale_stride_t) { + constexpr int groups = kStageRowGroups * kTileCols; + for (int group = threadIdx.x; group < groups; group += blockDim.x) { + const int local_row_group = group / kTileCols; + const int local_col = group - local_row_group * kTileCols; + const int local_row = local_row_group * kGroupSize; + const size_t global_row = stage_row + local_row; + const size_t global_col = tile_col + local_col; + if (global_row >= rows || global_col >= cols) { + continue; + } + + float x0[8]; + float x1[8]; + float group_amax = 0.0f; + load_col_group(tile, local_row, local_col, x0, x1, &group_amax); + + float block_amax = group_amax; + if constexpr (USE_2D_QUANTIZATION) { + block_amax = reduce_group_max_16(group_amax); + } + + const float global_amax = amax[0]; + const ScalePair scale_pair = compute_scale_pair(block_amax, global_amax); + CandidatePair candidates = make_candidates(x0, x1, scale_pair, global_amax); + + float err_map4 = candidates.map4.err; + float err_map6 = candidates.map6.err; + if constexpr (USE_2D_QUANTIZATION) { + err_map4 = reduce_group_sum_16(err_map4); + err_map6 = reduce_group_sum_16(err_map6); + } + + const bool pick_map4 = err_map4 < err_map6; + const nvfp4_scale_t selected_scale = select_scale(scale_pair, pick_map4); + const uint32_t *selected = select_packed(candidates, pick_map4); + + const size_t global_row_group = global_row / kGroupSize; + scales_t[global_col * scale_stride_t + global_row_group] = selected_scale; + store_packed_group(selected, &output_t[(global_col * rows + global_row) / 2]); + } +} + +template +__global__ void __launch_bounds__(kThreads) + quantize_4over6_kernel(const IType *input, fp4e2m1x2 *output, fp4e2m1x2 *output_t, + nvfp4_scale_t *scales, nvfp4_scale_t *scales_t, + const float *amax_rowwise, const float *amax_colwise, const size_t rows, + const size_t cols, const size_t scale_stride, + const size_t scale_stride_t, const float *noop) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + extern __shared__ char dynamic_shmem[]; + auto *tiles = reinterpret_cast(dynamic_shmem); + const size_t tile_col = blockIdx.x * kTileCols; + const size_t tile_row = blockIdx.y * kTileRows; + + IType *stage_tiles[kPipelineStages]; +#pragma unroll + for (int stage = 0; stage < kPipelineStages; ++stage) { + stage_tiles[stage] = &tiles[stage * kStageRows * kTileCols]; + } + + load_stage_to_shared_async(input, stage_tiles[0], rows, cols, tile_row, tile_col); + cp_async_commit_group(); + cp_async_wait_group<0>(); + __syncthreads(); + + for (int stage = 0; stage < kPipelineStages; ++stage) { + const int next_stage = stage + 1; + if (next_stage < kPipelineStages) { + const size_t next_stage_row = tile_row + next_stage * kStageRows; + load_stage_to_shared_async(input, stage_tiles[next_stage], rows, cols, next_stage_row, + tile_col); + cp_async_commit_group(); + } + + const size_t stage_row = tile_row + stage * kStageRows; + IType *stage_tile = stage_tiles[stage]; + + if constexpr (RETURN_IDENTITY) { + quantize_stage_rowwise( + stage_tile, output, scales, amax_rowwise, rows, cols, stage_row, tile_col, scale_stride); + } + + if constexpr (RETURN_TRANSPOSE) { + const float *columnwise_amax = amax_colwise; + if (columnwise_amax == nullptr) { + columnwise_amax = amax_rowwise; + } + quantize_stage_colwise( + stage_tile, output_t, scales_t, columnwise_amax, rows, cols, stage_row, tile_col, + scale_stride_t); + } + + if (next_stage < kPipelineStages) { + cp_async_wait_group<0>(); + __syncthreads(); + } + } +#else + NVTE_DEVICE_ERROR("sm_100 or higher is required."); +#endif +} + +template +void launch_quantize_4over6(const Tensor &input, const Tensor *noop, Tensor *output, + cudaStream_t stream) { + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + const bool row_scaled_nvfp4 = output->row_scaled_nvfp4; + const bool return_identity = output->has_data(); + const bool return_transpose = output->has_columnwise_data(); + + const auto *input_ptr = reinterpret_cast(input.data.dptr); + auto *output_ptr = reinterpret_cast(output->data.dptr); + auto *output_t_ptr = reinterpret_cast(output->columnwise_data.dptr); + auto *scales_ptr = reinterpret_cast(output->scale_inv.dptr); + auto *scales_t_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); + const auto *amax_rowwise_ptr = reinterpret_cast(output->amax.dptr); + const auto *amax_colwise_ptr = reinterpret_cast(output->columnwise_amax.dptr); + const auto *noop_ptr = reinterpret_cast(noop->data.dptr); + + const dim3 grid(DIVUP(cols, static_cast(kTileCols)), + DIVUP(rows, static_cast(kTileRows))); + const dim3 block(kThreads); + const size_t shmem = kPipelineStages * kStageRows * kTileCols * sizeof(IType); + const size_t scale_stride = return_identity ? output->scale_inv.shape[1] : 0; + const size_t scale_stride_t = return_transpose ? output->columnwise_scale_inv.shape[1] : 0; + + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_identity, RETURN_IDENTITY, { + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + TRANSFORMER_ENGINE_SWITCH_CONDITION(row_scaled_nvfp4, ROW_SCALED_NVFP4, { + auto kernel = quantize_4over6_kernel; + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem); + kernel<<>>(input_ptr, output_ptr, output_t_ptr, scales_ptr, + scales_t_ptr, amax_rowwise_ptr, amax_colwise_ptr, + rows, cols, scale_stride, scale_stride_t, noop_ptr); + }); + }); + }); +} + +} // namespace quantize_4over6_kernel + +#endif // FP4_TYPE_SUPPORTED + +template +void quantize_4over6(const Tensor &input, const Tensor *noop, Tensor *output, + const QuantizationConfig *quant_config, cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + using namespace quantize_4over6_kernel; + + checkCuDriverContext(stream); + CheckNoopTensor(*noop, "cast_noop"); + CheckInputTensor(input, "input"); + CheckOutputTensor(*output, "output", false); + + NVTE_CHECK(quant_config != nullptr && quant_config->nvfp4_4over6_mode != kNVTENVFP44Over6Disabled, + "NVFP4 4over6 quantization requires a non-disabled 4over6 mode."); + NVTE_CHECK(!quant_config->stochastic_rounding, + "NVFP4 4over6 quantization does not support stochastic rounding."); + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "NVFP4 4over6 output tensor must have rowwise or columnwise data."); + NVTE_CHECK(!output->with_gemm_swizzled_scales, "Output must have scales in compact format."); + NVTE_CHECK(input.flat_last_dim() % kGroupSize == 0, + "NVFP4 4over6 quantization requires columns divisible by ", kGroupSize, "."); + NVTE_CHECK(!(output->has_columnwise_data() || use_2d_quantization) || + input.flat_first_dim() % kGroupSize == 0, + "NVFP4 4over6 columnwise or 2D quantization requires rows divisible by ", kGroupSize, + "."); + NVTE_CHECK(!output->row_scaled_nvfp4 || !use_2d_quantization, + "Row-scaled NVFP4 quantization does not support 2D quantization."); + NVTE_CHECK(!output->row_scaled_nvfp4 || !output->has_columnwise_data(), + "Row-scaled NVFP4 quantization does not produce columnwise output."); + NVTE_CHECK(!use_2d_quantization || output->has_data(), + "NVFP4 4over6 2D quantization requires rowwise output."); + + if (output->has_data()) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); + NVTE_CHECK(output->amax.dptr != nullptr, "Rowwise amax tensor must be allocated."); + NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); + } + if (output->has_columnwise_data()) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "Transposed scaling tensor must be allocated."); + NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), + "Transposed output must have FP4 type."); + NVTE_CHECK(output->columnwise_amax.dptr != nullptr || output->amax.dptr != nullptr, + "NVFP4 4over6 columnwise quantization requires columnwise amax or rowwise amax."); + } + + TRANSFORMER_ENGINE_NVFP4_4OVER6_E4M3_MAX_SWITCH( + output->nvfp4_e4m3_max, E4M3_MAX, + TRANSFORMER_ENGINE_NVFP4_4OVER6_MODE_SWITCH( + quant_config->nvfp4_4over6_mode, MODE, + TRANSFORMER_ENGINE_SWITCH_CONDITION( + quant_config->nvfp4_4over6_err_use_fast_math, ERR_USE_FAST_MATH, { + using Cfg = quantize_4over6_kernel::Config; + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + input.dtype(), IType, + quantize_4over6_kernel::launch_quantize_4over6( + input, noop, output, stream);); + }););); + + NVTE_CHECK_CUDA(cudaGetLastError()); +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // FP4_TYPE_SUPPORTED +} + +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_4OVER6_NVFP4_CUH_ diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 28218e2b43..a1a0dd9d0b 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -230,6 +230,10 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz chunk.set_row_scaled_nvfp4(source.get_row_scaled_nvfp4()); continue; } + if (param_type == NVTETensorParam::kNVTENVFP4E4M3Max) { + chunk.set_nvfp4_e4m3_max(source.get_nvfp4_e4m3_max()); + continue; + } auto param = source.get_parameter(param_type); auto param_dptr = reinterpret_cast(param.data_ptr); auto param_dtype = static_cast(param.dtype); diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 12479f2a9c..5b6a9bf414 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -178,6 +178,12 @@ struct Tensor { * Only meaningful for NVFP4 tensors. */ bool row_scaled_nvfp4 = false; + /*! \brief Global E4M3 scale bound used by NVFP4. + * + * Standard NVFP4 uses 448. Some 4over6 tensors use 256 to leave room for + * map-to-4 local scale expansion. + */ + int nvfp4_e4m3_max = 448; /*! Map from NVTETensorParam to parameter sizes */ static constexpr size_t attr_sizes[] = { @@ -189,7 +195,8 @@ struct Tensor { sizeof(NVTEBasicTensor), // kNVTEColumnwiseScaleInv sizeof(NVTEBasicTensor), // kNVTEColumnwiseAmax sizeof(uint8_t), // kNVTEWithGEMMSwizzledScales - sizeof(uint8_t) // kNVTERowScaledNVFP4 + sizeof(uint8_t), // kNVTERowScaledNVFP4 + sizeof(int) // kNVTENVFP4E4M3Max }; Tensor() : scaling_mode{NVTE_DELAYED_TENSOR_SCALING}, nvte_tensor{0} {} @@ -206,6 +213,7 @@ struct Tensor { scaling_mode = NVTE_DELAYED_TENSOR_SCALING; with_gemm_swizzled_scales = false; row_scaled_nvfp4 = false; + nvfp4_e4m3_max = 448; } explicit operator NVTETensor() const noexcept { return nvte_tensor; } @@ -477,6 +485,8 @@ struct QuantizationConfig { bool nvfp4_2d_quantization = false; bool stochastic_rounding = false; bool use_fast_math = false; + NVTENVFP44Over6Mode nvfp4_4over6_mode = kNVTENVFP44Over6Disabled; + bool nvfp4_4over6_err_use_fast_math = false; static constexpr size_t attr_sizes[] = { sizeof(uint8_t), // force_pow_2_scales @@ -486,7 +496,9 @@ struct QuantizationConfig { sizeof(NVTETensor), // rng_seed and offset sizeof(uint8_t), // nvfp4_2d_quantization sizeof(uint8_t), // stochastic_rounding - sizeof(uint8_t) // use_fast_math + sizeof(uint8_t), // use_fast_math + sizeof(uint8_t), // nvfp4_4over6_mode + sizeof(uint8_t) // nvfp4_4over6_err_use_fast_math }; }; diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 045ae88893..ffb3243154 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -83,6 +83,13 @@ enum NVTETensorParam { * its values are populated during quantization. */ kNVTERowScaledNVFP4 = 8, + /*! Global E4M3 scale bound used by an NVFP4 tensor. + * + * This is part of the tensor data contract. Downstream dequantization and + * GEMM scale consumers must use the same bound used during quantization. + * Standard NVFP4 uses 448; 4over6 may use 256 for map-to-4 headroom. + */ + kNVTENVFP4E4M3Max = 9, kNVTENumTensorParams }; @@ -111,6 +118,15 @@ enum NVTEScalingMode { NVTE_INVALID_SCALING = 100 }; +/*! \enum NVTENVFP44Over6Mode + * \brief Method for NVFP4 4over6 quantization. + */ +enum NVTENVFP44Over6Mode { + kNVTENVFP44Over6Disabled = 0, /*!< 4over6 is not applied */ + kNVTENVFP44Over6MinMAE = 1, /*!< Select the candidate with lower mean absolute error */ + kNVTENVFP44Over6MinMSE = 2, /*!< Select the candidate with lower mean squared error */ +}; + /*! \brief TE Tensor type * * NVTETensor is a contiguous tensor type storing a pointer @@ -381,6 +397,20 @@ enum NVTEQuantizationConfigAttribute { * inconsistently between kernels. */ kNVTEQuantizationConfigUseFastMath = 7, + /*! Method for NVFP4 4over6 block scale selection. + * + * Non-disabled modes evaluate map-to-4 and map-to-6 candidates for each + * 1x16 block and store the lower-error candidate. The value is an + * NVTENVFP44Over6Mode encoded as uint8_t. + */ + kNVTEQuantizationConfigNVFP44Over6Mode = 8, + /*! Whether the NVFP4 4over6 candidate error computation may use fast math. + * + * This is intentionally separate from kNVTEQuantizationConfigUseFastMath so + * callers can keep candidate selection bitwise deterministic independent + * of ordinary NVFP4 fast-math settings. + */ + kNVTEQuantizationConfigNVFP44Over6ErrUseFastMath = 9, kNVTEQuantizationConfigNumAttributes }; @@ -781,6 +811,11 @@ class TensorWrapper { nvte_set_tensor_param_v2(tensor_, kNVTERowScaledNVFP4, &val, sizeof(val)); } + void set_nvfp4_e4m3_max(int nvfp4_e4m3_max) { + const auto val = nvfp4_e4m3_max; + nvte_set_tensor_param_v2(tensor_, kNVTENVFP4E4M3Max, &val, sizeof(val)); + } + // Parameter getters NVTEBasicTensor get_parameter(const NVTETensorParam param) const noexcept { @@ -823,6 +858,12 @@ class TensorWrapper { return static_cast(val); } + int get_nvfp4_e4m3_max() const { + int val = 448; + nvte_get_tensor_param_v2(tensor_, kNVTENVFP4E4M3Max, &val, sizeof(val), nullptr); + return val; + } + /*! \brief Get an underlying NVTETensor. * * \return NVTETensor held by this TensorWrapper. @@ -1318,6 +1359,20 @@ class QuantizationConfigWrapper { sizeof(val)); } + /*! \brief Set NVFP4 4over6 candidate-selection mode */ + void set_nvfp4_4over6_mode(NVTENVFP44Over6Mode mode) { + const auto val = static_cast(mode); + nvte_set_quantization_config_attribute(config_, kNVTEQuantizationConfigNVFP44Over6Mode, &val, + sizeof(val)); + } + + /*! \brief Set whether NVFP4 4over6 candidate error computation uses fast math */ + void set_nvfp4_4over6_err_use_fast_math(bool use_fast_math) { + const auto val = static_cast(use_fast_math); + nvte_set_quantization_config_attribute( + config_, kNVTEQuantizationConfigNVFP44Over6ErrUseFastMath, &val, sizeof(val)); + } + private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index b773a81d1b..8a03f2f51a 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -13,6 +13,8 @@ _BACKWARD_OVERRIDES = (None, "high_precision", "dequantized") +_NVFP4_4OVER6_SCOPES = ("none", "weights", "activations", "all") +_NVFP4_4OVER6_ERR_MODES = ("MAE", "MSE") class _FormatHelper(NamedTuple): @@ -522,6 +524,19 @@ class NVFP4BlockScaling(Recipe): If set to `True`, forward activation quantizers emit row-scaled NVFP4 tensors. In this mode, rowwise ``amax`` metadata is stored as a vector with one FP32 value per tensor row. + nvfp4_4over6 : {'none', 'weights', 'activations', 'all'}, default = 'none' + Enable 4over6 adaptive NVFP4 block scaling for selected tensor + scopes. For each selected FP4 block, quantization compares + map-to-4 and map-to-6 candidates and stores the candidate with + lower configured error. Current 4over6 support targets RL and + post-training scenarios; pre-training paths that combine 4over6 + with RHT are not yet implemented. + nvfp4_4over6_e4m3_use_256 : {'none', 'weights', 'activations', 'all'}, default = 'all' + Select 4over6 tensors that use 256 as the global E4M3 scale + bound. By default, all 4over6 tensors use 256. Use ``'none'`` + to keep the standard NVFP4 448 bound for 4over6 tensors. + nvfp4_4over6_err_mode : {'MAE', 'MSE'}, default = 'MAE' + Error metric used by NVFP4 4over6 candidate selection. backward_override : {None, 'high_precision', 'dequantized'}, default = None Backward precision mode. None does not modify backward behavior, `high_precision` keeps original high-precision operands for backward, @@ -536,6 +551,9 @@ class NVFP4BlockScaling(Recipe): ) disable_2d_quantization: bool = os.getenv("NVTE_NVFP4_DISABLE_2D_QUANTIZATION", "0") == "1" row_scaled_activation: bool = os.getenv("NVTE_NVFP4_ROW_SCALED_ACTIVATION", "0") == "1" + nvfp4_4over6: str = os.getenv("NVTE_NVFP4_4OVER6", "none") + nvfp4_4over6_e4m3_use_256: str = os.getenv("NVTE_NVFP4_4OVER6_E4M3_USE_256", "all") + nvfp4_4over6_err_mode: str = os.getenv("NVTE_NVFP4_4OVER6_ERR_MODE", "MAE").upper() fp4_format: Format = Format.E2M1 fp8_format: Format = Format.E4M3 @@ -551,6 +569,15 @@ def __post_init__(self) -> None: assert ( self.backward_override in _BACKWARD_OVERRIDES ), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'." + assert ( + self.nvfp4_4over6 in _NVFP4_4OVER6_SCOPES + ), "NVTE_NVFP4_4OVER6 must be one of: 'none', 'weights', 'activations', 'all'." + assert ( + self.nvfp4_4over6_e4m3_use_256 in _NVFP4_4OVER6_SCOPES + ), "NVTE_NVFP4_4OVER6_E4M3_USE_256 must be one of: 'none', 'weights', 'activations', 'all'." + assert ( + self.nvfp4_4over6_err_mode in _NVFP4_4OVER6_ERR_MODES + ), "NVTE_NVFP4_4OVER6_ERR_MODE must be one of: 'MAE', 'MSE'." # Quantization params # Note: RHT is currently only applied to column-wise usage so that @@ -580,6 +607,9 @@ def _make_repr(self) -> str: f"fp8_mha={self.fp8_mha}, " f"backward_override={self.backward_override}, " f"row_scaled_activation={self.row_scaled_activation}, " + f"nvfp4_4over6={self.nvfp4_4over6}, " + f"nvfp4_4over6_e4m3_use_256={self.nvfp4_4over6_e4m3_use_256}, " + f"nvfp4_4over6_err_mode={self.nvfp4_4over6_err_mode}, " f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " diff --git a/transformer_engine/common/recipe/nvfp4.cu b/transformer_engine/common/recipe/nvfp4.cu index 1c419d4f8c..576e6139c7 100644 --- a/transformer_engine/common/recipe/nvfp4.cu +++ b/transformer_engine/common/recipe/nvfp4.cu @@ -65,15 +65,15 @@ namespace nvfp4_recipe { * --------------------------------------------------------------------------- */ -// constexpr float factor = 6.0 * 6.0 * 448.0 * 448.0; -constexpr float factor_inv = 1.0 / (6.0 * 6.0 * 448.0 * 448.0); constexpr int kTileDim = 16; constexpr int kThreadsPerBlock = 256; // Kernel to compute alpha *= amax_A * amax_B / factor __global__ void compute_nvfp4_per_tensor_scale_kernel(float alpha_in, const float *amax_A, - const float *amax_B, float *alpha_out) { - // factor is defined in the enclosing namespace + const float *amax_B, float fp8_max_A, + float fp8_max_B, float *alpha_out) { + constexpr float fp4_max = 6.0f; + const float factor_inv = 1.0f / (fp4_max * fp4_max * fp8_max_A * fp8_max_B); *alpha_out = alpha_in * (*amax_A) * (*amax_B) * factor_inv; } @@ -924,6 +924,8 @@ void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_r void *amax_A_ptr = use_rowwise_amax_A ? tA->amax.dptr : tA->columnwise_amax.dptr; void *amax_B_ptr = use_rowwise_amax_B ? tB->amax.dptr : tB->columnwise_amax.dptr; void *alpha_ptr = tOut->data.dptr; + const float fp8_max_A = static_cast(tA->nvfp4_e4m3_max); + const float fp8_max_B = static_cast(tB->nvfp4_e4m3_max); // check for not null pointers NVTE_CHECK(amax_A_ptr != nullptr, "amax_A_ptr is null"); @@ -932,7 +934,8 @@ void nvte_nvfp4_compute_per_tensor_scale(const NVTETensor inpA, const bool use_r nvfp4_recipe::compute_nvfp4_per_tensor_scale_kernel<<<1, 1, 0, stream>>>( alpha_in, reinterpret_cast(amax_A_ptr), - reinterpret_cast(amax_B_ptr), reinterpret_cast(alpha_ptr)); + reinterpret_cast(amax_B_ptr), fp8_max_A, fp8_max_B, + reinterpret_cast(alpha_ptr)); NVTE_CHECK_CUDA(cudaGetLastError()); #else NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 1a52d76019..561f64d591 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -855,6 +855,11 @@ void nvte_set_tensor_param_v2(NVTETensor tensor, NVTETensorParam param, const vo case kNVTERowScaledNVFP4: t.row_scaled_nvfp4 = static_cast(*reinterpret_cast(buf)); break; + case kNVTENVFP4E4M3Max: + std::memcpy(&t.nvfp4_e4m3_max, buf, attr_size); + NVTE_CHECK(t.nvfp4_e4m3_max == 448 || t.nvfp4_e4m3_max == 256, + "Unsupported NVFP4 E4M3 max (got ", t.nvfp4_e4m3_max, ")"); + break; default: NVTE_ERROR("Unsupported tensor parameter (", static_cast(param), ")"); } @@ -938,6 +943,9 @@ void nvte_get_tensor_param_v2(const NVTETensor tensor, NVTETensorParam param, vo case kNVTERowScaledNVFP4: *reinterpret_cast(buf) = static_cast(t->row_scaled_nvfp4); break; + case kNVTENVFP4E4M3Max: + std::memcpy(buf, &t->nvfp4_e4m3_max, attr_size); + break; default: NVTE_ERROR("Unsupported tensor parameter (", static_cast(param), ")"); } @@ -1049,6 +1057,14 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigUseFastMath: bool_to_uint8(config_.use_fast_math, buf); break; + case kNVTEQuantizationConfigNVFP44Over6Mode: { + const auto val = static_cast(config_.nvfp4_4over6_mode); + std::memcpy(buf, &val, attr_size); + break; + } + case kNVTEQuantizationConfigNVFP44Over6ErrUseFastMath: + bool_to_uint8(config_.nvfp4_4over6_err_use_fast_math, buf); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } @@ -1104,6 +1120,18 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigUseFastMath: uint8_to_bool(buf, config_.use_fast_math); break; + case kNVTEQuantizationConfigNVFP44Over6Mode: { + const auto val = *reinterpret_cast(buf); + NVTE_CHECK(val == static_cast(kNVTENVFP44Over6Disabled) || + val == static_cast(kNVTENVFP44Over6MinMAE) || + val == static_cast(kNVTENVFP44Over6MinMSE), + "Invalid NVFP4 4over6 mode (got ", static_cast(val), ")"); + config_.nvfp4_4over6_mode = static_cast(val); + break; + } + case kNVTEQuantizationConfigNVFP44Over6ErrUseFastMath: + uint8_to_bool(buf, config_.nvfp4_4over6_err_use_fast_math); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 94350da1e6..b376b3022d 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -327,6 +327,10 @@ class NVFP4Quantizer : public Quantizer { // 2D block scaling bool with_2d_quantization; bool stochastic_rounding; + // 4over6 candidate-selection mode used when quantizing emitted NVFP4 tensors. + NVTENVFP44Over6Mode nvfp4_4over6_mode; + // Global E4M3 scale bound used by emitted NVFP4 tensors. + int nvfp4_e4m3_max; // Whether tensors emitted by this quantizer use row-scaled NVFP4 metadata. bool row_scaled_nvfp4; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 2b38339d67..d1a9cd8587 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -84,6 +84,8 @@ void group_quantize_nvfp4_impl(const GroupedTensorWrapper &grouped_input_tensor, // assert the 2D scaling case, since 2D scaling grouped quant kernel is not ready yet NVTE_CHECK(!nvfp4_quantizer_cpp->with_2d_quantization, "2D scaling grouped quant kernel is not ready yet"); + NVTE_CHECK(nvfp4_quantizer_cpp->nvfp4_4over6_mode == kNVTENVFP44Over6Disabled, + "NVFP4 4over6 quantization is not supported for grouped quantization."); auto quant_config_cpp = QuantizationConfigWrapper(); @@ -722,6 +724,9 @@ std::tuple, std::vector, bool> bulk_alloc // Quantization parameters const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage; const bool row_scaled_nvfp4 = quantizer_cpp_list[0]->row_scaled_nvfp4; + const bool nvfp4_use_4over6 = + quantizer_cpp_list[0]->nvfp4_4over6_mode != kNVTENVFP44Over6Disabled; + const int nvfp4_e4m3_max = quantizer_cpp_list[0]->nvfp4_e4m3_max; const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 bulk allocation requires rowwise usage."); @@ -866,10 +871,12 @@ std::tuple, std::vector, bool> bulk_alloc py::object amax_columnwise = columnwise_usage ? py::cast(amax_columnwise_list[i]) : py::none(); // Construct Python tensor - tensor_py_list.emplace_back(NVFP4TensorClass(rowwise_data, rowwise_scale, columnwise_data, - columnwise_scale, amax_rowwise, amax_columnwise, - fp4_dtype, quantizer_py_list[i], - with_gemm_swizzled_scales, row_scaled_nvfp4)); + tensor_py_list.emplace_back( + NVFP4TensorClass(rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, + amax_rowwise, amax_columnwise, fp4_dtype, quantizer_py_list[i], + with_gemm_swizzled_scales, py::arg("row_scaled_nvfp4") = row_scaled_nvfp4, + py::arg("nvfp4_use_4over6") = nvfp4_use_4over6, + py::arg("nvfp4_e4m3_max") = nvfp4_e4m3_max)); // Construct C++ tensor // Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor, @@ -887,6 +894,7 @@ std::tuple, std::vector, bool> bulk_alloc columnwise_usage ? columnwise_scale_shapes[i] : std::vector{0}, scaling_mode); tensor_wrapper.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); tensor_wrapper.set_row_scaled_nvfp4(row_scaled_nvfp4); + tensor_wrapper.set_nvfp4_e4m3_max(nvfp4_e4m3_max); // Set the amax rowwise and amax columnwise if available if (rowwise_usage) { @@ -997,6 +1005,9 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, cudaStream_t stream) { const size_t num_tensors = split_sections.size(); const auto &quantizer = *quantizers.front(); + const bool nvfp4_use_4over6 = quantizer.nvfp4_4over6_mode != kNVTENVFP44Over6Disabled; + NVTE_CHECK(!nvfp4_use_4over6, + "NVFP4 4over6 quantization is not supported with RHT split quantization."); std::vector nvte_tensor_input_list; std::vector nvte_tensor_output_list; @@ -1032,6 +1043,13 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, num_tensors, need_stochastic_rounding, with_bulk_generate_rng_states, need_separate_rng_states, quant_config_list, quant_config_list_colwise); + for (auto &config : quant_config_list) { + config.set_nvfp4_4over6_mode(quantizer.nvfp4_4over6_mode); + } + for (auto &config : quant_config_list_colwise) { + config.set_nvfp4_4over6_mode(quantizer.nvfp4_4over6_mode); + } + // Enable NVFP4 kernels to use math operations that sacrifice // accuracy for performance. These optimizations are experimental // and inconsistently implemented. @@ -1039,8 +1057,10 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, // 1. replace 1 / x by reciprocal_approximate_ftz(x) // 2. when RHT cast fusion is available, fusion allows cast to be performed on FP32 data, // this will essentially remove a round trip between FP32 to BF16 then FP32 + // NVFP4 4over6 candidate error math is controlled separately by + // NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH. const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); - if (use_fast_math) { + if (use_fast_math && !nvfp4_use_4over6) { for (auto &config : quant_config_list) { config.set_use_fast_math(true); } @@ -1049,6 +1069,17 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, } } + const auto use_4over6_err_use_fast_math = + transformer_engine::getenv("NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"); + if (use_4over6_err_use_fast_math) { + for (auto &config : quant_config_list) { + config.set_nvfp4_4over6_err_use_fast_math(true); + } + for (auto &config : quant_config_list_colwise) { + config.set_nvfp4_4over6_err_use_fast_math(true); + } + } + auto &quant_config_list_colwise_to_use = need_separate_rng_states ? quant_config_list_colwise : quant_config_list; @@ -1157,6 +1188,9 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, cudaStream_t stream) { const size_t num_tensors = input_list.size(); const auto &quantizer = *quantizers.front(); + const bool nvfp4_use_4over6 = quantizer.nvfp4_4over6_mode != kNVTENVFP44Over6Disabled; + NVTE_CHECK(!nvfp4_use_4over6 || !quantizer.stochastic_rounding, + "NVFP4 4over6 quantization does not support stochastic rounding."); std::vector nvte_tensor_input_list; std::vector nvte_tensor_output_list; @@ -1189,6 +1223,27 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, need_separate_rng_states, quant_config_list, dummy_quant_config_list_colwise); // colwise rng states are not needed in this case + for (auto &config : quant_config_list) { + config.set_nvfp4_4over6_mode(quantizer.nvfp4_4over6_mode); + } + + // NVFP4 4over6 candidate error math is controlled separately by + // NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH. + const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); + if (use_fast_math && !nvfp4_use_4over6) { + for (auto &config : quant_config_list) { + config.set_use_fast_math(true); + } + } + + const auto use_4over6_err_use_fast_math = + transformer_engine::getenv("NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"); + if (use_4over6_err_use_fast_math) { + for (auto &config : quant_config_list) { + config.set_nvfp4_4over6_err_use_fast_math(true); + } + } + // We need: // 1. Rowwise amax = amax for input // 2. Columnwise amax = amax for input too @@ -1259,6 +1314,11 @@ void split_quantize_nvfp4_impl(const TensorWrapper &input, "NVFP4 split-quantize does not support 2D quantization"); NVTE_CHECK(!quantizer.with_amax_reduction, "NVFP4 split-quantize does not support amax reduction"); + if (quantizer.nvfp4_4over6_mode != kNVTENVFP44Over6Disabled) { + NVTE_CHECK(!quantizer.with_rht, "NVFP4 4over6 quantization does not support RHT."); + NVTE_CHECK(!quantizer.stochastic_rounding, + "NVFP4 4over6 quantization does not support stochastic rounding."); + } // Check input tensor shape const size_t input_last_dim = input.ndim() > 0 ? input.size(input.ndim() - 1) : 1; diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 7045995dd7..bc87b54ba8 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1729,6 +1729,20 @@ NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantize this->with_post_rht_amax = quantizer.attr("with_post_rht_amax").cast(); this->with_2d_quantization = quantizer.attr("with_2d_quantization").cast(); this->stochastic_rounding = quantizer.attr("stochastic_rounding").cast(); + const bool nvfp4_use_4over6 = quantizer.attr("nvfp4_use_4over6").cast(); + this->nvfp4_e4m3_max = quantizer.attr("nvfp4_e4m3_max").cast(); + NVTE_CHECK(this->nvfp4_e4m3_max == 448 || this->nvfp4_e4m3_max == 256, + "Unsupported NVFP4 E4M3 max: ", this->nvfp4_e4m3_max); + const auto nvfp4_4over6_err_mode = quantizer.attr("nvfp4_4over6_err_mode").cast(); + if (!nvfp4_use_4over6) { + this->nvfp4_4over6_mode = kNVTENVFP44Over6Disabled; + } else if (nvfp4_4over6_err_mode == "MAE") { + this->nvfp4_4over6_mode = kNVTENVFP44Over6MinMAE; + } else if (nvfp4_4over6_err_mode == "MSE") { + this->nvfp4_4over6_mode = kNVTENVFP44Over6MinMSE; + } else { + NVTE_ERROR("Unsupported NVFP4 4over6 error mode: ", nvfp4_4over6_err_mode); + } this->row_scaled_nvfp4 = quantizer.attr("row_scaled_nvfp4").cast(); // Get amax reduction group if needed for NVFP4 AG @@ -1778,6 +1792,8 @@ std::pair NVFP4Quantizer::create_tensor( "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; + const bool nvfp4_use_4over6 = this->nvfp4_4over6_mode != kNVTENVFP44Over6Disabled; + const int nvfp4_e4m3_max = this->nvfp4_e4m3_max; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 quantization requires rowwise usage."); NVTE_CHECK(!columnwise_usage, @@ -1845,6 +1861,8 @@ std::pair NVFP4Quantizer::create_tensor( kwargs["quantizer"] = this->quantizer; kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); + kwargs["nvfp4_use_4over6"] = py::cast(nvfp4_use_4over6); + kwargs["nvfp4_e4m3_max"] = py::cast(nvfp4_e4m3_max); kwargs["fake_dtype"] = GetATenDType(dtype); py::tuple args(0); @@ -1875,6 +1893,8 @@ std::pair NVFP4Quantizer::create_tensor( kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); kwargs["device"] = py::cast(device); kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); + kwargs["nvfp4_use_4over6"] = py::cast(nvfp4_use_4over6); + kwargs["nvfp4_e4m3_max"] = py::cast(nvfp4_e4m3_max); py::tuple args(0); PyObject* result = PyObject_Call(reinterpret_cast(NVFP4TensorPythonClass), args.ptr(), kwargs.ptr()); @@ -1908,6 +1928,7 @@ std::pair NVFP4Quantizer::create_tensor( } out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); out_cpp.set_row_scaled_nvfp4(row_scaled_nvfp4); + out_cpp.set_nvfp4_e4m3_max(nvfp4_e4m3_max); this->set_quantization_params(&out_cpp); return {std::move(out_cpp), std::move(out_py)}; @@ -1936,6 +1957,8 @@ std::pair NVFP4Quantizer::create_grouped_tenso std::optional columnwise_amax; const std::vector logical_shape_vec = {logical_first_dim, logical_last_dim}; const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; + const bool nvfp4_use_4over6 = this->nvfp4_4over6_mode != kNVTENVFP44Over6Disabled; + const int nvfp4_e4m3_max = this->nvfp4_e4m3_max; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 grouped quantization requires rowwise usage."); NVTE_CHECK(!columnwise_usage, @@ -2010,6 +2033,8 @@ std::pair NVFP4Quantizer::create_grouped_tenso kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); kwargs["with_gemm_swizzled_scales"] = this->optimize_for_gemm; kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); + kwargs["nvfp4_use_4over6"] = py::cast(nvfp4_use_4over6); + kwargs["nvfp4_e4m3_max"] = py::cast(nvfp4_e4m3_max); PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); if (result == nullptr) { PyErr_Print(); @@ -2085,12 +2110,16 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( const auto [flat_first_dim, flat_last_dim] = get_2d_dims(shape); const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; + const bool nvfp4_use_4over6 = this->nvfp4_4over6_mode != kNVTENVFP44Over6Disabled; + const int nvfp4_e4m3_max = this->nvfp4_e4m3_max; if (row_scaled_nvfp4) { NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 quantization requires rowwise usage."); NVTE_CHECK(!columnwise_usage, "Row-scaled NVFP4 quantization does not support columnwise usage."); } tensor.attr("_row_scaled_nvfp4") = py::cast(row_scaled_nvfp4); + tensor.attr("_nvfp4_use_4over6") = py::cast(nvfp4_use_4over6); + tensor.attr("_nvfp4_e4m3_max") = py::cast(nvfp4_e4m3_max); // Coerce row-wise data if (rowwise_usage) { @@ -2195,6 +2224,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( } out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); out_cpp.set_row_scaled_nvfp4(row_scaled_nvfp4); + out_cpp.set_nvfp4_e4m3_max(nvfp4_e4m3_max); this->set_quantization_params(&out_cpp); return {std::move(out_cpp), std::move(tensor)}; @@ -2285,6 +2315,14 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou } quant_config.set_nvfp4_2d_quantization(this->with_2d_quantization); quant_config.set_stochastic_rounding(this->stochastic_rounding); + quant_config.set_nvfp4_4over6_mode(this->nvfp4_4over6_mode); + quant_config_columnwise.set_nvfp4_4over6_mode(this->nvfp4_4over6_mode); + + if (this->nvfp4_4over6_mode != kNVTENVFP44Over6Disabled) { + NVTE_CHECK(!this->with_rht, "NVFP4 4over6 quantization does not support RHT."); + NVTE_CHECK(!this->stochastic_rounding, + "NVFP4 4over6 quantization does not support stochastic rounding."); + } // We only need RHT for columnwise usage. // flat first dim and last dim for multi dimensional input @@ -2425,12 +2463,21 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou // 1. replace 1 / x by reciprocal_approximate_ftz(x) // 2. when RHT cast fusion is available, fusion allows cast to be performed on FP32 data, // this will essentially remove a round trip between FP32 to BF16 then FP32 + // NVFP4 4over6 candidate error math is controlled separately by + // NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH. const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); - if (use_fast_math) { + if (use_fast_math && this->nvfp4_4over6_mode == kNVTENVFP44Over6Disabled) { quant_config.set_use_fast_math(true); quant_config_columnwise.set_use_fast_math(true); } + const auto use_4over6_err_use_fast_math = + transformer_engine::getenv("NVTE_NVFP4_4OVER6_ERR_USE_FAST_MATH"); + if (use_4over6_err_use_fast_math) { + quant_config.set_nvfp4_4over6_err_use_fast_math(true); + quant_config_columnwise.set_nvfp4_4over6_err_use_fast_math(true); + } + if (this->with_rht) { if (eligible_for_rht_cast_fusion) { // fusion kernel requires passing in RHT matrix directly for maximum performance diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index 37ab0b0535..ddb85808a5 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -135,6 +135,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); const bool with_gemm_swizzled_scales = tensor.attr("_with_gemm_swizzled_scales").cast(); const bool row_scaled_nvfp4 = tensor.attr("_row_scaled_nvfp4").cast(); + const int nvfp4_e4m3_max = tensor.attr("_nvfp4_e4m3_max").cast(); NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for NVFP4 Tensor."); @@ -165,6 +166,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) // Scale layout ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); ret.set_row_scaled_nvfp4(row_scaled_nvfp4); + ret.set_nvfp4_e4m3_max(nvfp4_e4m3_max); // Quantizer state quantizer->set_quantization_params(&ret); diff --git a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py index acb7abefd1..5c23c76703 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py @@ -221,6 +221,8 @@ class NVFP4TensorRef(QuantizedTensorStorage): scale_t: Optional[torch.Tensor] = None global_amax_row: Optional[torch.Tensor] = None global_amax_col: Optional[torch.Tensor] = None + nvfp4_use_4over6: bool = False + nvfp4_e4m3_max: int = 448 dtype: Optional[torch.dtype] = None device: Optional[torch.device] = None @@ -350,9 +352,15 @@ def __init__( eps: float = 0.0, quant_tile_shape: Tuple[int, int] = (1, 16), row_scaled_nvfp4: bool = False, + nvfp4_use_4over6: bool = False, + nvfp4_e4m3_max: int = 448, + nvfp4_4over6_err_mode: str = "MAE", with_rht: bool = False, with_random_sign_mask: bool = True, ): + nvfp4_4over6_err_mode = nvfp4_4over6_err_mode.upper() + if nvfp4_4over6_err_mode not in ("MAE", "MSE"): + raise ValueError("nvfp4_4over6_err_mode must be 'MAE' or 'MSE'.") if row_scaled_nvfp4: if not rowwise: raise ValueError("Row-scaled NVFP4 reference quantization requires rowwise usage.") @@ -360,6 +368,11 @@ def __init__( raise ValueError( "Row-scaled NVFP4 reference quantization does not support columnwise usage." ) + if nvfp4_use_4over6: + if pow_2_scales: + raise ValueError("4over6 is only supported for NVFP4 (non-pow2) mode.") + if quant_tile_shape not in ((1, 16), (16, 16)): + raise ValueError("4over6 reference quantization only supports 1x16 or 16x16 tiles.") super().__init__(rowwise=rowwise, columnwise=columnwise) self.internal = True @@ -368,6 +381,11 @@ def __init__( self.eps = eps self.quant_tile_shape = quant_tile_shape self.row_scaled_nvfp4 = row_scaled_nvfp4 + self.nvfp4_use_4over6 = nvfp4_use_4over6 + self.nvfp4_e4m3_max = nvfp4_e4m3_max if nvfp4_use_4over6 else 448 + if self.nvfp4_e4m3_max not in (448, 256): + raise ValueError("nvfp4_e4m3_max must be 448 or 256.") + self.nvfp4_4over6_err_mode = nvfp4_4over6_err_mode self.with_rht = with_rht self.with_random_sign_mask = with_random_sign_mask @@ -446,6 +464,113 @@ def _recover_swizzled_scales( result = torch.reshape(tmp, (rounded_m, rounded_n)) return result[:m, :scale_n] + @staticmethod + def _quantize_blockwise_4over6_reference( + x: torch.Tensor, + vec_max: torch.Tensor, + global_amax: torch.Tensor, + global_encode_scale: torch.Tensor, + global_decode_scale: torch.Tensor, + row_scaled_nvfp4: bool, + tile_len_y: int, + nvfp4_4over6_err_mode: str, + nvfp4_e4m3_max: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize NVFP4 with 4over6 candidate selection. + + This mirrors the CUDA path: map-to-4 uses a 1.5x expanded E4M3 block scale, + the configured error is computed in the original input domain with the + selected global E4M3 denominator, and ties choose map-to-6. + """ + m, num_blocks, tile_len_x = x.shape + n = num_blocks * tile_len_x + FLOAT4_E2M1_MAX = torch.tensor(6.0, device=x.device, dtype=torch.float32) + FLOAT8_E4M3_MAX = torch.tensor(448.0, device=x.device, dtype=torch.float32) + GLOBAL_SCALE_E4M3_MAX = torch.tensor( + float(nvfp4_e4m3_max), device=x.device, dtype=torch.float32 + ) + + decode_scale_base = torch.div(vec_max, FLOAT4_E2M1_MAX) * global_encode_scale + decode_scale_map4 = decode_scale_base * 1.5 + decode_scale_map6 = decode_scale_base + decode_scale_map4 = torch.clamp( + decode_scale_map4, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX + ).to(torch.float8_e4m3fn) + decode_scale_map6 = torch.clamp( + decode_scale_map6, min=-FLOAT8_E4M3_MAX, max=FLOAT8_E4M3_MAX + ).to(torch.float8_e4m3fn) + + fp32_max = torch.tensor( + torch.finfo(torch.float32).max, + device=decode_scale_map4.device, + dtype=torch.float32, + ) + encode_scale_map4 = torch.min( + torch.div(1.0, decode_scale_map4.to(torch.float32) * global_decode_scale), + fp32_max, + ) + encode_scale_map6 = torch.min( + torch.div(1.0, decode_scale_map6.to(torch.float32) * global_decode_scale), + fp32_max, + ) + + clipped_x_map4 = torch.clamp( + x.to(torch.float32) * encode_scale_map4, + -FLOAT4_E2M1_MAX, + FLOAT4_E2M1_MAX, + ).reshape(m, n) + clipped_x_map6 = torch.clamp( + x.to(torch.float32) * encode_scale_map6, + -FLOAT4_E2M1_MAX, + FLOAT4_E2M1_MAX, + ).reshape(m, n) + qx_map4 = cast_to_fp4x2(clipped_x_map4) + qx_map6 = cast_to_fp4x2(clipped_x_map6) + + fp4_map4 = cast_from_fp4x2(qx_map4, torch.float32).view(m, num_blocks, tile_len_x) + fp4_map6 = cast_from_fp4x2(qx_map6, torch.float32).view(m, num_blocks, tile_len_x) + denom = FLOAT4_E2M1_MAX * GLOBAL_SCALE_E4M3_MAX + sf_map4 = decode_scale_map4.to(torch.float32).squeeze(-1) + sf_map6 = decode_scale_map6.to(torch.float32).squeeze(-1) + if row_scaled_nvfp4: + error_global_amax = global_amax.squeeze(-1) + else: + error_global_amax = global_amax + x_float = x.to(torch.float32) + err_map4 = torch.zeros_like(vec_max) + err_map6 = torch.zeros_like(vec_max) + for idx in range(tile_len_x): + val_map4 = fp4_map4[:, :, idx] * sf_map4 + val_map4 = val_map4 * error_global_amax + val_map4 = val_map4 / denom + diff_map4 = val_map4 - x_float[:, :, idx] + if nvfp4_4over6_err_mode == "MSE": + err_map4 = err_map4 + (diff_map4 * diff_map4).unsqueeze(-1) + else: + err_map4 = err_map4 + torch.abs(diff_map4).unsqueeze(-1) + + val_map6 = fp4_map6[:, :, idx] * sf_map6 + val_map6 = val_map6 * error_global_amax + val_map6 = val_map6 / denom + diff_map6 = val_map6 - x_float[:, :, idx] + if nvfp4_4over6_err_mode == "MSE": + err_map6 = err_map6 + (diff_map6 * diff_map6).unsqueeze(-1) + else: + err_map6 = err_map6 + torch.abs(diff_map6).unsqueeze(-1) + if tile_len_y == 1: + pick_map4 = err_map4 < err_map6 + else: + err_map4_blocks = err_map4.view(m // tile_len_y, tile_len_y, num_blocks, 1).sum(dim=1) + err_map6_blocks = err_map6.view(m // tile_len_y, tile_len_y, num_blocks, 1).sum(dim=1) + pick_map4 = (err_map4_blocks < err_map6_blocks).repeat_interleave(tile_len_y, dim=0) + qx = torch.where( + pick_map4.expand(-1, -1, tile_len_x // 2), + qx_map4.view(m, num_blocks, tile_len_x // 2), + qx_map6.view(m, num_blocks, tile_len_x // 2), + ).reshape(m, n // 2) + decode_scale = torch.where(pick_map4, decode_scale_map4, decode_scale_map6).squeeze(-1) + return qx, decode_scale + @classmethod def _quantize_blockwise_reference( cls, @@ -456,6 +581,9 @@ def _quantize_blockwise_reference( *, pow_2_scales: bool, row_scaled_nvfp4: bool = False, + nvfp4_use_4over6: bool = False, + nvfp4_e4m3_max: int = 448, + nvfp4_4over6_err_mode: str = "MAE", eps: float, # pylint: disable=unused-argument ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -488,6 +616,10 @@ def _quantize_blockwise_reference( x = x.view(m, n // tile_len_x, tile_len_x) FLOAT4_E2M1_MAX = torch.tensor(6.0, device=x.device, dtype=torch.float32) FLOAT8_E4M3_MAX = torch.tensor(448.0, device=x.device, dtype=torch.float32) + global_scale_e4m3_max = float(nvfp4_e4m3_max if nvfp4_use_4over6 else 448) + GLOBAL_SCALE_E4M3_MAX = torch.tensor( + global_scale_e4m3_max, device=x.device, dtype=torch.float32 + ) decode_scale = torch.div(vec_max, FLOAT4_E2M1_MAX) if pow_2_scales: @@ -500,7 +632,7 @@ def _quantize_blockwise_reference( if row_scaled_nvfp4: global_amax = global_amax.to(torch.float32).view(m, 1, 1) - global_encode_scale = torch.div(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX, global_amax) + global_encode_scale = torch.div(GLOBAL_SCALE_E4M3_MAX * FLOAT4_E2M1_MAX, global_amax) global_encode_scale = torch.min( global_encode_scale, torch.tensor( @@ -519,6 +651,22 @@ def _quantize_blockwise_reference( global_encode_scale, ) global_decode_scale = torch.div(1.0, global_encode_scale) + if nvfp4_use_4over6: + # FourOverSix compares map-to-4 and map-to-6 candidates using + # the configured original input-domain error, while keeping TE-style FP4 + # quantization for each candidate. + return cls._quantize_blockwise_4over6_reference( + x, + vec_max, + global_amax, + global_encode_scale, + global_decode_scale, + row_scaled_nvfp4, + tile_len_y, + nvfp4_4over6_err_mode, + nvfp4_e4m3_max, + ) + global_encode_scale_multiplier = global_encode_scale * torch.reciprocal(FLOAT4_E2M1_MAX) # Match the kernel's default path: fold the FP4 reciprocal into the @@ -679,6 +827,9 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ self.quant_tile_shape[0], pow_2_scales=self.pow_2_scales, row_scaled_nvfp4=self.row_scaled_nvfp4, + nvfp4_use_4over6=self.nvfp4_use_4over6, + nvfp4_e4m3_max=self.nvfp4_e4m3_max, + nvfp4_4over6_err_mode=self.nvfp4_4over6_err_mode, eps=self.eps, ) if transpose_scales: @@ -702,6 +853,9 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ self.quant_tile_shape[1], self.quant_tile_shape[0], pow_2_scales=self.pow_2_scales, + nvfp4_use_4over6=self.nvfp4_use_4over6, + nvfp4_e4m3_max=self.nvfp4_e4m3_max, + nvfp4_4over6_err_mode=self.nvfp4_4over6_err_mode, eps=self.eps, ) @@ -741,6 +895,8 @@ def quantize( scale_t=sx_t, global_amax_row=global_amax_row, global_amax_col=global_amax_col, + nvfp4_use_4over6=self.nvfp4_use_4over6, + nvfp4_e4m3_max=self.nvfp4_e4m3_max, dtype=tensor.dtype, device=tensor.device, quant_dtype=self.dtype, @@ -788,6 +944,8 @@ def update_quantized( dst.scale_t = sx_t dst.global_amax_row = global_amax_row dst.global_amax_col = global_amax_col + dst.nvfp4_use_4over6 = self.nvfp4_use_4over6 + dst.nvfp4_e4m3_max = self.nvfp4_e4m3_max dst.dtype = src.dtype dst.quant_dtype = self.dtype dst.original_shape = original_shape @@ -893,7 +1051,35 @@ def qgemm( sx = sx.to(torch.float32) sw = sw.to(torch.float32) - factor = 6.0 * 6.0 * 448.0 * 448.0 + qresult_x_nvfp4_use_4over6 = getattr( + qresult_x, + "nvfp4_use_4over6", + getattr(qresult_x, "_nvfp4_use_4over6", self.nvfp4_use_4over6), + ) + qresult_w_nvfp4_use_4over6 = getattr( + qresult_w, + "nvfp4_use_4over6", + getattr(qresult_w, "_nvfp4_use_4over6", self.nvfp4_use_4over6), + ) + qresult_x_e4m3_max = getattr( + qresult_x, + "nvfp4_e4m3_max", + getattr(qresult_x, "_nvfp4_e4m3_max", self.nvfp4_e4m3_max), + ) + qresult_w_e4m3_max = getattr( + qresult_w, + "nvfp4_e4m3_max", + getattr(qresult_w, "_nvfp4_e4m3_max", self.nvfp4_e4m3_max), + ) + if qresult_x_nvfp4_use_4over6: + fp8_max_x = float(qresult_x_e4m3_max) + else: + fp8_max_x = 448.0 + if qresult_w_nvfp4_use_4over6: + fp8_max_w = float(qresult_w_e4m3_max) + else: + fp8_max_w = 448.0 + factor = 6.0 * 6.0 * fp8_max_x * fp8_max_w if gemm_type == quantization.GEMMType.WGRAD: partial_alpha = qresult_x.global_amax_col * qresult_w.global_amax_col diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 0c40723517..e503b4b560 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -1635,7 +1635,11 @@ def make_quantizers(self) -> list: * Forward, ``"weight"`` -> ``recipe.fp4_quant_fwd_weight``. * Forward, ``"input"`` / ``"output"`` (and any unknown forward type) -> ``recipe.fp4_quant_fwd_inp``. - * Backward, any slot -> ``recipe.fp4_quant_bwd_grad``. + * ``"grad_output"`` / ``"grad_input"`` -> ``recipe.fp4_quant_bwd_grad``. + * NVFP4 4over6 is applied to non-gradient slots selected by + ``recipe.nvfp4_4over6``. Gradient slots always use standard NVFP4, + which lets gradient RHT and stochastic rounding follow the base + recipe. When the owning module/op provides a role list via ``get_quantizer_roles``, the per-slot ``tensor_type`` drives dispatch. @@ -1647,7 +1651,7 @@ def make_quantizers(self) -> list: from .tensor.nvfp4_tensor import NVFP4Quantizer def _qparams(tensor_type: str): - if self.mode == "backward": + if tensor_type in ("grad_output", "grad_input"): return self.recipe.fp4_quant_bwd_grad if tensor_type == "weight": return self.recipe.fp4_quant_fwd_weight @@ -1655,6 +1659,34 @@ def _qparams(tensor_type: str): def _make(tensor_type: str) -> NVFP4Quantizer: qparams = _qparams(tensor_type) + nvfp4_use_4over6 = False + if tensor_type not in ("grad_output", "grad_input"): + if self.recipe.nvfp4_4over6 == "all": + nvfp4_use_4over6 = True + elif self.recipe.nvfp4_4over6 == "weights": + nvfp4_use_4over6 = tensor_type == "weight" + elif self.recipe.nvfp4_4over6 == "activations": + nvfp4_use_4over6 = tensor_type != "weight" + nvfp4_e4m3_max = 448 + if nvfp4_use_4over6: + # Current 4over6 kernels target RL and post-training quantization paths. + # Pre-training usage still needs a fused RHT + 4over6 quantization kernel. + if qparams.random_hadamard_transform: + raise ValueError("NVFP4 4over6 quantization does not support RHT.") + if qparams.stochastic_rounding: + raise ValueError( + "NVFP4 4over6 quantization does not support stochastic rounding." + ) + if self.recipe.nvfp4_4over6_e4m3_use_256 == "all": + nvfp4_e4m3_max = 256 + elif self.recipe.nvfp4_4over6_e4m3_use_256 == "weights": + if tensor_type == "weight": + nvfp4_e4m3_max = 256 + elif self.recipe.nvfp4_4over6_e4m3_use_256 == "activations": + if tensor_type != "weight": + nvfp4_e4m3_max = 256 + elif self.recipe.nvfp4_4over6_e4m3_use_256 == "none": + nvfp4_e4m3_max = 448 return NVFP4Quantizer( fp4_dtype=self.dtype, rowwise=True, @@ -1668,6 +1700,9 @@ def _make(tensor_type: str) -> NVFP4Quantizer: and tensor_type != "weight" and self.recipe.row_scaled_activation ), + nvfp4_use_4over6=nvfp4_use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, + nvfp4_4over6_err_mode=self.recipe.nvfp4_4over6_err_mode, ) if self.mode not in ("forward", "backward"): diff --git a/transformer_engine/pytorch/tensor/grouped_tensor.py b/transformer_engine/pytorch/tensor/grouped_tensor.py index f28f972b58..0cc03602a1 100644 --- a/transformer_engine/pytorch/tensor/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/grouped_tensor.py @@ -93,6 +93,8 @@ def __new__( stride: Optional[List[int]] = None, with_gemm_swizzled_scales: bool = False, row_scaled_nvfp4: bool = False, + nvfp4_use_4over6: bool = False, + nvfp4_e4m3_max: int = 448, ): if ( shapes is not None @@ -166,6 +168,8 @@ def __new__( columnwise_scale_inv_offsets=columnwise_scale_inv_offsets, with_gemm_swizzled_scales=with_gemm_swizzled_scales, row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=nvfp4_use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, ) return instance @@ -198,6 +202,8 @@ def copy_grouped_storage_metadata(dst: GroupedTensor, src: GroupedTensor) -> Non dst.quantized_tensors = src.quantized_tensors dst._with_gemm_swizzled_scales = src._with_gemm_swizzled_scales dst.row_scaled_nvfp4 = src.row_scaled_nvfp4 + dst.nvfp4_use_4over6 = src.nvfp4_use_4over6 + dst.nvfp4_e4m3_max = src.nvfp4_e4m3_max def make_wrapper_like(src: GroupedTensor, requires_grad: bool) -> GroupedTensor: """Create a wrapper of the same type and tensor metadata as src.""" diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 2ebefefaaa..24962d67f2 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -130,6 +130,12 @@ class NVFP4Quantizer(Quantizer): """Whether emitted NVFP4 tensors store one FP32 amax per row.""" row_scaled_nvfp4: bool + """Whether to use NVFP4 4over6 map-to-4/map-to-6 block selection.""" + nvfp4_use_4over6: bool + """Global E4M3 scale bound used by emitted NVFP4 tensors.""" + nvfp4_e4m3_max: int + """NVFP4 4over6 candidate-selection error mode.""" + nvfp4_4over6_err_mode: str """RHT matrix random sign mask""" rht_matrix_random_sign_mask_t: int @@ -147,6 +153,9 @@ def __init__( with_2d_quantization: bool = False, stochastic_rounding: bool = False, row_scaled_nvfp4: bool = False, + nvfp4_use_4over6: bool = False, + nvfp4_e4m3_max: int = 448, + nvfp4_4over6_err_mode: str = "MAE", with_random_sign_mask: bool = True, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) @@ -158,6 +167,13 @@ def __init__( self.with_2d_quantization = with_2d_quantization self.stochastic_rounding = stochastic_rounding self.row_scaled_nvfp4 = row_scaled_nvfp4 + self.nvfp4_use_4over6 = nvfp4_use_4over6 + self.nvfp4_e4m3_max = nvfp4_e4m3_max if nvfp4_use_4over6 else 448 + if self.nvfp4_e4m3_max not in (448, 256): + raise ValueError("nvfp4_e4m3_max must be 448 or 256.") + self.nvfp4_4over6_err_mode = nvfp4_4over6_err_mode.upper() + if self.nvfp4_4over6_err_mode not in ("MAE", "MSE"): + raise ValueError("nvfp4_4over6_err_mode must be 'MAE' or 'MSE'.") self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht( with_random_sign_mask, torch.cuda.current_device() ) @@ -204,6 +220,9 @@ def copy(self) -> NVFP4Quantizer: with_2d_quantization=self.with_2d_quantization, stochastic_rounding=self.stochastic_rounding, row_scaled_nvfp4=self.row_scaled_nvfp4, + nvfp4_use_4over6=self.nvfp4_use_4over6, + nvfp4_e4m3_max=self.nvfp4_e4m3_max, + nvfp4_4over6_err_mode=self.nvfp4_4over6_err_mode, ) quantizer.internal = self.internal quantizer.optimize_for_gemm = self.optimize_for_gemm @@ -356,6 +375,8 @@ def __new__( quantizer: Quantizer, with_gemm_swizzled_scales: bool, row_scaled_nvfp4: bool = False, + nvfp4_use_4over6: bool = False, + nvfp4_e4m3_max: int = 448, **kwargs, ): instance = super().__new__( @@ -371,6 +392,8 @@ def __new__( with_gemm_swizzled_scales, *args, row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=nvfp4_use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, **kwargs, ) return instance @@ -528,6 +551,9 @@ def fsdp_pre_all_gather(self, mesh, orig_size, contiguous_orig_stride, module, m columnwise_usage, self._amax_rowwise, self._amax_columnwise, + self._row_scaled_nvfp4, + self._nvfp4_use_4over6, + self._nvfp4_e4m3_max, self.shape[-1], ) return sharded_tensors, metadata @@ -546,7 +572,16 @@ def fsdp_post_all_gather( all-gathered rowwise data. Columnwise data is derived locally via _create_columnwise() instead of being all-gathered. """ - fp4_dtype, columnwise_usage, amax_rowwise, amax_columnwise, K = metadata + ( + fp4_dtype, + columnwise_usage, + amax_rowwise, + amax_columnwise, + row_scaled_nvfp4, + nvfp4_use_4over6, + nvfp4_e4m3_max, + K, + ) = metadata # Only rowwise data+scales were all-gathered rowwise_data, rowwise_scale_inv = all_gather_outputs[:2] @@ -569,6 +604,9 @@ def fsdp_post_all_gather( out._rowwise_scale_inv = rowwise_scale_inv out._amax_rowwise = amax_rowwise out._amax_columnwise = amax_columnwise + out._row_scaled_nvfp4 = row_scaled_nvfp4 + out._nvfp4_use_4over6 = nvfp4_use_4over6 + out._nvfp4_e4m3_max = nvfp4_e4m3_max else: # Construct new tensor (first iteration) out = NVFP4Tensor( @@ -585,6 +623,9 @@ def fsdp_post_all_gather( requires_grad=False, with_gemm_swizzled_scales=False, device=rowwise_data.device, + row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=nvfp4_use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, ) # Derive columnwise data locally via transpose instead of all-gathering it @@ -724,6 +765,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): requires_grad=tensor.requires_grad, with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, device=tensor.device, + row_scaled_nvfp4=tensor._row_scaled_nvfp4, + nvfp4_use_4over6=tensor._nvfp4_use_4over6, + nvfp4_e4m3_max=tensor._nvfp4_e4m3_max, ) # Default case @@ -745,6 +789,9 @@ def __reduce_ex__(self, protocol: int) -> tuple: self.dtype, self._quantizer, self._with_gemm_swizzled_scales, + self._row_scaled_nvfp4, + self._nvfp4_use_4over6, + self._nvfp4_e4m3_max, ), ) @@ -837,6 +884,9 @@ def _set_data(self, tensor: torch.Tensor) -> None: self._amax_rowwise = tensor._amax_rowwise self._amax_columnwise = tensor._amax_columnwise self._with_gemm_swizzled_scales = tensor._with_gemm_swizzled_scales + self._row_scaled_nvfp4 = tensor._row_scaled_nvfp4 + self._nvfp4_use_4over6 = tensor._nvfp4_use_4over6 + self._nvfp4_e4m3_max = tensor._nvfp4_e4m3_max return # Quantize to FP8 @@ -889,7 +939,10 @@ def _make_nvfp4_tensor_in_reduce_ex( fp4_dtype: TE_DType, dtype: torch.dtype, quantizer: Quantizer, - with_gemm_swizzled_scales: bool = False, + with_gemm_swizzled_scales: bool, + row_scaled_nvfp4: bool = False, + nvfp4_use_4over6: bool = False, + nvfp4_e4m3_max: int = 448, ) -> NVFP4Tensor: """Reconstruct an ``NVFP4Tensor`` from its ``__reduce_ex__`` payload.""" # Infer device from whichever inner buffer is populated so the wrapper @@ -914,6 +967,9 @@ def _make_nvfp4_tensor_in_reduce_ex( requires_grad=False, with_gemm_swizzled_scales=with_gemm_swizzled_scales, device=device, + row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=nvfp4_use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, ) @@ -997,6 +1053,9 @@ def forward( requires_grad=tensor.requires_grad, with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, device=tensor.device, + row_scaled_nvfp4=tensor._row_scaled_nvfp4, + nvfp4_use_4over6=tensor._nvfp4_use_4over6, + nvfp4_e4m3_max=tensor._nvfp4_e4m3_max, ) @staticmethod @@ -1040,6 +1099,9 @@ def backward( requires_grad=grad.requires_grad, with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales, device=grad.device, + row_scaled_nvfp4=grad._row_scaled_nvfp4, + nvfp4_use_4over6=grad._nvfp4_use_4over6, + nvfp4_e4m3_max=grad._nvfp4_e4m3_max, ) return dgrad, None return grad.view(ctx.shape), None @@ -1125,6 +1187,9 @@ def forward( requires_grad=tensor.requires_grad, with_gemm_swizzled_scales=tensor._with_gemm_swizzled_scales, device=tensor.device, + row_scaled_nvfp4=tensor._row_scaled_nvfp4, + nvfp4_use_4over6=tensor._nvfp4_use_4over6, + nvfp4_e4m3_max=tensor._nvfp4_e4m3_max, ) @staticmethod @@ -1168,6 +1233,9 @@ def backward( requires_grad=grad.requires_grad, with_gemm_swizzled_scales=grad._with_gemm_swizzled_scales, device=grad.device, + row_scaled_nvfp4=grad._row_scaled_nvfp4, + nvfp4_use_4over6=grad._nvfp4_use_4over6, + nvfp4_e4m3_max=grad._nvfp4_e4m3_max, ) return dgrad, None return grad.view(ctx.shape), None diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index ac56d334bc..438e124021 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -73,6 +73,8 @@ def _initialize_storage_fields( stride: Optional[List[int]] = None, with_gemm_swizzled_scales: bool = False, row_scaled_nvfp4: bool = False, + nvfp4_use_4over6: bool = False, + nvfp4_e4m3_max: int = 448, ) -> None: """ Initialize a GroupedTensor. @@ -149,6 +151,8 @@ def _initialize_storage_fields( instance.quantized_tensors = None instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales instance.row_scaled_nvfp4 = row_scaled_nvfp4 + instance.nvfp4_use_4over6 = nvfp4_use_4over6 + instance.nvfp4_e4m3_max = nvfp4_e4m3_max if nvfp4_use_4over6 else 448 def __new__( cls, @@ -175,6 +179,8 @@ def __new__( stride: Optional[List[int]] = None, with_gemm_swizzled_scales: bool = False, row_scaled_nvfp4: bool = False, + nvfp4_use_4over6: bool = False, + nvfp4_e4m3_max: int = 448, ): instance = object.__new__(cls) cls._initialize_storage_fields( @@ -201,6 +207,8 @@ def __new__( stride=stride, with_gemm_swizzled_scales=with_gemm_swizzled_scales, row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=nvfp4_use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, ) return instance @@ -307,6 +315,33 @@ def get_dtype(self) -> torch.dtype: return self.fake_dtype + @property + def row_scaled_nvfp4(self) -> bool: + """Whether grouped NVFP4 tensors use row-scaled amax metadata.""" + return self._row_scaled_nvfp4 + + @row_scaled_nvfp4.setter + def row_scaled_nvfp4(self, row_scaled_nvfp4: bool) -> None: + self._row_scaled_nvfp4 = row_scaled_nvfp4 + + @property + def nvfp4_use_4over6(self) -> bool: + """Whether grouped NVFP4 tensors carry 4over6 metadata.""" + return self._nvfp4_use_4over6 + + @nvfp4_use_4over6.setter + def nvfp4_use_4over6(self, nvfp4_use_4over6: bool) -> None: + self._nvfp4_use_4over6 = nvfp4_use_4over6 + + @property + def nvfp4_e4m3_max(self) -> int: + """Global E4M3 scale bound used by grouped NVFP4 tensors.""" + return self._nvfp4_e4m3_max + + @nvfp4_e4m3_max.setter + def nvfp4_e4m3_max(self, nvfp4_e4m3_max: int) -> None: + self._nvfp4_e4m3_max = nvfp4_e4m3_max + def prepare_for_saving( self, ) -> Tuple[list[Optional[torch.Tensor]], "GroupedTensorStorage"]: @@ -376,6 +411,8 @@ def clear(self) -> None: self.tensor_shapes = [] self.fake_dtype = torch.float32 self.row_scaled_nvfp4 = False + self.nvfp4_use_4over6 = False + self.nvfp4_e4m3_max = 448 def __repr__(self) -> str: """String representation of the GroupedTensorStorage.""" @@ -545,6 +582,8 @@ def copy(self) -> "GroupedTensorStorage": columnwise_scale_inv_offsets=self.columnwise_scale_inv_offsets, with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, row_scaled_nvfp4=self.row_scaled_nvfp4, + nvfp4_use_4over6=self.nvfp4_use_4over6, + nvfp4_e4m3_max=self.nvfp4_e4m3_max, ) @staticmethod @@ -656,6 +695,8 @@ def make_grouped_tensor( scale_inv_offsets = None columnwise_scale_inv_offsets = None row_scaled_nvfp4 = False + nvfp4_use_4over6 = False + nvfp4_e4m3_max = 448 if no_quantization: assert dtype is not None, "dtype must be provided for unquantized GroupedTensor" if rowwise_usage: @@ -715,6 +756,8 @@ def make_grouped_tensor( amax = torch.empty(num_tensors, dtype=torch.float32, device=device) elif quantizer._get_compatible_recipe().nvfp4(): row_scaled_nvfp4 = quantizer.row_scaled_nvfp4 + nvfp4_use_4over6 = quantizer.nvfp4_use_4over6 + nvfp4_e4m3_max = quantizer.nvfp4_e4m3_max if row_scaled_nvfp4: if not rowwise_usage: raise ValueError( @@ -843,6 +886,8 @@ def make_grouped_tensor( quantizer.optimize_for_gemm if quantizer is not None else False ), row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=nvfp4_use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, ) grouped_tensor.quantized_tensors = grouped_tensor.split_into_quantized_tensors() return grouped_tensor @@ -957,6 +1002,8 @@ def split_into_quantized_tensors( self.columnwise_scale_inv_offsets = columnwise_scale_inv_offsets nvfp4_rowwise_amax_offsets = None row_scaled_nvfp4 = self.row_scaled_nvfp4 + nvfp4_use_4over6 = self.nvfp4_use_4over6 + nvfp4_e4m3_max = self.nvfp4_e4m3_max if recipe.nvfp4() and row_scaled_nvfp4: cum = 0 nvfp4_rowwise_amax_offsets = [0] @@ -1184,6 +1231,8 @@ def split_into_quantized_tensors( quantizer=quantizer, with_gemm_swizzled_scales=quantizer.optimize_for_gemm, row_scaled_nvfp4=row_scaled_nvfp4, + nvfp4_use_4over6=nvfp4_use_4over6, + nvfp4_e4m3_max=nvfp4_e4m3_max, ) result.append(tensor) diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index 490184e5f8..250fa6bdb2 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -104,6 +104,10 @@ class NVFP4TensorStorage(QuantizedTensorStorage): _with_gemm_swizzled_scales: bool # Whether this NVFP4 tensor uses row-scaled amax metadata _row_scaled_nvfp4: bool + # Whether this NVFP4 tensor uses 4over6 map-to-4/map-to-6 block selection + _nvfp4_use_4over6: bool + # Global E4M3 scale bound used by this NVFP4 tensor + _nvfp4_e4m3_max: int def __new__( cls, @@ -119,6 +123,8 @@ def __new__( *args, fake_dtype: Optional[torch.dtype] = None, row_scaled_nvfp4: bool = False, + nvfp4_use_4over6: bool = False, + nvfp4_e4m3_max: int = 448, **kwargs, ): if cls is NVFP4TensorStorage: @@ -137,6 +143,8 @@ def __new__( instance._amax_columnwise = amax_columnwise instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales instance._row_scaled_nvfp4 = row_scaled_nvfp4 + instance._nvfp4_use_4over6 = nvfp4_use_4over6 + instance._nvfp4_e4m3_max = nvfp4_e4m3_max if nvfp4_use_4over6 else 448 return instance @@ -163,6 +171,10 @@ def copy_from_storage(self, src: QuantizedTensorStorage) -> None: raise RuntimeError("Scale layout mismatch in copy_from_storage") if self._row_scaled_nvfp4 != src._row_scaled_nvfp4: raise RuntimeError("Rowwise amax scaling mode mismatch in copy_from_storage") + if self._nvfp4_use_4over6 != src._nvfp4_use_4over6: + raise RuntimeError("NVFP4 4over6 mode mismatch in copy_from_storage") + if self._nvfp4_e4m3_max != src._nvfp4_e4m3_max: + raise RuntimeError("NVFP4 4over6 E4M3 scale bound mismatch in copy_from_storage") def _copy_optional(dst: Optional[torch.Tensor], src_tensor: Optional[torch.Tensor]): if dst is not None and src_tensor is not None: @@ -188,6 +200,8 @@ def get_metadata(self) -> Dict[str, Any]: "quantizer": self._quantizer, "with_gemm_swizzled_scales": self._with_gemm_swizzled_scales, "row_scaled_nvfp4": self._row_scaled_nvfp4, + "nvfp4_use_4over6": self._nvfp4_use_4over6, + "nvfp4_e4m3_max": self._nvfp4_e4m3_max, "fake_dtype": self._dtype, } @@ -321,6 +335,8 @@ def view(self, shape: torch.Size): fp4_dtype=self._fp4_dtype, with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, row_scaled_nvfp4=self._row_scaled_nvfp4, + nvfp4_use_4over6=self._nvfp4_use_4over6, + nvfp4_e4m3_max=self._nvfp4_e4m3_max, fake_dtype=self._dtype, ) From 80ea3133efa4c3a3679845b8ee46dfc06e2792f0 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Fri, 22 May 2026 18:45:01 -0700 Subject: [PATCH 2/2] [PyTorch] Add `pad_between_seqs` support for non-CP and CP (A2A and P2P) with FA3 + THD (varlen) (#2596) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [PyTorch] Add pad_between_seqs support for FlashAttention 3 with CP Add support for padding between sequences (pad_between_seqs) in the FlashAttention 3 backend when used with context parallelism (CP). Key changes: - backends.py: Pass fa_pad_between_seqs through to FA3 forward/backward - context_parallel.py: Handle pad_between_seqs in A2A and P2P CP paths, zero FA3 padding garbage in CP forward, fix a2a backward alignment - dot_product_attention.py: Auto-detect pad_between_seqs from cu_seqlens - utils.py: Gate FA3 deterministic backward for hdim>=256, fix flash_attn_supported override for cross-attention and large head_dim, disable UnfusedDotProductAttention for pad_between_seqs, add SM100+ FA3 skip Signed-off-by: Sudhakar Singh * [PyTorch] Add pad_between_seqs tests for CP and non-CP FlashAttention Add test parametrization for pad_between_seqs in flash attention tests. Update run_attention_with_cp.py to support the new parameter and fix batch boundary alignment in the non-CP FA3 path. Run tests in parallel when multiple GPUs are available. Signed-off-by: Sudhakar Singh * [QA] Add CP deterministic tests to L3 and support TE_PATH in FA test Add deterministic CP test runs to L3 FA versions test. Support TE_PATH positional arg and fix GPU threshold for parallel test execution. Signed-off-by: Sudhakar Singh * [PyTorch] Fix FA3 deterministic gate to match upstream backward constraint The previous check disabled FA3 for deterministic mode whenever head_dim_qk > 128, which was overly conservative — FA3 forward supports deterministic execution at any head dim. The actual constraint from flash_api.cpp is that the backward pass does not support deterministic mode when max(head_size, head_size_v) >= 256. Narrow the gate to only disable FA3 during training (backward) and raise the threshold to >= 256, checking both head_dim_qk and head_dim_v to handle MLA configs with asymmetric head dimensions. Ref: https://github.com/Dao-AILab/flash-attention/blob/ac6f2eb5/hopper/flash_api.cpp#L1370 Signed-off-by: Sudhakar Singh * [PyTorch] Disable FlashAttention 4 for pad_between_seqs with THD The pad_between_seqs gate in get_attention_backend only disabled FlashAttention 2, letting FA4 leak through to the test-time fused-vs-flash comparison. On B200 runners that install flash-attn-4, this caused test_dpa_qkv_layout_thd to compare FusedAttention against an FA4 output whose padded positions contain garbage, producing 48 numerics failures in L3_pytorch_FA_versions_test--B200_1GPU. The log message already claimed FA4 would be disabled — this change makes the code match the message: set use_flash_attention_4 = False alongside use_flash_attention_2 when pad_between_seqs is True. FA3 continues to support pad_between_seqs via seqused_q/seqused_k. Signed-off-by: Sudhakar Singh * [QA] Fix cutlass-dsl utils shadow in FA versions test FA4 install brings in nvidia-cutlass-dsl, whose `import cutlass` adds cutlass/base_dsl/ to sys.path. That directory contains a utils/ package that shadows tests/pytorch/utils.py, breaking collection of test_attention_with_cp.py with: ImportError: cannot import name 'ModelConfig' from 'utils' Prepend $TE_PATH/tests/pytorch to PYTHONPATH so the local utils.py is always resolved first, regardless of what FA4 dependencies install. Signed-off-by: Sudhakar Singh * skip tests which OOM in deterministic+backward+hopper+large_configs as its a known cudnn issue Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make cp det and nondet tests run in parallel whenever possible Signed-off-by: Sudhakar Singh * [QA] L3: gate CP tests per-arch to avoid CI timeout PR 2596 added deterministic CP runs to the L3 FA-versions matrix, multiplying CP wall time across every FA version and causing CI timeouts (pipeline 50243000). Run CP tests once per arch instead, picking the FA version each arch's CP code path actually supports: - sm90 (H100): FA3 3.0.0b1 - context_parallel.py is FA3-only on Hopper (use_flash_attn_3 threaded throughout, FA4 not wired in; pad_between_seqs gated on use_flash_attn_3 at lines 1038, 1366) - sm>90 (B200): latest FA4 - FA3 is not built/installed for sm>90 Non-CP test_attention.py still runs for every FA version in the array. Also drop FA 2.7.3 from the sm90 list (no longer maintained as a target) and bump the FA4 pin from 4.0.0b8 to 4.0.0b11. b8 has an SM90 backward kernel bug fixed by upstream PR #2513 in b11 (get_smem_store_C() got multiple values for argument 'transpose'). Signed-off-by: Sudhakar Singh * [QA] L3: skip pre-installed FA3 build, per-FA junit XMLs Three follow-ups on top of 13ba0046 (L3 per-arch CP gating): 1. Skip the inline FA3 source build when flash_attn_interface is already importable. This makes the script a no-op on FA3 install when the base image has FA3 baked in (companion to TE !573 on te_ci, which auto-sets INSTALL_FA3=${RUN_L3_TESTS} so FA3 is preinstalled for L3 pipelines). Saves ~20 min of L3 H100 wall time once both land. Falls back to the existing inline build when FA3 is not pre-installed. 2. Suffix junit XMLs with the FA version (pytest_test_attention_fa2_8_3.xml etc.) so per-iteration results are preserved instead of overwritten. Pipeline 50348672 had no per-FA timing visibility because pytest.xml was clobbered by each loop iteration. 3. Include FA version in test_fail messages so CI dashboards show which FA iteration caused a failure (was "test_attention.py", now "test_attention.py (FA 2.8.3)"). Also fold the CP_FA_VERSION assignment into the same if-block as FA_versions (was a separate if-block immediately after) since the two are arch-keyed in lockstep. Signed-off-by: Sudhakar Singh * b200 shouldnt run FA3 even if present Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * L3: drop stale RUN_L3_TESTS=1 note; use flash_attn_3 for FA3 check Address two pending review comments: 1. The "auto-set when RUN_L3_TESTS=1" annotation on the base-image FA3 preinstall is no longer accurate; drop it so readers don't grep for a coupling that doesn't exist. 2. `flash_attn_interface` reads like a generic FA API even though the top-level shim is only created by the FA3 install. Switching to `import flash_attn_3` makes the FA3-specific intent unambiguous and matches the FA3 package layout produced by the source build. Local validation on H100 (sm90) with FA3 active, TE worktree resolving to the editable install (verified via three-layer import check from /tmp): test_attention_with_cp.py parallel det+nondet — 45 passed / 0 failed nondet (3:52), 33 passed / 0 failed det (2:55). 33 pad-True nondet passes + 21 pad-True det passes confirm the FA3+THD+CP path is exercised; 5 det OOM cases skip cleanly via the existing inline guard. Same test scope is exercised by L1_pytorch_distributed_unittest (parallel det+nondet) and the FA3 iteration of L3_pytorch_FA_versions_test; the changes here are L3-only documentation/detection tweaks and do not alter the Python test code, but the L1+L3 CP execution was re-run on the cleaned PR head end-to-end as proof. Signed-off-by: Sudhakar Singh * Address review nits: bHSS-gated OOM skip; drop Dockerfile.base specifics 1. Det FusedAttention backward THD/sm90 OOM skip: gate on the actual memory pressure (b*H*S*S) instead of num_heads >= 20. The cuDNN workspace is proportional to bHSS, so a future config with H >= 20 but small b or S would be needlessly skipped under the old guard, while a config with H < 20 but large b*S that hit the same OOM wouldn't be caught. Threshold 1e9 empirically matches the existing 5-case skip set on the test_essential fused subset (cp_2_0, cp_2_2, cp_3_1, cp_4_2, cp_4_3 — bHSS in 1.07B–4.29B) and lets cp_1_0/ cp_2_1/cp_2_4/cp_3_2/cp_3_4 (bHSS ~0.40B) keep running. 2. L3 FA3 install comment: drop the "Dockerfile.base INSTALL_FA3=1" reference. The detection check is the contract; mentioning a specific image variable couples this script to an out-of-tree provisioning detail that may evolve independently. Local validation on H100 (sm90) with FA3 active and TE worktree resolving to editable (verified via /tmp-cwd three-layer import check after reinstall — the /usr/local TE shadow had reappeared between sessions): test_attention_with_cp.py parallel det+nondet — 45 passed / 0 failed nondet (4:09), 33 passed / 0 failed det (3:14). 33 pad-True nondet passes + 21 pad-True det passes; 5 det OOM cases skip via the new bHSS gate — same cases as the old num_heads-only gate. Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Name the OOM-skip threshold and explain the 128*bHSS workspace observation Address review nits on the deterministic THD-backward OOM guard: 1. Replace the magic number 1_000_000_000 with the named constant SM90_DET_FUSED_THD_BWD_MAX_BHSS = 1 << 30, so the value is searchable and labeled. 2. Replace the prefatory comment with a short note tying the number to cuDNN's actual workspace request (~128 * bHSS bytes, measured on cuDNN 9.21.0 sm90 — see local sweep). At bHSS = 1<<30 the request is 128 GiB, which doesn't fit on H100's 80 GB. 3. Flag the b>=3 caveat for future readers: cuDNN rounds the batch up internally so workspace grows super-linearly past b=2 (b=4 asks for 4x the b=2 workspace, not 2x). The current fused-essential matrix is all b=2, so the threshold stays correct for what the test exercises; the note is there so the next person doesn't have to rediscover it. Skip set is unchanged — cp_2_0, cp_2_1, cp_3_1, cp_4_2, cp_4_3. Signed-off-by: Sudhakar Singh * Reword OOM-skip comment as observations, not cuDNN-internal claims We measured the workspace request from outside cuDNN, so the comment should say "observed" rather than asserting what cuDNN does. Reframes the ~128 * bHSS bytes formula and the super-linear b>=3 behavior as empirical observations from our sweep. No code change. Signed-off-by: Sudhakar Singh --------- Signed-off-by: Sudhakar Singh Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- qa/L1_pytorch_distributed_unittest/test.sh | 19 ++- qa/L3_pytorch_FA_versions_test/test.sh | 85 ++++++++++- .../attention/run_attention_with_cp.py | 96 +++++++----- tests/pytorch/attention/test_attention.py | 34 ++--- .../attention/test_attention_with_cp.py | 30 +++- .../dot_product_attention/backends.py | 32 +++- .../dot_product_attention/context_parallel.py | 141 +++++++++++++++--- .../dot_product_attention.py | 3 + .../attention/dot_product_attention/utils.py | 38 +++-- 9 files changed, 371 insertions(+), 107 deletions(-) diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index db13e9f1e0..7eb34a62e4 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -22,6 +22,24 @@ mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" +# Run CP tests (deterministic + non-deterministic) first so they can be parallelized. +# Each needs 4 GPUs, so >=8 GPUs allows them to run concurrently on disjoint GPU sets. +NUM_GPUS=$(python3 -c "import torch; print(torch.cuda.device_count())") +echo "Detected $NUM_GPUS GPU(s)" +if [ "$NUM_GPUS" -ge 8 ]; then + echo "Running CP tests in parallel: non-deterministic on GPUs 0-3, deterministic on GPUs 4-7" + CUDA_VISIBLE_DEVICES=0,1,2,3 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py & + PID_CP_NONDET=$! + CUDA_VISIBLE_DEVICES=4,5,6,7 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py & + PID_CP_DET=$! + wait $PID_CP_NONDET || test_fail "test_attention_with_cp.py" + wait $PID_CP_DET || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention_with_cp.py" +else + echo "Running CP tests sequentially: need >=8 GPUs for parallel execution" + python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" + NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention_with_cp.py" +fi + python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics_exact.xml $TE_PATH/tests/pytorch/distributed/test_numerics_exact.py || test_fail "test_numerics_exact.py" @@ -29,7 +47,6 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_newton_schulz.xml $TE_PATH/tests/pytorch/distributed/test_newton_schulz.py || test_fail "test_newton_schulz.py" diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index 642eb93b06..30f1fc38c0 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -2,13 +2,25 @@ # # See LICENSE for license information. -set -e +function error_exit() { + echo "Error: $1" + exit 1 +} + +function test_fail() { + RET=1 + FAILED_CASES="$FAILED_CASES $1" + echo "Error: sub-test failed: $1" +} + +RET=0 +FAILED_CASES="" : ${TE_PATH:=/opt/transformerengine} : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" -pip3 install pytest==8.2.1 +pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" # Limit parallel build jobs to avoid overwhelming system resources export MAX_JOBS=32 @@ -16,12 +28,18 @@ export MAX_JOBS=32 # Iterate over Flash Attention versions sm_arch=`python3 -c "import torch; sm = torch.cuda.get_device_capability(0); print(sm[0]*10+sm[1])"` export FLASH_ATTN_CUDA_ARCHS=$sm_arch +# CP tests are expensive and run only once per arch: +# - sm90 (H100): FA3 (3.0.0b1) - context_parallel.py only supports FA3 on Hopper +# - sm>90 (B200): latest FA4 - FA3 is not built/installed for sm>90 +# Non-CP tests still run for every FA version in the array. if [ $sm_arch -gt 90 ] then - FA_versions=(2.8.3 4.0.0b8) + FA_versions=(2.8.3 4.0.0b11) + CP_FA_VERSION="${FA_versions[-1]}" elif [ $sm_arch -eq 90 ] then - FA_versions=(2.7.3 2.8.3 3.0.0b1 4.0.0b8) + FA_versions=(2.8.3 3.0.0b1 4.0.0b11) + CP_FA_VERSION="3.0.0b1" fi for fa_version in "${FA_versions[@]}" @@ -35,12 +53,63 @@ do then pip3 install flash-attn-4==${fa_version} nvidia-cutlass-dsl[cu13]==4.4.2 --no-build-isolation else - git clone https://github.com/Dao-AILab/flash-attention.git - cd flash-attention/hopper && python setup.py install - cd ../../ + # FA3 source build (~20 min). Skip if FA3 is already installed. + if python3 -c "import flash_attn_3" 2>/dev/null; then + echo "FA3 already installed (from base image); skipping source build" + else + git clone https://github.com/Dao-AILab/flash-attention.git + cd flash-attention/hopper && python setup.py install + cd ../../ + fi fi + # Ensure local test utils is found before nvidia-cutlass-dsl's utils package + export PYTHONPATH=$TE_PATH/tests/pytorch:${PYTHONPATH:-} + # Run tests - NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/attention/test_attention.py + NUM_GPUS=$(nvidia-smi -L | wc -l) + echo "Detected $NUM_GPUS GPU(s)" + + # Suffix junit XMLs with the FA version so per-iteration results are preserved + # (otherwise pytest.xml is overwritten on each loop iteration and we lose timing + # data for all but the last FA version). + fa_tag="${fa_version//./_}" + XML_ATTN="$XML_LOG_DIR/pytest_test_attention_fa${fa_tag}.xml" + XML_CP="$XML_LOG_DIR/pytest_test_attention_with_cp_fa${fa_tag}.xml" + + if [ "$fa_version" = "$CP_FA_VERSION" ]; then + echo "Running CP tests with FA $fa_version (CP version for sm$sm_arch)" + if [ "$NUM_GPUS" -ge 5 ]; then + CP_NUM_GPUS=$(( NUM_GPUS - 1 > 4 ? 4 : NUM_GPUS - 1 )) + CP_GPUS=$(seq -s, 1 $CP_NUM_GPUS) + echo "Running tests in parallel: test_attention.py on GPU 0, test_attention_with_cp.py on GPUs $CP_GPUS ($CP_NUM_GPUS GPUs)" + + CUDA_VISIBLE_DEVICES=0 NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s \ + --junitxml=$XML_ATTN \ + $TE_PATH/tests/pytorch/attention/test_attention.py & + PID_ATTN=$! + CUDA_VISIBLE_DEVICES=$CP_GPUS NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s \ + --junitxml=$XML_CP \ + $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py & + PID_CP=$! + + wait $PID_ATTN || test_fail "test_attention.py (FA $fa_version)" + wait $PID_CP || test_fail "test_attention_with_cp.py (FA $fa_version)" + else + echo "Running tests sequentially: need >=5 GPUs for parallel execution (1 for test_attention + 4 for test_attention_with_cp)" + NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_ATTN $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py (FA $fa_version)" + NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_CP $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py (FA $fa_version)" + fi + else + echo "Skipping CP tests for FA $fa_version (CP only runs with FA $CP_FA_VERSION on sm$sm_arch)" + NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_ATTN $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py (FA $fa_version)" + fi done + +if [ "$RET" -ne 0 ]; then + echo "Error in the following test cases:$FAILED_CASES" + exit 1 +fi +echo "All tests passed" +exit 0 diff --git a/tests/pytorch/attention/run_attention_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py index 9f6b4944e6..6fca61d3c0 100644 --- a/tests/pytorch/attention/run_attention_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -47,6 +47,7 @@ def generate_input_shapes( config: ModelConfig, world_size: int, kernel_backend: str, + fa_pad_between_seqs: str = "False", ): if qkv_format == "bshd": q_input_shape = ( @@ -115,9 +116,12 @@ def generate_input_shapes( ).cuda() cu_seqlens_q = torch.clone(cu_seqlens_q_padded) - # Since FlashAttention doesn't support pad b/w sequences, and FusedAttention does, - # cu_seqlens_q is updated to reflect non-padded lengths for FusedAttention only. - if kernel_backend == "FusedAttention": + # Generate padded data (cu_seqlens_q reflects non-padded lengths, so it + # differs from cu_seqlens_q_padded) for FusedAttention always, and for + # FlashAttention only when its test param requests it. DPA auto-detects + # pad_between_seqs downstream from the cu_seqlens_q vs cu_seqlens_q_padded + # mismatch. + if kernel_backend == "FusedAttention" or fa_pad_between_seqs == "True": cu_seqlens_q[1:] = seqlens_q.cumsum(0, dtype=torch.int32).cuda() # NOTE: In case of Cross-Attention, `cu_seqlens_kv` and `cu_seqlens_kv_padded` @@ -196,6 +200,7 @@ def run_dpa_with_cp( scaling_mode="delayed", f16_O="False", is_training="True", + fa_pad_between_seqs="False", deterministic="False", log_level=logging.WARNING, ): @@ -314,7 +319,7 @@ def run_dpa_with_cp( cu_seqlens_kv, cu_seqlens_q_padded, cu_seqlens_kv_padded, - ) = generate_input_shapes(qkv_format, config, world_size, kernel_backend) + ) = generate_input_shapes(qkv_format, config, world_size, kernel_backend, fa_pad_between_seqs) q_orig = torch.clamp(torch.randn(q_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() k_orig = torch.clamp(torch.randn(k_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() v_orig = torch.clamp(torch.randn(v_input_shape, dtype=dtypes[dtype]), min=-1, max=1).cuda() @@ -557,11 +562,11 @@ def run_dpa_with_cp( tensors_to_deq[i] = tensor.dequantize() if not fp8_bwd: tensors[0], tensors[5] = tensors_to_deq - for i, tensor in enumerate(tensors): + for tensor, name in zip(tensors, names): # dbias/dbias_ could be None, so skip check for it if tensor is not None: - assert torch.all(~torch.isnan(tensor)), f"{names[i]} contains NaN" - assert torch.all(~torch.isinf(tensor)), f"{names[i]} contains Inf" + assert torch.all(~torch.isnan(tensor)), f"{name} has nan values" + assert torch.all(~torch.isinf(tensor)), f"{name} has inf values" out, dq, dk, dv, dbias, out_, dq_, dk_, dv_, dbias_ = tensors ############ compare results between CP and no-CP ############ @@ -617,49 +622,60 @@ def run_dpa_with_cp( if is_training: dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [dq, out]] dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [dk, dv]] - dq_, dk_, dv_, out_ = [dq_, dk_, dv_, out_] cu_seqlens_q_padded = cu_seqlens_q_padded // world_size cu_seqlens_q = get_cu_seqlens_on_cp_rank( cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True ) - cu_pads_q = cu_seqlens_q_padded - cu_seqlens_q - num_pads_q = cu_pads_q[1:] - cu_pads_q[:-1] - for x in [dq, out, dq_, out_]: - assert torch.count_nonzero(x[cu_seqlens_q_padded[-1] :]).item() == 0 - for b in range(config.batch_size): - assert ( - num_pads_q[b] == 0 - or torch.count_nonzero( - x[ - (cu_seqlens_q_padded[b + 1] - num_pads_q[b]) : cu_seqlens_q_padded[ - b + 1 - ] - ] - ).item() - == 0 - ) + num_pads_q = (cu_seqlens_q_padded - cu_seqlens_q)[1:] - ( + cu_seqlens_q_padded - cu_seqlens_q + )[:-1] cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size cu_seqlens_kv = get_cu_seqlens_on_cp_rank( cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True ) - cu_pads_kv = cu_seqlens_kv_padded - cu_seqlens_kv - num_pads_kv = cu_pads_kv[1:] - cu_pads_kv[:-1] - for x in [dk, dv, dk_, dv_]: - assert torch.count_nonzero(x[cu_seqlens_kv_padded[-1] :]).item() == 0 - for b in range(config.batch_size): - assert ( - num_pads_kv[b] == 0 - or torch.count_nonzero( - x[ - ( - cu_seqlens_kv_padded[b + 1] - num_pads_kv[b] - ) : cu_seqlens_kv_padded[b + 1] - ] - ).item() - == 0 + num_pads_kv = (cu_seqlens_kv_padded - cu_seqlens_kv)[1:] - ( + cu_seqlens_kv_padded - cu_seqlens_kv + )[:-1] + # FA3 leaves garbage at padding positions despite seqused_q/k (tile spillover). + # Forward out_ can't be pre-zeroed because FA3's custom op returns out_ as an + # output rather than mutating it in-place, triggering PyTorch's aliasing constraint. + # Backward dq/dk/dv CAN be pre-zeroed because FA3 marks them as mutated inputs. + if fa_pad_between_seqs == "True": + # out_ is a view inside the CP custom autograd Function, so in-place + # zeroing is blocked by PyTorch. Clone to break the view relationship. + out_ = out_.clone() + for x in [out, out_, dq]: + for b in range(config.batch_size): + x[ + cu_seqlens_q_padded[b + 1] - num_pads_q[b] : cu_seqlens_q_padded[b + 1] + ] = 0.0 + x[cu_seqlens_q_padded[-1] :] = 0.0 + for x in [dk, dv]: + for b in range(config.batch_size): + x[ + cu_seqlens_kv_padded[b + 1] + - num_pads_kv[b] : cu_seqlens_kv_padded[b + 1] + ] = 0.0 + x[cu_seqlens_kv_padded[-1] :] = 0.0 + # Verify CP backward tensors have clean padding (pre-zeroed in context_parallel.py). + for xname, x, cu, np_ in [ + ("dq_", dq_, cu_seqlens_q_padded, num_pads_q), + ("dk_", dk_, cu_seqlens_kv_padded, num_pads_kv), + ("dv_", dv_, cu_seqlens_kv_padded, num_pads_kv), + ]: + nnz = torch.count_nonzero(x[cu[-1] :]).item() + assert nnz == 0, ( + f"{xname} has {nnz} nonzero values in tail padding — " + "context_parallel.py should zero padding positions" ) + for b in range(config.batch_size): + if np_[b] > 0: + nnz = torch.count_nonzero(x[cu[b + 1] - np_[b] : cu[b + 1]]).item() + assert nnz == 0, ( + f"{xname} has {nnz} nonzero values in batch {b} padding — " + "context_parallel.py should zero padding positions" + ) else: - # Forward-only: reshape only out/out_ for comparison out = out.index_select(0, seq_idx_q).contiguous() out_ = out_ diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 32ea1694ee..5c46949f67 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -124,7 +124,7 @@ def reset_global_fp8_state(): @pytest.mark.parametrize("workspace_opt", [True, False]) @pytest.mark.parametrize("qkv_layout", [None]) @pytest.mark.parametrize("swa", [False]) -@pytest.mark.parametrize("pad_between_seqs", [False]) +@pytest.mark.parametrize("pad_between_seqs", [False, True]) def test_dot_product_attention( dtype, model_configs, @@ -157,6 +157,8 @@ def test_dot_product_attention( config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0] + if pad_between_seqs and qkv_format != "thd": + pytest.skip("pad_between_seqs only applies to THD format!") if qkv_format == "thd" and "padding" not in config.attn_mask_type: config.attn_mask_type = ( "padding_" + config.attn_mask_type if config.attn_mask_type != "no_mask" else "padding" @@ -195,19 +197,6 @@ def test_dot_product_attention( ) flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends - # FlashAttention does not support pad_between_seqs, but _run_dot_product_attention - # mannually pads and unpads the input and output of FlashAttention for testing purposes - if ( - pad_between_seqs - and FlashAttentionUtils.is_installed - and not ( - config.max_seqlen_q != config.max_seqlen_kv - and config.attn_mask_type in ["causal", "padding_causal"] - ) - and (config.window_size[0] == -1 or FlashAttentionUtils.v2_3_plus) - ): - flash_attn_supported = True - # Skip if only unfused backend is supported if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: pytest.skip("Less than two backends to compare.") @@ -1301,12 +1290,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: block.softmax_offset.requires_grad = True # Run a forward and backward pass - if backend in ["FlashAttention", "UnfusedDotProductAttention"]: + if backend in ["UnfusedDotProductAttention"]: q = inp_orig[0] k = inp_orig[1] v = inp_orig[2] d_out = out_grad_orig - if backend == "FusedAttention": + if backend in ["FusedAttention", "FlashAttention"]: q = inp[0] k = inp[1] v = inp[2] @@ -1322,14 +1311,19 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: max_seqlen_kv=config.max_seqlen_kv, cu_seqlens_q=cu_seqlens_q, cu_seqlens_kv=cu_seqlens_kv, - cu_seqlens_q_padded=cu_seqlens_q_after_pad if backend == "FusedAttention" else None, - cu_seqlens_kv_padded=cu_seqlens_kv_after_pad if backend == "FusedAttention" else None, + cu_seqlens_q_padded=( + cu_seqlens_q_after_pad if backend in ["FusedAttention", "FlashAttention"] else None + ), + cu_seqlens_kv_padded=( + cu_seqlens_kv_after_pad if backend in ["FusedAttention", "FlashAttention"] else None + ), attn_mask_type=config.attn_mask_type, checkpoint_core_attention=ckpt_attn, core_attention_bias_type=config.attn_bias_type, core_attention_bias=bias, alibi_slopes=alibi_slopes, fast_zero_fill=True, + pad_between_seqs=pad_between_seqs, # Only pass num_splits when exercising the FlashAttention path num_splits=config.num_splits if backend == "FlashAttention" else 1, ) @@ -1343,12 +1337,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: if is_training and config.softmax_type != "vanilla": d_softmax_offset = block.softmax_offset.grad - if backend in ["FlashAttention", "UnfusedDotProductAttention"]: + if backend in ["UnfusedDotProductAttention"]: if is_training: return out, max_logit, (q.grad, k.grad, v.grad, d_softmax_offset) else: return out, max_logit, (None, None, None, d_softmax_offset) - if backend == "FusedAttention": + if backend in ["FusedAttention", "FlashAttention"]: if qkv_format == "thd" and pad_between_seqs: out_orig = torch.Tensor([]).to(device="cuda", dtype=dtype) if is_training: diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index f0d2c27c12..a03f51f6c9 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -306,10 +306,19 @@ def _submit(pool: PoolWorker, **kwargs) -> None: @pytest.mark.parametrize("model", model_configs_flash_attn.keys()) @pytest.mark.parametrize("qkv_format", qkv_formats) @pytest.mark.parametrize("cp_comm_type", cp_comm_types) -def test_cp_with_flash_attention(cp_pool, dtype, model, qkv_format, cp_comm_type): +@pytest.mark.parametrize("pad_between_seqs", [False, True]) +def test_cp_with_flash_attention(cp_pool, dtype, model, qkv_format, cp_comm_type, pad_between_seqs): num_gpus = 4 if cp_comm_type == "a2a+p2p" else 2 pool = cp_pool(num_gpus) + if pad_between_seqs: + if qkv_format != "thd": + pytest.skip("pad_between_seqs only applies to THD format!") + if not FlashAttentionUtils.v3_is_installed or get_device_compute_capability() > (9, 0): + pytest.skip("pad_between_seqs with CP requires Flash Attention v3 on Hopper (sm90)!") + if cp_comm_type == "a2a+p2p": + pytest.skip("pad_between_seqs is not yet supported with A2A+P2P CP comm type!") + config = model_configs_flash_attn[model] config.context_parallel = True config.cp_comm_type = cp_comm_type @@ -361,6 +370,7 @@ def test_cp_with_flash_attention(cp_pool, dtype, model, qkv_format, cp_comm_type qkv_format=qkv_format, kernel_backend="FlashAttention", cp_comm_type=cp_comm_type, + fa_pad_between_seqs=pad_between_seqs, log_level=pytest_logging_level, ) @@ -606,6 +616,7 @@ def test_cp_with_fused_attention( is_training=is_training, deterministic=_deterministic, ) + _, fused_attn_supported, _ = available_backends if fused_attn_supported and config.attn_mask_type in ["causal", "padding_causal"]: config_copy = copy.deepcopy(config) @@ -628,6 +639,23 @@ def test_cp_with_fused_attention( pytest.skip("Deterministic mode does not support non-vanilla softmax with FusedAttention") if _deterministic and config.attn_bias_type == "post_scale_bias" and is_training: pytest.skip("Deterministic mode does not support post_scale_bias with requires_grad") + # Observed: cuDNN det THD backward asks for ~128 * bHSS bytes of workspace + # on sm90; at 1<<30 that's 128 GiB, won't fit on H100's 80 GB. Held exactly + # at b=2 + power-of-2 S in our sweep; for b>=3 the workspace was observed to + # grow super-linearly (b=4 took ~4x the b=2 amount, not 2x) — revisit if a + # config uses b>2. + SM90_DET_FUSED_THD_BWD_MAX_BHSS = 1 << 30 + if ( + _deterministic + and qkv_format == "thd" + and get_device_compute_capability() == (9, 0) + and config.batch_size * config.num_heads * config.max_seqlen_q * config.max_seqlen_kv + >= SM90_DET_FUSED_THD_BWD_MAX_BHSS + ): + pytest.skip( + "Deterministic FusedAttention backward with THD format OOMs on sm90" + " for large bHSS configs (known cuDNN issue)." + ) _submit( pool, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 6e097265ff..6c6adc6e3f 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -822,10 +822,13 @@ def forward( fp8: bool = False, fp8_meta: Optional[Dict[str, Any]] = None, quantizers=None, + pad_between_seqs: Optional[bool] = False, inference_params: Optional[InferenceParams] = None, flash_attention_backend: Optional[PkgVersion] = PkgVersion("0"), fp8_output: bool = False, num_splits: Optional[int] = 1, + cu_seqlens_q_padded: Optional[torch.Tensor] = None, + cu_seqlens_kv_padded: Optional[torch.Tensor] = None, ) -> torch.Tensor: """flash-attn fprop""" @@ -1024,8 +1027,16 @@ def forward( cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, - cu_seqlens_q if qkv_format == "thd" else None, - cu_seqlens_kv if qkv_format == "thd" else None, + ( + cu_seqlens_q_padded + if pad_between_seqs + else (cu_seqlens_q if qkv_format == "thd" else None) + ), + ( + cu_seqlens_kv_padded + if pad_between_seqs + else (cu_seqlens_kv if qkv_format == "thd" else None) + ), self.attention_dropout if self.training else 0.0, cp_group, cp_global_ranks, @@ -1037,7 +1048,7 @@ def forward( deterministic=self.deterministic, window_size=window_size, quantizers=quantizers, - pad_between_seqs=False, + pad_between_seqs=pad_between_seqs, use_flash_attn_3=use_flash_attn_3, fp8_output=fp8_output, ) @@ -1082,8 +1093,12 @@ def forward( else: func = flash_attn_with_kvcache_v3 # pylint: disable=possibly-used-before-assignment if not use_flash_attn_4 and (not use_flash_attn_3 or inference_params is None): - fa_optional_forward_args_thd.append(cu_seqlens_q) - fa_optional_forward_args_thd.append(cu_seqlens_kv) + fa_optional_forward_args_thd.append( + cu_seqlens_q_padded if pad_between_seqs else cu_seqlens_q + ) + fa_optional_forward_args_thd.append( + cu_seqlens_kv_padded if pad_between_seqs else cu_seqlens_kv + ) fa_optional_forward_args_thd.append(max_seqlen_q) fa_optional_forward_args_thd.append(max_seqlen_kv) if use_flash_attn_4: @@ -1139,6 +1154,13 @@ def forward( fa_3_optional_forward_kwargs = {} fa_3_optional_forward_kwargs["window_size"] = window_size fa_3_optional_forward_kwargs["num_splits"] = num_splits + if pad_between_seqs: + fa_3_optional_forward_kwargs["seqused_q"] = ( + cu_seqlens_q[1:] - cu_seqlens_q[:-1] + ) + fa_3_optional_forward_kwargs["seqused_k"] = ( + cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + ) if inference_params is None: fa_3_optional_forward_kwargs["deterministic"] = self.deterministic else: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 35684625a5..36847e40ed 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -663,6 +663,8 @@ def get_fa_args( dq=None, dk=None, dv=None, + seqused_q=None, + seqused_k=None, ): """Get forward/backward arguments for flash-attn v2 and v3.""" if use_flash_attn_3: @@ -672,7 +674,9 @@ def get_fa_args( *[None] * 4, # k_new, v_new, qv, out cu_seqlens_q, cu_seqlens_kv, - *[None] * 3, # cu_seqlens_k_new, seqused_q, seqused_k + None, # cu_seqlens_k_new + seqused_q, + seqused_k, max_seqlen_q, max_seqlen_kv, *[None] @@ -690,8 +694,8 @@ def get_fa_args( return [ cu_seqlens_q, cu_seqlens_kv, - None, # sequed_q - None, # sequed_k + seqused_q, + seqused_k, max_seqlen_q, max_seqlen_kv, dq, @@ -701,8 +705,8 @@ def get_fa_args( return [ None, # cu_seqlens_q None, # cu_seqlens_kv - None, # sequed_q - None, # sequed_k + None, # seqused_q + None, # seqused_k max_seqlen_q, max_seqlen_kv, dq, @@ -1020,6 +1024,9 @@ def cp_p2p_fwd_flash_attn( flash_attn_fwd, max_seqlen_q, max_seqlen_kv, + pad_between_seqs, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, q_part, k_part, v_part, @@ -1046,6 +1053,20 @@ def cp_p2p_fwd_flash_attn( fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = -1 + seqused_q = None + seqused_k = None + if pad_between_seqs and use_flash_attn_3 and qkv_format == "thd": + # Derive actual token counts per batch element from cu_seqlens + seqused_q = cu_seqlens_q_per_step[1:] - cu_seqlens_q_per_step[:-1] + seqused_k = cu_seqlens_kv_per_step[1:] - cu_seqlens_kv_per_step[:-1] + # Override cu_seqlens to padded layout for tensor memory layout + cu_seqlens_q_ = cu_seqlens_q_padded + cu_seqlens_kv_ = cu_seqlens_kv_padded + if section == "lower-triangle": + cu_seqlens_kv_ = cu_seqlens_kv_padded // 2 + elif section == "upper-triangle": + cu_seqlens_q_ = cu_seqlens_q_padded // 2 + fa_forward_args_thd = get_fa_args( True, use_flash_attn_3, @@ -1054,6 +1075,8 @@ def cp_p2p_fwd_flash_attn( cu_seqlens_kv=cu_seqlens_kv_, max_seqlen_q=max_seqlen_q_, max_seqlen_kv=max_seqlen_kv_, + seqused_q=seqused_q, + seqused_k=seqused_k, ) fa_outputs = flash_attn_fwd( q_part, @@ -1296,6 +1319,9 @@ def cp_p2p_bwd_flash_attn( rng_states, softmax_lse, softmax_lse_, + pad_between_seqs, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, q_part, k_part, v_part, @@ -1304,7 +1330,10 @@ def cp_p2p_bwd_flash_attn( section, ): """Per-tile backward call of CP P2P with FlashAttention backend""" - dq, dk, dv = [torch.empty_like(x) for x in [q_part, k_part, v_part]] + if pad_between_seqs: + dq, dk, dv = [torch.zeros_like(x) for x in [q_part, k_part, v_part]] + else: + dq, dk, dv = [torch.empty_like(x) for x in [q_part, k_part, v_part]] if fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus: fa_backward_kwargs["window_size"] = (-1, -1) elif use_flash_attn_3 or fa_utils.v2_7_0_plus: @@ -1329,17 +1358,33 @@ def cp_p2p_bwd_flash_attn( max_seqlen_q_ = max_seqlen_q // 2 softmax_lse__ = softmax_lse_ + seqused_q = None + seqused_k = None + cu_seqlens_q_bwd = cu_seqlens_q_per_step[cp_size - step - 1] + cu_seqlens_kv_bwd = cu_seqlens_kv_per_step[cp_size - step - 1] + if pad_between_seqs and use_flash_attn_3 and qkv_format == "thd": + seqused_q = cu_seqlens_q_bwd[1:] - cu_seqlens_q_bwd[:-1] + seqused_k = cu_seqlens_kv_bwd[1:] - cu_seqlens_kv_bwd[:-1] + cu_seqlens_q_bwd = cu_seqlens_q_padded + cu_seqlens_kv_bwd = cu_seqlens_kv_padded + if section == "lower-triangle": + cu_seqlens_kv_bwd = cu_seqlens_kv_padded // 2 + elif section == "upper-triangle": + cu_seqlens_q_bwd = cu_seqlens_q_padded // 2 + fa_backward_args_thd = get_fa_args( False, use_flash_attn_3, qkv_format, - cu_seqlens_q=cu_seqlens_q_per_step[cp_size - step - 1], - cu_seqlens_kv=cu_seqlens_kv_per_step[cp_size - step - 1], + cu_seqlens_q=cu_seqlens_q_bwd, + cu_seqlens_kv=cu_seqlens_kv_bwd, max_seqlen_q=max_seqlen_q_, max_seqlen_kv=max_seqlen_kv_, dq=dq, dk=dk, dv=dv, + seqused_q=seqused_q, + seqused_k=seqused_k, ) if use_flash_attn_3: fa_backward_kwargs["is_causal"] = causal_ @@ -1779,6 +1824,9 @@ def forward( flash_attn_fwd, max_seqlen_q, max_seqlen_kv, + pad_between_seqs, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, ] # cp_size = 4: @@ -1821,7 +1869,9 @@ def forward( else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( cp_p2p_fwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) ) elif i <= rank: @@ -1848,7 +1898,9 @@ def forward( else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( cp_p2p_fwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) ) else: @@ -1875,7 +1927,9 @@ def forward( else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( cp_p2p_fwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) ) else: @@ -1900,7 +1954,11 @@ def forward( ) = cp_p2p_fwd_fused_attn(*fused_attn_inputs, *prepare_outputs, section) else: out_per_step[i], softmax_lse_per_step[i], rng_states[i] = ( - cp_p2p_fwd_flash_attn(*flash_attn_inputs, *prepare_outputs, section) + cp_p2p_fwd_flash_attn( + *flash_attn_inputs, + *prepare_outputs, + section, + ) ) # softmax_lse correction @@ -2150,6 +2208,7 @@ def forward( ctx.attn_bias_shape = None if attn_bias is None else attn_bias.shape ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention + ctx.pad_between_seqs = pad_between_seqs ctx.softmax_lse_in_packed_format = softmax_lse_in_packed_format ctx.second_half_lse_seqlen = second_half_lse_seqlen ctx.fp8_meta = fp8_meta @@ -2560,6 +2619,9 @@ def backward(ctx, dout, *_args): rng_states, softmax_lse, softmax_lse_, + ctx.pad_between_seqs, + cu_seqlens_q_padded, + cu_seqlens_kv_padded, ] # Reverse the steps in forward. In the cp_size x cp_size (i.e. GPU x step) matrix, @@ -2575,7 +2637,9 @@ def backward(ctx, dout, *_args): ) else: dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) elif i >= (cp_size - rank - 1): section = "lower-triangle" @@ -2586,7 +2650,9 @@ def backward(ctx, dout, *_args): ) else: dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) else: section = "upper-triangle" @@ -2597,7 +2663,9 @@ def backward(ctx, dout, *_args): ) else: dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) else: section = "all" @@ -2608,7 +2676,9 @@ def backward(ctx, dout, *_args): ) else: dq_, dk_, dv_ = cp_p2p_bwd_flash_attn( - *flash_attn_inputs, *prepare_outputs, section + *flash_attn_inputs, + *prepare_outputs, + section, ) # dq, dk, dv are reduced across steps in higher precision @@ -3838,6 +3908,7 @@ def forward( cp_group, cp_stream, quantizers, + pad_between_seqs, use_flash_attn_3, softmax_type, softmax_offset, @@ -4073,14 +4144,25 @@ def forward( out_f16 = out_.dequantize(dtype=fwd_nominal_dtype) out_part = out_f16 else: + seqused_q = None + seqused_k = None + fa_cu_seqlens_q = cu_seqlens_q + fa_cu_seqlens_kv = cu_seqlens_kv + if pad_between_seqs and use_flash_attn_3 and qkv_format == "thd": + seqused_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seqused_k = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + fa_cu_seqlens_q = cu_seqlens_q_padded + fa_cu_seqlens_kv = cu_seqlens_kv_padded fa_forward_args_thd = get_fa_args( True, use_flash_attn_3, qkv_format, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q=fa_cu_seqlens_q, + cu_seqlens_kv=fa_cu_seqlens_kv, max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, + seqused_q=seqused_q, + seqused_k=seqused_k, ) fa_outputs = flash_attn_fwd( q_part, @@ -4217,6 +4299,7 @@ def forward( ctx.fwd_nominal_dtype = fwd_nominal_dtype ctx.fp8_recipe = fp8_recipe ctx.use_flash_attn_3 = use_flash_attn_3 + ctx.pad_between_seqs = pad_between_seqs ctx.softmax_type = softmax_type ctx.dQKV_quantizer = dQKV_quantizer @@ -4405,18 +4488,32 @@ def backward(ctx, dout, *_args): dq, dk, dv = [x._data for x in [dq, dk, dv]] else: softmax_lse, rng_state = aux_ctx_tensors - dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] + if ctx.pad_between_seqs: + dq, dk, dv = [torch.zeros_like(x) for x in [q, k, v]] + else: + dq, dk, dv = [torch.empty_like(x) for x in [q, k, v]] + seqused_q = None + seqused_k = None + fa_cu_seqlens_q = cu_seqlens_q + fa_cu_seqlens_kv = cu_seqlens_kv + if ctx.pad_between_seqs and ctx.use_flash_attn_3 and ctx.dqkv_format == "thd": + seqused_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seqused_k = cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] + fa_cu_seqlens_q = cu_seqlens_q_padded + fa_cu_seqlens_kv = cu_seqlens_kv_padded fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, ctx.dqkv_format, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, + cu_seqlens_q=fa_cu_seqlens_q, + cu_seqlens_kv=fa_cu_seqlens_kv, max_seqlen_q=ctx.max_seqlen_q, max_seqlen_kv=ctx.max_seqlen_kv, dq=dq, dk=dk, dv=dv, + seqused_q=seqused_q, + seqused_k=seqused_k, ) if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_state @@ -4524,6 +4621,7 @@ def backward(ctx, dout, *_args): None, None, None, + None, d_softmax_offset, None, ) @@ -4740,6 +4838,7 @@ def attn_forward_func_with_cp( cp_group, cp_stream, quantizers, + pad_between_seqs, use_flash_attn_3, softmax_type, softmax_offset, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index b38b66c3e6..ca848a9480 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -1658,10 +1658,13 @@ def forward( fp8=self.fp8 and self.fp8_meta["recipe"].fp8_dpa, fp8_meta=self.fp8_meta, quantizers=self.quantizers, + pad_between_seqs=pad_between_seqs, inference_params=inference_params, flash_attention_backend=flash_attention_backend, fp8_output=fp8_output, num_splits=num_splits, + cu_seqlens_q_padded=cu_seqlens_q_padded, + cu_seqlens_kv_padded=cu_seqlens_kv_padded, ) if use_fused_attention: diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 1f1637cecd..6565e9f6f6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -651,7 +651,7 @@ def get_attention_backend( # backend | precision | KV cache | architecture | qkv_format | page_size # --------------------------------------------------------------------------------------- # Fused | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | >= 1 - # Flash v2 | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | >= 256 + # Flash v2 | FP16/BF16 | non-paged/paged | sm80+ | bshd,sbhd,thd | % 256 == 0 # Flash v3 | FP16/BF16 | non-paged/paged | sm90 | bshd,sbhd,thd | >= 1 # | FP8 | non-paged/paged | sm90 | thd | >= 1 # Flash v4 | FP16/BF16 | TODO | sm80+ | bshd,sbhd,thd | TODO @@ -691,9 +691,9 @@ def get_attention_backend( use_fused_attention = False use_unfused_attention = False if inference_params.is_paged: - if use_flash_attention_2 and inference_params.page_size < 256: + if use_flash_attention_2 and inference_params.page_size % 256 != 0: if FlashAttentionUtils.is_installed: - logger.debug("Disabling FlashAttention 2 for page size < 256") + logger.debug("Disabling FlashAttention 2 for page size not divisible by 256") use_flash_attention_2 = False if use_flash_attention_2: if not FlashAttentionUtils.is_installed: @@ -703,6 +703,16 @@ def get_attention_backend( "Disabling FlashAttention 2 as paged attention requires flash-attn 2.5+" ) use_flash_attention_2 = False + else: + # Non-paged KV cache still passes a block_table to FA2 for thd_2bshd support, + # and FA2 enforces page_size % 256 == 0 on the effective page size (max_seqlen_kv). + if use_flash_attention_2 and max_seqlen_kv % 256 != 0: + if FlashAttentionUtils.is_installed: + logger.debug( + "Disabling FlashAttention 2 for non-paged KV cache" + " with max_seqlen_kv not divisible by 256" + ) + use_flash_attention_2 = False if use_flash_attention_4 and FlashAttentionUtils.v4_is_installed: logger.debug("Disabling FlashAttention 4 as it does not support KV cache.") use_flash_attention_4 = False @@ -844,15 +854,18 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt if qkv_format == "thd": if pad_between_seqs: if ( # pylint: disable=too-many-boolean-expressions - (use_flash_attention_2 and FlashAttentionUtils.is_installed) - or (use_flash_attention_3 and FlashAttentionUtils.v3_is_installed) - or (use_flash_attention_4 and FlashAttentionUtils.v4_is_installed) - ): + use_flash_attention_2 and FlashAttentionUtils.is_installed + ) or (use_flash_attention_4 and FlashAttentionUtils.v4_is_installed): logger.debug( - "Disabling FlashAttention for qkv_format = thd when there is " + "Disabling FlashAttention 2 and 4 for qkv_format = thd when there is " "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" ) - use_flash_attention = False + use_flash_attention_2 = False + use_flash_attention_4 = False + # FA3 supports pad_between_seqs via seqused_q/seqused_k + if use_unfused_attention: + logger.debug("Disabling UnfusedDotProductAttention for pad_between_seqs = True") + use_unfused_attention = False if device_compute_capability == (12, 0): if cudnn_version < (9, 18, 1): if use_fused_attention: @@ -1273,9 +1286,12 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt ) use_flash_attention_2 = False if use_flash_attention_3 and deterministic and FlashAttentionUtils.v3_is_installed: - if head_dim_qk >= 256: + if is_training and max(head_dim_qk, head_dim_v) >= 256: logger.debug( - "Disabling FlashAttention 3 for deterministic execution with head_dim_qk >= 256." + "Disabling FlashAttention 3 for deterministic backward with" + " max(head_dim_qk, head_dim_v) >= 256. Found: head_dim_qk = %s, head_dim_v = %s.", + head_dim_qk, + head_dim_v, ) use_flash_attention_3 = False if use_fused_attention and deterministic: