diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 47507dc384..7691582f97 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -7,6 +7,7 @@ from collections.abc import Iterable, Sequence import functools import io +import os import math import random from typing import Optional @@ -42,7 +43,6 @@ ) from transformer_engine.pytorch.tensor.grouped_tensor import GroupedTensor from transformer_engine.pytorch.cpp_extensions.gemm import general_grouped_gemm_for_grouped_tensor -from transformer_engine.pytorch.module.base import get_dummy_wgrad import transformer_engine_torch as tex # Import utility functions @@ -51,6 +51,7 @@ assert_close_grads, dtype_tols, make_recipe, + MegatronTrainingHelper, quantization_tols, reset_rng_states, ) @@ -212,76 +213,6 @@ def to_cpu(tensor: Optional[torch.Tensor]) -> Optional[torch.Tensor]: return out -class MegatronTrainingHelper: - """Test-side stand-in for the Megatron-Core DDP / MegatronFSDP wrapper. - Megatron's DDP wrapper (and MegatronFSDP) owns the per-parameter - ``main_grad`` buffer and the ``overwrite_main_grad`` / - ``grad_added_to_main_grad`` attributes that coordinate - ``fuse_wgrad_accumulation`` with TE modules. These helpers reproduce the - relevant slice of that protocol so TE tests can exercise the - accumulate-into-``main_grad`` code path without pulling in the full - Megatron-Core dependency. - """ - - @staticmethod - def init_main_grad_buffers( - weight_params: Iterable[torch.nn.Parameter], - *, - fill_value: float, - overwrite_main_grad: bool, - zero_out_wgrad: bool = False, - dtype: torch.dtype = torch.float32, - ) -> None: - """Allocate ``main_grad`` and stamp the wrapper attributes on each - param, mirroring what the Megatron DDP/FSDP wrapper does before - backward.""" - for wp in weight_params: - wp.main_grad = torch.full(wp.size(), fill_value, device=wp.device, dtype=dtype) - wp.overwrite_main_grad = overwrite_main_grad - wp.zero_out_wgrad = zero_out_wgrad - wp.grad_added_to_main_grad = False - - @staticmethod - def verify_main_grad_accumulation( - weight_params: Iterable[torch.nn.Parameter], - *, - expected_main_grads: Iterable[torch.Tensor], - rtol: float = 0.0, - atol: float = 0.0, - ) -> None: - """Check that backward produced what the Megatron wrapper expects: - each ``main_grad`` matches ``expected_main_grads``, - ``grad_added_to_main_grad`` was flipped to ``True`` so the wrapper's - post-backward hooks won't double-accumulate, and ``param.grad`` was - replaced by the cached dummy tensor (so a wrapper hook that did - ``main_grad += grad`` would be a no-op rather than double-counting). - """ - for wp, expected in zip(weight_params, expected_main_grads): - torch.testing.assert_close(wp.main_grad.to(expected), expected, rtol=rtol, atol=atol) - - assert wp.grad_added_to_main_grad is True, ( - "weight.grad_added_to_main_grad was not flipped to True; " - "the Megatron DDP/FSDP wrapper hook will double-accumulate." - ) - - # ``.grad`` should be the cached dummy tensor returned by - # ``get_dummy_wgrad`` -- shared storage, not the real wgrad. - expected_dummy = get_dummy_wgrad(list(wp.size()), wp.dtype) - assert ( - wp.grad is not None - ), "weight.grad is None; the Megatron protocol expects a dummy tensor stand-in here." - assert wp.grad.data_ptr() == expected_dummy.data_ptr(), ( - "weight.grad does not share storage with the cached dummy " - "wgrad; downstream wrapper hooks risk double-accumulating." - ) - if getattr(wp, "zero_out_wgrad", False): - assert torch.all(wp.grad == 0), ( - "weight.zero_out_wgrad=True but the dummy weight.grad " - "was not zeroed; downstream hooks reading .grad would " - "see stale bytes from the previous step." - ) - - class TestSequentialContainer: """Tests for sequential container""" @@ -2098,6 +2029,8 @@ def test_dropout( @pytest.mark.parametrize("input_requires_grad", (False, True)) @pytest.mark.parametrize("weight_requires_grad", (False, True)) @pytest.mark.parametrize("delay_wgrad_compute", (False, True)) + @pytest.mark.parametrize("single_grouped_weight", (False, True)) + @pytest.mark.parametrize("single_grouped_bias", (False, True)) def test_grouped_linear( self, *, @@ -2113,9 +2046,17 @@ def test_grouped_linear( input_requires_grad: bool, weight_requires_grad: bool, delay_wgrad_compute: bool, + single_grouped_weight: bool, + single_grouped_bias: bool, ) -> None: """Grouped GEMM""" - + if os.environ.get("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "0") == "0" and ( + single_grouped_weight or single_grouped_bias + ): + pytest.skip( + "single_grouped_weight/single_grouped_bias requires" + " NVTE_GROUPED_LINEAR_SINGLE_PARAM=1" + ) # Split sizes split_sizes = [split_alignment * i for i in range(group_size)] random.shuffle(split_sizes) @@ -2136,6 +2077,18 @@ def test_grouped_linear( if quantization is not None and dtype not in (torch.bfloat16, torch.float16): pytest.skip("Quantized group GEMM is only supported with BF16/FP16") + if single_grouped_bias and not bias: + pytest.skip("single_grouped_bias requires bias=True") + if ( + single_grouped_weight + and quantized_weight + and quantization in ("fp8_delayed_scaling", "fp8_current_scaling") + ): + pytest.skip( + "single_grouped_weight does not support FP8 delayed/current scaling " + "with quantized_model_init" + ) + # Random data x_ref, x_test = make_reference_and_test_tensors( in_shape, @@ -2194,12 +2147,26 @@ def test_grouped_linear( device=device, dtype=dtype, delay_wgrad_compute=delay_wgrad_compute, + single_grouped_weight=single_grouped_weight, + single_grouped_bias=single_grouped_bias, ) with torch.no_grad(): + if single_grouped_weight: + op_weights = op.weight.quantized_tensors + if op_weights is None: + op_weights = op.weight.split_into_quantized_tensors() + if single_grouped_bias: + op_bias_parts = op.bias.split_into_quantized_tensors() for group_idx in range(group_size): - getattr(op, f"weight{group_idx}").copy_(ws_test[group_idx]) + if single_grouped_weight: + op_weights[group_idx].copy_(ws_test[group_idx]) + else: + getattr(op, f"weight{group_idx}").copy_(ws_test[group_idx]) if bias: - getattr(op, f"bias{group_idx}").copy_(bs_test[group_idx]) + if single_grouped_bias: + op_bias_parts[group_idx].reshape(-1).copy_(bs_test[group_idx]) + else: + getattr(op, f"bias{group_idx}").copy_(bs_test[group_idx]) del ws_test, bs_test for param in op.parameters(): param.requires_grad_(requires_grad=weight_requires_grad) @@ -2227,20 +2194,222 @@ def test_grouped_linear( torch.testing.assert_close(dx_test, x_ref.grad, **tols) else: assert x_test.grad is None - for group_idx in range(group_size): - w_test = getattr(op, f"weight{group_idx}") + if single_grouped_weight: if weight_requires_grad: - dw_test = w_test.grad.to(dtype=torch.float64, device="cpu") - torch.testing.assert_close(dw_test, ws_ref[group_idx].grad, **tols) + dw_test_all = op.weight.grad.to(dtype=torch.float64, device="cpu") + w_ref_grad = torch.stack([w.grad for w in ws_ref], dim=0) + torch.testing.assert_close(dw_test_all, w_ref_grad, **tols) else: - assert w_test.grad is None - if bias: - b_test = getattr(op, f"bias{group_idx}") + assert op.weight.grad is None + else: + for group_idx in range(group_size): + w_test = getattr(op, f"weight{group_idx}") if weight_requires_grad: - db_test = b_test.grad.to(dtype=torch.float64, device="cpu") - torch.testing.assert_close(db_test, bs_ref[group_idx].grad, **tols) + dw_test = w_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(dw_test, ws_ref[group_idx].grad, **tols) else: - assert b_test.grad is None + assert w_test.grad is None + if bias: + if single_grouped_bias: + if weight_requires_grad: + db_test_all = op.bias.grad.to(dtype=torch.float64, device="cpu") + b_ref_grad = torch.stack([b.grad for b in bs_ref], dim=0) + torch.testing.assert_close(db_test_all, b_ref_grad, **tols) + else: + assert op.bias.grad is None + else: + for group_idx in range(group_size): + b_test = getattr(op, f"bias{group_idx}") + if weight_requires_grad: + db_test = b_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(db_test, bs_ref[group_idx].grad, **tols) + else: + assert b_test.grad is None + + @pytest.mark.parametrize("dtype", (torch.bfloat16, torch.float16)) + @pytest.mark.parametrize( + "quantization", + [None] + (["mxfp8"] if mxfp8_available else []), + ) + @pytest.mark.parametrize("quantized_weight", (False, True)) + @pytest.mark.parametrize("bias", (False, True)) + @pytest.mark.parametrize("single_grouped_weight", (False, True)) + @pytest.mark.parametrize("single_grouped_bias", (False, True)) + @pytest.mark.parametrize("accumulate_into_main_grad", (False, True)) + def test_grouped_linear_cuda_graph_safe( + self, + *, + dtype: torch.dtype, + quantization: Optional[str], + quantized_weight: bool, + bias: bool, + single_grouped_weight: bool, + single_grouped_bias: bool, + accumulate_into_main_grad: bool, + device: torch.device = "cuda", + group_size: int = 4, + in_features: int = 128, + out_features: int = 128, + split_alignment: int = 128, + token_padding: int = 256, + ) -> None: + """GroupedLinear forward+backward should be CUDA graph capturable. + + Exercises the grouped-tensor / cublas-grouped-gemm path which uses + GPU-resident split offsets and is the only flow safe to capture. + """ + if os.environ.get("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "0") == "0" and ( + single_grouped_weight or single_grouped_bias + ): + pytest.skip( + "single_grouped_weight/single_grouped_bias requires" + " NVTE_GROUPED_LINEAR_SINGLE_PARAM=1" + ) + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("Grouped GEMM CUDA-graph-safe path requires SM100+ (Blackwell)") + # Skip invalid configurations + if quantization is None and quantized_weight: + pytest.skip("quantized_weight requires a quantization recipe") + if single_grouped_bias and not bias: + pytest.skip("single_grouped_bias requires bias=True") + + # Split sizes (statically pinned for graph capture) + split_sizes = [split_alignment * (i + 1) for i in range(group_size)] + random.shuffle(split_sizes) + split_sizes = torch.tensor(split_sizes, dtype=torch.int, device=device) + # Pad input tokens to validate the sync-free flow + in_shape = (split_sizes.sum().item() + token_padding, in_features) + out_shape = (in_shape[0], out_features) + + recipe = make_recipe(quantization) + with te.quantized_model_init(enabled=quantized_weight, recipe=recipe): + op = te_ops.GroupedLinear( + group_size, + in_features, + out_features, + bias=bias, + device=device, + dtype=dtype, + accumulate_into_main_grad=accumulate_into_main_grad, + single_grouped_weight=single_grouped_weight, + single_grouped_bias=single_grouped_bias, + ) + + def _weight_params() -> list[torch.nn.Parameter]: + if single_grouped_weight: + return [op.weight] + return [getattr(op, f"weight{i}") for i in range(group_size)] + + def _bias_params() -> list[torch.nn.Parameter]: + if not bias: + return [] + if single_grouped_bias: + return [op.bias] + return [getattr(op, f"bias{i}") for i in range(group_size)] + + def _init_main_grads(value: float = 0.0) -> None: + if not accumulate_into_main_grad: + return + with torch.no_grad(): + for w in _weight_params(): + if getattr(w, "main_grad", None) is None: + w.main_grad = torch.empty(w.size(), device=device, dtype=torch.float32) + w.main_grad.fill_(value) + + def _collect_main_grads() -> list[torch.Tensor]: + return [w.main_grad.detach().clone() for w in _weight_params()] + + def _zero_param_grads() -> None: + for param in op.parameters(): + if param.grad is None: + param.grad = torch.zeros_like(param) + else: + param.grad.zero_() + + static_split_sizes = split_sizes.clone() + + def train_step( + x: torch.Tensor, + dy: torch.Tensor, + out_buf: torch.Tensor, + *, + use_graphed: bool, + ) -> torch.Tensor: + with te.autocast(enabled=quantization is not None, recipe=recipe): + out = ( + graphed_module(x, static_split_sizes) + if use_graphed + else op(x, static_split_sizes) + ) + out.backward(dy) + out_buf.copy_(out) + return out_buf + + _init_main_grads(0.0) + + static_x = torch.randn(in_shape, device=device, dtype=dtype, requires_grad=True) + static_dy = torch.randn(out_shape, device=device, dtype=dtype) + static_out_buf = torch.empty(out_shape, device=device, dtype=dtype) + + graphed_module = te.make_graphed_callables( + op, + (static_x, static_split_sizes), + num_warmup_iters=3, + enabled=quantization is not None, + recipe=recipe, + ) + + # Replace static buffers with fresh data (graph captures must replay + # against new inputs without re-recording). + fresh_x = torch.randn_like(static_x) + fresh_dy = torch.randn_like(static_dy) + with torch.no_grad(): + static_x.copy_(fresh_x) + static_dy.copy_(fresh_dy) + + # Reset grads & main_grads so the captured iteration starts fresh. + _zero_param_grads() + _init_main_grads(0.5) + if static_x.grad is not None: + static_x.grad.zero_() + + # Replay the graph + graph_out = ( + train_step(static_x, static_dy, static_out_buf, use_graphed=True).detach().clone() + ) + torch.cuda.synchronize() + graph_dx = static_x.grad.detach().clone() + if accumulate_into_main_grad: + graph_main_grads = _collect_main_grads() + graph_param_grads: list[torch.Tensor] = [] + else: + graph_main_grads = [] + graph_param_grads = [param.grad.detach().clone() for param in op.parameters()] + + # Reference: same op invoked eagerly with the same fresh inputs and + # the same starting grad/main_grad state. + _zero_param_grads() + _init_main_grads(0.5) + static_x.grad.zero_() + + expected_x = fresh_x.detach().clone().requires_grad_(True) + expected_dy = fresh_dy.detach().clone() + with te.autocast(enabled=quantization is not None, recipe=recipe): + expected_out = op(expected_x, static_split_sizes) + expected_out.backward(expected_dy) + + tols = dtype_tols(dtype) + if quantization is not None: + tols = quantization_tols(quantization) + + assert_close(graph_out, expected_out, **tols) + assert_close(graph_dx, expected_x.grad, **tols) + if accumulate_into_main_grad: + for g, w in zip(graph_main_grads, _weight_params()): + assert_close(g, w.main_grad, **tols) + else: + for g, param in zip(graph_param_grads, op.parameters()): + assert_close(g, param.grad, **tols) @pytest.mark.parametrize("in_shape", ((71, 192), (5, 7, 128))) @pytest.mark.parametrize("input_requires_grad", (False, True)) diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 8f8852edc2..c7cbe78a6d 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -8,6 +8,7 @@ import os import random import subprocess +from collections.abc import Iterable from contextlib import contextmanager from typing import Optional, Sequence, Tuple, Dict, Any, List from packaging.version import Version as PkgVersion @@ -27,6 +28,7 @@ check_set_window_size, ) from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend +from transformer_engine.pytorch.module.base import get_dummy_wgrad def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype: @@ -477,3 +479,73 @@ def run_distributed( msg += f"\n--- stderr ---\n{stderr_tail}" raise AssertionError(msg) return result + + +class MegatronTrainingHelper: + """Test-side stand-in for the Megatron-Core DDP / MegatronFSDP wrapper. + Megatron's DDP wrapper (and MegatronFSDP) owns the per-parameter + ``main_grad`` buffer and the ``overwrite_main_grad`` / + ``grad_added_to_main_grad`` attributes that coordinate + ``fuse_wgrad_accumulation`` with TE modules. These helpers reproduce the + relevant slice of that protocol so TE tests can exercise the + accumulate-into-``main_grad`` code path without pulling in the full + Megatron-Core dependency. + """ + + @staticmethod + def init_main_grad_buffers( + weight_params: Iterable[torch.nn.Parameter], + *, + fill_value: float, + overwrite_main_grad: bool, + zero_out_wgrad: bool = False, + dtype: torch.dtype = torch.float32, + ) -> None: + """Allocate ``main_grad`` and stamp the wrapper attributes on each + param, mirroring what the Megatron DDP/FSDP wrapper does before + backward.""" + for wp in weight_params: + wp.main_grad = torch.full(wp.size(), fill_value, device=wp.device, dtype=dtype) + wp.overwrite_main_grad = overwrite_main_grad + wp.zero_out_wgrad = zero_out_wgrad + wp.grad_added_to_main_grad = False + + @staticmethod + def verify_main_grad_accumulation( + weight_params: Iterable[torch.nn.Parameter], + *, + expected_main_grads: Iterable[torch.Tensor], + rtol: float = 0.0, + atol: float = 0.0, + ) -> None: + """Check that backward produced what the Megatron wrapper expects: + each ``main_grad`` matches ``expected_main_grads``, + ``grad_added_to_main_grad`` was flipped to ``True`` so the wrapper's + post-backward hooks won't double-accumulate, and ``param.grad`` was + replaced by the cached dummy tensor (so a wrapper hook that did + ``main_grad += grad`` would be a no-op rather than double-counting). + """ + for wp, expected in zip(weight_params, expected_main_grads): + torch.testing.assert_close(wp.main_grad.to(expected), expected, rtol=rtol, atol=atol) + + assert wp.grad_added_to_main_grad is True, ( + "weight.grad_added_to_main_grad was not flipped to True; " + "the Megatron DDP/FSDP wrapper hook will double-accumulate." + ) + + # ``.grad`` should be the cached dummy tensor returned by + # ``get_dummy_wgrad`` -- shared storage, not the real wgrad. + expected_dummy = get_dummy_wgrad(list(wp.size()), wp.dtype) + assert ( + wp.grad is not None + ), "weight.grad is None; the Megatron protocol expects a dummy tensor stand-in here." + assert wp.grad.data_ptr() == expected_dummy.data_ptr(), ( + "weight.grad does not share storage with the cached dummy " + "wgrad; downstream wrapper hooks risk double-accumulating." + ) + if getattr(wp, "zero_out_wgrad", False): + assert torch.all(wp.grad == 0), ( + "weight.zero_out_wgrad=True but the dummy weight.grad " + "was not zeroed; downstream hooks reading .grad would " + "see stale bytes from the previous step." + ) diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index beef6fe52f..9325d87ae7 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -6,6 +6,7 @@ from __future__ import annotations import functools +import math from importlib.metadata import PackageNotFoundError, version as get_pkg_version from typing import Optional @@ -88,6 +89,99 @@ def get_fp8_meta_from_fp8_tensor(tensor: Float8Tensor) -> tuple[FP8TensorMeta, i return fp8_meta, 0 +def get_main_grad_from_param( + weight_param: torch.nn.Parameter, + *, + op_label: str = "", +) -> torch.Tensor: + """Refresh ``main_grad`` from FSDP (if applicable) and return it. + Used by Megatron-LM-style wgrad fusion paths + (``accumulate_into_main_grad=True``) to obtain the buffer the wgrad GEMM + will write into. + Raises if the parameter does not have a ``main_grad`` attribute or if it + is ``None``. + """ + if hasattr(weight_param, "__fsdp_param__"): + weight_param.main_grad = weight_param.get_main_grad() + if not hasattr(weight_param, "main_grad") or weight_param.main_grad is None: + prefix = f"{op_label} " if op_label else "" + raise RuntimeError( + f"{prefix}operation is configured with accumulate_into_main_grad=True, " + "but weight parameter does not have a valid main_grad attribute" + ) + return weight_param.main_grad + + +def get_accumulate_flag_in_param(weight_param: torch.nn.Parameter) -> bool: + """Return whether the wgrad GEMM should accumulate into ``main_grad``. + + Returns ``False`` (i.e. overwrite) when the parameter has + ``overwrite_main_grad=True`` (used in Megatron-FSDP), and ``True`` + otherwise. + """ + return not getattr(weight_param, "overwrite_main_grad", False) + + +def view_main_grad_as_grouped_buffer( + main_grad: torch.Tensor, + num_groups: int, + weight_shape: tuple[int, ...], + *, + label: str = "", +) -> torch.Tensor: + """Return ``main_grad`` viewed as ``(num_groups, *weight_shape)`` without copy. + Raises if the numel doesn't match or if the existing stride pattern does + not allow a zero-copy view to the grouped layout. + """ + grouped_shape = (num_groups, *weight_shape) + if tuple(main_grad.shape) == grouped_shape: + return main_grad + prefix = f"{label} " if label else "Grouped weight " + if main_grad.numel() != math.prod(grouped_shape): + raise RuntimeError( + f"{prefix}main_grad expected shape {grouped_shape} or matching numel, " + f"but got shape {tuple(main_grad.shape)}" + ) + try: + return main_grad.view(grouped_shape) + except RuntimeError as e: + raise RuntimeError( + f"{prefix}main_grad must be viewable as {grouped_shape} without copy, " + f"but got shape {tuple(main_grad.shape)} and stride " + f"{tuple(main_grad.stride())}" + ) from e + + +def get_dummy_wgrads_for_params( + weight_params: list[torch.nn.Parameter], +) -> list[Optional[torch.Tensor]]: + """Build dummy ``.grad`` placeholders for Megatron-LM wgrad-fusion params. + + For each parameter that exposes ``grad_added_to_main_grad``, set the flag + to ``True`` and return a dummy wgrad tensor (zeroed if + ``zero_out_wgrad`` is also set on the parameter). For parameters without + the flag, the corresponding entry is ``None``. + + The returned list has the same length and order as ``weight_params``. + """ + from ..module.base import get_dummy_wgrad # pylint: disable=import-outside-toplevel + + out: list[Optional[torch.Tensor]] = [] + for wp in weight_params: + if hasattr(wp, "grad_added_to_main_grad"): + wp.grad_added_to_main_grad = True + out.append( + get_dummy_wgrad( + list(wp.size()), + wp.dtype, + zero=getattr(wp, "zero_out_wgrad", False), + ) + ) + else: + out.append(None) + return out + + def validate_grouped_mlp_dims(fc1, glu_op, fc2) -> None: """Validate FC1 / scaled GLU / FC2 dimensions for fused grouped MLP.""" diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 46d52f7ff3..41f0855f1d 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -24,7 +24,6 @@ _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD, - get_dummy_wgrad, ) from ...tensor import Quantizer from ...tensor.float8_tensor import Float8Quantizer @@ -36,7 +35,13 @@ devices_match, ) from ..op import BasicOperation, OperationContext -from .._common import maybe_dequantize, is_quantized_tensor +from .._common import ( + get_accumulate_flag_in_param, + get_dummy_wgrads_for_params, + get_main_grad_from_param, + is_quantized_tensor, + maybe_dequantize, +) def _wait_async(handle: Optional[Any]) -> None: @@ -1060,16 +1065,9 @@ def op_backward( grad_weight = None if ctx.weight_requires_grad and accumulate_into_main_grad: weight_param = self.weight - if hasattr(weight_param, "__fsdp_param__"): - weight_param.main_grad = weight_param.get_main_grad() - accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) - if not hasattr(weight_param, "main_grad"): - raise RuntimeError( - "BasicLinear op is configured with " - "accumulate_into_main_grad=True, " - "but weight parameter does not have main_grad attribute" - ) - grad_weight = weight_param.main_grad.detach() + main_grad = get_main_grad_from_param(weight_param, op_label="BasicLinear") + accumulate_into_main_grad = get_accumulate_flag_in_param(weight_param) + grad_weight = main_grad.detach() else: accumulate_into_main_grad = False @@ -1099,14 +1097,6 @@ def op_backward( # Megatron-LM wgrad fusion # Note: Return dummy tensor for grad weight if needed. if accumulate_into_main_grad: - grad_weight = None - weight_param = self.weight - if hasattr(weight_param, "grad_added_to_main_grad"): - weight_param.grad_added_to_main_grad = True - grad_weight = get_dummy_wgrad( - list(weight_param.size()), - weight_param.dtype, - zero=getattr(weight_param, "zero_out_wgrad", False), - ) + grad_weight = get_dummy_wgrads_for_params([self.weight])[0] return grad_input, [grad_weight] diff --git a/transformer_engine/pytorch/ops/basic/grouped_linear.py b/transformer_engine/pytorch/ops/basic/grouped_linear.py index a86abb1325..e698c2697f 100644 --- a/transformer_engine/pytorch/ops/basic/grouped_linear.py +++ b/transformer_engine/pytorch/ops/basic/grouped_linear.py @@ -14,28 +14,36 @@ import torch import transformer_engine_torch as tex -from ...cpp_extensions import general_grouped_gemm +from ...cpp_extensions import general_grouped_gemm, general_grouped_gemm_for_grouped_tensor from ...distributed import CudaRNGStatesTracker from ...module._common import WeightGradStore from ...module.base import ( _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD, - get_dummy_wgrad, ) from ...quantization import FP8GlobalStateManager, Recipe +from ...quantized_tensor import QuantizedTensorStorage from ...tensor import MXFP8Quantizer, MXFP8Tensor, Quantizer from ...utils import ( canonicalize_device, canonicalize_dtype, clear_tensor_data, devices_match, + get_device_compute_capability, resolve_grouped_linear_single_param_flags, round_up_to_nearest_multiple, ) -from .._common import is_quantized_tensor, maybe_dequantize +from .._common import ( + get_accumulate_flag_in_param, + get_dummy_wgrads_for_params, + get_main_grad_from_param, + is_quantized_tensor, + maybe_dequantize, + view_main_grad_as_grouped_buffer, +) from ..op import BasicOperation, OperationContext -from ...tensor import GroupedTensor +from ...tensor import GroupedTensor, GroupedTensorStorage from ...triton.grouped_dbias_dscales import ( compute_grouped_dbias, compute_grouped_dbias_dscales, @@ -242,7 +250,7 @@ def backward_dw(self) -> None: else: # Fused MXFP8 grouped MLP saves `GroupedTensor` activations for wgrad. clear_tensor_data( - activations.data, + activations.rowwise_data, activations.columnwise_data, activations.scale_inv, activations.columnwise_scale_inv, @@ -724,6 +732,151 @@ def op_backward(self, *args, **kwargs): "It overrides `fuser_backward` instead of `op_backward`." ) + @staticmethod + def _is_graph_safe_path_supported( + *, + with_quantized_compute: bool, + input_quantizers: Sequence[Optional[Quantizer]], + dtype: torch.dtype, + ) -> bool: + """Whether the graph-safe grouped-tensor flow can be used. + + * The graph-safe path dispatches to ``general_grouped_gemm_for_grouped_tensor``, + which is backed by ``nvte_grouped_gemm_with_discrete_inputA`` in the common + library. That kernel requires Blackwell (SM100) or newer with cuBLAS 13.3+. + * Quantized compute is currently MXFP8-only; every other quantization + recipe (fp8 delayed / current scaling, fp8 block scaling, NVFP4, ...) + falls back to the legacy flow. + * Unquantized compute supports BF16/FP16 only -- FP32 is excluded + because the cublasLt grouped GEMM doesn't support it. + """ + if get_device_compute_capability() < (10, 0): + return False + if with_quantized_compute: + return all(isinstance(q, MXFP8Quantizer) for q in input_quantizers) + return dtype in (torch.bfloat16, torch.float16) + + def _get_grouped_weight_for_gemm( + self, + weight_param: GroupedTensor, + weight_quantizers: list[Optional[Quantizer]], + columnwise_usage: bool, + with_quantized_compute: bool, + dtype: torch.dtype, + ) -> GroupedTensor: + """Prepare weights for ``general_grouped_gemm_for_grouped_tensor``. + Supports MXFP8/BF16/FP16 compute paths. + """ + num_groups = self.num_groups + is_weight_quantized = weight_param.quantizer is not None + if is_weight_quantized and with_quantized_compute: + # GGEMM can use it as it is + return weight_param + if is_weight_quantized and not with_quantized_compute: + # This use-case isnt optimized yet. Involves a per-group + # dequantize loop and a torch.stack copy. + weight_parts = weight_param.quantized_tensors + if weight_parts is None: + weight_parts = weight_param.split_into_quantized_tensors() + dequantized = [maybe_dequantize(w, dtype) for w in weight_parts] + weight_data = torch.stack(dequantized, dim=0).contiguous() + return GroupedTensor( + shape=(num_groups * self.out_features, self.in_features), + dtype=dtype, + num_tensors=num_groups, + shapes=[(self.out_features, self.in_features)] * num_groups, + quantizer=None, + data=weight_data.reshape(-1), + ) + if not with_quantized_compute: + # Make sure that weight param is the correct dtype, + # otherwise cast it to the correct dtype. + if weight_param.rowwise_data.dtype == dtype: + return weight_param + weight_data = weight_param.rowwise_data.to(dtype=dtype) + return GroupedTensor( + shape=(num_groups * self.out_features, self.in_features), + dtype=dtype, + num_tensors=num_groups, + shapes=[(self.out_features, self.in_features)] * num_groups, + quantizer=None, + data=weight_data.reshape(-1), + ) + # Quantized compute path, use the fused group quantize kernel. + weight_quantizer = weight_quantizers[0] + weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + return tex.group_quantize( + weight_param.rowwise_data.view(weight_param.logical_shape), + weight_quantizer, + num_groups, + None, + ) + + def _get_discrete_weights_for_gemm( + self, + weight_params: Optional[GroupedTensor] | list[torch.Tensor], + weight_quantizers: list[Optional[Quantizer]], + columnwise_usage: bool, + with_quantized_compute: bool, + dtype: torch.dtype, + ) -> list[torch.Tensor]: + """Prepare weights for ``general_grouped_gemm_for_grouped_tensor``. + Returns a Python list, which dispatches the GEMM to ``discrete_in`` mode. + """ + out: list[torch.Tensor] = [] + for w, quantizer in zip(weight_params, weight_quantizers): + if not with_quantized_compute: + w = maybe_dequantize(w, dtype) + elif with_quantized_compute and not is_quantized_tensor(w): + quantizer.set_usage(rowwise=True, columnwise=columnwise_usage) + w = quantizer(w) + out.append(w) + return out + + def _get_weight_tensors(self) -> list[torch.nn.Parameter]: + """Return the weight parameters in registration order. + + Length is 1 when ``single_grouped_weight=True`` (one + ``GroupedTensor`` parameter), otherwise ``num_groups``. + """ + if self.single_grouped_weight: + return [self.weight] + return [getattr(self, f"weight{idx}") for idx in range(self.num_groups)] + + def _get_grouped_bias_for_gemm( + self, + dtype: torch.dtype, + ) -> Optional[torch.Tensor]: + """Build a uniform GroupedTensor of per-group biases for the cublas + grouped GEMM. + + Each group expects a (1, out_features) bias vector. Returns ``None`` + when no additive bias is configured. + """ + if not self.has_bias: + return None + num_groups = self.num_groups + + if self.single_grouped_bias: + # Already a contiguous (num_groups * out_features) buffer. + bias_data = self.bias.rowwise_data + if bias_data.dtype != dtype: + bias_data = bias_data.to(dtype=dtype) + else: + bias_list = [ + maybe_dequantize(getattr(self, f"bias{idx}"), dtype) for idx in range(num_groups) + ] + bias_data = torch.stack(bias_list, dim=0).contiguous() + + return GroupedTensor( + shape=(num_groups, self.out_features), + dtype=dtype, + num_tensors=num_groups, + shapes=[(1, self.out_features)] * num_groups, + quantizer=None, + data=bias_data.reshape(-1), + ) + def fuser_forward( self, basic_op_ctxs: list[OperationContext], @@ -735,7 +888,6 @@ def fuser_forward( basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: num_groups = self.num_groups - has_bias = self.has_bias weight_param = self.weight if self.single_grouped_weight else self.weight0 device = weight_param.device @@ -753,13 +905,11 @@ def fuser_forward( # Quantizers input_quantizers = [None] * num_groups weight_quantizers = [None] * num_groups - grad_output_quantizers = [None] * num_groups with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() if with_quantized_compute: for group_idx in range(num_groups): input_quantizers[group_idx] = self.get_quantizer("forward", 2 * group_idx) weight_quantizers[group_idx] = self.get_quantizer("forward", 2 * group_idx + 1) - grad_output_quantizers[group_idx] = self.get_quantizer("backward", group_idx) # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -767,17 +917,151 @@ def fuser_forward( else: dtype = weight_param.dtype - # Extract split sizes from extra input + # Extract split sizes from extra input. Keep on GPU for graph safety. split_sizes = basic_op_extra_inputs[0][0] - split_sizes_int = [int(s) for s in split_sizes.tolist()] - if len(split_sizes_int) != num_groups: - raise ValueError(f"Expected {num_groups} splits, but got {len(split_sizes_int)}.") + if int(split_sizes.numel()) != num_groups: + raise ValueError(f"Expected {num_groups} splits, but got {int(split_sizes.numel())}.") + if split_sizes.dtype != torch.int64: + split_sizes = split_sizes.to(dtype=torch.int64) + if split_sizes.device != device: + split_sizes = split_sizes.to(device=device) # Extract scales tensor for bias scaling scales = None if self._scale_bias: scales = basic_op_extra_inputs[0][1] + # Dispatch: graph-safe GroupedTensor flow whenever it can be used. + # See ``_is_graph_safe_path_supported`` for the gating rationale -- + # in short it requires Blackwell (SM100+) plus a supported dtype / + # quantization recipe. Otherwise we fall back to the legacy + # ``tex.split_quantize`` + ``general_grouped_gemm`` flow. + use_grouped_tensor_path = self._is_graph_safe_path_supported( + with_quantized_compute=with_quantized_compute, + input_quantizers=input_quantizers, + dtype=dtype, + ) + + if use_grouped_tensor_path: + out, tensors_to_save = self._fuser_forward_grouped_tensor( + input_=input_, + split_sizes=split_sizes, + scales=scales, + with_quantized_compute=with_quantized_compute, + input_quantizers=input_quantizers, + weight_quantizers=weight_quantizers, + dtype=dtype, + input_requires_grad=input_requires_grad, + weight_requires_grad=weight_requires_grad, + device=device, + ) + else: + out, tensors_to_save = self._fuser_forward_split_quantize( + input_=input_, + split_sizes=split_sizes, + scales=scales, + with_quantized_compute=with_quantized_compute, + input_quantizers=input_quantizers, + weight_quantizers=weight_quantizers, + dtype=dtype, + input_requires_grad=input_requires_grad, + weight_requires_grad=weight_requires_grad, + device=device, + ) + + # Save tensors and autograd metadata on the basic-op context. + self.fuser_forward_save_ctx( + basic_op_ctxs=basic_op_ctxs, + input_=input_, + tensors_to_save=[tensors_to_save], + requires_grad=[ctx.requires_grad], + basic_op_extra_inputs=basic_op_extra_inputs, + prev_op_grad_output_quantizer=prev_op_grad_output_quantizer, + next_op_input_quantizer=next_op_input_quantizer, + basic_op_kwargs=basic_op_kwargs, + use_grouped_tensor_path=use_grouped_tensor_path, + ) + + return out, [()] + + def fuser_forward_save_ctx( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, # pylint: disable=unused-argument + tensors_to_save: list[ + tuple[Optional[torch.Tensor | QuantizedTensorStorage | GroupedTensorStorage], ...] + ], + *, + requires_grad: list[bool], + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], # pylint: disable=unused-argument + prev_op_grad_output_quantizer: Optional[Quantizer], # pylint: disable=unused-argument + next_op_input_quantizer: Optional[Quantizer], # pylint: disable=unused-argument + basic_op_kwargs: list[dict[str, Any]], # pylint: disable=unused-argument + use_grouped_tensor_path: bool, + ) -> None: + """ + Save tensors and autograd metadata in context. + """ + if not requires_grad[0]: + return + + ctx = basic_op_ctxs[0] + ctx.save_for_backward(*tensors_to_save[0]) + + num_groups = self.num_groups + weight_param = self.weight if self.single_grouped_weight else self.weight0 + + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + input_quantizers = [None] * num_groups + weight_quantizers = [None] * num_groups + grad_output_quantizers = [None] * num_groups + if with_quantized_compute: + for group_idx in range(num_groups): + input_quantizers[group_idx] = self.get_quantizer("forward", 2 * group_idx) + weight_quantizers[group_idx] = self.get_quantizer("forward", 2 * group_idx + 1) + grad_output_quantizers[group_idx] = self.get_quantizer("backward", group_idx) + + ctx.use_grouped_tensor_path = use_grouped_tensor_path + ctx.with_quantized_compute = with_quantized_compute + ctx.input_quantizers = input_quantizers + ctx.weight_quantizers = weight_quantizers + ctx.grad_output_quantizers = grad_output_quantizers + ctx.grad_input_quantizers = None + # ``split_sizes`` and ``base_split_offsets`` are routed through + # ``save_for_backward`` (see ``_fuser_forward_split_quantize`` and + # ``_fuser_forward_grouped_tensor`` for the saved-tensor layout). + if torch.is_autocast_enabled(): + ctx.dtype = torch.get_autocast_dtype("cuda") + else: + ctx.dtype = weight_param.dtype + ctx.input_requires_grad = requires_grad[0] + ctx.weight_requires_grad = requires_grad[0] and weight_param.requires_grad + + # ================================================================== + # Legacy `tex.split_quantize` + `general_grouped_gemm` flow. + # ``m_splits`` is needed on CPU here, so this flow is NOT cuda-graphable. + # ================================================================== + def _fuser_forward_split_quantize( + self, + *, + input_: torch.Tensor, + split_sizes: torch.Tensor, + scales: Optional[torch.Tensor], + with_quantized_compute: bool, + input_quantizers: list[Optional[Quantizer]], + weight_quantizers: list[Optional[Quantizer]], + dtype: torch.dtype, + input_requires_grad: bool, + weight_requires_grad: bool, + device: torch.device, + ) -> tuple[torch.Tensor, tuple[Optional[torch.Tensor], ...]]: + """Legacy ``tex.split_quantize`` + ``general_grouped_gemm`` flow.""" + num_groups = self.num_groups + has_bias = self.has_bias + + # Need CPU split sizes for split_quantize / general_grouped_gemm. + split_sizes_int = [int(s) for s in split_sizes.tolist()] + # Extract params if self.single_grouped_weight: weights = self.weight.quantized_tensors @@ -787,26 +1071,15 @@ def fuser_forward( weights = [getattr(self, f"weight{idx}") for idx in range(num_groups)] bs = None if has_bias: - if self.single_grouped_bias: - bias_parts = self.bias.quantized_tensors - if bias_parts is None: - bias_parts = self.bias.split_into_quantized_tensors() - bs = [maybe_dequantize(p.reshape(-1), dtype) for p in bias_parts] - else: - bs = [ - maybe_dequantize(getattr(self, f"bias{idx}"), dtype) - for idx in range(num_groups) - ] - - # Convert weight dtype if needed - ws = [] - for w, quantizer in zip(weights, weight_quantizers): - if not with_quantized_compute: - w = maybe_dequantize(w, dtype) - elif with_quantized_compute and not is_quantized_tensor(w): - quantizer.set_usage(rowwise=True, columnwise=input_requires_grad) - w = quantizer(w) - ws.append(w) + bs = self._get_bias_tensors(dtype) + + ws = self._get_discrete_weights_for_gemm( + weights, + weight_quantizers, + columnwise_usage=input_requires_grad, + with_quantized_compute=with_quantized_compute, + dtype=dtype, + ) # Split input tensor and convert dtypes if needed x = maybe_dequantize(input_, dtype) @@ -839,9 +1112,6 @@ def fuser_forward( ) # Add bias * scales when scale_bias is enabled - # TODO(vthumbe): Need to use GroupedBiasAdd kernel here. - # Would be done as part of larger refactor for GroupedLinear + GroupedTensor - # integration. if self._scale_bias and has_bias: scales_splits = torch.split(scales, split_sizes_int) out_splits = torch.split(out, split_sizes_int) @@ -863,24 +1133,147 @@ def fuser_forward( for x in xs: x.update_usage(rowwise_usage=False, columnwise_usage=True) - # Save state for backward pass - if ctx.requires_grad: - saved = [split_sizes] + # Build the tuple of tensors to save for backward. Layout: + # [split_sizes, base_split_offsets, split_points, + # (scales if scale_bias), *xs, *ws] + # ``base_split_offsets`` and ``split_points`` are unused on the + # split-quantize backward path but are included as ``None`` so the + # saved-tensor layout matches the graph-safe + # ``_fuser_forward_grouped_tensor`` path (and the fused MLP forward). + saved: list[Optional[torch.Tensor]] = [split_sizes, None, None] + if self._scale_bias: + saved.append(scales) + saved.extend(xs) + saved.extend(ws) + return out, tuple(saved) + + def _fuser_forward_grouped_tensor( + self, + *, + input_: torch.Tensor, + split_sizes: torch.Tensor, + scales: Optional[torch.Tensor], + with_quantized_compute: bool, + input_quantizers: list[Optional[Quantizer]], + weight_quantizers: list[Optional[Quantizer]], + dtype: torch.dtype, + input_requires_grad: bool, + weight_requires_grad: bool, + device: torch.device, + ) -> tuple[torch.Tensor, tuple[Optional[torch.Tensor], ...]]: + """Graph-safe GroupedTensor forward path (pure compute). + Returns ``(output, tensors_to_save)``. ``split_sizes``, + ``base_split_offsets`` and ``split_points`` are returned so that + ``fuser_forward_save_ctx`` can call ``save_for_backward`` on them. + """ + num_groups = self.num_groups + has_bias = self.has_bias + + base_split_offsets = tex.splits_to_offsets(split_sizes, 1) + split_points = base_split_offsets[1:].to(dtype=torch.int) + + # Flatten to 2D so the first dim is the total token count. + original_shape = list(input_.size()) + x = maybe_dequantize(input_, dtype).reshape(-1, self.in_features) + total_tokens = x.size(0) + + # Build the input GroupedTensor. + if with_quantized_compute: + input_quantizer = input_quantizers[0] + input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) + input_quantizer.optimize_for_gemm = True + grouped_x = tex.group_quantize(x, input_quantizer, num_groups, split_sizes) + else: + # No quantize: wrap the contiguous high-precision buffer. + grouped_x = GroupedTensor( + shape=(total_tokens, self.in_features), + dtype=dtype, + num_tensors=num_groups, + quantizer=None, + data=x.reshape(-1), + first_dims=split_sizes, + tensor_offsets=base_split_offsets * self.in_features, + ) + + # Build the weight GroupedTensor / list. + if self.single_grouped_weight: + # GroupedTensor + grouped_weights = self._get_grouped_weight_for_gemm( + self.weight, + weight_quantizers, + columnwise_usage=input_requires_grad, + with_quantized_compute=with_quantized_compute, + dtype=dtype, + ) + else: + # Discrete weights + grouped_weights = self._get_discrete_weights_for_gemm( + [getattr(self, f"weight{idx}") for idx in range(num_groups)], + weight_quantizers, + columnwise_usage=input_requires_grad, + with_quantized_compute=with_quantized_compute, + dtype=dtype, + ) + + # Allocate output buffer and wrap as a GroupedTensor view. + out_shape = original_shape[:-1] + [self.out_features] + out = torch.empty(out_shape, dtype=dtype, device=device) + grouped_out = GroupedTensor( + shape=(total_tokens, self.out_features), + dtype=dtype, + num_tensors=num_groups, + quantizer=None, + data=out.reshape(-1), + first_dims=split_sizes, + tensor_offsets=base_split_offsets * self.out_features, + ) + + # Bias: hand off to the grouped GEMM (graph-safe, fused). Plain bias + # uses ``bias=``; scaled bias also passes per-token ``bias_scale=``. + grouped_bias = None + bias_scale: Optional[torch.Tensor] = None + if has_bias: + # Bias always needs to be passed as a GroupedTensor for the grouped GEMM. + grouped_bias = self._get_grouped_bias_for_gemm(dtype) if self._scale_bias: - saved.append(scales) - saved.extend(xs) - saved.extend(ws) - ctx.save_for_backward(*saved) - ctx.with_quantized_compute = with_quantized_compute - ctx.input_quantizers = input_quantizers - ctx.weight_quantizers = weight_quantizers - ctx.grad_output_quantizers = grad_output_quantizers - ctx.grad_input_quantizers = None - ctx.dtype = dtype - ctx.input_requires_grad = input_requires_grad - ctx.weight_requires_grad = weight_requires_grad + bias_scale = scales.reshape(-1) + if bias_scale.dtype != torch.float32: + bias_scale = bias_scale.to(dtype=torch.float32) + + # Forward grouped GEMM. + general_grouped_gemm_for_grouped_tensor( + grouped_weights, + grouped_x, + grouped_out, + layout="TN", + use_split_accumulator=_2X_ACC_FPROP, + bias=grouped_bias, + bias_scale=bias_scale, + ) - return out, [()] + if not input_requires_grad: + grouped_weights = None if self.single_grouped_weight else [None] * num_groups + + if not weight_requires_grad: + grouped_x = None + + # Build the tuple of tensors to save for backward. Layout: + # [split_sizes, base_split_offsets, split_points, + # (scales if _scale_bias), grouped_x, *weights] + if grouped_x is not None: + if with_quantized_compute: + # only columnwise data is needed for wgrad + grouped_x.rowwise_data = None + grouped_x.scale_inv = None + saved: list[Optional[torch.Tensor]] = [split_sizes, base_split_offsets, split_points] + if self._scale_bias: + saved.append(scales) + saved.append(grouped_x) + if self.single_grouped_weight: + saved.append(grouped_weights) + else: + saved.extend(grouped_weights) + return out, tuple(saved) def fuser_backward( self, @@ -892,16 +1285,43 @@ def fuser_backward( torch.Tensor, Iterable[Iterable[Optional[torch.Tensor]]], Iterable[Iterable[Optional[torch.Tensor]]], + ]: + ctx = basic_op_ctxs[0] + # Dispatch to the path used in forward (saved as ``ctx.use_grouped_tensor_path``). + if getattr(ctx, "use_grouped_tensor_path", False): + return self._fuser_backward_grouped_tensor( + ctx=ctx, + grad_output=grad_output, + ) + return self._fuser_backward_split_quantize( + ctx=ctx, + grad_output=grad_output, + ) + + def _fuser_backward_split_quantize( + self, + *, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[Optional[torch.Tensor]]], + Iterable[Iterable[Optional[torch.Tensor]]], ]: num_groups = self.num_groups has_bias = self.has_bias - weight_param = self.weight if self.single_grouped_weight else self.weight0 - device = weight_param.device + weights = self._get_weight_tensors() + device = weights[0].device - # Saved tensors from forward pass - ctx = basic_op_ctxs[0] + # Saved tensors from forward pass. Layout: + # [split_sizes, base_split_offsets, split_points, + # (scales if _scale_bias), *xs, *ws] + # ``base_split_offsets`` and ``split_points`` are unused on this path + # but are present so the saved-tensor layout matches the graph-safe + # path (and the fused MLP forward). saved_tensors = ctx.saved_tensors - split_sizes, saved_tensors = saved_tensors[0], saved_tensors[1:] + split_sizes = saved_tensors[0] + saved_tensors = saved_tensors[3:] scales = None if self._scale_bias: scales, saved_tensors = saved_tensors[0], saved_tensors[1:] @@ -941,58 +1361,43 @@ def fuser_backward( dbias_packed = compute_grouped_dbias(dy_2d, offsets, num_groups) grad_biases = [dbias_packed[idx].to(dtype=ctx.dtype) for idx in range(num_groups)] - # Initialize grad weight buffers + # Initialize grad weight buffers. accumulate_into_main_grad = self._accumulate_into_main_grad grad_weights = [None] * num_groups + final_weight_grads: list[Optional[torch.Tensor]] = ( + [None] if self.single_grouped_weight else [None] * num_groups + ) if ctx.weight_requires_grad: - if accumulate_into_main_grad: - # Megatron-LM wgrad fusion - # Note: Get grad tensors from params so we can - # accumulate directly into it. - if self.single_grouped_weight: - if hasattr(weight_param, "__fsdp_param__"): - weight_param.main_grad = weight_param.get_main_grad() - main_grad = weight_param.main_grad - if isinstance(main_grad, GroupedTensor): - grad_weights = main_grad.quantized_tensors - if grad_weights is None: - grad_weights = main_grad.split_into_quantized_tensors() - else: - # main_grad may be [num_groups, out, in] or a flat buffer. - # Canonicalize to grouped layout before slicing per-group views. - weight_shape = (self.out_features, self.in_features) - grouped_shape = (num_groups, *weight_shape) - if main_grad.shape != grouped_shape: - if main_grad.numel() != math.prod(grouped_shape): - raise RuntimeError( - "GroupedLinear expected grouped weight main_grad to have " - f"shape {grouped_shape} or matching numel, " - f"but got shape {tuple(main_grad.shape)}" - ) - main_grad = main_grad.reshape(grouped_shape) - grad_weights = [main_grad[idx] for idx in range(num_groups)] - accumulate_into_main_grad = not getattr( - weight_param, "overwrite_main_grad", False + weight_shape = (self.out_features, self.in_features) + grouped_shape = (num_groups, *weight_shape) + if self.single_grouped_weight: + if accumulate_into_main_grad: + # Megatron-LM wgrad fusion: GEMM accumulates into the + # parameter's ``main_grad`` directly. + main_grad = get_main_grad_from_param(weights[0], op_label="GroupedLinear") + main_grad = view_main_grad_as_grouped_buffer( + main_grad, num_groups, weight_shape, label="GroupedLinear weight" ) + final_weight_grads[0] = main_grad + grad_weights = [main_grad[idx] for idx in range(num_groups)] + accumulate_into_main_grad = get_accumulate_flag_in_param(weights[0]) else: - for group_idx in range(num_groups): - weight_param = getattr(self, f"weight{group_idx}") - if hasattr(weight_param, "__fsdp_param__"): - weight_param.main_grad = weight_param.get_main_grad() - grad_weights[group_idx] = weight_param.main_grad - accumulate_into_main_grad = not getattr( - self.weight0, "overwrite_main_grad", False + final_weight_grads[0] = torch.empty( + grouped_shape, dtype=ctx.dtype, device=device ) + grad_weights = [final_weight_grads[0][idx] for idx in range(num_groups)] else: - weight_shape = (self.out_features, self.in_features) - for group_idx in range(num_groups): - grad_weights[group_idx] = torch.empty( - weight_shape, - dtype=ctx.dtype, - device=device, - ) - else: - accumulate_into_main_grad = False + if accumulate_into_main_grad: + grad_weights = [ + get_main_grad_from_param(w, op_label="GroupedLinear") for w in weights + ] + accumulate_into_main_grad = get_accumulate_flag_in_param(weights[0]) + else: + grad_weights = [ + torch.empty(weight_shape, dtype=ctx.dtype, device=device) + for _ in range(num_groups) + ] + final_weight_grads = list(grad_weights) # Perform dgrad GEMMs grad_input = None @@ -1050,54 +1455,14 @@ def fuser_backward( if not delay_wgrad: clear_tensor_data(*xs) - # Megatron-LM wgrad fusion - # Note: Return dummy tensor for grad weight if needed. - if accumulate_into_main_grad: - grad_weights = [None] * num_groups - if self.single_grouped_weight: - if hasattr(weight_param, "grad_added_to_main_grad"): - weight_param.grad_added_to_main_grad = True - grad_weight = get_dummy_wgrad( - list(weight_param.size()), - weight_param.dtype, - zero=getattr(weight_param, "zero_out_wgrad", False), - ) - else: - grad_weight = None - # Be mindful of param registration order. - if has_bias: - if self.single_grouped_bias: - final_bias_grads = torch.stack(grad_biases, dim=0).to(ctx.dtype) - grad_params = [grad_weight, final_bias_grads] - else: - grad_params = grad_biases + [grad_weight] - else: - grad_params = [grad_weight] - grad_extra = (None, grad_scales) if self._scale_bias else (None,) - return grad_input, [grad_params], [grad_extra] - for group_idx in range(num_groups): - weight_param = getattr(self, f"weight{group_idx}") - if hasattr(weight_param, "grad_added_to_main_grad"): - weight_param.grad_added_to_main_grad = True - grad_weights[group_idx] = get_dummy_wgrad( - list(weight_param.size()), - weight_param.dtype, - zero=getattr(weight_param, "zero_out_wgrad", False), - ) - - if self.single_grouped_weight: - grad_weight = None - if ctx.weight_requires_grad: - if delay_wgrad: - grad_weight = None - else: - grad_weight = torch.stack(grad_weights, dim=0) - final_weight_grads = [grad_weight] - else: - if delay_wgrad and ctx.weight_requires_grad and not accumulate_into_main_grad: - final_weight_grads = [None] * num_groups - else: - final_weight_grads = grad_weights + # Megatron-LM wgrad fusion: regardless of overwrite vs. accumulate, + # signal that ``main_grad`` already carries the wgrad and replace + # ``.grad`` with a dummy so DDP/FSDP hooks won't add ``.grad`` into + # ``main_grad`` again. + if ctx.weight_requires_grad and self._accumulate_into_main_grad: + final_weight_grads = get_dummy_wgrads_for_params(weights) + elif ctx.weight_requires_grad and delay_wgrad: + final_weight_grads = [None] if self.single_grouped_weight else [None] * num_groups if not has_bias: grad_params = list(final_weight_grads) @@ -1112,3 +1477,214 @@ def fuser_backward( grad_extra = (None, grad_scales) if self._scale_bias else (None,) return grad_input, [grad_params], [grad_extra] + + # ================================================================== + # Graph-safe backward: counterpart of `_fuser_forward_grouped_tensor`. + # ================================================================== + def _fuser_backward_grouped_tensor( + self, + *, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[ + torch.Tensor, + Iterable[Iterable[Optional[torch.Tensor]]], + Iterable[Iterable[Optional[torch.Tensor]]], + ]: + num_groups = self.num_groups + has_bias = self.has_bias + weights = self._get_weight_tensors() + device = weights[0].device + dtype = ctx.dtype + + with_quantized_compute = bool(getattr(ctx, "with_quantized_compute", False)) + + # Saved tensors from forward pass + # Layout: [split_sizes, base_split_offsets, split_points, + # (scales if _scale_bias), grouped_x, *weights] + # ``split_points`` is unused on this path but is present so the + # saved-tensor layout matches the fused MLP forward (which needs it + # for the cuDNN grouped GEMM kernel). + saved_tensors = ctx.saved_tensors + split_sizes = saved_tensors[0] + base_split_offsets = saved_tensors[1] + saved_tensors = saved_tensors[3:] + scales = None + if self._scale_bias: + scales, saved_tensors = saved_tensors[0], saved_tensors[1:] + grouped_x, saved_tensors = saved_tensors[0], saved_tensors[1:] + if self.single_grouped_weight: + ws, saved_tensors = saved_tensors[0], saved_tensors[1:] + else: + ws, saved_tensors = saved_tensors[:num_groups], saved_tensors[num_groups:] + + # Flatten grad_output to 2D (total_tokens, out_features) + # to figure out total tokens. + dy_2d = grad_output.reshape(-1, self.out_features) + total_tokens = dy_2d.size(0) + + # Build the grad_output GroupedTensor. + # Optionally get dbias is fusion available with bgrad_group_quantize + dbias_packed = None + if with_quantized_compute: + grad_output_quantizer = ctx.grad_output_quantizers[0] + grad_output_quantizer.set_usage( + rowwise=ctx.input_requires_grad, columnwise=ctx.weight_requires_grad + ) + grad_output_quantizer.optimize_for_gemm = True + + if has_bias and not self._scale_bias: + grouped_dy, dbias_packed = tex.bgrad_group_quantize( + dy_2d, grad_output_quantizer, num_groups, split_sizes + ) + else: + grouped_dy = tex.group_quantize( + dy_2d, grad_output_quantizer, num_groups, split_sizes + ) + else: + dy_2d = maybe_dequantize(dy_2d, dtype) + # Wrap BF16/FP16 buffer as a GroupedTensor for grouped gemm + grouped_dy = GroupedTensor( + shape=(total_tokens, self.out_features), + dtype=dtype, + num_tensors=num_groups, + quantizer=None, + data=dy_2d.reshape(-1), + first_dims=split_sizes, + tensor_offsets=base_split_offsets * self.out_features, + ) + + # Bias Grads compute if not already computed in bgrad_group_quantize + final_bias_grads: Optional[torch.Tensor] = None + grad_scales: Optional[torch.Tensor] = None + if has_bias: + if self._scale_bias: + bias_packed = torch.stack(self._get_bias_tensors(dtype)) + scales_f32 = scales.to(dtype=torch.float32) + dbias_packed, grad_scales = compute_grouped_dbias_dscales( + dy_2d, + scales_f32, + bias_packed, + offsets=base_split_offsets, + ) + elif dbias_packed is None: + # BF16/FP16 path + dbias_packed = compute_grouped_dbias(dy_2d, base_split_offsets, num_groups) + if self.single_grouped_bias: + final_bias_grads = [dbias_packed.to(dtype=dtype)] + else: + final_bias_grads = [dbias_packed[idx].to(dtype=dtype) for idx in range(num_groups)] + + # ---- dgrad GEMM ---------------------------------------------------- + grad_input = None + if ctx.input_requires_grad: + grad_input_shape = list(grad_output.size())[:-1] + [self.in_features] + grad_input = torch.empty(grad_input_shape, dtype=dtype, device=device) + grouped_grad_input = GroupedTensor( + shape=(total_tokens, self.in_features), + dtype=dtype, + num_tensors=num_groups, + quantizer=None, + data=grad_input.reshape(-1), + first_dims=split_sizes, + tensor_offsets=base_split_offsets * self.in_features, + ) + general_grouped_gemm_for_grouped_tensor( + ws, + grouped_dy, + grouped_grad_input, + layout="NN", + use_split_accumulator=_2X_ACC_DGRAD, + ) + + # params init for wgrad GEMM + accumulate_into_main_grad = False + weight_shape = (self.out_features, self.in_features) + wgrad_output: Any = None + grouped_wgrad: Optional[GroupedTensor] = None + final_weight_grads: list[Optional[torch.Tensor]] = ( + [None] if self.single_grouped_weight else [None] * num_groups + ) + + # Get the right wgrad buffers for grouped gemm. + # Can be a GroupedTensor or list of tensors based on single_grouped_weight. + if ctx.weight_requires_grad: + if self.single_grouped_weight: + if self._accumulate_into_main_grad: + # Main-grad fusion: GEMM writes directly into ``main_grad``. + # ``overwrite_main_grad`` only flips the GEMM's + # ``accumulate`` flag. + main_grad = get_main_grad_from_param(weights[0], op_label="GroupedLinear") + main_grad = view_main_grad_as_grouped_buffer( + main_grad, num_groups, weight_shape, label="GroupedLinear weight" + ) + grouped_wgrad = GroupedTensor.make_grouped_tensor_from_rowwise_data( + num_tensors=num_groups, + tensor_shape=weight_shape, + rowwise_data=main_grad.view(-1), + dtype=main_grad.dtype, + ) + accumulate_into_main_grad = get_accumulate_flag_in_param(weights[0]) + else: + grouped_wgrad = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_groups, + shapes=[weight_shape] * num_groups, + quantizer=None, + device=device, + dtype=dtype, + ) + final_weight_grads[0] = grouped_wgrad.rowwise_data.view(num_groups, *weight_shape) + wgrad_output = grouped_wgrad + else: + if self._accumulate_into_main_grad: + final_weight_grads = [ + get_main_grad_from_param(w, op_label="GroupedLinear") for w in weights + ] + accumulate_into_main_grad = get_accumulate_flag_in_param(weights[0]) + else: + final_weight_grads = [ + torch.empty(weight_shape, dtype=dtype, device=device) + for _ in range(num_groups) + ] + wgrad_output = final_weight_grads + + # wgrad GEMM + delay_wgrad = ( + ctx.weight_requires_grad + and self.wgrad_store is not None + and self.wgrad_store.delay_wgrad_compute() + ) + if ctx.weight_requires_grad: + wgrad_gemm = functools.partial( + general_grouped_gemm_for_grouped_tensor, + layout="NT", + accumulate=accumulate_into_main_grad, + use_split_accumulator=_2X_ACC_WGRAD, + ) + if delay_wgrad: + self.wgrad_store.put([grouped_x, grouped_dy, wgrad_output], wgrad_gemm) + else: + wgrad_gemm(grouped_x, grouped_dy, wgrad_output) + + # Megatron-LM wgrad fusion: regardless of overwrite vs. accumulate, + # signal that ``main_grad`` already carries the wgrad and replace + # ``.grad`` with a dummy so DDP/FSDP hooks won't add ``.grad`` into + # ``main_grad`` again. + if ctx.weight_requires_grad and self._accumulate_into_main_grad: + final_weight_grads = get_dummy_wgrads_for_params(weights) + elif ctx.weight_requires_grad and delay_wgrad: + final_weight_grads = [None] if self.single_grouped_weight else [None] * num_groups + + # Assemble grad params in parameter registration order and return. + if not has_bias: + grad_params = final_weight_grads + elif self.single_grouped_bias: + grad_params = final_weight_grads + final_bias_grads + else: + if self.single_grouped_weight: + grad_params = final_bias_grads + final_weight_grads + else: + grad_params = final_weight_grads + final_bias_grads + + grad_extra = (None, grad_scales) if self._scale_bias else (None,) + return grad_input, [grad_params], [grad_extra] diff --git a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py index b07ebb73eb..320c7c39e5 100644 --- a/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py @@ -7,7 +7,6 @@ from __future__ import annotations from collections.abc import Callable import functools -import math import os from typing import Optional @@ -25,11 +24,15 @@ from .._common import ( _cudnn_frontend_version_supported, fuse_grouped_mlp_ops, + get_accumulate_flag_in_param, + get_dummy_wgrads_for_params, + get_main_grad_from_param, maybe_dequantize, + view_main_grad_as_grouped_buffer, validate_grouped_mlp_dims, ) from ...cpp_extensions import general_grouped_gemm_for_grouped_tensor -from ...module.base import _2X_ACC_WGRAD, get_dummy_wgrad +from ...module.base import _2X_ACC_WGRAD from ...triton.grouped_dbias_dscales import compute_grouped_dbias_dscales @@ -149,44 +152,32 @@ def _compute_grad_params( Returns the grad_params list in parameter registration order. """ - # Allocate grad buffers, determine accumulate flag + # Allocate grad buffers, determine accumulate flag. accumulate_into_main_grad = False grouped_wgrad = None wgrad_output = None + op_label = f"Grouped MLP fused backward ({label})" if label else "Grouped MLP fused backward" + weights = fc_op._get_weight_tensors() if fc_op.single_grouped_weight: w_list = [None] if ctx.weight_requires_grad: - weight_param = fc_op.weight if fc_op._accumulate_into_main_grad: - if hasattr(weight_param, "__fsdp_param__"): - weight_param.main_grad = weight_param.get_main_grad() - main_grad = weight_param.main_grad - grouped_shape = (num_groups, *weight_shape) - if main_grad.shape != grouped_shape: - if main_grad.numel() != math.prod(grouped_shape): - raise RuntimeError( - f"Grouped MLP fused backward expected {label} main_grad to have " - f"shape {grouped_shape} or matching numel, " - f"but got shape {tuple(main_grad.shape)}" - ) - try: - main_grad = main_grad.view(grouped_shape) - except RuntimeError as e: - raise RuntimeError( - f"Grouped MLP fused backward requires {label} main_grad to be " - f"viewable as {grouped_shape} without copy, but got shape" - f" {tuple(main_grad.shape)} and stride" - f" {tuple(main_grad.stride())}" - ) from e - accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) + # Main-grad fusion: GEMM writes directly into ``main_grad``. + # ``overwrite_main_grad`` only flips the GEMM's ``accumulate`` + # flag (overwrite vs. accumulate); it does not change the + # output buffer. + main_grad = get_main_grad_from_param(weights[0], op_label=op_label) + main_grad = view_main_grad_as_grouped_buffer( + main_grad, num_groups, weight_shape, label=f"{op_label} weight" + ) grouped_wgrad = GroupedTensor.make_grouped_tensor_from_rowwise_data( num_tensors=num_groups, tensor_shape=weight_shape, rowwise_data=main_grad, dtype=main_grad.dtype, ) - - if grouped_wgrad is None: + accumulate_into_main_grad = get_accumulate_flag_in_param(weights[0]) + else: grouped_wgrad = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_groups, shapes=[weight_shape] * num_groups, @@ -195,19 +186,17 @@ def _compute_grad_params( dtype=dtype, ) wgrad_output = grouped_wgrad + w_list = [grouped_wgrad.rowwise_data.view(num_groups, *weight_shape)] else: w_list = [None] * num_groups if ctx.weight_requires_grad: if fc_op._accumulate_into_main_grad: - for idx in range(num_groups): - wp = getattr(fc_op, f"weight{idx}") - if hasattr(wp, "__fsdp_param__"): - wp.main_grad = wp.get_main_grad() - w_list[idx] = wp.main_grad - accumulate_into_main_grad = not getattr(fc_op.weight0, "overwrite_main_grad", False) + w_list = [get_main_grad_from_param(w, op_label=op_label) for w in weights] + accumulate_into_main_grad = get_accumulate_flag_in_param(weights[0]) else: - for idx in range(num_groups): - w_list[idx] = torch.empty(weight_shape, dtype=dtype, device=device) + w_list = [ + torch.empty(weight_shape, dtype=dtype, device=device) for _ in range(num_groups) + ] wgrad_output = w_list if ctx.weight_requires_grad: @@ -237,34 +226,11 @@ def _compute_grad_params( else: gemm_fn(grouped_x, grouped_dy, wgrad_output) - # Extract results, mark accumulated if needed - if fc_op.single_grouped_weight: - packed_wgrad = None - if not delay_wgrad: - packed_wgrad = grouped_wgrad.rowwise_data.view(num_groups, *weight_shape) - if fc_op._accumulate_into_main_grad and hasattr( - weight_param, "grad_added_to_main_grad" - ): - weight_param.grad_added_to_main_grad = True - packed_wgrad = get_dummy_wgrad( - list(weight_param.size()), - weight_param.dtype, - zero=getattr(weight_param, "zero_out_wgrad", False), - ) - w_list = [packed_wgrad] - else: - if delay_wgrad or fc_op._accumulate_into_main_grad: - w_list = [None] * num_groups - if fc_op._accumulate_into_main_grad: - for idx in range(num_groups): - wp = getattr(fc_op, f"weight{idx}") - if hasattr(wp, "grad_added_to_main_grad"): - wp.grad_added_to_main_grad = True - w_list[idx] = get_dummy_wgrad( - list(wp.size()), - wp.dtype, - zero=getattr(wp, "zero_out_wgrad", False), - ) + # Need to return dummy wgrads for Megatron-LM wgrad fusion if grad is already added + if fc_op._accumulate_into_main_grad: + w_list = get_dummy_wgrads_for_params(weights) + elif delay_wgrad: + w_list = [None] if fc_op.single_grouped_weight else [None] * num_groups # Assemble grad_params in parameter registration order. if not fc_op.has_bias: @@ -372,18 +338,15 @@ def fuser_backward( grad_output = grad_output.reshape(-1, fc2_weight_shape[0]) out_shape = list(grad_output.size()) num_groups = fc1_op.num_groups - fc1_weight_param = fc1_op.weight if fc1_op.single_grouped_weight else fc1_op.weight0 - device = fc1_weight_param.device + device = fc1_op._get_weight_tensors()[0].device dtype = fc1_ctx.dtype - # Saved tensors from FC1 forward + # Saved tensors from FC1 forward. + # Layout: [split_sizes, base_split_offsets, split_points, + # grouped_fc1_x, *fc1_weights] saved_tensors = fc1_ctx.saved_tensors - split_sizes, split_points, saved_tensors = ( - saved_tensors[0], - saved_tensors[1], - saved_tensors[2:], - ) - + split_sizes, base_split_offsets, split_points = saved_tensors[:3] + grouped_fc1_x, saved_tensors = saved_tensors[3], saved_tensors[4:] if fc1_op.single_grouped_weight: grouped_fc1_weight, saved_tensors = saved_tensors[0], saved_tensors[1:] else: @@ -392,21 +355,21 @@ def fuser_backward( saved_tensors[num_groups:], ) - ( - fc1_x_col_data, - fc1_x_col_scale, - fc1_x_tensor_offsets, - ), saved_tensors = ( - saved_tensors[:3], - saved_tensors[3:], - ) - # Saved tensors from scaled SwiGLU forward swiglu_in, scales = swiglu_ctx.saved_tensors - # Saved tensors from FC2 forward - saved_tensors = fc2_ctx.saved_tensors - _, saved_tensors = saved_tensors[0], saved_tensors[1:] # Assume same split sizes as FC1 + # Saved tensors from FC2 forward. + # Layout: [split_sizes, base_split_offsets, split_points, + # (fc2_scales if _scale_bias), + # grouped_fc2_x, *fc2_weights] + scale_bias = fc2_op._scale_bias and fc2_op.has_bias + saved_tensors = fc2_ctx.saved_tensors[3:] + if fc2_op._scale_bias: + # Saved for the unfused backward path, which reads its own + # per-op scales here. The fused backward below currently reuses + # the SwiGLU ``scales``. + saved_tensors = saved_tensors[1:] + grouped_fc2_x, saved_tensors = saved_tensors[0], saved_tensors[1:] if fc2_op.single_grouped_weight: grouped_fc2_weight, saved_tensors = saved_tensors[0], saved_tensors[1:] else: @@ -415,53 +378,19 @@ def fuser_backward( saved_tensors[num_groups:], ) - ( - fc2_x_col_data, - fc2_x_col_scale, - fc2_x_tensor_offsets, - ), saved_tensors = ( - saved_tensors[:3], - saved_tensors[3:], - ) - # Group splits if int(split_sizes.numel()) != num_groups: raise ValueError(f"Expected {num_groups} splits, but got {int(split_sizes.numel())}.") - scale_bias = fc2_op._scale_bias and fc2_op.has_bias - grouped_fc1_x = None - if fc1_ctx.weight_requires_grad: - grouped_fc1_x = GroupedTensor( - shape=(out_shape[0], fc1_weight_shape[1]), - dtype=dtype, - num_tensors=num_groups, - quantizer=fc1_ctx.input_quantizer, - columnwise_data=fc1_x_col_data, - columnwise_scale_inv=fc1_x_col_scale, - first_dims=split_sizes, - tensor_offsets=fc1_x_tensor_offsets, - with_gemm_swizzled_scales=True, - ) - - grouped_fc2_x = None - if fc2_ctx.weight_requires_grad: - grouped_fc2_x = GroupedTensor( - shape=(out_shape[0], fc2_weight_shape[1]), - dtype=dtype, - num_tensors=num_groups, - quantizer=fc2_ctx.input_quantizer, - columnwise_data=fc2_x_col_data, - columnwise_scale_inv=fc2_x_col_scale, - first_dims=split_sizes, - tensor_offsets=fc2_x_tensor_offsets, - with_gemm_swizzled_scales=True, - ) + if not fc1_ctx.weight_requires_grad: + grouped_fc1_x = None + if not fc2_ctx.weight_requires_grad: + grouped_fc2_x = None # Split grad output tensor and convert dtypes if needed - fc2_ctx.grad_output_quantizer.set_usage( - rowwise=True, columnwise=fc2_ctx.weight_requires_grad - ) - fc2_ctx.grad_output_quantizer.optimize_for_gemm = True + fc2_grad_output_quantizer = fc2_ctx.grad_output_quantizers[0] + fc2_grad_output_quantizer.set_usage(rowwise=True, columnwise=fc2_ctx.weight_requires_grad) + fc2_grad_output_quantizer.optimize_for_gemm = True output_fc2_dbias = fc2_op.has_bias fc2_dbias_packed = None fc2_dy = None @@ -476,14 +405,14 @@ def fuser_backward( if output_fc2_dbias and not scale_bias: grouped_fc2_dy, fc2_dbias_packed = tex.bgrad_group_quantize( fc2_dy, - fc2_ctx.grad_output_quantizer, + fc2_grad_output_quantizer, num_groups, split_sizes, ) else: grouped_fc2_dy = tex.group_quantize( fc2_dy, - fc2_ctx.grad_output_quantizer, + fc2_grad_output_quantizer, num_groups, split_sizes, ) @@ -600,7 +529,7 @@ def fuser_backward( fc2_dy, scales_f32, bias_packed, - offsets=fc1_ctx.base_split_offsets, + offsets=base_split_offsets, dscales=grad_scales, ) fc2_dbias_packed_result = fc2_dbias_packed_result.to(dtype=dtype) @@ -629,12 +558,12 @@ def fuser_backward( fc1_bias_grads = [dbias_2d[group_idx] for group_idx in range(num_groups)] # FC1 grad output for dgrad and wgrad GEMMs - fc1_dy_tensor_offsets = fc1_ctx.base_split_offsets * fc1_weight_shape[0] + fc1_dy_tensor_offsets = base_split_offsets * fc1_weight_shape[0] grouped_fc1_dy = GroupedTensor( shape=(out_shape[0], fc1_weight_shape[0]), dtype=dtype, num_tensors=num_groups, - quantizer=fc1_ctx.grad_output_quantizer, + quantizer=fc1_ctx.grad_output_quantizers[0], data=fc1_dy_row_data, columnwise_data=fc1_dy_col_data, scale_inv=fc1_dy_row_scale, @@ -668,7 +597,7 @@ def fuser_backward( and fc2_op.wgrad_store.delay_wgrad_compute() ): clear_tensor_data( - grouped_fc2_x.data, + grouped_fc2_x.rowwise_data, grouped_fc2_x.columnwise_data, grouped_fc2_x.scale_inv, grouped_fc2_x.columnwise_scale_inv, @@ -764,7 +693,7 @@ def fuser_backward( and fc1_op.wgrad_store.delay_wgrad_compute() ): clear_tensor_data( - grouped_fc1_x.data, + grouped_fc1_x.rowwise_data, grouped_fc1_x.columnwise_data, grouped_fc1_x.scale_inv, grouped_fc1_x.columnwise_scale_inv, diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_add.py b/transformer_engine/pytorch/ops/fused/backward_linear_add.py index c06e212e87..382fecfd07 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_add.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_add.py @@ -9,9 +9,13 @@ import torch -from ...module.base import get_dummy_wgrad from ...utils import clear_tensor_data from ..basic import BasicLinear, MakeExtraOutput +from .._common import ( + get_accumulate_flag_in_param, + get_dummy_wgrads_for_params, + get_main_grad_from_param, +) from ..op import FusedOperation, FusibleOperation, OperationContext @@ -57,16 +61,9 @@ def fuser_backward( grad_weight = None if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: weight_param = linear_op.weight - if hasattr(weight_param, "__fsdp_param__"): - weight_param.main_grad = weight_param.get_main_grad() - accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) - if not hasattr(weight_param, "main_grad"): - raise RuntimeError( - "BasicLinear op is configured with " - "accumulate_into_main_grad=True, " - "but weight parameter does not have main_grad attribute" - ) - grad_weight = weight_param.main_grad.detach() + main_grad = get_main_grad_from_param(weight_param, op_label="BasicLinear") + accumulate_into_main_grad = get_accumulate_flag_in_param(weight_param) + grad_weight = main_grad.detach() else: accumulate_into_main_grad = False @@ -99,15 +96,7 @@ def fuser_backward( # Megatron-LM wgrad fusion # Note: Return dummy tensor for grad weight if needed. if accumulate_into_main_grad: - grad_weight = None - weight_param = linear_op.weight - if hasattr(weight_param, "grad_added_to_main_grad"): - weight_param.grad_added_to_main_grad = True - grad_weight = get_dummy_wgrad( - list(weight_param.size()), - weight_param.dtype, - zero=getattr(weight_param, "zero_out_wgrad", False), - ) + grad_weight = get_dummy_wgrads_for_params([linear_op.weight])[0] return grad_input, [(), (grad_weight,)], [(), ()] diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py index 709073e6f8..b48c2e6d52 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py @@ -9,9 +9,13 @@ import torch -from ...module.base import get_dummy_wgrad from ...utils import clear_tensor_data from ..basic import BasicLinear, ConstantScale +from .._common import ( + get_accumulate_flag_in_param, + get_dummy_wgrads_for_params, + get_main_grad_from_param, +) from ..op import FusedOperation, FusibleOperation, OperationContext @@ -58,16 +62,9 @@ def fuser_backward( grad_weight = None if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: weight_param = linear_op.weight - if hasattr(weight_param, "__fsdp_param__"): - weight_param.main_grad = weight_param.get_main_grad() - accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) - if not hasattr(weight_param, "main_grad"): - raise RuntimeError( - "BasicLinear op is configured with " - "accumulate_into_main_grad=True, " - "but weight parameter does not have main_grad attribute" - ) - grad_weight = weight_param.main_grad.detach() + main_grad = get_main_grad_from_param(weight_param, op_label="BasicLinear") + accumulate_into_main_grad = get_accumulate_flag_in_param(weight_param) + grad_weight = main_grad.detach() else: accumulate_into_main_grad = False @@ -99,15 +96,7 @@ def fuser_backward( # Megatron-LM wgrad fusion # Note: Return dummy tensor for grad weight if needed. if accumulate_into_main_grad: - grad_weight = None - weight_param = linear_op.weight - if hasattr(weight_param, "grad_added_to_main_grad"): - weight_param.grad_added_to_main_grad = True - grad_weight = get_dummy_wgrad( - list(weight_param.size()), - weight_param.dtype, - zero=getattr(weight_param, "zero_out_wgrad", False), - ) + grad_weight = get_dummy_wgrads_for_params([linear_op.weight])[0] return grad_input, [(grad_weight,), ()], [(), ()] diff --git a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py index 599e5f96ae..91db2ff9b7 100644 --- a/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py +++ b/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py @@ -160,10 +160,9 @@ def fuser_forward( if int(split_sizes.numel()) != num_groups: raise ValueError(f"Expected {num_groups} splits, but got {int(split_sizes.numel())}.") split_sizes = split_sizes.to(dtype=torch.int64, device=device) - base_offsets = tex.splits_to_offsets(split_sizes, 1) - split_points = base_offsets[1:].to(dtype=torch.int) - fc1_x_tensor_offsets = base_offsets * fc1_weight_shape[1] - fc2_x_tensor_offsets = base_offsets * fc2_weight_shape[1] + base_split_offsets = tex.splits_to_offsets(split_sizes, 1) + split_points = base_split_offsets[1:].to(dtype=torch.int) + fc2_x_tensor_offsets = base_split_offsets * fc2_weight_shape[1] # Extract post-scales from extra input scales = basic_op_extra_inputs[1][0] @@ -452,27 +451,35 @@ def fuser_forward( # Save state for backward pass if requires_grad: mark_grouped_tensor(grouped_fc1_x, swiglu_in, scales, grouped_fc2_x) - fc1_input_tensors = ( - grouped_fc1_x.columnwise_data, - grouped_fc1_x.columnwise_scale_inv, - fc1_x_tensor_offsets, - ) - # FC1 + + # Save the input ``GroupedTensor``s themselves for the activations. + for grouped_fc_x in (grouped_fc1_x, grouped_fc2_x): + if grouped_fc_x is not None: + grouped_fc_x.rowwise_data = None + grouped_fc_x.scale_inv = None + + # FC1 saved-tensor layout. + # [split_sizes, base_split_offsets, split_points, + # grouped_fc1_x, *fc1_weight_tensors] fc1_weight_tensors = ( [grouped_fc1_weight] if fc1_op.single_grouped_weight else grouped_fc1_weight ) fc1_ctx.save_for_backward( - split_sizes, split_points, *fc1_weight_tensors, *fc1_input_tensors + split_sizes, + base_split_offsets, + split_points, + grouped_fc1_x, + *fc1_weight_tensors, ) + fc1_ctx.use_grouped_tensor_path = True fc1_ctx.with_quantized_compute = True - fc1_ctx.input_quantizer = fc1_input_quantizer - fc1_ctx.weight_quantizer = fc1_weight_quantizer - fc1_ctx.grad_output_quantizer = fc1_grad_output_quantizer + fc1_ctx.input_quantizers = [fc1_input_quantizer] + fc1_ctx.weight_quantizers = [fc1_weight_quantizer] + fc1_ctx.grad_output_quantizers = [fc1_grad_output_quantizer] fc1_ctx.grad_input_quantizers = None fc1_ctx.dtype = dtype fc1_ctx.input_requires_grad = input_requires_grad fc1_ctx.weight_requires_grad = weight_requires_grad - fc1_ctx.base_split_offsets = base_offsets # Scaled SwiGLU swiglu_ctx.save_for_backward(swiglu_in, scales) @@ -480,25 +487,31 @@ def fuser_forward( swiglu_ctx.extra_input_requires_grad = True swiglu_ctx.dtype = dtype - # FC2 state - if grouped_fc2_x is not None: - fc2_input_tensors = ( - grouped_fc2_x.columnwise_data, - grouped_fc2_x.columnwise_scale_inv, - fc2_x_tensor_offsets, - ) - else: - fc2_input_tensors = (None, None, None) - - if fc2_op.single_grouped_weight: - fc2_ctx.save_for_backward(split_sizes, grouped_fc2_weight, *fc2_input_tensors) - else: - fc2_ctx.save_for_backward(split_sizes, *grouped_fc2_weight, *fc2_input_tensors) - + # FC2 saved-tensor layout. Matches the unfused + # ``GroupedLinear._fuser_forward_grouped_tensor`` layout so the + # unfused backward (basic/grouped_linear.py) can consume the same + # ctx when the fused backward is unavailable. + # [split_sizes, base_split_offsets, split_points, + # (fc2_scales if _scale_bias), + # grouped_fc2_x, *fc2_weight_tensors] + fc2_weight_tensors = ( + [grouped_fc2_weight] if fc2_op.single_grouped_weight else grouped_fc2_weight + ) + fc2_saved: list[Optional[torch.Tensor]] = [ + split_sizes, + base_split_offsets, + split_points, + ] + if fc2_op._scale_bias: + fc2_saved.append(fc2_scales) + fc2_saved.append(grouped_fc2_x) + fc2_saved.extend(fc2_weight_tensors) + fc2_ctx.save_for_backward(*fc2_saved) + fc2_ctx.use_grouped_tensor_path = True fc2_ctx.with_quantized_compute = True - fc2_ctx.input_quantizer = fc2_input_quantizer - fc2_ctx.weight_quantizer = fc2_weight_quantizer - fc2_ctx.grad_output_quantizer = fc2_grad_output_quantizer + fc2_ctx.input_quantizers = [fc2_input_quantizer] + fc2_ctx.weight_quantizers = [fc2_weight_quantizer] + fc2_ctx.grad_output_quantizers = [fc2_grad_output_quantizer] fc2_ctx.grad_input_quantizers = None fc2_ctx.dtype = dtype fc2_ctx.input_requires_grad = input_requires_grad diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index fbaf69d75d..7d67815f9a 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -17,14 +17,19 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, fill_userbuffers_buffer_for_all_gather, - get_dummy_wgrad, get_ub, ) from ...quantized_tensor import Quantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data from ..basic import BasicLinear, Bias, ReduceScatter -from .._common import maybe_dequantize, is_quantized_tensor +from .._common import ( + get_accumulate_flag_in_param, + get_dummy_wgrads_for_params, + get_main_grad_from_param, + is_quantized_tensor, + maybe_dequantize, +) from ..op import FusedOperation, FusibleOperation, OperationContext @@ -519,16 +524,9 @@ def fuser_backward( grad_weight = None if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: weight_param = linear_op.weight - if hasattr(weight_param, "__fsdp_param__"): - weight_param.main_grad = weight_param.get_main_grad() - accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False) - if not hasattr(weight_param, "main_grad"): - raise RuntimeError( - "BasicLinear op is configured with " - "accumulate_into_main_grad=True, " - "but weight parameter does not have main_grad attribute" - ) - grad_weight = weight_param.main_grad.detach() + main_grad = get_main_grad_from_param(weight_param, op_label="UserbuffersBackwardLinear") + accumulate_into_main_grad = get_accumulate_flag_in_param(weight_param) + grad_weight = main_grad.detach() else: accumulate_into_main_grad = False @@ -563,15 +561,7 @@ def fuser_backward( # Megatron-LM wgrad fusion # Note: Return dummy tensor for grad weight if needed. if accumulate_into_main_grad: - grad_weight = None - weight_param = linear_op.weight - if hasattr(weight_param, "grad_added_to_main_grad"): - weight_param.grad_added_to_main_grad = True - grad_weight = get_dummy_wgrad( - list(weight_param.size()), - weight_param.dtype, - zero=getattr(weight_param, "zero_out_wgrad", False), - ) + grad_weight = get_dummy_wgrads_for_params([linear_op.weight])[0] # Return gradients grad_params = [() for _ in range(len(self.basic_ops))] diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index 5f12c3ed8c..485b32328b 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -303,6 +303,51 @@ def get_dtype(self) -> torch.dtype: return self.fake_dtype + def prepare_for_saving( + self, + ) -> Tuple[list[Optional[torch.Tensor]], "GroupedTensorStorage"]: + """Prepare the tensor base for saving for backward.""" + tensors = [ + self.rowwise_data, + self.columnwise_data, + self.scale_inv, + self.columnwise_scale_inv, + self.amax, + self.columnwise_amax, + self.scale, + self.first_dims, + self.last_dims, + self.tensor_offsets, + ] + self.rowwise_data = None + self.columnwise_data = None + self.scale_inv = None + self.columnwise_scale_inv = None + self.amax = None + self.columnwise_amax = None + self.scale = None + self.first_dims = None + self.last_dims = None + self.tensor_offsets = None + self.quantized_tensors = None + return tensors, self + + def restore_from_saved( + self, tensors: list[Optional[torch.Tensor]] + ) -> list[Optional[torch.Tensor]]: + """Restore the tensor base data from the saved tensors list.""" + self.rowwise_data = tensors[0] + self.columnwise_data = tensors[1] + self.scale_inv = tensors[2] + self.columnwise_scale_inv = tensors[3] + self.amax = tensors[4] + self.columnwise_amax = tensors[5] + self.scale = tensors[6] + self.first_dims = tensors[7] + self.last_dims = tensors[8] + self.tensor_offsets = tensors[9] + return tensors[10:] + def clear(self) -> None: """ Reset tensor data and clear all buffers.