Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions docs/envvars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,9 @@ Attention Backend Selection

.. envvar:: NVTE_FUSED_ATTN_BACKEND

:Type: ``int`` (0, 1, or 2)
:Type: ``int`` (1 or 2)
:Default: Auto-selected
:Description: Force a specific FusedAttention backend. ``0`` = F16_max512_seqlen (cuDNN, ≤512 seq len), ``1`` = F16_arbitrary_seqlen (cuDNN, any seq len), ``2`` = FP8 backend. If not set, the backend is automatically selected based on the input configuration.
:Description: Force a specific FusedAttention backend. ``1`` = F16_arbitrary_seqlen (cuDNN, any seq len), ``2`` = FP8 backend. If not set, the backend is automatically selected based on the input configuration.

.. envvar:: NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT

Expand Down Expand Up @@ -281,6 +281,12 @@ Kernel Configuration
:Default: ``0``
:Description: Emit a warning when falling back from CUTLASS to cuBLAS for grouped GEMM operations.

.. envvar:: NVTE_NVFP4_ROW_SCALED_ACTIVATION

:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable row-scaled NVFP4 tensors for forward activation quantizers in the ``NVFP4BlockScaling`` recipe. When set to ``1`` (or when ``NVFP4BlockScaling(row_scaled_activation=True)`` is used), rowwise ``amax`` metadata is stored as one FP32 value per tensor row instead of a single scalar.

Torch Compilation and Fusion
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
171 changes: 126 additions & 45 deletions tests/cpp/operator/test_cast_nvfp4_transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,14 @@ void quantize_nvfp4_1d(float (*OP)(const float),
block_amax = std::max(block_amax, std::abs(elt));
}

// 2. Compute E4M3 scaling factor
// Compute per-block encoding/decoding scaling factor
const float S_dec_b = block_amax / 6.0f;

// Scale & Store per-block decoding scaling factor
const fp8e4m3 S_dec_b_fp8 = static_cast<fp8e4m3>(S_dec_b * S_enc);
// Compute and store the per-block FP8 decode scale
const float S_dec_b = block_amax * (S_enc * (1.0f / 6.0f));
const fp8e4m3 S_dec_b_fp8 = static_cast<fp8e4m3>(fminf(S_dec_b, Numeric_Traits<float>::maxNorm));
const float S_dec_b_fp32 = static_cast<float>(S_dec_b_fp8);

// Compute "correct" per-block encoding scaling factor
const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32;
const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f :
fminf(1.0f / (S_dec_b_fp32 * (1.0f / S_enc)), Numeric_Traits<float>::maxNorm);

