Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
337 changes: 253 additions & 84 deletions tests/pytorch/test_fusible_ops.py

Large diffs are not rendered by default.

72 changes: 72 additions & 0 deletions tests/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import os
import random
import subprocess
from collections.abc import Iterable
from contextlib import contextmanager
from typing import Optional, Sequence, Tuple, Dict, Any, List
from packaging.version import Version as PkgVersion
Expand All @@ -27,6 +28,7 @@
check_set_window_size,
)
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
from transformer_engine.pytorch.module.base import get_dummy_wgrad


def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype:
Expand Down Expand Up @@ -477,3 +479,73 @@ def run_distributed(
msg += f"\n--- stderr ---\n{stderr_tail}"
raise AssertionError(msg)
return result


class MegatronTrainingHelper:
"""Test-side stand-in for the Megatron-Core DDP / MegatronFSDP wrapper.
Megatron's DDP wrapper (and MegatronFSDP) owns the per-parameter
``main_grad`` buffer and the ``overwrite_main_grad`` /
``grad_added_to_main_grad`` attributes that coordinate
``fuse_wgrad_accumulation`` with TE modules. These helpers reproduce the
relevant slice of that protocol so TE tests can exercise the
accumulate-into-``main_grad`` code path without pulling in the full
Megatron-Core dependency.
"""

@staticmethod
def init_main_grad_buffers(
weight_params: Iterable[torch.nn.Parameter],
*,
fill_value: float,
overwrite_main_grad: bool,
zero_out_wgrad: bool = False,
dtype: torch.dtype = torch.float32,
) -> None:
"""Allocate ``main_grad`` and stamp the wrapper attributes on each
param, mirroring what the Megatron DDP/FSDP wrapper does before
backward."""
for wp in weight_params:
wp.main_grad = torch.full(wp.size(), fill_value, device=wp.device, dtype=dtype)
wp.overwrite_main_grad = overwrite_main_grad
wp.zero_out_wgrad = zero_out_wgrad
wp.grad_added_to_main_grad = False

@staticmethod
def verify_main_grad_accumulation(
weight_params: Iterable[torch.nn.Parameter],
*,
expected_main_grads: Iterable[torch.Tensor],
rtol: float = 0.0,
atol: float = 0.0,
) -> None:
"""Check that backward produced what the Megatron wrapper expects:
each ``main_grad`` matches ``expected_main_grads``,
``grad_added_to_main_grad`` was flipped to ``True`` so the wrapper's
post-backward hooks won't double-accumulate, and ``param.grad`` was
replaced by the cached dummy tensor (so a wrapper hook that did
``main_grad += grad`` would be a no-op rather than double-counting).
"""
for wp, expected in zip(weight_params, expected_main_grads):
torch.testing.assert_close(wp.main_grad.to(expected), expected, rtol=rtol, atol=atol)

assert wp.grad_added_to_main_grad is True, (
"weight.grad_added_to_main_grad was not flipped to True; "
"the Megatron DDP/FSDP wrapper hook will double-accumulate."
)

# ``.grad`` should be the cached dummy tensor returned by
# ``get_dummy_wgrad`` -- shared storage, not the real wgrad.
expected_dummy = get_dummy_wgrad(list(wp.size()), wp.dtype)
assert (
wp.grad is not None
), "weight.grad is None; the Megatron protocol expects a dummy tensor stand-in here."
assert wp.grad.data_ptr() == expected_dummy.data_ptr(), (
"weight.grad does not share storage with the cached dummy "
"wgrad; downstream wrapper hooks risk double-accumulating."
)
if getattr(wp, "zero_out_wgrad", False):
assert torch.all(wp.grad == 0), (
"weight.zero_out_wgrad=True but the dummy weight.grad "
"was not zeroed; downstream hooks reading .grad would "
"see stale bytes from the previous step."
)
94 changes: 94 additions & 0 deletions transformer_engine/pytorch/ops/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from __future__ import annotations
import functools
import math
from importlib.metadata import PackageNotFoundError, version as get_pkg_version
from typing import Optional

Expand Down Expand Up @@ -88,6 +89,99 @@ def get_fp8_meta_from_fp8_tensor(tensor: Float8Tensor) -> tuple[FP8TensorMeta, i
return fp8_meta, 0


def get_main_grad_from_param(
weight_param: torch.nn.Parameter,
*,
op_label: str = "",
) -> torch.Tensor:
"""Refresh ``main_grad`` from FSDP (if applicable) and return it.
Used by Megatron-LM-style wgrad fusion paths
(``accumulate_into_main_grad=True``) to obtain the buffer the wgrad GEMM
will write into.
Raises if the parameter does not have a ``main_grad`` attribute or if it
is ``None``.
"""
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
if not hasattr(weight_param, "main_grad") or weight_param.main_grad is None:
prefix = f"{op_label} " if op_label else ""
raise RuntimeError(
f"{prefix}operation is configured with accumulate_into_main_grad=True, "
"but weight parameter does not have a valid main_grad attribute"
)
return weight_param.main_grad


def get_accumulate_flag_in_param(weight_param: torch.nn.Parameter) -> bool:
"""Return whether the wgrad GEMM should accumulate into ``main_grad``.

Returns ``False`` (i.e. overwrite) when the parameter has
``overwrite_main_grad=True`` (used in Megatron-FSDP), and ``True``
otherwise.
"""
return not getattr(weight_param, "overwrite_main_grad", False)


def view_main_grad_as_grouped_buffer(
main_grad: torch.Tensor,
num_groups: int,
weight_shape: tuple[int, ...],
*,
label: str = "",
) -> torch.Tensor:
"""Return ``main_grad`` viewed as ``(num_groups, *weight_shape)`` without copy.
Raises if the numel doesn't match or if the existing stride pattern does
not allow a zero-copy view to the grouped layout.
"""
grouped_shape = (num_groups, *weight_shape)
if tuple(main_grad.shape) == grouped_shape:
return main_grad
prefix = f"{label} " if label else "Grouped weight "
if main_grad.numel() != math.prod(grouped_shape):
raise RuntimeError(
f"{prefix}main_grad expected shape {grouped_shape} or matching numel, "
f"but got shape {tuple(main_grad.shape)}"
)
try:
return main_grad.view(grouped_shape)
except RuntimeError as e:
raise RuntimeError(
f"{prefix}main_grad must be viewable as {grouped_shape} without copy, "
f"but got shape {tuple(main_grad.shape)} and stride "
f"{tuple(main_grad.stride())}"
) from e


def get_dummy_wgrads_for_params(
weight_params: list[torch.nn.Parameter],
) -> list[Optional[torch.Tensor]]:
"""Build dummy ``.grad`` placeholders for Megatron-LM wgrad-fusion params.

For each parameter that exposes ``grad_added_to_main_grad``, set the flag
to ``True`` and return a dummy wgrad tensor (zeroed if
``zero_out_wgrad`` is also set on the parameter). For parameters without
the flag, the corresponding entry is ``None``.

The returned list has the same length and order as ``weight_params``.
"""
from ..module.base import get_dummy_wgrad # pylint: disable=import-outside-toplevel

out: list[Optional[torch.Tensor]] = []
for wp in weight_params:
if hasattr(wp, "grad_added_to_main_grad"):
wp.grad_added_to_main_grad = True
out.append(
get_dummy_wgrad(
list(wp.size()),
wp.dtype,
zero=getattr(wp, "zero_out_wgrad", False),
)
)
else:
out.append(None)
return out


def validate_grouped_mlp_dims(fc1, glu_op, fc2) -> None:
"""Validate FC1 / scaled GLU / FC2 dimensions for fused grouped MLP."""

Expand Down
32 changes: 11 additions & 21 deletions transformer_engine/pytorch/ops/basic/basic_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
_2X_ACC_FPROP,
_2X_ACC_DGRAD,
_2X_ACC_WGRAD,
get_dummy_wgrad,
)
from ...tensor import Quantizer
from ...tensor.float8_tensor import Float8Quantizer
Expand All @@ -36,7 +35,13 @@
devices_match,
)
from ..op import BasicOperation, OperationContext
from .._common import maybe_dequantize, is_quantized_tensor
from .._common import (
get_accumulate_flag_in_param,
get_dummy_wgrads_for_params,
get_main_grad_from_param,
is_quantized_tensor,
maybe_dequantize,
)


