Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 1 addition & 40 deletions tests/pytorch/distributed/fsdp2_tests/run_fsdp2_fused_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,6 @@ def test_fused_adam_fp8_master_weights_no_meta(recipe_name):
"""
recipe = get_recipe_from_string(recipe_name)

if recipe_name in ("MXFP8BlockScaling", "Float8BlockScaling", "NVFP4BlockScaling"):
pytest.xfail(
f"{recipe_name}: FSDP2 all-gather hooks for block-scaling QuantizedTensor "
"subclasses fail when parameters are initialized on CUDA. "
"Use device='meta' + reset_parameters() after sharding."
)

world_size, device = _get_dist_info()

model = _build_model(fp8_init=True, recipe=recipe, use_meta_device=False)
Expand Down Expand Up @@ -604,12 +597,6 @@ def test_safetensors_fp32_export(recipe_name):
- Saved tensor shapes match expected (unsharded) shapes
"""
recipe = get_recipe_from_string(recipe_name)
if recipe_name == "MXFP8BlockScaling":
pytest.xfail(
"MXFP8BlockScaling: FusedAdam CUDA kernel does not support "
"MXFP8 quantized tensors, causing illegal memory access. "
"Fixed by https://github.com/NVIDIA/TransformerEngine/pull/2789."
)

