diff --git a/docs/envvars.rst b/docs/envvars.rst index 1e040b4c3e..ffbad409d4 100644 --- a/docs/envvars.rst +++ b/docs/envvars.rst @@ -142,9 +142,9 @@ Attention Backend Selection .. envvar:: NVTE_FUSED_ATTN_BACKEND - :Type: ``int`` (0, 1, or 2) + :Type: ``int`` (1 or 2) :Default: Auto-selected - :Description: Force a specific FusedAttention backend. ``0`` = F16_max512_seqlen (cuDNN, ≤512 seq len), ``1`` = F16_arbitrary_seqlen (cuDNN, any seq len), ``2`` = FP8 backend. If not set, the backend is automatically selected based on the input configuration. + :Description: Force a specific FusedAttention backend. ``1`` = F16_arbitrary_seqlen (cuDNN, any seq len), ``2`` = FP8 backend. If not set, the backend is automatically selected based on the input configuration. .. envvar:: NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT @@ -281,6 +281,12 @@ Kernel Configuration :Default: ``0`` :Description: Emit a warning when falling back from CUTLASS to cuBLAS for grouped GEMM operations. +.. envvar:: NVTE_NVFP4_ROW_SCALED_ACTIVATION + + :Type: ``int`` (0 or 1) + :Default: ``0`` + :Description: Enable row-scaled NVFP4 tensors for forward activation quantizers in the ``NVFP4BlockScaling`` recipe. When set to ``1`` (or when ``NVFP4BlockScaling(row_scaled_activation=True)`` is used), rowwise ``amax`` metadata is stored as one FP32 value per tensor row instead of a single scalar. + Torch Compilation and Fusion ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 15d7c695c9..1f37520bc7 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -114,16 +114,14 @@ void quantize_nvfp4_1d(float (*OP)(const float), block_amax = std::max(block_amax, std::abs(elt)); } - // 2. Compute E4M3 scaling factor - // Compute per-block encoding/decoding scaling factor - const float S_dec_b = block_amax / 6.0f; - - // Scale & Store per-block decoding scaling factor - const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + // Compute and store the per-block FP8 decode scale + const float S_dec_b = block_amax * (S_enc * (1.0f / 6.0f)); + const fp8e4m3 S_dec_b_fp8 = static_cast(fminf(S_dec_b, Numeric_Traits::maxNorm)); const float S_dec_b_fp32 = static_cast(S_dec_b_fp8); // Compute "correct" per-block encoding scaling factor - const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32; + const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : + fminf(1.0f / (S_dec_b_fp32 * (1.0f / S_enc)), Numeric_Traits::maxNorm); const size_t scale_idx = i * scales_stride + block_X; scales[scale_idx] = S_dec_b_fp8; @@ -317,11 +315,31 @@ void compute_ref(float (*OP)(const float), const size_t scales_stride, const size_t scales_stride_t, const bool use_fast_math, - const bool use_2d_quantization = false) + const bool use_2d_quantization = false, + std::vector *rowwise_amax = nullptr) { std::vector input_t = create_transpose(input, rows, cols); - if (use_2d_quantization) { + if (rowwise_amax != nullptr) { + rowwise_amax->resize(rows, 0.0f); + for (size_t row = 0; row < rows; ++row) { + float row_amax = 0.0f; + for (size_t col = 0; col < cols; ++col) { + row_amax = fmaxf(row_amax, fabsf(static_cast(input[row * cols + col]))); + } + (*rowwise_amax)[row] = row_amax; + quantize_nvfp4(OP, + input + row * cols, + output + row * (cols / 2), + scales + row * scales_stride, + 1, + cols, + scales_stride, + row_amax, + use_fast_math, + use_2d_quantization); + } + } else if (use_2d_quantization) { // Step 1: Compute mathematical 8×8 scaling factors std::vector> math_scales; compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math); @@ -504,13 +522,12 @@ void print_detailed_tensor_comparison(const std::string& name, void compareResults_nvfp4(const Tensor &test, const void *ref, const void *ref_t, const int rows, const int cols, - double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true, bool dump_data = false) { + double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true, + bool dump_data = false, bool compare_columnwise = true) { if (if_on_gpus) test.to_cpu(); const fp4e2m1 *test_data = test.rowwise_cpu_dptr(); - const fp4e2m1 *test_data_t = test.columnwise_cpu_dptr(); const fp4e2m1 *ref_data = reinterpret_cast(ref); - const fp4e2m1 *ref_data_t = reinterpret_cast(ref_t); // Print detailed element-by-element comparison // print_detailed_tensor_comparison("output", test_data, ref_data, rows, cols); @@ -519,17 +536,33 @@ void compareResults_nvfp4(const Tensor &test, // Optionally dump tensor data to files for detailed analysis if (dump_data) { dump_nvfp4_tensor_data("output", test_data, ref_data, rows, cols); - dump_nvfp4_tensor_data("output_t", test_data_t, ref_data_t, cols, rows); } compare_nvfp4_tensors("output", test_data, ref_data, rows, cols, atol, rtol); - compare_nvfp4_tensors("output_t", test_data_t, ref_data_t, cols, rows, atol, rtol); + if (compare_columnwise) { + const fp4e2m1 *test_data_t = test.columnwise_cpu_dptr(); + const fp4e2m1 *ref_data_t = reinterpret_cast(ref_t); + if (dump_data) { + dump_nvfp4_tensor_data("output_t", test_data_t, ref_data_t, cols, rows); + } + compare_nvfp4_tensors("output_t", test_data_t, ref_data_t, cols, rows, atol, rtol); + } +} + +void compare_rowwise_amax(const Tensor &output, const std::vector &ref_amax) { + const std::vector test_amax_data = output.tensor_amax_values(); + ASSERT_EQ(test_amax_data.size(), ref_amax.size()); + for (size_t row = 0; row < ref_amax.size(); ++row) { + ASSERT_EQ(test_amax_data[row], ref_amax[row]) + << "Row-scaled amax mismatch at row " << row; + } } template void performTest(float (*OP)(const float), const std::vector& shape, - const bool use_fast_math) { + const bool use_fast_math, + const bool row_scaled_nvfp4 = false) { using namespace test; DType itype = TypeInfo::dtype; @@ -556,7 +589,7 @@ void performTest(float (*OP)(const float), const size_t scales_stride_t = blocks_X_t; Tensor input("input", shape, itype); - Tensor output("output", shape, otype, true, true, NVTE_NVFP4_1D_SCALING); + Tensor output("output", shape, otype, true, !row_scaled_nvfp4, NVTE_NVFP4_1D_SCALING); std::unique_ptr ref_output = std::make_unique(rows * (cols / 2)); std::unique_ptr ref_output_t = std::make_unique(cols * (rows / 2)); @@ -567,26 +600,44 @@ void performTest(float (*OP)(const float), // Golden value of amax chosen to make the 2nd-stage scaling mantissa zero and avoid rounding issues const float amax = 448.0f * 6.0f * 8.0f; - - // Set 2nd stage NVFP4 scaling factor - output.set_tensor_amax(amax); - output.set_tensor_amax_columnwise(amax); - + std::vector ref_rowwise_amax; bool use_2d_quantization = false; + if (row_scaled_nvfp4) { + output.set_tensor_amax_shape({rows}); + output.set_row_scaled_nvfp4(true); + compute_ref(OP, + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output_t.get(), + ref_scales.get(), + ref_scales_t.get(), + 0.0f, + rows, + cols, + scales_stride, + scales_stride_t, + use_fast_math, + use_2d_quantization, + &ref_rowwise_amax); + } else { + // Set 2nd stage NVFP4 scaling factor + output.set_tensor_amax(amax); + output.set_tensor_amax_columnwise(amax); + compute_ref(OP, + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output_t.get(), + ref_scales.get(), + ref_scales_t.get(), + amax, + rows, + cols, + scales_stride, + scales_stride_t, + use_fast_math, + use_2d_quantization); + } - compute_ref(OP, - input.rowwise_cpu_dptr(), - ref_output.get(), - ref_output_t.get(), - ref_scales.get(), - ref_scales_t.get(), - amax, - rows, - cols, - scales_stride, - scales_stride_t, - use_fast_math, - use_2d_quantization); // Initialize stochastic rounding Tensor rng_state("rng_state", std::vector{2}, DType::kInt64); rng_state.rowwise_cpu_dptr()[0] = 123; // rng_seed @@ -629,12 +680,8 @@ void performTest(float (*OP)(const float), const double rtol = 1.0E-6; // Set dump_data=true to enable dumping tensor data to files for analysis - compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, false); - - const fp8e4m3* kernel_scales = output.rowwise_cpu_scale_inv_ptr(); - const fp8e4m3* ref_scales_ptr = ref_scales.get(); - const fp8e4m3* kernel_scales_t = output.columnwise_cpu_scale_inv_ptr(); - const fp8e4m3* ref_scales_t_ptr = ref_scales_t.get(); + compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, + false, !row_scaled_nvfp4); size_t scale_mismatches_num = 0; compare_scaling_factors("scales", output.rowwise_cpu_scale_inv_ptr(), @@ -642,10 +689,16 @@ void performTest(float (*OP)(const float), unpadded_blocks_Y, unpadded_blocks_X, scales_stride, scale_mismatches_num); - compare_scaling_factors("scales_t", output.columnwise_cpu_scale_inv_ptr(), - ref_scales_t.get(), - unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t, - scale_mismatches_num); + if (!row_scaled_nvfp4) { + compare_scaling_factors("scales_t", output.columnwise_cpu_scale_inv_ptr(), + ref_scales_t.get(), + unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t, + scale_mismatches_num); + } + + if (row_scaled_nvfp4) { + compare_rowwise_amax(output, ref_rowwise_amax); + } } std::vector> tensor_dims = { @@ -678,6 +731,7 @@ class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam , transformer_engine::DType, + bool, bool>> {}; TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { @@ -693,6 +747,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { const auto tensor_dims = std::get<1>(GetParam()); const DType input_type = std::get<2>(GetParam()); const bool use_fast_math = std::get<3>(GetParam()); + const bool row_scaled_nvfp4 = std::get<4>(GetParam()); // Skip tests if the input tensor is 1D if (tensor_dims.size() < 2) { @@ -710,7 +765,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { } TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, - performTest(OP, tensor_dims, use_fast_math); + performTest(OP, tensor_dims, use_fast_math, row_scaled_nvfp4); ); } @@ -733,6 +788,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(Activation_types), ::testing::ValuesIn(tensor_dims), ::testing::Values(DType::kBFloat16), + ::testing::Values(false), ::testing::Values(false)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)); @@ -746,3 +802,28 @@ INSTANTIATE_TEST_SUITE_P( } return name; }); + +INSTANTIATE_TEST_SUITE_P( + OperatorTestRowScaled, + FusedCastTransposeNVFP4TestSuite, + ::testing::Combine( + ::testing::Values(ActivationType::Identity), + ::testing::Values(tensor_dims[4], tensor_dims[9], tensor_dims[12]), + ::testing::Values(DType::kBFloat16, DType::kFloat32), + ::testing::Values(false), + ::testing::Values(true)), + [](const testing::TestParamInfo& info) { + std::string name = to_string(std::get<0>(info.param)); + const auto& shape = std::get<1>(info.param); + for (const auto& s: shape) { + name += "X" + std::to_string(s); + } + name += "X" + test::typeName(std::get<2>(info.param)); + if (std::get<3>(info.param)) { + name += "X_FAST_SCALING"; + } + if (std::get<4>(info.param)) { + name += "XROW_SCALED"; + } + return name; + }); diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index 96e85cb5ed..ec405b1d90 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -42,7 +42,7 @@ float2 cvt_fp4x2_to_float2(fp4e2m1x2 fp4_pair) { template void compute_ref_dequantize_nvfp4(const uint8_t *packed_data, const fp8e4m3 *scales, - float amax, + const std::vector &amax, OType *output, size_t rows, size_t cols, @@ -55,7 +55,8 @@ void compute_ref_dequantize_nvfp4(const uint8_t *packed_data, for (size_t row = 0; row < rows; ++row) { for (size_t block = 0; block < Mread; ++block) { const fp8e4m3 scale = scales[row * scale_stride + block]; - const float final_scale = static_cast(scale) * amax * factor_inv; + const float final_scale = + static_cast(scale) * (amax.size() == 1 ? amax[0] : amax[row]) * factor_inv; for (size_t pair_idx = 0; pair_idx < bytes_per_block; ++pair_idx) { const size_t byte_idx = @@ -88,7 +89,8 @@ float compute_amax(const test::Tensor &t, size_t rows, size_t cols) { // Quantize a high-precision input to NVFP4, then dequantize and compare // against a CPU reference computed from the quantized data. template -void performTest_dequantize_nvfp4(const size_t rows, const size_t cols) { +void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, + const bool row_scaled_nvfp4) { using namespace test; DType otype = TypeInfo::dtype; @@ -97,7 +99,10 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols) { Tensor quantized("quantized", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); - if (rows > 0 && cols > 0) { + if (row_scaled_nvfp4) { + quantized.set_tensor_amax_shape({rows}); + quantized.set_row_scaled_nvfp4(true); + } else if (rows > 0 && cols > 0) { quantized.set_tensor_amax(compute_amax(input, rows, cols)); } else { quantized.set_tensor_amax(0.0f); @@ -120,7 +125,7 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols) { const uint8_t *fp4_data = reinterpret_cast(quantized.rowwise_cpu_dptr()); const fp8e4m3 *scales = quantized.rowwise_cpu_scale_inv_ptr(); - const float amax_val = quantized.amax(); + const std::vector amax_val = quantized.tensor_amax_values(); const NVTEShape scale_shape = quantized.rowwise_scale_inv_shape(); const size_t scale_stride = scale_shape.data[scale_shape.ndim - 1]; @@ -137,7 +142,8 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols) { // Dequantize NVFP4 with GEMM-swizzled scales and compare against compact path. template -void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols) { +void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, + const bool row_scaled_nvfp4) { using namespace test; DType otype = TypeInfo::dtype; @@ -146,7 +152,10 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols) Tensor quantized_compact("quantized_compact", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); - if (rows > 0 && cols > 0) { + if (row_scaled_nvfp4) { + quantized_compact.set_tensor_amax_shape({rows}); + quantized_compact.set_row_scaled_nvfp4(true); + } else if (rows > 0 && cols > 0) { quantized_compact.set_tensor_amax(compute_amax(input, rows, cols)); } else { quantized_compact.set_tensor_amax(0.0f); @@ -157,7 +166,7 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols) cudaDeviceSynchronize(); } - // Dequantize with compact scales → reference output + // Dequantize with compact scales to get the reference output. Tensor output_compact("output_compact", std::vector{rows, cols}, otype, true, false); nvte_dequantize(quantized_compact.data(), output_compact.data(), 0); cudaDeviceSynchronize(); @@ -165,13 +174,22 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols) // Create tensor with same FP4 data but swizzled scales Tensor quantized_swizzled("quantized_swizzled", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); - quantized_swizzled.set_tensor_amax(0.0f); + if (row_scaled_nvfp4) { + quantized_swizzled.set_tensor_amax_shape({rows}); + quantized_swizzled.set_row_scaled_nvfp4(true); + } else { + quantized_swizzled.set_tensor_amax(0.0f); + } quantized_swizzled.set_with_gemm_swizzled_scales(true); // Copy amax and scale from compact to swizzled before FP4 data, // since from_cpu() uploads all CPU buffers (including zero-init data). quantized_compact.to_cpu(); - quantized_swizzled.set_tensor_amax(quantized_compact.amax()); + if (row_scaled_nvfp4) { + quantized_swizzled.copy_tensor_amax_from(quantized_compact); + } else { + quantized_swizzled.set_tensor_amax(quantized_compact.amax()); + } // Copy FP4 data after from_cpu() to avoid being overwritten const size_t data_bytes = rows * cols / 2; @@ -227,7 +245,8 @@ std::vector> nvfp4_tensor_dims = { class DequantizeNVFP4TestSuite : public ::testing::TestWithParam , - transformer_engine::DType>> {}; + transformer_engine::DType, + bool>> {}; TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) { @@ -237,10 +256,11 @@ TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) const auto tensor_size = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); + const bool row_scaled_nvfp4 = std::get<2>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, performTest_dequantize_nvfp4( - tensor_size.first, tensor_size.second); + tensor_size.first, tensor_size.second, row_scaled_nvfp4); ); } @@ -249,19 +269,22 @@ INSTANTIATE_TEST_SUITE_P( DequantizeNVFP4TestSuite, ::testing::Combine( ::testing::ValuesIn(nvfp4_tensor_dims), - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16)), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Bool()), [](const testing::TestParamInfo& info) { std::string name = std::to_string(std::get<0>(info.param).first) + "X" + std::to_string(std::get<0>(info.param).second) + "X" + - test::typeName(std::get<1>(info.param)); + test::typeName(std::get<1>(info.param)) + "X" + + (std::get<2>(info.param) ? "RowScaled" : "PerTensor"); return name; } ); class DequantizeNVFP4SwizzledTestSuite : public ::testing::TestWithParam , - transformer_engine::DType>> {}; + transformer_engine::DType, + bool>> {}; TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled) { @@ -271,10 +294,11 @@ TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled) const auto tensor_size = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); + const bool row_scaled_nvfp4 = std::get<2>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, performTest_dequantize_nvfp4_swizzled( - tensor_size.first, tensor_size.second); + tensor_size.first, tensor_size.second, row_scaled_nvfp4); ); } @@ -283,12 +307,14 @@ INSTANTIATE_TEST_SUITE_P( DequantizeNVFP4SwizzledTestSuite, ::testing::Combine( ::testing::ValuesIn(nvfp4_tensor_dims), - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16)), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Bool()), [](const testing::TestParamInfo& info) { std::string name = std::to_string(std::get<0>(info.param).first) + "X" + std::to_string(std::get<0>(info.param).second) + "X" + test::typeName(std::get<1>(info.param)) + "X" + + (std::get<2>(info.param) ? "RowScaled" : "PerTensor") + "X" + "Swizzled"; return name; } diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index c756b83810..96e71f9513 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -543,6 +543,59 @@ void Tensor::set_scale(float scale) { } } +void Tensor::set_tensor_amax_shape(const std::vector &shape) { + const size_t numel = product(shape); + NVTE_CHECK(tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING, + "Amax shape override is only supported for NVFP4 test tensors."); + + auto old_amax = tensor_.get_amax(); + if (old_amax.data_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaFree(old_amax.data_ptr)); + } + + float *amax = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&amax, numel * sizeof(float))); + NVTE_CHECK_CUDA(cudaMemset(amax, 0, numel * sizeof(float))); + tensor_.set_amax(amax, DType::kFloat32, shape); +} + +std::vector Tensor::tensor_amax_values() const { + const auto amax = tensor_.get_amax(); + NVTE_CHECK(static_cast(amax.dtype) == DType::kFloat32, "Tensor amax must be FP32."); + + const size_t numel = product(amax.shape); + if (numel == 0) { + return {}; + } + NVTE_CHECK(amax.data_ptr != nullptr, "Tensor amax is not allocated."); + + std::vector values(numel); + NVTE_CHECK_CUDA( + cudaMemcpy(values.data(), amax.data_ptr, numel * sizeof(float), cudaMemcpyDeviceToHost)); + return values; +} + +void Tensor::copy_tensor_amax_from(const Tensor &other) { + const auto other_amax = other.tensor_.get_amax(); + NVTE_CHECK(static_cast(other_amax.dtype) == DType::kFloat32, + "Source tensor amax must be FP32."); + + auto my_amax = tensor_.get_amax(); + NVTE_CHECK(static_cast(my_amax.dtype) == DType::kFloat32, + "Destination tensor amax must be FP32."); + NVTE_CHECK(areShapesEqual(my_amax.shape, other_amax.shape), "Amax shape mismatch."); + + const size_t numel = product(other_amax.shape); + if (numel == 0) { + return; + } + + NVTE_CHECK(other_amax.data_ptr != nullptr, "Source tensor amax is not allocated."); + NVTE_CHECK(my_amax.data_ptr != nullptr, "Destination tensor amax is not allocated."); + NVTE_CHECK_CUDA(cudaMemcpy(my_amax.data_ptr, other_amax.data_ptr, numel * sizeof(float), + cudaMemcpyDeviceToDevice)); +} + void Tensor::set_scale_inv(float scale_inv) { if (isFp8Type(dtype()) || isFp4Type(dtype())) { if (rowwise_) { diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index b8389d5833..b2a7da89cf 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -319,10 +319,18 @@ class Tensor { tensor_.set_amax(nullptr, DType::kFloat32, tensor_.defaultShape); } + void set_tensor_amax_shape(const std::vector &shape); + std::vector tensor_amax_values() const; + void copy_tensor_amax_from(const Tensor &other); + void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales){ tensor_.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); } + void set_row_scaled_nvfp4(bool row_scaled_nvfp4) { + tensor_.set_row_scaled_nvfp4(row_scaled_nvfp4); + } + void to_cpu() const; void from_cpu() const; void set_scale(float scale); diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index c9ea791444..32ea1694ee 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -227,40 +227,16 @@ def test_dot_product_attention( # FusedAttention backend if fused_attn_supported: - if len(fused_attn_backends) == 1: - fused_attn_fwd, fused_max_logit, fused_attn_bwd = _run_dot_product_attention( - dtype, - config, - "FusedAttention", - ckpt_attn, - qkv_layout, - workspace_opt, - pad_between_seqs, - is_training, - ) - if len(fused_attn_backends) == 2: - os.environ["NVTE_FUSED_ATTN_BACKEND"] = "0" - fused_attn_fwd, _, fused_attn_bwd = _run_dot_product_attention( - dtype, - config, - "FusedAttention", - ckpt_attn, - qkv_layout, - workspace_opt, - pad_between_seqs, - is_training, - ) - os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" - fused_attn_fwd_1, _, fused_attn_bwd_1 = _run_dot_product_attention( - dtype, - config, - "FusedAttention", - ckpt_attn, - qkv_layout, - workspace_opt, - pad_between_seqs, - is_training, - ) + fused_attn_fwd, fused_max_logit, fused_attn_bwd = _run_dot_product_attention( + dtype, + config, + "FusedAttention", + ckpt_attn, + qkv_layout, + workspace_opt, + pad_between_seqs, + is_training, + ) # FlashAttention backend if flash_attn_supported: @@ -294,11 +270,6 @@ def test_dot_product_attention( torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols) for i, _ in enumerate(flash_attn_bwd): torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols) - if fused_attn_supported and len(fused_attn_backends) == 2: - logging.info("[test_dot_product_attention]: fused attn backend 0 vs 1") - torch.testing.assert_close(fused_attn_fwd, fused_attn_fwd_1, **tols) - for i, _ in enumerate(fused_attn_bwd): - torch.testing.assert_close(fused_attn_bwd[i], fused_attn_bwd_1[i], **tols) @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") @@ -2579,28 +2550,21 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: "fp8_8": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"), } param_types_fp8 = [torch.float16, torch.bfloat16] -cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1")) -models_v0 = ["fp8_1", "fp8_2", "fp8_5", "fp8_6"] -models_v1 = ["fp8_3", "fp8_4", "fp8_7", "fp8_8"] @pytest.mark.skipif( - ( - get_cudnn_version() < (8, 9, 3) - if cudnn_frontend_version == 0 - else get_cudnn_version() < (9, 2, 1) - ), - reason=f"""cuDNN {"8.9.3" if cudnn_frontend_version == 0 else "9.2.1"}+ is required.""", + get_cudnn_version() < (9, 2, 1), + reason="cuDNN 9.2.1+ is required for FP8 fused attention.", ) @pytest.mark.skipif(not fp8_attn_available, reason=reason_for_no_fp8_attn) @pytest.mark.parametrize("dtype", param_types_fp8) -@pytest.mark.parametrize("model", models_v1 if cudnn_frontend_version == 1 else models_v0) +@pytest.mark.parametrize("model", model_configs_fp8) def test_custom_mha_fp8_vs_f16(dtype, model): """Test FP8 dot product attention implementations based on cuDNN frontend v0.9 and v1.0+. Each test compares results from a custom implementation of an FP8 MHA module, i.e. Custom_MHA_FP8(), to results from an F16 MHA implementation, i.e. transformer_engine.pytorch.attention.MultiHeadAttention. - Both paths take F16 input and output. QKV layout is t3hd or bs3hd""" + Both paths take F16 input and output. QKV layout is bs3hd""" config = model_configs_fp8[model] @@ -2609,7 +2573,7 @@ def test_custom_mha_fp8_vs_f16(dtype, model): available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=torch.float8_e4m3fn, - qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd", + qkv_layout="bs3hd", is_training=is_training, deterministic=_deterministic, ) @@ -2816,18 +2780,17 @@ def forward( quantization_params=qkv_quantizer, use_split_accumulator=_2X_ACC_FPROP, ) - qkv_layout = "bs3hd" if cudnn_frontend_version == 1 else "t3hd" - o_format = "bshd" if cudnn_frontend_version == 1 else "thd" + qkv_layout = "bs3hd" + o_format = "bshd" qkv = qkv.view(-1, 3, h, d) qkv_fp16 = qkv.dequantize().view(b, max_s, 3, h, d).contiguous() torch.save(qkv_fp16, "qkv.pt") - if cudnn_frontend_version == 1: - qkv = qkv.view(b, max_s, 3, h, d) # bs3hd + qkv = qkv.view(b, max_s, 3, h, d) # bs3hd # FMHA - q_data = qkv._data[:, :, 0, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 0, :, :] - k_data = qkv._data[:, :, 1, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 1, :, :] - v_data = qkv._data[:, :, 2, :, :] if cudnn_frontend_version == 1 else qkv._data[:, 2, :, :] + q_data = qkv._data[:, :, 0, :, :] + k_data = qkv._data[:, :, 1, :, :] + v_data = qkv._data[:, :, 2, :, :] q = qkv.make_like(tensor=qkv, data=q_data, shape=q_data.shape) k = qkv.make_like(tensor=qkv, data=k_data, shape=k_data.shape) v = qkv.make_like(tensor=qkv, data=v_data, shape=v_data.shape) @@ -2849,7 +2812,7 @@ def forward( qkv_layout=qkv_layout, o_format=o_format, attn_bias_type="no_bias", - attn_mask_type=mask_type if cudnn_frontend_version == 1 else "padding", + attn_mask_type=mask_type, rng_gen=None, o_quantizer=o_quantizer, s_quantizer=s_quantizer, @@ -2916,9 +2879,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], do_format=ctx.o_format, dqkv_layout=ctx.qkv_layout, attn_bias_type="no_bias", - attn_mask_type=ctx.mask_type if cudnn_frontend_version == 1 else "padding", + attn_mask_type=ctx.mask_type, ) - dim = 2 if cudnn_frontend_version == 1 else 1 + dim = 2 dqkv = torch.Tensor().to(device=dq._data.device, dtype=dq._data.dtype) dqkv_shape = list(dq._data.shape) dqkv_shape.insert(dim, 3) diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 911b7660dc..b939336275 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -8,6 +8,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch import NVFP4Quantizer +from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils @@ -26,6 +27,7 @@ def check_nvfp4_gemm_versus_reference( *, x_columnwise: bool = False, w_columnwise: bool = False, + row_scaled_nvfp4: bool = False, ): te_dtype = tex.DType.kFloat4E2M1 @@ -51,11 +53,12 @@ def check_nvfp4_gemm_versus_reference( x_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, rowwise=True, - columnwise=True, + columnwise=not row_scaled_nvfp4, with_amax_reduction=False, amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + row_scaled_nvfp4=row_scaled_nvfp4, ) w_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -112,7 +115,16 @@ def check_nvfp4_gemm_versus_reference( sw_trimmed = sw_trimmed.view(torch.float8_e4m3fn) # Create reference quantizer for reference GEMM - ref_quantizer = NVFP4QuantizerRef( + x_ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=not row_scaled_nvfp4, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + row_scaled_nvfp4=row_scaled_nvfp4, + ) + w_ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, rowwise=True, columnwise=True, @@ -124,16 +136,16 @@ def check_nvfp4_gemm_versus_reference( # Create reference quantized tensors needed by reference GEMM # Reference GEMM is only rowwise. if x_columnwise: - x_nvfp4_ref = ref_quantizer.quantize(x.t().contiguous()) + x_nvfp4_ref = x_ref_quantizer.quantize(x.t().contiguous()) else: - x_nvfp4_ref = ref_quantizer.quantize(x) + x_nvfp4_ref = x_ref_quantizer.quantize(x) if w_columnwise: - w_nvfp4_ref = ref_quantizer.quantize(w.t().contiguous()) + w_nvfp4_ref = w_ref_quantizer.quantize(w.t().contiguous()) else: - w_nvfp4_ref = ref_quantizer.quantize(w) + w_nvfp4_ref = w_ref_quantizer.quantize(w) # Reference GEMM using quantizer's qgemm method - y_ref = ref_quantizer.qgemm( + y_ref = x_ref_quantizer.qgemm( qx=qx_data, qw=qw_data, m_params=None, # MMParams not used in reference @@ -166,27 +178,38 @@ def check_nvfp4_gemm_versus_reference( x_nvfp4_native.update_usage(rowwise_usage=False) if w_columnwise: w_nvfp4_native.update_usage(rowwise_usage=False) - # Native cuBLAS GEMM - # return type is out, bias_grad, gelu_input, extra_output - # We are just capturing out. - y_native = tex.generic_gemm( - w_nvfp4_native, - transa, - x_nvfp4_native, - transb, - out.clone() if accumulate else None, - out_quantizer, - TE_DType[out_dtype], - bias, - bias_dtype, - use_gelu, - gelu_input, - use_grad, - workspace, - workspace.shape[0], - accumulate, - use_split_accumulator, - )[0] + if row_scaled_nvfp4: + layout = ("T" if transa else "N") + ("T" if transb else "N") + y_native = general_gemm( + w_nvfp4_native, + x_nvfp4_native, + out_dtype=out_dtype, + accumulate=accumulate, + layout=layout, + out=out.clone() if accumulate else None, + )[0] + else: + # Native cuBLAS GEMM + # return type is out, bias_grad, gelu_input, extra_output + # We are just capturing out. + y_native = tex.generic_gemm( + w_nvfp4_native, + transa, + x_nvfp4_native, + transb, + out.clone() if accumulate else None, + out_quantizer, + TE_DType[out_dtype], + bias, + bias_dtype, + use_gelu, + gelu_input, + use_grad, + workspace, + workspace.shape[0], + accumulate, + use_split_accumulator, + )[0] # just in case of accumulation, make sure y_ref and y_native are not the same tensor assert y_ref is not y_native, "y_ref and y_native should not be the same tensor" @@ -199,6 +222,170 @@ def check_nvfp4_gemm_versus_reference( torch.testing.assert_close(y_native, y_ref, atol=8e-3, rtol=8e-3) +def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( + x_dtype: torch.dtype, + w_dtype: torch.dtype, + out_dtype: torch.dtype, + m_splits: list[int], + k: int, + n: int, + *, + use_bias: bool, + single_output: bool, +): + te_dtype = tex.DType.kFloat4E2M1 + device = "cuda" + torch.manual_seed(23) + torch.cuda.manual_seed(23) + + num_gemms = len(m_splits) + + x_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=False, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + row_scaled_nvfp4=True, + ) + w_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + + x_nvfp4 = [] + w_nvfp4 = [] + bias = [] + expected = [] + for m in m_splits: + x = torch.randn((m, k), dtype=x_dtype, device=device) + w = torch.randn((n, k), dtype=w_dtype, device=device) + x_nvfp4.append( + x_quantizer.update_quantized( + x, x_quantizer.make_empty(x.shape, dtype=x_dtype, device=device) + ) + ) + w_nvfp4.append( + w_quantizer.update_quantized( + w, w_quantizer.make_empty(w.shape, dtype=w_dtype, device=device) + ) + ) + bias.append(torch.randn(n, dtype=torch.bfloat16, device=device) if use_bias else None) + expected.append( + general_gemm( + w_nvfp4[-1], + x_nvfp4[-1], + out_dtype=out_dtype, + layout="TN", + bias=bias[-1], + )[0] + ) + + if single_output: + out = [torch.empty((sum(m_splits), n), dtype=out_dtype, device=device)] + else: + out = [torch.empty((m, n), dtype=out_dtype, device=device) for m in m_splits] + + grouped_out, _, _ = general_grouped_gemm( + w_nvfp4, + x_nvfp4, + out, + quantization_params=[None] * num_gemms, + out_dtype=out_dtype, + layout="TN", + m_splits=m_splits, + bias=bias, + use_bias=use_bias, + single_output=single_output, + ) + + if single_output: + grouped_slices = torch.split(grouped_out, m_splits, dim=0) + else: + grouped_slices = grouped_out + for grouped, ref in zip(grouped_slices, expected): + torch.testing.assert_close(grouped, ref, atol=0.0, rtol=0.0) + + +def check_nvfp4_row_scaled_gemm_matches_emulated( + x_dtype: torch.dtype, + w_dtype: torch.dtype, + out_dtype: torch.dtype, + M: int, + K: int, + N: int, +): + te_dtype = tex.DType.kFloat4E2M1 + device = "cuda" + torch.manual_seed(37) + torch.cuda.manual_seed(37) + + x = torch.randn((M, K), dtype=x_dtype, device=device) + w = torch.randn((N, K), dtype=w_dtype, device=device) + + x_row_scaled_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=False, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + row_scaled_nvfp4=True, + ) + x_tensorwise_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + w_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + + x_row_scaled = x_row_scaled_quantizer.update_quantized( + x, x_row_scaled_quantizer.make_empty(x.shape, dtype=x_dtype, device=device) + ) + w_nvfp4 = w_quantizer.update_quantized( + w, w_quantizer.make_empty(w.shape, dtype=w_dtype, device=device) + ) + y_row_scaled = general_gemm(w_nvfp4, x_row_scaled, out_dtype=out_dtype, layout="TN")[0] + + emulated_rows = [] + for i in range(M): + x_padded = torch.zeros((16, K), dtype=x_dtype, device=device) + x_padded[0].copy_(x[i]) + x_tensorwise = x_tensorwise_quantizer.update_quantized( + x_padded, + x_tensorwise_quantizer.make_empty(x_padded.shape, dtype=x_dtype, device=device), + ) + emulated_rows.append( + general_gemm(w_nvfp4, x_tensorwise, out_dtype=out_dtype, layout="TN")[0][:1] + ) + + y_emulated = torch.cat(emulated_rows, dim=0) + if out_dtype == torch.bfloat16: + torch.testing.assert_close(y_row_scaled, y_emulated, atol=0.0, rtol=7.8e-3) + else: + torch.testing.assert_close(y_row_scaled, y_emulated, atol=3.0517578125e-5, rtol=0.0) + + @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) @pytest.mark.parametrize( "M, K, N", @@ -229,6 +416,7 @@ def check_nvfp4_gemm_versus_reference( ], ids=["rowxrow", "colxrow", "colxcol"], ) +@pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) def test_nvfp4_gemm_versus_reference( M: int, K: int, @@ -239,7 +427,14 @@ def test_nvfp4_gemm_versus_reference( accumulate: bool, is_x_columnwise: bool, is_w_columnwise: bool, + row_scaled_nvfp4: bool, ): + if row_scaled_nvfp4: + if accumulate: + pytest.skip("Row-scaled NVFP4 GEMM output rescale does not support accumulation") + if is_x_columnwise: + pytest.skip("Row-scaled NVFP4 GEMM output rescale requires rowwise RHS usage") + check_nvfp4_gemm_versus_reference( x_dtype=x_dtype, w_dtype=w_dtype, @@ -250,4 +445,87 @@ def test_nvfp4_gemm_versus_reference( accumulate=accumulate, x_columnwise=is_x_columnwise, w_columnwise=is_w_columnwise, + row_scaled_nvfp4=row_scaled_nvfp4, + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "m_splits, k, n", + [ + ([32, 48, 48], 128, 128), + ([64, 80, 112], 128, 256), + ([64, 80, 112], 256, 256), + ([64, 80, 112], 1024, 256), + ([256, 256, 512], 1024, 1024), + ([1024, 1536, 1536], 512, 3072), + ([16, 32, 64], 128, 96), + ([80, 96, 128], 640, 304), + ([320, 336, 352], 3072, 992), + ([64, 80, 112], 64, 256), + ([32, 48, 48], 128, 112), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("use_bias", [False, True], ids=["no_bias", "bias"]) +@pytest.mark.parametrize("single_output", [False, True], ids=["list_output", "single_output"]) +def test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( + m_splits: list[int], + k: int, + n: int, + x_dtype: torch.dtype, + w_dtype: torch.dtype, + out_dtype: torch.dtype, + use_bias: bool, + single_output: bool, +): + check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( + x_dtype=x_dtype, + w_dtype=w_dtype, + out_dtype=out_dtype, + m_splits=m_splits, + k=k, + n=n, + use_bias=use_bias, + single_output=single_output, + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, K, N", + [ + (128, 128, 128), + (256, 128, 256), + (256, 256, 256), + (256, 1024, 256), + (1024, 1024, 1024), + (4096, 512, 3072), + (112, 128, 96), + (304, 640, 304), + (1008, 3072, 992), + (256, 64, 256), + (128, 128, 112), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +def test_nvfp4_row_scaled_gemm_matches_emulated( + M: int, + K: int, + N: int, + x_dtype: torch.dtype, + w_dtype: torch.dtype, + out_dtype: torch.dtype, +): + check_nvfp4_row_scaled_gemm_matches_emulated( + x_dtype=x_dtype, + w_dtype=w_dtype, + out_dtype=out_dtype, + M=M, + K=K, + N=N, ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index bf3f545b8b..0824a5e7bc 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -16,6 +16,19 @@ recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) +def maybe_skip_row_scaled_unsupported_quantization( + row_scaled_nvfp4: bool, + return_transpose: bool, + with_2d_quantization: bool = False, +) -> None: + if not row_scaled_nvfp4: + return + if return_transpose: + pytest.skip("Row-scaled NVFP4 does not support columnwise usage") + if with_2d_quantization: + pytest.skip("Row-scaled NVFP4 does not support 2D quantization") + + def unpack_fp4(x: torch.Tensor) -> torch.Tensor: repeated = x.repeat_interleave(2, dim=1) repeated[:, 0::2] &= 0x0F @@ -31,7 +44,12 @@ def check_quantization_nvfp4_versus_reference( swizzled_scale: bool, use_cpp_allocator: bool, with_2d_quantization: bool, + row_scaled_nvfp4: bool = False, ) -> None: + maybe_skip_row_scaled_unsupported_quantization( + row_scaled_nvfp4, return_transpose, with_2d_quantization + ) + te_dtype = tex.DType.kFloat4E2M1 # Setup device and random seed @@ -52,6 +70,7 @@ def check_quantization_nvfp4_versus_reference( with_rht=False, with_post_rht_amax=False, with_2d_quantization=with_2d_quantization, + row_scaled_nvfp4=row_scaled_nvfp4, ) if use_cpp_allocator: x_nvfp4_sut = nvfp4_quantizer(x) @@ -73,6 +92,7 @@ def check_quantization_nvfp4_versus_reference( ) sx_t = x_nvfp4_sut._columnwise_scale_inv qx_amax = x_nvfp4_sut._amax_rowwise + qx_amax_t = x_nvfp4_sut._amax_columnwise # Reference quantization quant_tile_shape = (1, 16) if not with_2d_quantization else (16, 16) @@ -83,6 +103,7 @@ def check_quantization_nvfp4_versus_reference( pow_2_scales=False, eps=0.0, quant_tile_shape=quant_tile_shape, + row_scaled_nvfp4=row_scaled_nvfp4, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -102,6 +123,7 @@ def check_quantization_nvfp4_versus_reference( x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None ) ref_amax = x_nvfp4_ref.global_amax_row + ref_amax_t = x_nvfp4_ref.global_amax_col qx = unpack_fp4(qx) qx_t = unpack_fp4(qx_t) if qx_t is not None else None @@ -121,6 +143,7 @@ def check_quantization_nvfp4_versus_reference( ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) @@ -155,6 +178,7 @@ def check_quantization_nvfp4_versus_reference( @pytest.mark.parametrize( "with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"] ) +@pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) def test_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, @@ -163,6 +187,7 @@ def test_quantization_block_tiling_versus_reference( swizzled_scale: bool, use_cpp_allocator: bool, with_2d_quantization: bool, + row_scaled_nvfp4: bool, ) -> None: check_quantization_nvfp4_versus_reference( x_dtype=x_dtype, @@ -172,6 +197,7 @@ def test_quantization_block_tiling_versus_reference( swizzled_scale=swizzled_scale, use_cpp_allocator=use_cpp_allocator, with_2d_quantization=with_2d_quantization, + row_scaled_nvfp4=row_scaled_nvfp4, ) @@ -188,6 +214,7 @@ def test_quantization_block_tiling_versus_reference( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) +@pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) def test_nvfp4_quantization_extrema_versus_reference( x_dtype: torch.dtype, M: int, @@ -195,7 +222,10 @@ def test_nvfp4_quantization_extrema_versus_reference( extrema_high: bool, return_transpose: bool, use_cpp_allocator: bool, + row_scaled_nvfp4: bool, ): + maybe_skip_row_scaled_unsupported_quantization(row_scaled_nvfp4, return_transpose) + te_dtype = tex.DType.kFloat4E2M1 device = "cuda" @@ -216,6 +246,7 @@ def test_nvfp4_quantization_extrema_versus_reference( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + row_scaled_nvfp4=row_scaled_nvfp4, ) if use_cpp_allocator: @@ -237,6 +268,7 @@ def test_nvfp4_quantization_extrema_versus_reference( ) sx_t = x_nvfp4_sut._columnwise_scale_inv qx_amax = x_nvfp4_sut._amax_rowwise + qx_amax_t = x_nvfp4_sut._amax_columnwise ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, @@ -245,6 +277,7 @@ def test_nvfp4_quantization_extrema_versus_reference( pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), + row_scaled_nvfp4=row_scaled_nvfp4, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -257,6 +290,7 @@ def test_nvfp4_quantization_extrema_versus_reference( x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None ) ref_amax = x_nvfp4_ref.global_amax_row + ref_amax_t = x_nvfp4_ref.global_amax_col torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) @@ -269,6 +303,7 @@ def test_nvfp4_quantization_extrema_versus_reference( ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) @@ -286,18 +321,22 @@ def test_nvfp4_quantization_extrema_versus_reference( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) +@pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) def test_nvfp4_quantization_boundary_values( x_dtype: torch.dtype, M: int, N: int, return_transpose: bool, use_cpp_allocator: bool, + row_scaled_nvfp4: bool, ): """ Stress rounding/threshold behavior by placing values just below/above many potential bin edges within each 16-element microblock. Validates native vs reference byte-for-byte and scale parity. """ + maybe_skip_row_scaled_unsupported_quantization(row_scaled_nvfp4, return_transpose) + te_dtype = tex.DType.kFloat4E2M1 device = "cuda" @@ -327,6 +366,7 @@ def test_nvfp4_quantization_boundary_values( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + row_scaled_nvfp4=row_scaled_nvfp4, ) if use_cpp_allocator: @@ -348,6 +388,7 @@ def test_nvfp4_quantization_boundary_values( ) sx_t = x_nvfp4_sut._columnwise_scale_inv qx_amax = x_nvfp4_sut._amax_rowwise + qx_amax_t = x_nvfp4_sut._amax_columnwise ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, @@ -356,6 +397,7 @@ def test_nvfp4_quantization_boundary_values( pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), + row_scaled_nvfp4=row_scaled_nvfp4, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -368,6 +410,7 @@ def test_nvfp4_quantization_boundary_values( x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None ) ref_amax = x_nvfp4_ref.global_amax_row + ref_amax_t = x_nvfp4_ref.global_amax_col torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) @@ -381,6 +424,7 @@ def test_nvfp4_quantization_boundary_values( ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) @@ -397,13 +441,17 @@ def test_nvfp4_quantization_boundary_values( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) +@pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) def test_nvfp4_quantization_noncontiguous_inputs( x_dtype: torch.dtype, M: int, N: int, return_transpose: bool, use_cpp_allocator: bool, + row_scaled_nvfp4: bool, ): + maybe_skip_row_scaled_unsupported_quantization(row_scaled_nvfp4, return_transpose) + te_dtype = tex.DType.kFloat4E2M1 device = "cuda" @@ -424,6 +472,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + row_scaled_nvfp4=row_scaled_nvfp4, ) if use_cpp_allocator: @@ -445,6 +494,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( ) sx_t = x_nvfp4_sut._columnwise_scale_inv qx_amax = x_nvfp4_sut._amax_rowwise + qx_amax_t = x_nvfp4_sut._amax_columnwise ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, @@ -453,6 +503,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), + row_scaled_nvfp4=row_scaled_nvfp4, ) x_nvfp4_ref = ref_quantizer.quantize(x_nc) @@ -465,6 +516,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None ) ref_amax = x_nvfp4_ref.global_amax_row + ref_amax_t = x_nvfp4_ref.global_amax_col # Quantized must match torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) @@ -479,5 +531,6 @@ def test_nvfp4_quantization_noncontiguous_inputs( ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index ed4f73adbc..c7c5a5b99d 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -78,6 +78,11 @@ marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), id="NVFP4BlockScaling", ), + pytest.param( + "nvfp4_row_scaled", + marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), + id="NVFP4RowScaledBlockScaling", + ), ] @@ -165,7 +170,7 @@ def _maybe_skip_recipe_dtype( ) -> None: if dtype == torch.bfloat16 and not bf16_available: pytest.skip(reason_for_no_bf16) - if recipe_name == "nvfp4": + if recipe_name in ("nvfp4", "nvfp4_row_scaled"): if module_type in ("linear", "layernorm_linear") and dtype not in ( torch.bfloat16, torch.float32, @@ -178,6 +183,14 @@ def _maybe_skip_recipe_dtype( def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: str) -> None: if module_type == "ops_linear" and recipe_name == "fp8_block_scaling": pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") + if module_type == "ops_linear" and recipe_name == "nvfp4_row_scaled": + pytest.skip("Row-scaled NVFP4 currently does not support fused te_ops paths.") + + +def _make_quantized_forward_reference_recipe(recipe_name: str) -> recipe.Recipe: + if recipe_name == "nvfp4_row_scaled": + return make_recipe(recipe_name, backward_override="dequantized") + return make_recipe(recipe_name) def _maybe_skip_unsupported_recipe_shape( @@ -195,7 +208,9 @@ def _maybe_skip_unsupported_recipe_shape( " by 32." ) return - if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + if recipe_name in ("nvfp4", "nvfp4_row_scaled") and ( + flat_first_dim % 16 != 0 or last_dim % 16 != 0 + ): pytest.skip( "Linear/LayerNormLinear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible" " by 16." @@ -220,7 +235,9 @@ def _maybe_skip_unsupported_recipe_shape( pytest.skip( "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." ) - if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + if recipe_name in ("nvfp4", "nvfp4_row_scaled") and ( + flat_first_dim % 16 != 0 or last_dim % 16 != 0 + ): pytest.skip( "te_ops.Linear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible by 16." ) @@ -239,9 +256,9 @@ def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int] ) if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") - if recipe_name == "nvfp4" and any(m % 16 != 0 for m in non_empty_splits): + if recipe_name in ("nvfp4", "nvfp4_row_scaled") and any(m % 16 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + NVFP4 requires each non-empty m_split divisible by 16.") - if recipe_name == "nvfp4" and any(m % 64 != 0 for m in non_empty_splits): + if recipe_name in ("nvfp4", "nvfp4_row_scaled") and any(m % 64 != 0 for m in non_empty_splits): pytest.skip( "GroupedLinear + NVFP4 grouped split_quantize currently requires each non-empty " "m_split divisible by 64 due to grouped amax kernel constraints." @@ -847,7 +864,7 @@ def test_linear_like_backward_override_matches_reference( _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, module_type) in_features = input_shape[-1] - quantized_ref_recipe = make_recipe(recipe_name) + quantized_ref_recipe = _make_quantized_forward_reference_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override(module_type, mode_recipe, backward_override) @@ -1031,8 +1048,9 @@ def test_grouped_linear_backward_override_matches_reference( num_gemms = len(m_splits) num_tokens = sum(m_splits) - quantized_ref_recipe = make_recipe(recipe_name) + quantized_ref_recipe = _make_quantized_forward_reference_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) + skip_unsupported_backward_override("grouped_linear", mode_recipe, backward_override) module_quantized_ref = te.GroupedLinear( num_gemms, @@ -1200,6 +1218,7 @@ def test_linear_like_runtime_backward_override_switch_updates_ctx( dy = torch.randn(*input_shape[:-1], out_features, dtype=dtype, device="cuda") default_recipe = make_recipe(recipe_name) + skip_unsupported_backward_override(module_type, default_recipe, None) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override(module_type, mode_recipe, backward_override) @@ -1270,7 +1289,9 @@ def test_grouped_linear_runtime_backward_override_switch_updates_ctx( dy = torch.randn(num_tokens, out_features, dtype=dtype, device="cuda") default_recipe = make_recipe(recipe_name) + skip_unsupported_backward_override("grouped_linear", default_recipe, None) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) + skip_unsupported_backward_override("grouped_linear", mode_recipe, backward_override) *_, default_ctx = _run_grouped_linear_single_step_with_ctx_state( module, @@ -1336,7 +1357,7 @@ def test_fused_linear_paths_match_backward_override_reference( reset_rng_states() - quantized_ref_recipe = make_recipe(recipe_name) + quantized_ref_recipe = _make_quantized_forward_reference_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override("ops_linear", mode_recipe, backward_override) @@ -1476,7 +1497,7 @@ def test_fused_bias_activation_matches_masked_linear_backward( reset_rng_states() in_features = input_shape[-1] - quantized_ref_recipe = make_recipe(recipe_name) + quantized_ref_recipe = _make_quantized_forward_reference_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override("ops_linear", mode_recipe, backward_override) @@ -1715,7 +1736,11 @@ def test_backward_override_memory_peak_report( x = torch.randn(*input_shape, dtype=dtype, device="cuda") dy = torch.randn(*input_shape[:-1], out_features, dtype=dtype, device="cuda") - modes = (None, "high_precision", "dequantized") + modes = ( + ("high_precision", "dequantized") + if recipe_name == "nvfp4_row_scaled" + else (None, "high_precision", "dequantized") + ) mode_results: dict[str, dict[str, float] | str] = {} for mode in modes: diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index a782dadc60..33ba65e0d9 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -20,17 +20,19 @@ is_fp8_available, is_fp8_block_scaling_available, is_mxfp8_available, + is_nvfp4_available, is_bf16_available, ) from transformer_engine.pytorch.quantization import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops from transformer_engine.common import recipe -from utils import ModelConfig, reset_rng_states, skip_unsupported_backward_override +from utils import ModelConfig, recipe_id, reset_rng_states, skip_unsupported_backward_override # Check if FP8 is supported. fp8_available = is_fp8_available() fp8_block_scaling_available = is_fp8_block_scaling_available() mxfp8_available = is_mxfp8_available() +nvfp4_available = is_nvfp4_available() # Reset RNG states. reset_rng_states() @@ -62,6 +64,14 @@ def nvfp4_rht_and_2d_quantization(): return nvfp4_recipe +def nvfp4_row_scaled(): + nvfp4_recipe = recipe.NVFP4BlockScaling(row_scaled_activation=True) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + def check_rht_usage(recipe: recipe.Recipe) -> bool: # if using RHT, we can only support bf16 # check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad @@ -88,7 +98,9 @@ def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> fp8_recipes = [] if mxfp8_available: fp8_recipes.append(recipe.MXFP8BlockScaling()) +if nvfp4_available: fp8_recipes.append(nvfp4_rht_and_2d_quantization()) + fp8_recipes.append(nvfp4_row_scaled()) if fp8_block_scaling_available: fp8_recipes.append(recipe.Float8BlockScaling()) if fp8_available: @@ -360,7 +372,7 @@ def _test_cuda_graphs( @pytest.mark.parametrize("module", _test_cuda_graphs_modules) @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("fp8_params", (False, True)) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=lambda r: type(r).__name__) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=recipe_id) @pytest.mark.parametrize("backward_override", (None, "high_precision", "dequantized")) def test_make_graphed_callables( *, @@ -390,6 +402,8 @@ def test_make_graphed_callables( f"Module not yet supported for {fp8_recipe.__class__.__name__} with CUDA graphs" ) if fp8 and fp8_recipe.nvfp4(): + if getattr(fp8_recipe, "row_scaled_activation", False) and module == "mha": + pytest.skip("Row-scaled NVFP4 CUDA graph coverage applies to GEMM modules.") if dtype not in get_nvfp4_inp_supported_dtypes(fp8_recipe, dtype): pytest.skip( f"Input dtype {dtype} not supported for NVFP4 Recipe" @@ -448,7 +462,7 @@ def test_make_graphed_callables( ) @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("fp8_params", (False, True)) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=lambda r: type(r).__name__) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("backward_override", (None, "high_precision", "dequantized")) def test_make_graphed_callables_with_fp8_weight_caching( *, diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 91d4b89013..5f5221af76 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -25,10 +25,16 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.quantization import ( FP8GlobalStateManager, + NVFP4BlockScalingRecipeState, _amax_and_scale_update, ) import transformer_engine.pytorch.ops as te_ops -from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling +from transformer_engine.common.recipe import ( + DelayedScaling, + Float8BlockScaling, + MXFP8BlockScaling, + NVFP4BlockScaling, +) # Check if FP8 is supported fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) @@ -507,8 +513,30 @@ def test_quantizer_update(self, module_class): y = module(x) +@pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) +def test_nvfp4_row_scaled_quantizer_roles(): + recipe = NVFP4BlockScaling(row_scaled_activation=True) + + forward_quantizers = NVFP4BlockScalingRecipeState( + recipe, + mode="forward", + num_quantizers=3, + ).make_quantizers() + assert [q.row_scaled_nvfp4 for q in forward_quantizers] == [True, False, True] + assert not forward_quantizers[0].is_quantizable(torch.empty(16, 16)) + assert forward_quantizers[1].is_quantizable(torch.empty(16, 16)) + + backward_quantizers = NVFP4BlockScalingRecipeState( + recipe, + mode="backward", + num_quantizers=2, + ).make_quantizers() + assert [q.row_scaled_nvfp4 for q in backward_quantizers] == [False, False] + + @pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) @pytest.mark.parametrize( "M, N", [ @@ -524,12 +552,19 @@ def test_quantizer_update(self, module_class): (8192, 8192), ], ) -def test_fp4_dequantize(dtype, M, N): - q = NVFP4Quantizer() +def test_fp4_dequantize(dtype, row_scaled_nvfp4, M, N): + q = NVFP4Quantizer( + columnwise=not row_scaled_nvfp4, + row_scaled_nvfp4=row_scaled_nvfp4, + ) a = torch.rand((M, N)).cuda().to(dtype=dtype) starting_tensor = q(a) + assert starting_tensor._row_scaled_nvfp4 == row_scaled_nvfp4 + assert starting_tensor._amax_rowwise.numel() == (M if row_scaled_nvfp4 else 1) dequantized_tensor = starting_tensor.dequantize() new_tensor = q(dequantized_tensor) + assert new_tensor._row_scaled_nvfp4 == row_scaled_nvfp4 + assert new_tensor._amax_rowwise.numel() == (M if row_scaled_nvfp4 else 1) torch.testing.assert_close( new_tensor._rowwise_data, starting_tensor._rowwise_data, diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 7f2f24fd69..c811342df5 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -38,12 +38,13 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.cpp_extensions import general_gemm from transformer_engine.pytorch.tensor.utils import replace_raw_data -from utils import ModelConfig, skip_unsupported_backward_override +from utils import ModelConfig, recipe_id, skip_unsupported_backward_override # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) fp8_block_scaling_available, _ = te.is_fp8_block_scaling_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +nvfp4_available, _ = te.is_nvfp4_available(return_reason=True) # Record initial RNG state from script run. seed = 1234 @@ -93,9 +94,18 @@ def nvfp4_vanilla(): return nvfp4_recipe +def nvfp4_row_scaled(): + nvfp4_recipe = recipe.NVFP4BlockScaling(row_scaled_activation=True) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + fp8_recipes = [] if mxfp8_available: fp8_recipes.append(recipe.MXFP8BlockScaling()) +if nvfp4_available: fp8_recipes.append(nvfp4_vanilla()) # TODO: fix check for this if fp8_block_scaling_available: fp8_recipes.append(recipe.Float8BlockScaling()) @@ -103,6 +113,9 @@ def nvfp4_vanilla(): fp8_recipes.append(recipe.Float8CurrentScaling()) fp8_recipes.append(recipe.DelayedScaling()) fp8_recipes.append(None) +fp8_recipes_with_row_scaled = fp8_recipes.copy() +if nvfp4_available: + fp8_recipes_with_row_scaled.insert(-1, nvfp4_row_scaled()) param_types = [torch.float32, torch.float16] if is_bf16_available(): # bf16 requires sm_80 or higher @@ -402,7 +415,7 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_row_scaled, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @@ -450,7 +463,7 @@ def test_sanity_layernorm_linear( @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_row_scaled, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @@ -488,7 +501,7 @@ def test_sanity_linear( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_row_scaled, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) @@ -529,7 +542,7 @@ def test_sanity_linear_with_zero_tokens( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_row_scaled, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) @@ -563,7 +576,12 @@ def test_sanity_grouped_linear( if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") if fp8_recipe.nvfp4(): - pytest.skip("NVFP4 not supported for grouped linear") + if not getattr(fp8_recipe, "row_scaled_activation", False): + pytest.skip("NVFP4 not supported for grouped linear") + if single_param: + pytest.skip("Row-scaled NVFP4 does not support GroupedTensor grouped linear") + if dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") use_fp8 = fp8_recipe is not None with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index 9d0ed79888..51f72b1e56 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -32,6 +32,7 @@ is_fp8_block_scaling_available, is_nvfp4_available, ) +from utils import recipe_id fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) @@ -47,6 +48,7 @@ _all_recipes.append(recipe.MXFP8BlockScaling()) if nvfp4_available: _all_recipes.append(recipe.NVFP4BlockScaling()) + _all_recipes.append(recipe.NVFP4BlockScaling(row_scaled_activation=True)) # --------------------------------------------------------------------------- @@ -303,7 +305,7 @@ def fn(inp): @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.parametrize("fp8_recipe", _all_recipes, ids=lambda r: type(r).__name__) +@pytest.mark.parametrize("fp8_recipe", _all_recipes, ids=recipe_id) def test_autocast_sanity(fp8_recipe): """Smoke test: torch.nn.Linear inside a single te.autocast with each built-in recipe. Forward + backward under torch.compile(fullgraph=True).""" diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index c7cbe78a6d..32e44be2af 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -117,7 +117,7 @@ def quantization_tols(name: str) -> dict[str, float]: "mxfp8_block_scaling", ): return dtype_tols(tex.DType.kFloat8E4M3) - if name == "nvfp4": + if name in ("nvfp4", "nvfp4_row_scaled"): return dtype_tols(tex.DType.kFloat4E2M1) raise ValueError(f"Unsupported quantization scheme ({name})") @@ -151,15 +151,39 @@ def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]: disable_2d_quantization=True, **recipe_kwargs, ) + if name == "nvfp4_row_scaled": + return transformer_engine.common.recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + row_scaled_activation=True, + **recipe_kwargs, + ) raise ValueError(f"Unsupported quantization scheme ({name})") +def recipe_id(recipe: Optional[Recipe]) -> str: + """Readable pytest id for a quantization recipe.""" + if not isinstance(recipe, Recipe): + return "None" + if recipe.nvfp4() and recipe.row_scaled_activation: + return "NVFP4RowScaledBlockScaling" + return type(recipe).__name__ + + def skip_unsupported_backward_override( layer_type: str, quant_recipe: Optional[Recipe], backward_override: Optional[str], ) -> None: """Skip known unsupported layer/recipe/backward-override combinations used in tests.""" + if ( + quant_recipe is not None + and quant_recipe.nvfp4() + and getattr(quant_recipe, "row_scaled_activation", False) + and backward_override is None + ): + pytest.skip("Row-scaled NVFP4 does not support default quantized backward.") if backward_override is None: return if quant_recipe is None and backward_override is not None: @@ -392,11 +416,11 @@ def test(): _attention_backends["backend_selection_requires_update"] = False return available_backends, flash_attention_backend, fused_attention_backend - backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} + backends = {1: "F16_arbitrary_seqlen", 2: "FP8"} if AttentionLogging._is_logging_setup is False: AttentionLogging.setup_logging() - for i in range(3): + for i in backends: os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) _attention_backends["backend_selection_requires_update"] = True available_backends, flash_attention_backend, fused_attention_backend = test() diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 734941595d..030023d949 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -182,7 +182,6 @@ list(APPEND transformer_engine_cuda_sources dropout/dropout.cu fused_attn/context_parallel.cu fused_attn/kv_cache.cu - fused_attn/fused_attn_f16_max512_seqlen.cu fused_attn/fused_attn_f16_arbitrary_seqlen.cu fused_attn/fused_attn_fp8.cu fused_attn/utils.cu diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 5d0d3c28e8..123362ce10 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -100,6 +100,14 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, int32_t rows = input_tensor->flat_first_dim(); int32_t cols = input_tensor->flat_last_dim(); auto dtype = input_tensor->dtype(); + const bool row_scaled_nvfp4 = output_tensor->row_scaled_nvfp4; + if (row_scaled_nvfp4) { + NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, + "Row-scaled NVFP4 quantization does not support 2D quantization."); + NVTE_CHECK(!output_tensor->has_columnwise_data(), + "Row-scaled NVFP4 quantization does not produce columnwise output."); + nvfp4::compute_rowwise_amax(*input_tensor, noop_tensor, output_tensor, stream); + } bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0) && output_tensor->has_data(); @@ -126,7 +134,9 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, /*rng_state=*/quant_config_cpp.rng_state, /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + /*row_scaled_nvfp4=*/row_scaled_nvfp4, + /*noop_tensor=*/noop_tensor->data, + /*stream=*/stream); } break; } @@ -239,6 +249,8 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens int32_t rows = grad_tensor->flat_first_dim(); int32_t cols = grad_tensor->flat_last_dim(); auto dtype = grad_tensor->dtype(); + NVTE_CHECK(!output_tensor->row_scaled_nvfp4, + "Backward NVFP4 quantization does not support row-scaled outputs."); bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0) && output_tensor->has_data(); @@ -265,7 +277,8 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, /*rng_state=*/quant_config_cpp.rng_state, /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + /*row_scaled_nvfp4=*/false, /*noop_tensor=*/noop_tensor->data, + /*stream=*/stream); } break; } diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index 4143208153..d549a050ee 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -34,8 +34,9 @@ namespace dequantize_kernel { template __global__ void __launch_bounds__(512) dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales, - const float *const tensor_amax, const size_t N, const size_t M, - const size_t scale_stride, const size_t num_scale_tiles_X) { + const float *const tensor_amax, const bool row_scaled_nvfp4, + const size_t N, const size_t M, const size_t scale_stride, + const size_t num_scale_tiles_X) { const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t x = thread_idx % M; const size_t y = thread_idx / M; @@ -63,7 +64,7 @@ __global__ void __launch_bounds__(512) fp4vec value; value.vec = input_vectorized[my_index]; fp8e4m3 scale = scales[my_scale_index]; - float amax = *tensor_amax; + float amax = row_scaled_nvfp4 ? tensor_amax[y] : tensor_amax[0]; constexpr float factor_inv = 1.0 / (6.0 * 448.0); float final_scale = static_cast(scale) * amax * factor_inv; #pragma unroll @@ -90,6 +91,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); const bool with_gemm_swizzled_scales = input.with_gemm_swizzled_scales; + const bool row_scaled_nvfp4 = input.row_scaled_nvfp4; constexpr int FP4_BLOCK_SIZE = 16; const size_t N = input.flat_first_dim(); @@ -103,6 +105,8 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) const size_t threads = 512; const size_t blocks = DIVUP(total, threads); const size_t num_scale_tiles_X = DIVUP(Mread, static_cast(4)); + NVTE_CHECK(!row_scaled_nvfp4 || input.amax.numel() == N, + "Row-scaled NVFP4 dequantization requires one rowwise amax per row."); TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( output->data.dtype, OType, @@ -112,7 +116,8 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) dequantize_fp4_kernel<<>>( input.data.dptr, reinterpret_cast(output->data.dptr), reinterpret_cast(input.scale_inv.dptr), - reinterpret_cast(input.amax.dptr), N, Mread, input.scale_inv.shape.back(), + reinterpret_cast(input.amax.dptr), row_scaled_nvfp4, N, Mread, + input.scale_inv.shape.back(), num_scale_tiles_X);); // NOLINT(*) ); // NOLINT(*) NVTE_CHECK_CUDA(cudaGetLastError()); diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index f164636e38..9e4aef5a1c 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -16,6 +16,8 @@ #include #include +#include + #include "../../common.h" #include "../../util/math.h" #include "../../util/ptx.cuh" @@ -27,6 +29,132 @@ namespace transformer_engine { namespace dispatch { namespace nvfp4 { +namespace rowwise_amax_kernel { + +using namespace ptx; + +#if FP4_TYPE_SUPPORTED + +constexpr int ROWWISE_AMAX_BLOCK_SIZE = 256; +constexpr int ROWWISE_AMAX_SF_VEC_SIZE = 16; + +template +__device__ __forceinline__ void abs_max_2x_update(ptx::FPx2 &dst, + const ptx::FPx2 &val) { + if constexpr (std::is_same_v) { + dst.x = fmaxf(fabsf(dst.x), fabsf(val.x)); + dst.y = fmaxf(fabsf(dst.y), fabsf(val.y)); + } else { + ptx::abs_max_2x(dst, dst, val); + } +} + +template +__device__ __forceinline__ float abs_max_2x_to_float(const ptx::FPx2 &val) { + if constexpr (std::is_same_v) { + return fmaxf(fabsf(val.x), fabsf(val.y)); + } else { + return static_cast(__hmax(__habs(val.x), __habs(val.y))); + } +} + +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(BLOCK_SIZE) +#endif + compute_rowwise_amax_kernel(const int num_rows, const int num_cols, + const IType *__restrict__ input, + float *__restrict__ output_rowwise_amax, + const float *__restrict__ noop) { +#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ < 1000) + NVTE_DEVICE_ERROR("SM 10.0+ is required."); +#else + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + using IType2 = typename ptx::FPx2; + + const int row_idx = blockIdx.x; + if (row_idx >= num_rows) return; + + const int num_vec2 = num_cols / 2; + const IType2 *input_row = reinterpret_cast(input + row_idx * num_cols); + + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; + for (int i = threadIdx.x; i < num_vec2; i += BLOCK_SIZE) { + const IType2 val = input_row[i]; + abs_max_2x_update(thread_amax_2x, val); + } + const float thread_max = abs_max_2x_to_float(thread_amax_2x); + + const float row_amax = + reduce_max(thread_max, threadIdx.x / THREADS_PER_WARP); + + if (threadIdx.x == 0) { + output_rowwise_amax[row_idx] = row_amax; + } +#endif +} + +template +void launch_compute_rowwise_amax(const int num_rows, const int num_cols, const IType *input, + float *output_rowwise_amax, cudaStream_t stream, + const float *noop = nullptr) { + if (num_rows == 0 || num_cols == 0) return; + + dim3 grid(num_rows); + dim3 block(ROWWISE_AMAX_BLOCK_SIZE); + + compute_rowwise_amax_kernel + <<>>(num_rows, num_cols, input, output_rowwise_amax, noop); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +#endif // FP4_TYPE_SUPPORTED + +} // namespace rowwise_amax_kernel + +inline void compute_rowwise_amax(const Tensor &input, const Tensor *noop, Tensor *output, + cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + using namespace rowwise_amax_kernel; + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + NVTE_CHECK(cols % ROWWISE_AMAX_SF_VEC_SIZE == 0, + "Row-scaled NVFP4 quantization requires last dim divisible by ", + ROWWISE_AMAX_SF_VEC_SIZE, "."); + + auto *amax_ptr = reinterpret_cast(output->amax.dptr); + NVTE_CHECK(amax_ptr != nullptr, "Row-scaled rowwise amax tensor must be allocated."); + NVTE_CHECK(output->amax.numel() == rows, "Row-scaled rowwise amax must have ", rows, + " entries, got ", output->amax.shape, "."); + + const auto *noop_ptr = reinterpret_cast(noop->data.dptr); + if (input.dtype() == DType::kBFloat16) { + const auto *input_ptr = reinterpret_cast(input.data.dptr); + launch_compute_rowwise_amax<__nv_bfloat16>(static_cast(rows), static_cast(cols), + input_ptr, amax_ptr, stream, noop_ptr); + } else if (input.dtype() == DType::kFloat16) { + const auto *input_ptr = reinterpret_cast(input.data.dptr); + launch_compute_rowwise_amax(static_cast(rows), static_cast(cols), input_ptr, + amax_ptr, stream, noop_ptr); + } else if (input.dtype() == DType::kFloat32) { + const auto *input_ptr = reinterpret_cast(input.data.dptr); + launch_compute_rowwise_amax(static_cast(rows), static_cast(cols), input_ptr, + amax_ptr, stream, noop_ptr); + } else { + NVTE_ERROR( + "Unsupported input dtype for row-scaled NVFP4 quantization. " + "Expected BFloat16, Float16, or Float32."); + } +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // FP4_TYPE_SUPPORTED +} + namespace quantize_transpose_kernel { using namespace quantization_and_transposition_SF; @@ -108,7 +236,8 @@ constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM; // 8 = 128 / 16 template + typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_TRANSPOSE, + bool ROW_SCALED_NVFP4> __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_output, @@ -508,27 +637,56 @@ __global__ void __launch_bounds__(THREADS_NUM) } } - // 2. Compute E4M3 scaling factor - const nvfp4_scale_t S_dec_b_fp8 = - compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + float block_scale_inverse; + if constexpr (ROW_SCALED_NVFP4) { + // 2. Compute E4M3 scaling factor + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const float S_enc_rowwise_block = + scales_offset_Y < rows + ? compute_global_encode_scaling_factor_FP4(amax_rowwise_ptr[scales_offset_Y]) + : 1.0f; + const float S_dec_rowwise_block = 1.0f / S_enc_rowwise_block; + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_rowwise_block); + + // Check boundaries + const size_t scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + + const bool rowwise_scale_is_within_bounds_Y = + (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; + if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { + scales_ptr[scale_idx_global] = S_dec_b_fp8; + } - // Check boundaries - const size_t scales_offset_Y = - scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; - const size_t scales_offset_X = scales_offset_X_rowwise; - const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + block_scale_inverse = + fminf(1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise_block), + float_max); // S_enc_b_fp8 + } else { + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + + // Check boundaries + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const size_t scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + + const bool rowwise_scale_is_within_bounds_Y = + (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; + if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { + scales_ptr[scale_idx_global] = S_dec_b_fp8; + } - // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; - const bool rowwise_scale_is_within_bounds_Y = - (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; - if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { - scales_ptr[scale_idx_global] = S_dec_b_fp8; + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + block_scale_inverse = fminf(1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), + float_max); // S_enc_b_fp8 } - - // Compute "correct" per-block encoding scaling factor - constexpr float float_max = detail::TypeExtrema::max; - const float block_scale_inverse = fminf( - 1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8 const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; // 3. Scale elements @@ -1051,7 +1209,6 @@ __global__ void __launch_bounds__(THREADS_NUM) const size_t scales_offset_X = scales_offset_X_rowwise; const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; - // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; const bool rowwise_scale_is_within_bounds_Y = (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { @@ -1162,6 +1319,9 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, using namespace ptx; bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; + const bool row_scaled_nvfp4 = output->row_scaled_nvfp4; + NVTE_CHECK(!row_scaled_nvfp4 || !use_2d_quantization, + "Row-scaled NVFP4 quantization does not support 2D quantization."); // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to // return the transposed data. @@ -1186,6 +1346,10 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + NVTE_CHECK(!row_scaled_nvfp4 || output->amax.dptr != nullptr, + "Row-scaled NVFP4 quantization requires rowwise amax."); + NVTE_CHECK(!row_scaled_nvfp4 || !output->has_columnwise_data(), + "Row-scaled NVFP4 quantization does not produce columnwise output."); NVTE_CHECK(!output->with_gemm_swizzled_scales, "Output must have scales in compact format."); if (return_transpose) { NVTE_CHECK(output->has_columnwise_data(), "NVFP4 transposed output tensor must be allocated."); @@ -1268,20 +1432,23 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, TRANSFORMER_ENGINE_SWITCH_CONDITION( use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, - TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { - auto kernel = quantize_transpose_nvfp4_kernel; + TRANSFORMER_ENGINE_SWITCH_CONDITION(row_scaled_nvfp4, ROW_SCALED_NVFP4, { + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = quantize_transpose_nvfp4_kernel; - if constexpr (use_2d_quantization) { - kernel = quantize_transpose_nvfp4_2D_kernel; - } + if constexpr (use_2d_quantization) { + kernel = quantize_transpose_nvfp4_2D_kernel; + } - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - kernel<<>>( - tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, - scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, - scale_stride, scale_stride_transpose, rng_state); + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, + scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, + scale_stride, scale_stride_transpose, rng_state); + }); });); #else NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh index fc337f6078..8adda82131 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -261,14 +261,12 @@ __device__ __forceinline__ void colwise_scaling(const IType *__restrict__ sIn_pt } } -template -__device__ __forceinline__ void rowwise_scaling(const IType *__restrict__ sIn_ptr, - fp4e2m1x2 *__restrict__ sOut_ptr, - nvfp4_scale_t *__restrict__ sSFrowwise_ptr, - const float S_enc_rowwise, const int stage_Y, - const int stage_X, const int buff_in, - const int buff_out, RNG_t &rng, uint4 &random_uint4, - int &rnd_idx) { +template +__device__ __forceinline__ void rowwise_scaling( + const IType *__restrict__ sIn_ptr, fp4e2m1x2 *__restrict__ sOut_ptr, + nvfp4_scale_t *__restrict__ sSFrowwise_ptr, const float S_enc_rowwise, const int stage_Y, + const int stage_X, const int buff_in, const int buff_out, const float *amax_rowwise_ptr, + const size_t row_offset, const size_t rows, RNG_t &rng, uint4 &random_uint4, int &rnd_idx) { using scaling_coeff_type = typename SCALING_COEFFICIENT_TYPE::type; const auto &sIn = *reinterpret_cast(sIn_ptr); @@ -315,9 +313,21 @@ __device__ __forceinline__ void rowwise_scaling(const IType *__restrict__ sIn_pt } const float block_amax = get_amax_of_pair(thread_amax_2x); - const nvfp4_scale_t S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise); - const scaling_coeff_type SFcoefficient = - compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_rowwise); + nvfp4_scale_t S_dec_b_fp8; + scaling_coeff_type SFcoefficient; + if constexpr (ROW_SCALED_NVFP4) { + const size_t row_idx = row_offset + stage_Y * TILE_DIM_Y + it_offset_Y_rowwise; + const float S_enc_rowwise_block = + row_idx < rows ? core::compute_global_encode_scaling_factor_FP4(amax_rowwise_ptr[row_idx]) + : 1.0f; + S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise_block); + SFcoefficient = + compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_rowwise_block); + } else { + S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + SFcoefficient = + compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_rowwise); + } // Store scaling factors to SMEM buffer (R2S) if (SF_storing_thread) { @@ -350,7 +360,8 @@ __device__ __forceinline__ void rowwise_scaling(const IType *__restrict__ sIn_pt } } -template +template __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D_kernel( const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_output, @@ -571,9 +582,9 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D ptx::cp_async_bulk_wait_group_read(); // NVFP4 Quantization - rowwise_scaling( + rowwise_scaling( sIn_ptr, sOut_ptr, sSFrowwise_ptr, S_enc_rowwise, stage_Y, stage_X, buff_in, buff_out, - rng, random_uint4, rnd_idx); + amax_rowwise_ptr, block_offset_Y, rows, rng, random_uint4, rnd_idx); if constexpr (RETURN_TRANSPOSE) { colwise_scaling( @@ -680,6 +691,7 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, const bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; const bool use_fast_math = quant_config ? quant_config->use_fast_math : false; + const bool row_scaled_nvfp4 = output->row_scaled_nvfp4; // If transposed output is allocated, return the transposed data // Otherwise, it's not necesary to return the transposed data. @@ -694,6 +706,10 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + NVTE_CHECK(!row_scaled_nvfp4 || output->amax.dptr != nullptr, + "Row-scaled NVFP4 quantization requires rowwise amax."); + NVTE_CHECK(!row_scaled_nvfp4 || !output->has_columnwise_data(), + "Row-scaled NVFP4 quantization does not produce columnwise output."); if (return_transpose) { NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), @@ -783,16 +799,20 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, TRANSFORMER_ENGINE_SWITCH_CONDITION( use_fast_math, USE_FAST_MATH, - TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { - auto kernel = quantize_transpose_nvfp4_tuned_1D_kernel; - - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - kernel<<>>( - tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, - scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, - scale_stride, scale_stride_transpose, rng_state); - }););); + TRANSFORMER_ENGINE_SWITCH_CONDITION( + row_scaled_nvfp4, ROW_SCALED_NVFP4, + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = + quantize_transpose_nvfp4_tuned_1D_kernel; + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, + scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, + scale_stride, scale_stride_transpose, rng_state); + });););); #else NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); #endif // FP4_TYPE_SUPPORTED diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 133f1a09e6..28218e2b43 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -222,6 +222,14 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz TensorWrapper chunk(scaling_mode); for (int param_id = 0; param_id < NVTETensorParam::kNVTENumTensorParams; param_id++) { auto param_type = static_cast(param_id); + if (param_type == NVTETensorParam::kNVTEWithGEMMSwizzledScales) { + chunk.set_with_gemm_swizzled_scales(source.get_with_gemm_swizzled_scales()); + continue; + } + if (param_type == NVTETensorParam::kNVTERowScaledNVFP4) { + chunk.set_row_scaled_nvfp4(source.get_row_scaled_nvfp4()); + continue; + } auto param = source.get_parameter(param_type); auto param_dptr = reinterpret_cast(param.data_ptr); auto param_dtype = static_cast(param.dtype); diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index c1b3f8f427..12479f2a9c 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -173,6 +173,11 @@ struct Tensor { * Only meaningful for MXFP8 and NVFP4. */ bool with_gemm_swizzled_scales = false; + /*! \brief Whether NVFP4 rowwise amax metadata is row-scaled. + * + * Only meaningful for NVFP4 tensors. + */ + bool row_scaled_nvfp4 = false; /*! Map from NVTETensorParam to parameter sizes */ static constexpr size_t attr_sizes[] = { @@ -183,7 +188,8 @@ struct Tensor { sizeof(NVTEBasicTensor), // kNVTERowwiseScaleInv sizeof(NVTEBasicTensor), // kNVTEColumnwiseScaleInv sizeof(NVTEBasicTensor), // kNVTEColumnwiseAmax - sizeof(uint8_t) // kNVTEWithGEMMSwizzledScales + sizeof(uint8_t), // kNVTEWithGEMMSwizzledScales + sizeof(uint8_t) // kNVTERowScaledNVFP4 }; Tensor() : scaling_mode{NVTE_DELAYED_TENSOR_SCALING}, nvte_tensor{0} {} @@ -199,6 +205,7 @@ struct Tensor { columnwise_scale_inv.clear(); scaling_mode = NVTE_DELAYED_TENSOR_SCALING; with_gemm_swizzled_scales = false; + row_scaled_nvfp4 = false; } explicit operator NVTETensor() const noexcept { return nvte_tensor; } diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index ae8ddbed69..d2eb1a831c 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -11,7 +11,6 @@ #include "../util/cuda_runtime.h" #include "../util/system.h" #include "fused_attn_f16_arbitrary_seqlen.h" -#include "fused_attn_f16_max512_seqlen.h" #include "fused_attn_fp8.h" #include "utils.h" @@ -255,35 +254,31 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( if ((q_dtype == NVTEDType::kNVTEFloat8E4M3 || q_dtype == NVTEDType::kNVTEFloat8E5M2) && sm_arch_ >= 90 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && - // 8.9: t3hd, max_s=512, d=64, padding - ((cudnn_runtime_version >= 8900 && sm_arch_ < 100 && - qkv_layout == NVTE_QKV_Layout::NVTE_T3HD && max_seqlen_q == max_seqlen_kv && - max_seqlen_q <= 512 && head_dim_qk == 64 && head_dim_v == 64 && - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - // 9.2.1: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} - (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 && - max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 && - (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || - // 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal} - (cudnn_runtime_version >= 90700 && - // TODO (cyang): add is_training to nvte_get_fused_attn_backend - // sm90: fwd d<=256, bwd d=128 only - // sm100: fwd d<=128, bwd d<=128 - ((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) || - (sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) || - (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && - head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) || - // 9.21: d_qk=192, d_v=128 - (cudnn_runtime_version >= 92100 && sm_arch_ >= 100 && head_dim_qk <= 192 && - head_dim_v <= 128 && head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK))) && + ( + // 9.2.1: {bshd, sbhd}, any seqlen, d=128, {no_mask, causal} + (cudnn_runtime_version >= 90201 && sm_arch_ < 100 && max_seqlen_q % 128 == 0 && + max_seqlen_kv % 128 == 0 && head_dim_qk == 128 && head_dim_v == 128 && + (attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) || + // 9.7: {bshd, sbhd}, any seqlen, d<=256 for sm90 and d<=128 for sm100, {padding, padding_causal} + (cudnn_runtime_version >= 90700 && + // TODO (cyang): add is_training to nvte_get_fused_attn_backend + // sm90: fwd d<=256, bwd d=128 only + // sm100: fwd d<=128, bwd d<=128 + ((sm_arch_ < 100 && (!is_training) && head_dim_qk <= 256 && head_dim_v <= 256) || + (sm_arch_ < 100 && is_training && head_dim_qk == 128 && head_dim_v == 128) || + (sm_arch_ >= 100 && head_dim_qk <= 128 && head_dim_v <= 128)) && + head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && + (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) || + // 9.21: d_qk=192, d_v=128 + (cudnn_runtime_version >= 92100 && sm_arch_ >= 100 && head_dim_qk <= 192 && + head_dim_v <= 128 && head_dim_qk % 16 == 0 && head_dim_v % 16 == 0 && + (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || + attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK))) && // pre-9.21: {bshd, sbhd}, {vanilla} // 9.21+: {bshd, sbhd, bhsd}, {vanilla, off-by-one, learnable} ((cudnn_runtime_version < 92100 && @@ -295,37 +290,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( !requires_64bit_ragged_offset && // 9.10.0: known bugs with SDPA FP8 (cudnn_runtime_version != 91000) && !return_max_logit) { - if (cudnn_runtime_version >= 8900) { - backend = NVTE_Fused_Attn_Backend::NVTE_FP8; - } else { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: FP8 fused attention is supported by cuDNN 8.9.0+." - " Please upgrade your cuDNN version if possible." - << std::endl; - } + backend = NVTE_Fused_Attn_Backend::NVTE_FP8; } else if ((q_dtype == NVTEDType::kNVTEFloat16) || (q_dtype == NVTEDType::kNVTEBFloat16)) { - bool flag_m512 = false; bool flag_arb = false; - if ((sm_arch_ == 80 || sm_arch_ == 90) && (max_seqlen_q <= 512 && max_seqlen_q % 64 == 0) && - (max_seqlen_kv <= 512 && max_seqlen_kv % 64 == 0) && (head_dim_qk == 64) && - (head_dim_v == 64) && (num_attn_heads == num_gqa_groups) && - ((bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || - (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS)) && - ((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK) || - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK) || - (attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && - max_seqlen_q == max_seqlen_kv) || - (attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK)) && - ((qkv_layout == NVTE_QKV_Layout::NVTE_SB3HD) || - (qkv_layout == NVTE_QKV_Layout::NVTE_SBHD_SB2HD) || - (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) || - (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) || - (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD)) && - ((window_size_left == -1) && (window_size_right == -1 || window_size_right == 0)) && - !requires_64bit_ragged_offset && - (softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX) && !return_max_logit) { - flag_m512 = true; - } if ( // TODO(cyang): replace with cudnn-frontend check_support for cleaner logic and better error messaging // architecture @@ -499,31 +466,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS))))) { flag_arb = true; } - if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { + if (flag_arb) { backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; } - if ((max_seqlen_q <= 512) && (max_seqlen_kv <= 512)) { - if (flag_arb == true) { - backend = NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen; - } else if ((flag_arb == false) && (flag_m512 == true)) { - backend = NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen; - } - int env_backend = static_cast(backend); - env_backend = transformer_engine::getenv("NVTE_FUSED_ATTN_BACKEND", env_backend); - if (((env_backend == static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen)) && - flag_m512) || - ((env_backend == static_cast(NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen)) && - flag_arb)) { - backend = static_cast(env_backend); - } - } - if (cudnn_runtime_version < 8901 && - backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { - backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; - std::cout << "Warning: FP16/BF16 fused attention is supported by cuDNN 8.9.1+." - " Please upgrade your cuDNN version if possible." - << std::endl; - } if (cudnn_runtime_version < 8900 && backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend; @@ -668,12 +613,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, return_max_logit, cuda_graph, false); - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { - fused_attn_max_512_fwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, is_training, attn_scale, - dropout, qkv_layout, bias_type, attn_mask_type, input_Q, input_K, - input_V, input_Bias, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, - input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { fused_attn_arbitrary_seqlen_fwd( b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v, page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training, @@ -754,13 +694,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false, cuda_graph, deterministic); - if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - fused_attn_max_512_bwd(b, h_q, max_seqlen_q, max_seqlen_kv, d_qk, attn_scale, dropout, - qkv_layout, bias_type, attn_mask_type, input_Q, input_K, input_V, - input_dO, output_S, output_dQ, output_dK, output_dV, output_dBias, - input_cu_seqlens_q, input_cu_seqlens_kv, wkspace, stream, handle); - } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { + if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) { size_t i = 0; Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); @@ -782,10 +716,6 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso } else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) { size_t i = 0; const Tensor *input_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - const Tensor *input_ZInv = nullptr; - if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - input_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - } const Tensor *input_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); const Tensor *input_SoftmaxOffset = nullptr; if (softmax_type != NVTE_VANILLA_SOFTMAX) { @@ -799,10 +729,10 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso qkv_layout, o_format, do_format, dqkv_layout, qkv_scale_inv_format, do_scale_inv_format, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, - input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, input_M, - input_ZInv, input_S, input_SoftmaxOffset, input_output_dP, output_dQ, - output_dK, output_dV, output_dSoftmaxOffset, input_cu_seqlens_q, - input_cu_seqlens_kv, input_rng_state, wkspace, stream, handle); + input_Q, input_K, input_V, input_O, input_dO, input_dO_f16, input_M, input_S, + input_SoftmaxOffset, input_output_dP, output_dQ, output_dK, output_dV, + output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, + input_rng_state, wkspace, stream, handle); } else { NVTE_ERROR("Invalid combination of data type and sequence length for fused attention. \n"); } diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu deleted file mode 100644 index d5151a51f1..0000000000 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.cu +++ /dev/null @@ -1,1343 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include -#include -#include - -#include -#include - -#include "../common.h" -#include "../cudnn_utils.h" -#include "fused_attn_f16_max512_seqlen.h" -#include "utils.h" - -#define Q_ID 1 -#define K_ID 2 -#define V_ID 3 -#define O_ID 4 -#define S_ID 5 -#define B_ID 6 -#define DROPOUT_CONST_ID 7 -#define S_CONST_ID 8 -#define Q_SEQLEN_ID 9 -#define K_SEQLEN_ID 10 -#define dQ_ID 11 -#define dK_ID 12 -#define dV_ID 13 -#define dO_ID 14 -#define MASK_VAL_ID 15 -#define dS_ID 16 -#define dBias_ID 17 -#define DROPOUT_SEED_ID 18 -#define DROPOUT_OFFSET_ID 19 - -#define VIRTUAL_ID 20 - -namespace transformer_engine { -namespace fused_attn { - -static void createScale(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, - NVTE_QKV_Layout layout, cudnnDataType_t tensorType, - std::vector &ops) { - // scale - int64_t scale_dim[4] = {1, 1, 1, 1}; - int64_t scale_stride[4] = {1, 1, 1, 1}; - - int64_t k_dim[4] = {b, h, d, s_kv}; - int64_t k_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, k_stride, layout, - NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); - - auto scaleTensor = - tensor_create(tensorType, S_CONST_ID, scale_dim, scale_stride, false, true); // is by value - auto kTensor = tensor_create(tensorType, K_ID, k_dim, k_stride, false, false); - auto afterScaleKTensor = - tensor_create(tensorType, VIRTUAL_ID, k_dim, k_stride, true, false); // is virtual - - // Define the scale descriptor - auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a Scale Node. - auto scale_op = binary_pw_op_create(kTensor, scaleTensor, afterScaleKTensor, scaleDesc); - - ops.push_back(std::move(scale_op)); -} - -static cudnn_frontend::Tensor createBMM1(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, - NVTE_QKV_Layout layout, cudnnDataType_t tensorType, - bool zero_s, std::vector &ops) { - // Creates the necessary tensor descriptors - int64_t q_dim[4] = {b, h, s_q, d}; - int64_t q_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - - int64_t k_dim[4] = {b, h, d, s_kv}; - int64_t k_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, k_stride, layout, - NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); - - int64_t p_dim[4] = {b, h, s_q, s_kv}; - int64_t p_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, p_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); - - int64_t seqlen_dim[4] = {b, 1, 1, 1}; - int64_t seqlen_stride[4] = {1, 1, 1, 1}; - - auto qTensor = tensor_create(tensorType, Q_ID, q_dim, q_stride, false, false); - auto afterScaleKTensor = - tensor_create(tensorType, VIRTUAL_ID, k_dim, k_stride, true, false); // is virtual - // first GEMM output - auto pTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 1, p_dim, p_stride, true, - false); // is virtual - - auto seqlenQTensor = - tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); - auto seqlenKTensor = - tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); - - // Define the matmul 1 desc - // set padding value optionally to 0 for writing zeros to S tensor (if not set, old behaviour) - auto matmul_1_Desc = cudnn_frontend::MatMulDescBuilder().setComputeType(CUDNN_DATA_FLOAT).build(); - - if (zero_s) { - matmul_1_Desc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(0.0f) - .build(); - } - - // Create a matmul 1 Node - auto matmul_op1 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(qTensor) - .setbMatDesc(afterScaleKTensor) - .setcMatDesc(pTensor) - .setmOverrideDesc(seqlenQTensor) - .setnOverrideDesc(seqlenKTensor) - .setmatmulDesc(matmul_1_Desc) - .build(); - - ops.push_back(std::move(matmul_op1)); - - return pTensor; -} - -static cudnn_frontend::Tensor createBias(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, - NVTE_QKV_Layout layout, cudnnDataType_t tensorType, - std::vector &ops, - cudnn_frontend::Tensor const &prevBlockOutputTensor) { - NVTE_CHECK(ops.size() != 0, "Bias op constructed incorrectly as the first one."); - - int64_t b_dim[4] = {1, h, s_q, s_kv}; - int64_t b_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; - - int64_t afterBias_dim[4] = {b, h, s_q, s_kv}; - int64_t afterBias_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, afterBias_stride, layout, - NVTE_QKV_Matrix::NVTE_S_Matrix); - - // bias - auto bTensor = tensor_create(tensorType, B_ID, b_dim, b_stride, false, false); - // output - auto afterBiasTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 50, afterBias_dim, - afterBias_stride, true, false); // is virtual - - // Define the bias descriptor - auto biasDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_ADD); - - // Create a Bias Node. - auto bias_op = binary_pw_op_create(prevBlockOutputTensor, bTensor, afterBiasTensor, biasDesc); - - ops.push_back(std::move(bias_op)); - - return afterBiasTensor; -} - -static cudnn_frontend::Tensor createMask(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, - NVTE_QKV_Layout layout, NVTE_Mask_Type mask_type, - cudnnDataType_t tensorType, - std::vector &ops, - cudnn_frontend::Tensor const &prevBlockOutputTensor, - bool is_bprop) { - NVTE_CHECK(ops.size() != 0, "Padding mask constructed incorrectly as the first one."); - - // subtraction output - int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; - int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; - - int64_t seqlen_dim[4] = {b, 1, 1, 1}; - int64_t seqlen_stride[4] = {1, 1, 1, 1}; - - int64_t maskVal_dim[4] = {1, 1, 1, 1}; - int64_t maskVal_stride[4] = {1, 1, 1, 1}; - - // mask value to put in the masked pixels - auto maskValTensor = tensor_create(CUDNN_DATA_FLOAT, MASK_VAL_ID, maskVal_dim, maskVal_stride, - false, true); // is by value - - auto seqlenQTensor = - tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); - auto seqlenKTensor = - tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); - // gen index row output - auto rowIndexTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 100, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - // gen index column output - auto columnIndexTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 101, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - // less than row output - auto lessThanRowTensor = - tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 102, afterBMM1_dim, afterBMM1_stride, true, - false); // is virtual - // less than column output - auto lessThanColTensor = tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 103, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - // padding mask (lessthanRow && lessthanCol) - auto paddingMaskTensor = tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 104, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - // row >= col check for causal mask - auto rowGreaterColTensor = tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 105, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - // create causal mask (padding && row >= col) - auto causalMaskTensor = tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 106, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - - // output after masking - int64_t maskOutputTensor_id = VIRTUAL_ID + 107; - int64_t maskOutputTensor_virtual = true; - cudnnDataType_t maskOutputTensor_dataType = CUDNN_DATA_FLOAT; - auto maskOutputTensor_reorderType = cudnn_frontend::TensorReordering_t::NONE; - - if (is_bprop) { - maskOutputTensor_id = dS_ID; - maskOutputTensor_virtual = false; - maskOutputTensor_dataType = tensorType; - maskOutputTensor_reorderType = cudnn_frontend::TensorReordering_t::F16x16; - } - - auto maskOutputTensor = - cudnn_frontend::TensorBuilder() - .setDim(4, afterBMM1_dim) - .setStride(4, afterBMM1_stride) - .setAlignment(16) // 16B alignment is needed to run a tensor core engine - .setByValue(false) - .setDataType(maskOutputTensor_dataType) - .setVirtual(maskOutputTensor_virtual) - .setId(maskOutputTensor_id) - .setReorderType(maskOutputTensor_reorderType) - .build(); - - // Define the gen index for row descriptor - auto genIndexRowDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_GEN_INDEX) - .setAxis(2) - .setComputeType(CUDNN_DATA_FLOAT) - .build(); - - // Create a gen index Node. - auto genIndexRow_op = unary_pw_op_create(prevBlockOutputTensor, rowIndexTensor, genIndexRowDesc); - - // Define the gen index for row descriptor - auto genIndexColumnDesc = cudnn_frontend::PointWiseDescBuilder() - .setMode(CUDNN_POINTWISE_GEN_INDEX) - .setAxis(3) - .setComputeType(CUDNN_DATA_FLOAT) - .build(); - - // Create a gen index Node. - auto genIndexColumn_op = - unary_pw_op_create(prevBlockOutputTensor, columnIndexTensor, genIndexColumnDesc); - - // Define the less than comparison for row descriptor - auto lessThanRowDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_CMP_LT); - - // Create a less than comparison for row Node. - auto lessThanRow_op = - binary_pw_op_create(rowIndexTensor, seqlenQTensor, lessThanRowTensor, lessThanRowDesc); - - // Define the less than comparison for column descriptor - auto lessThanColDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_CMP_LT); - - // Create a less than comparison for col Node. - auto lessThanCol_op = - binary_pw_op_create(columnIndexTensor, seqlenKTensor, lessThanColTensor, lessThanColDesc); - - // Define the less than comparison for column descriptor - auto paddingMaskAndDesc = pw_desc_create(CUDNN_DATA_BOOLEAN, CUDNN_POINTWISE_LOGICAL_AND); - - // Create a and node for combining lessThanRow and lessThanCol - auto paddingMaskAnd_op = binary_pw_op_create(lessThanRowTensor, lessThanColTensor, - paddingMaskTensor, paddingMaskAndDesc); - - // Define the greater than equal to comparison descriptor - auto rowGreaterColDesc = pw_desc_create(CUDNN_DATA_BOOLEAN, CUDNN_POINTWISE_CMP_GE); - - // Create a greater than equal to Node. - auto rowGreaterCol_op = binary_pw_op_create(rowIndexTensor, columnIndexTensor, - rowGreaterColTensor, rowGreaterColDesc); - - // Define the and to create causal mask descriptor - auto causalMaskAndDesc = pw_desc_create(CUDNN_DATA_BOOLEAN, CUDNN_POINTWISE_LOGICAL_AND); - - // Create a causal Mask Node. - auto causalMaskAnd_op = binary_pw_op_create(paddingMaskTensor, rowGreaterColTensor, - causalMaskTensor, causalMaskAndDesc); - - /////////////////// Apply the mask ////////////////////////// - - auto maskTensor = (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) - ? std::move(causalMaskTensor) - : std::move(paddingMaskTensor); - - // Define the binary select to perform masking descriptor - auto maskDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_BINARY_SELECT); - - // Create a binary select Node. - auto mask_op = ternary_pw_op_create(prevBlockOutputTensor, maskValTensor, maskTensor, - maskOutputTensor, maskDesc); - - ops.push_back(std::move(genIndexRow_op)); - ops.push_back(std::move(genIndexColumn_op)); - ops.push_back(std::move(lessThanRow_op)); - ops.push_back(std::move(lessThanCol_op)); - ops.push_back(std::move(paddingMaskAnd_op)); - if (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK) { - ops.push_back(std::move(rowGreaterCol_op)); - ops.push_back(std::move(causalMaskAnd_op)); - } - ops.push_back(std::move(mask_op)); - - return maskOutputTensor; -} - -static cudnn_frontend::Tensor createSoftmaxForward( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, NVTE_QKV_Layout layout, - bool enable_dropout, bool softmax_output_virtual, cudnnDataType_t tensorType, - std::vector &ops, - cudnn_frontend::Tensor const &prevBlockOutputTensor) { - int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; - int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; - - int64_t afterReduction_dim[4] = {b, h, s_q, 1}; - int64_t afterReduction_stride[4] = {h * s_q, s_q, 1, 1}; - - cudnnDataType_t softmaxOutputType = enable_dropout ? CUDNN_DATA_FLOAT : tensorType; - uint64_t softmaxOutputName = softmax_output_virtual ? VIRTUAL_ID + 154 : S_ID; - - // max (x) - auto afterMaxReductionTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 150, afterReduction_dim, afterReduction_stride, - true, false); // is virtual - // x - max(x) - auto afterSubtractionTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 151, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - // e^(x - max(x)) - auto afterExponentTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 152, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual; - // sum (e^(x - max(x))) - auto afterAddReductionTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 153, afterReduction_dim, afterReduction_stride, - true, false); // is virtual - // divide (e/ sum(e)) - - auto reorder_type = cudnn_frontend::TensorReordering_t::F16x16; - - auto afterDivisionTensor = - cudnn_frontend::TensorBuilder() - .setDim(4, afterBMM1_dim) - .setStride(4, afterBMM1_stride) - .setId(softmaxOutputName) - .setAlignment(16) // 16B alignment is needed to run a tensor core engine - .setDataType(softmaxOutputType) - .setVirtual(softmax_output_virtual) - .setByValue(false) - .setReorderType(reorder_type) - .build(); - - // Define the reduction descriptor - auto reductionMaxDesc = cudnn_frontend::ReductionDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setReductionOp(CUDNN_REDUCE_TENSOR_MAX) - .build(); - - // Create a reduction max Node. - auto reductionMax_op = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(prevBlockOutputTensor) - .setyDesc(afterMaxReductionTensor) - .setreductionDesc(reductionMaxDesc) - .build(); - - // Define the subtract descriptor - auto subtractDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB); - - // Create a subtract Node. - auto subtract_op = binary_pw_op_create(prevBlockOutputTensor, afterMaxReductionTensor, - afterSubtractionTensor, subtractDesc); - - // Define the exponent descriptor - auto exponentDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_EXP); - - // Create a exponent Node. - auto exponent_op = unary_pw_op_create(afterSubtractionTensor, afterExponentTensor, exponentDesc); - - // Define the reduction descriptor - auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) - .build(); - - // Create a reduction add Node. - auto reductionAdd_op = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(afterExponentTensor) - .setyDesc(afterAddReductionTensor) - .setreductionDesc(reductionAddDesc) - .build(); - - // Define the division descriptor - auto divisionDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_DIV); - - // Create a subtract Node. - auto division_op = binary_pw_op_create(afterExponentTensor, afterAddReductionTensor, - afterDivisionTensor, divisionDesc); - - ops.push_back(std::move(reductionMax_op)); - ops.push_back(std::move(subtract_op)); - ops.push_back(std::move(exponent_op)); - ops.push_back(std::move(reductionAdd_op)); - ops.push_back(std::move(division_op)); - - return afterDivisionTensor; -} - -static cudnn_frontend::Tensor createDropout(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, double probability, - cudnnDataType_t tensorType, - std::vector &ops, - cudnn_frontend::Tensor const &prevBlockOutputTensor) { - NVTE_CHECK(ops.size() != 0, "Dropout DAG constructed incorrectly as the first one"); - - int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; - int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; - - int64_t scale_dim[4] = {1, 1, 1, 1}; - int64_t scale_stride[4] = {1, 1, 1, 1}; - - // mask for the dropout - auto dropoutMaskTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 200, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - - auto reorder_type = cudnn_frontend::TensorReordering_t::F16x16; - - // after dropout tensor - auto afterDropoutTensor = - cudnn_frontend::TensorBuilder() - .setDim(4, afterBMM1_dim) - .setStride(4, afterBMM1_stride) - .setId(S_ID) - .setAlignment(16) // 16B alignment is needed to run a tensor core engine - .setDataType(tensorType) - .setVirtual(false) - .setByValue(false) - .setReorderType(reorder_type) - .build(); - // scale after dropout - auto scaleDropoutTensor = - tensor_create(tensorType, DROPOUT_CONST_ID, scale_dim, scale_stride, false, - true); // is by value - // after Scale - auto afterScaleTensor = tensor_create(tensorType, VIRTUAL_ID + 201, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - - // Define the reduction descriptor - auto rngDesc = cudnn_frontend::RngDescBuilder() - .setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI) - .setBernoulliDistProbability(1.0 - probability) - .build(); - - auto dropoutSeed = - tensor_create(CUDNN_DATA_INT64, DROPOUT_SEED_ID, scale_dim, scale_stride, false, false); - auto dropoutOffset = - tensor_create(CUDNN_DATA_INT64, DROPOUT_OFFSET_ID, scale_dim, scale_stride, false, false); - - // Create a rng Node. - auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR) - .setyDesc(dropoutMaskTensor) - .setSeedDesc(dropoutSeed) - .setOffsetDesc(dropoutOffset) - .setRngDesc(rngDesc) - .build(); - - // Define the multiply mask descriptor - auto maskMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a multiply mask Node. - auto maskMul_op = binary_pw_op_create(prevBlockOutputTensor, dropoutMaskTensor, - afterDropoutTensor, maskMulDesc); - - // Define the multiply scale descriptor - auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a multiply mask Node. - auto scaleMul_op = - binary_pw_op_create(afterDropoutTensor, scaleDropoutTensor, afterScaleTensor, scaleMulDesc); - - ops.push_back(std::move(rng_op)); - ops.push_back(std::move(maskMul_op)); - ops.push_back(std::move(scaleMul_op)); - - return afterScaleTensor; -} - -static void createBMM2(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, - NVTE_QKV_Layout layout, cudnnDataType_t tensorType, - std::vector &ops, - cudnn_frontend::Tensor const &prevBlockOutputTensor) { - NVTE_CHECK(ops.size() != 0, "BMM2 op constructed incorrectly as the first one"); - - int64_t seqlen_dim[4] = {b, 1, 1, 1}; - int64_t seqlen_stride[4] = {1, 1, 1, 1}; - - int64_t v_dim[4] = {b, h, s_kv, d}; - int64_t v_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, v_stride, layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - - int64_t o_dim[4] = {b, h, s_q, d}; - int64_t o_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - - auto seqlenQTensor = - tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); - auto seqlenKTensor = - tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); - auto vTensor = tensor_create(tensorType, V_ID, v_dim, v_stride, false, false); - // second GEMM output - auto oTensor = tensor_create(tensorType, O_ID, o_dim, o_stride, false, false); - - // Define the matmul 2 desc - // set padding value optionally to 0 for writing zeros to O tensor (if not set, old behaviour) - auto matmul_2_Desc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(0.0f) - .build(); - - // Create a matmul 2 Node - auto matmul_op2 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(prevBlockOutputTensor) - .setbMatDesc(vTensor) - .setcMatDesc(oTensor) - .setmOverrideDesc(seqlenQTensor) - .setkOverrideDesc(seqlenKTensor) - .setmatmulDesc(matmul_2_Desc) - .build(); - - ops.push_back(std::move(matmul_op2)); -} - -static cudnn_frontend::Tensor createSoftmaxBackward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, NVTE_QKV_Layout layout, - cudnnDataType_t tensorType, - std::vector &ops, - cudnn_frontend::Tensor const &yTensor, - cudnn_frontend::Tensor const &dyTensor) { - NVTE_CHECK(ops.size() != 0, "Softmax backward constructed incorrectly as the first one"); - - int64_t p_dim[4] = {b, h, s_q, s_kv}; - int64_t p_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, p_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); - - int64_t p_reduction_dim[4] = {b, h, s_q, 1}; - int64_t p_reduction_stride[4]; - - p_reduction_stride[3] = 1; - p_reduction_stride[2] = 1; - p_reduction_stride[1] = s_q; - p_reduction_stride[0] = s_q * h; - - int64_t const_dim[4] = {1, 1, 1, 1}; - int64_t const_stride[4] = {1, 1, 1, 1}; - - // creating all tensors - auto softmaxScaleTensor = - tensor_create(CUDNN_DATA_FLOAT, S_CONST_ID, const_dim, const_stride, false, true); - auto dyMulYTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 250, p_dim, p_stride, true, false); - auto dxAfterReductionTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 251, p_reduction_dim, - p_reduction_stride, true, false); - auto dxAfterSubtractionTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 252, p_dim, p_stride, true, false); - auto dxUnscaleTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 253, p_dim, p_stride, true, false); - auto dxTensor = tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 254, p_dim, p_stride, true, false); - - // creating all ops - // mul (y * dy) - auto mul_1_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - auto mul_1_op = binary_pw_op_create(yTensor, dyTensor, dyMulYTensor, mul_1_desc); - - // reduction add sum (y * dy) - auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) - .build(); - - auto reductionAdd_op = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(dyMulYTensor) - .setyDesc(dxAfterReductionTensor) - .setreductionDesc(reductionAddDesc) - .build(); - - // subtraction (dy - sum(y * dy)) - auto sub_0_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB); - auto sub_0_op = - binary_pw_op_create(dyTensor, dxAfterReductionTensor, dxAfterSubtractionTensor, sub_0_desc); - - // mul (y * (dy - sum(y * dy))) - auto mul_2_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - auto mul_2_op = - binary_pw_op_create(yTensor, dxAfterSubtractionTensor, dxUnscaleTensor, mul_2_desc); - - // mul (scale * dx) - auto mul_3_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - auto mul_3_op = binary_pw_op_create(dxUnscaleTensor, softmaxScaleTensor, dxTensor, mul_3_desc); - - ops.push_back(std::move(mul_1_op)); - ops.push_back(std::move(reductionAdd_op)); - ops.push_back(std::move(sub_0_op)); - ops.push_back(std::move(mul_2_op)); - ops.push_back(std::move(mul_3_op)); - - return dxTensor; -} - -void fused_attn_max_512_fwd_impl( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, bool is_training, - float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, void *devPtrQ, void *devPtrK, void *devPtrV, - void *devPtrS, void *devPtrO, void *devPtrBias, void *devPtrCuSeqlenQ, void *devPtrCuSeqlenKV, - void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *workspace, size_t *workspace_size, - cudnnDataType_t tensorType, cudaStream_t stream, cudnnHandle_t handle) { - try { - FADescriptor descriptor{b, h, - s_q, s_kv, - d, scaling_factor, - is_training, dropout_probability, - layout, bias_type, - mask_type, tensorType, - false}; - - using CacheType = std::map; - static thread_local CacheType fmha_fprop_cache; - - // softmax auxiliary is only used in the training mode - bool enable_dropout = is_training && (dropout_probability != 0.0f); - - // two conditions that make softmax auxiliary in virtual - // 1. inference mode (not is_training) - // 2. dropout enabled: the auxiliary becomes the dropout output - bool softmax_output_virtual = !is_training || enable_dropout; - - // Get plan from cache if cache is available, otherwise create one - auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) { - // if hit, return - auto it = cache.find(descriptor); - if (it != cache.end()) { - auto plan = it->second; - return plan; - } - - // otherwise, build the op_graph and the plan. Then update cache - std::vector all_ops; - std::vector ops; - - createScale(b, h, s_q, s_kv, d, layout, tensorType, ops); - - // if bias, we need to memset the S buffer to correctly computate dbias - // WAR: causal_mask without bias needs memset the S buffer - // inference mode doesn't need the S auxiliary - auto zero_s = (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) || - (mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK || - (mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK)) && - is_training; - std::shared_ptr maskInput; - auto bmm1_output = createBMM1(b, h, s_q, s_kv, d, layout, tensorType, zero_s, ops); - - NVTE_CHECK(bias_type != NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS, - "NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS has not been implemented."); - - if (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) { - auto bias_output = createBias(b, h, s_q, s_kv, d, layout, tensorType, ops, bmm1_output); - maskInput = std::make_shared(std::move(bias_output)); - } - if (bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) { - maskInput = std::make_shared(std::move(bmm1_output)); - } - - auto mask_output = createMask(b, h, s_q, s_kv, d, layout, mask_type, tensorType, ops, - *maskInput.get(), false); - - NVTE_CHECK(dropout_probability != 1.0f, "Dropout probability cannot be 1.0."); - - auto softmax_output = - createSoftmaxForward(b, h, s_q, s_kv, d, layout, enable_dropout, softmax_output_virtual, - tensorType, ops, mask_output); - - if (enable_dropout) { - auto dropout_output = - createDropout(b, h, s_q, s_kv, d, dropout_probability, tensorType, ops, softmax_output); - createBMM2(b, h, s_q, s_kv, d, layout, tensorType, ops, dropout_output); - } else { - createBMM2(b, h, s_q, s_kv, d, layout, tensorType, ops, softmax_output); - } - - for (unsigned int i = 0; i < ops.size(); i++) { - all_ops.push_back(&ops[i]); - } - - // Create an Operation Graph - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle) - .setOperationGraph(all_ops.size(), all_ops.data()) - .build(); - - cudnn_frontend::EngineConfigList filtered_configs; - auto statuses = cudnn_frontend::get_heuristics_list<1>( - {"heuristics_instant"}, opGraph, allowAllConfig, filtered_configs, true); - - if (filtered_configs.size() == 0) { - cudnn_frontend::set_error_and_throw_exception( - nullptr, CUDNN_STATUS_NOT_SUPPORTED, - "run_mha_fprop: No config returned by the heuristics"); - } - auto plan = cudnn_frontend::ExecutionPlanBuilder() - .setHandle(handle) - .setEngineConfig(filtered_configs[0], opGraph.getTag()) - .build(); - cache.insert({descriptor, plan}); - return plan; - }; - - auto plan = get_plan(fmha_fprop_cache, descriptor); - - auto plan_workspace_size = plan.getWorkspaceSize(); - - // Exit to request upper level API to allocate memory if needed - if (workspace == nullptr) { - size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); - *workspace_size = plan_workspace_size + actual_seqlen_workspace_size; - return; - } - - // cuDNN stream check needs to be moved here to support dummy kernel calls with - // null streams for sizing the cuDNN workspace. - NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); - - // Prepare actual seqlen - constexpr size_t nthreads_per_block = 128; - const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; - void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; - void *devActualSeqlenK = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); - cu_seqlens_to_actual_seqlens<<>>( - b, b, static_cast(devPtrCuSeqlenQ), - static_cast(devPtrCuSeqlenKV), static_cast(devActualSeqlenQ), - static_cast(devActualSeqlenK)); - NVTE_CHECK_CUDA(cudaGetLastError()); - - // change this if you have access to float_min - float negInfinity = -1.0E+10; - float scale_dropout = 1 / (1 - dropout_probability); - - std::set> data_ptrs; - // add all the data pointers to be used in the variant pack - data_ptrs.insert(std::pair(Q_ID, devPtrQ)); - data_ptrs.insert(std::pair(K_ID, devPtrK)); - data_ptrs.insert(std::pair(V_ID, devPtrV)); - data_ptrs.insert(std::pair(Q_SEQLEN_ID, devActualSeqlenQ)); - data_ptrs.insert(std::pair(K_SEQLEN_ID, devActualSeqlenK)); - data_ptrs.insert(std::pair(MASK_VAL_ID, &negInfinity)); - - __half half_cast_scaling_factor{scaling_factor}; - __nv_bfloat16 bfloat_cast_scaling_factor{scaling_factor}; - - if (tensorType == CUDNN_DATA_FLOAT) { - data_ptrs.insert(std::pair(S_CONST_ID, &scaling_factor)); - } else if (tensorType == CUDNN_DATA_HALF) { - data_ptrs.insert(std::pair(S_CONST_ID, &half_cast_scaling_factor)); - } else if (tensorType == CUDNN_DATA_BFLOAT16) { - data_ptrs.insert(std::pair(S_CONST_ID, &bfloat_cast_scaling_factor)); - } else { - NVTE_ERROR("Unsupported tensor type."); - } - - data_ptrs.insert(std::pair(O_ID, devPtrO)); - - if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) { - data_ptrs.insert(std::pair(B_ID, devPtrBias)); - } - - // if enable_dropout, S is the result after dropout - // if not enable dropout, S is the result after softmax - if (enable_dropout || !softmax_output_virtual) { - data_ptrs.insert(std::pair(S_ID, devPtrS)); - } - - __half half_cast_scale_dropout{scale_dropout}; - __nv_bfloat16 bfloat16_cast_scale_dropout{scale_dropout}; - - if (enable_dropout) { - // TODO(rewang): make a util func - if (tensorType == CUDNN_DATA_FLOAT) { - data_ptrs.insert(std::pair(DROPOUT_CONST_ID, &scale_dropout)); - } else if (tensorType == CUDNN_DATA_HALF) { - data_ptrs.insert(std::pair(DROPOUT_CONST_ID, &half_cast_scale_dropout)); - } else if (tensorType == CUDNN_DATA_BFLOAT16) { - data_ptrs.insert( - std::pair(DROPOUT_CONST_ID, &bfloat16_cast_scale_dropout)); - } else { - NVTE_ERROR("Unsupported tensor type."); - } - data_ptrs.insert(std::pair(DROPOUT_SEED_ID, devPtrDropoutSeed)); - data_ptrs.insert(std::pair(DROPOUT_OFFSET_ID, devPtrDropoutOffset)); - } - - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace) - .setDataPointers(data_ptrs) - .build(); - - NVTE_CHECK_CUDNN(cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc())); - } catch (cudnn_frontend::cudnnException &e) { - NVTE_ERROR(e.what()); - } -} - -void fused_attn_max_512_bwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, - float scaling_factor, float dropout_probability, - NVTE_QKV_Layout layout, NVTE_Mask_Type mask_type, - NVTE_Bias_Type bias_type, void *devPtrQ, void *devPtrK, - void *devPtrV, void *devPtrS, void *devPtrdQ, void *devPtrdK, - void *devPtrdV, void *devPtrdO, void *devPtrdS, void *devPtrdBias, - void *devPtrCuSeqlenQ, void *devPtrCuSeqlenKV, void *workspace, - size_t *workspace_size, cudnnDataType_t tensorType, - cudaStream_t stream, cudnnHandle_t handle) { - try { - FADescriptor descriptor{ - b, h, s_q, s_kv, d, scaling_factor, true, dropout_probability, - layout, bias_type, mask_type, tensorType, false}; - - using CacheType = std::map; - static thread_local CacheType fmha_bprop_cache; - - auto get_plan = [&](CacheType &cache, const FADescriptor &descriptor) { - auto it = cache.find(descriptor); - if (it != cache.end()) { - return it->second; - } - - std::vector all_ops; - std::vector ops; - - // Creates the necessary tensor descriptors - int64_t q_dim[4] = {b, h, s_q, d}; - int64_t q_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - - int64_t k_dim[4] = {b, h, s_kv, d}; - int64_t k_stride[4]; - generateMatrixStrides( - b, h, s_q, s_kv, d, k_stride, layout, - NVTE_QKV_Matrix::NVTE_K_Matrix); // type is correct as K is not transposed - - int64_t v_dim[4] = {b, h, d, s_kv}; - int64_t v_stride[4]; - generateMatrixStrides( - b, h, s_q, s_kv, d, v_stride, layout, - NVTE_QKV_Matrix::NVTE_V_Matrix_Transpose); // type is correct as V is transposed - - int64_t p_dim[4] = {b, h, s_q, s_kv}; - int64_t p_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, p_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); - - int64_t p_transpose_dim[4] = {b, h, s_kv, s_q}; - int64_t p_transpose_stride[4]; - p_transpose_stride[0] = p_stride[0]; - p_transpose_stride[1] = p_stride[1]; - p_transpose_stride[2] = p_stride[3]; - p_transpose_stride[3] = p_stride[2]; - - int64_t o_dim[4] = {b, h, s_q, d}; - int64_t o_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - - int64_t seqlen_dim[4] = {b, 1, 1, 1}; - int64_t seqlen_stride[4] = {1, 1, 1, 1}; - - int64_t scale_dim[4] = {1, 1, 1, 1}; - int64_t scale_stride[4] = {1, 1, 1, 1}; - - // inputs to fprop - auto qTensor = tensor_create(tensorType, Q_ID, q_dim, q_stride, false, false); - auto kTensor = tensor_create(tensorType, K_ID, k_dim, k_stride, false, false); - auto vTensor = tensor_create(tensorType, V_ID, v_dim, v_stride, false, false); - auto seqlenQTensor = - tensor_create(CUDNN_DATA_INT32, Q_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); - auto seqlenKTensor = - tensor_create(CUDNN_DATA_INT32, K_SEQLEN_ID, seqlen_dim, seqlen_stride, false, false); - - // gradient of the output - auto doTensor = tensor_create(tensorType, dO_ID, o_dim, o_stride, false, false); - - auto reorder_type = cudnn_frontend::TensorReordering_t::F16x16; - - // activation from fprop - auto pTensor = cudnn_frontend::TensorBuilder() - .setDim(4, p_dim) - .setStride(4, p_stride) - .setId(S_ID) - .setAlignment(16) // 16B alignment is needed to run a tensor core engine - .setDataType(tensorType) - .setVirtual(false) - .setByValue(false) - .setReorderType(reorder_type) - .build(); - - // outputs from bprop - auto dqTensor = tensor_create(tensorType, dQ_ID, q_dim, q_stride, false, false); - auto dkTensor = tensor_create(tensorType, dK_ID, k_dim, k_stride, false, false); - auto dvTensor = tensor_create(tensorType, dV_ID, k_dim, k_stride, false, - false); // not transposed therefore k_dim and k_stride - - //////////////////////////////////////////////////////// - // start creating the ops and the intermediate tensors - auto pReshapeTensor = tensor_create(tensorType, VIRTUAL_ID + 300, p_transpose_dim, - p_transpose_stride, true, false); - - // reshape to perform transpose and make pReshape - auto reshape_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) - .setxDesc(pTensor) - .setyDesc(pReshapeTensor) - .build(); - - ops.push_back(std::move(reshape_op)); - - // scale dropout - auto dropoutScaleTensor = tensor_create(CUDNN_DATA_FLOAT, DROPOUT_CONST_ID, scale_dim, - scale_stride, false, true); // is by value - auto pAfterScaleTensor = tensor_create(tensorType, VIRTUAL_ID + 301, p_transpose_dim, - p_transpose_stride, true, false); - - auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - auto scaleMul_op = - binary_pw_op_create(pReshapeTensor, dropoutScaleTensor, pAfterScaleTensor, scaleMulDesc); - ops.push_back(std::move(scaleMul_op)); - - // perform absolute operation to remove the mask bit - auto pTransposeAfterAbsTensor = tensor_create(tensorType, VIRTUAL_ID + 302, p_transpose_dim, - p_transpose_stride, true, false); - - auto absDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_ABS); - auto abs_op = unary_pw_op_create(pAfterScaleTensor, pTransposeAfterAbsTensor, absDesc); - ops.push_back(std::move(abs_op)); - - // matmul to calculate dvTensor - // set padding value optionally to 0 for writing zeros to dV tensor (if not set, old - // behaviour) - auto matmul_0_Desc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(0.0f) - .build(); - - auto matmul_op0 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(pTransposeAfterAbsTensor) - .setbMatDesc(doTensor) - .setcMatDesc(dvTensor) - .setmOverrideDesc(seqlenKTensor) - .setkOverrideDesc(seqlenQTensor) - .setmatmulDesc(matmul_0_Desc) - .build(); - - ops.push_back(std::move(matmul_op0)); - - // matmul to calculate dpTensor - auto dpTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 303, p_dim, p_stride, true, false); - - auto matmul_1_Desc = - cudnn_frontend::MatMulDescBuilder().setComputeType(CUDNN_DATA_FLOAT).build(); - - auto matmul_op1 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(doTensor) - .setbMatDesc(vTensor) - .setcMatDesc(dpTensor) - .setmOverrideDesc(seqlenQTensor) - .setnOverrideDesc(seqlenKTensor) - .setmatmulDesc(matmul_1_Desc) - .build(); - - ops.push_back(std::move(matmul_op1)); - - // mask the values which were dropped in dropout - auto pAbsTensor = tensor_create(tensorType, VIRTUAL_ID + 304, p_dim, p_stride, true, false); - - auto p_absDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_ABS); - auto p_abs_op = unary_pw_op_create(pTensor, pAbsTensor, p_absDesc); - ops.push_back(std::move(p_abs_op)); - - // create the dropout mask - auto zeroTensor = tensor_create(CUDNN_DATA_FLOAT, MASK_VAL_ID, scale_dim, scale_stride, false, - true); // is by value - auto dropoutMaskTensor = - tensor_create(CUDNN_DATA_BOOLEAN, VIRTUAL_ID + 305, p_dim, p_stride, true, false); - - auto greater_than_0_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_CMP_GT); - auto greater_than_0_op = - binary_pw_op_create(pTensor, zeroTensor, dropoutMaskTensor, greater_than_0_desc); - ops.push_back(std::move(greater_than_0_op)); - - // scale for the dropout - auto dpAfterScaleTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 306, p_dim, p_stride, true, false); - - auto mul_0_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - auto mul_0_op = - binary_pw_op_create(dpTensor, dropoutScaleTensor, dpAfterScaleTensor, mul_0_desc); - ops.push_back(std::move(mul_0_op)); - - // drop the values based on the dropout mask - auto dpAfterDropoutTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 307, p_dim, p_stride, true, false); - - auto selection_0_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_BINARY_SELECT); - auto selection_0_op = ternary_pw_op_create(dpAfterScaleTensor, zeroTensor, dropoutMaskTensor, - dpAfterDropoutTensor, selection_0_desc); - ops.push_back(std::move(selection_0_op)); - - // softmax backward - auto dsTensor = createSoftmaxBackward(b, h, s_q, s_kv, d, layout, tensorType, ops, pAbsTensor, - dpAfterDropoutTensor); - - // mask - auto dsAfterMaskTensor = - createMask(b, h, s_q, s_kv, d, layout, mask_type, tensorType, ops, dsTensor, true); - - // dbias tensor - int64_t dbias_dim[4] = {1, h, s_q, s_kv}; - int64_t dbias_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; - auto dBiasTensor = tensor_create(tensorType, dBias_ID, dbias_dim, dbias_stride, false, false); - - if (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS) { - auto softmaxScaleTensor = - tensor_create(CUDNN_DATA_FLOAT, S_CONST_ID, scale_dim, scale_stride, false, true); - auto softmaxScaleReciprocalTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 401, scale_dim, scale_stride, true, false); - auto dbiasBeforeScaleTensor = - tensor_create(CUDNN_DATA_FLOAT, VIRTUAL_ID + 402, dbias_dim, dbias_stride, true, false); - - // Define the reduction descriptor - auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) - .build(); - - // Create a reduction add node to compute the dbias - auto reductionAdd_op = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(dsAfterMaskTensor) - .setyDesc(dbiasBeforeScaleTensor) - .setreductionDesc(reductionAddDesc) - .build(); - ops.push_back(std::move(reductionAdd_op)); - - // take the reciprocal of the scale - auto reciprocal_scale_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_RECIPROCAL); - auto reciprocal_scale_op = unary_pw_op_create( - softmaxScaleTensor, softmaxScaleReciprocalTensor, reciprocal_scale_desc); - ops.push_back(std::move(reciprocal_scale_op)); - - // apply the scale - auto dBias_scale_desc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - auto dBias_scale_op = binary_pw_op_create( - dbiasBeforeScaleTensor, softmaxScaleReciprocalTensor, dBiasTensor, dBias_scale_desc); - ops.push_back(std::move(dBias_scale_op)); - } - - // matmul to calculate dqTensor - // set padding value optionally to 0 for writing zeros to dqTensor (if not set, old - // behaviour) - auto matmul_2_Desc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(0.0f) - .build(); - - auto matmul_op2 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(dsAfterMaskTensor) - .setbMatDesc(kTensor) - .setcMatDesc(dqTensor) - .setmOverrideDesc(seqlenQTensor) - .setkOverrideDesc(seqlenKTensor) - .setmatmulDesc(matmul_2_Desc) - .build(); - - ops.push_back(std::move(matmul_op2)); - - // reshape for transpose of ds - auto dsAfterMaskReshapeTensor = tensor_create(tensorType, VIRTUAL_ID + 308, p_transpose_dim, - p_transpose_stride, true, false); - - auto reshape_2_op = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) - .setxDesc(dsAfterMaskTensor) - .setyDesc(dsAfterMaskReshapeTensor) - .build(); - - ops.push_back(std::move(reshape_2_op)); - - // matmul to calculate dkTensor - // set padding value optionally to 0 for writing zeros to dktensor (if not set, old - // behaviour) - auto matmul_3_Desc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(0.0f) - .build(); - - auto matmul_op3 = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(dsAfterMaskReshapeTensor) - .setbMatDesc(qTensor) - .setcMatDesc(dkTensor) - .setmOverrideDesc(seqlenKTensor) - .setkOverrideDesc(seqlenQTensor) - .setmatmulDesc(matmul_3_Desc) - .build(); - - ops.push_back(std::move(matmul_op3)); - - ///////////////////////////////////////////////////////////////// - - for (unsigned int i = 0; i < ops.size(); i++) { - all_ops.push_back(&ops[i]); - } - - // Create an Operation Graph - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle) - .setOperationGraph(all_ops.size(), all_ops.data()) - .build(); - - cudnn_frontend::EngineConfigList filtered_configs; - auto statuses = cudnn_frontend::get_heuristics_list<1>( - {"heuristics_instant"}, opGraph, allowAllConfig, filtered_configs, true); - - if (filtered_configs.size() == 0) { - cudnn_frontend::set_error_and_throw_exception( - nullptr, CUDNN_STATUS_NOT_SUPPORTED, - "run_mha_bprop: No config returned by the heuristics"); - } - - auto plan = cudnn_frontend::ExecutionPlanBuilder() - .setHandle(handle) - .setEngineConfig(filtered_configs[0], opGraph.getTag()) - .build(); - cache.insert({descriptor, plan}); - return plan; - }; - - auto plan = get_plan(fmha_bprop_cache, descriptor); - - auto plan_workspace_size = plan.getWorkspaceSize(); - - // Exit to request upper level API to allocate memory if needed - if (workspace == nullptr) { - size_t actual_seqlen_workspace_size = 2 * b * sizeof(int32_t); - *workspace_size = plan_workspace_size + actual_seqlen_workspace_size; - return; - } - - // cuDNN stream check needs to be moved here to support dummy kernel calls with - // null streams for sizing the cuDNN workspace. - NVTE_CHECK_CUDNN(cudnnSetStream(handle, stream)); - - constexpr size_t nthreads_per_block = 128; - const size_t grid = (b + nthreads_per_block - 1) / nthreads_per_block; - void *devActualSeqlenQ = static_cast(workspace) + plan_workspace_size; - void *devActualSeqlenK = static_cast(devActualSeqlenQ) + b * sizeof(int32_t); - cu_seqlens_to_actual_seqlens<<>>( - b, b, static_cast(devPtrCuSeqlenQ), - static_cast(devPtrCuSeqlenKV), static_cast(devActualSeqlenQ), - static_cast(devActualSeqlenK)); - NVTE_CHECK_CUDA(cudaGetLastError()); - - std::set> data_ptrs; - // add all the data pointers to be used in the variant pack - data_ptrs.insert(std::pair(dQ_ID, devPtrdQ)); - data_ptrs.insert(std::pair(dK_ID, devPtrdK)); - data_ptrs.insert(std::pair(dV_ID, devPtrdV)); - - data_ptrs.insert(std::pair(Q_ID, devPtrQ)); - data_ptrs.insert(std::pair(K_ID, devPtrK)); - data_ptrs.insert(std::pair(V_ID, devPtrV)); - data_ptrs.insert(std::pair(S_ID, devPtrS)); - data_ptrs.insert(std::pair(dO_ID, devPtrdO)); - data_ptrs.insert(std::pair(dS_ID, devPtrdS)); - data_ptrs.insert(std::pair(Q_SEQLEN_ID, devActualSeqlenQ)); - data_ptrs.insert(std::pair(K_SEQLEN_ID, devActualSeqlenK)); - - if (bias_type != NVTE_Bias_Type::NVTE_NO_BIAS) { - data_ptrs.insert(std::pair(dBias_ID, devPtrdBias)); - } - - float zeroVal = 0.0f; - float dropoutScale = 1.0f / (1.0f - dropout_probability); - - data_ptrs.insert(std::pair(DROPOUT_CONST_ID, &dropoutScale)); - data_ptrs.insert(std::pair(S_CONST_ID, &scaling_factor)); - data_ptrs.insert(std::pair(MASK_VAL_ID, &zeroVal)); - - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace) - .setDataPointers(data_ptrs) - .build(); - - NVTE_CHECK_CUDNN(cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc())); - } catch (cudnn_frontend::cudnnException &e) { - NVTE_ERROR(e.what()); - } -} - -} // namespace fused_attn - -using namespace transformer_engine::fused_attn; -void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, - const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens, - const Tensor *kv_cu_seqlens, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle) { - using namespace transformer_engine; - - void *devPtrQ = input_Q->data.dptr; - void *devPtrK = input_K->data.dptr; - void *devPtrV = input_V->data.dptr; - - void *devPtrBias = input_Bias->data.dptr; - - void *devPtrO = output_O->data.dptr; - - void *devPtrS = nullptr; - - const DType q_type = input_Q->data.dtype; - const DType kv_type = input_K->data.dtype; - NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV."); - - if (Aux_CTX_Tensors->size == 0) { - Aux_CTX_Tensors->size = 1; - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - output_S->data.dptr = nullptr; - output_S->data.shape = {batch, num_head, q_max_seqlen, kv_max_seqlen}; - output_S->data.dtype = q_type; - } else if (Aux_CTX_Tensors->size == 1) { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[0]); - devPtrS = output_S->data.dptr; - } else { - NVTE_ERROR("Unexpected Aux_CTX_Tensors->size."); - } - - void *devQCuSeqlen = q_cu_seqlens->data.dptr; - void *devKVCuSeqlen = kv_cu_seqlens->data.dptr; - - const DType rng_state_type = rng_state->data.dtype; - NVTE_CHECK(rng_state_type == DType::kInt64); - void *devPtrDropoutSeed = rng_state->data.dptr; - void *devPtrDropoutOffset = - static_cast(static_cast(rng_state->data.dptr) + 1); - - size_t workspace_size = 0; - - fused_attn_max_512_fwd_impl( - batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, is_training, attn_scale, p_dropout, - qkv_layout, bias_type, mask_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrO, devPtrBias, - devQCuSeqlen, devKVCuSeqlen, devPtrDropoutSeed, devPtrDropoutOffset, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(q_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} - -void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, const Tensor *input_dO, Tensor *output_S, - Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, - Tensor *output_dBias, const Tensor *q_cu_seqlens, - const Tensor *kv_cu_seqlens, Tensor *workspace, cudaStream_t stream, - cudnnHandle_t handle) { - using namespace transformer_engine; - - void *devPtrQ = input_Q->data.dptr; - void *devPtrK = input_K->data.dptr; - void *devPtrV = input_V->data.dptr; - - void *devPtrdO = input_dO->data.dptr; - - void *devPtrdQ = output_dQ->data.dptr; - void *devPtrdK = output_dK->data.dptr; - void *devPtrdV = output_dV->data.dptr; - - void *devPtrdBias = output_dBias->data.dptr; - - void *devPtrS = output_S->data.dptr; - - // devPtrdS reuses the memory of devPtrS - void *devPtrdS = devPtrS; - - void *devPtrQCuSeqlens = q_cu_seqlens->data.dptr; - void *devPtrKVCuSeqlens = kv_cu_seqlens->data.dptr; - - const auto q_type = input_Q->data.dtype; - const auto kv_type = input_K->data.dtype; - NVTE_CHECK(q_type == kv_type, "data type of Q must be equal to data type of KV."); - size_t workspace_size = 0; - - fused_attn_max_512_bwd_impl( - batch, num_head, q_max_seqlen, kv_max_seqlen, head_dim, attn_scale, p_dropout, qkv_layout, - mask_type, bias_type, devPtrQ, devPtrK, devPtrV, devPtrS, devPtrdQ, devPtrdK, devPtrdV, - devPtrdO, devPtrdS, devPtrdBias, devPtrQCuSeqlens, devPtrKVCuSeqlens, workspace->data.dptr, - &workspace_size, get_cudnn_dtype(q_type), stream, handle); - - if (workspace_size > 0) { - if (workspace->data.dptr == nullptr) { - workspace->data.shape = {workspace_size}; - workspace->data.dtype = DType::kByte; - return; - } - } else if (workspace_size == 0) { - workspace->data.shape = {1}; - workspace->data.dtype = DType::kByte; - return; - } else { - NVTE_ERROR("Unexpected workspace_size."); - } -} -} // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h b/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h deleted file mode 100644 index 1e59d4dc8f..0000000000 --- a/transformer_engine/common/fused_attn/fused_attn_f16_max512_seqlen.h +++ /dev/null @@ -1,41 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -/*! \file fused_attn_fp16_bf16_max_seqlen_512.h - * \brief Functions for fused attention for half precision with seqlen <= 512 - */ - -#ifndef TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_MAX_512_H_ -#define TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_MAX_512_H_ - -#include - -#include "common/common.h" -#include "transformer_engine/fused_attn.h" - -namespace transformer_engine { -void fused_attn_max_512_fwd(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, bool is_training, - float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, - const Tensor *input_Bias, Tensor *output_O, - NVTETensorPack *Aux_CTX_Tensors, const Tensor *q_cu_seqlens, - const Tensor *kv_cu_seqlens, const Tensor *rng_state, Tensor *workspace, - cudaStream_t stream, cudnnHandle_t handle); - -void fused_attn_max_512_bwd(size_t batch, size_t num_head, size_t q_max_seqlen, - size_t kv_max_seqlen, size_t head_dim, float attn_scale, - float p_dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, const Tensor *input_Q, const Tensor *input_K, - const Tensor *input_V, const Tensor *input_dO, Tensor *output_S, - Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, - Tensor *output_dBias, const Tensor *q_cu_seqlens, - const Tensor *kv_cu_seqlens, Tensor *workspace, cudaStream_t stream, - cudnnHandle_t handle); -} // namespace transformer_engine - -#endif // TRANSFORMER_ENGINE_COMMON_FUSED_ATTN_FUSED_ATTN_MAX_512_H_ diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index d97f388459..eab1ae02e6 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -15,1648 +15,14 @@ namespace fused_attn { using namespace transformer_engine; -std::unordered_map tensor_name_to_uid = {{"Q", 1}, - {"K", 2}, - {"V", 3}, - {"O", 4}, - {"S", 5}, - {"B", 6}, - {"DROPOUT_SCALE", 7}, - {"S_CONST", 8}, - {"MNK_OVERRIDE", 9}, - {"dQ", 11}, - {"dK", 12}, - {"dV", 13}, - {"dO", 14}, - {"MASK_VAL", 15}, - {"dS", 16}, - {"O_SEQLEN", 17}, - {"M", 18}, - {"Z", 19}, - {"descaleQ", 20}, - {"descaleK", 21}, - {"descaleV", 22}, - {"descaleS", 23}, - {"scaleS", 24}, - {"amaxS", 25}, - {"amaxO", 26}, - {"QKV_RAGGED", 27}, - {"O_RAGGED", 28}, - {"K_TRANSPOSE", 29}, - {"AttnScale", 30}, - {"scaleO", 31}, - {"Z_INV", 32}, - {"descaleO", 33}, - {"descaledO", 34}, - {"descaledS", 35}, - {"descaledQ", 36}, - {"descaledK", 37}, - {"descaledV", 38}, - {"scaledS", 39}, - {"scaledQ", 40}, - {"scaledK", 41}, - {"scaledV", 42}, - {"amaxdS", 43}, - {"amaxdQ", 44}, - {"amaxdK", 45}, - {"amaxdV", 46}, - {"V_TRANSPOSE", 47}, - {"AttnScale_dS_K", 48}, - {"AttnScale_dSTranspose_Q", 49}, - {"DROPOUT_SCALE_dOVt_OdO", 50}, - {"DROPOUT_OFFSET", 51}, - {"DROPOUT_SEED", 52}, - {"VIRTUAL", 80}}; - -static cudnn_frontend::Tensor createAmax(const std::string& amax_tensor_name, - const cudnn_frontend::Tensor& prevBlockOutputTensor, - std::vector* ops) { - int64_t amax_dim[4] = {1, 1, 1, 1}; - int64_t amax_stride[4] = {1, 1, 1, 1}; - auto amaxTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid[amax_tensor_name], amax_dim, - amax_stride, false, false); - - // Define the amax descriptor - auto reductionDesc = cudnn_frontend::ReductionDescBuilder() - .setMathPrecision(CUDNN_DATA_FLOAT) - .setReductionOp(CUDNN_REDUCE_TENSOR_AMAX) - .build(); - - // Create a reduction amax Node - auto reduction_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(prevBlockOutputTensor) - .setyDesc(amaxTensor) - .setreductionDesc(reductionDesc) - .build(); - ops->push_back(std::move(reduction_op)); - return amaxTensor; -} - -static cudnn_frontend::Tensor createScale(const cudnn_frontend::Tensor& prevBlockOutputTensor, - const std::string& scale_tensor_name, - cudnnDataType_t tensorType, bool isOutputVirtual, - bool isScaleByValue, - std::vector* ops, - const std::string& output_tensor_name = "") { - int64_t scale_dim[4] = {1, 1, 1, 1}; - int64_t scale_stride[4] = {1, 1, 1, 1}; - - int64_t output_dim[4]; - int64_t output_stride[4]; - - for (int i = 0; i < 4; i++) { - output_dim[i] = prevBlockOutputTensor.getDim()[i]; - output_stride[i] = prevBlockOutputTensor.getStride()[i]; - } - - auto scaleTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid[scale_tensor_name], - scale_dim, scale_stride, false, isScaleByValue); // is by value - - int64_t outputUID = - isOutputVirtual ? tensor_name_to_uid["VIRTUAL"] + tensor_name_to_uid[scale_tensor_name] + 5000 - : tensor_name_to_uid[output_tensor_name]; - auto afterScaleKTensor = tensor_create(tensorType, outputUID, output_dim, output_stride, - isOutputVirtual, false); // is virtual - - // Define the scale descriptor - auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a Scale Node - auto scale_op = - binary_pw_op_create(prevBlockOutputTensor, scaleTensor, afterScaleKTensor, scaleDesc); - - ops->push_back(std::move(scale_op)); - return afterScaleKTensor; -} - -static cudnn_frontend::Tensor createScale(const cudnn_frontend::Tensor& prevBlockOutputTensor, - const cudnn_frontend::Tensor& scaleTensor, - cudnnDataType_t tensorType, bool isOutputVirtual, - bool isScaleByValue, - std::vector* ops, - int UID_offset, - const std::string& output_tensor_name = "") { - int64_t output_dim[4]; - int64_t output_stride[4]; - for (int i = 0; i < 4; i++) { - output_dim[i] = prevBlockOutputTensor.getDim()[i]; - output_stride[i] = prevBlockOutputTensor.getStride()[i]; - } - - int64_t outputUID = isOutputVirtual ? tensor_name_to_uid["VIRTUAL"] + UID_offset - : tensor_name_to_uid[output_tensor_name]; - auto afterScaleTensor = tensor_create(tensorType, outputUID, output_dim, output_stride, - isOutputVirtual, false); // is virtual - - // Define the scale descriptor - auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a Scale Node - auto scale_op = - binary_pw_op_create(prevBlockOutputTensor, scaleTensor, afterScaleTensor, scaleDesc); - - ops->push_back(std::move(scale_op)); - return afterScaleTensor; -} - -static cudnn_frontend::Tensor createScaleWithOffset( - const cudnn_frontend::Tensor& prevBlockOutputTensor, const std::string& scale_tensor_name, - NVTE_QKV_Layout layout, cudnnDataType_t tensorType, bool isOutputVirtual, bool isScaleByValue, - std::vector* ops, - std::shared_ptr offsetTensor, - const std::string& output_tensor_name = "") { - int64_t scale_dim[4] = {1, 1, 1, 1}; - int64_t scale_stride[4] = {1, 1, 1, 1}; - - int64_t output_dim[4]; - int64_t output_stride[4]; - // If output tensor is dQ, dK, or dV, we need to generate QKV interleaved strides - if (output_tensor_name == "dQ" || output_tensor_name == "dK" || output_tensor_name == "dV") { - for (int i = 0; i < 4; i++) { - output_dim[i] = prevBlockOutputTensor.getDim()[i]; - } - generateMatrixStrides(output_dim[0], output_dim[1], output_dim[2], - 0 /*s_kv = 0 for placeholder*/, output_dim[3], output_stride, layout, - NVTE_QKV_Matrix::NVTE_Q_Matrix); - } else { - // Otherwise output dim and stride should be the same as prev block dim and stride - for (int i = 0; i < 4; i++) { - output_dim[i] = prevBlockOutputTensor.getDim()[i]; - output_stride[i] = prevBlockOutputTensor.getStride()[i]; - } - } - - auto scaleTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid[scale_tensor_name], - scale_dim, scale_stride, false, isScaleByValue); // is by value - - cudnnDataType_t outputDataType = isOutputVirtual ? CUDNN_DATA_FLOAT : tensorType; - int64_t outputUID = - isOutputVirtual ? tensor_name_to_uid["VIRTUAL"] + tensor_name_to_uid[scale_tensor_name] + 7000 - : tensor_name_to_uid[output_tensor_name]; - auto afterScaleTensor = - tensor_create_with_offset(outputDataType, outputUID, output_dim, output_stride, - isOutputVirtual, false, offsetTensor); // is virtual - - // Define the scale descriptor - auto scaleDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a Scale Node - auto scale_op = - binary_pw_op_create(prevBlockOutputTensor, scaleTensor, afterScaleTensor, scaleDesc); - - ops->push_back(std::move(scale_op)); - return afterScaleTensor; -} - -static cudnn_frontend::Tensor createSoftmaxForward( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, std::vector* ops, - const cudnn_frontend::Tensor& prevBlockOutputTensor, bool isTraining) { - int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; - int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; - - int64_t afterReduction_dim[4] = {b, h, s_q, 1}; - int64_t afterReduction_stride[4] = {h * s_q, s_q, 1, 1}; - - // max (x) (M tensor) - auto afterMaxReductionTensor = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["M"], afterReduction_dim, - afterReduction_stride, !isTraining, false); // not virtual if training is true, - // virtual if training is false - // x - max(x) - auto afterSubtractionTensor = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 151, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - // e^(x - max(x)) - auto afterExponentTensor = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 152, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual; - // sum (e^(x - max(x))) (Z tensor) - auto zTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["Z"], afterReduction_dim, - afterReduction_stride, true, false); // is virtual - // 1 / sum (e^(x - max(x))) (Z_INV tensor) - auto zInvTensor = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["Z_INV"], afterReduction_dim, - afterReduction_stride, !isTraining, false); // not virtual if training is true, - // virtual if training is false - // Final softmax output (After exponent * Z_INV) - auto beforeDropoutTensor = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 153, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - - // Define the reduction descriptor - auto reductionMaxDesc = cudnn_frontend::ReductionDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setReductionOp(CUDNN_REDUCE_TENSOR_MAX) - .build(); - - // Create a reduction max Node - auto reductionMax_op = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(prevBlockOutputTensor) - .setyDesc(afterMaxReductionTensor) - .setreductionDesc(reductionMaxDesc) - .build(); - - // Define the subtract descriptor - auto subtractDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB); - - // Create a subtract Node - auto subtract_op = binary_pw_op_create(prevBlockOutputTensor, afterMaxReductionTensor, - afterSubtractionTensor, subtractDesc); - - // Define the exponent descriptor - auto exponentDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_EXP); - - // Create a exponent Node - auto exponent_op = unary_pw_op_create(afterSubtractionTensor, afterExponentTensor, exponentDesc); - - // Define the reduction descriptor - auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) - .build(); - - // Create a reduction add Node - auto reductionAdd_op = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(afterExponentTensor) - .setyDesc(zTensor) - .setreductionDesc(reductionAddDesc) - .build(); - - // Define the reciprocal descriptor - auto reciprocalDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_RECIPROCAL); - - // Create a reciprocal Node - auto reciprocal_op = unary_pw_op_create(zTensor, zInvTensor, reciprocalDesc); - - // Define the pw multiply descriptor - auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a multiply Node - auto mutliply_op = - binary_pw_op_create(afterExponentTensor, zInvTensor, beforeDropoutTensor, multiplyDesc); - - ops->push_back(std::move(reductionMax_op)); - ops->push_back(std::move(subtract_op)); - ops->push_back(std::move(exponent_op)); - ops->push_back(std::move(reductionAdd_op)); - ops->push_back(std::move(reciprocal_op)); - ops->push_back(std::move(mutliply_op)); - - return beforeDropoutTensor; -} - -static cudnn_frontend::Tensor createDropoutForward( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, double probability, - std::vector* ops, - const cudnn_frontend::Tensor& beforeDropoutTensor) { - NVTE_CHECK(ops->size() > 0, "Dropout DAG constructed incorrectly as the first one"); - - int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; - int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; - - int64_t scale_dim[4] = {1, 1, 1, 1}; - int64_t scale_stride[4] = {1, 1, 1, 1}; - - // Mask for the dropout - auto dropoutMaskTensor = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 250, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - auto dropoutSeedTensor = tensor_create(CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_SEED"], - scale_dim, scale_stride, false, false); // is by value - auto dropoutOffsetTensor = tensor_create(CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_OFFSET"], - scale_dim, scale_stride, false, false); // is by value - - // After dropout tensor befor scale - auto beforeDropoutScaleTensor = - cudnn_frontend::TensorBuilder() - .setDim(4, afterBMM1_dim) - .setStride(4, afterBMM1_stride) - .setId(tensor_name_to_uid["VIRTUAL"] + 201) - .setAlignment(16) // 16B alignment is needed to run a tensor core engine - .setDataType(CUDNN_DATA_FLOAT) - .setVirtual(true) - .setByValue(false) - .setReorderType(cudnn_frontend::TensorReordering_t::F16x16) - .build(); - // Scale after dropout - auto scaleDropoutTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["DROPOUT_SCALE"], - scale_dim, scale_stride, false, true); // is by value - // After Scale - auto afterDropout_before_quan_S = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 202, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - - // Define the reduction descriptor - auto rngDesc = cudnn_frontend::RngDescBuilder() - .setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI) - .setBernoulliDistProbability(1.0 - probability) - .build(); - - // Create a rng Node - auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR) - .setyDesc(dropoutMaskTensor) - .setSeedDesc(dropoutSeedTensor) - .setOffsetDesc(dropoutOffsetTensor) - .setRngDesc(rngDesc) - .build(); - - // Define the multiply mask descriptor - auto maskMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a multiply mask Node - auto maskMul_op = binary_pw_op_create(beforeDropoutTensor, dropoutMaskTensor, - beforeDropoutScaleTensor, maskMulDesc); - - // Define the multiply scale descriptor - auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a multiply mask Node - auto scaleMul_op = binary_pw_op_create(beforeDropoutScaleTensor, scaleDropoutTensor, - afterDropout_before_quan_S, scaleMulDesc); - - ops->push_back(std::move(rng_op)); - ops->push_back(std::move(maskMul_op)); - ops->push_back(std::move(scaleMul_op)); - - return afterDropout_before_quan_S; -} - -static cudnn_frontend::Tensor createDropoutBackward( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, double probability, - std::vector* ops, const cudnn_frontend::Tensor& beforeDropoutTensor, - const cudnn_frontend::Tensor& dropoutMaskTensor) { - NVTE_CHECK(ops->size() > 0, "Dropout DAG constructed incorrectly as the first one"); - - int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; - int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; - - int64_t scale_dim[4] = {1, 1, 1, 1}; - int64_t scale_stride[4] = {1, 1, 1, 1}; - - auto dropoutSeedTensor = tensor_create(CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_SEED"], - scale_dim, scale_stride, false, false); // is by value - auto dropoutOffsetTensor = tensor_create(CUDNN_DATA_INT64, tensor_name_to_uid["DROPOUT_OFFSET"], - scale_dim, scale_stride, false, false); // is by value - - // After dropout tensor befor scale - auto beforeDropoutScaleTensor = - cudnn_frontend::TensorBuilder() - .setDim(4, afterBMM1_dim) - .setStride(4, afterBMM1_stride) - .setId(tensor_name_to_uid["VIRTUAL"] + 201) - .setAlignment(16) // 16B alignment is needed to run a tensor core engine - .setDataType(CUDNN_DATA_FLOAT) - .setVirtual(true) - .setByValue(false) - .setReorderType(cudnn_frontend::TensorReordering_t::F16x16) - .build(); - // Scale after dropout (1 / (1 - p)) - auto scaleDropoutTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["DROPOUT_SCALE"], - scale_dim, scale_stride, false, true); // is by value - // After Scale - auto afterDropout_before_quan_S = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 202, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - - // Define the reduction descriptor - auto rngDesc = cudnn_frontend::RngDescBuilder() - .setRngDistribution(CUDNN_RNG_DISTRIBUTION_BERNOULLI) - .setBernoulliDistProbability(1.0 - probability) - .build(); - - // Create a rng Node - auto rng_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RNG_DESCRIPTOR) - .setyDesc(dropoutMaskTensor) - .setSeedDesc(dropoutSeedTensor) - .setOffsetDesc(dropoutOffsetTensor) - .setRngDesc(rngDesc) - .build(); - - // Define the multiply mask descriptor - auto maskMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a multiply mask Node - auto maskMul_op = binary_pw_op_create(beforeDropoutTensor, dropoutMaskTensor, - beforeDropoutScaleTensor, maskMulDesc); - - // Define the multiply scale descriptor - auto scaleMulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a multiply mask Node - auto scaleMul_op = binary_pw_op_create(beforeDropoutScaleTensor, scaleDropoutTensor, - afterDropout_before_quan_S, scaleMulDesc); - - ops->push_back(std::move(rng_op)); - ops->push_back(std::move(maskMul_op)); - ops->push_back(std::move(scaleMul_op)); - - return afterDropout_before_quan_S; -} - -static cudnn_frontend::Tensor createSoftmaxBackward(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, - std::vector* ops, - const cudnn_frontend::Tensor& dyTensor) { - NVTE_CHECK(ops->size() > 0, "Softmax backward constructed incorrectly as the first one"); - - int64_t dx_dim[4] = {b, h, s_q, s_kv}; - int64_t dx_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; - - int64_t M_Z_dim[4] = {b, h, s_q, 1}; - int64_t M_Z_stride[4] = {h * s_q, s_q, 1, 1}; - - // Creating all tensors - auto MTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["M"], M_Z_dim, M_Z_stride, - false, false); // not virtual - auto ZInvTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["Z_INV"], M_Z_dim, - M_Z_stride, false, false); // not virtual - auto dxAfterSubtractionTensor = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 252, dx_dim, dx_stride, true, - false); // is virtual - auto dxAfterExponentiation = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 253, - dx_dim, dx_stride, true, false); // is virtual - auto dxBeforeDropout_QKt_Tensor = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 254, dx_dim, dx_stride, true, - false); // is virtual - - // Creating all ops - // sub (dy - M) - auto subtractionDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB); - auto subtractionOp = - binary_pw_op_create(dyTensor, MTensor, dxAfterSubtractionTensor, subtractionDesc); - - // Define the exponent descriptor - auto exponentDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_EXP); - - // Create a exponent Node. (exp(dy - M)) - auto exponentOp = - unary_pw_op_create(dxAfterSubtractionTensor, dxAfterExponentiation, exponentDesc); - - // Define the pw multiply descriptor - auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a multiply Node - auto mutliplyOp = binary_pw_op_create(dxAfterExponentiation, ZInvTensor, - dxBeforeDropout_QKt_Tensor, multiplyDesc); - - ops->push_back(std::move(subtractionOp)); - ops->push_back(std::move(exponentOp)); - ops->push_back(std::move(mutliplyOp)); - - return dxBeforeDropout_QKt_Tensor; -} - -static cudnn_frontend::Tensor createQKBMM( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, NVTE_QKV_Layout layout, - cudnnDataType_t tensorType, std::vector* ops, - const cudnn_frontend::Tensor& qTensor, const cudnn_frontend::Tensor& kTensor, - const cudnn_frontend::Tensor& mnkOverride, - std::shared_ptr QKVRaggedOffsetTensor) { - // Creates the necessary tensor descriptors - int64_t k_transpose_dim[4] = {b, h, d, s_kv}; - int64_t k_transpose_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, k_transpose_stride, layout, - NVTE_QKV_Matrix::NVTE_K_Matrix_Transpose); - - int64_t s_dim[4] = {b, h, s_q, s_kv}; - int64_t s_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, s_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); - - auto kTransposeTensor = tensor_create_with_offset(tensorType, tensor_name_to_uid["K_TRANSPOSE"], - k_transpose_dim, k_transpose_stride, false, - false, QKVRaggedOffsetTensor); // is virtual - - // First GEMM output - auto afterQKTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 1, s_dim, - s_stride, true, false); // is virtual - - // Define the matmul desc - auto matmulDesc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(-2000000) - .build(); - - // Create reshape node for K -> K.T - auto reshape_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) - .setxDesc(kTensor) - .setyDesc(kTransposeTensor) - .build(); - - // Create a matmul Node - auto matmulOp = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(qTensor) - .setbMatDesc(kTransposeTensor) - .setcMatDesc(afterQKTensor) - .setmOverrideDesc(mnkOverride) - .setnOverrideDesc(mnkOverride) - .setmatmulDesc(matmulDesc) - .build(); - - ops->push_back(std::move(reshape_op)); - ops->push_back(std::move(matmulOp)); - - return afterQKTensor; -} - -static cudnn_frontend::Tensor createSVBMM( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, NVTE_QKV_Layout layout, - cudnnDataType_t tensorType, std::vector* ops, - const cudnn_frontend::Tensor& softmaxTensor, const cudnn_frontend::Tensor& mnkOverride, - std::shared_ptr QKVRaggedOffsetTensor) { - NVTE_CHECK(ops->size() > 0, "BMM2 op constructed incorrectly as the first one"); - - int64_t v_dim[4] = {b, h, s_kv, d}; - int64_t v_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, v_stride, layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - - int64_t o_dim[4] = {b, h, s_q, d}; - int64_t o_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - - auto vTensor = tensor_create_with_offset(tensorType, tensor_name_to_uid["V"], v_dim, v_stride, - false, false, QKVRaggedOffsetTensor); - // Second fprop GEMM output - auto oTensor = tensor_create(tensorType, tensor_name_to_uid["VIRTUAL"] + 300, o_dim, o_stride, - true, false); // is virtual - - // Define the matmul desc - auto matmulDesc = cudnn_frontend::MatMulDescBuilder().setComputeType(CUDNN_DATA_FLOAT).build(); - - // Create a matmul Node - auto matmulOp = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(softmaxTensor) - .setbMatDesc(vTensor) - .setcMatDesc(oTensor) - .setmOverrideDesc(mnkOverride) - .setkOverrideDesc(mnkOverride) - .setmatmulDesc(matmulDesc) - .build(); - - ops->push_back(std::move(matmulOp)); - - return oTensor; -} - -static cudnn_frontend::Tensor createSdOBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, cudnnDataType_t tensorType, - std::vector* ops, - const cudnn_frontend::Tensor& softmaxTensor, - const cudnn_frontend::Tensor& dOTensor, - const cudnn_frontend::Tensor& mnkOverride) { - NVTE_CHECK(ops->size() > 0, "BMM2 op constructed incorrectly as the first one"); - - int64_t s_dim_transpose[4] = {b, h, s_kv, s_q}; - int64_t s_stride_transpose[4] = {h * s_kv * s_q, s_kv * s_q, 1, s_kv}; - - int64_t v_dim[4] = {b, h, s_kv, d}; - int64_t v_stride[4] = {h * s_kv * d, d, h * d, 1}; - - auto sTransposeTensor = - tensor_create(tensorType, tensor_name_to_uid["VIRTUAL"] + 499, s_dim_transpose, - s_stride_transpose, true, false); // is virtual - // S.T * dO - auto dVTensor_before_dequan_S = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 500, v_dim, v_stride, true, - false); // is virtual - - // Create reshape node for softmax -> softmax.T - auto reshape_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) - .setxDesc(softmaxTensor) - .setyDesc(sTransposeTensor) - .build(); - - // Define the matmul desc - auto matmulDesc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(0) - .build(); - - // Create a matmul Node - auto matmulOp = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(sTransposeTensor) - .setbMatDesc(dOTensor) - .setcMatDesc(dVTensor_before_dequan_S) - .setmOverrideDesc(mnkOverride) - .setkOverrideDesc(mnkOverride) - .setmatmulDesc(matmulDesc) - .build(); - - ops->push_back(std::move(reshape_op)); - ops->push_back(std::move(matmulOp)); - - return dVTensor_before_dequan_S; -} - -static cudnn_frontend::Tensor createdOVBMM( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, NVTE_QKV_Layout layout, - cudnnDataType_t tensorType, std::vector* ops, - const cudnn_frontend::Tensor& dOTensor, const cudnn_frontend::Tensor& mnkOverride, - std::shared_ptr QKVRaggedOffsetTensor) { - // Creates the necessary tensor descriptors - int64_t v_dim[4] = {b, h, s_kv, d}; - int64_t v_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, v_stride, layout, NVTE_QKV_Matrix::NVTE_V_Matrix); - - int64_t v_transpose_dim[4] = {b, h, d, s_kv}; - int64_t v_transpose_stride[4]; - v_transpose_stride[0] = v_stride[0]; - v_transpose_stride[1] = v_stride[1]; - v_transpose_stride[2] = v_stride[3]; - v_transpose_stride[3] = v_stride[2]; - - int64_t s_dim[4] = {b, h, s_q, s_kv}; - int64_t s_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, s_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); - - auto vTensor = tensor_create_with_offset(tensorType, tensor_name_to_uid["V"], v_dim, v_stride, - false, false, QKVRaggedOffsetTensor); - auto vTransposeTensor = tensor_create_with_offset(tensorType, tensor_name_to_uid["V_TRANSPOSE"], - v_transpose_dim, v_transpose_stride, false, - false, QKVRaggedOffsetTensor); // is virtual - - // dO * V.T - auto afterdOVTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 600, s_dim, - s_stride, true, false); // is virtual - - // Define the matmul desc - auto matmulDesc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(0) - .build(); - - // Create reshape node for V -> V.T - auto reshape_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) - .setxDesc(vTensor) - .setyDesc(vTransposeTensor) - .build(); - - // Create a matmul Node - auto matmulOp = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(dOTensor) - .setbMatDesc(vTransposeTensor) - .setcMatDesc(afterdOVTensor) - .setmOverrideDesc(mnkOverride) - .setnOverrideDesc(mnkOverride) - .setmatmulDesc(matmulDesc) - .build(); - - ops->push_back(std::move(reshape_op)); - ops->push_back(std::move(matmulOp)); - - return afterdOVTensor; -} - -static cudnn_frontend::Tensor createdOAndORowReductionChain( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, NVTE_QKV_Layout layout, - std::vector* ops, const cudnn_frontend::Tensor& O_after_dequan, - const cudnn_frontend::Tensor& dO_after_dequan, - const cudnn_frontend::Tensor& dropoutScale_dOVt_OdO_Tensor) { - int64_t o_dim[4] = {b, h, s_q, d}; - int64_t o_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - int64_t o_dim_row_sum[4] = {b, h, s_q, 1}; - int64_t o_dim_row_sum_stride[4] = {s_q * h, s_q, 1, 1}; - - auto O_dO_after_pointwise_multiply = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 700, o_dim, o_stride, true, - false); // is virtual - auto O_dO_after_dropout_scale = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 701, o_dim, o_stride, true, - false); // is virtual - auto O_dO_after_rowsum = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 702, o_dim_row_sum, - o_dim_row_sum_stride, true, false); // is virtual - - // Define the pw multiply descriptor - auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a multiply Node - auto mutliply_op = binary_pw_op_create(O_after_dequan, dO_after_dequan, - O_dO_after_pointwise_multiply, multiplyDesc); - - // Create multiply node with dropout scale - auto dropout_scale_multiply_op = - binary_pw_op_create(O_dO_after_pointwise_multiply, dropoutScale_dOVt_OdO_Tensor, - O_dO_after_dropout_scale, multiplyDesc); - - // Define the reduction descriptor - auto reductionAddDesc = cudnn_frontend::ReductionDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setReductionOp(CUDNN_REDUCE_TENSOR_ADD) - .build(); - - // Create a reduction add Node - auto reductionAdd_op = - cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_REDUCTION_DESCRIPTOR) - .setxDesc(O_dO_after_dropout_scale) - .setyDesc(O_dO_after_rowsum) - .setreductionDesc(reductionAddDesc) - .build(); - - ops->push_back(std::move(mutliply_op)); - ops->push_back(std::move(dropout_scale_multiply_op)); - ops->push_back(std::move(reductionAdd_op)); - - return O_dO_after_rowsum; -} - -static cudnn_frontend::Tensor createBiasSubtractionSoftmaxMulChain( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, NVTE_QKV_Layout layout, - std::vector* ops, const cudnn_frontend::Tensor& dS_after_dropout, - const cudnn_frontend::Tensor& AfterDropout_before_quan_S, - const cudnn_frontend::Tensor& O_dO_after_rowsum, const cudnn_frontend::Tensor& attnScale) { - int64_t o_dim[4] = {b, h, s_q, s_kv}; - int64_t o_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, o_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); - auto dS_minus_O_dO = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 800, o_dim, - o_stride, true, false); // is virtual - auto AfterAttnScale_before_dS = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 801, o_dim, o_stride, true, - false); // is virtual - auto S_mul_dS_minus_O_dO = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 802, - o_dim, o_stride, true, false); // is virtual - - // Define the pw subtraction descriptor - auto subDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_SUB); - - // Create a subtraction Node - auto sub_op = binary_pw_op_create(dS_after_dropout, O_dO_after_rowsum, dS_minus_O_dO, subDesc); - - // Define the pw multiplication descriptor - auto multiplyDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // dS_minus_O_dO * attnScale - auto mutliply_attn_scale_op = - binary_pw_op_create(dS_minus_O_dO, attnScale, AfterAttnScale_before_dS, multiplyDesc); - - // AfterDropout_before_quan_S * AfterAttnScale_before_dS - auto mutliply_op = binary_pw_op_create(AfterDropout_before_quan_S, AfterAttnScale_before_dS, - S_mul_dS_minus_O_dO, multiplyDesc); - - ops->push_back(std::move(sub_op)); - ops->push_back(std::move(mutliply_attn_scale_op)); - ops->push_back(std::move(mutliply_op)); - - return S_mul_dS_minus_O_dO; -} - -static cudnn_frontend::Tensor createdSKBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, std::vector* ops, - const cudnn_frontend::Tensor& dSTensor, - const cudnn_frontend::Tensor& kTensor, - const cudnn_frontend::Tensor& mnkOverride) { - // Creates the necessary tensor descriptors - int64_t after_dSK_dim[4] = {b, h, s_kv, d}; - int64_t after_dSK_stride[4] = {h * s_kv * d, d, h * d, 1}; - // dS * K - auto After_dS_K = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 875, - after_dSK_dim, after_dSK_stride, true, false); // is virtual - - // Define the matmul desc - auto matmulDesc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(0) - .build(); - - // Create a matmul Node - auto matmulOp = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(dSTensor) - .setbMatDesc(kTensor) - .setcMatDesc(After_dS_K) - .setmOverrideDesc(mnkOverride) - .setkOverrideDesc(mnkOverride) - .setmatmulDesc(matmulDesc) - .build(); - - ops->push_back(std::move(matmulOp)); - - return After_dS_K; -} - -static cudnn_frontend::Tensor createdSQBMM(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, - int64_t d, NVTE_QKV_Layout layout, - std::vector* ops, - const cudnn_frontend::Tensor& dSTensor, - const cudnn_frontend::Tensor& qTensor, - const cudnn_frontend::Tensor& mnkOverride) { - // Creates the necessary tensor descriptors - int64_t dS_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, dS_stride, layout, NVTE_QKV_Matrix::NVTE_S_Matrix); - - int64_t dS_transpose_dim[4] = {b, h, s_kv, s_q}; - int64_t dS_transpose_stride[4]; - dS_transpose_stride[0] = dS_stride[0]; - dS_transpose_stride[1] = dS_stride[1]; - dS_transpose_stride[2] = dS_stride[3]; - dS_transpose_stride[3] = dS_stride[2]; - - int64_t after_dSTranspose_Q_dim[4] = {b, h, s_kv, d}; - int64_t after_dSTranspose_Q_stride[4] = {h * s_kv * d, d, h * d, 1}; - - auto dSTransposeTensor = - tensor_create(CUDNN_DATA_FP8_E5M2, tensor_name_to_uid["VIRTUAL"] + 650, dS_transpose_dim, - dS_transpose_stride, true, false); // is virtual - - // dS.T * Q - auto After_dSTranspose_Q = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 651, after_dSTranspose_Q_dim, - after_dSTranspose_Q_stride, true, false); // is virtual - - // Create reshape node for V -> V.T - auto reshape_op = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_RESHAPE_DESCRIPTOR) - .setxDesc(dSTensor) - .setyDesc(dSTransposeTensor) - .build(); - - // Define the matmul desc - auto matmulDesc = cudnn_frontend::MatMulDescBuilder() - .setComputeType(CUDNN_DATA_FLOAT) - .setPaddingValue(0) - .build(); - - // Create a matmul Node - auto matmulOp = cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_MATMUL_DESCRIPTOR) - .setaMatDesc(dSTransposeTensor) - .setbMatDesc(qTensor) - .setcMatDesc(After_dSTranspose_Q) - .setmOverrideDesc(mnkOverride) - .setkOverrideDesc(mnkOverride) - .setmatmulDesc(matmulDesc) - .build(); - - ops->push_back(std::move(reshape_op)); - ops->push_back(std::move(matmulOp)); - - return After_dSTranspose_Q; -} - -// fused attention FWD FP8 with FE 0.9 -void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, - bool isTraining, float attnScale, float dropoutProbability, - NVTE_QKV_Layout layout, void* devPtrQ, void* devPtrK, void* devPtrV, - void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, - void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, - void* devPtrScaleS, void* devPtrScaleO, void* devPtrAmaxO, - void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, - void* devPtrDropoutSeed, void* devPtrDropoutOffset, - cudnnDataType_t tensorType, void* workspace_ptr, - size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle_) { - try { - FADescriptor descriptor{b, - h, - s_q, - s_kv, - d, - attnScale, - isTraining, - dropoutProbability, - layout, - NVTE_Bias_Type::NVTE_NO_BIAS, - NVTE_Mask_Type::NVTE_PADDING_MASK, - tensorType, - false}; - - using CacheType = std::map; - static thread_local CacheType fa_fprop_cache; - - // Get plan from cache if cache is available, otherwise create one - auto get_plan = [&](CacheType& cache, const FADescriptor& descriptor) { - // If hit, return - auto it = cache.find(descriptor); - if (it != cache.end()) { - auto plan = it->second; - return plan; - } - - // Otherwise, build the op_graph and the plan. Then update cache - std::vector all_ops; - std::vector ops; - - NVTE_CHECK(dropoutProbability == 0.0f || isTraining, - "Dropout probability should be 0.0f for inference mode"); - NVTE_CHECK(dropoutProbability != 1.0f, "Dropout probability cannot be 1.0"); - - int64_t raggedDim[4] = {b + 1, 1, 1, 1}; - int64_t raggedStride[4] = {1, 1, 1, 1}; - // Create offset tensors - auto QKVOffsetTensor = tensor_create(CUDNN_DATA_INT32, tensor_name_to_uid["QKV_RAGGED"], - raggedDim, raggedStride, false, false); - auto ORaggedOffsetTensor = tensor_create(CUDNN_DATA_INT32, tensor_name_to_uid["O_RAGGED"], - raggedDim, raggedStride, false, false); - - int64_t seqlen_dim[4] = {b, 1, 1, 1}; - int64_t seqlen_stride[4] = {1, 1, 1, 1}; - // Create override tensors - auto seqlenMNKTensor = tensor_create(CUDNN_DATA_INT32, tensor_name_to_uid["MNK_OVERRIDE"], - seqlen_dim, seqlen_stride, false, false); - - // Create shared ptrs to ragged offset tensors - // for multiple tensors to use ragged offset - std::shared_ptr QKVRaggedOffsetTensorPtr = - std::make_shared(std::move(QKVOffsetTensor)); - std::shared_ptr ORaggedOffsetTensorPtr = - std::make_shared(std::move(ORaggedOffsetTensor)); - - // Create Q and K tensors that are used in different places - int64_t q_dim[4] = {b, h, s_q, d}; - int64_t q_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - - int64_t k_dim[4] = {b, h, s_kv, d}; - int64_t k_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, k_stride, layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - - auto qTensor = tensor_create_with_offset(tensorType, tensor_name_to_uid["Q"], q_dim, q_stride, - false, false, QKVRaggedOffsetTensorPtr); - auto kTensor = tensor_create_with_offset(tensorType, tensor_name_to_uid["K"], k_dim, k_stride, - false, false, QKVRaggedOffsetTensorPtr); - - // Q * K.T - auto afterQKTensor = createQKBMM(b, h, s_q, s_kv, d, layout, tensorType, &ops, qTensor, - kTensor, seqlenMNKTensor, QKVRaggedOffsetTensorPtr); - - // QK.T * attn scale - auto AfterAttnScale_before_dequan_Q_tensor = - createScale(afterQKTensor, // input tensor - "AttnScale", // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - true, // scale is by value - &ops); - - // QK.T * attn scale * dequant_Q - auto AfterAttnScale_before_dequan_K_tensor = - createScale(AfterAttnScale_before_dequan_Q_tensor, // input tensor - "descaleQ", // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - // QK.T * attn scale * dequant_Q * dequant_K - auto AfterAttnScale_tensor = - createScale(AfterAttnScale_before_dequan_K_tensor, // input tensor - "descaleK", // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - auto BeforeDropoutTensor = - createSoftmaxForward(b, h, s_q, s_kv, &ops, AfterAttnScale_tensor, isTraining); - - auto AfterDropout_before_quan_S = - createDropoutForward(b, h, s_q, s_kv, dropoutProbability, &ops, BeforeDropoutTensor); - - // Amax for S - createAmax("amaxS", BeforeDropoutTensor, &ops); - - // After softmax * dropout * scale S -> fp8 input to next bmm with V - auto AfterMultiplyDropout = createScale(AfterDropout_before_quan_S, // input tensor - "scaleS", // scale tensor - tensorType, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - // After softmax * Dropout * V - auto OTensor_before_dequan_S_tensor = - createSVBMM(b, h, s_q, s_kv, d, layout, tensorType, &ops, AfterMultiplyDropout, - seqlenMNKTensor, QKVRaggedOffsetTensorPtr); - - // O * dequant_S - auto OTensor_before_dequan_V_tensor = - createScale(OTensor_before_dequan_S_tensor, // input tensor - "descaleS", // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - // O * dequant_S * dequant_V - auto OTensor_before_quan_O_tensor = - createScale(OTensor_before_dequan_V_tensor, // input tensor - "descaleV", // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - // O * dequant_S * dequant_V * scale O - auto OTensor = createScaleWithOffset(OTensor_before_quan_O_tensor, // input tensor - "scaleO", // scale tensor - layout, // qkv layout - tensorType, // output tensor type - false, // output not virtual - false, // scale is by value - &ops, - ORaggedOffsetTensorPtr, // ragged offset - "O"); - - // Amax for O - createAmax("amaxO", OTensor_before_quan_O_tensor, &ops); - - for (unsigned int i = 0; i < ops.size(); i++) { - all_ops.push_back(&ops[i]); - } - - // Create an Operation Graph - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(all_ops.size(), all_ops.data()) - .build(); - - cudnn_frontend::EngineConfigList filtered_configs; - auto statuses = cudnn_frontend::get_heuristics_list<1>( - {"heuristics_instant"}, opGraph, allowAllConfig, filtered_configs, true); - - if (filtered_configs.size() == 0) { - cudnn_frontend::set_error_and_throw_exception( - nullptr, CUDNN_STATUS_NOT_SUPPORTED, - "run_mha_fprop: No config returned by the heuristics"); - } - - auto plan = cudnn_frontend::ExecutionPlanBuilder() - .setHandle(handle_) - .setEngineConfig(filtered_configs[0], opGraph.getTag()) - .build(); - cache.insert({descriptor, plan}); - return plan; - }; // end of get_plan - - auto plan = get_plan(fa_fprop_cache, descriptor); - size_t wkspace_size = static_cast(plan.getWorkspaceSize()); - - // Exit to request upper level API to allocate memory if needed - if (workspace_ptr == nullptr) { - *workspace_size = wkspace_size + ((b + 1) * 2 + b) * sizeof(int32_t); - return; - } - - // cuDNN stream check needs to be moved here to support dummy kernel calls with - // null streams for sizing the cuDNN workspace. - NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream)); - - int32_t* qkv_ragged_offset = - reinterpret_cast(reinterpret_cast(workspace_ptr) + wkspace_size); - int32_t* o_ragged_offset = reinterpret_cast(reinterpret_cast(workspace_ptr) + - wkspace_size + (b + 1) * sizeof(int32_t)); - int32_t* actual_seqlens_q = reinterpret_cast( - reinterpret_cast(workspace_ptr) + wkspace_size + (b + 1) * 2 * sizeof(int32_t)); - // FP8 currently only supports self-attention, so doesn't use devPtrcuSeqlensKV - dim3 blockDims(128); - dim3 gridDims((b + blockDims.x) / blockDims.x); - cu_seqlens_to_offsets<<>>( - b, h, d, reinterpret_cast(devPtrcuSeqlensQ), actual_seqlens_q, qkv_ragged_offset, - o_ragged_offset); - NVTE_CHECK_CUDA(cudaGetLastError()); - void* devPtrQKVRaggedOffset = reinterpret_cast(qkv_ragged_offset); - void* devPtrORaggedOffset = reinterpret_cast(o_ragged_offset); - void* devPtrMNKOverride = reinterpret_cast(actual_seqlens_q); - - float dropoutScale = 1.0f / (1.0f - dropoutProbability); - - std::set> data_ptrs; - data_ptrs.emplace(std::pair(tensor_name_to_uid["Q"], devPtrQ)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["K"], devPtrK)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["K_TRANSPOSE"], devPtrK)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["V"], devPtrV)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["AttnScale"], &attnScale)); - data_ptrs.emplace( - std::pair(tensor_name_to_uid["DROPOUT_SCALE"], &dropoutScale)); - data_ptrs.emplace( - std::pair(tensor_name_to_uid["DROPOUT_SEED"], devPtrDropoutSeed)); - data_ptrs.emplace( - std::pair(tensor_name_to_uid["DROPOUT_OFFSET"], devPtrDropoutOffset)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["O"], devPtrO)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["descaleQ"], devPtrDescaleQ)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["descaleK"], devPtrDescaleK)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["descaleV"], devPtrDescaleV)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["descaleS"], devPtrDescaleS)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["scaleS"], devPtrScaleS)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["scaleO"], devPtrScaleO)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["amaxO"], devPtrAmaxO)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["amaxS"], devPtrAmaxS)); - data_ptrs.emplace( - std::pair(tensor_name_to_uid["QKV_RAGGED"], devPtrQKVRaggedOffset)); - data_ptrs.emplace( - std::pair(tensor_name_to_uid["O_RAGGED"], devPtrORaggedOffset)); - data_ptrs.emplace( - std::pair(tensor_name_to_uid["MNK_OVERRIDE"], devPtrMNKOverride)); - - // If training, then we need to write out M and Z_INV - if (isTraining) { - data_ptrs.emplace(std::pair(tensor_name_to_uid["M"], devPtrM)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["Z_INV"], devPtrZInv)); - } - - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(data_ptrs) - .build(); - - NVTE_CHECK_CUDNN(cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc())); - } catch (cudnn_frontend::cudnnException& e) { - struct cudaDeviceProp prop; - NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); - - // This example is only for GH100 cards (cudnn Version >= 8900) - if (!((prop.major == 9 && prop.minor == 0 && CUDNN_VERSION >= 8900)) && - (e.getCudnnStatus() == CUDNN_STATUS_ARCH_MISMATCH || - e.getCudnnStatus() == CUDNN_STATUS_NOT_SUPPORTED)) { - std::cout << "Example is only supported for GH100 (cuDNN >= 8900) GPUs" << std::endl; - } else { - std::cout << "[ERROR] Exception " << e.what() << std::endl; - } - } -} - -// fused attention BWD FP8 with FE 0.9 -void fused_attn_fp8_bwd_impl( - int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d, float attnScale, - float dropoutProbability, NVTE_QKV_Layout layout, void* devPtrQ, void* devPtrK, void* devPtrV, - void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrdQ, void* devPtrdK, - void* devPtrdV, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, - void* devPtrDescaleO, void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledS, - void* devPtrScaleS, void* devPtrScaledS, void* devPtrScaledQ, void* devPtrScaledK, - void* devPtrScaledV, void* devPtrAmaxdS, void* devPtrAmaxdQ, void* devPtrAmaxdK, - void* devPtrAmaxdV, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, - void* devPtrDropoutOffset, cudnnDataType_t tensorType, void* workspace_ptr, - size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle_) { - try { - FADescriptor descriptor{b, - h, - s_q, - s_kv, - d, - attnScale, - false, - dropoutProbability, - layout, - NVTE_Bias_Type::NVTE_NO_BIAS, - NVTE_Mask_Type::NVTE_PADDING_MASK, - tensorType, - false}; - - using CacheType = std::map; - static thread_local CacheType fa_bprop_cache; - - // Get plan from cache if cache is available, otherwise create one - auto get_plan = [&](CacheType& cache, const FADescriptor& descriptor) { - // If hit, return - auto it = cache.find(descriptor); - if (it != cache.end()) { - auto plan = it->second; - return plan; - } - - // Otherwise, build the op_graph and the plan. Then update cache - std::vector all_ops; - std::vector ops; - - NVTE_CHECK(dropoutProbability != 1.0f, "Dropout probability cannot be 1.0"); - - int64_t raggedDim[4] = {b + 1, 1, 1, 1}; - int64_t raggedStride[4] = {1, 1, 1, 1}; - // Create offset tensors - auto QKVOffsetTensor = tensor_create(CUDNN_DATA_INT32, tensor_name_to_uid["QKV_RAGGED"], - raggedDim, raggedStride, false, false); - auto ORaggedOffsetTensor = tensor_create(CUDNN_DATA_INT32, tensor_name_to_uid["O_RAGGED"], - raggedDim, raggedStride, false, false); - - // Create shared ptrs to ragged offset tensors for multiple tensors - std::shared_ptr QKVRaggedOffsetTensorPtr = - std::make_shared(std::move(QKVOffsetTensor)); - std::shared_ptr ORaggedOffsetTensorPtr = - std::make_shared(std::move(ORaggedOffsetTensor)); - - // Create Q and K tensors that are used in different places - int64_t q_dim[4] = {b, h, s_q, d}; - int64_t q_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, q_stride, layout, NVTE_QKV_Matrix::NVTE_Q_Matrix); - - int64_t k_dim[4] = {b, h, s_kv, d}; - int64_t k_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, k_stride, layout, NVTE_QKV_Matrix::NVTE_K_Matrix); - - auto qTensor = tensor_create_with_offset(tensorType, tensor_name_to_uid["Q"], q_dim, q_stride, - false, false, QKVRaggedOffsetTensorPtr); - auto kTensor = tensor_create_with_offset(tensorType, tensor_name_to_uid["K"], k_dim, k_stride, - false, false, QKVRaggedOffsetTensorPtr); - - int64_t scale_dim[4] = {1, 1, 1, 1}; - int64_t scale_stride[4] = {1, 1, 1, 1}; - - // Create attnScale tensor for multiple ops to use - auto attnScaleTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["AttnScale"], - scale_dim, scale_stride, false, true); // is by value - - // Create descale Q K dO dS global tensors since they are used in multiple places - auto descaleQTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["descaleQ"], - scale_dim, scale_stride, false, false); - auto descaleKTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["descaleK"], - scale_dim, scale_stride, false, false); - auto descaledOTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["descaledO"], - scale_dim, scale_stride, false, false); - auto descaledSTensor = tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["descaledS"], - scale_dim, scale_stride, false, false); - - int64_t seqlen_dim[4] = {b, 1, 1, 1}; - int64_t seqlen_stride[4] = {1, 1, 1, 1}; - // Create MNK override tensor - auto seqlenMNKTensor = tensor_create(CUDNN_DATA_INT32, tensor_name_to_uid["MNK_OVERRIDE"], - seqlen_dim, seqlen_stride, false, false); - - int64_t O_dim[4] = {b, h, s_q, d}; - int64_t O_stride[4]; - generateMatrixStrides(b, h, s_q, s_kv, d, O_stride, layout, NVTE_QKV_Matrix::NVTE_O_Matrix); - // Create O and loss tensor - auto OTensor = tensor_create_with_offset(tensorType, tensor_name_to_uid["O"], O_dim, O_stride, - false, false, ORaggedOffsetTensorPtr); - // dO is used in multiple places and E5M2 - auto dOTensor = - tensor_create_with_offset(CUDNN_DATA_FP8_E5M2, tensor_name_to_uid["dO"], O_dim, O_stride, - false, false, ORaggedOffsetTensorPtr); - - // Q * K.T - auto afterQKTensor = createQKBMM(b, h, s_q, s_kv, d, layout, tensorType, &ops, qTensor, - kTensor, seqlenMNKTensor, QKVRaggedOffsetTensorPtr); - - // QK.T * attn scale - auto AfterAttnScale_before_dequan_Q_tensor = - createScale(afterQKTensor, // input tensor - attnScaleTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - true, // scale is by value - &ops, 1999 /*UID offset*/); - - // QK.T * attn scale * dequant_Q - auto AfterAttnScale_before_dequan_K_tensor = - createScale(AfterAttnScale_before_dequan_Q_tensor, // input tensor - descaleQTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops, 2000 /*UID offset*/); - - // QK.T * attn scale * dequant_Q * dequant_K - auto AfterAttnScale_tensor = - createScale(AfterAttnScale_before_dequan_K_tensor, // input tensor - descaleKTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops, 2001 /*UID offset*/); - - auto beforeDropout_QKt_Tensor = - createSoftmaxBackward(b, h, s_q, s_kv, &ops, AfterAttnScale_tensor); - - int64_t afterBMM1_dim[4] = {b, h, s_q, s_kv}; - int64_t afterBMM1_stride[4] = {h * s_q * s_kv, s_q * s_kv, s_kv, 1}; - - // mask for the dropout. Used in different places - auto dropoutMaskTensor = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 200, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - - auto AfterDropout_before_quan_S = createDropoutBackward( - b, h, s_q, s_kv, dropoutProbability, &ops, beforeDropout_QKt_Tensor, dropoutMaskTensor); - - // After softmax * scale S -> fp8 input to next bmm with V - auto AfterMultiply = createScale(AfterDropout_before_quan_S, // input tensor - "scaleS", // scale tensor - tensorType, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - // After softmax * dO - auto dVTensor_before_dequan_S = createSdOBMM(b, h, s_q, s_kv, d, tensorType, &ops, - AfterMultiply, dOTensor, seqlenMNKTensor); - - // O * dequant_S - auto dVTensor_before_dequan_dO = createScale(dVTensor_before_dequan_S, // input tensor - "descaleS", // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - // O * dequant_S * dequant_dO - auto dVTensor_before_quan_dV = createScale(dVTensor_before_dequan_dO, // input tensor - descaledOTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops, 2002 /*UID offset*/); - - // O * dequant_S * dequant_dO * scale dV - auto dVTensor = createScaleWithOffset(dVTensor_before_quan_dV, // input tensor - "scaledV", // scale tensor - layout, // qkv layout - CUDNN_DATA_FP8_E5M2, // output tensor type - false, // output not virtual - false, // scale is by value - &ops, - QKVRaggedOffsetTensorPtr, // ragged offset - "dV" /*Output tensor name*/); - - // Amax for dV - createAmax("amaxdV", dVTensor_before_quan_dV, &ops); - - auto dS_before_dequan_dO_Tensor = - createdOVBMM(b, h, s_q, s_kv, d, layout, tensorType, &ops, dOTensor, seqlenMNKTensor, - QKVRaggedOffsetTensorPtr); - - // dS * dequant_dO - auto dS_before_dequan_V = createScale(dS_before_dequan_dO_Tensor, // input tensor - descaledOTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops, 2003 /*UID offset*/); - - // O * dequant_S * dequant_dV - auto dS_after_dequan = createScale(dS_before_dequan_V, // input tensor - "descaleV", // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - // RNG Multiply - auto beforeDropoutScale_dOVt_Tensor = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 350, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - // After dropout mask and scale - auto dS_after_dropout = - tensor_create(CUDNN_DATA_FLOAT, tensor_name_to_uid["VIRTUAL"] + 351, afterBMM1_dim, - afterBMM1_stride, true, false); // is virtual - - // Define the multiply mask descriptor - auto mulDesc = pw_desc_create(CUDNN_DATA_FLOAT, CUDNN_POINTWISE_MUL); - - // Create a multiply mask Node - auto maskMul_op = binary_pw_op_create(dS_after_dequan, dropoutMaskTensor, - beforeDropoutScale_dOVt_Tensor, mulDesc); - - ops.push_back(std::move(maskMul_op)); - - // scale after dropout for dO and O chain - auto dropoutScale_dOVt_OdO_Tensor = - tensor_create(tensorType, tensor_name_to_uid["DROPOUT_SCALE_dOVt_OdO"], scale_dim, - scale_stride, false, true); // is by value - - // Create a multiply dropout scale Node - auto mul_dropout_scale_op = binary_pw_op_create( - beforeDropoutScale_dOVt_Tensor, dropoutScale_dOVt_OdO_Tensor, dS_after_dropout, mulDesc); - - ops.push_back(std::move(mul_dropout_scale_op)); - - // O * dequant_O - auto O_after_dequan_Tensor = createScale(OTensor, // input tensor - "descaleO", // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - // dO * dequant_dO - auto dO_after_dequan_Tensor = createScale(dOTensor, // input tensor - descaledOTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops, 2004 /*UID offset*/); - - // row reduction sum[(dO * dequant_dO) * (O * dequant_O) * (1 - p)] - auto O_dO_after_rowsum = - createdOAndORowReductionChain(b, h, s_q, s_kv, d, layout, &ops, O_after_dequan_Tensor, - dO_after_dequan_Tensor, dropoutScale_dOVt_OdO_Tensor); - - // (dS_after_dropout - O_dO_after_rowsum) * AfterDropout_before_quan_S * attnScale - auto S_mul_dS_minus_O_dO = createBiasSubtractionSoftmaxMulChain( - b, h, s_q, s_kv, d, layout, &ops, dS_after_dropout, AfterDropout_before_quan_S, - O_dO_after_rowsum, attnScaleTensor); - - // S_mul_dS_minus_O_dO * scaledS - auto S_mul_dS_minus_O_dO_after_quan_dS = - createScale(S_mul_dS_minus_O_dO, // input tensor - "scaledS", // scale tensor - CUDNN_DATA_FP8_E5M2, // output tensor type - true, // output is virtual - false, // scale is by value - &ops); - - // Amax for dS - createAmax("amaxdS", S_mul_dS_minus_O_dO, &ops); - - // dS @ K - auto After_dS_K = createdSKBMM(b, h, s_q, s_kv, d, &ops, S_mul_dS_minus_O_dO_after_quan_dS, - kTensor, seqlenMNKTensor); - - // (dS * K) * descale dS - auto After_dS_K_before_dequan_K = createScale(After_dS_K, // input tensor - descaledSTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops, 2006 /*UID offset*/); - - // (dS * K) * descale dS * descale K - auto After_dS_K_before_quan_dQ = createScale(After_dS_K_before_dequan_K, // input tensor - descaleKTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops, 2007 /*UID offset*/); - - // (dS * K) * descale dS * descale K * scale dQ - auto dQ = createScaleWithOffset(After_dS_K_before_quan_dQ, // input tensor - "scaledQ", // scale tensor - layout, // qkv layout - CUDNN_DATA_FP8_E5M2, // output tensor type - false, // output not virtual - false, // scale is by value - &ops, - QKVRaggedOffsetTensorPtr, // ragged offset - "dQ"); - - // Amax for dQ - createAmax("amaxdQ", After_dS_K_before_quan_dQ, &ops); - - // dS.T @ Q - auto After_dSTranspose_Q = - createdSQBMM(b, h, s_q, s_kv, d, layout, &ops, S_mul_dS_minus_O_dO_after_quan_dS, qTensor, - seqlenMNKTensor); - - // (dS.T * Q) * descale dS - auto After_dSTranspose_Q_before_dequan_Q = - createScale(After_dSTranspose_Q, // input tensor - descaledSTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops, 2009 /*UID offset*/); - - // (dS.T * Q) * descale dS * descale Q - auto After_dSTranspose_Q_before_quan_dK = - createScale(After_dSTranspose_Q_before_dequan_Q, // input tensor - descaleQTensor, // scale tensor - CUDNN_DATA_FLOAT, // output tensor type - true, // output is virtual - false, // scale is by value - &ops, 2010 /*UID offset*/); - - // (dS.T * Q) * descale dS * descale Q * scale dK - auto dK = createScaleWithOffset(After_dSTranspose_Q_before_quan_dK, // input tensor - "scaledK", // scale tensor - layout, // qkv layout - CUDNN_DATA_FP8_E5M2, // output tensor type - false, // output not virtual - false, // scale is by value - &ops, - QKVRaggedOffsetTensorPtr, // ragged offset - "dK"); - - // Amax for dK - createAmax("amaxdK", After_dSTranspose_Q_before_quan_dK, &ops); - - for (unsigned int i = 0; i < ops.size(); i++) { - all_ops.push_back(&ops[i]); - } - - // Create an Operation Graph - auto opGraph = cudnn_frontend::OperationGraphBuilder() - .setHandle(handle_) - .setOperationGraph(all_ops.size(), all_ops.data()) - .build(); - - cudnn_frontend::EngineConfigList filtered_configs; - auto statuses = cudnn_frontend::get_heuristics_list<1>( - {"heuristics_instant"}, opGraph, allowAllConfig, filtered_configs, true); - - if (filtered_configs.size() == 0) { - cudnn_frontend::set_error_and_throw_exception( - nullptr, CUDNN_STATUS_NOT_SUPPORTED, - "run_mha_bprop: No config returned by the heuristics"); - } - - auto plan = cudnn_frontend::ExecutionPlanBuilder() - .setHandle(handle_) - .setEngineConfig(filtered_configs[0], opGraph.getTag()) - .build(); - cache.insert({descriptor, plan}); - return plan; - }; - - auto plan = get_plan(fa_bprop_cache, descriptor); - size_t wkspace_size = static_cast(plan.getWorkspaceSize()); - - // Exit to request upper level API to allocate memory if needed - if (workspace_ptr == nullptr) { - *workspace_size = wkspace_size + ((b + 1) * 2 + b) * sizeof(int32_t); - return; - } - - // cuDNN stream check needs to be moved here to support dummy kernel calls with - // null streams for sizing the cuDNN workspace. - NVTE_CHECK_CUDNN(cudnnSetStream(handle_, stream)); - - int32_t* qkv_ragged_offset = - reinterpret_cast(reinterpret_cast(workspace_ptr) + wkspace_size); - int32_t* o_ragged_offset = reinterpret_cast(reinterpret_cast(workspace_ptr) + - wkspace_size + (b + 1) * sizeof(int32_t)); - int32_t* actual_seqlens_q = reinterpret_cast( - reinterpret_cast(workspace_ptr) + wkspace_size + (b + 1) * 2 * sizeof(int32_t)); - // FP8 currently only supports self-attention, so doesn't use devPtrcuSeqlensKV - dim3 blockDims(128); - dim3 gridDims((b + blockDims.x) / blockDims.x); - cu_seqlens_to_offsets<<>>( - b, h, d, reinterpret_cast(devPtrcuSeqlensQ), actual_seqlens_q, qkv_ragged_offset, - o_ragged_offset); - NVTE_CHECK_CUDA(cudaGetLastError()); - void* devPtrQKVRaggedOffset = reinterpret_cast(qkv_ragged_offset); - void* devPtrORaggedOffset = reinterpret_cast(o_ragged_offset); - void* devPtrMNKOverride = reinterpret_cast(actual_seqlens_q); - - std::set> data_ptrs; - float dropoutScale = 1.0f / (1.0f - dropoutProbability); - float dropoutScale_dOVt_OdO = 1.0f - dropoutProbability; - data_ptrs.emplace(std::pair(tensor_name_to_uid["Q"], devPtrQ)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["K"], devPtrK)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["K_TRANSPOSE"], devPtrK)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["V"], devPtrV)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["V_TRANSPOSE"], devPtrV)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["dQ"], devPtrdQ)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["dK"], devPtrdK)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["dV"], devPtrdV)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["dO"], devPtrdO)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["AttnScale"], &attnScale)); - data_ptrs.emplace( - std::pair(tensor_name_to_uid["DROPOUT_SCALE"], &dropoutScale)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["DROPOUT_SCALE_dOVt_OdO"], - &dropoutScale_dOVt_OdO)); - data_ptrs.emplace( - std::pair(tensor_name_to_uid["DROPOUT_SEED"], devPtrDropoutSeed)); - data_ptrs.emplace( - std::pair(tensor_name_to_uid["DROPOUT_OFFSET"], devPtrDropoutOffset)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["M"], devPtrM)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["Z_INV"], devPtrZInv)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["O"], devPtrO)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["descaleQ"], devPtrDescaleQ)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["descaleK"], devPtrDescaleK)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["descaleV"], devPtrDescaleV)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["descaleS"], devPtrDescaleS)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["descaledS"], devPtrDescaledS)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["descaleO"], devPtrDescaleO)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["descaledO"], devPtrDescaledO)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["scaleS"], devPtrScaleS)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["scaledS"], devPtrScaledS)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["scaledQ"], devPtrScaledQ)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["scaledK"], devPtrScaledK)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["scaledV"], devPtrScaledV)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["amaxdS"], devPtrAmaxdS)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["amaxdQ"], devPtrAmaxdQ)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["amaxdK"], devPtrAmaxdK)); - data_ptrs.emplace(std::pair(tensor_name_to_uid["amaxdV"], devPtrAmaxdV)); - data_ptrs.emplace( - std::pair(tensor_name_to_uid["QKV_RAGGED"], devPtrQKVRaggedOffset)); - data_ptrs.emplace( - std::pair(tensor_name_to_uid["O_RAGGED"], devPtrORaggedOffset)); - data_ptrs.emplace( - std::pair(tensor_name_to_uid["MNK_OVERRIDE"], devPtrMNKOverride)); - - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(workspace_ptr) - .setDataPointers(data_ptrs) - .build(); - NVTE_CHECK_CUDNN(cudnnBackendExecute(handle_, plan.get_raw_desc(), variantPack.get_raw_desc())); - } catch (cudnn_frontend::cudnnException& e) { - struct cudaDeviceProp prop; - NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); - - // This example is only for GH100 cards (cudnn Version >= 8900) - if (!((prop.major == 9 && prop.minor == 0 && CUDNN_VERSION >= 8900)) && - (e.getCudnnStatus() == CUDNN_STATUS_ARCH_MISMATCH || - e.getCudnnStatus() == CUDNN_STATUS_NOT_SUPPORTED)) { - std::cout << "Example is only supported for GH100 (cuDNN >= 8900) GPUs" << std::endl; - } else { - std::cout << "[ERROR] Exception " << e.what() << std::endl; - } - } -} - // fused attention FWD FP8 with FE 1.0+ -void fused_attn_fp8_fwd_impl_v1( +void fused_attn_fp8_fwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, bool is_training, float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void* devPtrQ, void* devPtrK, void* devPtrV, - void* devPtrSoftmaxOffset, void* devPtrM, void* devPtrZInv, void* devPtrO, void* devPtrDescaleQ, + void* devPtrSoftmaxOffset, void* devPtrM, void* devPtrO, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleS, void* devPtrScaleS, void* devPtrScaleO, void* devPtrAmaxO, void* devPtrAmaxS, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, void* devPtrDropoutOffset, @@ -2080,26 +446,26 @@ void fused_attn_fp8_fwd_impl_v1( } // fused attention BWD FP8 with FE 1.0+ -void fused_attn_fp8_bwd_impl_v1( +void fused_attn_fp8_bwd_impl( int64_t b, int64_t h, int64_t hg, int64_t s_q, int64_t s_kv, int64_t d_qk, int64_t d_v, float scaling_factor, float dropout_probability, NVTE_QKV_Layout qkv_layout, NVTE_QKV_Format o_format, NVTE_QKV_Format do_format, NVTE_QKV_Layout dqkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, - bool deterministic, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, - void* devPtrZInv, void* devPtrO, void* devPtrdO, void* devPtrSoftmaxOffset, void* devPtrdQ, - void* devPtrdK, void* devPtrdV, void* devPtrdSoftmaxOffset, void* devPtrDescaleQ, - void* devPtrDescaleK, void* devPtrDescaleV, void* devPtrDescaleO, void* devPtrDescaledO, - void* devPtrDescaleS, void* devPtrDescaledP, void* devPtrScaleS, void* devPtrScaledP, - void* devPtrScaledQ, void* devPtrScaledK, void* devPtrScaledV, void* devPtrAmaxdP, - void* devPtrAmaxdQ, void* devPtrAmaxdK, void* devPtrAmaxdV, void* devPtrQ_t, void* devPtrK_t, - void* devPtrdO_f16, void* devPtrdO_t, void* devPtrDescaleQ_t, void* devPtrDescaleK_t, - void* devPtrDescaledO_t, void* devPtrcuSeqlensQ, void* devPtrcuSeqlensKV, - void* devPtrDropoutSeed, void* devPtrDropoutOffset, cudnn_frontend::DataType_t qkv_tensor_type, - cudnn_frontend::DataType_t o_tensor_type, cudnn_frontend::DataType_t do_tensor_type, - cudnn_frontend::DataType_t dqkv_tensor_type, NVTEScalingMode scaling_mode, - NVTE_QKV_Format qkv_scale_inv_format, NVTE_QKV_Format do_scale_inv_format, void* workspace, - size_t* workspace_size, cudaStream_t stream, cudnnHandle_t handle) { + bool deterministic, void* devPtrQ, void* devPtrK, void* devPtrV, void* devPtrM, void* devPtrO, + void* devPtrdO, void* devPtrSoftmaxOffset, void* devPtrdQ, void* devPtrdK, void* devPtrdV, + void* devPtrdSoftmaxOffset, void* devPtrDescaleQ, void* devPtrDescaleK, void* devPtrDescaleV, + void* devPtrDescaleO, void* devPtrDescaledO, void* devPtrDescaleS, void* devPtrDescaledP, + void* devPtrScaleS, void* devPtrScaledP, void* devPtrScaledQ, void* devPtrScaledK, + void* devPtrScaledV, void* devPtrAmaxdP, void* devPtrAmaxdQ, void* devPtrAmaxdK, + void* devPtrAmaxdV, void* devPtrQ_t, void* devPtrK_t, void* devPtrdO_f16, void* devPtrdO_t, + void* devPtrDescaleQ_t, void* devPtrDescaleK_t, void* devPtrDescaledO_t, void* devPtrcuSeqlensQ, + void* devPtrcuSeqlensKV, void* devPtrDropoutSeed, void* devPtrDropoutOffset, + cudnn_frontend::DataType_t qkv_tensor_type, cudnn_frontend::DataType_t o_tensor_type, + cudnn_frontend::DataType_t do_tensor_type, cudnn_frontend::DataType_t dqkv_tensor_type, + NVTEScalingMode scaling_mode, NVTE_QKV_Format qkv_scale_inv_format, + NVTE_QKV_Format do_scale_inv_format, void* workspace, size_t* workspace_size, + cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; const auto cudnn_runtime_version = cudnnGetVersion(); bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS); @@ -2760,19 +1126,12 @@ void fused_attn_fp8_fwd( devPtrSoftmaxOffset = input_SoftmaxOffset->data.dptr; } void* devPtrM = nullptr; - void* devPtrZInv = nullptr; if (Aux_CTX_Tensors->size == 0) { int i = 0; Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_M->data.dptr = nullptr; output_M->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; output_M->data.dtype = DType::kFloat32; - if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_ZInv->data.dptr = nullptr; - output_ZInv->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - output_ZInv->data.dtype = DType::kFloat32; - } Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = nullptr; output_rng_state->data.shape = {2}; @@ -2788,11 +1147,6 @@ void fused_attn_fp8_fwd( int i = 0; Tensor* output_M = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); devPtrM = output_M->data.dptr; - devPtrZInv = nullptr; - if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - Tensor* output_ZInv = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - devPtrZInv = output_ZInv->data.dptr; - } Tensor* output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = rng_state->data.dptr; if (softmax_type != NVTE_VANILLA_SOFTMAX) { @@ -2819,25 +1173,17 @@ void fused_attn_fp8_fwd( NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); if ((qkv_format == NVTE_QKV_Format::NVTE_BSHD) || (qkv_format == NVTE_QKV_Format::NVTE_SBHD) || (qkv_format == NVTE_QKV_Format::NVTE_BHSD)) { - fused_attn::fused_attn_fp8_fwd_impl_v1( + fused_attn::fused_attn_fp8_fwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, is_training, attn_scale, p_dropout, qkv_layout, o_format, bias_type, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, - devPtrV, devPtrSoftmaxOffset, devPtrM, devPtrZInv, devPtrO, devPtrDescaleQ, devPtrDescaleK, + devPtrV, devPtrSoftmaxOffset, devPtrM, devPtrO, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), input_Q->scaling_mode, qkv_scale_inv_format, workspace->data.dptr, &workspace_size, stream, handle); - } else if (qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - fused_attn::fused_attn_fp8_fwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim_qk, is_training, attn_scale, - p_dropout, qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, - devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleS, devPtrScaleS, devPtrScaleO, - devPtrAmaxO, devPtrAmaxS, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, - devPtrDropoutOffset, get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, - stream, handle); } else { - NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); + NVTE_ERROR("FP8 fused attention only supports qkv_format=BSHD, SBHD, or BHSD.\n"); } if (workspace_size > 0) { @@ -2862,11 +1208,11 @@ void fused_attn_fp8_bwd( NVTE_Softmax_Type softmax_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor* input_Q, const Tensor* input_K, const Tensor* input_V, const Tensor* input_O, const Tensor* input_dO, - const Tensor* input_dO_f16, const Tensor* input_M, const Tensor* input_ZInv, - const Tensor* input_S, const Tensor* input_SoftmaxOffset, Tensor* input_output_dP, - const Tensor* output_dQ, const Tensor* output_dK, const Tensor* output_dV, - Tensor* output_dSoftmaxOffset, const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, - const Tensor* rng_state, Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { + const Tensor* input_dO_f16, const Tensor* input_M, const Tensor* input_S, + const Tensor* input_SoftmaxOffset, Tensor* input_output_dP, const Tensor* output_dQ, + const Tensor* output_dK, const Tensor* output_dV, Tensor* output_dSoftmaxOffset, + const Tensor* cu_seqlens_q, const Tensor* cu_seqlens_kv, const Tensor* rng_state, + Tensor* workspace, cudaStream_t stream, cudnnHandle_t handle) { using namespace transformer_engine; void* devPtrQ = input_Q->data.dptr; void* devPtrK = input_K->data.dptr; @@ -2899,7 +1245,6 @@ void fused_attn_fp8_bwd( } void* devPtrM = input_M->data.dptr; - void* devPtrZInv = (input_ZInv != nullptr) ? input_ZInv->data.dptr : nullptr; void *devPtrScaleS = nullptr, *devPtrDescaleS = nullptr, *devPtrAmaxdP = nullptr, *devPtrScaledP = nullptr, *devPtrDescaledP = nullptr; @@ -2949,34 +1294,22 @@ void fused_attn_fp8_bwd( NVTE_QKV_Format dqkv_format = nvte_get_qkv_format(dqkv_layout); if ((dqkv_format == NVTE_QKV_Format::NVTE_BSHD) || (dqkv_format == NVTE_QKV_Format::NVTE_SBHD) || (dqkv_format == NVTE_QKV_Format::NVTE_BHSD)) { - fused_attn::fused_attn_fp8_bwd_impl_v1( + fused_attn::fused_attn_fp8_bwd_impl( batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v, attn_scale, p_dropout, qkv_layout, o_format, do_format, dqkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, - devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrSoftmaxOffset, - devPtrdQ, devPtrdK, devPtrdV, devPtrdSoftmaxOffset, devPtrDescaleQ, devPtrDescaleK, - devPtrDescaleV, devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, - devPtrScaleS, devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, - devPtrAmaxdQ, devPtrAmaxdK, devPtrAmaxdV, devPtrQ_t, devPtrK_t, devPtrdO_f16, devPtrdO_t, + devPtrQ, devPtrK, devPtrV, devPtrM, devPtrO, devPtrdO, devPtrSoftmaxOffset, devPtrdQ, + devPtrdK, devPtrdV, devPtrdSoftmaxOffset, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, + devPtrDescaleO, devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, + devPtrScaledP, devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, + devPtrAmaxdK, devPtrAmaxdV, devPtrQ_t, devPtrK_t, devPtrdO_f16, devPtrdO_t, devPtrDescaleQ_t, devPtrDescaleK_t, devPtrDescaledO_t, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, get_cudnn_fe_dtype(QKV_type), get_cudnn_fe_dtype(O_type), get_cudnn_fe_dtype(dO_type), get_cudnn_fe_dtype(dQKV_type), input_dO->scaling_mode, qkv_scale_inv_format, do_scale_inv_format, workspace->data.dptr, &workspace_size, stream, handle); - } else if (dqkv_layout == NVTE_QKV_Layout::NVTE_T3HD) { - // remove this when cuDNN FE supports FP8 + THD - NVTE_CHECK(input_ZInv != nullptr && input_ZInv->data.dptr != nullptr, - "ZInv tensor required for FP8 fused attention backward with T3HD layout."); - fused_attn::fused_attn_fp8_bwd_impl( - batch, num_attn_heads, max_seqlen_q, max_seqlen_kv, head_dim_qk, attn_scale, p_dropout, - qkv_layout, devPtrQ, devPtrK, devPtrV, devPtrM, devPtrZInv, devPtrO, devPtrdO, devPtrdQ, - devPtrdK, devPtrdV, devPtrDescaleQ, devPtrDescaleK, devPtrDescaleV, devPtrDescaleO, - devPtrDescaledO, devPtrDescaleS, devPtrDescaledP, devPtrScaleS, devPtrScaledP, - devPtrScaledQ, devPtrScaledK, devPtrScaledV, devPtrAmaxdP, devPtrAmaxdQ, devPtrAmaxdK, - devPtrAmaxdV, devPtrcuSeqlensQ, devPtrcuSeqlensKV, devPtrDropoutSeed, devPtrDropoutOffset, - get_cudnn_dtype(QKV_type), workspace->data.dptr, &workspace_size, stream, handle); } else { - NVTE_ERROR("FP8 fused attention only supports qkv_layout=t3hd or qkv_format=bshd/sbhd. \n"); + NVTE_ERROR("FP8 fused attention only supports dqkv_format=BSHD, SBHD, or BHSD.\n"); } if (workspace_size > 0) { diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.h b/transformer_engine/common/fused_attn/fused_attn_fp8.h index aaf5039eeb..b9660128ca 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.h +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.h @@ -5,7 +5,7 @@ ************************************************************************/ /*! \file fused_attn_fp8.h - * \brief Functions for fused attention for FP8 with seqlen <= 512 + * \brief Functions for fused attention for FP8 */ #include "transformer_engine/fused_attn.h" @@ -34,9 +34,9 @@ void fused_attn_fp8_bwd( NVTE_Softmax_Type softmax_type, size_t window_size_left, size_t window_size_right, bool bottom_right_diagonal, bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO, - const Tensor *input_dO_f16, const Tensor *input_M, const Tensor *input_ZInv, - const Tensor *input_S, const Tensor *input_SoftmaxOffset, Tensor *input_output_dP, - const Tensor *output_dQ, const Tensor *output_dK, const Tensor *output_dV, - Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, - const Tensor *rng_state, Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); + const Tensor *input_dO_f16, const Tensor *input_M, const Tensor *input_S, + const Tensor *input_SoftmaxOffset, Tensor *input_output_dP, const Tensor *output_dQ, + const Tensor *output_dK, const Tensor *output_dV, Tensor *output_dSoftmaxOffset, + const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *rng_state, + Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle); } // namespace transformer_engine diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index f37eeb0c68..3e628b6581 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -411,20 +411,6 @@ cudnn_frontend::Operation ternary_pw_op_create(cudnn_frontend::Tensor const &xDe return pw_op_created; } -// convert cu_seqlens_q to qkv/o_ragged_offset and actual_seqlens_q -__global__ void cu_seqlens_to_offsets(int64_t b, int64_t h, int64_t d, int32_t *cu_seqlens_q, - int32_t *actual_seqlens_q, int32_t *qkv_ragged_offset, - int32_t *o_ragged_offset) { - size_t tid = blockIdx.x * blockDim.x + threadIdx.x; - if (tid < b) { - actual_seqlens_q[tid] = cu_seqlens_q[tid + 1] - cu_seqlens_q[tid]; - } - if (tid < b + 1) { - qkv_ragged_offset[tid] = cu_seqlens_q[tid] * 3 * h * d; - o_ragged_offset[tid] = cu_seqlens_q[tid] * h * d; - } -} - // convert cu_seqlens to actual_seqlens __global__ void cu_seqlens_to_actual_seqlens(int64_t actual_b, int64_t max_b, int32_t const *const q_cu_seqlens, diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index c3736a6c65..41656062a4 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -333,10 +333,6 @@ struct FADescriptor_v1 { } }; -__global__ void cu_seqlens_to_offsets(int64_t b, int64_t h, int64_t d, int32_t *cu_seqlens_q, - int32_t *actual_seqlens_q, int32_t *qkv_ragged_offset, - int32_t *o_ragged_offset); - __global__ void cu_seqlens_to_actual_seqlens(int64_t actual_b, int64_t max_b, int32_t const *const q_cu_seqlens, int32_t const *const kv_cu_seqlens, int32_t *q_seqlens, diff --git a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu index 8aff85450a..7e516af97b 100644 --- a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu @@ -5,7 +5,6 @@ ************************************************************************/ #include -#include #include #include @@ -21,187 +20,102 @@ namespace fused_router { template __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, const IndexType* tokens_per_expert, - int total_num_tokens, int num_experts, - int num_rows, int num_cols, int topk, float coeff, - DataType* aux_loss, float* Const_buf) { -#if __CUDA_ARCH__ >= 900 - // Using cooperative_groups to manage the cluster - namespace cg = cooperative_groups; - cg::cluster_group cluster = cg::this_cluster(); - int thread_id = cg::this_grid().thread_rank(); - int lane_id = thread_id % kThreadsPerWarp; - int warp_id = thread_id / kThreadsPerWarp; - int warp_num = blockDim.x * gridDim.x / kThreadsPerWarp; - // Only 1 block in the cluster - int block_id = cluster.block_rank(); - int block_num = cluster.dim_blocks().x; - int cluster_id = blockIdx.x / block_num; - if (cluster_id > 0) return; // Only use the cluster 0 - - extern __shared__ float shmem_aux_loss[]; - CompType* aggregated_probs_per_expert = reinterpret_cast(shmem_aux_loss); - // Clear the shmem - for (int i = threadIdx.x; i < num_cols; i += blockDim.x) { - aggregated_probs_per_expert[i] = CompType(0); - } - __syncthreads(); - - /** - * Section: Reduce the probs to the aggregated_probs_per_expert - * 1. reduce on the block - * 2. reduce on the cluster - */ - // Loop: for all positions in each row - for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) { - CompType tmp = CompType(0); - // Loop: for all rows that this warp is responsible for - for (int j = warp_id; j < num_rows; j += warp_num) { - tmp += CompType(probs[j * num_cols + i]); - } - atomicAdd(&aggregated_probs_per_expert[i], tmp); + int total_num_tokens, int num_rows, int num_cols, + int topk, float coeff, float* Coeff_buf) { + // ----------------------------------------------------------------------- + // 1) Write the CPU-computed coefficient into a device buffer to re-use in BWD + // ----------------------------------------------------------------------- + if (threadIdx.x == 0 && blockIdx.x == 0) { + Coeff_buf[0] = coeff; } - cluster.sync(); - // The block 0 will reduce the results of all blocks - if (block_id == 0) { - for (int i = 1; i < block_num; i++) { - // Map the shared memory of the block i to the current block - CompType* dst_smem = reinterpret_cast(cluster.map_shared_rank(shmem_aux_loss, i)); - for (int j = threadIdx.x; j < num_cols; j += blockDim.x) { - atomicAdd(&aggregated_probs_per_expert[j], dst_smem[j]); - } - } - } - cluster.sync(); - /** - * Section: aggregated_probs_per_expert * tokens_per_expert - * In-place update on shmem - */ - if (block_id == 0) { - for (int i = threadIdx.x; i < num_cols; i += blockDim.x) { - aggregated_probs_per_expert[i] *= CompType(tokens_per_expert[i]); - } - __syncthreads(); + // ----------------------------------------------------------------------- + // 2) Each CTA computes a partial dot-product: + // Sigma_col ( Sigma_row probs[row, col] ) * tokens_per_expert[col] + // ----------------------------------------------------------------------- + CompType thread_sum = CompType(0); - if (warp_id == 0) { - /** - * Section: Reduce to get the sum of aggregated_probs_per_expert - */ - CompType intermediate_result = - warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, ReduceFuncType::SUM, lane_id); - __syncwarp(); + // Grid-stride over rows so that every row is processed exactly once. + // Each thread processes a subset of columns. + for (int col = threadIdx.x; col < num_cols; col += blockDim.x) { + CompType col_sum = CompType(0); - if (lane_id == 0) { - /** - * Section: Compute the aux_loss - */ - float C_coeff = (num_experts * coeff) / topk / total_num_tokens / total_num_tokens; - aux_loss[0] = static_cast(intermediate_result * C_coeff); - Const_buf[0] = C_coeff; - } + // Accumulate probs over the rows assigned to this CTA (grid-stride). + for (int row = blockIdx.x; row < num_rows; row += gridDim.x) { + col_sum += CompType(probs[row * num_cols + col]); } - } -#else - // Use Only 1 block/1024 threads to avoid the grid sync - if (blockIdx.x > 0) return; - int warp_num = blockDim.x / kThreadsPerWarp; - int warp_id = threadIdx.x / kThreadsPerWarp; - int lane_id = threadIdx.x % kThreadsPerWarp; - extern __shared__ float shmem_aux_loss[]; - CompType* aggregated_probs_per_expert = reinterpret_cast(shmem_aux_loss); - // Clear the shmem - for (int i = threadIdx.x; i < num_cols; i += blockDim.x) { - aggregated_probs_per_expert[i] = CompType(0); - } - __syncthreads(); + // Multiply by the token count for this expert. + col_sum *= CompType(tokens_per_expert[col]); - /** - * Section: Reduce the probs to the aggregated_probs_per_expert - */ - // Loop: for all positions in each row - for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) { - CompType tmp = CompType(0); - // Loop: for all rows that this warp is responsible for - for (int j = warp_id; j < num_rows; j += warp_num) { - tmp += CompType(probs[j * num_cols + i]); - } - atomicAdd(&aggregated_probs_per_expert[i], tmp); + // Accumulate the per-column contribution into the thread-local sum. + thread_sum += col_sum; } - __syncthreads(); - /** - * Section: aggregated_probs_per_expert * tokens_per_expert - * In-place update on shmem - */ - for (int i = threadIdx.x; i < num_cols; i += blockDim.x) { - aggregated_probs_per_expert[i] *= CompType(tokens_per_expert[i]); - } + // ----------------------------------------------------------------------- + // 3) Block-level reduction of thread_sum using warp_reduce_on_shmem + // ----------------------------------------------------------------------- + extern __shared__ float shmem[]; + CompType* shmem_block = reinterpret_cast(shmem); + shmem_block[threadIdx.x] = thread_sum; __syncthreads(); + const int warp_id = threadIdx.x / kThreadsPerWarp; + const int lane_id = threadIdx.x % kThreadsPerWarp; if (warp_id == 0) { - /** - * Section: Reduce to get the sum of aggregated_probs_per_expert - */ - CompType intermediate_result = - warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, ReduceFuncType::SUM, lane_id); - __syncwarp(); - + CompType block_sum = warp_reduce_on_shmem(shmem_block, static_cast(blockDim.x), + ReduceFuncType::SUM, lane_id); if (lane_id == 0) { - /** - * Section: Compute the aux_loss - */ - float C_coeff = (num_experts * coeff) / topk / total_num_tokens / total_num_tokens; - aux_loss[0] = static_cast(intermediate_result * C_coeff); - Const_buf[0] = C_coeff; + atomicAdd(&Coeff_buf[1], static_cast(block_sum * coeff)); } } -#endif } +// Small kernel to convert the float accumulator to the output DataType. +template +__global__ void convert_accum_to_output(const float* Coeff_buf, DataType* aux_loss) { + aux_loss[0] = static_cast(Coeff_buf[1]); +} + +/* ------------------------------------------------------------------------- + * Kernel launcher -- simplified (no cluster launch). + * ------------------------------------------------------------------------- */ template void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs, const IndexType* tokens_per_expert, int total_num_tokens, int num_experts, int num_rows, int num_cols, int topk, float coeff, - DataType* aux_loss, float* Const_buf, + DataType* aux_loss, float* Coeff_buf, cudaStream_t stream) { - if (cuda::sm_arch(cuda::current_device()) >= 90) { - cudaLaunchConfig_t config = {0}; - int cluster_size = 8; - config.gridDim = cluster_size; - config.blockDim = 1024; - config.dynamicSmemBytes = sizeof(CompType) * num_cols; - config.stream = stream; - - // Update the max cluster size based on the device - NVTE_CHECK_CUDA(cudaOccupancyMaxPotentialClusterSize( - &cluster_size, - reinterpret_cast(fused_moe_aux_loss_forward_kernel), &config)); - - cudaLaunchAttribute attribute[1]; - attribute[0].id = cudaLaunchAttributeClusterDimension; - attribute[0].val.clusterDim.x = cluster_size; - attribute[0].val.clusterDim.y = 1; - attribute[0].val.clusterDim.z = 1; - config.numAttrs = 1; - config.attrs = attribute; + NVTE_CHECK(num_experts == num_cols, "Number of experts (", num_experts, + ") must be equal to number of input columns (", num_cols, ")."); + + // Round up to a multiple of warp size for correct warp shuffles. + const int block_size = ((std::min(1024, num_cols) + static_cast(kThreadsPerWarp) - 1) / + static_cast(kThreadsPerWarp)) * + static_cast(kThreadsPerWarp); + const int grid_size = cuda::sm_count() * 2; + + // One CompType per thread in shared memory. + const size_t smem_size = block_size * sizeof(CompType); + check_shared_memory_capacity_num_experts(smem_size, num_experts); + + // Compute final coefficient and zero the float accumulator (Coeff_buf[1]) before launch. + const float C_coeff = (num_experts * coeff) / topk / total_num_tokens / total_num_tokens; + NVTE_CHECK_CUDA(cudaMemsetAsync(Coeff_buf + 1, 0, sizeof(float), stream)); + fused_moe_aux_loss_forward_kernel + <<>>(probs, tokens_per_expert, total_num_tokens, + num_rows, num_cols, topk, C_coeff, Coeff_buf); + NVTE_CHECK_CUDA(cudaGetLastError()); - NVTE_CHECK_CUDA(cudaLaunchKernelEx( - &config, fused_moe_aux_loss_forward_kernel, probs, tokens_per_expert, - total_num_tokens, num_experts, num_rows, num_cols, topk, coeff, aux_loss, Const_buf)); - } else { - size_t smem_size = sizeof(CompType) * num_cols; - fused_moe_aux_loss_forward_kernel - <<<1, 1024, smem_size, stream>>>(probs, tokens_per_expert, total_num_tokens, num_experts, - num_rows, num_cols, topk, coeff, aux_loss, Const_buf); - NVTE_CHECK_CUDA(cudaGetLastError()); - } + // Convert the float accumulator to the output DataType. + convert_accum_to_output<<<1, 1, 0, stream>>>(Coeff_buf, aux_loss); + NVTE_CHECK_CUDA(cudaGetLastError()); } void fused_moe_aux_loss_forward(const Tensor& probs, const Tensor& tokens_per_expert, int total_num_tokens, int num_experts, int num_rows, int num_cols, - int topk, float coeff, Tensor& aux_loss, Tensor& Const_buf, + int topk, float coeff, Tensor& aux_loss, Tensor& Coeff_buf, cudaStream_t stream) { TE_ROUTER_PROBS_TYPE_SWITCH_ALL( probs.data.dtype, DataType, @@ -212,7 +126,7 @@ void fused_moe_aux_loss_forward(const Tensor& probs, const Tensor& tokens_per_ex reinterpret_cast(tokens_per_expert.data.dptr), total_num_tokens, num_experts, num_rows, num_cols, topk, coeff, reinterpret_cast(aux_loss.data.dptr), - reinterpret_cast(Const_buf.data.dptr), stream););); + reinterpret_cast(Coeff_buf.data.dptr), stream););); } template @@ -269,13 +183,13 @@ void fused_moe_aux_loss_backward(const Tensor& Const_buf, const Tensor& tokens_p void nvte_fused_moe_aux_loss_forward(const NVTETensor probs, const NVTETensor tokens_per_expert, int total_num_tokens, int num_experts, int num_rows, int num_cols, int topk, float coeff, NVTETensor aux_loss, - NVTETensor Const_buf, cudaStream_t stream) { + NVTETensor Coeff_buf, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_moe_aux_loss_forward); using namespace transformer_engine; fused_router::fused_moe_aux_loss_forward( *convertNVTETensorCheck(probs), *convertNVTETensorCheck(tokens_per_expert), total_num_tokens, num_experts, num_rows, num_cols, topk, coeff, *convertNVTETensorCheck(aux_loss), - *convertNVTETensorCheck(Const_buf), stream); + *convertNVTETensorCheck(Coeff_buf), stream); } void nvte_fused_moe_aux_loss_backward(const NVTETensor Const_buf, diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 144aea1a07..8589d7045d 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -318,6 +318,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const void *alpha, const void *beta, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter, cudaStream_t stream) { + NVTE_CHECK(!inputA->row_scaled_nvfp4 && !inputB->row_scaled_nvfp4, + "cuBLAS GEMM does not support row-scaled NVFP4 inputs."); + // Tensor dims in row-major order const int A0 = inputA->flat_first_dim(); const int A1 = inputA->flat_last_dim(); diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index 912dc32d35..d9d2786623 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -156,11 +156,9 @@ enum NVTE_Softmax_Type { enum NVTE_Fused_Attn_Backend { /*! No supported backend */ NVTE_No_Backend = -1, - /*! cuDNN-based FP16/BF16 fused attention for <= 512 sequence length */ - NVTE_F16_max512_seqlen = 0, /*! cuDNN-based FP16/BF16 fused attention for any sequence length */ NVTE_F16_arbitrary_seqlen = 1, - /*! cuDNN-based FP8 fused attention for <= 512 sequence length */ + /*! cuDNN-based FP8 fused attention */ NVTE_FP8 = 2, }; @@ -233,16 +231,6 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * - D = Dropout(S) * - O = D * Transpose(V) * - * Support Matrix: - \verbatim - | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | 0 | FP16/BF16 | BS3HD,SB3HD,BSHD_BS2HD,SBHD_SB2HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 | - | 1 | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 | - | | | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | | | | | | - | | | BSHD_BSHD_BSHD,SBHD_SBHD_SBHD | | | | | | - | 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 | - \endverbatim - * * Notes: * * Tensors `cu_seqlens_q_padded` and `cu_seqlens_kv_padded` @@ -264,7 +252,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in,out] S The S tensor. * \param[out] O The output O tensor. * \param[out] Aux_CTX_Tensors Auxiliary output tensors when training, - * e.g. M, ZInv, rng_state. + * e.g. softmax stats, optional Max, rng_state. * \param[in] cu_seqlens_q Cumulative sequence lengths for Q, [batch_size + 1]. * \param[in] cu_seqlens_kv Cumulative sequence lengths for K and V, [batch_size + 1]. * \param[in] cu_seqlens_q_padded Cumulative sequence offsets for Q, [batch_size + 1]. @@ -311,16 +299,6 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso /*! \brief Compute the backward of the dot product attention with separate Q, K and V. * - * Support Matrix: - \verbatim - | backend | precision | qkv layout | bias | mask | dropout | sequence length | head_dim | - | 0 | FP16/BF16 | BS3HD,SB3HD,BSHD_BS2HD,SBHD_SB2HD | NO/POST_SCALE_BIAS | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | <= 512, % 64 == 0 | 64 | - | 1 | FP16/BF16 | BS3HD,SB3HD,BSH3D,SBH3D | NO/POST_SCALE_BIAS/ALIBI | NO/PADDING/CAUSAL/PADDING_CAUSAL_MASK | Yes | > 512, % 64 == 0 | <= 128, % 8 == 0 | - | | | BSHD_BS2HD,BSHD_BSH2D,SBHD_SB2HD,SBHD_SBH2D | | | | | | - | | | BSHD_BSHD_BSHD,SBHD_SBHD_SBHD | | | | | | - | 2 | FP8 | T3HD | NO_BIAS | PADDING_MASK | Yes | <= 512, % 64 == 0 | 64 | - \endverbatim - * * Notes: * * Tensors `cu_seqlens_q_padded` and `cu_seqlens_kv_padded` @@ -342,7 +320,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso * \param[in] S The S tensor. * \param[in,out] dP The gradient of the P tensor. * \param[in] Aux_CTX_Tensors Auxiliary tensors from context when in training mode, - * e.g. M, ZInv, rng_state. + * e.g. softmax stats, optional Max, rng_state. * \param[out] dQ The gradient of the Q tensor. * \param[out] dK The gradient of the K tensor. * \param[out] dV The gradient of the V tensor. diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index bf9394c988..9fe692dd2d 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -440,7 +440,7 @@ void nvte_grouped_bias_add(const NVTEGroupedTensor output, const NVTEGroupedTens /*! \brief Grouped Scaled Bias add for grouped GEMM outputs. * * output[row,col] += bias[col] * scale[row], where biases are per-group -* and scales are per-token (per-row across all groups). +* and scales are per-row across all groups. * Requires uniform last-dimension across all output tensors and bias tensors. */ void nvte_grouped_scaled_bias_add(const NVTEGroupedTensor output, const NVTEGroupedTensor bias, diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index b7461a85d1..e9a6f4f735 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -72,6 +72,7 @@ enum NVTETensorParam { kNVTEColumnwiseScaleInv = 5, /*!< Scale inverse tensor for decoding Columnwise Data */ kNVTEColumnwiseAmax = 6, /*!< Columnwise Amax tensor */ kNVTEWithGEMMSwizzledScales = 7, /*!< Whether scaling factors are in format expected by GEMM */ + kNVTERowScaledNVFP4 = 8, /*!< Whether an NVFP4 tensor uses row scaling */ kNVTENumTensorParams }; @@ -765,6 +766,11 @@ class TensorWrapper { nvte_set_tensor_param_v2(tensor_, kNVTEWithGEMMSwizzledScales, &val, sizeof(val)); } + void set_row_scaled_nvfp4(bool row_scaled_nvfp4) { + const auto val = static_cast(row_scaled_nvfp4); + nvte_set_tensor_param_v2(tensor_, kNVTERowScaledNVFP4, &val, sizeof(val)); + } + // Parameter getters NVTEBasicTensor get_parameter(const NVTETensorParam param) const noexcept { @@ -801,6 +807,12 @@ class TensorWrapper { return static_cast(val); } + bool get_row_scaled_nvfp4() const { + uint8_t val = 0; + nvte_get_tensor_param_v2(tensor_, kNVTERowScaledNVFP4, &val, sizeof(val), nullptr); + return static_cast(val); + } + /*! \brief Get an underlying NVTETensor. * * \return NVTETensor held by this TensorWrapper. diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 67b6f87067..0d0b2fd37f 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -478,6 +478,10 @@ class NVFP4BlockScaling(Recipe): If set to `True`, stochastic rounding is disabled during quantization for all tensors. disable_2d_quantization : bool, default = False If set to `True`, 1D block scaling with block size 16 is used for all tensors. + row_scaled_activation : bool, default = False + If set to `True`, forward activation quantizers emit row-scaled + NVFP4 tensors. In this mode, rowwise ``amax`` metadata is stored + as a vector with one FP32 value per tensor row. backward_override : {None, 'high_precision', 'dequantized'}, default = None Backward precision mode. None does not modify backward behavior, `high_precision` keeps original high-precision operands for backward, @@ -491,6 +495,7 @@ class NVFP4BlockScaling(Recipe): os.getenv("NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING", "0") == "1" ) disable_2d_quantization: bool = os.getenv("NVTE_NVFP4_DISABLE_2D_QUANTIZATION", "0") == "1" + row_scaled_activation: bool = os.getenv("NVTE_NVFP4_ROW_SCALED_ACTIVATION", "0") == "1" fp4_format: Format = Format.E2M1 fp8_format: Format = Format.E4M3 @@ -534,6 +539,7 @@ def __repr__(self) -> str: f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " f"backward_override={self.backward_override}, " + f"row_scaled_activation={self.row_scaled_activation}, " f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 1261879a8b..1a52d76019 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -852,6 +852,9 @@ void nvte_set_tensor_param_v2(NVTETensor tensor, NVTETensorParam param, const vo case kNVTEWithGEMMSwizzledScales: t.with_gemm_swizzled_scales = static_cast(*reinterpret_cast(buf)); break; + case kNVTERowScaledNVFP4: + t.row_scaled_nvfp4 = static_cast(*reinterpret_cast(buf)); + break; default: NVTE_ERROR("Unsupported tensor parameter (", static_cast(param), ")"); } @@ -932,6 +935,9 @@ void nvte_get_tensor_param_v2(const NVTETensor tensor, NVTETensorParam param, vo case kNVTEWithGEMMSwizzledScales: *reinterpret_cast(buf) = static_cast(t->with_gemm_swizzled_scales); break; + case kNVTERowScaledNVFP4: + *reinterpret_cast(buf) = static_cast(t->row_scaled_nvfp4); + break; default: NVTE_ERROR("Unsupported tensor parameter (", static_cast(param), ")"); } diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index a5ec2306b1..c462b30147 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -67,7 +67,7 @@ void quantize_transpose_vector_blockwise_fp4( SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &output_t, const float epsilon, const bool return_identity, const bool return_transpose, const bool pow2_scale, const bool swizzled_scale, const bool use_stochastic_rounding, - const NVTETensor rng_state_tensor, const bool use_2d_quantization, + const NVTETensor rng_state_tensor, const bool use_2d_quantization, const bool row_scaled_nvfp4, const SimpleTensor &noop_tensor, cudaStream_t stream); } // namespace transformer_engine::detail diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index d3d3dceca9..cf9821f1a9 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -316,7 +316,7 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x(const float2 in01, template + bool kApplyStochasticRounding, bool kIs2DBlockScaling, bool kRowScaledNVFP4> __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( const IType* const input, const float* global_amax, OType* const output_c, OType* const output_t, ScaleType* const tile_scales_inv_c, ScaleType* const tile_scales_inv_t, @@ -509,8 +509,19 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo amax = amax_smem[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x]; } // Step 2.4: Compute scale - ScaleType scale_inv = ComputeDecodeScaleFP4(amax, global_encode_scale_multiplier); - float encode_scale = ComputeEncodeScaleFP4(scale_inv, global_decode_scale); + const size_t row_idx = block_idx_y * kTileDim + r_s; + float row_global_encode_scale = global_encode_scale; + if constexpr (kRowScaledNVFP4) { + row_global_encode_scale = + row_idx < num_rows ? ComputeGlobalEncodeScaleFP4(global_amax[row_idx]) : 1.0f; + } + const float row_global_encode_scale_multiplier = + kRowScaledNVFP4 ? row_global_encode_scale * fp4_max_inv : global_encode_scale_multiplier; + const float row_global_decode_scale = + kRowScaledNVFP4 ? 1.0f / row_global_encode_scale : global_decode_scale; + ScaleType scale_inv = + ComputeDecodeScaleFP4(amax, row_global_encode_scale_multiplier); + float encode_scale = ComputeEncodeScaleFP4(scale_inv, row_global_decode_scale); // Step 2.5: Write scale_inv bool write_scale_inv = is_src_lane; if constexpr (!kAligned) { @@ -708,7 +719,7 @@ void quantize_transpose_vector_blockwise_fp4( SimpleTensor& scale_inv_t, SimpleTensor& output, SimpleTensor& output_t, const float epsilon, const bool return_identity, const bool return_transpose, const bool pow2_scale, const bool swizzled_scale, const bool use_stochastic_rounding, - const NVTETensor rng_state_tensor, const bool use_2d_quantization, + const NVTETensor rng_state_tensor, const bool use_2d_quantization, const bool row_scaled_nvfp4, const SimpleTensor& noop_tensor, cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise_fp4); #if CUDA_VERSION >= 12080 @@ -722,6 +733,10 @@ void quantize_transpose_vector_blockwise_fp4( NVTE_CHECK(return_identity || !use_2d_quantization, "2D block quantization is only supported when return_identity is true."); + NVTE_CHECK(!row_scaled_nvfp4 || (return_identity && !return_transpose), + "Row-scaled NVFP4 quantization only supports rowwise quantization."); + NVTE_CHECK(!row_scaled_nvfp4 || !use_2d_quantization, + "Row-scaled NVFP4 quantization does not support 2D quantization."); const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_elements = row_length; @@ -801,35 +816,41 @@ void quantize_transpose_vector_blockwise_fp4( TRANSFORMER_ENGINE_SWITCH_CONDITION( use_2d_quantization, kIs2DBlockScaling, - size_t smem_bytes = kSMemSize * sizeof(InputType); - auto kernel = block_scaled_1d_cast_transpose_kernel< - kReturnIdentity, kReturnTranspose, kPow2Scale, kAligned, - float, InputType, OutputType, ScaleType, kSwizzledScale, - kApplyStochasticRounding, kIs2DBlockScaling>; - if (smem_bytes >= 48 * 1024) { - cudaError_t err = cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_bytes); - NVTE_CHECK(err == cudaSuccess, - "Failed to set dynamic shared memory size."); - } kernel<<>>( - reinterpret_cast(input.dptr), - reinterpret_cast(global_amax.dptr), - reinterpret_cast(output.dptr), - reinterpret_cast(output_t.dptr), - reinterpret_cast(scale_inv.dptr), - reinterpret_cast(scale_inv_t.dptr), row_length, - num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, - scale_t_stride_y, kScaleBlockDim, epsilon, rng_state, - noop_ptr);) // kIs2DBlockScaling - ) // kApplyStochasticRounding - ) // kSwizzledScale - ) // kAligned - ) // kReturnTranspose - ) // kReturnIdentity - ) // OutputType - ) // InputType + TRANSFORMER_ENGINE_SWITCH_CONDITION( + row_scaled_nvfp4, kRowScaledNVFP4, + + size_t smem_bytes = kSMemSize * sizeof(InputType); + auto kernel = block_scaled_1d_cast_transpose_kernel< + kReturnIdentity, kReturnTranspose, kPow2Scale, kAligned, + float, InputType, OutputType, ScaleType, kSwizzledScale, + kApplyStochasticRounding, kIs2DBlockScaling, + kRowScaledNVFP4>; + if (smem_bytes >= 48 * 1024) { + cudaError_t err = cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_bytes); + NVTE_CHECK(err == cudaSuccess, + "Failed to set dynamic shared memory size."); + } kernel<<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(global_amax.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), + row_length, num_rows, scale_stride_x, scale_stride_y, + scale_t_stride_x, scale_t_stride_y, kScaleBlockDim, + epsilon, rng_state, + noop_ptr);) // kRowScaledNVFP4 + ) // kIs2DBlockScaling + ) // kApplyStochasticRounding + ) // kSwizzledScale + ) // kAligned + ) // kReturnTranspose + ) // kReturnIdentity + ) // OutputType + ) // InputType NVTE_CHECK_CUDA(cudaGetLastError()); #else diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index fdfa47da8f..ef7687e3e9 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -79,7 +79,6 @@ .value("NVTE_Paged_KV_THD_SBHD_SBHD", NVTE_QKV_Layout::NVTE_Paged_KV_THD_SBHD_SBHD) \ .value("NVTE_BHSD_BHSD_BHSD", NVTE_QKV_Layout::NVTE_BHSD_BHSD_BHSD); \ pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) \ - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) \ .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) \ .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8) \ .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend); \ diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 40d02f40e1..489bfde997 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -353,10 +353,7 @@ def abstract( config.window_size, ).get_fused_attn_backend() - if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: - softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen) - softmax_dtype = q_dtype - elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: + if backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: # cuDNN 9.6 reduces the required softmax shape if get_cudnn_version() >= (9, 6, 0): if config.qkv_layout.is_thd(): diff --git a/transformer_engine/jax/cpp_extensions/router.py b/transformer_engine/jax/cpp_extensions/router.py index f2affacdaa..0ae267cbf3 100644 --- a/transformer_engine/jax/cpp_extensions/router.py +++ b/transformer_engine/jax/cpp_extensions/router.py @@ -401,7 +401,7 @@ def abstract(probs_aval, tokens_per_expert_aval, topk, coeff): del topk, coeff, tokens_per_expert_aval i_dtype = dtypes.canonicalize_dtype(probs_aval.dtype) aux_loss_aval = probs_aval.update(shape=(), dtype=i_dtype) - const_buf_aval = probs_aval.update(shape=(1,), dtype=jnp.float32) + const_buf_aval = probs_aval.update(shape=(2,), dtype=jnp.float32) return aux_loss_aval, const_buf_aval @staticmethod diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index 76f2d92891..ed136d7b9e 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -28,7 +28,6 @@ NVTE_Fused_Attn_Backend GetFusedAttnBackend( /* NOTE: PrepareFusedAttnForwardAuxTensors unifies the auxiliary tensor pack logic from the fused attention forward kernels in: - - common/fused_attn/fused_attn_f16_max512_seqlen.cu lines 594-634 and 773-812 - common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu lines 1270-1281 and 1348-1359 */ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t input_batch, @@ -40,7 +39,6 @@ void PrepareFusedAttnForwardAuxTensors(NVTETensorPack *tensor_pack, const size_t void *bias_buf = nullptr, void *softmax_offset_buf = nullptr) { // all backends need softmax but expect different shapes/dtypes - // start with the max512 sequence length softmax shape/dtype and correct later tensor_pack->size = 1; NVTETensor &softmax_aux = tensor_pack->tensors[0]; NVTEBasicTensor softmax_aux_data; @@ -127,15 +125,6 @@ void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_ q_max_seqlen, kv_max_seqlen, dtype, dummy_bias_type, dummy_backend, softmax_buf, rng_state_buf, bias_buf, softmax_offset_buf); - - // correct softmax shape for max512 sequence length kernel - if (backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) { - NVTEBasicTensor softmax_aux_data = - nvte_get_tensor_param(tensor_pack->tensors[0], kNVTERowwiseData); - softmax_aux_data.shape.data[3] = kv_max_seqlen; // {B,H,Qs,1} -> {B,H,Qs,Ks} - softmax_aux_data.dtype = static_cast(dtype); - nvte_set_tensor_param(&(tensor_pack->tensors[0]), kNVTERowwiseData, &softmax_aux_data); - } } pybind11::tuple GetFusedAttnForwardWorkspaceSizes( diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index b002643942..70d0403b3e 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -189,7 +189,6 @@ PYBIND11_MODULE(transformer_engine_jax, m) { pybind11::enum_(m, "NVTE_Fused_Attn_Backend", pybind11::module_local()) .value("NVTE_No_Backend", NVTE_Fused_Attn_Backend::NVTE_No_Backend) - .value("NVTE_F16_max512_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) .value("NVTE_F16_arbitrary_seqlen", NVTE_Fused_Attn_Backend::NVTE_F16_arbitrary_seqlen) .value("NVTE_FP8", NVTE_Fused_Attn_Backend::NVTE_FP8); diff --git a/transformer_engine/jax/csrc/extensions/router.cpp b/transformer_engine/jax/csrc/extensions/router.cpp index c81671f104..79daec3f07 100644 --- a/transformer_engine/jax/csrc/extensions/router.cpp +++ b/transformer_engine/jax/csrc/extensions/router.cpp @@ -177,12 +177,12 @@ Error_Type FusedMoEAuxLossForwardFFI(cudaStream_t stream, std::vector{static_cast(num_tokens), static_cast(num_experts)}; auto tpe_dtype = convert_ffi_datatype_to_te_dtype(tokens_per_expert_buf.element_type()); auto tpe_shape = std::vector{static_cast(num_experts)}; - auto scalar_shape = std::vector{1}; auto probs_tensor = TensorWrapper(probs_buf.untyped_data(), probs_shape, dtype); auto tpe_tensor = TensorWrapper(tokens_per_expert_buf.untyped_data(), tpe_shape, tpe_dtype); - auto aux_loss_tensor = TensorWrapper(aux_loss_buf->untyped_data(), scalar_shape, dtype); - auto const_buf_tensor = TensorWrapper(const_buf->untyped_data(), scalar_shape, DType::kFloat32); + auto aux_loss_tensor = TensorWrapper(aux_loss_buf->untyped_data(), std::vector{1}, dtype); + auto const_buf_tensor = + TensorWrapper(const_buf->untyped_data(), std::vector{2}, DType::kFloat32); nvte_fused_moe_aux_loss_forward(probs_tensor.data(), tpe_tensor.data(), num_tokens, num_experts, num_tokens, num_experts, static_cast(topk), @@ -219,16 +219,16 @@ Error_Type FusedMoEAuxLossBackwardFFI(cudaStream_t stream, auto num_tokens = static_cast(grad_probs_dims[0]); auto num_experts = static_cast(grad_probs_dims[1]); - auto scalar_shape = std::vector{1}; auto tpe_dims = tokens_per_expert_buf.dimensions(); auto tpe_shape = std::vector{static_cast(tpe_dims[0])}; auto grad_probs_shape = std::vector{static_cast(num_tokens), static_cast(num_experts)}; - auto const_buf_tensor = TensorWrapper(const_buf_in.untyped_data(), scalar_shape, DType::kFloat32); + auto const_buf_tensor = + TensorWrapper(const_buf_in.untyped_data(), std::vector{2}, DType::kFloat32); auto tpe_tensor = TensorWrapper(tokens_per_expert_buf.untyped_data(), tpe_shape, tpe_dtype); auto grad_aux_loss_tensor = - TensorWrapper(grad_aux_loss_buf.untyped_data(), scalar_shape, grad_dtype); + TensorWrapper(grad_aux_loss_buf.untyped_data(), std::vector{1}, grad_dtype); auto grad_probs_tensor = TensorWrapper(grad_probs_buf->untyped_data(), grad_probs_shape, grad_dtype); diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 4104820a1c..79ebbd4afa 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -1859,31 +1859,12 @@ def backward(ctx, d_out, *_args): class FusedAttention(torch.nn.Module): - """Dot product attention, with multiple backends: - - 1. FusedAttnBackend["F16_max512_seqlen"] - cuDNN based fused attention for FP16/BF16 and <=512 sequence length. - 2. FusedAttnBackend["F16_arbitrary_seqlen"] - cuDNN based fused attention for FP16/BF16 and any sequence length. - - Support matrix: - - | backend | 1 | 2 | - | flash based | no | yes | - | cuDNN based | yes | yes | - | qkv dtype | fp16/bf16 | fp16/bf16 | - | attn_type | self/cross | self/cross | - | qkv_layout | | | - | - (q,k,v) | sb3hd, bs3hd | sb3hd, bs3hd, sbh3d, bsh3d | - | | sbhd_sb2hd, bshd_bs2hd | sbhd_sb2hd, bshd_bs2hd | - | | bshd_bshd_bshd | sbhd_sbh2d, bshd_bsh2d | - | | | sbhd_sbhd_sbhd, bshd_bshd_bshd | - | mask_type | causal/padding/no_mask | causal/padding/no_mask | - | bias_type | post_scale_bias/no_bias | post_scale_bias/alibi/no_bias | - | dropout | yes | yes | - | max_seqlen | <=512, multiple of 64 | any, multiple of 64 | - | head_dim | 64 | <=128, multiple of 8 | - | output dtype | fp16/bf16 | fp16/bf16 | + """Dot product attention using cuDNN attention: + + FusedAttnBackend["F16_arbitrary_seqlen"] + cuDNN attention for FP16/BF16 with any sequence length. + FusedAttnBackend["FP8"] + cuDNN attention for FP8 with any sequence length. """ def __init__( diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 7b10593acf..32eb1b597a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -980,10 +980,7 @@ def cp_p2p_fwd_fused_attn( ) if fp8: - if qkv_layout != "t3hd": - softmax_lse_per_step, rng_states = aux_ctx_tensors - else: - softmax_lse_per_step, _, rng_states = aux_ctx_tensors + softmax_lse_per_step, rng_states = aux_ctx_tensors else: softmax_lse_per_step, rng_states, *rest = aux_ctx_tensors attn_bias = rest[0] if len(rest) > 0 else None @@ -1169,17 +1166,7 @@ def cp_p2p_bwd_fused_attn( section, ): """Per-tile backward call of CP P2P with FusedAttention backend""" - if fp8: - if qkv_layout == "t3hd": - aux_tensors = [ - softmax_lse, - softmax_lse, - rng_states[cp_size - step - 1], - ] - else: - aux_tensors = [softmax_lse, rng_states[cp_size - step - 1]] - else: - aux_tensors = [softmax_lse, rng_states[cp_size - step - 1]] + aux_tensors = [softmax_lse, rng_states[cp_size - step - 1]] max_seqlen_q_ = max_seqlen_q max_seqlen_kv_ = max_seqlen_kv @@ -1195,17 +1182,7 @@ def cp_p2p_bwd_fused_attn( attn_mask_type_ = "padding" if "padding" in attn_mask_type else "no_mask" elif section == "upper-triangle": q_part, out_part, dout_part = [x.contiguous() for x in [q_part, out_part, dout_part]] - if fp8: - if qkv_layout == "t3hd": - aux_tensors = [ - softmax_lse_, - softmax_lse_, - rng_states[cp_size - step - 1], - ] - else: - aux_tensors = [softmax_lse_, rng_states[cp_size - step - 1]] - else: - aux_tensors = [softmax_lse_, rng_states[cp_size - step - 1]] + aux_tensors = [softmax_lse_, rng_states[cp_size - step - 1]] max_seqlen_q_ = max_seqlen_q // 2 cu_seqlens_q_padded_ = None if cu_seqlens_q_padded is None else cu_seqlens_q_padded // 2 @@ -3223,10 +3200,7 @@ def forward( **fp8_meta_kwargs, ) if fp8: - if qkv_layout != "t3hd": - softmax_lse_per_step[i], rng_states[i] = aux_ctx_tensors - else: - softmax_lse_per_step[i], _, rng_states[i] = aux_ctx_tensors + softmax_lse_per_step[i], rng_states[i] = aux_ctx_tensors else: softmax_lse_per_step[i], rng_states[i], *_ = aux_ctx_tensors if return_max_logit: @@ -3588,17 +3562,10 @@ def backward(ctx, dout, *_args): out_part = out.select(seq_dim_o, i).contiguous() dout_part = dout.select(seq_dim_o, i).contiguous() if ctx.use_fused_attention: - if ctx.fp8 and ctx.qkv_layout == "t3hd": - aux_ctx_tensors = [ - softmax_lse_per_step[i], - softmax_lse_per_step[i], - rng_states[i], - ] - else: - aux_ctx_tensors = [ - softmax_lse_per_step[i], - rng_states[i], - ] + aux_ctx_tensors = [ + softmax_lse_per_step[i], + rng_states[i], + ] fused_attn_backend = tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen fp8_meta_kwargs = {} new_qkv_layout = ctx.qkv_layout diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index ed87423534..7df5daabe5 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -1217,10 +1217,6 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt "Disabling FusedAttention as dbias calculation is not supported for 111s" ) use_fused_attention = False - elif not fu_core_attention_bias_requires_grad: - # max512 backend will only support [1, h, s, s] - os.environ["NVTE_FUSED_ATTN_BACKEND"] = "1" - # Filter: cuDNN support fused_attention_backend = None if use_fused_attention: @@ -1254,32 +1250,6 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt logger.debug("Disabling FusedAttention as no backend supports the provided input") use_fused_attention = False fused_attention_backend = None - if ( - use_fused_attention - and window_size is not None - and (window_size[0] != -1 or window_size[1] not in [-1, 0]) - and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] - ): - logger.debug( - "Disabling FusedAttention as only sub-backend %s does not support " - "slidng window attention", - int(fused_attention_backend), - ) - use_fused_attention = False - fused_attention_backend = None - if ( - use_fused_attention - and fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] - and fu_core_attention_bias_type == "post_scale_bias" - and fu_core_attention_bias_shape != "1hss" - ): - logger.debug( - "Disabling FusedAttention as cuDNN sub-backend 0 only supports post_scale_bias in" - " [1, H, S, S] shape" - ) - use_fused_attention = False - fused_attention_backend = None - # Filter: Determinism # backend | deterministic # --------------------------------------------- diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 01e139da46..2ce939430d 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -98,13 +98,12 @@ } FusedAttnBackend = { - "F16_max512_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen, "F16_arbitrary_seqlen": NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, "FP8": NVTE_Fused_Attn_Backend.NVTE_FP8, "No_Backend": NVTE_Fused_Attn_Backend.NVTE_No_Backend, } -BACKEND_F16m512_FP8_THREADS_PER_CTA = 128 +BACKEND_FP8_THREADS_PER_CTA = 128 BACKEND_F16arb_ELTS_PER_THREADS = 16 META_QKV = FP8FwdTensorIdx.GEMM1_OUTPUT @@ -249,22 +248,17 @@ def fused_attn_fwd( if is_training is False, aux_ctx_tensors = None softmax-related tensors: - 1. if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"] - softmax: torch.Tensor - Softmax(Q*K.T) - shape [batch_size, num_heads, max_seqlen_q, max_seqlen_kv], dtype float32 - 2. if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] + 1. if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] softmaxStats: torch.Tensor log(sum(e^(x - max(x)))), where x=Q*K.T shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 - 3. if fused_attention_backend == FusedAttnBackend["FP8"] - M: torch.Tensor - max(Q*K.T) + Max: torch.Tensor, only when return_max_logit is True shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 - ZInv: torch.Tensor, only allocated for T3HD path - 1/sum(e^(x - max(x))), where x=Q*K.T + 2. if fused_attention_backend == FusedAttnBackend["FP8"] + softmaxStats: torch.Tensor + log(sum(e^(x - max(x)))), where x=Q*K.T shape [batch_size, num_heads, max_seqlen_q, 1], dtype float32 - rng_state: torch.Tensor, optional, if backend is not F16_max512_seqlen + rng_state: torch.Tensor state of the random number generator; [seed, offset], dtype uint64 max_logit : if return_max_logit = True, shape [h] and same data type as O; otherwise None @@ -299,19 +293,13 @@ def fused_attn_fwd( f" q.dtype={q.dtype}, backend={fused_attention_backend}." ) - # BF16/FP16 fused attention API from fmha_v1 apex - if fused_attention_backend == FusedAttnBackend["F16_max512_seqlen"]: - rng_elts_per_thread = ( - max_seqlen_q * max_seqlen_kv + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 - ) // BACKEND_F16m512_FP8_THREADS_PER_CTA - # BF16/FP16 fused attention API from fmha_v2 - elif fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: + if fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]: rng_elts_per_thread = BACKEND_F16arb_ELTS_PER_THREADS # FP8 fused attention API from fmha_v2 elif fused_attention_backend == FusedAttnBackend["FP8"]: rng_elts_per_thread = ( - max_seqlen_q * max_seqlen_q + BACKEND_F16m512_FP8_THREADS_PER_CTA - 1 - ) // BACKEND_F16m512_FP8_THREADS_PER_CTA + max_seqlen_q * max_seqlen_q + BACKEND_FP8_THREADS_PER_CTA - 1 + ) // BACKEND_FP8_THREADS_PER_CTA else: raise ValueError(f"Unsupported backend {fused_attention_backend}") @@ -472,7 +460,7 @@ def fused_attn_bwd( in torch.dtype aux_ctx_tensors : List[torch.Tensor] auxiliary output tensors of the forward pass when its is_training is True, - e.g. aux_ctx_tensors = [M, ZInv, rng_state] + e.g. aux_ctx_tensors = [S, Max, rng_state] fused_attention_backend : tex.NVTE_Fused_Attn_Backend please see FusedAttention module for details on supported backends. cu_seqlens_q_padded : torch.Tensor, default = None @@ -566,13 +554,12 @@ def fused_attn_bwd( f" q.dtype={q.dtype}, backend={fused_attention_backend}." ) - if fused_attention_backend != FusedAttnBackend["F16_max512_seqlen"]: - if len(aux_ctx_tensors) < 1: - raise ValueError( - "aux_ctx_tensors must contain rng_state as its last element," - f" but got len(aux_ctx_tensors)={len(aux_ctx_tensors)}" - f" for backend={fused_attention_backend}." - ) + if len(aux_ctx_tensors) < 1: + raise ValueError( + "aux_ctx_tensors must contain rng_state as its last element," + f" but got len(aux_ctx_tensors)={len(aux_ctx_tensors)}" + f" for backend={fused_attention_backend}." + ) output_tensors = tex.fused_attn_bwd( max_seqlen_q, diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 6f3553bf94..edf2c1e1c2 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -15,6 +15,8 @@ from ..quantized_tensor import Quantizer from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from ..tensor.storage.grouped_tensor_storage import GroupedTensorStorage +from ..tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage from ..tensor.utils import is_custom from ..custom_recipes.gemm import custom_gemm from ...debug.pytorch.debug_quantization import DebugQuantizer @@ -69,6 +71,38 @@ def validate_gemm_scale(scale: Optional[float], required: bool) -> float: return 0.0 +def _is_nvfp4_row_scaled_tensor(tensor: torch.Tensor) -> bool: + """Whether tensor carries row-scaled NVFP4 global amax metadata.""" + return isinstance(tensor, NVFP4TensorStorage) and tensor._row_scaled_nvfp4 + + +def _nvfp4_row_scaled_gemm_inputs( + A: NVFP4TensorStorage, + B: NVFP4TensorStorage, + *, + transa: bool, +) -> Tuple[NVFP4TensorStorage, NVFP4TensorStorage, torch.Tensor]: + """Return GEMM aliases and FP32 output scales for row-scaled NVFP4.""" + A_metadata = A.get_metadata() + weight_amax = A._amax_rowwise if transa else A._amax_columnwise + assert weight_amax is not None and weight_amax.numel() == 1 + A_metadata["amax_rowwise" if transa else "amax_columnwise"] = weight_amax.new_ones(1) + A_metadata["row_scaled_nvfp4"] = False + + B_metadata = B.get_metadata() + rhs_rowwise_amax = B._amax_rowwise + assert rhs_rowwise_amax is not None + B_metadata["amax_rowwise"] = rhs_rowwise_amax.new_ones(1) + B_metadata["row_scaled_nvfp4"] = False + + assert rhs_rowwise_amax.dtype == torch.float32 and weight_amax.dtype == torch.float32 + return ( + NVFP4TensorStorage(**A_metadata), + NVFP4TensorStorage(**B_metadata), + (rhs_rowwise_amax * weight_amax).view(-1, 1), + ) + + def general_gemm( A: torch.Tensor, B: torch.Tensor, @@ -174,7 +208,65 @@ def general_gemm( "beta": beta, } - out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) + if not _is_nvfp4_row_scaled_tensor(A) and not _is_nvfp4_row_scaled_tensor(B): + out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) + else: + if _is_nvfp4_row_scaled_tensor(A): + raise NotImplementedError("Row-scaled NVFP4 GEMM does not support row-scaled A.") + assert layout[1] == "N", "Row-scaled NVFP4 GEMM currently supports N-layout B only." + if grad: + raise RuntimeError( + "Row-scaled NVFP4 GEMM currently supports fprop only. " + "Backward NVFP4 gradient quantizers should use scalar global amax." + ) + assert not gelu, "Row-scaled NVFP4 GEMM currently does not support fused GELU." + assert not accumulate, "Row-scaled NVFP4 GEMM currently does not support accumulation." + assert ( + quantization_params is None + ), "Row-scaled NVFP4 GEMM currently does not support output quantization." + assert ub is None, "Row-scaled NVFP4 GEMM currently does not support CommOverlap." + assert ( + extra_output is None + ), "Row-scaled NVFP4 GEMM currently does not support extra output." + assert not bulk_overlap, "Row-scaled NVFP4 GEMM currently does not support bulk overlap." + assert out is None or ( + isinstance(out, torch.Tensor) and not is_custom(out) + ), "Row-scaled NVFP4 GEMM currently supports only plain torch.Tensor outputs." + assert isinstance( + A, NVFP4TensorStorage + ), "Row-scaled NVFP4 GEMM currently requires NVFP4 A." + # cuBLAS folds NVFP4 global amax values into GEMM alpha. Keep the row-scaled + # recipe's global scales out of alpha and apply them in FP32 below. + gemm_A, gemm_B, rowwise_global_scales = _nvfp4_row_scaled_gemm_inputs(A, B, transa=transa) + + requested_out, requested_out_dtype = out, out_dtype + fp32_out = ( + torch.empty_like(requested_out, dtype=torch.float32) + if requested_out is not None + else None + ) + gemm_args = list(args) + gemm_args[0] = gemm_A # A + gemm_args[2] = gemm_B # B + gemm_args[4] = fp32_out # out + gemm_args[5] = None # quantization_params + gemm_args[6] = TE_DType[torch.float32] # out_dtype + gemm_args[7] = None # bias + out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*gemm_args, **kwargs) + out_2d = out.reshape(-1, out.shape[-1]) + + assert rowwise_global_scales.dtype == torch.float32 and out.dtype == torch.float32 + assert rowwise_global_scales.numel() == out_2d.shape[0] + + out_2d.mul_(rowwise_global_scales) + if bias is not None: + out_2d.add_(bias.to(dtype=torch.float32)) + + if requested_out is not None: + requested_out.copy_(out.to(dtype=requested_out.dtype)) + out = requested_out + elif requested_out_dtype is not None and requested_out_dtype != torch.float32: + out = out.to(dtype=requested_out_dtype) if debug_quantizer is not None: out = debug_quantizer.process_gemm_output(out) @@ -229,6 +321,44 @@ def general_grouped_gemm( else: bias_dtype = TE_DType[torch.bfloat16] + if any(_is_nvfp4_row_scaled_tensor(tensor) for tensor in A): + raise NotImplementedError("Row-scaled NVFP4 grouped GEMM does not support row-scaled A.") + if any(_is_nvfp4_row_scaled_tensor(tensor) for tensor in B): + assert D_dtype is None, "Row-scaled NVFP4 grouped GEMM currently does not support D_dtype." + if single_output: + assert ( + m_splits is not None + ), "Row-scaled NVFP4 grouped GEMM requires m_splits with single output." + out_init = out[0] if single_output else None + if single_output: + start_idx = 0 + out_views = [] + for i in range(num_gemms): + size = m_splits[i] + out_views.append(out_init[start_idx : start_idx + size]) + start_idx += size + else: + out_views = out + for i in range(num_gemms): + if out_views[i].numel() == 0: + continue + general_gemm( + A[i], + B[i], + quantization_params=quantization_params[i], + out_dtype=out_views[i].dtype, + out=out_views[i], + gelu=gelu, + accumulate=accumulate, + layout=layout, + bias=bias[i] if use_bias else None, + use_split_accumulator=use_split_accumulator, + grad=grad, + ) + if single_output: + out = out_init + return out, grad_bias, gelu_input + if isinstance(quantization_params[0], DebugQuantizer): assert not gelu, "GELU not supported in debug mode" if single_output: @@ -350,6 +480,13 @@ def general_grouped_gemm_for_grouped_tensor( if is_discrete_in and is_discrete_out: raise ValueError("Both A and out are discrete. This is not supported yet.") + if isinstance(A, GroupedTensorStorage) and A.row_scaled_nvfp4: + raise NotImplementedError("Row-scaled NVFP4 GroupedTensor GEMM is not supported yet.") + if isinstance(B, GroupedTensorStorage) and B.row_scaled_nvfp4: + raise NotImplementedError("Row-scaled NVFP4 GroupedTensor GEMM is not supported yet.") + if isinstance(out, GroupedTensorStorage) and out.row_scaled_nvfp4: + raise NotImplementedError("Row-scaled NVFP4 GroupedTensor GEMM is not supported yet.") + if is_discrete_out: # wgrad case. grouped_gemm_impl = tex.te_general_grouped_gemm_for_discrete_out diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 8e3bcdd5b3..8f5b8294e8 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -320,6 +320,8 @@ class NVFP4Quantizer : public Quantizer { // 2D block scaling bool with_2d_quantization; bool stochastic_rounding; + // Whether tensors emitted by this quantizer use row-scaled NVFP4 metadata. + bool row_scaled_nvfp4; int rht_matrix_random_sign_mask_t; at::Tensor rht_matrix; diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 2df3b66553..cab9fab30a 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -42,8 +42,9 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->row_scaled_nvfp4 || + (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { + // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else { impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; @@ -154,8 +155,9 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->row_scaled_nvfp4 || + (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { + // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else { impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index e6781bd58a..7e8018b3fd 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -271,22 +271,17 @@ std::vector fused_attn_fwd( nvte_set_tensor_param(&nvte_aux_tensor_pack.tensors[i], kNVTERowwiseData, &temp_data); }; // allocate memory for nvte_aux_tensor_pack.tensors - // f16_max512 : S [b, h, sq, skv] - // f16_arbitrary: - // return_max_logit=false: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] - // return_max_logit=true: S [b, h, sq, 1], Max [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] - // fp8 : M [b, h, sq, 1], optional ZInv [b, h, sq, 1] (T3HD path), rng_state [2] + // f16_arbitrary: S [b, h, sq, 1]/[tq, h, 1], (optional) Max [b, h, sq, 1]/[tq, h, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] + // fp8 : S [b, h, sq, 1], rng_state [2] size_t i = 0; at::Tensor output_tensor; - // intermediate softmax tensor, S or M (for fp8) + // intermediate softmax stats tensor S output_tensor = allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); set_tensor_param(i++, output_tensor); - // fp8 T3HD has an additional softmax stats tensor, ZInv; return_max_logit=true has an additional Max tensor - if (((qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) && - qkv_layout == NVTE_QKV_Layout::NVTE_T3HD) || - return_max_logit) { + // return_max_logit=true allocates Max after S + if (return_max_logit) { output_tensor = allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index 0cf2025f1b..4a78dde388 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -152,8 +152,9 @@ std::vector dact_dbias( } else if (detail::IsNVFP4Quantizers(quantizer_py.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->row_scaled_nvfp4 || + (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { + // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else { impl = Impl::FUSED_DACT_AMAX_NVFP4; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 50fe4c109e..9e1f381bfe 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -798,7 +798,13 @@ std::tuple, std::vector, bool> bulk_alloc // Quantization parameters const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage; + const bool row_scaled_nvfp4 = quantizer_cpp_list[0]->row_scaled_nvfp4; const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; + if (row_scaled_nvfp4) { + NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 bulk allocation requires rowwise usage."); + NVTE_CHECK(!columnwise_usage, + "Row-scaled NVFP4 bulk allocation does not support columnwise usage."); + } const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); const auto fp4_dtype = quantizer_cpp_list[0]->dtype; const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) Enable based on optimize_for_gemm; @@ -828,6 +834,16 @@ std::tuple, std::vector, bool> bulk_alloc } return fp4_shape; }; + auto flat_first_dim = [](const std::vector &shape) -> size_t { + if (shape.empty()) { + return 1; + } + size_t rows = 1; + for (size_t i = 0; i + 1 < shape.size(); ++i) { + rows *= shape[i]; + } + return rows; + }; // Allocate row-wise data std::vector rowwise_data_list, rowwise_scale_list, amax_rowwise_list; @@ -866,7 +882,11 @@ std::tuple, std::vector, bool> bulk_alloc // Note: Multi-quantize kernel does not require contiguous amaxes. const auto offset = roundup(buffer_size, 16); amax_offsets.push_back(offset); - buffer_size = offset + 4; + size_t amax_size = 4; + if (row_scaled_nvfp4) { + amax_size *= flat_first_dim(rowwise_data_shapes[i]); + } + buffer_size = offset + amax_size; } // Allocate full buffer @@ -879,8 +899,12 @@ std::tuple, std::vector, bool> bulk_alloc data_offsets[i], torch::kUInt8)); rowwise_scale_list.emplace_back( make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); + std::vector amax_shape{1}; + if (row_scaled_nvfp4) { + amax_shape = {flat_first_dim(rowwise_data_shapes[i])}; + } amax_rowwise_list.emplace_back( - make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kFloat32)); + make_torch_view(buffer, amax_shape, amax_offsets[i], torch::kFloat32)); } } @@ -960,9 +984,10 @@ std::tuple, std::vector, bool> bulk_alloc py::object amax_columnwise = columnwise_usage ? py::cast(amax_columnwise_list[i]) : py::none(); // Construct Python tensor - tensor_py_list.emplace_back(NVFP4TensorClass( - rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, amax_rowwise, - amax_columnwise, fp4_dtype, quantizer_py_list[i], with_gemm_swizzled_scales)); + tensor_py_list.emplace_back(NVFP4TensorClass(rowwise_data, rowwise_scale, columnwise_data, + columnwise_scale, amax_rowwise, amax_columnwise, + fp4_dtype, quantizer_py_list[i], + with_gemm_swizzled_scales, row_scaled_nvfp4)); // Construct C++ tensor // Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor, @@ -979,11 +1004,12 @@ std::tuple, std::vector, bool> bulk_alloc rowwise_usage ? rowwise_scale_shapes[i] : std::vector{0}, columnwise_usage ? columnwise_scale_shapes[i] : std::vector{0}, scaling_mode); tensor_wrapper.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); + tensor_wrapper.set_row_scaled_nvfp4(row_scaled_nvfp4); // Set the amax rowwise and amax columnwise if available if (rowwise_usage) { tensor_wrapper.set_amax(amax_rowwise_list[i].data_ptr(), DType::kFloat32, - std::vector{1}); + getTensorShape(amax_rowwise_list[i])); } if (columnwise_usage) { tensor_wrapper.set_columnwise_amax(amax_columnwise_list[i].data_ptr(), DType::kFloat32, @@ -1455,7 +1481,16 @@ std::vector split_quantize(const at::Tensor &tensor, return detail::IsNVFP4Quantizers(quantizer.ptr()); })) { allocation_method = AllocationMethod::BULK_NVFP4; - quantization_method = QuantizationMethod::FUSED_NVFP4; + const bool has_row_scaled_nvfp4 = + std::any_of(quantizer_cpp_list.begin(), quantizer_cpp_list.end(), + [](const std::unique_ptr &quantizer) { + return static_cast(quantizer.get())->row_scaled_nvfp4; + }); + if (has_row_scaled_nvfp4) { + quantization_method = QuantizationMethod::UNFUSED; + } else { + quantization_method = QuantizationMethod::FUSED_NVFP4; + } } } @@ -1492,7 +1527,8 @@ std::vector split_quantize(const at::Tensor &tensor, bool contiguous_data_and_scale = false; std::tie(output_py_list, output_cpp_list, contiguous_data_and_scale) = bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers); - if (!input_shape.empty() && input_shape.back() % 128 != 0) { + if (quantization_method == QuantizationMethod::FUSED_NVFP4 && !input_shape.empty() && + input_shape.back() % 128 != 0) { static std::once_flag once_unfused_nvfp4_fallback_warning; std::call_once(once_unfused_nvfp4_fallback_warning, []() { NVTE_WARN( diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index fb4c7aa1c9..4887b59c28 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -120,8 +120,9 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->row_scaled_nvfp4 || + (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { + // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else if (!transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { // TE kernel supports amax output @@ -357,8 +358,9 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->row_scaled_nvfp4 || + (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { + // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else if (!transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { // TE kernel supports amax output diff --git a/transformer_engine/pytorch/csrc/extensions/router.cpp b/transformer_engine/pytorch/csrc/extensions/router.cpp index 94625c0f12..4df64d8e26 100644 --- a/transformer_engine/pytorch/csrc/extensions/router.cpp +++ b/transformer_engine/pytorch/csrc/extensions/router.cpp @@ -148,7 +148,7 @@ std::tuple fused_moe_aux_loss_fwd(at::Tensor probs, // Create the output tensor at::Tensor aux_loss = at::empty({}, at::dtype(probs.scalar_type()).device(at::kCUDA)); - at::Tensor Const_buf = at::empty({}, at::dtype(at::kFloat).device(at::kCUDA)); + at::Tensor Const_buf = at::empty({2}, at::dtype(at::kFloat).device(at::kCUDA)); auto probs_cu = makeTransformerEngineTensor(probs); auto tokens_per_expert_cu = makeTransformerEngineTensor(tokens_per_expert); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index da91e5c170..8f2de325ae 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1696,6 +1696,7 @@ NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantize this->with_post_rht_amax = quantizer.attr("with_post_rht_amax").cast(); this->with_2d_quantization = quantizer.attr("with_2d_quantization").cast(); this->stochastic_rounding = quantizer.attr("stochastic_rounding").cast(); + this->row_scaled_nvfp4 = quantizer.attr("row_scaled_nvfp4").cast(); // Get amax reduction group if needed for NVFP4 AG const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast(); @@ -1747,6 +1748,12 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve NVTE_CHECK(flat_last_dim % NVFP4_BLOCK_SIZE == 0, "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); + const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; + if (row_scaled_nvfp4) { + NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 quantization requires rowwise usage."); + NVTE_CHECK(!columnwise_usage, + "Row-scaled NVFP4 quantization does not support columnwise usage."); + } const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); @@ -1760,9 +1767,10 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve rowwise_scale_inv_shape.end()); rowwise_data_tensor = at::empty(convert_shape_for_fp4(shape_int64), bit8_tensor_opts); rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); + const int64_t amax_rows = row_scaled_nvfp4 ? static_cast(flat_first_dim) : 1; // hadamard amax kernel will zero out pointer with ZeroAmaxKernel // nvte_compute_amax_with_config will zero out the pointer if needed - amax_rowwise = at::empty({1}, bit32_tensor_opts); + amax_rowwise = at::empty({amax_rows}, bit32_tensor_opts); } if (columnwise_usage) { const std::vector scale_inv_shape_int64(columnwise_scale_inv_shape.begin(), @@ -1805,6 +1813,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve kwargs["fp4_dtype"] = py::cast(this->dtype); kwargs["quantizer"] = this->quantizer; kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); + kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); kwargs["fake_dtype"] = GetATenDType(dtype); py::tuple args(0); @@ -1833,6 +1842,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve kwargs["fp4_dtype"] = py::cast(this->dtype); kwargs["quantizer"] = this->quantizer; kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); + kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); py::tuple args(0); PyObject* result = PyObject_Call(reinterpret_cast(NVFP4TensorPythonClass), args.ptr(), kwargs.ptr()); @@ -1850,7 +1860,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve out_cpp.set_rowwise_data(rowwise_data_tensor.data_ptr(), DType::kFloat4E2M1, shape); out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, rowwise_scale_inv_shape); - out_cpp.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, std::vector{1}); + out_cpp.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, getTensorShape(amax_rowwise)); } if (columnwise_usage) { // enforce 2D shape to avoid [S, B, H] shape and B and be 1 @@ -1865,6 +1875,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve std::vector{1}); } out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); + out_cpp.set_row_scaled_nvfp4(row_scaled_nvfp4); this->set_quantization_params(&out_cpp); return {std::move(out_cpp), std::move(out_py)}; @@ -1892,6 +1903,12 @@ std::pair NVFP4Quantizer::create_grouped_tenso std::optional rowwise_amax; std::optional columnwise_amax; const std::vector logical_shape_vec = {logical_first_dim, logical_last_dim}; + const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; + if (row_scaled_nvfp4) { + NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 grouped quantization requires rowwise usage."); + NVTE_CHECK(!columnwise_usage, + "Row-scaled NVFP4 grouped quantization does not support columnwise usage."); + } const int64_t total_data_elements = total_elements / 2; @@ -1900,7 +1917,9 @@ std::pair NVFP4Quantizer::create_grouped_tenso const auto scale_shape = get_scale_shape(logical_shape_vec, false); const int64_t total_scale_elements = static_cast(product(scale_shape)); rowwise_scale_inv = at::empty({total_scale_elements}, uint8_opts); - rowwise_amax = at::empty({static_cast(num_tensors)}, float_opts); + const int64_t amax_elements = row_scaled_nvfp4 ? static_cast(logical_first_dim) + : static_cast(num_tensors); + rowwise_amax = at::empty({amax_elements}, float_opts); } if (columnwise_usage) { @@ -1958,6 +1977,7 @@ std::pair NVFP4Quantizer::create_grouped_tenso kwargs["last_dims"] = py::none(); kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); kwargs["with_gemm_swizzled_scales"] = this->optimize_for_gemm; + kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); if (result == nullptr) { PyErr_Print(); @@ -1975,15 +1995,22 @@ std::pair NVFP4Quantizer::create_unquantized_tensor_w auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype); // Register amax pointer from quantized tensor - void* amax_ptr = quantized_tensor.amax(); + auto rowwise_amax = quantized_tensor.get_amax(); + auto columnwise_amax = quantized_tensor.get_columnwise_amax(); + + void* amax_ptr = rowwise_amax.data_ptr; + std::vector amax_shape = convertShape(rowwise_amax.shape); if (amax_ptr == nullptr) { - amax_ptr = quantized_tensor.get_columnwise_amax().data_ptr; + amax_ptr = columnwise_amax.data_ptr; + amax_shape = convertShape(columnwise_amax.shape); } NVTE_CHECK(amax_ptr != nullptr, "Could not extract amax pointer from NVFP4 tensor."); - out_cpp.set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + out_cpp.set_amax(amax_ptr, DType::kFloat32, amax_shape); // Zero out amax - NVTE_CHECK_CUDA(cudaMemsetAsync(amax_ptr, 0, sizeof(float), at::cuda::getCurrentCUDAStream())); + const size_t amax_numel = product(amax_shape); + NVTE_CHECK_CUDA( + cudaMemsetAsync(amax_ptr, 0, amax_numel * sizeof(float), at::cuda::getCurrentCUDAStream())); return {std::move(out_cpp), std::move(out_py)}; } @@ -2031,6 +2058,13 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( } } const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; + const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; + if (row_scaled_nvfp4) { + NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 quantization requires rowwise usage."); + NVTE_CHECK(!columnwise_usage, + "Row-scaled NVFP4 quantization does not support columnwise usage."); + } + tensor.attr("_row_scaled_nvfp4") = py::cast(row_scaled_nvfp4); // Coerce row-wise data if (rowwise_usage) { @@ -2048,11 +2082,12 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( rowwise_scale_inv = at::empty(scale_inv_shape_int64, opts); tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv; } - if (!amax_rowwise) { + const int64_t amax_rows = row_scaled_nvfp4 ? static_cast(flat_first_dim) : 1; + if (!amax_rowwise || amax_rowwise->numel() != amax_rows) { const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); // hadamard amax kernel will zero out pointer with ZeroAmaxKernel // nvte_compute_amax_with_config will zero out the pointer if needed - amax_rowwise = at::empty({1}, opts); + amax_rowwise = at::empty({amax_rows}, opts); tensor.attr("_amax_rowwise") = *amax_rowwise; } } else { // rowwise_usage == false @@ -2118,7 +2153,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( out_cpp.set_rowwise_data(rowwise_data->data_ptr(), DType::kFloat4E2M1, shape); out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E4M3, getTensorShape(*rowwise_scale_inv)); - out_cpp.set_amax(amax_rowwise->data_ptr(), DType::kFloat32, std::vector{1}); + out_cpp.set_amax(amax_rowwise->data_ptr(), DType::kFloat32, getTensorShape(*amax_rowwise)); } if (columnwise_usage) { // enforce 2D shape to avoid [S, B, H] shape and B and be 1 @@ -2133,6 +2168,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( std::vector{1}); } out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); + out_cpp.set_row_scaled_nvfp4(row_scaled_nvfp4); this->set_quantization_params(&out_cpp); return {std::move(out_cpp), std::move(tensor)}; @@ -2241,6 +2277,18 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou } size_t cols = input.size(input.ndim() - 1); + const bool row_scaled_nvfp4 = out.get_row_scaled_nvfp4(); + if (row_scaled_nvfp4) { + NVTE_CHECK(!this->with_rht, "Row-scaled NVFP4 quantization does not support RHT."); + NVTE_CHECK(!this->with_2d_quantization, + "Row-scaled NVFP4 quantization does not support 2D quantization."); + NVTE_CHECK(!this->stochastic_rounding, + "Row-scaled NVFP4 quantization does not support stochastic rounding."); + NVTE_CHECK(!this->with_amax_reduction, + "Row-scaled NVFP4 quantization does not support amax reduction."); + NVTE_CHECK(cols % 16 == 0, "Row-scaled NVFP4 quantization requires last dim divisible by 16."); + } + // Restriction for the RHT cast fusion kernel because we are using MMA hardware for computing RHT bool eligible_for_rht_cast_fusion = input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; @@ -2307,7 +2355,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou "Use with_post_rht_amax=true instead."); } } else { // Without RHT - if (compute_amax) { + if (compute_amax && !row_scaled_nvfp4) { // Amax pointers auto rowwise_amax_ptr = out.get_amax().data_ptr; auto columnwise_amax_ptr = out.get_columnwise_amax().data_ptr; @@ -2408,6 +2456,8 @@ void NVFP4Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, } void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out) { + NVTE_CHECK(!out.get_row_scaled_nvfp4(), + "quantize_with_amax is not supported for row-scaled NVFP4 quantization."); // Update output tensor amaxes with input tensor amax auto input_amax_ptr = input.amax(); auto output_rowwise_amax_ptr = out.get_amax().data_ptr; diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index e13554a98c..37ab0b0535 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -134,6 +134,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); const bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); const bool with_gemm_swizzled_scales = tensor.attr("_with_gemm_swizzled_scales").cast(); + const bool row_scaled_nvfp4 = tensor.attr("_row_scaled_nvfp4").cast(); NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for NVFP4 Tensor."); @@ -163,6 +164,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) // Scale layout ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); + ret.set_row_scaled_nvfp4(row_scaled_nvfp4); // Quantizer state quantizer->set_quantization_params(&ret); diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index dd01ae05d3..12f8ef8f5b 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -350,9 +350,17 @@ def __init__( pow_2_scales: bool = False, eps: float = 0.0, quant_tile_shape: Tuple[int, int] = (1, 16), + row_scaled_nvfp4: bool = False, with_rht: bool = False, with_random_sign_mask: bool = True, ): + if row_scaled_nvfp4: + if not rowwise: + raise ValueError("Row-scaled NVFP4 reference quantization requires rowwise usage.") + if columnwise: + raise ValueError( + "Row-scaled NVFP4 reference quantization does not support columnwise usage." + ) super().__init__(rowwise=rowwise, columnwise=columnwise) self.internal = True @@ -360,6 +368,7 @@ def __init__( self.pow_2_scales = pow_2_scales self.eps = eps self.quant_tile_shape = quant_tile_shape + self.row_scaled_nvfp4 = row_scaled_nvfp4 self.with_rht = with_rht self.with_random_sign_mask = with_random_sign_mask @@ -447,6 +456,7 @@ def _quantize_blockwise_reference( tile_len_y: int, *, pow_2_scales: bool, + row_scaled_nvfp4: bool = False, eps: float, # pylint: disable=unused-argument ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -488,6 +498,9 @@ def _quantize_blockwise_reference( decode_scale.to(torch.float32), ) else: + if row_scaled_nvfp4: + global_amax = global_amax.to(torch.float32).view(m, 1, 1) + global_encode_scale = torch.div(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX, global_amax) global_encode_scale = torch.min( global_encode_scale, @@ -497,8 +510,15 @@ def _quantize_blockwise_reference( dtype=torch.float32, ), ) - if global_encode_scale == torch.tensor(0.0, device=x.device, dtype=torch.float32): - global_encode_scale = torch.tensor(1.0, device=x.device, dtype=torch.float32) + if global_encode_scale.numel() == 1: + if global_encode_scale == torch.tensor(0.0, device=x.device, dtype=torch.float32): + global_encode_scale = torch.tensor(1.0, device=x.device, dtype=torch.float32) + else: + global_encode_scale = torch.where( + global_encode_scale == 0.0, + torch.ones_like(global_encode_scale), + global_encode_scale, + ) global_decode_scale = torch.div(1.0, global_encode_scale) global_encode_scale_multiplier = global_encode_scale * torch.reciprocal(FLOAT4_E2M1_MAX) @@ -609,6 +629,8 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ raise ValueError( f"MXFP4 only supports 1x32 tile shape, got {self.quant_tile_shape}" ) + if self.row_scaled_nvfp4: + raise ValueError("Row-scaled NVFP4 is only supported for NVFP4 (non-pow2) mode.") # TODO(etsykunov): Fix bug where global_amax_row and # global_amax_col are not defined # global_amax = torch.empty(0, device=tensor.device, dtype=torch.float32) @@ -625,13 +647,22 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ if self.with_rht else tensor.t().contiguous() ) - # Compute amax for rowwise and columnwise paths separately - global_amax_row = torch.max(torch.abs(row_input)).to(torch.float32).view(1) - global_amax_col = ( - torch.max(torch.abs(col_input)).to(torch.float32).view(1) - if self.columnwise_usage - else global_amax_row - ) + if self.row_scaled_nvfp4: + if self.quant_tile_shape != (1, 16): + raise ValueError( + "Row-scaled NVFP4 only supports NVFP4 1x16 tile shape, " + f"got {self.quant_tile_shape}" + ) + global_amax_row = torch.max(torch.abs(row_input), dim=1).values.to(torch.float32) + global_amax_col = global_amax_row + else: + # Compute amax for rowwise and columnwise paths separately + global_amax_row = torch.max(torch.abs(row_input)).to(torch.float32).view(1) + global_amax_col = ( + torch.max(torch.abs(col_input)).to(torch.float32).view(1) + if self.columnwise_usage + else global_amax_row + ) transpose_scales = False @@ -648,6 +679,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ self.quant_tile_shape[1], self.quant_tile_shape[0], pow_2_scales=self.pow_2_scales, + row_scaled_nvfp4=self.row_scaled_nvfp4, eps=self.eps, ) if transpose_scales: @@ -868,7 +900,11 @@ def qgemm( partial_alpha = qresult_x.global_amax_col * qresult_w.global_amax_col else: partial_alpha = qresult_x.global_amax_row * qresult_w.global_amax_row - alpha = torch.div(partial_alpha, factor).squeeze(-1) + if partial_alpha.numel() > 1 and partial_alpha.numel() == high_precision_x.shape[0]: + partial_alpha = partial_alpha.view(-1, 1) + else: + partial_alpha = partial_alpha.squeeze(-1) + alpha = torch.div(partial_alpha, factor) M, K = high_precision_x.shape N, K_w = high_precision_w.shape diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 9956fb77ec..e9f009d93d 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -1375,6 +1375,7 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: with_post_rht_amax=qparams.random_hadamard_transform, with_2d_quantization=qparams.fp4_2d_quantization, stochastic_rounding=qparams.stochastic_rounding, + row_scaled_nvfp4=self.recipe.row_scaled_activation and idx % 3 != 1, ) return [_make_quantizer(idx) for idx in range(self.num_quantizers)] @@ -1389,6 +1390,7 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: with_post_rht_amax=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, with_2d_quantization=self.recipe.fp4_quant_bwd_grad.fp4_2d_quantization, stochastic_rounding=self.recipe.fp4_quant_bwd_grad.stochastic_rounding, + row_scaled_nvfp4=False, ) for _ in range(self.num_quantizers) ] diff --git a/transformer_engine/pytorch/tensor/grouped_tensor.py b/transformer_engine/pytorch/tensor/grouped_tensor.py index ab0c7484fc..f28f972b58 100644 --- a/transformer_engine/pytorch/tensor/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/grouped_tensor.py @@ -92,6 +92,7 @@ def __new__( requires_grad: bool = False, stride: Optional[List[int]] = None, with_gemm_swizzled_scales: bool = False, + row_scaled_nvfp4: bool = False, ): if ( shapes is not None @@ -164,6 +165,7 @@ def __new__( scale_inv_offsets=scale_inv_offsets, columnwise_scale_inv_offsets=columnwise_scale_inv_offsets, with_gemm_swizzled_scales=with_gemm_swizzled_scales, + row_scaled_nvfp4=row_scaled_nvfp4, ) return instance @@ -195,6 +197,7 @@ def copy_grouped_storage_metadata(dst: GroupedTensor, src: GroupedTensor) -> Non dst.logical_shape = src.logical_shape dst.quantized_tensors = src.quantized_tensors dst._with_gemm_swizzled_scales = src._with_gemm_swizzled_scales + dst.row_scaled_nvfp4 = src.row_scaled_nvfp4 def make_wrapper_like(src: GroupedTensor, requires_grad: bool) -> GroupedTensor: """Create a wrapper of the same type and tensor metadata as src.""" diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 65678aa347..285a7f030a 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -128,6 +128,9 @@ class NVFP4Quantizer(Quantizer): """Stochastic rounding, only applicable for gradients.""" stochastic_rounding: bool + """Whether emitted NVFP4 tensors store one FP32 amax per row.""" + row_scaled_nvfp4: bool + """RHT matrix random sign mask""" rht_matrix_random_sign_mask_t: int rht_matrix: torch.Tensor @@ -143,6 +146,7 @@ def __init__( with_post_rht_amax: bool = False, with_2d_quantization: bool = False, stochastic_rounding: bool = False, + row_scaled_nvfp4: bool = False, with_random_sign_mask: bool = True, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) @@ -153,6 +157,7 @@ def __init__( self.amax_reduction_group = amax_reduction_group self.with_2d_quantization = with_2d_quantization self.stochastic_rounding = stochastic_rounding + self.row_scaled_nvfp4 = row_scaled_nvfp4 self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht( with_random_sign_mask, torch.cuda.current_device() ) @@ -198,6 +203,7 @@ def copy(self) -> NVFP4Quantizer: with_post_rht_amax=self.with_post_rht_amax, with_2d_quantization=self.with_2d_quantization, stochastic_rounding=self.stochastic_rounding, + row_scaled_nvfp4=self.row_scaled_nvfp4, ) quantizer.internal = self.internal quantizer.optimize_for_gemm = self.optimize_for_gemm @@ -212,6 +218,8 @@ def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: def is_quantizable(self, inp: torch.Tensor) -> bool: """Returns whether or not given inp can be quantized""" + if self.row_scaled_nvfp4: + return False if inp.ndim < 2: return False if inp.shape[-1] % NVFP4_BLOCK_SCALING_SIZE != 0: @@ -313,6 +321,11 @@ def make_empty( f"Incorrect shape {shape} for NVFP4. Tensor dims must be divisible by" f" {NVFP4_BLOCK_SCALING_SIZE}" ) + if self.row_scaled_nvfp4: + if not self.rowwise_usage: + raise ValueError("Row-scaled NVFP4 quantization requires rowwise usage.") + if self.columnwise_usage: + raise ValueError("Row-scaled NVFP4 quantization does not support columnwise usage.") # Allocate FP4 data data = None @@ -329,8 +342,11 @@ def make_empty( scale_inv = torch.empty( scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory ) - # Allocate per tensor scale inverse. FP32 format. - amax_rowwise = torch.zeros(1, dtype=torch.float32, device=device, pin_memory=pin_memory) + # Allocate global amax metadata. Row-scaled NVFP4 stores one value per row. + amax_rows = flat_first_dim if self.row_scaled_nvfp4 else 1 + amax_rowwise = torch.zeros( + amax_rows, dtype=torch.float32, device=device, pin_memory=pin_memory + ) # Allocate FP8 data transpose if needed columnwise_data = None @@ -371,6 +387,7 @@ def make_empty( quantizer=self, requires_grad=requires_grad, with_gemm_swizzled_scales=False, + row_scaled_nvfp4=self.row_scaled_nvfp4, ) def calibrate(self, tensor: torch.Tensor) -> None: @@ -431,6 +448,7 @@ def __new__( fp4_dtype: TE_DType, quantizer: Quantizer, with_gemm_swizzled_scales: bool, + row_scaled_nvfp4: bool = False, **kwargs, ): instance = super().__new__( @@ -445,6 +463,7 @@ def __new__( quantizer, with_gemm_swizzled_scales, *args, + row_scaled_nvfp4=row_scaled_nvfp4, **kwargs, ) return instance diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index 485b32328b..ac56d334bc 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -72,6 +72,7 @@ def _initialize_storage_fields( requires_grad: bool = False, stride: Optional[List[int]] = None, with_gemm_swizzled_scales: bool = False, + row_scaled_nvfp4: bool = False, ) -> None: """ Initialize a GroupedTensor. @@ -147,6 +148,7 @@ def _initialize_storage_fields( # Used as a convenience. instance.quantized_tensors = None instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales + instance.row_scaled_nvfp4 = row_scaled_nvfp4 def __new__( cls, @@ -172,6 +174,7 @@ def __new__( requires_grad: bool = False, stride: Optional[List[int]] = None, with_gemm_swizzled_scales: bool = False, + row_scaled_nvfp4: bool = False, ): instance = object.__new__(cls) cls._initialize_storage_fields( @@ -197,6 +200,7 @@ def __new__( requires_grad=requires_grad, stride=stride, with_gemm_swizzled_scales=with_gemm_swizzled_scales, + row_scaled_nvfp4=row_scaled_nvfp4, ) return instance @@ -371,6 +375,7 @@ def clear(self) -> None: self.columnwise_scale_inv_offsets = None self.tensor_shapes = [] self.fake_dtype = torch.float32 + self.row_scaled_nvfp4 = False def __repr__(self) -> str: """String representation of the GroupedTensorStorage.""" @@ -539,6 +544,7 @@ def copy(self) -> "GroupedTensorStorage": scale_inv_offsets=self.scale_inv_offsets, columnwise_scale_inv_offsets=self.columnwise_scale_inv_offsets, with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, + row_scaled_nvfp4=self.row_scaled_nvfp4, ) @staticmethod @@ -649,6 +655,7 @@ def make_grouped_tensor( scale = None scale_inv_offsets = None columnwise_scale_inv_offsets = None + row_scaled_nvfp4 = False if no_quantization: assert dtype is not None, "dtype must be provided for unquantized GroupedTensor" if rowwise_usage: @@ -707,6 +714,19 @@ def make_grouped_tensor( # Amax buffer for delayed scaling - one per tensor amax = torch.empty(num_tensors, dtype=torch.float32, device=device) elif quantizer._get_compatible_recipe().nvfp4(): + row_scaled_nvfp4 = quantizer.row_scaled_nvfp4 + if row_scaled_nvfp4: + if not rowwise_usage: + raise ValueError( + "Row-scaled NVFP4 grouped quantization requires rowwise usage." + ) + if columnwise_usage: + raise ValueError( + "Row-scaled NVFP4 grouped quantization does not support columnwise usage." + ) + total_amax_elements = ( + sum(math.prod(s[:-1]) for s in shape) if row_scaled_nvfp4 else num_tensors + ) if rowwise_usage: # Allocate rowwise data buffer (1D flattened, uint8, but FP4 packs 2 values per byte) @@ -720,8 +740,7 @@ def make_grouped_tensor( total_scale_elements += math.prod(scale_inv_shape) scale_inv_offsets.append(total_scale_elements) scale_inv = torch.empty(total_scale_elements, dtype=torch.uint8, device=device) - # Amax buffer - one per tensor - amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + amax = torch.empty(total_amax_elements, dtype=torch.float32, device=device) if columnwise_usage: # Allocate columnwise data buffer (1D flattened, uint8, FP4 packed) @@ -738,7 +757,6 @@ def make_grouped_tensor( columnwise_scale_inv = torch.empty( total_columnwise_scale_elements, dtype=torch.uint8, device=device ) - # Columnwise amax buffer - one per tensor columnwise_amax = torch.empty(num_tensors, dtype=torch.float32, device=device) elif quantizer._get_compatible_recipe().float8_block_scaling(): if rowwise_usage: @@ -824,6 +842,7 @@ def make_grouped_tensor( with_gemm_swizzled_scales=( quantizer.optimize_for_gemm if quantizer is not None else False ), + row_scaled_nvfp4=row_scaled_nvfp4, ) grouped_tensor.quantized_tensors = grouped_tensor.split_into_quantized_tensors() return grouped_tensor @@ -936,6 +955,14 @@ def split_into_quantized_tensors( cum += math.prod(scale_shape) columnwise_scale_inv_offsets.append(cum) self.columnwise_scale_inv_offsets = columnwise_scale_inv_offsets + nvfp4_rowwise_amax_offsets = None + row_scaled_nvfp4 = self.row_scaled_nvfp4 + if recipe.nvfp4() and row_scaled_nvfp4: + cum = 0 + nvfp4_rowwise_amax_offsets = [0] + for i in range(self.num_tensors): + cum += math.prod(self.tensor_shapes[i][:-1]) + nvfp4_rowwise_amax_offsets.append(cum) for i in range(self.num_tensors): quantizer = self.quantizer @@ -1128,9 +1155,13 @@ def split_into_quantized_tensors( cscale_shape ) - # Extract amax - one per tensor if self.amax is not None: - amax_rowwise = self.amax[i : i + 1] + if nvfp4_rowwise_amax_offsets is not None: + amax_start = nvfp4_rowwise_amax_offsets[i] + amax_end = nvfp4_rowwise_amax_offsets[i + 1] + amax_rowwise = self.amax[amax_start:amax_end] + else: + amax_rowwise = self.amax[i : i + 1] if self.columnwise_amax is not None: amax_columnwise = self.columnwise_amax[i : i + 1] @@ -1152,6 +1183,7 @@ def split_into_quantized_tensors( fp4_dtype=quantizer.dtype, quantizer=quantizer, with_gemm_swizzled_scales=quantizer.optimize_for_gemm, + row_scaled_nvfp4=row_scaled_nvfp4, ) result.append(tensor) diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index 70699ad71a..e51acb71e5 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -97,6 +97,8 @@ class NVFP4TensorStorage(QuantizedTensorStorage): # Whether scaling factors are in the swizzled format expected by # GEMM _with_gemm_swizzled_scales: bool + # Whether this NVFP4 tensor uses row-scaled amax metadata + _row_scaled_nvfp4: bool def __new__( cls, @@ -111,6 +113,7 @@ def __new__( with_gemm_swizzled_scales: bool, *args, fake_dtype: Optional[torch.dtype] = None, + row_scaled_nvfp4: bool = False, **kwargs, ): if cls is NVFP4TensorStorage: @@ -128,6 +131,7 @@ def __new__( instance._amax_rowwise = amax_rowwise instance._amax_columnwise = amax_columnwise instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales + instance._row_scaled_nvfp4 = row_scaled_nvfp4 return instance @@ -152,6 +156,8 @@ def copy_from_storage(self, src: QuantizedTensorStorage) -> None: raise RuntimeError("FP4 dtype mismatch in copy_from_storage") if self._with_gemm_swizzled_scales != src._with_gemm_swizzled_scales: raise RuntimeError("Scale layout mismatch in copy_from_storage") + if self._row_scaled_nvfp4 != src._row_scaled_nvfp4: + raise RuntimeError("Rowwise amax scaling mode mismatch in copy_from_storage") def _copy_optional(dst: Optional[torch.Tensor], src_tensor: Optional[torch.Tensor]): if dst is not None and src_tensor is not None: @@ -176,6 +182,7 @@ def get_metadata(self) -> Dict[str, Any]: "fp4_dtype": self._fp4_dtype, "quantizer": self._quantizer, "with_gemm_swizzled_scales": self._with_gemm_swizzled_scales, + "row_scaled_nvfp4": self._row_scaled_nvfp4, "fake_dtype": self._dtype, } @@ -308,6 +315,7 @@ def view(self, shape: torch.Size): quantizer=self._quantizer, fp4_dtype=self._fp4_dtype, with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, + row_scaled_nvfp4=self._row_scaled_nvfp4, fake_dtype=self._dtype, )