def _wait_async(handle: Optional[Any]) -> None:
Expand Down Expand Up @@ -1060,16 +1065,9 @@ def op_backward(
grad_weight = None
if ctx.weight_requires_grad and accumulate_into_main_grad:
weight_param = self.weight
if hasattr(weight_param, "__fsdp_param__"):
weight_param.main_grad = weight_param.get_main_grad()
accumulate_into_main_grad = not getattr(weight_param, "overwrite_main_grad", False)
if not hasattr(weight_param, "main_grad"):
raise RuntimeError(
"BasicLinear op is configured with "
"accumulate_into_main_grad=True, "
"but weight parameter does not have main_grad attribute"
)
grad_weight = weight_param.main_grad.detach()
main_grad = get_main_grad_from_param(weight_param, op_label="BasicLinear")
accumulate_into_main_grad = get_accumulate_flag_in_param(weight_param)
grad_weight = main_grad.detach()
else:
accumulate_into_main_grad = False

Expand Down Expand Up @@ -1099,14 +1097,6 @@ def op_backward(
# Megatron-LM wgrad fusion
# Note: Return dummy tensor for grad weight if needed.
if accumulate_into_main_grad:
grad_weight = None
weight_param = self.weight
if hasattr(weight_param, "grad_added_to_main_grad"):
weight_param.grad_added_to_main_grad = True
grad_weight = get_dummy_wgrad(
list(weight_param.size()),
weight_param.dtype,
zero=getattr(weight_param, "zero_out_wgrad", False),
)
grad_weight = get_dummy_wgrads_for_params([self.weight])[0]

return grad_input, [grad_weight]
Loading
Loading