const size_t scale_idx = i * scales_stride + block_X;
scales[scale_idx] = S_dec_b_fp8;
Expand Down Expand Up @@ -317,11 +315,31 @@ void compute_ref(float (*OP)(const float),
const size_t scales_stride,
const size_t scales_stride_t,
const bool use_fast_math,
const bool use_2d_quantization = false)
const bool use_2d_quantization = false,
std::vector<float> *rowwise_amax = nullptr)
{
std::vector<InputType> input_t = create_transpose(input, rows, cols);

if (use_2d_quantization) {
if (rowwise_amax != nullptr) {
rowwise_amax->resize(rows, 0.0f);
for (size_t row = 0; row < rows; ++row) {
float row_amax = 0.0f;
for (size_t col = 0; col < cols; ++col) {
row_amax = fmaxf(row_amax, fabsf(static_cast<float>(input[row * cols + col])));
}
(*rowwise_amax)[row] = row_amax;
quantize_nvfp4(OP,
input + row * cols,
output + row * (cols / 2),
scales + row * scales_stride,
1,
cols,
scales_stride,
row_amax,
use_fast_math,
use_2d_quantization);
}
} else if (use_2d_quantization) {
// Step 1: Compute mathematical 8×8 scaling factors
std::vector<std::vector<fp8e4m3>> math_scales;
compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math);
Expand Down Expand Up @@ -504,13 +522,12 @@ void print_detailed_tensor_comparison(const std::string& name,

void compareResults_nvfp4(const Tensor &test,
const void *ref, const void *ref_t, const int rows, const int cols,
double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true, bool dump_data = false) {
double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true,
bool dump_data = false, bool compare_columnwise = true) {
if (if_on_gpus) test.to_cpu();

const fp4e2m1 *test_data = test.rowwise_cpu_dptr<fp4e2m1>();
const fp4e2m1 *test_data_t = test.columnwise_cpu_dptr<fp4e2m1>();
const fp4e2m1 *ref_data = reinterpret_cast<const fp4e2m1*>(ref);
const fp4e2m1 *ref_data_t = reinterpret_cast<const fp4e2m1*>(ref_t);

// Print detailed element-by-element comparison
// print_detailed_tensor_comparison("output", test_data, ref_data, rows, cols);
Expand All @@ -519,17 +536,33 @@ void compareResults_nvfp4(const Tensor &test,
// Optionally dump tensor data to files for detailed analysis
if (dump_data) {
dump_nvfp4_tensor_data("output", test_data, ref_data, rows, cols);
dump_nvfp4_tensor_data("output_t", test_data_t, ref_data_t, cols, rows);
}

compare_nvfp4_tensors("output", test_data, ref_data, rows, cols, atol, rtol);
compare_nvfp4_tensors("output_t", test_data_t, ref_data_t, cols, rows, atol, rtol);
if (compare_columnwise) {
const fp4e2m1 *test_data_t = test.columnwise_cpu_dptr<fp4e2m1>();
const fp4e2m1 *ref_data_t = reinterpret_cast<const fp4e2m1*>(ref_t);
if (dump_data) {
dump_nvfp4_tensor_data("output_t", test_data_t, ref_data_t, cols, rows);
}
compare_nvfp4_tensors("output_t", test_data_t, ref_data_t, cols, rows, atol, rtol);
}
}

void compare_rowwise_amax(const Tensor &output, const std::vector<float> &ref_amax) {
const std::vector<float> test_amax_data = output.tensor_amax_values();
ASSERT_EQ(test_amax_data.size(), ref_amax.size());
for (size_t row = 0; row < ref_amax.size(); ++row) {
ASSERT_EQ(test_amax_data[row], ref_amax[row])
<< "Row-scaled amax mismatch at row " << row;
}
}

template <typename InputType>
void performTest(float (*OP)(const float),
const std::vector<size_t>& shape,
const bool use_fast_math) {
const bool use_fast_math,
const bool row_scaled_nvfp4 = false) {
using namespace test;

DType itype = TypeInfo<InputType>::dtype;
Expand All @@ -556,7 +589,7 @@ void performTest(float (*OP)(const float),
const size_t scales_stride_t = blocks_X_t;

Tensor input("input", shape, itype);
Tensor output("output", shape, otype, true, true, NVTE_NVFP4_1D_SCALING);
Tensor output("output", shape, otype, true, !row_scaled_nvfp4, NVTE_NVFP4_1D_SCALING);

std::unique_ptr<fp4e2m1x2[]> ref_output = std::make_unique<fp4e2m1x2[]>(rows * (cols / 2));
std::unique_ptr<fp4e2m1x2[]> ref_output_t = std::make_unique<fp4e2m1x2[]>(cols * (rows / 2));
Expand All @@ -567,26 +600,44 @@ void performTest(float (*OP)(const float),

// Golden value of amax chosen to make the 2nd-stage scaling mantissa zero and avoid rounding issues
const float amax = 448.0f * 6.0f * 8.0f;

// Set 2nd stage NVFP4 scaling factor
output.set_tensor_amax(amax);
output.set_tensor_amax_columnwise(amax);

std::vector<float> ref_rowwise_amax;
bool use_2d_quantization = false;
if (row_scaled_nvfp4) {
output.set_tensor_amax_shape({rows});
output.set_row_scaled_nvfp4(true);
compute_ref<InputType>(OP,
input.rowwise_cpu_dptr<InputType>(),
ref_output.get(),
ref_output_t.get(),
ref_scales.get(),
ref_scales_t.get(),
0.0f,
rows,
cols,
scales_stride,
scales_stride_t,
use_fast_math,
use_2d_quantization,
&ref_rowwise_amax);
} else {
// Set 2nd stage NVFP4 scaling factor
output.set_tensor_amax(amax);
output.set_tensor_amax_columnwise(amax);
compute_ref<InputType>(OP,
input.rowwise_cpu_dptr<InputType>(),
ref_output.get(),
ref_output_t.get(),
ref_scales.get(),
ref_scales_t.get(),
amax,
rows,
cols,
scales_stride,
scales_stride_t,
use_fast_math,
use_2d_quantization);
}

compute_ref<InputType>(OP,
input.rowwise_cpu_dptr<InputType>(),
ref_output.get(),
ref_output_t.get(),
ref_scales.get(),
ref_scales_t.get(),
amax,
rows,
cols,
scales_stride,
scales_stride_t,
use_fast_math,
use_2d_quantization);
// Initialize stochastic rounding
Tensor rng_state("rng_state", std::vector<size_t>{2}, DType::kInt64);
rng_state.rowwise_cpu_dptr<int64_t>()[0] = 123; // rng_seed
Expand Down Expand Up @@ -629,23 +680,25 @@ void performTest(float (*OP)(const float),
const double rtol = 1.0E-6;

// Set dump_data=true to enable dumping tensor data to files for analysis
compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, false);

const fp8e4m3* kernel_scales = output.rowwise_cpu_scale_inv_ptr<fp8e4m3>();
const fp8e4m3* ref_scales_ptr = ref_scales.get();
const fp8e4m3* kernel_scales_t = output.columnwise_cpu_scale_inv_ptr<fp8e4m3>();
const fp8e4m3* ref_scales_t_ptr = ref_scales_t.get();
compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true,
false, !row_scaled_nvfp4);

size_t scale_mismatches_num = 0;
compare_scaling_factors<fp8e4m3>("scales", output.rowwise_cpu_scale_inv_ptr<fp8e4m3>(),
ref_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
scale_mismatches_num);

compare_scaling_factors<fp8e4m3>("scales_t", output.columnwise_cpu_scale_inv_ptr<fp8e4m3>(),
ref_scales_t.get(),
unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t,
scale_mismatches_num);
if (!row_scaled_nvfp4) {
compare_scaling_factors<fp8e4m3>("scales_t", output.columnwise_cpu_scale_inv_ptr<fp8e4m3>(),
ref_scales_t.get(),
unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t,
scale_mismatches_num);
}

if (row_scaled_nvfp4) {
compare_rowwise_amax(output, ref_rowwise_amax);
}
}

std::vector<std::vector<size_t>> tensor_dims = {
Expand Down Expand Up @@ -678,6 +731,7 @@ class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam
<std::tuple<ActivationType,
std::vector<size_t>,
transformer_engine::DType,
bool,
bool>> {};

TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) {
Expand All @@ -693,6 +747,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) {
const auto tensor_dims = std::get<1>(GetParam());
const DType input_type = std::get<2>(GetParam());
const bool use_fast_math = std::get<3>(GetParam());
const bool row_scaled_nvfp4 = std::get<4>(GetParam());

// Skip tests if the input tensor is 1D
if (tensor_dims.size() < 2) {
Expand All @@ -710,7 +765,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) {
}

TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
performTest<InputType>(OP, tensor_dims, use_fast_math);
performTest<InputType>(OP, tensor_dims, use_fast_math, row_scaled_nvfp4);
);
}

Expand All @@ -733,6 +788,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn(Activation_types),
::testing::ValuesIn(tensor_dims),
::testing::Values(DType::kBFloat16),
::testing::Values(false),
::testing::Values(false)),
[](const testing::TestParamInfo<FusedCastTransposeNVFP4TestSuite::ParamType>& info) {
std::string name = to_string(std::get<0>(info.param));
Expand All @@ -746,3 +802,28 @@ INSTANTIATE_TEST_SUITE_P(
}
return name;
});

INSTANTIATE_TEST_SUITE_P(
OperatorTestRowScaled,
FusedCastTransposeNVFP4TestSuite,
::testing::Combine(
::testing::Values(ActivationType::Identity),
::testing::Values(tensor_dims[4], tensor_dims[9], tensor_dims[12]),
::testing::Values(DType::kBFloat16, DType::kFloat32),
::testing::Values(false),
::testing::Values(true)),
[](const testing::TestParamInfo<FusedCastTransposeNVFP4TestSuite::ParamType>& info) {
std::string name = to_string(std::get<0>(info.param));
const auto& shape = std::get<1>(info.param);
for (const auto& s: shape) {
name += "X" + std::to_string(s);
}
name += "X" + test::typeName(std::get<2>(info.param));
if (std::get<3>(info.param)) {
name += "X_FAST_SCALING";
}
if (std::get<4>(info.param)) {
name += "XROW_SCALED";
}
return name;
});
Loading
Loading