From ad4b3fd19bb7b6e1536b819fc6ca6425a4d57ce2 Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Mon, 4 May 2026 10:40:47 -0700 Subject: [PATCH 1/2] [PyTorch][Core] Fix CUBLAS GGEMM when weight dims are not divisible by 128 (#2954) * fix CUBLAS for GPT oss sizes Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comments Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * ci fix Signed-off-by: Varun Thumbe * Fix test case dimensions in test_numerics.py Total dim should be divisible by 128 Signed-off-by: vthumbe1503 --------- Signed-off-by: Varun Thumbe Signed-off-by: vthumbe1503 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_numerics.py | 69 ++++++++++++++----- .../common/cast/mxfp8/swizzle.cuh | 7 +- .../common/gemm/cublaslt_grouped_gemm.cu | 65 +++++++++++++++-- 3 files changed, 117 insertions(+), 24 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 5eef7f151d..a718ea2a8a 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -3084,7 +3084,10 @@ def _make_zero_tokens_grouped_tensor(logical_last_dim, is_a): weight_tensors = [torch.randn(n, k, dtype=dtype, device=device) for _ in range(z)] if use_mxfp8: grouped_A = _make_grouped_tensor_quantized_mxfp8( - weight_tensors, is_a=True, transposed=transa, device=device + weight_tensors, + rowwise=transa, + columnwise=not transa, + device=device, ) else: grouped_A = _make_grouped_tensor_uniform(z, n, k, device, dtype) @@ -3138,36 +3141,61 @@ def _make_zero_tokens_grouped_tensor(logical_last_dim, is_a): def _make_grouped_tensor_quantized_mxfp8( tensors: List[torch.Tensor], *, - is_a: bool, - transposed: bool, + rowwise: bool, + columnwise: bool, device: torch.device, - optimize_for_gemm: bool = True, + is_weight: bool = False, ) -> GroupedTensor: + """Create a quantized MXFP8 GroupedTensor from a list of per-expert tensors. + + For weights (uniform per-expert shape), we generally won't keep it swizzled since we + might need for future dequantize operations. Swizzling is done internally within + general_grouped_gemm_for_grouped_tensor call. + + For non-weight tensors (inputs / grad_outputs), we still pass + ``first_dims`` and keep ``optimize_for_gemm=True``; so the kernel must emit the + already-swizzled layout up front. + """ if not tensors: raise ValueError("Expected non-empty tensor list for grouped quantization.") - if is_a: - rowwise = transposed - columnwise = not transposed - else: - rowwise = not transposed - columnwise = transposed quantizer = MXFP8Quantizer( fp8_dtype=tex.DType.kFloat8E4M3, rowwise=rowwise, columnwise=columnwise, ) - quantizer.optimize_for_gemm = optimize_for_gemm + quantizer.optimize_for_gemm = not is_weight grouped_input = torch.cat(tensors, dim=0) - first_dims = torch.tensor([t.shape[0] for t in tensors], dtype=torch.int64, device=device) + if is_weight: + first_dims = None + else: + first_dims = torch.tensor([t.shape[0] for t in tensors], dtype=torch.int64, device=device) return tex.group_quantize(grouped_input, quantizer, len(tensors), first_dims) +def _per_tensor_quantize_mxfp8( + tensors: List[torch.Tensor], + *, + rowwise: bool, + columnwise: bool, +) -> List: + """Quantize each tensor individually with MXFP8. + Used to build reference discrete inputs for grouped GEMM. + """ + quantizer = MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=rowwise, + columnwise=columnwise, + ) + return [quantizer(t) for t in tensors] + + @pytest.mark.parametrize( "shape", [ (1, 128, 128, 512), (8, 1024, 128, 512), (16, 4096, 128, 512), + (2, 256, 2880, 2880), ], ) @pytest.mark.parametrize("accumulate", [False, True]) @@ -3208,12 +3236,21 @@ def test_grouped_gemm_grouped_tensor_mxfp8( transa = layout[0] == "T" transb = layout[1] == "T" - grouped_A = _make_grouped_tensor_quantized_mxfp8(A, is_a=True, transposed=transa, device="cuda") + a_is_weight = all(t.shape == A[0].shape for t in A) + a_rowwise, a_columnwise = transa, not transa + b_rowwise, b_columnwise = not transb, transb + grouped_A = _make_grouped_tensor_quantized_mxfp8( + A, + rowwise=a_rowwise, + columnwise=a_columnwise, + device="cuda", + is_weight=a_is_weight, + ) grouped_B = _make_grouped_tensor_quantized_mxfp8( - B, is_a=False, transposed=transb, device="cuda" + B, rowwise=b_rowwise, columnwise=b_columnwise, device="cuda" ) - A_fp8 = grouped_A.split_into_quantized_tensors() - B_fp8 = grouped_B.split_into_quantized_tensors() + A_fp8 = _per_tensor_quantize_mxfp8(A, rowwise=a_rowwise, columnwise=a_columnwise) + B_fp8 = _per_tensor_quantize_mxfp8(B, rowwise=b_rowwise, columnwise=b_columnwise) general_grouped_gemm( A_fp8, diff --git a/transformer_engine/common/cast/mxfp8/swizzle.cuh b/transformer_engine/common/cast/mxfp8/swizzle.cuh index 7648e3f5cb..e3876eb908 100644 --- a/transformer_engine/common/cast/mxfp8/swizzle.cuh +++ b/transformer_engine/common/cast/mxfp8/swizzle.cuh @@ -16,6 +16,9 @@ namespace dispatch { namespace mxfp8 { namespace swizzle { +constexpr size_t GEMM_SWIZZLED_SCALE_TILE_DIM_X = 4; +constexpr size_t GEMM_SWIZZLED_SCALE_TILE_DIM_Y = 128; + /*! \brief Convert compact scale indices into GEMM swizzled scale index * * MXFP8 GEMM expects scaling factors to be in a "swizzled" order @@ -25,8 +28,8 @@ namespace swizzle { * */ __device__ __forceinline__ size_t gemm_swizzled_scale_idx(size_t i, size_t j, size_t num_tiles_X) { - constexpr size_t TILE_DIM_X = 4; // Tile dim in scale buffer - constexpr size_t TILE_DIM_Y = 128; + constexpr size_t TILE_DIM_X = GEMM_SWIZZLED_SCALE_TILE_DIM_X; + constexpr size_t TILE_DIM_Y = GEMM_SWIZZLED_SCALE_TILE_DIM_Y; constexpr size_t TILE_SIZE = TILE_DIM_X * TILE_DIM_Y; const size_t tile_idx_X = j / TILE_DIM_X; const size_t tile_idx_Y = i / TILE_DIM_Y; diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index ed2275b442..6a7af158e5 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -15,6 +15,7 @@ #include #include +#include "../cast/mxfp8/swizzle.cuh" #include "../common.h" #include "../util/cuda_runtime.h" #include "../util/handle_manager.h" @@ -330,6 +331,7 @@ struct GroupedOperandSelection { NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING; bool with_gemm_swizzled_scales = false; bool trans = false; + bool rowwise = true; }; constexpr int kMaxGroups = 64; @@ -613,6 +615,7 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: sel.dptr = static_cast(t->columnwise_data.dptr); sel.scale_inv = t->columnwise_scale_inv.dptr; sel.dtype = col_dtype; + sel.rowwise = false; sel.shape = create_shape_info(t, swap_dims); }; @@ -621,6 +624,7 @@ inline GroupedOperandSelection select_grouped_operand(const transformer_engine:: sel.dptr = static_cast(t->data.dptr); sel.scale_inv = t->scale_inv.dptr; sel.dtype = row_dtype; + sel.rowwise = true; sel.shape = create_shape_info(t, /*swap_dims=*/false); }; @@ -846,6 +850,45 @@ __forceinline__ __device__ int64_t compute_grouped_tensor_offset(const TensorSha } } +__forceinline__ __device__ int64_t padded_mxfp8_scale_inv_bytes(int64_t first, int64_t last, + bool rowwise) { + namespace mxfp8_swizzle = transformer_engine::dispatch::mxfp8::swizzle; + constexpr int64_t kMxfp8BlockSize = 32; + // x is the dimension along which quantization is applied, y is other dimension + const int64_t scale_tile_y = static_cast(mxfp8_swizzle::GEMM_SWIZZLED_SCALE_TILE_DIM_Y); + const int64_t scale_tile_x = static_cast(mxfp8_swizzle::GEMM_SWIZZLED_SCALE_TILE_DIM_X); + // Padded byte size of the swizzled MXFP8 scale_inv for a single tensor with data + // shape (first, last). Rowwise scales use rows=first, cols=last; columnwise + // scales swap the orientation since they are stored in column-major order. + const int64_t scale_dim_y = rowwise ? first : last; + const int64_t padded_scale_dim_y = + ((scale_dim_y + scale_tile_y - 1) / scale_tile_y) * scale_tile_y; + const int64_t data_dim_x = rowwise ? last : first; + const int64_t scale_dim_x = (data_dim_x + kMxfp8BlockSize - 1) / kMxfp8BlockSize; + const int64_t padded_scale_dim_x = + ((scale_dim_x + scale_tile_x - 1) / scale_tile_x) * scale_tile_x; + // MXFP8 scales are E8M0 (1 byte per element), so element count == byte count. + return padded_scale_dim_y * padded_scale_dim_x; +} + +// Device helper: byte offset into a contiguous grouped MXFP8 scale_inv buffer for +// tensor `idx`. Each expert's scale_inv is expected to be padded +// to the 128x4 swizzled layout. +__forceinline__ __device__ int64_t compute_grouped_tensor_mxfp8_scale_inv_offset( + const TensorShapeInfo &meta, size_t idx, bool rowwise) { + if (meta.first_dims != nullptr || meta.last_dims != nullptr) { + int64_t cumsum = 0; + for (size_t i = 0; i < idx; i++) { + const int64_t f = meta.first_dims ? meta.first_dims[i] : meta.uniform_first; + const int64_t l = meta.last_dims ? meta.last_dims[i] : meta.uniform_last; + cumsum += padded_mxfp8_scale_inv_bytes(f, l, rowwise); + } + return cumsum; + } + return static_cast(idx) * + padded_mxfp8_scale_inv_bytes(meta.uniform_first, meta.uniform_last, rowwise); +} + // Linear scan to find which tensor contains the given row. // Returns the tensor index and writes the exclusive end-row of that tensor to *out_tensor_row_end. __forceinline__ __device__ int find_tensor_for_row(const int64_t *first_dims, int64_t uniform_first, @@ -977,7 +1020,8 @@ __global__ void setup_grouped_gemm_kernel( size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, float *alpha_ptr, float *beta_ptr, // Scale inputs: for tensor scaling, pass float* and set mxfp8_base to nullptr // For MXFP8, pass nullptr for tensor_scale and set mxfp8_base - float *a_scale_base, float *b_scale_base, NVTEScalingMode scaling_mode, size_t num_tensors, + float *a_scale_base, float *b_scale_base, bool a_rowwise, bool b_rowwise, + NVTEScalingMode scaling_mode, size_t num_tensors, MultiTensorGroupGemmInputArgs a_multi_tensor_args, MultiTensorGroupGemmOutputArgs c_multi_tensor_args, MultiTensorGroupGemmOutputArgs d_multi_tensor_args) { @@ -1038,12 +1082,13 @@ __global__ void setup_grouped_gemm_kernel( // Fill scale pointers (per-matrix). // The interpretation of the scale buffers depends on the shared scaling recipe: - // NVTE_MXFP8_1D_SCALING : E8M0 byte stream; offset = data_offset / 32 elements // otherwise : one float per tensor, indexed by tensor index if (a_scale_base) { if (scaling_mode == NVTE_MXFP8_1D_SCALING) { + const int64_t a_scale_offset = + compute_grouped_tensor_mxfp8_scale_inv_offset(A_meta, idx, a_rowwise); a_scale_inv_ptrs[idx] = reinterpret_cast( - static_cast(static_cast(a_scale_base)) + a_offset / 32); + static_cast(static_cast(a_scale_base)) + a_scale_offset); } else { a_scale_inv_ptrs[idx] = static_cast(a_scale_base) + idx; } @@ -1052,8 +1097,10 @@ __global__ void setup_grouped_gemm_kernel( } if (b_scale_base) { if (scaling_mode == NVTE_MXFP8_1D_SCALING) { + const int64_t b_scale_offset = + compute_grouped_tensor_mxfp8_scale_inv_offset(B_meta, idx, b_rowwise); b_scale_inv_ptrs[idx] = reinterpret_cast( - static_cast(static_cast(b_scale_base)) + b_offset / 32); + static_cast(static_cast(b_scale_base)) + b_scale_offset); } else { b_scale_inv_ptrs[idx] = static_cast(b_scale_base) + idx; } @@ -1116,14 +1163,19 @@ inline void launch_grouped_gemm_setup( // A and B share the same scaling recipe (validated in validate_grouped_gemm_inputs). // Pass scale buffers as void* and let the kernel interpret them via scaling_mode. + + // Scale rowwise flag for MXFP8/NVFP4: to calculate scale_inv padding based offsets + // within kernel. Ignored for tensor scaling. + const bool a_rowwise = A_sel.rowwise; + const bool b_rowwise = B_sel.rowwise; setup_grouped_gemm_kernel<<>>( ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, ws.a_rows, ws.a_cols, ws.b_rows, ws.b_cols, ws.d_rows, ws.d_cols, ws.alpha_ptrs, ws.beta_ptrs, ws.a_scale_inv_ptrs, ws.b_scale_inv_ptrs, A_sel.dptr, B_sel.dptr, c_base, d_base, A_meta, B_meta, C_meta, D_meta, a_elem_size, b_elem_size, c_elem_size, d_elem_size, static_cast(alpha_tensor->data.dptr), static_cast(beta_tensor->data.dptr), reinterpret_cast(A_sel.scale_inv), - reinterpret_cast(B_sel.scale_inv), A_sel.scaling_mode, num_tensors, - a_multi_tensor_args, c_multi_tensor_args, d_multi_tensor_args); + reinterpret_cast(B_sel.scale_inv), a_rowwise, b_rowwise, A_sel.scaling_mode, + num_tensors, a_multi_tensor_args, c_multi_tensor_args, d_multi_tensor_args); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -1276,6 +1328,7 @@ void nvte_grouped_gemm_with_discrete_inputA(const NVTETensor *A_list, size_t num choose_grouped_operand_storage(static_cast(transa), /*is_A=*/true, mxfp8, is_fp8, non_tn_fp8_ok, A_list_info.all_row, A_list_info.all_col, "A"); A_sel.trans = choice.trans; + A_sel.rowwise = choice.use_rowwise; if (choice.use_rowwise) { NVTE_CHECK(A_list_info.all_row, "Grouped GEMM: A_list is missing row-wise data"); A_sel.dtype = A_list_info.row_dtype; From 528f16c5067a50c5a4ec2b8f4c466d3372536323 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 4 May 2026 18:06:37 -0400 Subject: [PATCH 2/2] [PyTorch] Guard/document single parameter feature for grouped linear (#2955) * Better documentation for single param and envvar guard Signed-off-by: Kirthi Shankar Sivamani * fix doc Signed-off-by: ksivamani * Fix test envvar Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani Signed-off-by: ksivamani --- qa/L0_pytorch_debug_unittest/test.sh | 2 +- qa/L0_pytorch_unittest/test.sh | 6 ++-- .../pytorch/module/grouped_linear.py | 12 +++++-- .../pytorch/ops/basic/grouped_linear.py | 10 ++++++ transformer_engine/pytorch/utils.py | 31 +++++++++++++++++++ 5 files changed, 55 insertions(+), 6 deletions(-) diff --git a/qa/L0_pytorch_debug_unittest/test.sh b/qa/L0_pytorch_debug_unittest/test.sh index ce65bc4305..3efa462628 100644 --- a/qa/L0_pytorch_debug_unittest/test.sh +++ b/qa/L0_pytorch_debug_unittest/test.sh @@ -36,7 +36,7 @@ NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_perf.py" # standard sanity and numerics tests with initialized debug -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "debug test_sanity.py" +NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "debug test_sanity.py" NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "debug test_numerics.py" if [ "$RET" -ne 0 ]; then diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index a8f8cf8754..22636828f9 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -24,7 +24,7 @@ mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" +NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" @@ -37,11 +37,11 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_grouped_tensor.xml $TE_PATH/tests/pytorch/test_grouped_tensor.py || test_fail "test_grouped_tensor.py" +NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/test_grouped_tensor.xml $TE_PATH/tests/pytorch/test_grouped_tensor.py || test_fail "test_grouped_tensor.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" -NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" +NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 NVTE_CUTEDSL_FUSED_GROUPED_MLP=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_backward_override.xml $TE_PATH/tests/pytorch/test_backward_override.py || test_fail "test_backward_override.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index cab8abae11..4ae7b47b9b 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -31,6 +31,7 @@ clear_tensor_data, init_method_constant, requires_grad, + resolve_grouped_linear_single_param_flags, get_nvtx_range_context, ) from ..distributed import ( @@ -659,11 +660,15 @@ class GroupedLinear(TransformerEngineBaseModule): single_grouped_weight : bool, default = False If set to ``True``, grouped weights are stored as a single grouped parameter instead of one parameter per GEMM. - EXPERIMENTAL and subject to change. + EXPERIMENTAL and subject to change. Gated by the + ``NVTE_GROUPED_LINEAR_SINGLE_PARAM`` environment variable: if the env var + is not set this argument is forced to ``False`` with a warning. single_grouped_bias : bool, default = False If set to ``True``, grouped biases are stored as a single grouped bias instead of one bias per GEMM. - EXPERIMENTAL and subject to change. + EXPERIMENTAL and subject to change. Gated by the + ``NVTE_GROUPED_LINEAR_SINGLE_PARAM`` environment variable: if the env var + is not set this argument is forced to ``False`` with a warning. Notes ----- @@ -712,6 +717,9 @@ def __init__( self.ub_overlap_ag = ub_overlap_ag self.ub_name = ub_name self.save_original_input = save_original_input + single_grouped_weight, single_grouped_bias = resolve_grouped_linear_single_param_flags( + single_grouped_weight, single_grouped_bias + ) self.single_grouped_weight = single_grouped_weight self.single_grouped_bias = single_grouped_bias if ub_overlap_rs or ub_overlap_ag: diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index b503cb186b..a86abb1325 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -30,6 +30,7 @@ canonicalize_dtype, clear_tensor_data, devices_match, + resolve_grouped_linear_single_param_flags, round_up_to_nearest_multiple, ) from .._common import is_quantized_tensor, maybe_dequantize @@ -78,11 +79,17 @@ class GroupedLinear(BasicOperation): ``main_grad`` instead of accumulating. single_grouped_weight : bool, default = ``False`` Store all expert weights as one ``GroupedTensor`` parameter ``weight``. + EXPERIMENTAL and subject to change. Gated by the + ``NVTE_GROUPED_LINEAR_SINGLE_PARAM`` environment variable: if the env var + is not set this argument is forced to ``False`` with a warning. delay_wgrad_compute : bool, default = ``False`` Whether to delay weight gradient computation single_grouped_bias : bool, default = ``False`` If ``True`` (and ``bias=True``), store all expert biases as one ``GroupedTensor`` parameter named ``bias`` instead of ``bias0``..``bias{N-1}``. + EXPERIMENTAL and subject to change. Gated by the + ``NVTE_GROUPED_LINEAR_SINGLE_PARAM`` environment variable: if the env var + is not set this argument is forced to ``False`` with a warning. scale_bias : bool, default = ``False`` If ``True`` (and ``bias=True``), expects a probability tensor as an additional extra input and adds ``bias * scales`` instead of ``bias`` @@ -123,6 +130,9 @@ def __init__( self.num_groups: int = num_groups self.in_features: int = in_features self.out_features: int = out_features + single_grouped_weight, single_grouped_bias = resolve_grouped_linear_single_param_flags( + single_grouped_weight, single_grouped_bias + ) self.single_grouped_weight: bool = single_grouped_weight self.single_grouped_bias: bool = single_grouped_bias self.use_bias: bool = bias diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index a76f205acc..250daec67f 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -7,6 +7,7 @@ import functools import math import os +import warnings from typing import Any, Callable, List, Optional, Sequence, Tuple, Union from contextlib import nullcontext import numpy as np @@ -81,6 +82,36 @@ def get_device_compute_capability() -> Tuple[int, int]: return _get_device_compute_capability(torch.cuda.current_device()) +def resolve_grouped_linear_single_param_flags( + single_grouped_weight: bool, + single_grouped_bias: bool, +) -> Tuple[bool, bool]: + """Gate ``single_grouped_weight`` / ``single_grouped_bias`` on ``NVTE_GROUPED_LINEAR_SINGLE_PARAM``.""" + if not (single_grouped_weight or single_grouped_bias): + return single_grouped_weight, single_grouped_bias + + env_enabled = int(os.environ.get("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "0")) > 0 + if not env_enabled: + warnings.warn( + f"GroupedLinear was constructed with single_grouped_weight={single_grouped_weight} " + f"and single_grouped_bias={single_grouped_bias}, but the " + "NVTE_GROUPED_LINEAR_SINGLE_PARAM environment variable is not set. " + "Disabling single grouped weight/bias and falling back to per-expert parameters.", + UserWarning, + stacklevel=3, + ) + return False, False + + warnings.warn( + "GroupedLinear is using single_grouped_weight/single_grouped_bias. " + "This feature is experimental, may change in future " + "releases, and is known to be non-deterministic in certain cases.", + UserWarning, + stacklevel=3, + ) + return single_grouped_weight, single_grouped_bias + + def attention_mask_func( attention_scores: torch.Tensor, attention_mask: torch.Tensor ) -> torch.Tensor: