Skip to content

Commit dfe5b7d

Browse files
janekb04pre-commit-ci[bot]Copilottimmoon10
authored
[Common][Pytorch] Add support for the FP8 Block Scaling (ie. Deepseek) recipe on Blackwell (#2157)
* Update to_string(NVTEScalingMode) to include block scaling Signed-off-by: Jan Bielak <jbielak@nvidia.com> * Add `nvte_swizzle_block_scaling_to_mxfp8_scaling_factors` Signed-off-by: Jan Bielak <jbielak@nvidia.com> * Convert FP8 block scaling tensors to MXFP8 tensors on Blackwell and newer in GEMM Signed-off-by: Jan Bielak <jbielak@nvidia.com> * Allow Blackwell and newer in Deepseek recipe compatbility check Signed-off-by: Jan Bielak <jbielak@nvidia.com> * Allow data_rows % 4 != 0 in 1d kernel Signed-off-by: Jan Bielak <jbielak@nvidia.com> * Load scaling factors in unswizzled order in 1d kernel Signed-off-by: Jan Bielak <jbielak@nvidia.com> * Enforce use of power of two scaling Signed-off-by: Jan Bielak <jbielak@nvidia.com> * Skip the FP8 block scaling exact GEMM test on Blackwell Signed-off-by: Jan Bielak <jbielak@nvidia.com> * Skip further tests with pow_2_scales=False Signed-off-by: Jan Bielak <jbielak@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Initial implementation of tensor conversion for grouped gemm Signed-off-by: Jan Bielak <jbielak@nvidia.com> * Skip non power of two scaling cpp unit tests Signed-off-by: Jan Bielak <jbielak@nvidia.com> * Fix handling of all gather Signed-off-by: Jan Bielak <jbielak@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Jan Bielak <jbielak@nvidia.com> * Use compute capability 10.0 for logic with Blackwell Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Apply suggestions from code review Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by: Jan Bielak <jbielak@nvidia.com> Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
1 parent b840898 commit dfe5b7d

14 files changed

Lines changed: 553 additions & 35 deletions

tests/cpp/operator/test_cast_float8blockwise.cu

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,12 @@ TEST_P(FusedCastFloat8BlockwiseTestSuite, TestFusedCastFloat8Blockwise) {
501501
q_opts.amax_epsilon = eps;
502502
q_opts.block_scaling_dim = 2u;
503503

504+
// On Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8,
505+
// which requires using power of two scaling factors. Skip unsupported tests.
506+
if (getDeviceComputeCapability() >= blackwellComputeCapability && !force_pow_2) {
507+
GTEST_SKIP();
508+
}
509+
504510
if (colwise && matrix_size.size() < 2) {
505511
// test_common Tensor initialization code does not
506512
// handle this case.
@@ -552,6 +558,12 @@ TEST_P(FusedCastFloat8VectorwiseTestSuite, TestFusedCastFloat8Vectorwise) {
552558
q_opts.amax_epsilon = eps;
553559
q_opts.block_scaling_dim = 1u;
554560

561+
// On Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8,
562+
// which requires using power of two scaling factors. Skip unsupported tests.
563+
if (getDeviceComputeCapability() >= blackwellComputeCapability && !force_pow_2) {
564+
GTEST_SKIP();
565+
}
566+
555567
if (colwise && matrix_size.size() < 2) {
556568
// test_common Tensor initialization code does not
557569
// handle this case.

tests/pytorch/test_float8_blockwise_gemm_exact.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import transformer_engine_torch as tex
99

1010
from transformer_engine.pytorch.constants import TE_DType
11+
from transformer_engine.pytorch.utils import get_device_compute_capability
1112
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
1213
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
1314
Float8BlockQuantizer,
@@ -19,7 +20,8 @@
1920

2021
def fp8_blockwise_gemm_supported() -> bool:
2122
supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
22-
return supported
23+
emulated = get_device_compute_capability() >= (10, 0)
24+
return supported and not emulated
2325

2426

2527
def cublas_gemm_fp8_blockwise_case(

tests/pytorch/test_float8_blockwise_scaling_exact.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import transformer_engine_torch as tex
1313
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
1414
from transformer_engine.common.recipe import Float8BlockScaling
15+
from transformer_engine.pytorch.utils import get_device_compute_capability
1516
from transformer_engine.pytorch.constants import TE_DType
1617
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
1718
Float8BlockQuantizer,
@@ -32,6 +33,7 @@
3233
if tensor_dump_dir_env is not None:
3334
TENSOR_DUMP_DIR = pathlib.Path(tensor_dump_dir_env)
3435
recipe_available, reason_for_no_recipe = FP8GlobalStateManager.is_fp8_block_scaling_available()
36+
recipe_emulated = get_device_compute_capability() >= (10, 0)
3537

3638

3739
class GetRecipes:
@@ -218,6 +220,12 @@ def check_quantization_block_tiling_versus_reference(
218220
pow_2_scales: bool,
219221
tile_size: Tuple[int, int],
220222
) -> None:
223+
if recipe_emulated and not pow_2_scales:
224+
pytest.skip(
225+
"On Blackwell and newer, the FP8 block scaling recipe is emulated "
226+
"with MXFP8, which requires using power of two scaling factors."
227+
)
228+
221229
te_dtype = TE_DType[quant_dtype]
222230
if tile_size == (1, 128):
223231
block_scaling_dim = 1
@@ -409,6 +417,12 @@ def test_quantization_block_tiling_extrema_versus_reference(
409417
tile_size: Tuple[int, int],
410418
extrema_high: bool,
411419
) -> None:
420+
if recipe_emulated and not pow_2_scales:
421+
pytest.skip(
422+
"On Blackwell and newer, the FP8 block scaling recipe is emulated "
423+
"with MXFP8, which requires using power of two scaling factors."
424+
)
425+
412426
# This test runs a single tile through a quantizer as a way to test
413427
# branch coverage of scale computation.
414428
te_dtype = TE_DType[quant_dtype]

transformer_engine/common/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ list(APPEND transformer_engine_SOURCES
127127
util/multi_stream.cpp
128128
util/rtc.cpp
129129
swizzle/swizzle.cu
130+
swizzle/swizzle_block_scaling.cu
130131
fused_softmax/scaled_masked_softmax.cu
131132
fused_softmax/scaled_upper_triang_masked_softmax.cu
132133
fused_softmax/scaled_aligned_causal_masked_softmax.cu

transformer_engine/common/include/transformer_engine/swizzle.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,26 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud
4444
void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs,
4545
const size_t num_tensors, cudaStream_t stream);
4646

47+
/*! \brief Swizzling FP8 block scaling scaling factors into mxfp8 interleaved layout for GEMM
48+
*
49+
* \param[in] input Input FP8 block scaling tensor with GEMM_READY scale_inv.
50+
* \param[in,out] output Output mxfp8 tensor which hosts swizzled scale_inv.
51+
* \param[in] stream CUDA stream used for the operation.
52+
*
53+
* This function is used for emulating the FP8 block scaling recipe on Blackwell and newer as it
54+
* not natively supported by cublasLt on architectures other than Hopper.
55+
56+
* Requirements:
57+
* - input is an FP8 block scaling tensor
58+
* - input has rowwise usage
59+
* - input.scale_inv is in GEMM_READY format
60+
* - output is an MXFP8 tensor
61+
* - output has rowwise usage
62+
* - output.scale_inv has appropriate shape
63+
* */
64+
void nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(const NVTETensor input, NVTETensor output,
65+
cudaStream_t stream);
66+
4767
#ifdef __cplusplus
4868
} // extern "C"
4969
#endif

0 commit comments

Comments
 (0)