Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"

NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_custom_recipe.xml $TE_PATH/tests/pytorch/test_custom_recipe.py || test_fail "test_custom_recipe.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
Expand Down
55 changes: 25 additions & 30 deletions tests/pytorch/distributed/run_numerics_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE
from transformer_engine.pytorch.custom_recipes import quantization_nvfp4
from transformer_engine.pytorch.custom_recipes import quantization_ref_nvfp4
from transformer_engine.pytorch.custom_recipes import utils
from run_layer_with_overlap import _compare_tensors

Expand Down Expand Up @@ -52,44 +52,39 @@ def get_nvfp4_quantizer_factory():
"""
Create a quantizer factory for NVFP4 reference implementation.

This factory returns NVFP4QuantizerRef instances with RHT and 2D quantization
enabled.
Linear/grouped-linear weight slots get 2D (16x16) quantization without RHT;
every other slot (input, gradient, boundary slots with ``role is None``,
and any unknown tensor type) gets 1D (1x16) quantization with RHT.

Mirrors the canonical "branch on what we care about, default fall-through"
pattern from
``transformer_engine.pytorch.custom_recipes.quantization_recipes_base``;
every slot gets a real :class:`NVFP4QuantizerRef` (``CustomRecipeState``
rejects ``None`` returns).

Returns:
A factory function that takes a role string and returns a quantizer instance
A factory function that takes a QuantizerRole and returns a quantizer instance
"""

def factory(role):
if role == "linear_input":
return quantization_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=True, # RHT enabled for input
)
elif role == "linear_weight":
return quantization_nvfp4.NVFP4QuantizerRef(
is_weight = (
role is not None
and role.module_type in ("linear", "grouped_linear")
and role.tensor_type == "weight"
)
if is_weight:
return quantization_ref_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(16, 16), # 2D quantization for weight
quant_tile_shape=(16, 16),
pow_2_scales=False,
with_rht=False,
)
elif role == "linear_output":
# Output quantization not used
return None
elif role == "linear_grad_output":
return quantization_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=True, # RHT enabled for grad_output
)
elif role == "linear_grad_input":
# Grad input quantization not used
return None
else:
# For any other roles, return None
return None
return quantization_ref_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=True,
)

return factory

Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes import utils


Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/nvfp4/test_nvfp4_group_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes import utils
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.common.recipe import NVFP4BlockScaling
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes import utils
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.common.recipe import NVFP4BlockScaling
Expand Down
51 changes: 24 additions & 27 deletions tests/pytorch/nvfp4/test_nvfp4_module_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
from transformer_engine.pytorch.custom_recipes import quantization_nvfp4
from transformer_engine.pytorch.custom_recipes import quantization_ref_nvfp4
from transformer_engine.pytorch.custom_recipes import utils


Expand Down Expand Up @@ -76,40 +76,37 @@ def get_nvfp4_quantizer_factory(with_rht: bool = False, with_2d_quantization: bo
with_2d_quantization: Whether to use 2D quantization (16x16 tiles for weights)

Returns:
A factory function that takes a role string and returns a quantizer instance
A factory function that takes a QuantizerRole (or None for boundary slots)
and returns a quantizer instance.
"""

# Boundary slots (output, grad_input) get role=None from Linear.get_quantizer_roles
# when no consumer is configured. CustomRecipeState rejects None returns from
# qfactory, so we return a valid quantizer for those slots; it is harmless because
# the GEMM outputs in the high-precision activation dtype, not in NVFP4.
def _default_quantizer():
return quantization_ref_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=with_rht,
)

def factory(role):
if role == "linear_input":
return quantization_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=with_rht,
)
elif role == "linear_weight":
return quantization_nvfp4.NVFP4QuantizerRef(
if role is None:
return _default_quantizer()
if role.tensor_type == "input":
return _default_quantizer()
if role.tensor_type == "weight":
return quantization_ref_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(16, 16) if with_2d_quantization else (1, 16),
pow_2_scales=False,
with_rht=False,
)
elif role == "linear_output":
# Output quantization not used
return None
elif role == "linear_grad_output":
return quantization_nvfp4.NVFP4QuantizerRef(
dtype=utils.Fp4Formats.E2M1,
quant_tile_shape=(1, 16),
pow_2_scales=False,
with_rht=with_rht,
)
elif role == "linear_grad_input":
# Grad input quantization not used
return None
else:
# For any other roles, return None
return None
if role.tensor_type == "grad_output":
return _default_quantizer()
return _default_quantizer()

return factory

Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes import utils
from transformer_engine.common.recipe import NVFP4BlockScaling
from transformer_engine.pytorch.constants import TE_DType
Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.pytorch import NVFP4Quantizer
from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import NVFP4QuantizerRef
from transformer_engine.pytorch.custom_recipes import utils
from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.common.recipe import NVFP4BlockScaling
Expand Down
17 changes: 10 additions & 7 deletions tests/pytorch/test_backward_override.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,23 +400,26 @@ def _snapshot_backward_ctx_state(
) -> tuple[str, bool, object, bool]:
if output.grad_fn is None:
raise RuntimeError("Output tensor has no grad_fn; cannot inspect backward context state.")
# ``Linear`` packs backward state into ``grad_fn.backward_objects``
# (``LinearBwdArgs``); other linear-like modules still set the attributes
# directly on the autograd ctx.
state_holder = getattr(output.grad_fn, "backward_objects", output.grad_fn)
required_attrs = (
"backward_override",
"fp8",
"grad_output_quantizer",
"reduce_and_update_bwd_fp8_tensors",
)
missing_attrs = [attr for attr in required_attrs if not hasattr(output.grad_fn, attr)]
missing_attrs = [attr for attr in required_attrs if not hasattr(state_holder, attr)]
if missing_attrs:
raise RuntimeError(
"grad_fn does not expose required backward context attributes: "
f"{', '.join(missing_attrs)}."
f"Backward context does not expose required attributes: {', '.join(missing_attrs)}."
)
return (
getattr(output.grad_fn, "backward_override"),
bool(getattr(output.grad_fn, "fp8")),
getattr(output.grad_fn, "grad_output_quantizer"),
bool(getattr(output.grad_fn, "reduce_and_update_bwd_fp8_tensors")),
getattr(state_holder, "backward_override"),
bool(getattr(state_holder, "fp8")),
getattr(state_holder, "grad_output_quantizer"),
bool(getattr(state_holder, "reduce_and_update_bwd_fp8_tensors")),
)


Expand Down
Loading
Loading