from safetensors.torch import load_file, save_file
from torch.distributed.checkpoint.state_dict import (
Expand Down Expand Up @@ -692,40 +679,14 @@ def test_dcp_output_parity(recipe_name, async_save):
"""
recipe = get_recipe_from_string(recipe_name)

if recipe_name == "MXFP8BlockScaling":
pytest.xfail(
"MXFP8BlockScaling: FusedAdam CUDA kernel does not support "
"MXFP8 quantized tensors, causing illegal memory access: "
"/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh:92 in function "
"multi_tensor_apply: CUDA Error: an illegal memory access was encountered. "
"Fixed by https://github.com/NVIDIA/TransformerEngine/pull/2789."
)

if recipe_name == "NVFP4BlockScaling":
pytest.xfail(
"NVFP4BlockScaling: DCP load_state_dict triggers reset_sharded_param() "
"which calls data_ptr() on NVFP4Tensor wrapper subclass with invalid storage"
)

if (
recipe_name == "Float8BlockScaling"
and not async_save
and torch.cuda.get_device_capability()[0] == 12
):
if recipe_name == "Float8BlockScaling" and torch.cuda.get_device_capability()[0] == 12:
pytest.xfail(
"Float8BlockScaling is failing on SM120 with RuntimeError: "
"transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu:534 "
"in function quantize_transpose_vector_blockwise: Assertion failed: pow2_scale. On "
"Blackwell and newer, the FP8 block scaling recipe is emulated with MXFP8, which "
"requires using power of two scaling factors."
)
if recipe_name == "Float8BlockScaling" and async_save:
pytest.xfail(
"Float8BlockScaling: async DCP save/load round-trip produces different model "
"outputs — quantization metadata (scales) is not correctly persisted through "
"async distributed checkpointing. On SM120, additionally fails with pow2_scale "
"assertion in quantize_transpose_vector_blockwise."
)

import torch.distributed.checkpoint as dcp

Expand Down
9 changes: 0 additions & 9 deletions tests/pytorch/distributed/fsdp2_tests/run_fsdp2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,20 +379,11 @@ def test_distributed(recipe_name, fp8_init, sharding_dims, layer_type):
"sending only 1 tensor (scale is per-tensor metadata). Fix: concatenate MXFP8 "
"data and scale_inv into a single buffer in pre_all_gather, split in post."
)

if recipe_name == "Float8BlockScaling" and fp8_init:
pytest.xfail(
"Float8BlockScaling + fp8_init: scale inverse padding is not handled "
"correctly during FSDP2 all-gather slice ops."
)
if recipe_name == "NVFP4BlockScaling" and fp8_init and layer_type == "TransformerLayer":
pytest.xfail(
"NVFP4BlockScaling + fp8_init + TransformerLayer: "
"_check_fp8_fsdp2_allgather numerical error compounds across multiple "
"linear layers in the transformer block (up to ~1e-2 max abs diff). "
"LayerNormLinear passes with relaxed tolerances. "
"NVFP4 + FSDP2 training is validated by run_fsdp2_fused_adam.py."
)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

Expand Down
50 changes: 50 additions & 0 deletions tests/pytorch/test_quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,56 @@ def test_identity_op(
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, dx_ref, **tols)

@pytest.mark.parametrize("quantization", _quantization_list)
def test_cpu_dequantize(
self,
*,
quantization: str,
shape: Iterable[int] = (128, 128),
dtype: torch.dtype = torch.bfloat16,
) -> None:
"""Dequantize on a CPU-resident QuantizedTensor."""

# Construct a quantized tensor on CUDA.
_, x_cuda = make_reference_and_test_tensors(
shape=shape,
quantization=quantization,
test_dtype=dtype,
requires_grad=False,
)
assert isinstance(x_cuda, QuantizedTensor)
assert x_cuda.device.type == "cuda"

# Reference: dequantize on CUDA, then move the dense result to CPU.
ref_cpu = x_cuda.dequantize().to(device="cpu")

# Move the QuantizedTensor itself to CPU and dequantize there.
# ``.cpu()`` routes through ``aten._to_copy.default`` so all inner
# buffers (data, scales, amax) are moved to CPU.
x_cpu = x_cuda.cpu()
assert isinstance(x_cpu, QuantizedTensor)
assert x_cpu.device.type == "cpu"
for attr in (
"_data",
"_rowwise_data",
"_columnwise_data",
"_rowwise_scale_inv",
"_columnwise_scale_inv",
"_amax_rowwise",
"_amax_columnwise",
):
buf = getattr(x_cpu, attr, None)
if buf is not None:
assert buf.device.type == "cpu", f"{attr} did not move to CPU"

# Dequantize the CPU tensor. Implementation may bounce through CUDA
# internally, but must return a CPU tensor.
y_cpu = x_cpu.dequantize()
assert y_cpu.device.type == "cpu"
assert y_cpu.dtype == ref_cpu.dtype
assert y_cpu.shape == ref_cpu.shape
torch.testing.assert_close(y_cpu, ref_cpu, rtol=0, atol=0)

@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("dim", [0, 1])
def test_chunk(
Expand Down
11 changes: 10 additions & 1 deletion transformer_engine/common/util/pybind_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,16 @@
.value("kBFloat16", transformer_engine::DType::kBFloat16) \
.value("kFloat8E4M3", transformer_engine::DType::kFloat8E4M3) \
.value("kFloat8E5M2", transformer_engine::DType::kFloat8E5M2) \
.value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1); \
.value("kFloat4E2M1", transformer_engine::DType::kFloat4E2M1) \
.def("__reduce_ex__", \
[](transformer_engine::DType self, pybind11::object /*protocol*/) { \
return pybind11::make_tuple(pybind11::type::of(pybind11::cast(self)), \
pybind11::make_tuple(static_cast<int>(self))); \
}) \
.def("__reduce__", [](transformer_engine::DType self) { \
return pybind11::make_tuple(pybind11::type::of(pybind11::cast(self)), \
pybind11::make_tuple(static_cast<int>(self))); \
}); \
pybind11::enum_<NVTE_Bias_Type>(m, "NVTE_Bias_Type", pybind11::module_local()) \
.value("NVTE_NO_BIAS", NVTE_Bias_Type::NVTE_NO_BIAS) \
.value("NVTE_PRE_SCALE_BIAS", NVTE_Bias_Type::NVTE_PRE_SCALE_BIAS) \
Expand Down
58 changes: 58 additions & 0 deletions transformer_engine/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,66 @@
from transformer_engine.pytorch.tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor import Float8BlockwiseQTensor
from transformer_engine.pytorch.tensor import NVFP4Tensor
from transformer_engine.pytorch.tensor.float8_tensor import (
_make_float8_tensor_in_reduce_ex,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import (
_make_mxfp8_tensor_in_reduce_ex,
)
from transformer_engine.pytorch.tensor.nvfp4_tensor import (
_make_nvfp4_tensor_in_reduce_ex,
)
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
_make_float8_blockwise_tensor_in_reduce_ex,
)

try:
torch._dynamo.config.error_on_nested_jit_trace = False
except AttributeError:
pass # error_on_nested_jit_trace was added in PyTorch 2.2.0

# To allow for safe unpickling of QuantizedTensors when using DCP
# checkpointing with FSDP2. ``tex.DType`` (the pybind11 enum) has its
# ``__reduce_ex__`` / ``__reduce__`` overridden in the C++ binding (see
# ``transformer_engine/common/util/pybind_helper.h``) so its pickle
# stream encodes as ``(tex.DType, (int,))`` and only the class itself
# needs to be allow-listed below.
try:
from torch.serialization import add_safe_globals
import transformer_engine_torch as tex

add_safe_globals(
[
# Storage mixins (used during pickling of internal-only tensors)
QuantizedTensorStorage,
Float8TensorStorage,
MXFP8TensorStorage,
NVFP4TensorStorage,
Float8BlockwiseQTensorStorage,
# Quantizer types embedded in metadata
Quantizer,
Float8Quantizer,
Float8CurrentScalingQuantizer,
MXFP8Quantizer,
NVFP4Quantizer,
Float8BlockQuantizer,
# pybind11 enum used as Quantizer.dtype
tex.DType,
# __reduce_ex__ reconstructors (module-level functions).
_make_float8_tensor_in_reduce_ex,
_make_mxfp8_tensor_in_reduce_ex,
_make_nvfp4_tensor_in_reduce_ex,
_make_float8_blockwise_tensor_in_reduce_ex,
]
)
except (ImportError, AttributeError):
import warnings as _warnings

_warnings.warn(
"transformer_engine: torch.serialization.add_safe_globals is "
"unavailable on this PyTorch version (added in 2.4). DCP "
"checkpointing of QuantizedTensor weights with FSDP2 will not "
"work; upgrade to PyTorch >= 2.4 to enable it.",
RuntimeWarning,
stacklevel=2,
)
5 changes: 4 additions & 1 deletion transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from ..quantized_tensor import QuantizedTensor, QuantizedTensorStorage, Quantizer
from ..tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.nvfp4_tensor import NVFP4Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.storage.float8_tensor_storage import Float8TensorStorage
from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
Expand Down Expand Up @@ -1641,7 +1642,9 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
raise RuntimeError("Weight quantizer has not been initialized")
quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled())
quantizer.internal = False
if is_dtensor and isinstance(quantizer, Float8CurrentScalingQuantizer):
if is_dtensor and isinstance(
quantizer, (Float8CurrentScalingQuantizer, NVFP4Quantizer)
):
device_mesh = dtensor_param.device_mesh
amax_reduction_group = (
device_mesh.get_group(mesh_dim="shard")
Expand Down
62 changes: 57 additions & 5 deletions transformer_engine/pytorch/quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,9 +552,26 @@ def half(self) -> torch.Tensor:
# pylint: disable=missing-function-docstring
return self.dequantize(dtype=torch.float16)

def cpu(self, memory_format=torch.preserve_format) -> torch.Tensor:
def cpu(self, memory_format=torch.preserve_format) -> QuantizedTensor:
"""Move tensor to CPU while preserving the QuantizedTensor type.

Routes through ``aten._to_copy.default`` so the subclass-preserving
handler in ``__torch_dispatch__`` runs (rather than dequantizing).

"""
# pylint: disable=missing-function-docstring
return self.dequantize().cpu(memory_format=memory_format)
return self.to(device=torch.device("cpu"), memory_format=memory_format)

def untyped_storage(self) -> torch.UntypedStorage:
"""Return an empty UntypedStorage on the tensor's device.

``QuantizedTensor`` is a ``_make_wrapper_subclass`` and has no real
backing storage of its own; the actual bytes live in the inner
buffers (e.g. ``_rowwise_data`` / ``_columnwise_data``) which are
an implementation detail of the quantization scheme. Need to define
this method to avoid DCP staging errors with FSDP2.
"""
return torch.UntypedStorage(0, device=self.device)

def expand_as(self, other: torch.Tensor) -> torch.Tensor:
# pylint: disable=missing-function-docstring
Expand Down Expand Up @@ -608,6 +625,36 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
dst.copy_(src)
return None

# _to_copy op (used by .to(device=...), .cpu(), DCP staging).
# Preserve the QuantizedTensor subclass and move all internal
# buffers (data, scales, etc.) to the requested device.
if func == torch.ops.aten._to_copy.default:
tensor = args[0]
kw = dict(kwargs) if kwargs else {}
dtype = kw.get("dtype", None)
if dtype is None or dtype == tensor.dtype:
target_device = kw.get("device", tensor.device) or tensor.device
target_device = torch.device(target_device)
pin_memory = bool(kw.get("pin_memory", False))
non_blocking = bool(kw.get("non_blocking", False))
new_metadata = {"device": target_device}
# Update tensor storage metadata
for key, value in tensor.get_metadata().items():
if isinstance(value, torch.Tensor):
value = value.to(device=target_device, non_blocking=non_blocking)
if pin_memory and target_device.type == "cpu":
value = value.pin_memory()
new_metadata[key] = value
# Update torch Tensor metadata
new_metadata.update(
{
"dtype": tensor.dtype,
"shape": tensor.shape,
"requires_grad": tensor.requires_grad,
}
)
return type(tensor)(**new_metadata)

# View op
if func == torch.ops.aten.view.default:
raise NotImplementedError("{cls.__name__} class does not support tensor views")
Expand Down Expand Up @@ -748,14 +795,19 @@ def make_like(
"""Create new quantized tensor

By default, new tensor has the same attributes and underlying
data. This function is intended to create view of tensors.

data. This function is intended to create a view of ``tensor``,
"""
shape = shape if shape is not None else tensor.shape
dtype = dtype if dtype is not None else tensor.dtype
kwargs = tensor.get_metadata()
kwargs["fake_dtype"] = dtype
return cls(shape=shape, dtype=dtype, requires_grad=requires_grad, **kwargs)
return cls(
shape=shape,
dtype=dtype,
requires_grad=requires_grad,
device=tensor.device,
**kwargs,
)

def to_dtype(self, dtype: torch.dtype) -> QuantizedTensor:
"""Create `QuantizedTensor` with given nominal dtype
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/tensor/_quantization_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def forward(
kwargs = tensor.get_metadata()
for key, val in init_kwargs.items():
kwargs[key] = val
kwargs["device"] = tensor.device
return type(tensor)(tensor.shape, tensor.dtype, **kwargs)

@staticmethod
Expand Down
Loading
Loading