diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 22636828f9..c35dc4c063 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -26,6 +26,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" NVTE_GROUPED_LINEAR_SINGLE_PARAM=1 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_custom_recipe.xml $TE_PATH/tests/pytorch/test_custom_recipe.py || test_fail "test_custom_recipe.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" diff --git a/tests/pytorch/distributed/run_numerics_exact.py b/tests/pytorch/distributed/run_numerics_exact.py index 0f3d2cbbf0..15ae2dae63 100644 --- a/tests/pytorch/distributed/run_numerics_exact.py +++ b/tests/pytorch/distributed/run_numerics_exact.py @@ -22,7 +22,7 @@ ) from transformer_engine.pytorch import NVFP4Quantizer from transformer_engine.pytorch.constants import NVFP4_BLOCK_SCALING_SIZE -from transformer_engine.pytorch.custom_recipes import quantization_nvfp4 +from transformer_engine.pytorch.custom_recipes import quantization_ref_nvfp4 from transformer_engine.pytorch.custom_recipes import utils from run_layer_with_overlap import _compare_tensors @@ -52,44 +52,39 @@ def get_nvfp4_quantizer_factory(): """ Create a quantizer factory for NVFP4 reference implementation. - This factory returns NVFP4QuantizerRef instances with RHT and 2D quantization - enabled. + Linear/grouped-linear weight slots get 2D (16x16) quantization without RHT; + every other slot (input, gradient, boundary slots with ``role is None``, + and any unknown tensor type) gets 1D (1x16) quantization with RHT. + + Mirrors the canonical "branch on what we care about, default fall-through" + pattern from + ``transformer_engine.pytorch.custom_recipes.quantization_recipes_base``; + every slot gets a real :class:`NVFP4QuantizerRef` (``CustomRecipeState`` + rejects ``None`` returns). Returns: - A factory function that takes a role string and returns a quantizer instance + A factory function that takes a QuantizerRole and returns a quantizer instance """ def factory(role): - if role == "linear_input": - return quantization_nvfp4.NVFP4QuantizerRef( - dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(1, 16), - pow_2_scales=False, - with_rht=True, # RHT enabled for input - ) - elif role == "linear_weight": - return quantization_nvfp4.NVFP4QuantizerRef( + is_weight = ( + role is not None + and role.module_type in ("linear", "grouped_linear") + and role.tensor_type == "weight" + ) + if is_weight: + return quantization_ref_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(16, 16), # 2D quantization for weight + quant_tile_shape=(16, 16), pow_2_scales=False, with_rht=False, ) - elif role == "linear_output": - # Output quantization not used - return None - elif role == "linear_grad_output": - return quantization_nvfp4.NVFP4QuantizerRef( - dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(1, 16), - pow_2_scales=False, - with_rht=True, # RHT enabled for grad_output - ) - elif role == "linear_grad_input": - # Grad input quantization not used - return None - else: - # For any other roles, return None - return None + return quantization_ref_nvfp4.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=True, + ) return factory diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index b939336275..a7ea4f089f 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -9,7 +9,7 @@ from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch import NVFP4Quantizer from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm -from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef +from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py index 7bf288fff7..20a91bf6fe 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize.py @@ -13,7 +13,7 @@ import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.pytorch import NVFP4Quantizer -from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef +from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils from transformer_engine.pytorch.constants import TE_DType from transformer_engine.common.recipe import NVFP4BlockScaling diff --git a/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py b/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py index cf2ae50ee9..d46a874695 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py +++ b/tests/pytorch/nvfp4/test_nvfp4_group_quantize_graph_safe.py @@ -6,7 +6,7 @@ import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.pytorch import NVFP4Quantizer -from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef +from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils from transformer_engine.pytorch.constants import TE_DType from transformer_engine.common.recipe import NVFP4BlockScaling diff --git a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py index a96fea3af0..b57b78eb13 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_module_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_module_exact.py @@ -6,7 +6,7 @@ import torch import transformer_engine.pytorch as te from transformer_engine.common import recipe -from transformer_engine.pytorch.custom_recipes import quantization_nvfp4 +from transformer_engine.pytorch.custom_recipes import quantization_ref_nvfp4 from transformer_engine.pytorch.custom_recipes import utils @@ -76,40 +76,37 @@ def get_nvfp4_quantizer_factory(with_rht: bool = False, with_2d_quantization: bo with_2d_quantization: Whether to use 2D quantization (16x16 tiles for weights) Returns: - A factory function that takes a role string and returns a quantizer instance + A factory function that takes a QuantizerRole (or None for boundary slots) + and returns a quantizer instance. """ + # Boundary slots (output, grad_input) get role=None from Linear.get_quantizer_roles + # when no consumer is configured. CustomRecipeState rejects None returns from + # qfactory, so we return a valid quantizer for those slots; it is harmless because + # the GEMM outputs in the high-precision activation dtype, not in NVFP4. + def _default_quantizer(): + return quantization_ref_nvfp4.NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=with_rht, + ) + def factory(role): - if role == "linear_input": - return quantization_nvfp4.NVFP4QuantizerRef( - dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(1, 16), - pow_2_scales=False, - with_rht=with_rht, - ) - elif role == "linear_weight": - return quantization_nvfp4.NVFP4QuantizerRef( + if role is None: + return _default_quantizer() + if role.tensor_type == "input": + return _default_quantizer() + if role.tensor_type == "weight": + return quantization_ref_nvfp4.NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(16, 16) if with_2d_quantization else (1, 16), pow_2_scales=False, with_rht=False, ) - elif role == "linear_output": - # Output quantization not used - return None - elif role == "linear_grad_output": - return quantization_nvfp4.NVFP4QuantizerRef( - dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(1, 16), - pow_2_scales=False, - with_rht=with_rht, - ) - elif role == "linear_grad_input": - # Grad input quantization not used - return None - else: - # For any other roles, return None - return None + if role.tensor_type == "grad_output": + return _default_quantizer() + return _default_quantizer() return factory diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index 0824a5e7bc..53569d90d9 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -7,7 +7,7 @@ import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.pytorch import NVFP4Quantizer -from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef +from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils from transformer_engine.common.recipe import NVFP4BlockScaling from transformer_engine.pytorch.constants import TE_DType diff --git a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py index 795721df04..2d159dbf6a 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_rht_quantize_exact.py @@ -12,7 +12,7 @@ import transformer_engine.pytorch as te import transformer_engine_torch as tex from transformer_engine.pytorch import NVFP4Quantizer -from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef +from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils from transformer_engine.pytorch.constants import TE_DType from transformer_engine.common.recipe import NVFP4BlockScaling diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index c7c5a5b99d..43e9587d95 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -400,23 +400,26 @@ def _snapshot_backward_ctx_state( ) -> tuple[str, bool, object, bool]: if output.grad_fn is None: raise RuntimeError("Output tensor has no grad_fn; cannot inspect backward context state.") + # ``Linear`` packs backward state into ``grad_fn.backward_objects`` + # (``LinearBwdArgs``); other linear-like modules still set the attributes + # directly on the autograd ctx. + state_holder = getattr(output.grad_fn, "backward_objects", output.grad_fn) required_attrs = ( "backward_override", "fp8", "grad_output_quantizer", "reduce_and_update_bwd_fp8_tensors", ) - missing_attrs = [attr for attr in required_attrs if not hasattr(output.grad_fn, attr)] + missing_attrs = [attr for attr in required_attrs if not hasattr(state_holder, attr)] if missing_attrs: raise RuntimeError( - "grad_fn does not expose required backward context attributes: " - f"{', '.join(missing_attrs)}." + f"Backward context does not expose required attributes: {', '.join(missing_attrs)}." ) return ( - getattr(output.grad_fn, "backward_override"), - bool(getattr(output.grad_fn, "fp8")), - getattr(output.grad_fn, "grad_output_quantizer"), - bool(getattr(output.grad_fn, "reduce_and_update_bwd_fp8_tensors")), + getattr(state_holder, "backward_override"), + bool(getattr(state_holder, "fp8")), + getattr(state_holder, "grad_output_quantizer"), + bool(getattr(state_holder, "reduce_and_update_bwd_fp8_tensors")), ) diff --git a/tests/pytorch/test_custom_recipe.py b/tests/pytorch/test_custom_recipe.py index 536d43adc0..62a6291797 100644 --- a/tests/pytorch/test_custom_recipe.py +++ b/tests/pytorch/test_custom_recipe.py @@ -17,8 +17,16 @@ GroupedLinear, Float8CurrentScalingQuantizer, ) +from transformer_engine.pytorch.quantization import QuantizerRole import transformer_engine.pytorch.ops as te_ops -from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import ( +from transformer_engine.pytorch.custom_recipes.quantization_recipes_base import ( + current_scaling_quantizer_factory, + mxfp8_quantizer_factory, + float8_block_scaling_quantizer_factory, + nvfp4_quantizer_factory, + delayed_scaling_quantizer_factory, +) +from transformer_engine.pytorch.custom_recipes.quantization_ref_nvfp4 import ( nvfp4_ref_rht_2d_quantizer_factory, ) @@ -91,9 +99,9 @@ def test_custom_recipe_sanity(module_type): # Single factory: map roles to quantizers def quantizer_factory(role): - if role in ("linear_input", "linear_weight", "linear_output"): + if role is None: return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if role in ("linear_grad_output", "linear_grad_input"): + if role.tensor_type == "grad_output": return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -119,18 +127,18 @@ def test_custom_recipe_grouped_linear_sanity(): num_gemms = 3 in_features = 64 out_features = 64 - batch = 32 - base = batch // num_gemms - rem = batch % num_gemms - m_splits = [base + (1 if i < rem else 0) for i in range(num_gemms)] + # Each per-GEMM M dim must be a multiple of 16 to satisfy cuBLAS FP8 GEMM's + # leading-dimension alignment requirement on Hopper (sm_90). + m_splits = [16] * num_gemms + batch = sum(m_splits) model = GroupedLinear(num_gemms, in_features, out_features, params_dtype=torch.bfloat16).cuda() inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) def quantizer_factory(role): - if role in ("linear_input", "linear_weight", "linear_output"): + if role is None: return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if role in ("linear_grad_output", "linear_grad_input"): + if role.tensor_type == "grad_output": return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -190,9 +198,9 @@ def test_custom_recipe_matches_current_scaling(): # Custom: single factory returning quantizers per role to match Float8CurrentScaling def quantizer_factory(role): - if role in ("linear_input", "linear_weight", "linear_output"): + if role is None: return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if role in ("linear_grad_output", "linear_grad_input"): + if role.tensor_type == "grad_output": return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -210,7 +218,7 @@ def quantizer_factory(role): assert cus_fwd_w.dtype == tex.DType.kFloat8E4M3 assert cus_fwd_out.dtype == tex.DType.kFloat8E4M3 assert cus_bwd_go.dtype == tex.DType.kFloat8E5M2 - assert cus_bwd_gi.dtype == tex.DType.kFloat8E5M2 + assert cus_bwd_gi.dtype == tex.DType.kFloat8E4M3 # role=None fallback loss_custom = (out_custom.float() * scale.view(1, -1)).sum() loss_custom.backward() @@ -247,9 +255,9 @@ def test_custom_recipe_ops_linear_2_1_layout(): inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) def quantizer_factory(role): - if role in ("linear_input", "linear_weight", "linear_output"): + if role is None: return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") - if role in ("linear_grad_output", "linear_grad_input"): + if role.tensor_type == "grad_output": return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device="cuda") return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") @@ -272,44 +280,47 @@ def test_custom_recipe_factory_invocation_counts_and_cycling(): in_features = 64 out_features = 64 - batch = 8 + # batch must be a multiple of 16 to satisfy cuBLAS FP8 GEMM's leading-dim + # alignment requirement on Hopper (sm_90). + batch = 16 op = Linear(in_features, out_features, params_dtype=torch.bfloat16) inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) - # Counters per role + # Counters per tensor_type. The output (fwd) and grad_input (bwd) + # slots have role=None by default (unknown consumer), so we count + # those separately. counts = { - "linear_input": 0, - "linear_weight": 0, - "linear_output": 0, - "linear_grad_output": 0, - "linear_grad_input": 0, + "input": 0, + "weight": 0, + "grad_output": 0, + None: 0, } def quantizer_factory(role): - if role in counts: - counts[role] += 1 - if role in ("linear_input", "linear_weight", "linear_output"): + if role is None: + counts[None] += 1 return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) - if role in ("linear_grad_output", "linear_grad_input"): + assert isinstance(role, QuantizerRole), f"Expected QuantizerRole, got {type(role)}" + assert role.module_type == "linear" + if role.tensor_type in counts: + counts[role.tensor_type] += 1 + if role.tensor_type == "grad_output": return Float8CurrentScalingQuantizer(tex.DType.kFloat8E5M2, device=torch.device("cuda")) return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) custom = recipe.CustomRecipe(qfactory=quantizer_factory) - # Run fwd+bwd once; for a single GEMM, expect forward to build 3 quantizers (cycled from 1 factory), - # and backward to build 2 quantizers (cycled from 1 factory). with autocast(enabled=True, recipe=custom): out = op(inp) loss = out.float().sum() loss.backward() - # Single GEMM: forward should request input, weight, output; backward grad_output, grad_input - assert counts["linear_input"] == 1 - assert counts["linear_weight"] == 1 - assert counts["linear_output"] == 1 - assert counts["linear_grad_output"] == 1 - assert counts["linear_grad_input"] == 1 + # Forward: input, weight, output(None); backward: grad_output, grad_input(None) + assert counts["input"] == 1 + assert counts["weight"] == 1 + assert counts["grad_output"] == 1 + assert counts[None] == 2, f"Expected 2 None roles (output + grad_input), got {counts[None]}" def test_factories_return_distinct_instances_and_buffers(): @@ -317,9 +328,15 @@ def test_factories_return_distinct_instances_and_buffers(): if not torch.cuda.is_available() or not available: pytest.skip(f"FP8 unsupported on this device: {reason}") - # Two calls should produce distinct quantizer objects and distinct tensor buffers + from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer + + # Two calls should produce distinct quantizer objects with distinct + # scale/amax buffers (Float8Quantizer / delayed-scaling is the class + # that owns persistent per-quantizer state; current scaling has none). def factory(): - return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device=torch.device("cuda")) + scale = torch.ones(1, dtype=torch.float32, device="cuda") + amax = torch.zeros(1, dtype=torch.float32, device="cuda") + return Float8Quantizer(scale=scale, amax=amax, fp8_dtype=tex.DType.kFloat8E4M3) q1 = factory() q2 = factory() @@ -331,3 +348,1504 @@ def factory(): # Mutating one should not affect the other q1.scale.fill_(123.0) assert not torch.equal(q1.scale, q2.scale) + + +def _run_linear_fwd_bwd(model, inp, recipe): + """Run forward + backward with a given recipe and return (output, inp.grad, param grads).""" + with autocast(enabled=True, recipe=recipe): + out = model(inp) + loss = out.float().sum() + loss.backward() + param_grads = {n: p.grad.clone() for n, p in model.named_parameters() if p.grad is not None} + return out.clone(), inp.grad.clone(), param_grads + + +def _make_pair(in_features=128, out_features=128, batch=32, seed=42): + """Create a pair of identical Linear models and matching inputs.""" + torch.manual_seed(seed) + model_ref = Linear(in_features, out_features, params_dtype=torch.bfloat16, bias=False).cuda() + model_cus = Linear(in_features, out_features, params_dtype=torch.bfloat16, bias=False).cuda() + model_cus.load_state_dict(model_ref.state_dict()) + + base_inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16) + inp_ref = base_inp.clone().detach().requires_grad_(True) + inp_cus = base_inp.clone().detach().requires_grad_(True) + return model_ref, model_cus, inp_ref, inp_cus + + +def _assert_match(out_ref, out_cus, grad_ref, grad_cus, pgrads_ref, pgrads_cus): + """Assert exact match of outputs and all gradients.""" + assert torch.allclose( + out_ref, out_cus, rtol=0.0, atol=0.0 + ), f"Forward mismatch: max diff = {(out_ref - out_cus).abs().max()}" + assert torch.allclose( + grad_ref, grad_cus, rtol=0.0, atol=0.0 + ), f"Input grad mismatch: max diff = {(grad_ref - grad_cus).abs().max()}" + for name in pgrads_ref: + assert torch.allclose(pgrads_ref[name], pgrads_cus[name], rtol=0.0, atol=0.0), ( + f"Param grad '{name}' mismatch: max diff = " + f"{(pgrads_ref[name] - pgrads_cus[name]).abs().max()}" + ) + + +def test_factory_matches_delayed_scaling(): + """delayed_scaling_quantizer_factory should produce bit-identical results + to the built-in DelayedScaling recipe.""" + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported: {reason}") + + model_ref, model_cus, inp_ref, inp_cus = _make_pair() + + out_ref, grad_ref, pgrads_ref = _run_linear_fwd_bwd(model_ref, inp_ref, recipe.DelayedScaling()) + out_cus, grad_cus, pgrads_cus = _run_linear_fwd_bwd( + model_cus, inp_cus, recipe.CustomRecipe(qfactory=delayed_scaling_quantizer_factory) + ) + _assert_match(out_ref, out_cus, grad_ref, grad_cus, pgrads_ref, pgrads_cus) + + +def test_factory_matches_current_scaling(): + """current_scaling_quantizer_factory should produce bit-identical results + to the built-in Float8CurrentScaling recipe.""" + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported: {reason}") + + model_ref, model_cus, inp_ref, inp_cus = _make_pair() + + out_ref, grad_ref, pgrads_ref = _run_linear_fwd_bwd( + model_ref, inp_ref, recipe.Float8CurrentScaling() + ) + out_cus, grad_cus, pgrads_cus = _run_linear_fwd_bwd( + model_cus, inp_cus, recipe.CustomRecipe(qfactory=current_scaling_quantizer_factory) + ) + _assert_match(out_ref, out_cus, grad_ref, grad_cus, pgrads_ref, pgrads_cus) + + +def test_factory_matches_mxfp8(): + """mxfp8_quantizer_factory should produce bit-identical results + to the built-in MXFP8BlockScaling recipe.""" + available, reason = te.is_mxfp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"MXFP8 unsupported: {reason}") + + model_ref, model_cus, inp_ref, inp_cus = _make_pair() + + out_ref, grad_ref, pgrads_ref = _run_linear_fwd_bwd( + model_ref, inp_ref, recipe.MXFP8BlockScaling() + ) + out_cus, grad_cus, pgrads_cus = _run_linear_fwd_bwd( + model_cus, inp_cus, recipe.CustomRecipe(qfactory=mxfp8_quantizer_factory) + ) + _assert_match(out_ref, out_cus, grad_ref, grad_cus, pgrads_ref, pgrads_cus) + + +def test_factory_matches_block_scaling(): + """float8_block_scaling_quantizer_factory should produce bit-identical results + to the built-in Float8BlockScaling recipe.""" + available = te.is_fp8_block_scaling_available() + if not torch.cuda.is_available() or not available: + pytest.skip("Float8 block scaling unsupported on this device") + + model_ref, model_cus, inp_ref, inp_cus = _make_pair() + + out_ref, grad_ref, pgrads_ref = _run_linear_fwd_bwd( + model_ref, inp_ref, recipe.Float8BlockScaling() + ) + out_cus, grad_cus, pgrads_cus = _run_linear_fwd_bwd( + model_cus, inp_cus, recipe.CustomRecipe(qfactory=float8_block_scaling_quantizer_factory) + ) + _assert_match(out_ref, out_cus, grad_ref, grad_cus, pgrads_ref, pgrads_cus) + + +def test_factory_matches_nvfp4(): + """nvfp4_quantizer_factory should produce bit-identical results + to the built-in NVFP4BlockScaling recipe.""" + available = te.is_nvfp4_available() + if not torch.cuda.is_available() or not available: + pytest.skip("NVFP4 unsupported on this device") + + model_ref, model_cus, inp_ref, inp_cus = _make_pair() + + out_ref, grad_ref, pgrads_ref = _run_linear_fwd_bwd( + model_ref, inp_ref, recipe.NVFP4BlockScaling() + ) + out_cus, grad_cus, pgrads_cus = _run_linear_fwd_bwd( + model_cus, inp_cus, recipe.CustomRecipe(qfactory=nvfp4_quantizer_factory) + ) + + _assert_match(out_ref, out_cus, grad_ref, grad_cus, pgrads_ref, pgrads_cus) + + +def test_custom_recipe_quantization_targets(): + """Validate fine-grained per-module quantization targeting via QuantizerRole. + + Four transformer layers, each assembled at a different abstraction level. + The default recipe is NVFP4; specific modules are overridden: + + Layer 0 - ``TransformerLayer`` (name="tl0") -> all MXFP8 + Layer 1 - ``TransformerLayer`` (name="tl1") -> NVFP4 (default), + except fc2 overridden to MXFP8 + Layer 2 - ``MultiheadAttention`` + ``LayerNormMLP`` + (name prefix "tl2") -> NVFP4 (default), + except qkv and fc1 overridden to Float8 block-scaling + Layer 3 - Individual blocks (name prefix "tl3") -> NVFP4 (default), + except proj overridden to Float8 current-scaling + + The test validates that: + * The factory receives QuantizerRole objects with correct names + * Different quantizer types are dispatched per module + * Forward + backward complete successfully through all four layers + """ + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + if not te.is_mxfp8_available(): + pytest.skip("MXFP8 unsupported on this device") + if not te.is_nvfp4_available(): + pytest.skip("NVFP4 unsupported on this device") + if not te.is_fp8_block_scaling_available(): + pytest.skip("Float8 block scaling unsupported on this device") + + torch.manual_seed(42) + + H = 64 # hidden_size + FFN = 64 # ffn_hidden_size + NH = 4 # num_heads + KV = H // NH # kv_channels + B = 4 # batch + S = 8 # seq_len + common = dict(params_dtype=torch.bfloat16, bias=False) + + # Layer 0: TransformerLayer -> MXFP8 + tl0 = te.TransformerLayer( + H, + FFN, + NH, + hidden_dropout=0.0, + attention_dropout=0.0, + name="tl0", + **common, + ).cuda() + + # Layer 1: TransformerLayer -> NVFP4 default, fc2 overridden to MXFP8 + tl1 = te.TransformerLayer( + H, + FFN, + NH, + hidden_dropout=0.0, + attention_dropout=0.0, + name="tl1", + **common, + ).cuda() + + # Layer 2: MHA + LayerNormMLP -> NVFP4 default, qkv and fc1 to block-scaling + tl2_mha = te.MultiheadAttention( + H, + NH, + KV, + attention_dropout=0.0, + input_layernorm=True, + return_bias=True, + name="tl2.self_attention", + **common, + ).cuda() + tl2_mlp = LayerNormMLP(H, FFN, name="tl2.layernorm_mlp", **common).cuda() + + # Layer 3: Individual blocks with DPA -> NVFP4 default, proj to current-scaling + tl3_qkv = LayerNormLinear(H, 3 * H, name="tl3.qkv", **common).cuda() + tl3_dpa = te.DotProductAttention(NH, KV, attention_dropout=0.0, name="tl3.core_attention") + tl3_proj = Linear(H, H, name="tl3.proj", **common).cuda() + tl3_fc1 = LayerNormLinear(H, FFN, name="tl3.fc1", **common).cuda() + tl3_fc2 = Linear(FFN, H, name="tl3.fc2", **common).cuda() + + # ------------------------------------------------------------------ + # Recording + dispatching factory + # ------------------------------------------------------------------ + recorded_roles = [] + + def targeting_factory(role): + recorded_roles.append(role) + + if role is None: + return nvfp4_quantizer_factory(role) + + assert isinstance(role, QuantizerRole), f"Expected QuantizerRole, got {type(role)}" + + # Layer 0 (tl0.*): all MXFP8 + if role.name.startswith("tl0"): + return mxfp8_quantizer_factory(role) + + # Layer 1 (tl1.*): NVFP4 default, but fc2 overridden to MXFP8 + if role.name == "tl1.layernorm_mlp.fc2": + return mxfp8_quantizer_factory(role) + + # Layer 2: block scaling for qkv and fc1, rest falls through to default + if role.name == "tl2.self_attention.layernorm_linear_qkv": + return float8_block_scaling_quantizer_factory(role) + if role.name == "tl2.layernorm_mlp.fc1": + return float8_block_scaling_quantizer_factory(role) + + # Layer 3: current-scaling for proj, rest falls through to default + if role.name == "tl3.proj": + return current_scaling_quantizer_factory(role) + + # Default: NVFP4 + return nvfp4_quantizer_factory(role) + + custom_recipe = recipe.CustomRecipe(qfactory=targeting_factory) + + # ------------------------------------------------------------------ + # Forward + backward + # ------------------------------------------------------------------ + inp = torch.randn(S, B, H, device="cuda", dtype=torch.bfloat16, requires_grad=True) + + with autocast(enabled=True, recipe=custom_recipe): + # Layer 0 & 1: TransformerLayer + h = tl1(tl0(inp)) + + # Layer 2: MHA + residual + LayerNormMLP + residual + attn_out, _ = tl2_mha(h) + h = h + attn_out + h = h + tl2_mlp(h) + + # Layer 3: individual blocks with DPA + residual = h + qkv = tl3_qkv(h).view(S, B, 3, NH, KV) + q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] + attn = tl3_dpa(q, k, v).view(S, B, H) + h = residual + tl3_proj(attn) + residual = h + h = residual + tl3_fc2(torch.nn.functional.gelu(tl3_fc1(h))) + + loss = h.float().sum() + loss.backward() + + # ------------------------------------------------------------------ + # Assertions + # ------------------------------------------------------------------ + + assert inp.grad is not None, "Input gradient is None" + + # -- Name propagation check -- + # The factory dispatches on role.name, so if a TE module fails to propagate + # names (e.g. TransformerLayer -> MHA -> LayerNormLinear) the factory would + # silently fall through to the default recipe. The quantizer-type assertions + # below would catch that too, but checking names explicitly gives a clearer + # error message pointing at the broken name rather than a wrong quantizer type. + role_names = {r.name for r in recorded_roles if r is not None} + + def _tl_names(prefix): + """Expected role names for a standard TransformerLayer with given prefix.""" + return { + f"{prefix}.self_attention.layernorm_linear_qkv", + f"{prefix}.self_attention.proj", + f"{prefix}.layernorm_mlp.fc1", + f"{prefix}.layernorm_mlp.fc2", + } + + all_expected = ( + _tl_names("tl0") + | _tl_names("tl1") + | _tl_names("tl2") + | {"tl3.qkv", "tl3.proj", "tl3.fc1", "tl3.fc2"} + ) + missing = all_expected - role_names + assert not missing, ( + f"Expected module names not seen in QuantizerRole.name: {missing}\n" + f"Recorded names: {sorted(role_names)}" + ) + + for r in recorded_roles: + if r is not None and r.module_type: + assert r.module_type in ( + "linear", + "dpa", + ), f"Unexpected module_type={r.module_type} for role {r}" + + # -- Quantizer-type checks -- + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer + from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer + + def _check_q(mod, expected_cls, label=""): + q = mod.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] + assert isinstance(q, expected_cls), ( + f"{mod.name}{' (' + label + ')' if label else ''}: " + f"expected {expected_cls.__name__}, got {type(q).__name__}" + ) + + # Layer 0: all MXFP8 + _check_q(tl0.self_attention.layernorm_qkv, MXFP8Quantizer) + _check_q(tl0.self_attention.proj, MXFP8Quantizer) + + # Layer 1: NVFP4 default, fc2 overridden to MXFP8 + _check_q(tl1.self_attention.layernorm_qkv, NVFP4Quantizer, "default") + _check_q(tl1.self_attention.proj, NVFP4Quantizer, "default") + assert any( + r is not None and r.name == "tl1.layernorm_mlp.fc2" and r.tensor_type == "input" + for r in recorded_roles + ), "tl1.layernorm_mlp.fc2 input role not recorded" + + # Layer 2: block-scaling on qkv and fc1, NVFP4 on proj and fc2 + _check_q(tl2_mha.layernorm_qkv, Float8BlockQuantizer) + _check_q(tl2_mha.proj, NVFP4Quantizer, "default") + + # Layer 3: current-scaling on proj, NVFP4 on everything else + _check_q(tl3_proj, Float8CurrentScalingQuantizer) + for mod in [tl3_qkv, tl3_fc1, tl3_fc2]: + _check_q(mod, NVFP4Quantizer, "default") + + +def test_grouped_linear_module_type_dispatch(): + """Verify GroupedLinear emits module_type='grouped_linear' so factories can + distinguish it from regular Linear (critical for MoE mixed-recipe dispatch).""" + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + + torch.manual_seed(0) + + num_gemms = 2 + in_features = 64 + out_features = 64 + # Each per-GEMM M dim must be a multiple of 16 to satisfy cuBLAS FP8 GEMM's + # leading-dimension alignment requirement on Hopper (sm_90). + batch = 32 + m_splits = [batch // num_gemms] * num_gemms + + model = GroupedLinear( + num_gemms, in_features, out_features, params_dtype=torch.bfloat16, name="experts" + ).cuda() + inp = torch.randn(batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True) + + recorded_roles = [] + + def recording_factory(role): + recorded_roles.append(role) + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + + custom_recipe = recipe.CustomRecipe(qfactory=recording_factory) + + with autocast(enabled=True, recipe=custom_recipe): + out = model(inp, m_splits) + loss = out.float().sum() + loss.backward() + + non_none = [r for r in recorded_roles if r is not None] + assert len(non_none) > 0, "No QuantizerRole objects recorded" + for r in non_none: + assert isinstance(r, QuantizerRole) + assert ( + r.module_type == "grouped_linear" + ), f"Expected module_type='grouped_linear', got '{r.module_type}'" + assert r.name == "experts", f"Expected name='experts', got '{r.name}'" + + fwd_types = {r.tensor_type for r in non_none if r.tensor_type in ("input", "weight")} + bwd_types = {r.tensor_type for r in non_none if r.tensor_type == "grad_output"} + assert "input" in fwd_types, "Missing 'input' tensor_type in forward roles" + assert "weight" in fwd_types, "Missing 'weight' tensor_type in forward roles" + assert "grad_output" in bwd_types, "Missing 'grad_output' tensor_type in backward roles" + + +def test_delayed_scaling_request_wiring(): + """Shared buffers, correct views, Float8Quantizer instances.""" + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported: {reason}") + + from transformer_engine.pytorch.quantization import ( + DelayedScalingRequest, + CustomRecipeState, + ) + from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer + from transformer_engine.common.recipe import Format + + def ds_factory(role): + return DelayedScalingRequest(fp8_format=Format.HYBRID, amax_history_len=16) + + custom_recipe = recipe.CustomRecipe(qfactory=ds_factory) + + # 3 quantizers (input, weight, output) like a Linear fwd + state = CustomRecipeState( + custom_recipe, + mode="forward", + num_quantizers=3, + roles=[ + QuantizerRole(module_type="linear", tensor_type="input"), + QuantizerRole(module_type="linear", tensor_type="weight"), + QuantizerRole(module_type="linear", tensor_type="output"), + ], + ) + quantizers = state.make_quantizers() + + # All quantizers should be Float8Quantizer + assert len(quantizers) == 3 + for q in quantizers: + assert isinstance(q, Float8Quantizer), f"Expected Float8Quantizer, got {type(q).__name__}" + + # Managed state should exist + assert state._has_delayed_scaling + assert state.scale is not None + assert state.amax_history is not None + + # Shared buffers: scale shape = (3,), amax_history shape = (16, 3) + assert state.scale.shape == (3,) + assert state.amax_history.shape == (16, 3) + + # Each quantizer's scale should be a view into the shared buffer + for i, q in enumerate(quantizers): + assert q.scale.data_ptr() == state.scale[i].data_ptr() + + # Each quantizer's amax should be a view into amax_history[0] + for i, q in enumerate(quantizers): + assert q.amax.data_ptr() == state.amax_history[0][i].reshape((1,)).data_ptr() + + # Inner recipe should be a DelayedScaling + inner = state._inner_delayed_scaling_recipe + assert isinstance(inner, recipe.DelayedScaling) + assert inner.amax_history_len == 16 + assert inner.fp8_format == Format.HYBRID + + +def test_custom_recipe_mixed_ds_and_stateless(): + """Mix DelayedScalingRequest + stateless quantizers in same CustomRecipeState.""" + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported: {reason}") + + from transformer_engine.pytorch.quantization import ( + DelayedScalingRequest, + CustomRecipeState, + ) + from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer + from transformer_engine.common.recipe import Format + + def mixed_factory(role): + # Only weight gets delayed scaling, rest get current scaling + if role is not None and role.tensor_type == "weight": + return DelayedScalingRequest(fp8_format=Format.HYBRID) + return Float8CurrentScalingQuantizer(tex.DType.kFloat8E4M3, device="cuda") + + custom_recipe = recipe.CustomRecipe(qfactory=mixed_factory) + + # 3 quantizers: input(current), weight(DS), output(current) + state = CustomRecipeState( + custom_recipe, + mode="forward", + num_quantizers=3, + roles=[ + QuantizerRole(module_type="linear", tensor_type="input"), + QuantizerRole(module_type="linear", tensor_type="weight"), + QuantizerRole(module_type="linear", tensor_type="output"), + ], + ) + quantizers = state.make_quantizers() + assert len(quantizers) == 3 + + # Slot 0 (input): current scaling + assert isinstance(quantizers[0], Float8CurrentScalingQuantizer) + # Slot 1 (weight): delayed scaling + assert isinstance(quantizers[1], Float8Quantizer) + # Slot 2 (output): current scaling + assert isinstance(quantizers[2], Float8CurrentScalingQuantizer) + + # Only 1 DS request => shared buffers have size 1 + assert state._has_delayed_scaling + assert state.scale.shape == (1,) + assert state.amax_history.shape == (1024, 1) + + +def test_custom_recipe_ds_multi_step(): + """amax_history updates across multiple forward steps.""" + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported: {reason}") + + from transformer_engine.pytorch.quantization import DelayedScalingRequest + from transformer_engine.common.recipe import Format + + def ds_factory(role): + return DelayedScalingRequest(fp8_format=Format.HYBRID) + + in_features = 128 + out_features = 128 + batch = 32 + num_steps = 3 + + torch.manual_seed(99) + model = Linear(in_features, out_features, params_dtype=torch.bfloat16, bias=False).cuda() + custom = recipe.CustomRecipe(qfactory=ds_factory) + + amax_snapshots = [] + for step in range(num_steps): + inp = torch.randn( + batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True + ) + with autocast(enabled=True, recipe=custom): + out = model(inp) + loss = out.float().sum() + loss.backward() + + # Capture amax_history snapshot + fwd_state = model.fp8_meta["scaling_fwd"] + amax_snapshots.append(fwd_state.amax_history.clone()) + + # After 3 steps, amax_history should have been updated at least once + # The first row (amax_history[0]) should differ from the initial zeros + # after the first step + assert not torch.all(amax_snapshots[0] == 0), "amax_history should be updated after first step" + + +# ---------------------------------------------------------------------- +# State preservation across role-driven rebuilds +# ---------------------------------------------------------------------- +# +# Setting ``output_quantizer_role`` / ``grad_input_quantizer_role`` to a +# different value flips ``fp8_meta_tensors_initialized = False`` so the +# next ``set_meta_tensor`` call rebuilds the recipe state and quantizers +# with up-to-date roles. That rebuild MUST preserve persistent training +# buffers (delayed scaling's ``scale`` / ``amax_history``); otherwise +# checkpointed amax history is silently destroyed on the first forward +# pass after ``load_state_dict`` (when MHA wires boundary roles for the +# first time on the freshly-loaded module). The buffers must also be +# preserved by tensor-object identity, not just by value: the +# ``FP8GlobalStateManager`` reduction buffer holds a direct reference to +# the tensor created at first init, so any rebuild that allocates fresh +# tensors would break amax all-reduce. + + +def test_role_change_preserves_delayed_scaling_state(): + """Built-in DelayedScaling: role-driven rebuild preserves scale / amax_history. + + Stashes sentinel values into the buffers, forces a rebuild via the role + setter, and verifies values + tensor-object identity survive. + """ + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported: {reason}") + + torch.manual_seed(0) + model = Linear(64, 64, params_dtype=torch.bfloat16, bias=False).cuda() + inp = torch.randn(32, 64, device="cuda", dtype=torch.bfloat16, requires_grad=True) + fp8_recipe = recipe.DelayedScaling(amax_history_len=8) + + # Initialize state via a forward pass. + with autocast(enabled=True, recipe=fp8_recipe): + model(inp).float().sum().backward() + assert model.fp8_meta_tensors_initialized + + state_before = model.fp8_meta["scaling_fwd"] + state_before.scale.fill_(3.14) + state_before.amax_history.fill_(2.71) + scale_obj_id = id(state_before.scale) + amax_obj_id = id(state_before.amax_history) + scale_data_ptr = state_before.scale.data_ptr() + amax_data_ptr = state_before.amax_history.data_ptr() + + # Trigger role-driven invalidation. Setting a non-None role flips + # ``fp8_meta_tensors_initialized = False`` so the next ``set_meta_tensor`` + # falls through and creates a fresh ``RecipeState``. + model.output_quantizer_role = QuantizerRole( + module_type="dpa", tensor_type="qkv", name="downstream" + ) + assert not model.fp8_meta_tensors_initialized + + # Trigger the rebuild directly (no forward, so we can compare buffers exactly). + model.init_fp8_meta_tensors(fp8_recipe) + assert model.fp8_meta_tensors_initialized + + state_after = model.fp8_meta["scaling_fwd"] + assert state_after is not state_before, "state should have been rebuilt" + # Tensor objects must be inherited (not freshly allocated) so the + # FP8GlobalStateManager reduction buffer's reference stays valid. + assert ( + id(state_after.scale) == scale_obj_id + ), "scale tensor object replaced by rebuild; global reduction buffer would dangle" + assert id(state_after.amax_history) == amax_obj_id + assert state_after.scale.data_ptr() == scale_data_ptr + assert state_after.amax_history.data_ptr() == amax_data_ptr + # Sentinel values must be preserved. + assert state_after.scale.eq(3.14).all(), "scale was wiped by role-driven rebuild" + assert state_after.amax_history.eq(2.71).all(), "amax_history was wiped" + + +def test_role_change_preserves_custom_delayed_scaling_state(): + """CustomRecipe + DelayedScalingRequest: role-driven rebuild preserves inner DSRS. + + Same property as the built-in case, but for the + ``CustomRecipeState`` -> composed ``DelayedScalingRecipeState`` path. + The inner DS state must be re-used across the rebuild so its + accumulated buffers (and any external references to them) survive. + """ + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported: {reason}") + + from transformer_engine.pytorch.quantization import ( + CustomRecipeState, + DelayedScalingRequest, + ) + from transformer_engine.common.recipe import Format + + def ds_factory(role): + return DelayedScalingRequest(fp8_format=Format.HYBRID, amax_history_len=8) + + torch.manual_seed(0) + model = Linear(64, 64, params_dtype=torch.bfloat16, bias=False).cuda() + inp = torch.randn(32, 64, device="cuda", dtype=torch.bfloat16, requires_grad=True) + custom_recipe = recipe.CustomRecipe(qfactory=ds_factory) + + # Initialize state via a forward pass. + with autocast(enabled=True, recipe=custom_recipe): + model(inp).float().sum().backward() + assert model.fp8_meta_tensors_initialized + + state_before = model.fp8_meta["scaling_fwd"] + assert isinstance(state_before, CustomRecipeState) + assert state_before._has_delayed_scaling + inner_before = state_before._ds_state + inner_before.scale.fill_(3.14) + inner_before.amax_history.fill_(2.71) + scale_obj_id = id(inner_before.scale) + amax_obj_id = id(inner_before.amax_history) + + # Trigger role-driven invalidation. + model.output_quantizer_role = QuantizerRole( + module_type="dpa", tensor_type="qkv", name="downstream" + ) + assert not model.fp8_meta_tensors_initialized + + # Rebuild. + model.init_fp8_meta_tensors(custom_recipe) + assert model.fp8_meta_tensors_initialized + + state_after = model.fp8_meta["scaling_fwd"] + assert isinstance(state_after, CustomRecipeState) + assert state_after is not state_before, "outer CustomRecipeState should have been rebuilt" + assert state_after._has_delayed_scaling, "rebuild lost the inner DS state" + inner_after = state_after._ds_state + # Inner DSRS object identity is preserved (we reuse the existing inner state), + # which means its buffers' tensor objects are also preserved. + assert ( + inner_after is inner_before + ), "inner DSRS replaced; FP8GlobalStateManager reduction buffer would dangle" + assert id(inner_after.scale) == scale_obj_id + assert id(inner_after.amax_history) == amax_obj_id + # Sentinel values preserved. + assert inner_after.scale.eq(3.14).all() + assert inner_after.amax_history.eq(2.71).all() + + +def test_role_change_does_not_invalidate_when_role_unchanged(): + """Setting the role to its current value is a no-op (no rebuild).""" + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported: {reason}") + + torch.manual_seed(0) + model = Linear(64, 64, params_dtype=torch.bfloat16, bias=False).cuda() + inp = torch.randn(32, 64, device="cuda", dtype=torch.bfloat16, requires_grad=True) + fp8_recipe = recipe.DelayedScaling(amax_history_len=8) + + role = QuantizerRole(module_type="dpa", tensor_type="qkv", name="x") + model.output_quantizer_role = role # initial set: state not yet built, no-op + + with autocast(enabled=True, recipe=fp8_recipe): + model(inp).float().sum().backward() + assert model.fp8_meta_tensors_initialized + + # Re-setting the same role value must not invalidate. + model.output_quantizer_role = QuantizerRole(module_type="dpa", tensor_type="qkv", name="x") + assert ( + model.fp8_meta_tensors_initialized + ), "Setting role to an equal value should be a no-op (frozen-dataclass __eq__)" + + +def test_custom_recipe_dpa_fp8(): + """DotProductAttention forward+backward with CustomRecipe and role-based mixed quantizers. + + Uses the nvfp4_linear_fp8_dpa_factory which dispatches: + * DPA S/dP slots -> DelayedScalingRequest (stateful) + * DPA QKV/O/dO/dQKV slots -> Float8CurrentScalingQuantizer + * Linear slots -> NVFP4Quantizer + """ + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + if not te.is_nvfp4_available(): + pytest.skip("NVFP4 unsupported on this device") + + from transformer_engine.pytorch.utils import get_device_compute_capability + + cc = get_device_compute_capability() + if cc < (9, 0) or cc >= (12, 0): + pytest.skip(f"FP8 attention not supported on sm{cc[0]*10+cc[1]}") + + from transformer_engine.pytorch.quantization import ( + DelayedScalingRequest, + CustomRecipeState, + ) + from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, + ) + from transformer_engine.pytorch.custom_recipes.quantization_factory_examples import ( + nvfp4_linear_fp8_dpa_factory, + ) + + torch.manual_seed(42) + + H = 64 + NH = 4 + KV = H // NH + B = 2 + S = 32 + + # Build a small model: Linear -> DPA -> Linear + qkv_proj = Linear(H, 3 * H, params_dtype=torch.bfloat16, bias=False, name="qkv").cuda() + dpa = te.DotProductAttention( + NH, KV, attention_dropout=0.0, qkv_format="bshd", name="core_attention" + ) + out_proj = Linear(H, H, params_dtype=torch.bfloat16, bias=False, name="proj").cuda() + + custom_recipe = recipe.CustomRecipe( + qfactory=nvfp4_linear_fp8_dpa_factory, + fp8_dpa=True, + ) + + inp = torch.randn(B, S, H, device="cuda", dtype=torch.bfloat16, requires_grad=True) + + with autocast(enabled=True, recipe=custom_recipe): + qkv = qkv_proj(inp).view(B, S, 3, NH, KV) + q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] + attn_out = dpa(q, k, v, qkv_format="bshd").reshape(B, S, H) + out = out_proj(attn_out) + + loss = out.float().sum() + loss.backward() + + assert inp.grad is not None, "Input gradient should exist" + + # Verify DPA recipe state is CustomRecipeState + fwd_state = dpa.fp8_meta["scaling_fwd"] + assert isinstance( + fwd_state, CustomRecipeState + ), f"Expected CustomRecipeState for DPA fwd, got {type(fwd_state).__name__}" + + # Verify DPA quantizers: 9 forward slots (3 GEMMs x 3) + fwd_quantizers = dpa.quantizers["scaling_fwd"] + assert len(fwd_quantizers) == 9, f"Expected 9 fwd quantizers, got {len(fwd_quantizers)}" + + # Slots 0-2: QKV (GEMM1) -> current scaling (role: module_type="dpa") + # Slots 3-5: O (GEMM2) -> current scaling (role: name hint "dpa_output") + # Slots 6-8: S (GEMM3) -> delayed scaling (Float8Quantizer from DelayedScalingRequest) + for i in range(6): + assert isinstance(fwd_quantizers[i], Float8CurrentScalingQuantizer), ( + f"Slot {i} (QKV/O): expected Float8CurrentScalingQuantizer, " + f"got {type(fwd_quantizers[i]).__name__}" + ) + for i in range(6, 9): + assert isinstance(fwd_quantizers[i], Float8Quantizer), ( + f"Slot {i} (S): expected Float8Quantizer (delayed scaling), " + f"got {type(fwd_quantizers[i]).__name__}" + ) + + # Verify DS state exists for the S/dP delayed scaling requests + assert fwd_state._has_delayed_scaling, "DPA fwd state should have delayed scaling for S slots" + + # Verify backward quantizers exist too + bwd_quantizers = dpa.quantizers["scaling_bwd"] + assert len(bwd_quantizers) == 6, f"Expected 6 bwd quantizers, got {len(bwd_quantizers)}" + + # Slots 0-1: dQKV (GEMM1) -> current scaling (role: name hint "dpa_grad_input") + # Slots 2-3: dO (GEMM2) -> current scaling (role: module_type="dpa") + # Slots 4-5: dP (GEMM3) -> delayed scaling + for i in range(4): + assert isinstance(bwd_quantizers[i], Float8CurrentScalingQuantizer), ( + f"Bwd slot {i} (dQKV/dO): expected Float8CurrentScalingQuantizer, " + f"got {type(bwd_quantizers[i]).__name__}" + ) + for i in range(4, 6): + assert isinstance(bwd_quantizers[i], Float8Quantizer), ( + f"Bwd slot {i} (dP): expected Float8Quantizer (delayed scaling), " + f"got {type(bwd_quantizers[i]).__name__}" + ) + + # Linear modules should have CustomRecipeState with NVFP4 quantizers + from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + + qkv_fwd = qkv_proj.fp8_meta["scaling_fwd"] + assert isinstance( + qkv_fwd, CustomRecipeState + ), f"Expected CustomRecipeState for qkv_proj, got {type(qkv_fwd).__name__}" + qkv_fwd_quantizers = qkv_proj.quantizers["scaling_fwd"] + for i, q in enumerate(qkv_fwd_quantizers): + if q is not None: + assert isinstance( + q, NVFP4Quantizer + ), f"qkv_proj fwd slot {i}: expected NVFP4Quantizer, got {type(q).__name__}" + + +def test_custom_recipe_dpa_mxfp8(): + """DotProductAttention forward+backward with CustomRecipe and MXFP8 attention. + + Uses the nvfp4_linear_mxfp8_dpa_factory which dispatches: + * DPA roles (QKV/O/S/dO/dP/dQKV) -> MXFP8Quantizer (S/dP later nulled + out by ``get_attention_quantizers`` since the MXFP8 fused-attention + kernel handles those slots internally) + * DPA boundary hints -> MXFP8Quantizer + * Linear slots -> NVFP4Quantizer + + Mirrors the documented "NVFP4 linear + MXFP8 attention" combo from + ``dot_product_attention.py``'s recipe-combination table. + """ + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported on this device: {reason}") + if not te.is_mxfp8_available(): + pytest.skip("MXFP8 unsupported on this device") + if not te.is_nvfp4_available(): + pytest.skip("NVFP4 unsupported on this device") + + from transformer_engine.pytorch.utils import get_device_compute_capability + + cc = get_device_compute_capability() + if cc < (9, 0) or cc >= (12, 0): + pytest.skip(f"FP8 attention not supported on sm{cc[0]*10+cc[1]}") + + from transformer_engine.pytorch.quantization import CustomRecipeState + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer + from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + from transformer_engine.pytorch.custom_recipes.quantization_factory_examples import ( + nvfp4_linear_mxfp8_dpa_factory, + ) + + torch.manual_seed(42) + + # MXFP8 fused attention requires s_q % 128 == 0, s_kv % 128 == 0, + # d_qk % 32 == 0, d_v % 32 == 0. + H = 128 + NH = 4 + KV = H // NH # 32 + B = 2 + S = 128 + + # Build a small model: Linear -> DPA -> Linear + qkv_proj = Linear(H, 3 * H, params_dtype=torch.bfloat16, bias=False, name="qkv").cuda() + dpa = te.DotProductAttention( + NH, KV, attention_dropout=0.0, qkv_format="bshd", name="core_attention" + ) + out_proj = Linear(H, H, params_dtype=torch.bfloat16, bias=False, name="proj").cuda() + + custom_recipe = recipe.CustomRecipe( + qfactory=nvfp4_linear_mxfp8_dpa_factory, + fp8_dpa=True, + ) + + inp = torch.randn(B, S, H, device="cuda", dtype=torch.bfloat16, requires_grad=True) + + with autocast(enabled=True, recipe=custom_recipe): + qkv = qkv_proj(inp).view(B, S, 3, NH, KV) + q, k, v = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2] + # MXFP8 fused attention requires s_q % 128 == 0, s_kv % 128 == 0, + # d_qk % 32 == 0, d_v % 32 == 0. The B/S/H values above are picked + # to satisfy all four constraints (S=128, KV=32). + attn_out = dpa(q, k, v, qkv_format="bshd").reshape(B, S, H) + out = out_proj(attn_out) + + loss = out.float().sum() + loss.backward() + + assert inp.grad is not None, "Input gradient should exist" + + # DPA recipe state should be CustomRecipeState + fwd_state = dpa.fp8_meta["scaling_fwd"] + assert isinstance( + fwd_state, CustomRecipeState + ), f"Expected CustomRecipeState for DPA fwd, got {type(fwd_state).__name__}" + + # All DPA slots should resolve to MXFP8Quantizer (the factory returns MXFP8 + # uniformly for DPA roles; S/dP nulling happens inside get_attention_quantizers + # at fused-attn dispatch time, not here). + fwd_quantizers = dpa.quantizers["scaling_fwd"] + assert len(fwd_quantizers) == 9, f"Expected 9 fwd quantizers, got {len(fwd_quantizers)}" + for i, q in enumerate(fwd_quantizers): + assert isinstance( + q, MXFP8Quantizer + ), f"DPA fwd slot {i}: expected MXFP8Quantizer, got {type(q).__name__}" + + bwd_quantizers = dpa.quantizers["scaling_bwd"] + assert len(bwd_quantizers) == 6, f"Expected 6 bwd quantizers, got {len(bwd_quantizers)}" + for i, q in enumerate(bwd_quantizers): + assert isinstance( + q, MXFP8Quantizer + ), f"DPA bwd slot {i}: expected MXFP8Quantizer, got {type(q).__name__}" + + # MXFP8 attention has no delayed-scaling state (no S/dP DS-request slots). + assert ( + not fwd_state._has_delayed_scaling + ), "DPA fwd state should NOT have delayed scaling for the all-MXFP8 factory" + + # Linear modules should still be NVFP4 + qkv_fwd = qkv_proj.fp8_meta["scaling_fwd"] + assert isinstance( + qkv_fwd, CustomRecipeState + ), f"Expected CustomRecipeState for qkv_proj, got {type(qkv_fwd).__name__}" + qkv_fwd_quantizers = qkv_proj.quantizers["scaling_fwd"] + for i, q in enumerate(qkv_fwd_quantizers): + if q is not None: + assert isinstance( + q, NVFP4Quantizer + ), f"qkv_proj fwd slot {i}: expected NVFP4Quantizer, got {type(q).__name__}" + + +def test_custom_recipe_debug_tool_compat(): + """Custom recipe quantizers should work when wrapped by DebugQuantizer. + + Verifies that the debug tool (nvdlfw_inspect) can wrap custom-recipe + quantizers produced via QuantizerRole dispatch without errors. + """ + try: + import nvdlfw_inspect.api as debug_api + except ImportError: + pytest.skip("nvdlfw_inspect not installed") + + available, reason = te.is_fp8_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 unsupported: {reason}") + + import pathlib + import tempfile + + from transformer_engine.debug.pytorch.debug_state import TEDebugState + + te_debug_features = str( + pathlib.Path(__file__).resolve().parent.parent.parent + / "transformer_engine" + / "debug" + / "features" + ) + + # Log config that keeps DebugQuantizer active (not bypassed by no_debug_features_active) + log_config = """log: + layers: + layer_types: [linear] + enabled: True + transformer_engine: + LogTensorStats: + enabled: True + tensors: [activation, weight] + stats: [max] + start_step: 0 + end_step: 3 +""" + + torch.manual_seed(0) + + in_features = 64 + out_features = 64 + batch = 16 + + with tempfile.NamedTemporaryFile(mode="w+", suffix=".yaml", delete=False) as cfg: + cfg.write(log_config) + cfg.flush() + config_path = cfg.name + + try: + with tempfile.TemporaryDirectory() as log_dir: + debug_api.initialize( + config_file=config_path, + feature_dirs=te_debug_features, + log_dir=log_dir, + ) + + model = Linear( + in_features, out_features, params_dtype=torch.bfloat16, name="layer" + ).cuda() + + custom_recipe = recipe.CustomRecipe(qfactory=current_scaling_quantizer_factory) + + assert TEDebugState.debug_enabled, "Debug mode should be active" + + for _ in range(3): + inp_step = torch.randn( + batch, in_features, device="cuda", dtype=torch.bfloat16, requires_grad=True + ) + with autocast(enabled=True, recipe=custom_recipe): + out = model(inp_step) + out.float().sum().backward() + debug_api.step() + + assert inp_step.grad is not None, "Input gradient should exist" + + log_files = list(pathlib.Path(log_dir).rglob("*.log")) + assert ( + len(log_files) > 0 + ), f"Debug log output expected in {log_dir} but no .log files found" + finally: + debug_api.end_debug() + TEDebugState._reset() + import os + + os.unlink(config_path) + + +# ---------------------------------------------------------------------- +# Role-aware dispatch in built-in block-scaling recipe states +# ---------------------------------------------------------------------- +# +# These tests exercise ``Float8BlockScalingRecipeState.make_quantizers`` and +# ``NVFP4BlockScalingRecipeState.make_quantizers`` directly to verify that +# per-slot dispatch is driven by ``QuantizerRole.tensor_type`` with a +# positional fallback that matches the legacy behavior. They construct the +# recipe state objects directly (no autocast / no fwd pass) so they don't +# depend on any module's ``get_quantizer_roles`` implementation. + + +def _fp8block_role(tensor_type): + """QuantizerRole helper for FP8-block tests.""" + return QuantizerRole(module_type="linear", tensor_type=tensor_type, name="t") + + +def test_fp8block_recipe_state_role_dispatch_forward(): + """Forward dispatch: input/output -> x cfg, weight -> w cfg.""" + available, reason = te.is_fp8_block_scaling_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 block scaling unsupported: {reason}") + + from transformer_engine.pytorch.quantization import Float8BlockScalingRecipeState + + fp8_recipe = recipe.Float8BlockScaling() + state = Float8BlockScalingRecipeState( + fp8_recipe, + mode="forward", + num_quantizers=3, + roles=[ + _fp8block_role("input"), + _fp8block_role("weight"), + _fp8block_role("output"), + ], + ) + quantizers = state.make_quantizers() + assert len(quantizers) == 3 + # input slot uses x cfg + assert quantizers[0].block_scaling_dim == fp8_recipe.x_block_scaling_dim + # weight slot uses w cfg + assert quantizers[1].block_scaling_dim == fp8_recipe.w_block_scaling_dim + # output slot mirrors input cfg (legacy behavior preserved) + assert quantizers[2].block_scaling_dim == fp8_recipe.x_block_scaling_dim + # Sanity: the recipe defaults distinguish x and w block scaling dims so + # the test would fail if dispatch were uniform. + assert fp8_recipe.x_block_scaling_dim != fp8_recipe.w_block_scaling_dim + + +def test_fp8block_recipe_state_role_dispatch_backward(): + """Backward dispatch: grad_output / grad_input both -> grad cfg.""" + available, reason = te.is_fp8_block_scaling_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 block scaling unsupported: {reason}") + + from transformer_engine.pytorch.quantization import Float8BlockScalingRecipeState + + fp8_recipe = recipe.Float8BlockScaling() + state = Float8BlockScalingRecipeState( + fp8_recipe, + mode="backward", + num_quantizers=2, + roles=[ + _fp8block_role("grad_output"), + _fp8block_role("grad_input"), + ], + ) + quantizers = state.make_quantizers() + assert len(quantizers) == 2 + for q in quantizers: + assert q.block_scaling_dim == fp8_recipe.grad_block_scaling_dim + + +def test_fp8block_recipe_state_positional_fallback_matches_explicit_roles(): + """``roles=None`` produces the same per-slot configs as explicit ``[input, weight, output]``.""" + available, reason = te.is_fp8_block_scaling_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 block scaling unsupported: {reason}") + + from transformer_engine.pytorch.quantization import Float8BlockScalingRecipeState + + fp8_recipe = recipe.Float8BlockScaling() + + explicit = Float8BlockScalingRecipeState( + fp8_recipe, + mode="forward", + num_quantizers=3, + roles=[ + _fp8block_role("input"), + _fp8block_role("weight"), + _fp8block_role("output"), + ], + ).make_quantizers() + + fallback = Float8BlockScalingRecipeState( + fp8_recipe, + mode="forward", + num_quantizers=3, + roles=None, + ).make_quantizers() + + assert len(explicit) == len(fallback) == 3 + for a, b in zip(explicit, fallback): + assert a.block_scaling_dim == b.block_scaling_dim + assert a.dtype == b.dtype + assert a.amax_epsilon == b.amax_epsilon + assert a.force_pow_2_scales == b.force_pow_2_scales + + +def test_fp8block_recipe_state_supports_non_multiple_of_three(): + """Two-slot forward (fusible-Linear shape) used to fail ``% 3 == 0`` assert.""" + available, reason = te.is_fp8_block_scaling_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 block scaling unsupported: {reason}") + + from transformer_engine.pytorch.quantization import Float8BlockScalingRecipeState + + fp8_recipe = recipe.Float8BlockScaling() + state = Float8BlockScalingRecipeState( + fp8_recipe, + mode="forward", + num_quantizers=2, + roles=[ + _fp8block_role("input"), + _fp8block_role("weight"), + ], + ) + quantizers = state.make_quantizers() + assert len(quantizers) == 2 + assert quantizers[0].block_scaling_dim == fp8_recipe.x_block_scaling_dim + assert quantizers[1].block_scaling_dim == fp8_recipe.w_block_scaling_dim + + +def test_fp8block_recipe_state_unknown_or_none_role_falls_back_positionally(): + """Per-slot ``None`` and unknown ``tensor_type`` use the positional pattern.""" + available, reason = te.is_fp8_block_scaling_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 block scaling unsupported: {reason}") + + from transformer_engine.pytorch.quantization import Float8BlockScalingRecipeState + + fp8_recipe = recipe.Float8BlockScaling() + # Slot 0: bare role (empty tensor_type) -> positional "input" -> x cfg + # Slot 1: unknown tensor_type "qkv" (DPA-style) -> positional "weight" -> w cfg + # Slot 2: None role -> positional "output" -> x cfg + state = Float8BlockScalingRecipeState( + fp8_recipe, + mode="forward", + num_quantizers=3, + roles=[ + QuantizerRole(), + QuantizerRole(module_type="dpa", tensor_type="qkv"), + None, + ], + ) + quantizers = state.make_quantizers() + assert len(quantizers) == 3 + assert quantizers[0].block_scaling_dim == fp8_recipe.x_block_scaling_dim + assert quantizers[1].block_scaling_dim == fp8_recipe.w_block_scaling_dim + assert quantizers[2].block_scaling_dim == fp8_recipe.x_block_scaling_dim + + +def _nvfp4_role(tensor_type): + return QuantizerRole(module_type="linear", tensor_type=tensor_type, name="t") + + +def test_nvfp4_recipe_state_role_dispatch_forward(): + """Forward dispatch: input/output -> inp cfg (RHT, 1D), weight -> weight cfg (no RHT, 2D).""" + if not torch.cuda.is_available() or not te.is_nvfp4_available(): + pytest.skip("NVFP4 unsupported on this device") + + from transformer_engine.pytorch.quantization import NVFP4BlockScalingRecipeState + + nvfp4_recipe = recipe.NVFP4BlockScaling() + state = NVFP4BlockScalingRecipeState( + nvfp4_recipe, + mode="forward", + num_quantizers=3, + roles=[ + _nvfp4_role("input"), + _nvfp4_role("weight"), + _nvfp4_role("output"), + ], + ) + quantizers = state.make_quantizers() + assert len(quantizers) == 3 + # input slot + assert quantizers[0].with_rht == nvfp4_recipe.fp4_quant_fwd_inp.random_hadamard_transform + assert quantizers[0].with_2d_quantization == nvfp4_recipe.fp4_quant_fwd_inp.fp4_2d_quantization + # weight slot + assert quantizers[1].with_rht == nvfp4_recipe.fp4_quant_fwd_weight.random_hadamard_transform + assert ( + quantizers[1].with_2d_quantization == nvfp4_recipe.fp4_quant_fwd_weight.fp4_2d_quantization + ) + # output slot mirrors input cfg + assert quantizers[2].with_rht == nvfp4_recipe.fp4_quant_fwd_inp.random_hadamard_transform + assert quantizers[2].with_2d_quantization == nvfp4_recipe.fp4_quant_fwd_inp.fp4_2d_quantization + # Sanity: defaults distinguish input vs weight (RHT and 2D toggles differ). + assert ( + nvfp4_recipe.fp4_quant_fwd_inp.random_hadamard_transform + != nvfp4_recipe.fp4_quant_fwd_weight.random_hadamard_transform + ) or ( + nvfp4_recipe.fp4_quant_fwd_inp.fp4_2d_quantization + != nvfp4_recipe.fp4_quant_fwd_weight.fp4_2d_quantization + ) + + +def test_nvfp4_recipe_state_role_dispatch_backward(): + """Backward dispatch: any slot -> grad cfg (uniform).""" + if not torch.cuda.is_available() or not te.is_nvfp4_available(): + pytest.skip("NVFP4 unsupported on this device") + + from transformer_engine.pytorch.quantization import NVFP4BlockScalingRecipeState + + nvfp4_recipe = recipe.NVFP4BlockScaling() + state = NVFP4BlockScalingRecipeState( + nvfp4_recipe, + mode="backward", + num_quantizers=2, + roles=[ + _nvfp4_role("grad_output"), + _nvfp4_role("grad_input"), + ], + ) + quantizers = state.make_quantizers() + assert len(quantizers) == 2 + for q in quantizers: + assert q.with_rht == nvfp4_recipe.fp4_quant_bwd_grad.random_hadamard_transform + assert q.with_2d_quantization == nvfp4_recipe.fp4_quant_bwd_grad.fp4_2d_quantization + assert q.stochastic_rounding == nvfp4_recipe.fp4_quant_bwd_grad.stochastic_rounding + + +def test_nvfp4_recipe_state_positional_fallback_matches_explicit_roles(): + """``roles=None`` matches explicit ``[input, weight, output]`` slot-for-slot.""" + if not torch.cuda.is_available() or not te.is_nvfp4_available(): + pytest.skip("NVFP4 unsupported on this device") + + from transformer_engine.pytorch.quantization import NVFP4BlockScalingRecipeState + + nvfp4_recipe = recipe.NVFP4BlockScaling() + + explicit = NVFP4BlockScalingRecipeState( + nvfp4_recipe, + mode="forward", + num_quantizers=3, + roles=[ + _nvfp4_role("input"), + _nvfp4_role("weight"), + _nvfp4_role("output"), + ], + ).make_quantizers() + + fallback = NVFP4BlockScalingRecipeState( + nvfp4_recipe, + mode="forward", + num_quantizers=3, + roles=None, + ).make_quantizers() + + assert len(explicit) == len(fallback) == 3 + for a, b in zip(explicit, fallback): + assert a.with_rht == b.with_rht + assert a.with_post_rht_amax == b.with_post_rht_amax + assert a.with_2d_quantization == b.with_2d_quantization + assert a.stochastic_rounding == b.stochastic_rounding + assert a.dtype == b.dtype + + +def test_nvfp4_recipe_state_supports_non_multiple_of_three(): + """Two-slot forward (fusible-Linear shape) succeeds with role-driven dispatch.""" + if not torch.cuda.is_available() or not te.is_nvfp4_available(): + pytest.skip("NVFP4 unsupported on this device") + + from transformer_engine.pytorch.quantization import NVFP4BlockScalingRecipeState + + nvfp4_recipe = recipe.NVFP4BlockScaling() + state = NVFP4BlockScalingRecipeState( + nvfp4_recipe, + mode="forward", + num_quantizers=2, + roles=[ + _nvfp4_role("input"), + _nvfp4_role("weight"), + ], + ) + quantizers = state.make_quantizers() + assert len(quantizers) == 2 + assert quantizers[0].with_rht == nvfp4_recipe.fp4_quant_fwd_inp.random_hadamard_transform + assert quantizers[1].with_rht == nvfp4_recipe.fp4_quant_fwd_weight.random_hadamard_transform + + +def test_nvfp4_recipe_state_unknown_or_none_role_falls_back_positionally(): + """Per-slot ``None`` and unknown ``tensor_type`` use the positional pattern.""" + if not torch.cuda.is_available() or not te.is_nvfp4_available(): + pytest.skip("NVFP4 unsupported on this device") + + from transformer_engine.pytorch.quantization import NVFP4BlockScalingRecipeState + + nvfp4_recipe = recipe.NVFP4BlockScaling() + # Slot 0: bare role (empty tensor_type) -> positional "input" -> inp cfg + # Slot 1: DPA-style unknown tensor_type "qkv" -> positional "weight" -> weight cfg + # Slot 2: None role -> positional "output" -> inp cfg + state = NVFP4BlockScalingRecipeState( + nvfp4_recipe, + mode="forward", + num_quantizers=3, + roles=[ + QuantizerRole(), + QuantizerRole(module_type="dpa", tensor_type="qkv"), + None, + ], + ) + quantizers = state.make_quantizers() + assert len(quantizers) == 3 + assert quantizers[0].with_rht == nvfp4_recipe.fp4_quant_fwd_inp.random_hadamard_transform + assert quantizers[1].with_rht == nvfp4_recipe.fp4_quant_fwd_weight.random_hadamard_transform + assert quantizers[2].with_rht == nvfp4_recipe.fp4_quant_fwd_inp.random_hadamard_transform + + +# ---------------------------------------------------------------------- +# RecipeState._slot_role primitive +# ---------------------------------------------------------------------- +# +# `_slot_role` is the primitive that role-driven recipe states use to +# resolve per-slot dispatch info. It returns the real role when one was +# provided and synthesizes one with the positional ``tensor_type`` fallback +# (and empty ``module_type``/``name``) otherwise. Future recipes that +# dispatch on ``module_type`` / ``name`` rely on this contract. +# +# We exercise these via a concrete ``Float8BlockScalingRecipeState`` since +# ``RecipeState`` is abstract; the helper itself is mode-aware but +# recipe-agnostic. + + +def _make_fp8block_state(*, mode, num_quantizers, roles): + from transformer_engine.pytorch.quantization import Float8BlockScalingRecipeState + + return Float8BlockScalingRecipeState( + recipe.Float8BlockScaling(), + mode=mode, + num_quantizers=num_quantizers, + roles=roles, + ) + + +def test_slot_role_passes_real_role_through_unchanged(): + """A real ``QuantizerRole`` from the producer is returned as-is.""" + available, reason = te.is_fp8_block_scaling_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 block scaling unsupported: {reason}") + + real = QuantizerRole(module_type="linear", tensor_type="weight", name="layer37.fc1") + state = _make_fp8block_state(mode="forward", num_quantizers=1, roles=[real]) + resolved = state._slot_role(0) + # Identity: no copying, the real instance is returned. + assert resolved is real + assert resolved.module_type == "linear" + assert resolved.tensor_type == "weight" + assert resolved.name == "layer37.fc1" + + +def test_slot_role_passes_unknown_tensor_type_through_unchanged(): + """A real role with non-canonical ``tensor_type`` is NOT remapped by ``_slot_role``. + + ``_slot_tensor_type`` would fall back to positional, but ``_slot_role`` + must preserve the original so module-type / name dispatch still works. + """ + available, reason = te.is_fp8_block_scaling_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 block scaling unsupported: {reason}") + + dpa_role = QuantizerRole(module_type="dpa", tensor_type="qkv", name="self_attention.dpa") + state = _make_fp8block_state(mode="forward", num_quantizers=1, roles=[dpa_role]) + resolved = state._slot_role(0) + assert resolved is dpa_role + assert resolved.tensor_type == "qkv" # unchanged, NOT folded into known set + # ``_slot_tensor_type`` still falls back to positional pattern[0] = "input". + assert state._slot_tensor_type(0) == "input" + + +def test_slot_role_returns_bare_role_when_per_slot_role_is_none(): + """Boundary slot (``roles[i] is None``) returns a bare ``QuantizerRole()``. + + The primitive does NOT synthesize a positional ``tensor_type`` — that's + a tensor-type-dispatch policy owned by ``_slot_tensor_type``. + """ + available, reason = te.is_fp8_block_scaling_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 block scaling unsupported: {reason}") + + real_input = QuantizerRole(module_type="linear", tensor_type="input", name="t") + real_weight = QuantizerRole(module_type="linear", tensor_type="weight", name="t") + # Slot 2 (output) is None: typical for Linear without parent setting + # ``_output_quantizer_role``. + state = _make_fp8block_state( + mode="forward", num_quantizers=3, roles=[real_input, real_weight, None] + ) + # Real slots pass through. + assert state._slot_role(0) is real_input + assert state._slot_role(1) is real_weight + # None slot returns a bare QuantizerRole(): all fields empty, no + # tensor-type-specific synthesis. + bare = state._slot_role(2) + assert bare.tensor_type == "" + assert bare.module_type == "" + assert bare.name == "" + # Consumers get positional fallback through _slot_tensor_type, not _slot_role. + assert state._slot_tensor_type(2) == "output" + + +def test_slot_role_returns_bare_role_when_roles_list_is_none(): + """``roles=None`` yields bare ``QuantizerRole()`` for every slot, fwd and bwd. + + Positional fallback for tensor types lives in ``_slot_tensor_type``, not here. + """ + available, reason = te.is_fp8_block_scaling_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 block scaling unsupported: {reason}") + + fwd = _make_fp8block_state(mode="forward", num_quantizers=4, roles=None) + # _slot_role is field-agnostic: every slot is a bare QuantizerRole(). + for i in range(4): + role = fwd._slot_role(i) + assert role.tensor_type == "" + assert role.module_type == "" + assert role.name == "" + # _slot_tensor_type applies the positional fallback (with wrap). + fwd_types = [fwd._slot_tensor_type(i) for i in range(4)] + assert fwd_types == ["input", "weight", "output", "input"] + + bwd = _make_fp8block_state(mode="backward", num_quantizers=3, roles=None) + for i in range(3): + assert bwd._slot_role(i).tensor_type == "" + bwd_types = [bwd._slot_tensor_type(i) for i in range(3)] + assert bwd_types == ["grad_output", "grad_input", "grad_output"] + + +def test_slot_role_supports_module_type_only_role(): + """A role that fills ONLY ``module_type`` is preserved as-is. + + This is the producer convention for future module-type-driven recipes: + fill only the field(s) you have signal for. ``_slot_role`` must not + invent a ``tensor_type`` to mask the empty one (otherwise the module-type + branch in a mixed recipe would never see a clean signal). + """ + available, reason = te.is_fp8_block_scaling_available(return_reason=True) + if not torch.cuda.is_available() or not available: + pytest.skip(f"FP8 block scaling unsupported: {reason}") + + moe = QuantizerRole(module_type="moe_expert") + state = _make_fp8block_state(mode="forward", num_quantizers=1, roles=[moe]) + resolved = state._slot_role(0) + assert resolved is moe + assert resolved.module_type == "moe_expert" + assert resolved.tensor_type == "" # NOT auto-filled + assert resolved.name == "" + # Tensor-type-only recipes fall back to positional for this slot. + assert state._slot_tensor_type(0) == "input" diff --git a/tests/pytorch/test_float8_current_scaling_exact.py b/tests/pytorch/test_float8_current_scaling_exact.py index 99ab9c4984..3b964a5af9 100644 --- a/tests/pytorch/test_float8_current_scaling_exact.py +++ b/tests/pytorch/test_float8_current_scaling_exact.py @@ -14,7 +14,7 @@ from transformer_engine.pytorch.quantization import autocast, get_fp8_torch_dtype from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.custom_recipes.quantization import MMParams -from transformer_engine.pytorch.custom_recipes.quantization_current_scaling import ( +from transformer_engine.pytorch.custom_recipes.quantization_ref_current_scaling import ( CurrentScalingQuantizerRef, ) diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index fbba27941c..aa7cb50e8b 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -19,7 +19,7 @@ static __global__ void moe_permute_row_map(const int *sorted_row_id, int *row_id const int tid = threadIdx.x; const int idx = bid * blockDim.x + tid; - if (idx >= num_rows * topK) return; + if (idx >= static_cast(num_rows) * topK) return; int source_row = sorted_row_id[idx]; int source_token_id = source_row / topK; @@ -27,10 +27,10 @@ static __global__ void moe_permute_row_map(const int *sorted_row_id, int *row_id if (idx >= num_out_tokens) { // Set the indices of dropped tokens to -1 - row_id_map[source_topK_id * num_rows + source_token_id] = -1; + row_id_map[static_cast(source_topK_id) * num_rows + source_token_id] = -1; } else { // Create a row id map for subsequent unpermute operation - row_id_map[source_topK_id * num_rows + source_token_id] = idx; + row_id_map[static_cast(source_topK_id) * num_rows + source_token_id] = idx; } } @@ -42,7 +42,7 @@ __global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const TCompute *s_prob = reinterpret_cast(s_mem); // Each block corresponds to one dest token - const int source_token = blockIdx.x; + const int64_t source_token = blockIdx.x; const int tid = threadIdx.x; if (hasProb) { @@ -65,7 +65,7 @@ __global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const TCompute frag_elem[kElementsPerAccess]; TCompute frag_sum[kElementsPerAccess]; - int source_row = row_id_map[source_token]; + int64_t source_row = row_id_map[source_token]; // source_row == -1 represents a dropped token if (source_row != -1) { @@ -134,7 +134,7 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac TCompute *s_prob = reinterpret_cast(s_mem); // Each block corresponds to one source token - const int source_token = blockIdx.x; + const int64_t source_token = blockIdx.x; const int tid = threadIdx.x; if (hasProb) { @@ -172,7 +172,7 @@ __global__ void moe_permute_kernel(const T *input_bwd, const T *input_fwd, T *ac for (int k = 0; k < topKTile; k++) { if (k == topK) break; - int dest_row = row_id_map[index]; + int64_t dest_row = row_id_map[index]; index += num_rows; if (dest_row != -1) { @@ -239,7 +239,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, // moe_permute_fwd int threads = 64; - int blocks = (num_rows * topK + threads - 1) / threads; + int blocks = (static_cast(num_rows) * topK + threads - 1) / threads; moe_permute_row_map<<>>(sorted_row_id, row_id_map, num_rows, topK, num_out_tokens); @@ -371,6 +371,13 @@ void nvte_device_radix_sort_pairs(void *temp_storage, size_t *temp_storage_bytes int *keys_out, int *values_in, int *values_out, size_t num_items) { NVTE_API_CALL(nvte_device_radix_sort_pairs); - cub::DeviceRadixSort::SortPairs(temp_storage, *temp_storage_bytes, keys_in, keys_out, values_in, - values_out, num_items); + // Sort keys as uint32_t so any negative-int sentinel (e.g. `-1` placed by an + // expert-parallel rank mask) becomes a large unsigned value and lands at the + // tail of the sorted output, matching the existing capacity-drop convention + // (drops encoded as a large positive expert id) and the + // `idx >= num_out_tokens` drop branch in moe_permute_row_map. + auto *u_keys_in = reinterpret_cast(keys_in); + auto *u_keys_out = reinterpret_cast(keys_out); + cub::DeviceRadixSort::SortPairs(temp_storage, *temp_storage_bytes, u_keys_in, u_keys_out, + values_in, values_out, num_items); } diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 0d0b2fd37f..9599663691 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -558,19 +558,33 @@ class CustomRecipe(Recipe): Parameters ---------- qfactory : Callable - Factory callable that returns a quantizer instance for a - given semantic tensor role. - The callable is typically invoked as:: + Factory callable that returns a quantizer instance *or* a + ``QuantizerRequest`` subclass for a given ``QuantizerRole``. + The callable is invoked as:: qfactory( - role: str, - ) + role: QuantizerRole, + ) -> Union[Quantizer, QuantizerRequest] - Where `role` is one of the following strings for e.g. te.Linear - (stable public contract): + ``QuantizerRole`` is a frozen dataclass with the following fields: + + - ``module_type`` (str): module type (empty string when not set), e.g. + ``"linear"``, ``"grouped_linear"``, ``"dpa"``. + - ``tensor_type`` (str): what tensor is being quantized (empty + string when not set), e.g. ``"input"``, ``"weight"``, ``"grad_output"``. + - ``name`` (str): caller-provided module instance name (empty + string when not set), e.g. ``"qkv"``, ``"proj"``, ``"fc1"``, ``"fc2"``. + + For stateful quantizers (delayed scaling), return a + ``DelayedScalingRequest`` dataclass instead of a quantizer. + TE will allocate shared scale/amax_history buffers and create + ``Float8Quantizer`` instances integrated with the existing + delayed-scaling reduction infrastructure. + + See ``transformer_engine.pytorch.quantization.QuantizerRole`` + and ``transformer_engine.pytorch.quantization.DelayedScalingRequest`` + for full documentation. - - forward: "linear_input", "linear_weight", "linear_output" - - backward: "linear_grad_output", "linear_grad_input" backward_override : {None, 'high_precision', 'dequantized'}, default = None Backward precision mode. None does not modify backward behavior, `high_precision` keeps original high-precision operands for backward, @@ -580,6 +594,11 @@ class CustomRecipe(Recipe): qfactory: Callable[..., Any] + # fp8_format does not affect quantization (quantization factory controls that), + # but TE internals (e.g. get_fp8_te_dtype, backend selection) read it + # from the recipe. HYBRID (E4M3 fwd, E5M2 bwd) is a safe default. + fp8_format: Format = Format.HYBRID + fp8_dpa: bool = False fp8_mha: bool = False backward_override: Optional[str] = os.getenv("NVTE_BACKWARD_OVERRIDE", None) diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index d145cf0a21..3ff0d75ee4 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -48,6 +48,9 @@ from transformer_engine.pytorch.quantization import is_fp8_block_scaling_available from transformer_engine.pytorch.quantization import is_nvfp4_available from transformer_engine.pytorch.quantization import get_default_recipe +from transformer_engine.pytorch.quantization import QuantizerRole +from transformer_engine.pytorch.quantization import QuantizerRequest +from transformer_engine.pytorch.quantization import DelayedScalingRequest from transformer_engine.pytorch.utils import get_cudnn_version from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.pytorch.utils import is_bf16_available diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 79ebbd4afa..6e097265ff 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -175,6 +175,25 @@ _dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1" +def _qkv_quantizer_type(qkv_quantizer): + """Map a DPA QKV quantizer instance to its kernel-facing FP8 sub-recipe label. + + Returns one of ``'delayed'`` / ``'current'`` / ``'mxfp8'``. Used by FP8 + attention forward/backward to dispatch save-for-backward and + re-quantization decisions from the *quantizer instance* rather than + the top-level ``Recipe`` type, so that ``CustomRecipe`` + is handled correctly. Built-in recipes already + produce the matching quantizer instances, so behavior is preserved. + """ + if isinstance(qkv_quantizer, Float8Quantizer): + return "delayed" + if isinstance(qkv_quantizer, Float8CurrentScalingQuantizer): + return "current" + if isinstance(qkv_quantizer, MXFP8Quantizer): + return "mxfp8" + raise TypeError(f"Unsupported FP8 attention QKV quantizer: {type(qkv_quantizer).__name__}") + + class FP8EmulationFunc(torch.autograd.Function): """ Emulate the effects of FP8 quantization on tensors. Used in UnfusedDotProductAttention as follows: @@ -491,7 +510,7 @@ def forward( fp8_recipe = fp8_meta["local_recipes"][0] # get quantizers from DPA; all Nones if not fp8 QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) + dpa_utils.get_attention_quantizers(fp8, quantizers) ) # S/dP are forced to use DS quantizers in DPA.init_fp8_metadata; revert them here for true CS emulation if fp8_recipe.float8_current_scaling(): @@ -1318,9 +1337,15 @@ def forward( # get quantizers from DPA; all Nones if not fp8 QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) + dpa_utils.get_attention_quantizers(fp8, quantizers) ) + # Effective FP8 sub-recipe label inferred from the QKV quantizer + # instance. Drives save-for-backward and re-quantization dispatch + # below so that CustomRecipe (and built-in recipes alike) work + # without depending on `fp8_recipe.()`. + qkv_type = _qkv_quantizer_type(QKV_quantizer) if fp8 else None + # get nominal data type for out # FP16/BF16 attention: torch.float16 or torch.bfloat16 # FP8 attention: torch.float16 or torch.bfloat16 @@ -1402,19 +1427,13 @@ def forward( not is_bwd_fp8 or ( is_bwd_fp8 - and ( - (fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16) - or fp8_recipe.mxfp8() - ) + and ((qkv_type == "current" and _dpa_fp8_cs_o_in_f16) or qkv_type == "mxfp8") ) ) bwd_requires_o_fp8 = ( is_training and is_bwd_fp8 - and ( - fp8_recipe.delayed() - or (fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16) - ) + and (qkv_type == "delayed" or (qkv_type == "current" and not _dpa_fp8_cs_o_in_f16)) ) if isinstance(out_, QuantizedTensorStorage): if not is_output_fp8 or bwd_requires_o_f16: @@ -1442,14 +1461,10 @@ def forward( fp8_tensors = (None, None, None, None) f16_tensors = (None, None, None, None) if is_bwd_fp8: - if ( - fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16 - ) or fp8_recipe.mxfp8(): + if (qkv_type == "current" and _dpa_fp8_cs_o_in_f16) or qkv_type == "mxfp8": fp8_tensors = (q_fp8, k_fp8, v_fp8, None) f16_tensors = (None, None, None, out_f16) - elif fp8_recipe.delayed() or ( - fp8_recipe.float8_current_scaling() and not _dpa_fp8_cs_o_in_f16 - ): + elif qkv_type == "delayed" or (qkv_type == "current" and not _dpa_fp8_cs_o_in_f16): fp8_tensors = (q_fp8, k_fp8, v_fp8, out_fp8) else: if is_input_fp8: @@ -1536,6 +1551,7 @@ def forward( ctx.dO_quantizer = dO_quantizer ctx.dP_quantizer = dP_quantizer ctx.S_quantizer = S_quantizer + ctx.qkv_type = qkv_type if ctx.fp8 and isinstance(ctx.S_quantizer, Float8Quantizer): ctx.S_quantizer = S_quantizer.copy() ctx.S_quantizer.scale = S_quantizer.scale.clone() @@ -1706,9 +1722,9 @@ def backward(ctx, d_out, *_args): # MXFP8BlockScaling: # out_, dq_, dk_, dv_, d_out: torch.Tensor; dtype = torch.float16 or torch.bfloat16 out_ = out_fp8 - if ctx.fp8_recipe.float8_current_scaling() and _dpa_fp8_cs_o_in_f16: + if ctx.qkv_type == "current" and _dpa_fp8_cs_o_in_f16: out_ = out - if ctx.fp8_recipe.mxfp8(): + if ctx.qkv_type == "mxfp8": out_ = out aux_ctx_tensors.append(d_out) dq_, dk_, dv_, *rest = fused_attn_bwd( @@ -2059,7 +2075,7 @@ def forward( " with FP8!" ) if fp8_recipe.float8_current_scaling() and context_parallel: - all_quantizers = dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) + all_quantizers = dpa_utils.get_attention_quantizers(fp8, quantizers) for q in all_quantizers: if isinstance(q, Float8CurrentScalingQuantizer): q.with_amax_reduction = True diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index 32eb1b597a..995ecf31b4 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -58,6 +58,29 @@ _dpa_fp8_cs_o_in_f16 = os.getenv("NVTE_DPA_FP8CS_O_in_F16", "1") == "1" +def _reject_custom_recipe_under_cp(fp8, fp8_recipe): + """Fail fast when CustomRecipe meets context-parallel FP8 attention. + + Single-device FP8 attention dispatch was migrated to read quantizer + instance types (see ``backends._qkv_quantizer_type`` and + ``utils.get_attention_quantizers``), which makes CustomRecipe work end + to end. The CP code path in this module still dispatches on + ``fp8_recipe.()`` at ~90 sites; under CustomRecipe those + predicates all return False and the dispatch silently falls through to + incorrect tensor save / amax-reduction shapes. Until that migration + lands, surface the limitation here with a clear error rather than + failing later in C++ assertions or with silently-wrong gradients. + """ + if fp8 and fp8_recipe is not None and fp8_recipe.custom(): + raise NotImplementedError( + "CustomRecipe + Context Parallelism is not yet supported for FP8 " + "DotProductAttention. Either disable context parallelism, or use a " + "built-in FP8 recipe (DelayedScaling, Float8CurrentScaling, " + "MXFP8BlockScaling) for CP. The single-device CustomRecipe + DPA " + "path is supported." + ) + + def get_bsh_dims(tensor_format): """Get batch dimension and sequence dimension from tensor format""" if tensor_format in ["bshd", "sbhd", "bhsd"]: @@ -1453,6 +1476,7 @@ def forward( fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: fp8_recipe = fp8_meta["local_recipes"][0] + _reject_custom_recipe_under_cp(fp8, fp8_recipe) ( QKV_quantizer, O_quantizer, @@ -1460,7 +1484,7 @@ def forward( dQKV_quantizer, dO_quantizer, dP_quantizer, - ) = dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) + ) = dpa_utils.get_attention_quantizers(fp8, quantizers) # q, k, v a2a: gather s and split h # FP8DS/CS: Float8Tensor -> torch.uint8 -> Float8Tensor @@ -3043,6 +3067,7 @@ def forward( fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: fp8_recipe = fp8_meta["local_recipes"][0] + _reject_custom_recipe_under_cp(fp8, fp8_recipe) ( QKV_quantizer, O_quantizer, @@ -3050,7 +3075,7 @@ def forward( dQKV_quantizer, dO_quantizer, dP_quantizer, - ) = dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) + ) = dpa_utils.get_attention_quantizers(fp8, quantizers) fwd_nominal_dtype = q.dtype q_fp8, k_fp8, v_fp8 = (q, k, v) if is_input_fp8 else (None, None, None) q_f16, k_f16, v_f16 = (None, None, None) if is_input_fp8 else (q, k, v) @@ -3904,13 +3929,14 @@ def forward( fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_meta is not None and fp8_meta.get("local_recipes", None) is not None: fp8_recipe = fp8_meta["local_recipes"][0] + _reject_custom_recipe_under_cp(fp8, fp8_recipe) fwd_nominal_dtype = q.dtype fused_attn_backend = None max_logit = None QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer = ( - dpa_utils.get_attention_quantizers(fp8, fp8_recipe, quantizers) + dpa_utils.get_attention_quantizers(fp8, quantizers) ) q_fp8, k_fp8, v_fp8 = (None, None, None) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 17e9a337a4..b38b66c3e6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -23,6 +23,7 @@ ) from transformer_engine.pytorch.utils import get_cudnn_version from transformer_engine.pytorch.quantization import ( + QuantizerRole, get_fp8_te_dtype, FP8GlobalStateManager, RecipeState, @@ -312,6 +313,8 @@ class DotProductAttention(TransformerEngineBaseModule): `_). :math:`\text{max_logit} = \max(S)`, where :math:`S = \text{mask}(Q \cdot K^T \cdot \text{softmax_scale} + \text{bias})` of shape ``[b, h, s_q, s_kv]``, and :math:`\text{max_logit}` is of shape ``[h]``. + name : Optional[str], default = None + module instance name. Parallelism parameters ---------------------- @@ -371,8 +374,9 @@ def __init__( softmax_scale: Optional[float] = None, softmax_type: str = "vanilla", return_max_logit: Optional[bool] = False, + name: Optional[str] = None, ) -> None: - super().__init__() + super().__init__(name=name) self.logger = logging.getLogger("DotProductAttention") self.logger.setLevel(attn_log._log_level) @@ -612,6 +616,7 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: # global recipe set in autocast() fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8_recipe.custom(): + super().init_fp8_metadata(num_gemms=num_gemms) return # switch/append recipe: fp8_recipe stays unchanged, but DPA.fp8_meta["recipe"] may be set to @@ -820,6 +825,9 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: def set_meta_tensor(self, fwd: bool, recipe: Union[Recipe, List[Recipe]]) -> None: """Override to allow multiple recipes. Init scales and amaxes for fwd | bwd.""" + if isinstance(recipe, Recipe) and recipe.custom(): + TransformerEngineBaseModule.set_meta_tensor(self, fwd, recipe) + return if isinstance(recipe, Recipe): recipe = [recipe] fp8_recipe_dpa = recipe[-1] @@ -859,13 +867,127 @@ def set_meta_tensor(self, fwd: bool, recipe: Union[Recipe, List[Recipe]]) -> Non for i in range(len(recipe)) ] - self.fp8_meta[fp8_meta_tensor_key] = ( - recipe_states[-1] if len(recipe) == 2 else recipe_states[0] - ) + # Reached the rebuild path because ``fp8_meta_tensors_initialized`` + # was flipped to False after first init — most commonly because the + # base-class ``output_quantizer_role`` / ``grad_input_quantizer_role`` + # setter invalidated state when MHA wired boundary roles. That + # setter is recipe-agnostic, so this code fires for built-in + # recipes too even though they don't consume role information here + # (e.g. ``test_dpa_fp8_extra_state`` reaches this path with pure + # DelayedScaling). + # + # Rebuilding the recipe state must preserve persistent training + # buffers (delayed-scaling ``scale`` / ``amax_history``) so the new + # quantizer instances and the ``FP8GlobalStateManager`` reduction + # buffers end up viewing the SAME tensor objects, and so any + # checkpoint-loaded state isn't silently destroyed on the first + # forward after ``load_state_dict``. + # + # Inheritance targets the "primary" state stored under + # ``fp8_meta[fp8_meta_tensor_key]`` — the one tracked across + # ``set_meta_tensor`` calls. Auxiliary states in a multi-recipe + # splice (e.g. the CS half of ``[CS, DS]``) are stateless and have + # nothing to inherit. + old_state = self.fp8_meta.get(fp8_meta_tensor_key) + primary_idx = -1 if len(recipe) == 2 else 0 + if old_state is not None: + recipe_states[primary_idx].inherit_state_from(old_state) + + self.fp8_meta[fp8_meta_tensor_key] = recipe_states[primary_idx] self.quantizers[fp8_meta_tensor_key] = [] for recipe_state in recipe_states: self.quantizers[fp8_meta_tensor_key].extend(recipe_state.make_quantizers()) + def get_quantizer_roles( + self, + *, + fwd: bool, + num_quantizers: int, + ) -> Optional[List[QuantizerRole]]: + """QuantizerRole list for quantizers used by ``DotProductAttention``. + + DPA internally performs two matmuls:: + + S = softmax(Q · K^T) (GEMM1) + O = S · V (GEMM2) + + cuDNN's fused-attention API exposes FP8 scale/amax descriptors as + a flat array of **slot groups** numbered 1-3. The numbering is a + cuDNN convention — it does *not* correspond to operation order + inside DPA: + + Forward (3 slot groups × 3 positions = 9 slots): + + =========== =========================================== =========== + Slot group Primary tensor cuDNN enum + =========== =========================================== =========== + Group 1 QKV — inputs to GEMM1 (Q·K^T) GEMM1_OUTPUT + Group 2 O — output of GEMM2 (S·V) GEMM2_INPUT + Group 3 S — post-softmax, input to GEMM2 (S·V) GEMM3_OUTPUT + =========== =========================================== =========== + + Backward (3 slot groups × 2 positions = 6 slots): + + =========== =========================================== =========== + Slot group Primary tensor cuDNN enum + =========== =========================================== =========== + Group 1 dQKV — gradients flowing back to Q, K, V GRAD_OUTPUT1 + Group 2 dO — gradient of the attention output GRAD_INPUT2 + Group 3 dP — gradient of the softmax output GRAD_INPUT3 + =========== =========================================== =========== + + Unused positions within a group share the role of the group's + primary tensor. + + **Boundary slots** — O (fwd) and dQKV (bwd) leave DPA and enter + the next module (e.g. proj linear). DPA does not know that + consumer, so these default to ``None``. The parent module + (e.g. ``MultiheadAttention``) can set + :attr:`output_quantizer_role` / :attr:`grad_input_quantizer_role` + to fill in the consumer identity. + + When not set, a hint-only ``QuantizerRole`` with empty + ``module_type`` / ``tensor_type`` is emitted, with ``name`` + containing ``"dpa_output"`` or ``"dpa_grad_input"``. This lets + the factory return a DPA-compatible quantizer (required by the + fused kernel) even when the downstream consumer is unknown. + """ + name = self.name or "" + if fwd: + qkv_role = QuantizerRole(module_type="dpa", tensor_type="qkv", name=name) + o_role = self._output_quantizer_role + if o_role is None: + o_role = QuantizerRole(name=f"{name}.dpa_output" if name else "dpa_output") + s_role = QuantizerRole(module_type="dpa", tensor_type="s", name=name) + base = [ + qkv_role, + qkv_role, + qkv_role, # Group 1: QKV (inputs to Q·K^T) + o_role, + o_role, + o_role, # Group 2: O (output of S·V) — boundary + s_role, + s_role, + s_role, # Group 3: S (post-softmax, input to S·V) + ] + else: + dqkv_role = self._grad_input_quantizer_role + if dqkv_role is None: + dqkv_role = QuantizerRole( + name=f"{name}.dpa_grad_input" if name else "dpa_grad_input" + ) + do_role = QuantizerRole(module_type="dpa", tensor_type="do", name=name) + dp_role = QuantizerRole(module_type="dpa", tensor_type="dp", name=name) + base = [ + dqkv_role, + dqkv_role, # Group 1: dQKV (grads to Q,K,V) — boundary + do_role, + do_role, # Group 2: dO (grad of attention output) + dp_role, + dp_role, # Group 3: dP (grad of softmax output) + ] + return base[:num_quantizers] + @no_torch_dynamo(recursive=False) def forward( self, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 7df5daabe5..1f1637cecd 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -2334,7 +2334,7 @@ def check_set_window_size( return window_size -def get_attention_quantizers(fp8, fp8_recipe, quantizers): +def get_attention_quantizers(fp8, quantizers): """Get the list of quantizers used in attention from the quantizers list.""" if not fp8: return [None] * 6 @@ -2363,7 +2363,11 @@ def get_attention_quantizers(fp8, fp8_recipe, quantizers): dQKV_quantizer.internal = False dQKV_quantizer.set_usage(rowwise=True, columnwise=False) - if fp8_recipe.mxfp8(): + # MXFP8 attention: detect from the QKV quantizer instance rather than the + # recipe predicate so that CustomRecipe (whose `mxfp8()` predicate returns + # False) gets the same treatment as the built-in MXFP8 recipe. The kernel + # handles S/dP internally for MXFP8, hence S/dP are nulled out. + if isinstance(QKV_quantizer, MXFP8Quantizer): QKV_quantizer.columnwise_usage = True QKV_quantizer.optimize_for_gemm = True S_quantizer = None @@ -2374,6 +2378,29 @@ def get_attention_quantizers(fp8, fp8_recipe, quantizers): dP_quantizer = None dQKV_quantizer.columnwise_usage = True + _fp8_types = (Float8Quantizer, Float8CurrentScalingQuantizer, MXFP8Quantizer) + # S/dP are intentionally None under MXFP8 attention; skip the type check + # for those slots in that case. + _allow_none = {"S", "dP"} if isinstance(QKV_quantizer, MXFP8Quantizer) else set() + for _name, _q in [ + ("QKV", QKV_quantizer), + ("O", O_quantizer), + ("S", S_quantizer), + ("dQKV", dQKV_quantizer), + ("dO", dO_quantizer), + ("dP", dP_quantizer), + ]: + if _q is None and _name in _allow_none: + continue + assert isinstance(_q, _fp8_types), ( + "FP8 attention requires FP8-compatible quantizers for all DPA tensor slots, " + f"but {_name} quantizer is {type(_q).__name__}. " + "When using CustomRecipe with fp8_dpa=True, ensure the factory returns an " + "FP8 quantizer (Float8Quantizer, Float8CurrentScalingQuantizer, or " + "MXFP8Quantizer) for all DPA roles (module_type='dpa') and for None roles " + "(boundary slots like O output and dQKV grad-input)." + ) + return QKV_quantizer, O_quantizer, S_quantizer, dQKV_quantizer, dO_quantizer, dP_quantizer diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index afc4622b22..70ae9dfc21 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -8,7 +8,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union import torch -from transformer_engine.pytorch.quantization import FP8GlobalStateManager +from transformer_engine.pytorch.quantization import FP8GlobalStateManager, QuantizerRole from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.module.base import TransformerEngineBaseModule from transformer_engine.pytorch.module import LayerNormLinear, Linear, RMSNorm, LayerNorm @@ -461,6 +461,7 @@ def __init__( layer_number=self.layer_number, attention_type=self.attention_type, softmax_type=self.softmax_type, + name=name + ".core_attention" if name is not None else None, ) # Linear @@ -478,6 +479,84 @@ def __init__( **common_gemm_kwargs, ) + def _update_output_quantizer_roles( + self, + qkv_fp8_output: bool, + proj_fp8_grad: bool, + dpa_fp8_output: bool, + ) -> None: + """Set quantizer roles at the boundaries between QKV, DPA, and proj. + + MHA contains three submodules connected as follows:: + + Forward: QKV linear ──(QKV tensor)──> DPA ──(O tensor)──> Proj linear + Backward: QKV linear <──(dQKV tensor)── DPA <──(dO tensor)── Proj linear + + Each submodule owns quantizers for its internal tensors, but the + *boundary* tensors (the arrows above) need to know which module + will *consume* them so the quantizer factory can pick the right + format. This method sets those boundary roles on all four edges: + + 1. ``qkv_fp8_output`` — **QKV linear → DPA (fwd)**: the QKV + linear's ``output_quantizer_role`` is told its consumer is DPA. + 2. ``proj_fp8_grad`` — **Proj linear ← DPA (bwd)**: proj's + ``grad_input_quantizer_role`` is told its producer is DPA. + 3. ``dpa_fp8_output`` — **DPA → Proj linear (fwd)**: DPA's + ``output_quantizer_role`` is told its consumer is the proj linear. + 4. ``dpa_fp8_output`` — **DPA ← QKV linear (bwd)**: DPA's + ``grad_input_quantizer_role`` is told its consumer is QKV linear. + + When a flag is ``False`` the corresponding role is reset to ``None`` + so the module falls back to its own default. + """ + dpa_name = self.core_attention.name or "" + + # ── Boundary 1 (fwd): QKV linear output → consumed by DPA ──────── + qkv_output_role = ( + QuantizerRole(module_type="dpa", tensor_type="qkv", name=dpa_name) + if qkv_fp8_output + else None + ) + if self.attention_type == "self": + if self.input_layernorm: + self.layernorm_qkv.output_quantizer_role = qkv_output_role + else: + self.qkv.output_quantizer_role = qkv_output_role + elif self.attention_type == "cross": + if self.input_layernorm: + self.layernorm_query.output_quantizer_role = qkv_output_role + else: + self.query_layer.output_quantizer_role = qkv_output_role + self.key_value.output_quantizer_role = qkv_output_role + + # ── Boundary 2 (bwd): Proj grad-input ← produced by DPA ────────── + proj_grad_input_role = ( + QuantizerRole(module_type="dpa", tensor_type="do", name=dpa_name) + if proj_fp8_grad + else None + ) + self.proj.grad_input_quantizer_role = proj_grad_input_role + + # ── Boundary 3 (fwd): DPA output (O) → consumed by Proj linear ─── + proj_name = self.proj.name or "" + self.core_attention.output_quantizer_role = ( + QuantizerRole(module_type="linear", tensor_type="input", name=proj_name) + if dpa_fp8_output + else None + ) + + # ── Boundary 4 (bwd): DPA grad-input (dQKV) → consumed by QKV linear + if self.attention_type == "self": + qkv_linear = self.layernorm_qkv if self.input_layernorm else self.qkv + else: + qkv_linear = self.layernorm_query if self.input_layernorm else self.query_layer + qkv_name = qkv_linear.name or "" + self.core_attention.grad_input_quantizer_role = ( + QuantizerRole(module_type="linear", tensor_type="grad_output", name=qkv_name) + if dpa_fp8_output + else None + ) + def fast_setattr(self, name: str, value: Any) -> None: """Fast attribute set for non-parameter fields.""" self.__dict__[name] = value @@ -822,6 +901,8 @@ def forward( # 1. FP8CS recipe: produce F16 grads; again, due to cuBLAS limitation proj_fp8_grad = dpa_fp8_output and not float8_current_scaling + self._update_output_quantizer_roles(qkv_fp8_output, proj_fp8_grad, dpa_fp8_output) + layernorm_output = None if self.attention_type == "self": # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn] diff --git a/transformer_engine/pytorch/custom_recipes/quantization_factory_examples.py b/transformer_engine/pytorch/custom_recipes/quantization_factory_examples.py new file mode 100644 index 0000000000..d660e5a53b --- /dev/null +++ b/transformer_engine/pytorch/custom_recipes/quantization_factory_examples.py @@ -0,0 +1,271 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Quantizer factory examples. + +Demonstrates how to use the ``CustomRecipe`` + ``qfactory`` interface to apply +*different* quantization recipes to different module/tensor types/instances within the same model. + +Usage:: + + from transformer_engine.common.recipe import CustomRecipe + from transformer_engine.pytorch.quantization import autocast + from transformer_engine.pytorch.custom_recipes.quantization_factory_examples import ( + nvfp4_linear_mxfp8_grouped_linear_factory, + nvfp4_linear_fp8_dpa_factory, + nvfp4_linear_mxfp8_dpa_factory, + ) + + # Mixed module types: NVFP4 for Linear, MXFP8 for GroupedLinear + recipe = CustomRecipe(qfactory=nvfp4_linear_mxfp8_grouped_linear_factory) + with autocast(recipe=recipe): + output = model(input) + + # NVFP4 for Linear, FP8 current-scaling + delayed-scaling for DPA + recipe = CustomRecipe(qfactory=nvfp4_linear_fp8_dpa_factory, fp8_dpa=True) + with autocast(recipe=recipe): + output = model(input) + + # NVFP4 for Linear, MXFP8 for DPA + recipe = CustomRecipe(qfactory=nvfp4_linear_mxfp8_dpa_factory, fp8_dpa=True) + with autocast(recipe=recipe): + output = model(input) +""" + +from __future__ import annotations + +from typing import Optional + +import transformer_engine_torch as tex + +from transformer_engine.pytorch.quantization import QuantizerRole + + +def nvfp4_linear_mxfp8_grouped_linear_factory( + role: Optional[QuantizerRole], +): + """Quantizer factory: NVFP4 for ``Linear``, MXFP8 for ``GroupedLinear``. + + Dispatch logic: + * ``role.module_type == "grouped_linear"`` -> MXFP8 (E4M3, block-32) + * everything else (``"linear"`` or unknown) -> NVFP4 (E2M1) + + NVFP4 settings follow the built-in ``NVFP4BlockScaling`` defaults: + * Weights: 2D quantization (16x16), no RHT, no stochastic rounding + * Inputs: 1D quantization, RHT enabled, no stochastic rounding + * Grads: 1D quantization, RHT enabled, stochastic rounding enabled + """ + is_grouped_linear = role is not None and role.module_type == "grouped_linear" + + if is_grouped_linear: + return _make_mxfp8_quantizer() + + return _make_nvfp4_quantizer(role) + + +def _make_mxfp8_quantizer(): + """Return an MXFP8 quantizer with default settings (E4M3, block-32, E8M0 scales).""" + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer + + return MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + ) + + +def _make_nvfp4_quantizer(role: Optional[QuantizerRole]): + """Return an NVFP4 quantizer configured per tensor role. + + Mirrors :class:`NVFP4BlockScaling` recipe defaults. + """ + from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + + is_linear = role is not None and role.module_type == "linear" + is_weight = is_linear and role.tensor_type == "weight" + is_grad = is_linear and role.tensor_type == "grad_output" + + if is_weight: + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=True, + stochastic_rounding=False, + with_random_sign_mask=True, + ) + + if is_grad: + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_rht=True, + with_post_rht_amax=True, + with_2d_quantization=False, + stochastic_rounding=True, + with_random_sign_mask=True, + ) + + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_rht=True, + with_post_rht_amax=True, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=True, + ) + + +def nvfp4_linear_fp8_dpa_factory( + role: Optional[QuantizerRole], +): + """Quantizer factory: NVFP4 for ``Linear``, mixed FP8 for ``DotProductAttention``. + + This factory demonstrates how to use ``CustomRecipe`` with ``fp8_dpa=True`` + to combine NVFP4 quantization for linear layers with FP8 attention. + + DPA tensor types (``role.module_type == "dpa"``): + + =========== ============================================================ + tensor_type Description + =========== ============================================================ + ``"qkv"`` Query, Key, Value inputs to the first attention GEMM + ``"s"`` Softmax output (S = softmax(Q·K^T)), fed into the second GEMM + ``"o"`` Attention output (O = S·V) + ``"do"`` Gradient of the attention output (dO), backward input + ``"dp"`` Gradient of the softmax output (dP = dO·V^T), backward + ``"dqkv"`` Gradient flowing back to Q, K, V + =========== ============================================================ + + Dispatch logic: + * ``role.module_type == "dpa"`` with ``tensor_type in ("s", "dp")`` + -> FP8 delayed scaling (stateful amax tracking) + * ``role.module_type == "dpa"`` (QKV, dO) + -> FP8 current scaling (E4M3) + * DPA boundary hints (``"dpa_output"`` / ``"dpa_grad_input"`` in ``role.name``) + -> FP8 current scaling placeholder. The fused attention kernel requires + FP8-compatible quantizers in all DPA slots, even when the output is + produced in BF16 (``fp8_mha=False``). DPA emits these hint-only roles + (with empty ``module_type`` and ``tensor_type``) when the downstream + consumer is unknown. + * everything else (``"linear"`` / ``"grouped_linear"`` / ``None``) + -> NVFP4 (E2M1), configured per tensor role + + Usage:: + + from transformer_engine.common.recipe import CustomRecipe + from transformer_engine.pytorch.quantization import autocast + from transformer_engine.pytorch.custom_recipes.quantization_factory_examples import ( + nvfp4_linear_fp8_dpa_factory, + ) + + recipe = CustomRecipe( + qfactory=nvfp4_linear_fp8_dpa_factory, + fp8_dpa=True, + ) + with autocast(recipe=recipe): + output = model(input) + """ + from transformer_engine.pytorch.quantization import DelayedScalingRequest + from transformer_engine.pytorch.tensor.float8_tensor import Float8CurrentScalingQuantizer + + is_dpa = role is not None and role.module_type == "dpa" + is_softmax_or_dp = is_dpa and role.tensor_type in ("s", "dp") + + if is_softmax_or_dp: + return DelayedScalingRequest() + + if is_dpa: + return Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device="cuda", + ) + + # DPA boundary slots (O output / dQKV grad-input): the fused attention + # kernel only supports FP8 quantizers here, regardless of the linear recipe. + is_dpa_boundary = ( + role is not None + and not role.module_type + and ("dpa_output" in role.name or "dpa_grad_input" in role.name) + ) + if is_dpa_boundary: + return Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device="cuda", + ) + + return _make_nvfp4_quantizer(role) + + +def nvfp4_linear_mxfp8_dpa_factory( + role: Optional[QuantizerRole], +): + """Quantizer factory: NVFP4 for ``Linear``, MXFP8 for ``DotProductAttention``. + + Mirrors the documented "NVFP4 linear + MXFP8 attention" combo from + :mod:`transformer_engine.pytorch.attention.dot_product_attention.dot_product_attention` + (see the recipe-combination table at the top of that module). With + ``CustomRecipe`` the per-tensor decision is made directly here, so the + ``NVTE_DPA_FP8_RECIPE="MXFP8BlockScaling"`` env override that the + built-in recipes would otherwise need is unnecessary. + + DPA tensor types (``role.module_type == "dpa"``): + + =========== ============================================================ + tensor_type Description + =========== ============================================================ + ``"qkv"`` Query, Key, Value inputs to the first attention GEMM + ``"s"`` Softmax output (S = softmax(Q·K^T)), fed into the second GEMM + ``"o"`` Attention output (O = S·V) + ``"do"`` Gradient of the attention output (dO), backward input + ``"dp"`` Gradient of the softmax output (dP = dO·V^T), backward + ``"dqkv"`` Gradient flowing back to Q, K, V + =========== ============================================================ + + Dispatch logic: + * ``role.module_type == "dpa"`` -> MXFP8 (E4M3, block-32) + The MXFP8 fused-attention kernel handles the S/dP slots + internally, so any quantizer returned for those roles is later + nulled out by ``get_attention_quantizers``. Returning MXFP8 is + the simplest valid choice. + * DPA boundary hints (``"dpa_output"`` / ``"dpa_grad_input"`` in + ``role.name``) -> MXFP8 placeholder. The fused attention kernel + requires FP8-compatible quantizers in all DPA slots. + * everything else (``"linear"`` / ``"grouped_linear"`` / ``None``) + -> NVFP4 (E2M1), configured per tensor role. + + Usage:: + + from transformer_engine.common.recipe import CustomRecipe + from transformer_engine.pytorch.quantization import autocast + from transformer_engine.pytorch.custom_recipes.quantization_factory_examples import ( + nvfp4_linear_mxfp8_dpa_factory, + ) + + recipe = CustomRecipe( + qfactory=nvfp4_linear_mxfp8_dpa_factory, + fp8_dpa=True, + ) + with autocast(recipe=recipe): + output = model(input) + """ + is_dpa = role is not None and role.module_type == "dpa" + if is_dpa: + return _make_mxfp8_quantizer() + + # DPA boundary slots (O output / dQKV grad-input): emitted by DPA with + # empty `module_type` and a `name` like ".dpa_output". The fused + # attention kernel requires an FP8-compatible quantizer here even when + # the downstream consumer is unknown. + is_dpa_boundary = ( + role is not None + and not role.module_type + and ("dpa_output" in role.name or "dpa_grad_input" in role.name) + ) + if is_dpa_boundary: + return _make_mxfp8_quantizer() + + return _make_nvfp4_quantizer(role) diff --git a/transformer_engine/pytorch/custom_recipes/quantization_recipes_base.py b/transformer_engine/pytorch/custom_recipes/quantization_recipes_base.py new file mode 100644 index 0000000000..22eafaa665 --- /dev/null +++ b/transformer_engine/pytorch/custom_recipes/quantization_recipes_base.py @@ -0,0 +1,179 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Quantizer factory examples using real silicon quantizers. + +Each factory below replicates the behaviour of built-in TE recipe but via the +``CustomRecipe`` + ``qfactory`` interface. This is useful when you want to +start from a known-good recipe and then selectively override quantizer settings +for specific layers / tensor types. + +Usage (any factory):: + + from transformer_engine.common.recipe import CustomRecipe + from transformer_engine.pytorch.quantization import autocast + from transformer_engine.pytorch.custom_recipes.quantization_recipes_base import ( + nvfp4_quantizer_factory, + ) + + recipe = CustomRecipe(qfactory=nvfp4_quantizer_factory) + with autocast(recipe=recipe): + output = model(input) +""" + +from __future__ import annotations + +from typing import Optional + +import torch +import transformer_engine_torch as tex + +from transformer_engine.pytorch.quantization import QuantizerRole + + +def delayed_scaling_quantizer_factory( + role: Optional[QuantizerRole], # pylint: disable=unused-argument +) -> "DelayedScalingRequest": + """Factory that mirrors :class:`DelayedScaling` recipe defaults. + + Returns a :class:`DelayedScalingRequest` for every slot. TE allocates + shared scale/amax_history buffers and wires them into the existing + delayed-scaling reduction path. + + * HYBRID format: E4M3 forward, E5M2 backward + * amax_history_len = 1024 + * reduce_amax = True + """ + from transformer_engine.pytorch.quantization import DelayedScalingRequest + from transformer_engine.common.recipe import Format + + return DelayedScalingRequest(fp8_format=Format.HYBRID) + + +def current_scaling_quantizer_factory( + role: Optional[QuantizerRole], +) -> "Float8CurrentScalingQuantizer": + """Factory that mirrors :class:`Float8CurrentScaling` recipe defaults. + + * Forward tensors (input, weight) → E4M3 + * Backward tensors (grad_output) → E5M2 + """ + from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8CurrentScalingQuantizer, + ) + + is_backward = role is not None and role.tensor_type == "grad_output" + fp8_dtype = tex.DType.kFloat8E5M2 if is_backward else tex.DType.kFloat8E4M3 + + return Float8CurrentScalingQuantizer( + fp8_dtype=fp8_dtype, + device=torch.device("cuda"), + force_pow_2_scales=False, # constrain scale to powers of 2 + amax_epsilon=0.0, # clamp amax from below to avoid div-by-zero + ) + + +def mxfp8_quantizer_factory( + role: Optional[QuantizerRole], # pylint: disable=unused-argument +) -> "MXFP8Quantizer": + """Factory that mirrors :class:`MXFP8BlockScaling` recipe defaults. + + * E4M3 by default for all tensors + * Block size 32, power-of-2 (E8M0) scales + """ + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer + + return MXFP8Quantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + ) + + +def float8_block_scaling_quantizer_factory( + role: Optional[QuantizerRole], +) -> "Float8BlockQuantizer": + """Factory that mirrors :class:`Float8BlockScaling` recipe defaults. + + * E4M3 by default for all tensors + * Weights use 2D block scaling, everything else uses 1D + * Power-of-2 scales by default + """ + from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockQuantizer, + ) + + is_weight = ( + role is not None + and role.module_type in ("linear", "grouped_linear") + and role.tensor_type == "weight" + ) + block_scaling_dim = 2 if is_weight else 1 + + return Float8BlockQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=True, + amax_epsilon=0.0, # clamp amax from below to avoid div-by-zero + force_pow_2_scales=True, + block_scaling_dim=block_scaling_dim, # 1 = 1D (1×128), 2 = 2D (128×128) + ) + + +def nvfp4_quantizer_factory( + role: Optional[QuantizerRole], +) -> "NVFP4Quantizer": + """Factory that mirrors :class:`NVFP4BlockScaling` recipe defaults. + + * All tensors quantized to E2M1 (FP4) + * Weights: 2D quantization (16x16 blocks), no RHT, no stochastic rounding + * Inputs: 1D quantization, RHT enabled, no stochastic rounding + * Grads: 1D quantization, RHT enabled, stochastic rounding enabled + + Quantizer knobs: + fp4_dtype - E2M1 (only supported format) + with_rht - randomized Hadamard transform (smooths outliers) + with_post_rht_amax - recompute amax after RHT (should match with_rht) + with_2d_quantization - 16x16 2D blocks (vs 1x16 1D) + stochastic_rounding - probabilistic rounding to reduce quant bias + with_random_sign_mask - random sign flip in the Hadamard matrix + """ + from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer + + is_linear = role is not None and role.module_type in ("linear", "grouped_linear") + is_weight = is_linear and role.tensor_type == "weight" + is_grad = is_linear and role.tensor_type == "grad_output" + + if is_weight: + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=True, + stochastic_rounding=False, + with_random_sign_mask=True, + ) + + if is_grad: + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_rht=True, + with_post_rht_amax=True, + with_2d_quantization=False, + stochastic_rounding=True, + with_random_sign_mask=True, + ) + + # For input and unknown roles + return NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=True, + columnwise=True, + with_rht=True, + with_post_rht_amax=True, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=True, + ) diff --git a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py b/transformer_engine/pytorch/custom_recipes/quantization_ref_current_scaling.py similarity index 98% rename from transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py rename to transformer_engine/pytorch/custom_recipes/quantization_ref_current_scaling.py index 8580cf4a33..ecbb667ecf 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_current_scaling.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_ref_current_scaling.py @@ -18,17 +18,18 @@ def current_scaling_ref_quantizer_factory(role): """Factory function for current scaling reference quantizer. - Usage with CustomRecipe and autocast: + Receives a :class:`~transformer_engine.pytorch.quantization.QuantizerRole`. + + Backward tensors use E5M2, everything else uses E4M3. + + Usage with CustomRecipe and autocast:: + custom_recipe = recipe.CustomRecipe(qfactory=current_scaling_ref_quantizer_factory) with autocast(recipe=custom_recipe): output = model(input) """ - if role in ("linear_input", "linear_weight"): - dtype = torch.float8_e4m3fn - elif role in ("linear_output", "linear_grad_output"): - dtype = torch.float8_e5m2 - else: - return None + is_backward = role is not None and role.tensor_type == "grad_output" + dtype = torch.float8_e5m2 if is_backward else torch.float8_e4m3fn return CurrentScalingQuantizerRef( dtype=dtype, rowwise=True, diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py similarity index 98% rename from transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py rename to transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py index 12f8ef8f5b..acb7abefd1 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_ref_nvfp4.py @@ -18,33 +18,32 @@ def nvfp4_ref_rht_2d_quantizer_factory(role): """ Quantizer factory for NVFP4 recipe reference implementation (RHT and 2D quantization for weights). - Usage with CustomRecipe and autocast: + Receives a :class:`~transformer_engine.pytorch.quantization.QuantizerRole`. + + Usage with CustomRecipe and autocast:: + custom_recipe = recipe.CustomRecipe(qfactory=nvfp4_ref_rht_2d_quantizer_factory) - with autocast(fp8_recipe=custom_recipe): + with autocast(recipe=custom_recipe): output = model(input) """ - if role == "linear_input": - return NVFP4QuantizerRef( - dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(1, 16), - pow_2_scales=False, - with_rht=True, - ) - if role == "linear_weight": + is_weight_tensor_in_gemm = ( + role is not None + and role.module_type in ("linear", "grouped_linear") + and role.tensor_type == "weight" + ) + if is_weight_tensor_in_gemm: # 2D quantization for weights in GEMM-based modules return NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, quant_tile_shape=(16, 16), pow_2_scales=False, with_rht=False, ) - if role == "linear_grad_output": - return NVFP4QuantizerRef( - dtype=utils.Fp4Formats.E2M1, - quant_tile_shape=(1, 16), - pow_2_scales=False, - with_rht=True, - ) - return None + return NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + quant_tile_shape=(1, 16), + pow_2_scales=False, + with_rht=True, + ) def cast_to_fp4x2(x): diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index e6bedee0c0..746177ec78 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -27,8 +27,11 @@ Float8CurrentScalingRecipeState, Float8BlockScalingRecipeState, NVFP4BlockScalingRecipeState, + CustomRecipeState, FP8GlobalStateManager, + QuantizerRole, RecipeState, + _has_delayed_scaling_state, ) from ..distributed import ( gather_along_first_dim, @@ -789,6 +792,8 @@ def __init__(self, name: Optional[str] = None) -> None: self.activation_dtype: Optional[torch.dtype] = None self.wgrad_accumulation_and_reduce_hooks = [] self.wgrad_store = None + self._output_quantizer_role: Optional[QuantizerRole] = None + self._grad_input_quantizer_role: Optional[QuantizerRole] = None if not TEDebugState.debug_enabled: TEDebugState.initialize() @@ -809,6 +814,72 @@ def module_setattr(self, name: str, value: Any) -> None: """ super().__setattr__(name, value) + @property + def output_quantizer_role(self) -> Optional[QuantizerRole]: + """Caller-configurable :class:`QuantizerRole` for the forward output quantizer. + + When set, overrides the default role used by :meth:`get_quantizer_roles` + for the forward-pass output quantizer slot. Setting this after + quantizers have been created forces their recreation on the next + forward pass. + + See also :attr:`grad_input_quantizer_role` for the backward-pass + counterpart. + """ + return self._output_quantizer_role + + @output_quantizer_role.setter + def output_quantizer_role(self, role: Optional[QuantizerRole]) -> None: + if role == self._output_quantizer_role: + return + self._output_quantizer_role = role + if self.fp8_meta_tensors_initialized: + self.fp8_meta_tensors_initialized = False + + @property + def grad_input_quantizer_role(self) -> Optional[QuantizerRole]: + """Caller-configurable :class:`QuantizerRole` for the grad-input quantizer. + + Backward-pass counterpart of :attr:`output_quantizer_role`. + """ + return self._grad_input_quantizer_role + + @grad_input_quantizer_role.setter + def grad_input_quantizer_role(self, role: Optional[QuantizerRole]) -> None: + if role == self._grad_input_quantizer_role: + return + self._grad_input_quantizer_role = role + if self.fp8_meta_tensors_initialized: + self.fp8_meta_tensors_initialized = False + + def _warn_missing_output_quantizer_role( + self, + fp8_output: bool, + fp8_grad: bool, + ) -> None: + """Warn when quantized output is requested but no consumer role is set. + + Only relevant for ``CustomRecipe`` where the ``qfactory`` dispatches + on roles. Built-in recipes ignore role metadata. + """ + recipe = FP8GlobalStateManager.get_fp8_recipe() + if not recipe.custom(): + return + if fp8_output and self._output_quantizer_role is None: + warnings.warn( + f"{type(self).__name__}: fp8_output=True but " + "output_quantizer_role is not set. The CustomRecipe qfactory " + "will receive None for the output quantizer role.", + stacklevel=3, + ) + if fp8_grad and self._grad_input_quantizer_role is None: + warnings.warn( + f"{type(self).__name__}: fp8_grad=True but " + "grad_input_quantizer_role is not set. The CustomRecipe " + "qfactory will receive None for the grad-input quantizer role.", + stacklevel=3, + ) + @property def is_fsdp2(self) -> bool: """Whether this module is wrapped with FSDP2.""" @@ -901,21 +972,124 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: return if recipe.nvfp4() and isinstance(recipe_state, NVFP4BlockScalingRecipeState): return + if recipe.custom() and isinstance(recipe_state, CustomRecipeState): + return # Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and # 2 (grad_output and grad_input) for bwd num_fp8_tensors = self.fp8_meta["num_gemms"] * 3 if fwd else self.fp8_meta["num_gemms"] * 2 # Initialize recipe state and quantizers - recipe_state = RecipeState.create( + roles = self.get_quantizer_roles( # pylint: disable=assignment-from-none + fwd=fwd, num_quantizers=num_fp8_tensors + ) + if roles is not None: + assert ( + len(roles) == num_fp8_tensors + ), f"Recipe roles must match number of quantizers ({len(roles)=} vs {num_fp8_tensors=})" + recipe_state = RecipeState.create( # pylint: disable=assignment-from-none recipe, mode=("forward" if fwd else "backward"), num_quantizers=num_fp8_tensors, + roles=roles, ) + # Reached the rebuild path because ``fp8_meta_tensors_initialized`` + # was flipped to False after first init — most commonly because the + # ``output_quantizer_role`` / ``grad_input_quantizer_role`` setter + # invalidated state when a parent module (e.g. ``MultiheadAttention``) + # wired boundary roles. That setter is recipe-agnostic, so this code + # fires even for built-in recipes that don't consume role information + # in ``make_quantizers``. + # + # Rebuilding the recipe state must preserve persistent training + # buffers (delayed-scaling ``scale`` / ``amax_history``) so the new + # quantizer instances and the ``FP8GlobalStateManager`` reduction + # buffers end up viewing the SAME tensor objects, and so any + # checkpoint-loaded state isn't silently destroyed on the first + # forward after ``load_state_dict``. + old_state = self.fp8_meta.get(fp8_meta_tensor_key) + if old_state is not None: + recipe_state.inherit_state_from(old_state) + self.fp8_meta[fp8_meta_tensor_key] = recipe_state self.quantizers[fp8_meta_tensor_key] = recipe_state.make_quantizers() + def get_quantizer_roles( + self, + *, + fwd: bool, # pylint: disable=unused-argument + num_quantizers: int, # pylint: disable=unused-argument + ) -> Optional[List[QuantizerRole]]: + """Return an ordered list of :class:`QuantizerRole` for quantizers. + + Overview + -------- + When using ``CustomRecipe``, the quantizer factory is called once + per quantizer slot. Each call receives a ``QuantizerRole`` that + tells the factory *what* is being quantized so it can return the + right quantizer. + + This method builds the role list. Subclasses override it to + describe their internal GEMM layout. + + Slot layout + ----------- + Return one ``QuantizerRole`` per slot, in the same order as the + module's quantizer array. For example, ``Linear`` uses 3 + forward slots ``[input, weight, output]`` and 2 backward slots + ``[grad_output, grad_input]``. Multi-GEMM modules like + ``LayerNormMLP`` repeat that pattern for each GEMM: + ``[fc1_input, fc1_weight, fc1_output, fc2_input, fc2_weight, fc2_output]``. + + What to put in each slot + ------------------------ + Create a ``QuantizerRole(module_type=..., tensor_type=..., + name=...)`` for each slot: + + * **module_type** — the kind of module: ``"linear"``, + ``"grouped_linear"``, ``"dpa"``, etc. The factory can dispatch + on this to use different quantization formats per module type. + * **tensor_type** — what tensor this slot holds, in the module's + own vocabulary. For linears: ``"input"``, ``"weight"``, + ``"grad_output"``, etc. For DPA: ``"qkv"``, ``"s"``, + ``"do"``, ``"dp"``, etc. + * **name** — the instance name (e.g. ``"encoder.layer0.fc1"``), + propagated from the ``name=`` constructor argument. The factory + can dispatch on this to target specific layers. + + Boundary slots + -------------- + The last slot of a forward GEMM group (output) and the last slot + of a backward group (grad_input) are **boundary** slots — the + tensor leaves this module and enters an unknown consumer. For + these slots, use ``self._output_quantizer_role`` (fwd) and + ``self._grad_input_quantizer_role`` (bwd), which default to + ``None``. A parent module (e.g. ``MultiheadAttention``) can set + these attributes to fill in the consumer identity; see + ``MultiheadAttention._update_output_quantizer_roles`` for an + example. + + Return value + ------------ + * A list of ``QuantizerRole`` with length ``num_quantizers``. + * ``None`` to opt out of role-based dispatch. + + Not implemented (default) + ~~~~~~~~~~~~~~~~~~~~~~~~~ + The base implementation returns ``None``. When ``None`` is + returned, ``CustomRecipeState`` emits a warning and falls back + to bare ``QuantizerRole()`` instances (all fields empty) for + every slot. The factory still gets called once per slot, but + every call receives an identical empty role — it cannot + distinguish input from weight, forward from backward, or one + module from another. What happens then depends entirely on the + factory: it may return the same quantizer for all slots (acting + as a uniform recipe), or it may raise an error if it requires + meaningful roles to dispatch on. + """ + return None + def _update_weight_quantizers(self) -> None: """Update the quantizers for the weight tensors.""" weight_tensors = self._get_weight_tensors() @@ -1024,7 +1198,7 @@ def to_cpu(src: torch.Tensor) -> torch.Tensor: # Copy tensors to CPU and store state = {} state["recipe"] = self.fp8_meta["recipe"] - if state["recipe"].delayed(): + if _has_delayed_scaling_state(self.fp8_meta): state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale) state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history) state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale) @@ -1096,7 +1270,7 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: dst.copy_(src, non_blocking=True) # Load tensors - if self.fp8_meta["recipe"].delayed(): + if _has_delayed_scaling_state(self.fp8_meta): copy_tensor(state["scale_fwd"], self.fp8_meta["scaling_fwd"].scale) copy_tensor(state["amax_history_fwd"], self.fp8_meta["scaling_fwd"].amax_history) copy_tensor(state["scale_bwd"], self.fp8_meta["scaling_bwd"].scale) @@ -1223,7 +1397,7 @@ def prepare_forward( # Activation recomputation is used and this is the second forward phase. if self.fp8 and in_fp8_activation_recompute_phase(): - delayed_scaling_recipe = self.fp8_meta["recipe"].delayed() + delayed_scaling_recipe = _has_delayed_scaling_state(self.fp8_meta) FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta) else: if not inp.is_cuda: @@ -1242,14 +1416,15 @@ def prepare_forward( self.init_fp8_metadata(num_gemms=num_gemms) self._check_weight_tensor_recipe_correspondence() - delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() + delayed_scaling_recipe = self.fp8 and _has_delayed_scaling_state(self.fp8_meta) if delayed_scaling_recipe: if self.sequence_parallel: - if not self.fp8_meta["recipe"].reduce_amax: - raise ValueError( - "Amax reduction across tensor parallel group is " - "necessary when using sequence parallelism with FP8." - ) + assert ( + self.fp8_meta["recipe"].custom() or self.fp8_meta["recipe"].reduce_amax + ), ( + "Amax reduction across tensor parallel group is " + "necessary when using sequence parallelism with FP8." + ) if not FP8GlobalStateManager.fp8_graph_capturing(): FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta) @@ -1268,7 +1443,7 @@ def end_forward(self): Required to be called at the end of the forward function to properly handle DelayedScaling metadata handling and the NVTX ranges. """ - delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed() + delayed_scaling_recipe = self.fp8 and _has_delayed_scaling_state(self.fp8_meta) if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase(): FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta) nvtx_range_pop() diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 4ae7b47b9b..e950f26571 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -24,7 +24,7 @@ _2X_ACC_WGRAD, ) from ._common import WeightGradStore -from ..quantization import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager, QuantizerRole from ..utils import ( divide, cast_if_needed, @@ -116,7 +116,18 @@ def forward( # Configure quantizers if save_original_input and isinstance(input_quantizers[0], Float8Quantizer): - raise ValueError("DelayedScaling recipe is not supported with save_original_input") + if FP8GlobalStateManager.get_fp8_recipe().custom(): + # Custom recipe factory may produce DS quantizers unknown to caller. + # TODO(negvet): fix on Megatron side — guard should also exclude 'custom', or + # better: check at runtime whether quantizers are DS-based. + warnings.warn( + "save_original_input is incompatible with delayed-scaling quantizers " + "(Float8Quantizer). Disabling save_original_input for this module.", + stacklevel=2, + ) + save_original_input = False + else: + raise ValueError("DelayedScaling recipe is not supported with save_original_input") if input_quantizers[0] is not None: for input_quantizer in input_quantizers: input_quantizer.set_usage( @@ -829,6 +840,33 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: if recipe.float8_current_scaling(): self._customize_quantizers_float8_current_scaling(fwd, recipe) + def get_quantizer_roles( + self, + *, + fwd: bool, + num_quantizers: int, + ) -> Optional[List[QuantizerRole]]: + """QuantizerRole list for quantizers used by ``GroupedLinear``. + + For grouped GEMMs we repeat the same pattern for each GEMM in + order. The output (fwd) and grad-input (bwd) slots default to + ``None`` (unknown consumer). Set :attr:`output_quantizer_role` / + :attr:`grad_input_quantizer_role` to provide consumer identity. + """ + name = self.name or "" + if fwd: + base = [ + QuantizerRole(module_type="grouped_linear", tensor_type="input", name=name), + QuantizerRole(module_type="grouped_linear", tensor_type="weight", name=name), + self._output_quantizer_role, + ] + else: + base = [ + QuantizerRole(module_type="grouped_linear", tensor_type="grad_output", name=name), + self._grad_input_quantizer_role, + ] + return [base[i % len(base)] for i in range(num_quantizers)] + def make_grouped_weights(self, defer_init=False) -> None: """ Convert parameters into a GroupedTensor and re-register them as parameters. diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index abfa6af034..8c88f3ee82 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -28,7 +28,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..quantization import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager, QuantizerRole from ..utils import ( assert_dim_for_fp8_exec, cast_if_needed, @@ -1504,6 +1504,32 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: elif recipe.nvfp4(): self._customize_quantizers_nvfp4(fwd, recipe) + def get_quantizer_roles( + self, + *, + fwd: bool, + num_quantizers: int, + ) -> Optional[List[QuantizerRole]]: + """QuantizerRole list for quantizers used by ``LayerNormLinear``. + + The output (fwd) and grad-input (bwd) slots default to ``None`` + (unknown consumer). Set :attr:`output_quantizer_role` / + :attr:`grad_input_quantizer_role` to provide consumer identity. + """ + name = self.name or "" + if fwd: + base = [ + QuantizerRole(module_type="linear", tensor_type="input", name=name), + QuantizerRole(module_type="linear", tensor_type="weight", name=name), + self._output_quantizer_role, + ] + else: + base = [ + QuantizerRole(module_type="linear", tensor_type="grad_output", name=name), + self._grad_input_quantizer_role, + ] + return [base[i % len(base)] for i in range(num_quantizers)] + def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( @@ -1713,6 +1739,9 @@ def forward( def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): if not self.fp8: return [None] * 6 + + self._warn_missing_output_quantizer_role(fp8_output, fp8_grad) + grad_input_quantizer = None grad_weight_quantizer = None grad_output_quantizer = None diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 4fa7eb2856..46918ff0f1 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -29,7 +29,7 @@ _2X_ACC_DGRAD, _2X_ACC_WGRAD, ) -from ..quantization import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager, QuantizerRole from ..jit import ( bias_gelu_fused, bgrad_dgelu_fused, @@ -2104,6 +2104,53 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: elif recipe.nvfp4(): self._customize_quantizers_nvfp4(fwd, recipe) + def get_quantizer_roles( + self, + *, + fwd: bool, + num_quantizers: int, + ) -> Optional[List[QuantizerRole]]: + """QuantizerRole list for quantizers used by ``LayerNormMLP``. + + Each internal GEMM (fc1, fc2) gets a distinct name suffix so that + custom-recipe factories can target them individually. + + The module's final output (fc2 fwd) and final grad (fc1 bwd) + slots default to ``None`` (unknown consumer). Set + :attr:`output_quantizer_role` / :attr:`grad_input_quantizer_role` + to provide consumer identity. Internal boundaries use fixed + roles with known consumer identity. + """ + base_name = self.name or "" + fc1_name = f"{base_name}.fc1" if base_name else "fc1" + fc2_name = f"{base_name}.fc2" if base_name else "fc2" + # Roles use the *consumer's* identity: internal boundary tensors are + # labeled with the downstream module that will consume them. + # + # Forward: fc1_input -> fc1 GEMM -> [act] -> fc2_input -> fc2 GEMM -> output + # Backward: grad_input <- fc1 GEMM <- [act'] <- fc2 GEMM <- grad_output + if fwd: + base = [ + QuantizerRole(module_type="linear", tensor_type="input", name=fc1_name), + QuantizerRole(module_type="linear", tensor_type="weight", name=fc1_name), + # fc1 output — consumed by fc2 (via activation), so labeled as fc2 input + QuantizerRole(module_type="linear", tensor_type="input", name=fc2_name), + QuantizerRole(module_type="linear", tensor_type="input", name=fc2_name), + QuantizerRole(module_type="linear", tensor_type="weight", name=fc2_name), + # fc2 output — boundary, consumer unknown + self._output_quantizer_role, + ] + else: + base = [ + QuantizerRole(module_type="linear", tensor_type="grad_output", name=fc1_name), + # fc1 grad_input — boundary, consumer unknown + self._grad_input_quantizer_role, + QuantizerRole(module_type="linear", tensor_type="grad_output", name=fc2_name), + # fc2 grad_input — consumed by fc1 (via activation'), so labeled as fc1 grad_output + QuantizerRole(module_type="linear", tensor_type="grad_output", name=fc1_name), + ] + return [base[i % len(base)] for i in range(num_quantizers)] + def reset_layer_norm_parameters(self) -> None: """Init LN params""" warnings.warn( @@ -2336,6 +2383,9 @@ def forward( return out def _get_quantizers(self, fp8_output, is_grad_enabled): + if self.fp8: + self._warn_missing_output_quantizer_role(fp8_output, False) + ( fc1_input_quantizer, fc1_output_quantizer, diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 2b14eaaf2e..dcbb9eaf93 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -3,7 +3,8 @@ # See LICENSE for license information. """Linear API""" -from typing import Callable, Dict, Optional, Tuple, Union, List +from dataclasses import dataclass +from typing import Any, Callable, Dict, Optional, Tuple, Union, List from functools import reduce from operator import mul as multiply_op import warnings @@ -27,13 +28,12 @@ _2X_ACC_WGRAD, ) from ._common import noop_cat, WeightGradStore -from ..quantization import FP8GlobalStateManager +from ..quantization import FP8GlobalStateManager, QuantizerRole from ..utils import ( cast_if_needed, clear_tensor_data, divide, init_method_constant, - requires_grad, needs_quantized_gemm, assert_dim_for_fp8_exec, nvtx_range_pop, @@ -80,6 +80,163 @@ __all__ = ["Linear"] +TensorOrQuantized = Union[torch.Tensor, QuantizedTensorStorage] + + +@dataclass(slots=True) +class LinearFwdArgs: + """Single-argument bag for the forward path of :class:`_Linear`.""" + + # --- Differentiable tensors (also passed positionally to autograd) --- + weight: TensorOrQuantized + inp: torch.Tensor + bias: Optional[torch.Tensor] + + # --- Non-differentiable cached tensors --- + weight_workspace: Optional[torch.Tensor] + + # --- requires_grad flags (cached so backward does not re-query) --- + input_requires_grad: bool + weight_requires_grad: bool + bias_requires_grad: bool + + # --- Quantizers --- + input_quantizer: Optional[Quantizer] + weight_quantizer: Optional[Quantizer] + output_quantizer: Optional[Quantizer] + grad_input_quantizer: Optional[Quantizer] + grad_weight_quantizer: Optional[Quantizer] + grad_output_quantizer: Optional[Quantizer] + + # --- Numerical / dtype config --- + activation_dtype: torch.dtype + fp8: bool + fp8_calibration: bool + fp8_output: bool + save_original_input: bool + backward_override: Optional[str] + custom: bool + debug: bool + + # --- Weight-workspace caching --- + is_first_microbatch: Optional[bool] + cache_weight: bool + skip_fp8_weight_update: Optional[bool] + + # --- Tensor / sequence parallelism --- + parallel_mode: Optional[str] + tp_group: Optional[Any] + tp_size: int + tensor_parallel: bool + sequence_parallel: bool + symmetric_ar_type: Optional[str] + backward_input_needs_gather: bool + + # --- Userbuffers (comm + GEMM overlap) --- + ub_name: Optional[str] + ub_overlap_ag_fprop: bool + ub_overlap_rs_fprop: bool + ub_overlap_ag_dgrad: bool + ub_overlap_rs_dgrad: bool + ub_bulk_dgrad: bool + ub_bulk_wgrad: bool + + # --- FSDP --- + fsdp_group: Optional[Any] + is_fsdp2: bool + + # --- Weight-grad scheduling --- + fuse_wgrad_accumulation: bool + wgrad_store: Optional[Any] + + # --- Misc --- + cpu_offloading: bool + is_grad_enabled: bool + + +@dataclass(slots=True) +class LinearBwdArgs: + """Single-argument bag for the backward path of :class:`_Linear`.""" + + # --- Saved / restored tensors (populated at backward entry) --- + grad_output: Optional[torch.Tensor] = None + inputmat: Optional[TensorOrQuantized] = None + weight_fp8: Optional[TensorOrQuantized] = None + saved_weight: Optional[TensorOrQuantized] = None + bias: Optional[torch.Tensor] = None + + # --- Quantizers --- + input_quantizer: Optional[Quantizer] = None + weight_quantizer: Optional[Quantizer] = None + grad_input_quantizer: Optional[Quantizer] = None + grad_weight_quantizer: Optional[Quantizer] = None + grad_output_quantizer: Optional[Quantizer] = None + + # --- Differentiability summary --- + use_bias: bool = False + requires_dgrad: bool = False + requires_wgrad: bool = False + inp_shape: Optional[torch.Size] = None + + # --- Numerical / dtype config --- + activation_dtype: Optional[torch.dtype] = None + fp8: bool = False + fp8_recipe: Optional[Recipe] = None + backward_override: Optional[str] = None + is_weight_param_quantized: bool = False + custom: bool = False + debug: bool = False + + # --- Tensor / sequence parallelism --- + parallel_mode: Optional[str] = None + tp_group: Optional[Any] = None + tp_size: int = 1 + tensor_parallel: bool = False + sequence_parallel: bool = False + backward_input_needs_gather: bool = False + + # --- Userbuffers (comm + GEMM overlap) --- + ub_name: Optional[str] = None + ub_overlap_ag: bool = False + ub_overlap_rs_dgrad: bool = False + ub_bulk_dgrad: bool = False + ub_bulk_wgrad: bool = False + + # --- FSDP --- + fsdp_group: Optional[Any] = None + fsdp_shapes: Any = None + is_fsdp2: bool = False + + # --- Weight-grad scheduling / accumulation --- + is_first_microbatch: Optional[bool] = None + fuse_wgrad_accumulation: bool = False + wgrad_store: Optional[Any] = None + origin_weight_ref: Optional[Any] = None + origin_weight_overwrites_main_grad: bool = False + main_grad_func: Optional[Callable[[], torch.Tensor]] = None + + # --- FP8 reduce-and-update bookkeeping --- + reduce_and_update_bwd_fp8_tensors: bool = False + + # --- Misc --- + cpu_offloading: bool = False + owns_input: bool = False + + # --- Per-backward scratch state (populated inside _linear_backward) --- + ub_obj_gradout: Optional[Any] = None + + def setup_saved_tensors(self, ctx: torch.autograd.function.FunctionCtx) -> None: + """Pull saved tensors from ``ctx`` into the fields backward consumes.""" + ( + self.inputmat, + self.weight_fp8, + self.saved_weight, + self.bias, + ) = restore_from_func_ctx( + ctx + ) # pylint: disable=unbalanced-tuple-unpacking + + def _check_fp8_reduce_and_update(): """Check if this is the first FP8 module (for backward reduce-and-update).""" qstate = FP8GlobalStateManager.quantization_state @@ -91,54 +248,39 @@ def _check_fp8_reduce_and_update(): def _linear_forward_impl( - weight: torch.Tensor, - weight_workspace: Optional[torch.Tensor], - inp: torch.Tensor, - bias: Optional[torch.Tensor], - non_tensor_args: Tuple, - input_quantizer: Optional[Quantizer], - weight_quantizer: Optional[Quantizer], - output_quantizer: Optional[Quantizer], -) -> Tuple: + args: LinearFwdArgs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple], None, Optional[Dict]]: """Forward implementation for the linear layer. - Returns (out, tensors_to_save, tensor_objects, ctx_attrs) where the last - three are None when gradients are disabled. + Returns ``(out, new_weight_workspace, tensors_to_save_from_forward, None, + ctx_attrs)``. ``new_weight_workspace`` is the freshly produced FP8 weight + workspace (returned alongside ``out`` so the caller can refresh its + cache). The last three are ``None`` when gradients are disabled. """ - ( - is_first_microbatch, - fp8, - fp8_calibration, - _wgrad_store, - _fuse_wgrad_accumulation, - cpu_offloading, - tp_group, - tp_size, - sequence_parallel, - tensor_parallel, - activation_dtype, - parallel_mode, - is_grad_enabled, - ub_overlap_rs_fprop, - _ub_overlap_ag_dgrad, - ub_overlap_ag_fprop, - _ub_overlap_rs_dgrad, - _ub_bulk_dgrad, - _ub_bulk_wgrad, - ub_name, - _fp8_output, - fsdp_group, - cache_weight, - skip_fp8_weight_update, - symmetric_ar_type, - save_original_input, - debug, - backward_override, - custom, - backward_input_needs_gather, - is_fsdp2, - ) = non_tensor_args + weight = args.weight + inp = args.inp + bias = args.bias + input_quantizer = args.input_quantizer + weight_quantizer = args.weight_quantizer + output_quantizer = args.output_quantizer + is_first_microbatch = args.is_first_microbatch + fp8 = args.fp8 + cpu_offloading = args.cpu_offloading + tp_group = args.tp_group + sequence_parallel = args.sequence_parallel + activation_dtype = args.activation_dtype + parallel_mode = args.parallel_mode + is_grad_enabled = args.is_grad_enabled + ub_overlap_rs_fprop = args.ub_overlap_rs_fprop + ub_overlap_ag_fprop = args.ub_overlap_ag_fprop + ub_name = args.ub_name + fsdp_group = args.fsdp_group + symmetric_ar_type = args.symmetric_ar_type + save_original_input = args.save_original_input + debug = args.debug + backward_override = args.backward_override + is_fsdp2 = args.is_fsdp2 if backward_override == "high_precision": save_original_input = True @@ -189,7 +331,7 @@ def _linear_forward_impl( if fp8 or debug: if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") - if not isinstance(inputmat, QuantizedTensorStorage) and not custom: + if not isinstance(inputmat, QuantizedTensorStorage) and not args.custom: own_quantized_input = True input_quantizer.set_usage( rowwise=True, @@ -280,12 +422,12 @@ def _linear_forward_impl( weightmat, new_weight_workspace = quantize_weight( tensor=weight, quantizer=weight_quantizer, - workspace=weight_workspace, + workspace=args.weight_workspace, update_workspace=update_ws, - skip_update_flag=skip_fp8_weight_update, + skip_update_flag=args.skip_fp8_weight_update, fsdp_group=fsdp_group, workspace_dtype=activation_dtype, - cache=cache_weight, + cache=args.cache_weight, ) weightmat.update_usage(rowwise_usage=True) @@ -303,7 +445,7 @@ def _linear_forward_impl( bias = cast_if_needed(bias, bias_dtype) if bias is not None else bias # Calibrate quantizers if needed - if not fp8 and fp8_calibration: + if not fp8 and args.fp8_calibration: if input_quantizer is not None: input_quantizer.calibrate(inputmat_total) if weight_quantizer is not None: @@ -363,12 +505,12 @@ def _linear_forward_impl( out = None if ub_overlap_rs_fprop: out = reduce_scatter_out - elif parallel_mode == "row" and tp_size > 1: + elif parallel_mode == "row" and args.tp_size > 1: nvtx_range_push(f"{nvtx_label}.row_parallel_comm") out = gemm_out if sequence_parallel: out, _ = reduce_scatter_along_first_dim(out, tp_group) - elif tensor_parallel: + elif args.tensor_parallel: if symmetric_ar_type is not None: out, _ = symmetric_all_reduce(out, tp_group, all_reduce_type=symmetric_ar_type) else: @@ -381,8 +523,7 @@ def _linear_forward_impl( # ------------------------------------------------------ # Prepare backward state - tensors_to_save = None - tensor_objects = None + tensors_to_save_from_forward = None ctx_attrs = None if is_grad_enabled: @@ -398,7 +539,8 @@ def _linear_forward_impl( if backward_override is not None: inputmat.update_usage(rowwise_usage=True, columnwise_usage=False) elif ( - backward_input_needs_gather and weight_quantizer.supports_only_rowwise_all_gather() + args.backward_input_needs_gather + and weight_quantizer.supports_only_rowwise_all_gather() ): # All-gather is not supported with FP8 column-wise data inputmat.update_usage(rowwise_usage=True, columnwise_usage=False) @@ -434,226 +576,224 @@ def _linear_forward_impl( wt_save = weightmat if is_fsdp2 and weightmat is not weight: wt_save = None - tensors_to_save, tensor_objects = prepare_for_saving( - saved_inputmat, - wt_save, - weight, - bias, - ) - owns_input = saved_inputmat is not inp + # Dedup save slots that alias forward inputs; ``_linear_setup_ctx`` + # rebuilds the refs from ``inp`` / ``weight`` / ``bias``. + # Needed for torch.compile to work correctly. + saved_tensor_aliases = ( + "inp" if saved_inputmat is inp else None, + "weight" if wt_save is weight else None, + "weight", # ``saved_weight`` slot is always the weight parameter + "bias" if bias is not None else None, + ) + tensors_to_save_from_forward = ( + None if saved_tensor_aliases[0] is not None else saved_inputmat, + None if saved_tensor_aliases[1] is not None else wt_save, + None, + None if saved_tensor_aliases[3] is not None else bias, + ) ctx_attrs = { - "weight_quantizer": weight_quantizer, "fsdp_shapes": fsdp_shapes, - "owns_input": owns_input, - "is_fsdp2": is_fsdp2, + "saved_tensor_aliases": saved_tensor_aliases, } - return out, new_weight_workspace, tensors_to_save, tensor_objects, ctx_attrs + return out, new_weight_workspace, tensors_to_save_from_forward, None, ctx_attrs def _linear_setup_ctx( - ctx, - tensors_to_save, - tensor_objects, - ctx_attrs, - inp, - weight, - bias, - non_tensor_args, - input_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, -): - """Save forward state into autograd context for backward pass.""" - ctx.save_for_backward(*tensors_to_save) - ctx.tensor_objects = tensor_objects - - ( - is_first_microbatch, - fp8, - _fp8_calibration, - wgrad_store, - fuse_wgrad_accumulation, - cpu_offloading, - tp_group, - tp_size, - sequence_parallel, - tensor_parallel, - activation_dtype, - parallel_mode, - _is_grad_enabled, - _ub_overlap_rs_fprop, - ub_overlap_ag_dgrad, - _ub_overlap_ag_fprop, - ub_overlap_rs_dgrad, - ub_bulk_dgrad, - ub_bulk_wgrad, - ub_name, - _fp8_output, - fsdp_group, - _cache_weight, - _skip_fp8_weight_update, - _symmetric_ar_type, - _save_original_input, - debug, - backward_override, - custom, - backward_input_needs_gather, - _is_fsdp2, - ) = non_tensor_args - - # Values derived from input tensors - ctx.use_bias = bias is not None - ctx.requires_dgrad = inp.requires_grad - ctx.requires_wgrad = weight.requires_grad - ctx.inp_shape = inp.shape + bwd_args: LinearBwdArgs, + fwd_args: LinearFwdArgs, + out: torch.Tensor, + ctx_attrs: Dict, + tensors_to_save_from_forward: Tuple[Any, ...], +) -> Tuple[Any, ...]: + """Populate ``bwd_args`` from forward state. + + Returns the merged list of tensors that should be passed through + ``prepare_for_saving`` by the caller (``_Linear.forward``). Keeping the + ``prepare_for_saving`` call out of here lets callers stitch in extra + tensors (e.g. the original ``weight`` parameter so backward can reuse it + for FSDP2 re-quantization) without having to mutate the structured + metadata returned by ``prepare_for_saving``. + """ + del out # No-op; kept for symmetry with the compile-time helper signature. + + inp = fwd_args.inp + weight = fwd_args.weight + bias = fwd_args.bias + + backward_override = fwd_args.backward_override + fp8 = fwd_args.fp8 + fuse_wgrad_accumulation = fwd_args.fuse_wgrad_accumulation # Quantizers - ctx.input_quantizer = input_quantizer - ctx.grad_input_quantizer = grad_input_quantizer - ctx.grad_weight_quantizer = grad_weight_quantizer - ctx.grad_output_quantizer = grad_output_quantizer - - # Values from non_tensor_args - ctx.activation_dtype = activation_dtype - ctx.fp8 = fp8 - ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None - ctx.backward_override = backward_override - ctx.is_weight_param_quantized = isinstance(weight, QuantizedTensorStorage) - ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation - ctx.cpu_offloading = cpu_offloading - ctx.is_first_microbatch = is_first_microbatch - ctx.sequence_parallel = sequence_parallel - ctx.tensor_parallel = tensor_parallel - ctx.parallel_mode = parallel_mode - ctx.tp_group = tp_group - ctx.tp_size = tp_size - ctx.ub_name = ub_name - ctx.fsdp_group = fsdp_group - ctx.debug = debug - ctx.wgrad_store = wgrad_store - ctx.ub_overlap_ag = ub_overlap_ag_dgrad - - ctx.ub_overlap_rs_dgrad = ub_overlap_rs_dgrad - ctx.ub_bulk_dgrad = ub_bulk_dgrad - ctx.ub_bulk_wgrad = ub_bulk_wgrad - - # Derived values - ctx.backward_input_needs_gather = backward_input_needs_gather - ctx.custom = custom - - # main_grad_func setup - if fuse_wgrad_accumulation and weight.requires_grad: - ctx.origin_weight_ref = weakref.ref(weight) - ctx.origin_weight_overwrites_main_grad = getattr(weight, "overwrite_main_grad", False) + bwd_args.input_quantizer = fwd_args.input_quantizer + bwd_args.weight_quantizer = ( + weight._quantizer if isinstance(weight, QuantizedTensor) else fwd_args.weight_quantizer + ) + bwd_args.grad_input_quantizer = fwd_args.grad_input_quantizer + bwd_args.grad_weight_quantizer = fwd_args.grad_weight_quantizer + bwd_args.grad_output_quantizer = fwd_args.grad_output_quantizer + + # Differentiability summary + bwd_args.use_bias = bias is not None + bwd_args.requires_dgrad = fwd_args.input_requires_grad + bwd_args.requires_wgrad = fwd_args.weight_requires_grad + bwd_args.inp_shape = inp.shape + + # Numerical / dtype config + bwd_args.activation_dtype = fwd_args.activation_dtype + bwd_args.fp8 = fp8 + bwd_args.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None + bwd_args.backward_override = backward_override + bwd_args.is_weight_param_quantized = isinstance(weight, QuantizedTensorStorage) + bwd_args.custom = fwd_args.custom + bwd_args.debug = fwd_args.debug + + # Tensor / sequence parallelism + bwd_args.parallel_mode = fwd_args.parallel_mode + bwd_args.tp_group = fwd_args.tp_group + bwd_args.tp_size = fwd_args.tp_size + bwd_args.tensor_parallel = fwd_args.tensor_parallel + bwd_args.sequence_parallel = fwd_args.sequence_parallel + bwd_args.backward_input_needs_gather = fwd_args.backward_input_needs_gather + + # Userbuffers + bwd_args.ub_name = fwd_args.ub_name + bwd_args.ub_overlap_ag = fwd_args.ub_overlap_ag_dgrad + bwd_args.ub_overlap_rs_dgrad = fwd_args.ub_overlap_rs_dgrad + bwd_args.ub_bulk_dgrad = fwd_args.ub_bulk_dgrad + bwd_args.ub_bulk_wgrad = fwd_args.ub_bulk_wgrad + + # FSDP + bwd_args.fsdp_group = fwd_args.fsdp_group + bwd_args.fsdp_shapes = ctx_attrs["fsdp_shapes"] + bwd_args.is_fsdp2 = fwd_args.is_fsdp2 + + # Weight-grad scheduling / accumulation + bwd_args.is_first_microbatch = fwd_args.is_first_microbatch + bwd_args.fuse_wgrad_accumulation = fuse_wgrad_accumulation + bwd_args.wgrad_store = fwd_args.wgrad_store + if fuse_wgrad_accumulation and fwd_args.weight_requires_grad: + bwd_args.origin_weight_ref = weakref.ref(weight) + bwd_args.origin_weight_overwrites_main_grad = getattr(weight, "overwrite_main_grad", False) if hasattr(weight, "__fsdp_param__"): - ctx.main_grad_func = weight.get_main_grad + bwd_args.main_grad_func = weight.get_main_grad else: - ctx.main_grad_func = lambda: weight.main_grad + bwd_args.main_grad_func = lambda: weight.main_grad - # Forward-computed values that can't be derived here - ctx.weight_quantizer = ctx_attrs["weight_quantizer"] - ctx.fsdp_shapes = ctx_attrs["fsdp_shapes"] - ctx.owns_input = ctx_attrs["owns_input"] - ctx.is_fsdp2 = ctx_attrs["is_fsdp2"] + # Misc + bwd_args.cpu_offloading = fwd_args.cpu_offloading - # backward overrides if backward_override is not None: - ctx.fp8 = False - ctx.debug = False - ctx.ub_overlap_ag = False - ctx.ub_overlap_rs_dgrad = False - ctx.ub_bulk_dgrad = False - ctx.ub_bulk_wgrad = False - ctx.grad_input_quantizer = None - ctx.grad_weight_quantizer = None - ctx.grad_output_quantizer = None - - -def _linear_backward( - ctx, - grad_output: torch.Tensor, - input_quantizer: Optional[Quantizer], - weight_quantizer: Optional[Quantizer], - grad_input_quantizer: Optional[Quantizer], - grad_weight_quantizer: Optional[Quantizer], - grad_output_quantizer: Optional[Quantizer], -) -> Tuple[Union[torch.Tensor, None], ...]: - """Backward implementation for the linear layer.""" + bwd_args.fp8 = False + bwd_args.debug = False + bwd_args.ub_overlap_ag = False + bwd_args.ub_overlap_rs_dgrad = False + bwd_args.ub_bulk_dgrad = False + bwd_args.ub_bulk_wgrad = False + bwd_args.grad_input_quantizer = None + bwd_args.grad_weight_quantizer = None + bwd_args.grad_output_quantizer = None + + saved_inputmat, wt_save, saved_weight, saved_bias = tensors_to_save_from_forward + inputmat_alias, wt_save_alias, saved_weight_alias, bias_alias = ctx_attrs[ + "saved_tensor_aliases" + ] + bwd_args.owns_input = inputmat_alias != "inp" + if inputmat_alias == "inp": + saved_inputmat = inp + if wt_save_alias == "weight": + wt_save = weight + if saved_weight_alias == "weight": + saved_weight = weight + if bias_alias == "bias": + saved_bias = bias + return (saved_inputmat, wt_save, saved_weight, saved_bias) + + +def _linear_backward(args: LinearBwdArgs) -> Tuple[Union[torch.Tensor, None], ...]: + """Backward implementation for the linear layer. + + Caller must have populated ``args.grad_output`` and run + ``args.setup_saved_tensors(ctx)`` before invocation. + """ + bwd_args = args + grad_output = args.grad_output + assert grad_output is not None + inputmat = args.inputmat + weight_fp8 = args.weight_fp8 + saved_weight = args.saved_weight + bias = args.bias + input_quantizer = args.input_quantizer + weight_quantizer = args.weight_quantizer + grad_input_quantizer = args.grad_input_quantizer + grad_weight_quantizer = args.grad_weight_quantizer + grad_output_quantizer = args.grad_output_quantizer # NVTX label for profiling nvtx_label = "transformer_engine._Linear.backward" - if ctx.ub_name is not None: - nvtx_label = f"{nvtx_label}.{ctx.ub_name}" + if bwd_args.ub_name is not None: + nvtx_label = f"{nvtx_label}.{bwd_args.ub_name}" with get_nvtx_range_context("_Linear_backward"): - ( - inputmat, - weight_fp8, - saved_weight, - bias, - ) = restore_from_func_ctx( # pylint: disable=unbalanced-tuple-unpacking - ctx - ) - origin_weight_python_object = None - origin_weight_overwrites_main_grad = getattr( - ctx, "origin_weight_overwrites_main_grad", False - ) + origin_weight_overwrites_main_grad = bwd_args.origin_weight_overwrites_main_grad main_grad = None - if ctx.fuse_wgrad_accumulation and ctx.requires_wgrad: - origin_weight_ref = ctx.origin_weight_ref - ctx.origin_weight_ref = None + if bwd_args.fuse_wgrad_accumulation and bwd_args.requires_wgrad: + origin_weight_ref = bwd_args.origin_weight_ref + bwd_args.origin_weight_ref = None origin_weight_python_object = ( origin_weight_ref() if origin_weight_ref is not None else None ) assert ( origin_weight_python_object is not None ), "weight was removed while fuse_wgrad_accumulation=True" - main_grad = ctx.main_grad_func() + main_grad = bwd_args.main_grad_func() origin_weight_python_object.main_grad = main_grad # Gather intermediate/activation tensors if needed - # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already + # NOTE: weight_fp8 = weight when bwd_args.fp8 == False and torch.disttributed.FSDP already # shards/unshards the base weights so we don't do it ourselves nvtx_range_push(f"{nvtx_label}.fsdp_gather") _fsdp_gather_tensors( - ctx.fsdp_group, - ctx.fsdp_shapes, + bwd_args.fsdp_group, + bwd_args.fsdp_shapes, inputmat, weight_fp8, ) nvtx_range_pop(f"{nvtx_label}.fsdp_gather") # Configure Userbuffers communication (comm+GEMM overlap) - ctx.ub_obj_gradout = None + bwd_args.ub_obj_gradout = None ub_obj_dgrad = None ub_obj_wgrad = None ub_type_dgrad = None ub_type_wgrad = None - dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] - if ctx.ub_overlap_ag: + dgrad_shape = [ + reduce(multiply_op, bwd_args.inp_shape[:-1]), + bwd_args.inp_shape[-1], + ] + if bwd_args.ub_overlap_ag: # Overlap grad_output all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) - ub_obj_dgrad = ctx.ub_obj_gradout + bwd_args.ub_obj_gradout = get_ub(bwd_args.ub_name + "_dgrad", bwd_args.fp8) + ub_obj_dgrad = bwd_args.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG - elif ctx.ub_overlap_rs_dgrad: + elif bwd_args.ub_overlap_rs_dgrad: # Overlap dgrad reduce-scatter with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) - ub_obj_dgrad = ctx.ub_obj_gradout + bwd_args.ub_obj_gradout = get_ub(bwd_args.ub_name + "_dgrad", bwd_args.fp8) + ub_obj_dgrad = bwd_args.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.RS else: - if ctx.ub_bulk_dgrad: + if bwd_args.ub_bulk_dgrad: # Overlap inputmat all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) - ub_obj_dgrad = ctx.ub_obj_gradout + bwd_args.ub_obj_gradout = get_ub(bwd_args.ub_name + "_dgrad", bwd_args.fp8) + ub_obj_dgrad = bwd_args.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG - if ctx.ub_bulk_wgrad: + if bwd_args.ub_bulk_wgrad: # Overlap dgrad reduce-scatter with wgrad compute - ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) + ub_obj_wgrad = get_ub(bwd_args.ub_name + "_wgrad", bwd_args.fp8) ub_type_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- @@ -670,7 +810,7 @@ def _linear_backward( if grad_output_quantizer is not None: quantizer = grad_output_quantizer quantizer.set_usage(rowwise=True, columnwise=True) - if ctx.ub_overlap_ag: + if bwd_args.ub_overlap_ag: # Userbuffers only supports communication for one # tensor usage at a time. Configure quantizer with # usage for only dgrad GEMM. @@ -680,20 +820,28 @@ def _linear_backward( # on whether wgrad calculations will be performed. # NOTE: If requires_dgrad is False, disabling `rowwise` quantization and keeping `columnwise` quantization # results in `Assertion failed: output_tensor->has_data(). Quantizing in only the columnwise direction not supported yet!` - # NOTE: For `ctx.bias is True`, selected quantize kernel errors with + # NOTE: For `bias is True`, selected quantize kernel errors with # `cast_kernels.cuh:1322 in function fp8_quantize_arch_l_100: Not implemented scaling mode or fusion: NVTE_DELAYED_TENSOR_SCALING or IS_DBIAS=true on GPU with compute capability < 10.0.` - if not ctx.use_bias and not ctx.requires_wgrad and grad_output_quantizer is not None: + if ( + not bwd_args.use_bias + and not bwd_args.requires_wgrad + and grad_output_quantizer is not None + ): grad_output_quantizer.set_usage(columnwise=False) - # Prepare grad output tensor + # Prepare grad output tensor. + # ``grad_output_preprocess`` accesses a small set of attributes + # (sequence_parallel, fp8, backward_override, debug, ub_overlap_ag, + # tp_group, ub_obj_gradout, use_bias). ``LinearBwdArgs`` exposes the + # same names so we can pass it directly. nvtx_range_push(f"{nvtx_label}.grad_output_preprocess") ( grad_output, grad_bias, ) = TransformerEngineBaseModule.grad_output_preprocess( - ctx, + bwd_args, grad_output, - ctx.parallel_mode == "row", + bwd_args.parallel_mode == "row", grad_output_quantizer, ) nvtx_range_pop(f"{nvtx_label}.grad_output_preprocess") @@ -710,12 +858,12 @@ def _linear_backward( # -------------------------------------------------- inputmat_total = None inputmat_total_work = None - if ctx.requires_wgrad: - if ctx.fp8 or ctx.debug: + if bwd_args.requires_wgrad: + if bwd_args.fp8 or bwd_args.debug: if isinstance(inputmat, QuantizedTensorStorage): # Input tensor is already quantized pass - elif ctx.debug or ctx.custom: + elif bwd_args.debug or bwd_args.custom: # Debug quantizer will be applied immediately before wgrad GEMM pass else: @@ -725,19 +873,19 @@ def _linear_backward( # All-gather is not supported with FP8 column-wise data quantizer.set_usage( rowwise=True, - columnwise=not ctx.backward_input_needs_gather, + columnwise=not bwd_args.backward_input_needs_gather, ) else: quantizer.set_usage(rowwise=False, columnwise=True) inputmat = quantizer(inputmat) else: if isinstance(inputmat, QuantizedTensorStorage): - inputmat = inputmat.dequantize(dtype=ctx.activation_dtype) + inputmat = inputmat.dequantize(dtype=bwd_args.activation_dtype) else: - inputmat = cast_if_needed(inputmat, ctx.activation_dtype) - if ctx.backward_input_needs_gather: + inputmat = cast_if_needed(inputmat, bwd_args.activation_dtype) + if bwd_args.backward_input_needs_gather: quantizer = None - if ctx.fp8 or ctx.debug: + if bwd_args.fp8 or bwd_args.debug: quantizer = input_quantizer if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually @@ -745,18 +893,18 @@ def _linear_backward( else: # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) - if ctx.ub_bulk_dgrad: + if bwd_args.ub_bulk_dgrad: inputmat_total, _ = fill_userbuffers_buffer_for_all_gather( ub_obj_dgrad, inputmat, quantizer, - ctx.tp_group, + bwd_args.tp_group, ) else: nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") inputmat_total, inputmat_total_work = gather_along_first_dim( inputmat, - ctx.tp_group, + bwd_args.tp_group, async_op=True, quantizer=quantizer, ) @@ -773,7 +921,7 @@ def _linear_backward( dgrad = None dgrad_work = None - if ctx.requires_dgrad: + if bwd_args.requires_dgrad: # FSDP2: Re-create workspace from all-gathered weight when # workspace was not saved. (Issue #2681) @@ -784,15 +932,15 @@ def _linear_backward( # saved weight is already set to right usages by # fsdp2 quantized-tensor hooks when workspace was not saved. weight_fp8 = saved_weight - elif ctx.weight_quantizer is not None: - ctx.weight_quantizer.set_usage(rowwise=True, columnwise=True) - weight_fp8 = ctx.weight_quantizer(saved_weight) + elif bwd_args.weight_quantizer is not None: + bwd_args.weight_quantizer.set_usage(rowwise=True, columnwise=True) + weight_fp8 = bwd_args.weight_quantizer(saved_weight) # Make sure required data is available if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(rowwise_usage=True) if ( - ctx.fp8 + bwd_args.fp8 and weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensorStorage) ): @@ -800,8 +948,8 @@ def _linear_backward( # Choose whether to use GEMM kernel with split accumulator use_split_accumulator = _2X_ACC_DGRAD - if ctx.fp8: - recipe = ctx.fp8_recipe + if bwd_args.fp8: + recipe = bwd_args.fp8_recipe if hasattr(recipe, "fp8_gemm_dgrad"): use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator @@ -812,11 +960,13 @@ def _linear_backward( # Output buffers for Userbuffers reduce-scatter gemm_out = None reduce_scatter_out = None - if ctx.ub_overlap_rs_dgrad: + if bwd_args.ub_overlap_rs_dgrad: reduce_scatter_out = torch.empty( - dgrad_shape, dtype=ctx.activation_dtype, device=grad_output_arg.device + dgrad_shape, + dtype=bwd_args.activation_dtype, + device=grad_output_arg.device, ) - elif ctx.ub_bulk_wgrad: + elif bwd_args.ub_bulk_wgrad: gemm_out = ub_obj_wgrad.get_buffer(local_chunk=False) # dgrad GEMM @@ -824,15 +974,15 @@ def _linear_backward( nvtx_range_push(f"{nvtx_label}.dgrad_gemm") weight_for_dgrad = weight_fp8 - if ctx.backward_override == "dequantized": + if bwd_args.backward_override == "dequantized": if isinstance(weight_for_dgrad, QuantizedTensorStorage): - weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) + weight_for_dgrad = weight_for_dgrad.dequantize(dtype=bwd_args.activation_dtype) else: - weight_for_dgrad = cast_if_needed(weight_for_dgrad, ctx.activation_dtype) - elif ctx.backward_override == "high_precision": + weight_for_dgrad = cast_if_needed(weight_for_dgrad, bwd_args.activation_dtype) + elif bwd_args.backward_override == "high_precision": weight_for_dgrad = saved_weight if isinstance(weight_for_dgrad, QuantizedTensorStorage): - weight_for_dgrad = weight_for_dgrad.dequantize(dtype=ctx.activation_dtype) + weight_for_dgrad = weight_for_dgrad.dequantize(dtype=bwd_args.activation_dtype) gemm_out, *_, reduce_scatter_out = general_gemm( weight_for_dgrad, grad_output, @@ -840,12 +990,12 @@ def _linear_backward( grad=True, quantization_params=grad_input_quantizer, out=gemm_out, - out_dtype=ctx.activation_dtype, + out_dtype=bwd_args.activation_dtype, use_split_accumulator=use_split_accumulator, ub=ub_obj_dgrad, ub_type=ub_type_dgrad, extra_output=reduce_scatter_out, - bulk_overlap=ctx.ub_bulk_dgrad, + bulk_overlap=bwd_args.ub_bulk_dgrad, ) nvtx_range_pop(f"{nvtx_label}.dgrad_gemm") @@ -854,26 +1004,26 @@ def _linear_backward( # and 2d block-scaled weights in TE managed memory. So we need to clear # it here. # (Issues #2681, #2717) - if getattr(ctx, "is_fsdp2", False) and isinstance(weight_fp8, QuantizedTensorStorage): + if bwd_args.is_fsdp2 and isinstance(weight_fp8, QuantizedTensorStorage): clear_columnwise_cache(weight_fp8) # Prepare grad input tensor # Note: Perform tensor-parallel communication - if ctx.ub_overlap_rs_dgrad: + if bwd_args.ub_overlap_rs_dgrad: dgrad = reduce_scatter_out - elif ctx.ub_bulk_wgrad: + elif bwd_args.ub_bulk_wgrad: dgrad = ub_obj_wgrad.get_buffer(local_chunk=True) - elif ctx.parallel_mode == "column" and ctx.tp_size > 1: + elif bwd_args.parallel_mode == "column" and bwd_args.tp_size > 1: nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad") dgrad = gemm_out - if ctx.sequence_parallel: + if bwd_args.sequence_parallel: dgrad, dgrad_work = reduce_scatter_along_first_dim( dgrad, - ctx.tp_group, + bwd_args.tp_group, async_op=True, ) else: - dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) + dgrad, dgrad_work = allreduce(dgrad, bwd_args.tp_group, async_op=True) nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad") else: dgrad = gemm_out @@ -887,7 +1037,7 @@ def _linear_backward( # -------------------------------------------------- wgrad = None - if ctx.requires_wgrad: + if bwd_args.requires_wgrad: # Prepare input tensor # Note: Synchronize tensor-parallel communication and @@ -895,7 +1045,7 @@ def _linear_backward( if inputmat_total_work is not None: inputmat_total_work.wait() inputmat_total_work = None - if ctx.fp8 or ctx.debug: + if bwd_args.fp8 or bwd_args.debug: if isinstance(inputmat_total, QuantizedTensorStorage): inputmat_total.update_usage(columnwise_usage=True) else: @@ -905,7 +1055,7 @@ def _linear_backward( # Prepare grad output tensor # Note: Synchronize tensor-parallel communication and # make sure required data is available - if ctx.ub_overlap_ag and isinstance(grad_output_quantizer, MXFP8Quantizer): + if bwd_args.ub_overlap_ag and isinstance(grad_output_quantizer, MXFP8Quantizer): # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we @@ -917,7 +1067,7 @@ def _linear_backward( dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() # This object is separate from the ub_obj_wgrad object which is passed to the GEMM - ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) + ub_obj_overlap_wgrad = get_ub(bwd_args.ub_name + "_wgrad", bwd_args.fp8) grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -929,7 +1079,7 @@ def _linear_backward( ub_obj_overlap_wgrad, grad_output_arg, grad_output_quantizer, - ctx.tp_group, + bwd_args.tp_group, ) # Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm @@ -937,7 +1087,7 @@ def _linear_backward( ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream ) - if ctx.fp8 or ctx.debug: + if bwd_args.fp8 or bwd_args.debug: if isinstance(grad_output, QuantizedTensorStorage): grad_output.update_usage(columnwise_usage=True) else: @@ -946,31 +1096,35 @@ def _linear_backward( # Figure out whether to use split accumulator use_split_accumulator = _2X_ACC_WGRAD - if ctx.fp8: - recipe = ctx.fp8_recipe + if bwd_args.fp8: + recipe = bwd_args.fp8_recipe if hasattr(recipe, "fp8_gemm_wgrad"): use_split_accumulator = recipe.fp8_gemm_wgrad.use_split_accumulator # Figure out whether to output wgrad GEMM directly into main grad - if ctx.is_first_microbatch is not None: + if bwd_args.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( - ctx.fuse_wgrad_accumulation and not ctx.is_first_microbatch + bwd_args.fuse_wgrad_accumulation and not bwd_args.is_first_microbatch ) else: - accumulate_wgrad_into_param_main_grad = ctx.fuse_wgrad_accumulation + accumulate_wgrad_into_param_main_grad = bwd_args.fuse_wgrad_accumulation # Output buffer for overlapping FP8 grad input # reduce-scatter with wgrad GEMM reduce_scatter_out = None - if ctx.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf(): + if bwd_args.ub_bulk_wgrad and ub_obj_wgrad.is_fp8_ubuf(): reduce_scatter_out = torch.empty( - dgrad_shape, dtype=ctx.activation_dtype, device=grad_output_arg.device + dgrad_shape, + dtype=bwd_args.activation_dtype, + device=grad_output_arg.device, ) # Arguments to include in wgrad GEMM closure wgrad_gemm_kwargs = { "out_dtype": ( - main_grad.dtype if ctx.fuse_wgrad_accumulation else ctx.activation_dtype + main_grad.dtype + if bwd_args.fuse_wgrad_accumulation + else bwd_args.activation_dtype ), "quantization_params": grad_weight_quantizer, "accumulate": ( @@ -979,14 +1133,14 @@ def _linear_backward( else False ), "layout": "NT", - "out": main_grad if ctx.fuse_wgrad_accumulation else None, - "bias": (bias if (grad_bias is None and not ctx.fp8) else None), + "out": main_grad if bwd_args.fuse_wgrad_accumulation else None, + "bias": (bias if (grad_bias is None and not bwd_args.fp8) else None), "use_split_accumulator": use_split_accumulator, "grad": True, "ub": ub_obj_wgrad, "ub_type": ub_type_wgrad, "extra_output": reduce_scatter_out, - "bulk_overlap": ctx.ub_bulk_wgrad, + "bulk_overlap": bwd_args.ub_bulk_wgrad, } def wgrad_gemm( @@ -1007,7 +1161,7 @@ def wgrad_gemm( return dw, db # Choose whether to call wgrad GEMM now or delay - if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): + if bwd_args.wgrad_store is not None and bwd_args.wgrad_store.delay_wgrad_compute(): if ( wgrad_gemm_kwargs["ub"] is not None or wgrad_gemm_kwargs["ub_type"] is not None @@ -1018,7 +1172,7 @@ def wgrad_gemm( "Delayed weight grad computation is not supported " "with Userbuffers (tensor-parallel communication overlapping)" ) - ctx.wgrad_store.put([inputmat_total, grad_output], wgrad_gemm) + bwd_args.wgrad_store.put([inputmat_total, grad_output], wgrad_gemm) else: # Call wgrad GEMM now @@ -1030,18 +1184,18 @@ def wgrad_gemm( del grad_bias_ # Deallocate tensors if permitted - if ctx.owns_input: + if bwd_args.owns_input: # Input tensor is internal clear_tensor_data(inputmat_total) - elif ctx.backward_input_needs_gather: + elif bwd_args.backward_input_needs_gather: # Gathered input tensor is internal clear_tensor_data(inputmat_total) - if ctx.parallel_mode == "row" and ctx.sequence_parallel: + if bwd_args.parallel_mode == "row" and bwd_args.sequence_parallel: # Gathered grad output tensor is internal clear_tensor_data(grad_output) # Update grad input if overlapping reduce-scatter with wgrad GEMM - if ctx.ub_bulk_wgrad: + if bwd_args.ub_bulk_wgrad: if ub_obj_wgrad.is_fp8_ubuf(): dgrad = reduce_scatter_out else: @@ -1052,7 +1206,7 @@ def wgrad_gemm( # -------------------------------------------------- # Don't return grad bias if not needed - if not ctx.use_bias: + if not bwd_args.use_bias: grad_bias = None # Make sure all tensor-parallel communication is finished @@ -1063,9 +1217,9 @@ def wgrad_gemm( dgrad_work.wait() dgrad_work = None - if ctx.requires_wgrad: + if bwd_args.requires_wgrad: # Handle custom DDP from mcore. - if ctx.fuse_wgrad_accumulation and hasattr( + if bwd_args.fuse_wgrad_accumulation and hasattr( origin_weight_python_object, "grad_added_to_main_grad" ): origin_weight_python_object.grad_added_to_main_grad = True @@ -1080,26 +1234,18 @@ def wgrad_gemm( list(main_grad.shape), origin_weight_python_object.dtype, ) - elif ctx.fuse_wgrad_accumulation: + elif bwd_args.fuse_wgrad_accumulation: wgrad = None else: wgrad = None # Scatter fp8 weight buffers - if ctx.fp8 and not ctx.is_weight_param_quantized: - _fsdp_scatter_tensors(ctx.fsdp_group, weight_fp8) + if bwd_args.fp8 and not bwd_args.is_weight_param_quantized: + _fsdp_scatter_tensors(bwd_args.fsdp_group, weight_fp8) return ( wgrad, - None, # weight_workspace - dgrad.view(ctx.inp_shape) if ctx.requires_dgrad else None, + dgrad.view(bwd_args.inp_shape) if bwd_args.requires_dgrad else None, grad_bias, - None, - None, - None, - None, - None, - None, - None, ) @@ -1112,73 +1258,75 @@ class _Linear(torch.autograd.Function): def forward( ctx, weight: torch.Tensor, - weight_workspace: Optional[torch.Tensor], inp: torch.Tensor, bias: Optional[torch.Tensor], - non_tensor_args: Tuple, - input_quantizer: Optional[Quantizer], - weight_quantizer: Optional[Quantizer], - output_quantizer: Optional[Quantizer], - grad_input_quantizer: Optional[Quantizer], - grad_weight_quantizer: Optional[Quantizer], - grad_output_quantizer: Optional[Quantizer], + fwd_args: LinearFwdArgs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """Forward pass: compute linear output and set up autograd context.""" - out, new_weight_workspace, tensors_to_save, tensor_objects, ctx_attrs = ( - _linear_forward_impl( - weight, - weight_workspace, - inp, - bias, - non_tensor_args, - input_quantizer, - weight_quantizer, - output_quantizer, - ) - ) + """Forward pass: compute linear output and set up autograd context. + + ``weight``, ``inp`` and ``bias`` are positional Tensor arguments so + autograd tracks them; they are immediately re-attached to ``fwd_args`` + so every downstream helper can be invoked with a single argument. + + ``weight_workspace`` is intentionally NOT a positional input: it is a + non-differentiable cached tensor passed in via + ``fwd_args.weight_workspace`` and the freshly produced workspace is + returned as a separate output so the module can refresh its cache. + """ + fwd_args.weight = weight + fwd_args.inp = inp + fwd_args.bias = bias + ( + out, + new_weight_workspace, + tensors_to_save_from_forward, + _, + ctx_attrs, + ) = _linear_forward_impl(fwd_args) if ctx is not None: - _linear_setup_ctx( - ctx, - tensors_to_save, - tensor_objects, + bwd_args = LinearBwdArgs() + tensors_to_save_from_setup = _linear_setup_ctx( + bwd_args, + fwd_args, + out, ctx_attrs, - inp, - weight, - bias, - non_tensor_args, - input_quantizer=input_quantizer, - grad_input_quantizer=grad_input_quantizer, - grad_weight_quantizer=grad_weight_quantizer, - grad_output_quantizer=grad_output_quantizer, + tensors_to_save_from_forward, ) - fp8 = non_tensor_args[1] - if fp8 and requires_grad(inp, weight, bias): - ctx.reduce_and_update_bwd_fp8_tensors = _check_fp8_reduce_and_update() - else: - ctx.reduce_and_update_bwd_fp8_tensors = False - if ctx.backward_override is not None: - ctx.reduce_and_update_bwd_fp8_tensors = False + tensors_to_save, tensor_objects = prepare_for_saving(*tensors_to_save_from_setup) + ctx.save_for_backward(*tensors_to_save) + ctx.tensor_objects = tensor_objects + ctx.backward_objects = bwd_args + if fwd_args.fp8 and ( + fwd_args.input_requires_grad + or fwd_args.weight_requires_grad + or fwd_args.bias_requires_grad + ): + bwd_args.reduce_and_update_bwd_fp8_tensors = _check_fp8_reduce_and_update() + if fwd_args.backward_override is not None: + bwd_args.reduce_and_update_bwd_fp8_tensors = False return out, new_weight_workspace @staticmethod def backward( - ctx, grad_output: torch.Tensor, _grad_weight_workspace + ctx, + grad_output: torch.Tensor, + _grad_weight_workspace, ) -> Tuple[Union[torch.Tensor, None], ...]: """Backward pass: compute gradients and reduce FP8 scaling factors.""" + bwd_args: LinearBwdArgs = ctx.backward_objects + bwd_args.grad_output = grad_output + bwd_args.setup_saved_tensors(ctx) nvtx_label = "transformer_engine._Linear.backward" - if ctx.ub_name is not None: - nvtx_label = f"{nvtx_label}.{ctx.ub_name}" - result = _linear_backward( - ctx, - grad_output, - input_quantizer=ctx.input_quantizer, - weight_quantizer=ctx.weight_quantizer, - grad_input_quantizer=ctx.grad_input_quantizer, - grad_weight_quantizer=ctx.grad_weight_quantizer, - grad_output_quantizer=ctx.grad_output_quantizer, - ) - if ctx.reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): + if bwd_args.ub_name is not None: + nvtx_label = f"{nvtx_label}.{bwd_args.ub_name}" + result = _linear_backward(bwd_args) + (None,) # fwd_args grad slot + reduce_and_update_bwd_fp8_tensors = bwd_args.reduce_and_update_bwd_fp8_tensors + # Drop all references held by bwd_args (saved tensors, quantizers, weakrefs, + # main_grad closure) so they don't outlive backward via ctx under retain_graph. + ctx.backward_objects = None + del bwd_args + if reduce_and_update_bwd_fp8_tensors and not is_graph_capturing(): nvtx_range_push(f"{nvtx_label}.reduce_and_update_fp8_tensors") FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) nvtx_range_pop(f"{nvtx_label}.reduce_and_update_fp8_tensors") @@ -1510,6 +1658,32 @@ def __init__( if name in self.weight_names or name in self.bias_names: param.skip_backward_post_hook = True + def get_quantizer_roles( + self, + *, + fwd: bool, + num_quantizers: int, + ) -> Optional[List[QuantizerRole]]: + """QuantizerRole list for quantizers used by ``Linear``. + + The output (fwd) and grad-input (bwd) slots default to ``None`` + (unknown consumer). Set :attr:`output_quantizer_role` / + :attr:`grad_input_quantizer_role` to provide consumer identity. + """ + name = self.name or "" + if fwd: + base = [ + QuantizerRole(module_type="linear", tensor_type="input", name=name), + QuantizerRole(module_type="linear", tensor_type="weight", name=name), + self._output_quantizer_role, + ] + else: + base = [ + QuantizerRole(module_type="linear", tensor_type="grad_output", name=name), + self._grad_input_quantizer_role, + ] + return [base[i % len(base)] for i in range(num_quantizers)] + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) @@ -1659,52 +1833,74 @@ def forward( ub_bulk_dgrad = self.ub_bulk_dgrad ub_bulk_wgrad = self.ub_bulk_wgrad - non_tensor_args = ( - is_first_microbatch, - self.fp8, - self.fp8_calibration, - self.wgrad_store, - self.fuse_wgrad_accumulation, - is_cpu_offload_enabled(), - self.tp_group, - self.tp_size, - self.sequence_parallel, - self.tp_size > 1, - self.activation_dtype, - self.parallel_mode, - is_grad_enabled, - ub_overlap_rs_fprop, - ub_overlap_ag_dgrad, - ub_overlap_ag_fprop, - ub_overlap_rs_dgrad, - ub_bulk_dgrad, - ub_bulk_wgrad, - self.ub_name, - fp8_output, - self.fsdp_group, - cache_name is not None, - skip_fp8_weight_update, - self.symmetric_ar_type, - self.save_original_input, - debug, - backward_override, - custom, - backward_input_needs_gather, - self.is_fsdp2, + linear_bias_tensor = ( + bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None + ) + wgrad_store = self.wgrad_store if self.wgrad_store.delay_wgrad_compute() else None + fwd_args = LinearFwdArgs( + # tensors + weight=weight_tensor, + inp=inp, + bias=linear_bias_tensor, + weight_workspace=weight_workspace, + # requires_grad flags + input_requires_grad=inp.requires_grad, + weight_requires_grad=weight_tensor.requires_grad, + bias_requires_grad=( + linear_bias_tensor.requires_grad if linear_bias_tensor is not None else False + ), + # quantizers + input_quantizer=input_quantizer, + weight_quantizer=weight_quantizer, + output_quantizer=output_quantizer, + grad_input_quantizer=grad_input_quantizer, + grad_weight_quantizer=grad_weight_quantizer, + grad_output_quantizer=grad_output_quantizer, + # numerical / dtype config + activation_dtype=self.activation_dtype, + fp8=self.fp8, + fp8_calibration=self.fp8_calibration, + fp8_output=fp8_output, + save_original_input=self.save_original_input, + backward_override=backward_override, + custom=custom, + debug=debug, + # weight-workspace caching + is_first_microbatch=is_first_microbatch, + cache_weight=cache_name is not None, + skip_fp8_weight_update=skip_fp8_weight_update, + # tensor / sequence parallelism + parallel_mode=self.parallel_mode, + tp_group=self.tp_group, + tp_size=self.tp_size, + tensor_parallel=self.tp_size > 1, + sequence_parallel=self.sequence_parallel, + symmetric_ar_type=self.symmetric_ar_type, + backward_input_needs_gather=backward_input_needs_gather, + # userbuffers + ub_name=self.ub_name, + ub_overlap_ag_fprop=ub_overlap_ag_fprop, + ub_overlap_rs_fprop=ub_overlap_rs_fprop, + ub_overlap_ag_dgrad=ub_overlap_ag_dgrad, + ub_overlap_rs_dgrad=ub_overlap_rs_dgrad, + ub_bulk_dgrad=ub_bulk_dgrad, + ub_bulk_wgrad=ub_bulk_wgrad, + # FSDP + fsdp_group=self.fsdp_group, + is_fsdp2=self.is_fsdp2, + # weight-grad scheduling + fuse_wgrad_accumulation=self.fuse_wgrad_accumulation, + wgrad_store=wgrad_store, + # misc + cpu_offloading=is_cpu_offload_enabled(), + is_grad_enabled=is_grad_enabled, ) out, new_weight_workspace = linear_fn( *autograd_ctx, weight_tensor, - weight_workspace, inp, - bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None, - non_tensor_args, - input_quantizer, - weight_quantizer, - output_quantizer, - grad_input_quantizer, - grad_weight_quantizer, - grad_output_quantizer, + linear_bias_tensor, + fwd_args, ) if new_weight_workspace is not None and cache_name is not None: @@ -1724,6 +1920,9 @@ def forward( def _get_quantizers(self, fp8_output, fp8_grad, is_grad_enabled): if not self.fp8: return [None] * 6 + + self._warn_missing_output_quantizer_role(fp8_output, fp8_grad) + grad_input_quantizer = None grad_weight_quantizer = None grad_output_quantizer = None diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 41f0855f1d..95e0440303 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -19,7 +19,7 @@ gather_along_first_dim, reduce_scatter_along_first_dim, ) -from ...quantization import FP8GlobalStateManager, Recipe +from ...quantization import FP8GlobalStateManager, QuantizerRole, Recipe from ...module.base import ( _2X_ACC_FPROP, _2X_ACC_DGRAD, @@ -275,6 +275,21 @@ def num_quantizers(self, mode: str) -> int: return 1 return 0 + def get_quantizer_roles(self, mode: str) -> Optional[list[QuantizerRole]]: + name = getattr(self, "name", "") or "" + if mode == "forward": + # BasicLinear owns input and weight quantizers. + # Output quantizer is provided by the next op (as its input quantizer). + return [ + QuantizerRole(module_type="linear", tensor_type="input", name=name), + QuantizerRole(module_type="linear", tensor_type="weight", name=name), + ] + if mode == "backward": + # BasicLinear owns grad_output quantizer. + # Grad_input quantizer is provided by the previous op (as its grad_output quantizer). + return [QuantizerRole(module_type="linear", tensor_type="grad_output", name=name)] + return None + def reset_parameters(self) -> None: """Initialize parameter buffers and values""" diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index c5c8ea3463..1687187230 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -16,6 +16,7 @@ from transformer_engine.common.recipe import Recipe from ..quantization import ( FP8GlobalStateManager, + QuantizerRole, RecipeState, autocast, ) @@ -209,6 +210,17 @@ def num_quantizers( """ return 0 + def get_quantizer_roles( + self, mode: str # pylint: disable=unused-argument + ) -> Optional[list[QuantizerRole]]: + """Return an ordered list of :class:`QuantizerRole` for quantizers. + + The returned list must be aligned with the internal quantizer ordering and + must have length ``num_quantizers(mode)`` for supported modes. + Returning ``None`` means "no explicit roles". + """ + return None + def get_input_quantizer(self) -> Optional[Quantizer]: if self.num_quantizers("forward") > 0: return self.get_quantizer("forward", 0) @@ -268,10 +280,17 @@ def reset_recipe_state( ) # Construct quantization recipe state - recipe_state = RecipeState.create( + roles = self.get_quantizer_roles(mode) # pylint: disable=assignment-from-none + if roles is not None: + assert len(roles) == num_quantizers, ( + "Recipe roles must match number of quantizers " + f"({len(roles)=} vs {num_quantizers=})" + ) + recipe_state = RecipeState.create( # pylint: disable=assignment-from-none recipe, mode=mode, num_quantizers=num_quantizers, + roles=roles, ) fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( forward=(mode == "forward"), diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index e9f009d93d..82b8274378 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -6,7 +6,7 @@ from __future__ import annotations import abc -import itertools +import dataclasses import warnings import os from dataclasses import dataclass, field @@ -41,6 +41,9 @@ "is_nvfp4_available", "get_default_recipe", "get_align_size_for_quantization", + "QuantizerRole", + "QuantizerRequest", + "DelayedScalingRequest", ] @@ -50,6 +53,99 @@ _FP8_BLOCK_SCALING_SUPPORT: Optional[Tuple[bool, str]] = None +@dataclasses.dataclass(frozen=True) +class QuantizerRole: + """Identity of a tensor slot requesting a quantizer. + + TE modules populate all fields they know about. + User factories inspect only the fields they care about. + + .. warning:: + **EXPERIMENTAL**: QuantizerRole is experimental, still under active development, + and the API is subject to change without notice. Use at your own risk. + + Fields + ------ + module_type : str + Module type that emits this role, e.g. `"linear"`, `"grouped_linear"`, `"dpa"`. + Empty string when not provided. + tensor_type : str + What tensor is being quantized, in the module's own vocabulary. + Linear modules: `"input"`, `"weight"`, `"grad_output"`, etc. + DPA: `"qkv"`, `"s"`, etc. + Empty string when not provided. + name : str + Caller-provided module instance name (e.g. set by the training + framework), e.g. + `"qkv"`, `"proj"`, `"fc1"`, `"fc2"`, `"linear_39"`. + Empty string when not provided. + """ + + module_type: str = "" + tensor_type: str = "" + name: str = "" + + def __str__(self) -> str: + parts = [] + if self.module_type: + parts.append(f"module_type={self.module_type}") + if self.tensor_type: + parts.append(f"tensor_type={self.tensor_type}") + if self.name: + parts.append(f"name={self.name}") + return "|".join(parts) if parts else "QuantizerRole()" + + +@dataclasses.dataclass(frozen=True) +class QuantizerRequest: + """Base class for stateful quantizer requests. + + Custom recipe factories return ``QuantizerRequest`` subclasses (instead of + quantizer instances) when the quantizer requires TE-managed shared state. + TE detects these requests, allocates the required state, and replaces them + with real quantizer instances. + + .. warning:: + **EXPERIMENTAL**: QuantizerRequest is experimental, still under active + development, and the API is subject to change without notice. + """ + + +@dataclasses.dataclass(frozen=True) +class DelayedScalingRequest(QuantizerRequest): + """Request a Float8Quantizer with TE-managed delayed scaling state. + + .. warning:: + **EXPERIMENTAL**: DelayedScalingRequest is experimental, still under active + development, and the API is subject to change without notice. + + All ``DelayedScalingRequest`` instances within the same ``CustomRecipeState`` + must share identical parameter values. + + Parameters + ---------- + fp8_format : Format, default = Format.HYBRID + Controls fwd/bwd dtype (HYBRID = E4M3 fwd, E5M2 bwd). + margin : int, default = 0 + Margin for scaling factor computation. + amax_history_len : int, default = 1024 + Length of the amax history window. + amax_compute_algo : str or Callable, default = "max" + Algorithm for choosing amax from history. + scaling_factor_compute_algo : Callable or None, default = None + Custom scaling factor computation. + reduce_amax : bool, default = True + Whether to all-reduce amax across the distributed group. + """ + + fp8_format: Format = Format.HYBRID + margin: int = 0 + amax_history_len: int = 1024 + amax_compute_algo: Union[str, Callable] = "max" + scaling_factor_compute_algo: Optional[Callable] = None + reduce_amax: bool = True + + def _compute_fp8_support() -> Tuple[bool, str]: """Return if fp8 support is available""" if get_device_compute_capability() >= (9, 0): # hopper and above @@ -383,7 +479,7 @@ def add_fp8_tensors_to_global_buffer( fp8_meta: Dict[str, Any], ) -> None: """ - Delayed scaling only. + Delayed scaling only (built-in or custom recipe with DS requests). The amax reduction process happens completely outside the FP8 modules. To participate in the reduction, the only role played by a module is @@ -398,8 +494,8 @@ def add_fp8_tensors_to_global_buffer( wrapper. For non CG case, it's called from within the module. """ - # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): + # noop unless delayed scaling state is present + if not _has_delayed_scaling_state(fp8_meta): return # Every module must call this function exactly once since @@ -417,7 +513,17 @@ def add_fp8_tensors_to_global_buffer( # Handles non-parameter FP8 modules, e.g. DPA. continue - key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"]) + state = fp8_meta[fp8_meta_tensor_key] + + # Determine recipe + buffers: built-in DS or custom with DS requests + if isinstance(state, CustomRecipeState) and state._has_delayed_scaling: + inner_recipe = state._inner_delayed_scaling_recipe + key = cls.get_key_in_buffer(forward, inner_recipe, fp8_meta["fp8_group"]) + # Register inner recipe in autocast_arguments for reduction + autocast_key = cls.get_unique_autocast_key(inner_recipe, fp8_meta["fp8_group"]) + qstate.autocast_arguments[autocast_key] = (inner_recipe, fp8_meta["fp8_group"]) + else: + key = cls.get_key_in_buffer(forward, fp8_meta["recipe"], fp8_meta["fp8_group"]) if key not in qstate.global_amax_buffer: qstate.global_amax_buffer[key] = [fp8_meta[fp8_meta_tensor_key].amax_history[0]] @@ -655,7 +761,7 @@ def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) - """ # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): + if not _has_delayed_scaling_state(fp8_meta): return buffer_position_key = "global_fp8_buffer_pos_fwd_recompute" @@ -682,7 +788,7 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non 1 forward for indentical numerical outputs. """ # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): + if not _has_delayed_scaling_state(fp8_meta): return # Store updated amaxes and scales from phase 1 post forward. @@ -703,7 +809,7 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: """Restore latest scaling factors and amaxes after recompute forward run.""" # delayed scaling only function, noop for any other recipe - if not fp8_meta["recipe"].delayed(): + if not _has_delayed_scaling_state(fp8_meta): return fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"]) @@ -1026,8 +1132,112 @@ class RecipeState(abc.ABC): This class may pack together the state for multiple quantizers, which is helpful for applying fused kernels with less overhead. + Subclasses that own mutable training buffers (e.g. delayed scaling's + ``scale`` / ``amax_history``) MUST list them in + :attr:`_persistent_state_buffers`. These buffers are preserved across + role-driven rebuilds and post-checkpoint resume via + :meth:`inherit_state_from`. Stateless subclasses leave the attribute + empty. """ + roles: Optional[List[QuantizerRole]] + mode: str + + # Names of mutable torch.Tensor attributes that represent persistent + # training state (e.g. running scale, amax history). The default + # ``inherit_state_from`` rebinds these from a predecessor RecipeState + # so external references (e.g. ``FP8GlobalStateManager`` reduction + # buffers) keep pointing at the same backing tensor. + _persistent_state_buffers: Tuple[str, ...] = () + + # Canonical tensor types that a recipe state can dispatch on. + _KNOWN_TENSOR_TYPES = ("input", "weight", "output", "grad_output", "grad_input") + # Positional fallback used when no role information is available: the + # tensor type at slot ``i`` defaults to ``_FWD_DEFAULT_TENSOR_TYPES[i % len]`` + # (forward) or ``_BWD_DEFAULT_TENSOR_TYPES[i % len]`` (backward). Mirrors + # the ``[input, weight, output, ...]`` / ``[grad_output, grad_input, ...]`` + # convention assumed by ``module/base.py::set_meta_tensor``. + _FWD_DEFAULT_TENSOR_TYPES = ("input", "weight", "output") + _BWD_DEFAULT_TENSOR_TYPES = ("grad_output", "grad_input") + + @staticmethod + def _validate_roles( + roles: Optional[List[QuantizerRole]], + num_quantizers: int, + ) -> None: + """Validate that ``roles``, if provided, has length ``num_quantizers``.""" + if roles is not None and len(roles) != num_quantizers: + raise ValueError( + "RecipeState requires roles to match num_quantizers " + f"({len(roles)=} vs {num_quantizers=})" + ) + + def _slot_role(self, idx: int) -> QuantizerRole: + """Resolve slot ``idx`` to a non-``None`` :class:`QuantizerRole`. + + This is the field-agnostic primitive that role-driven recipe states + use to dispatch on any combination of role fields (``tensor_type``, + ``module_type``, ``name``, future fields). + + Resolution rules: + + * If a real ``QuantizerRole`` was provided for this slot, it is + returned unchanged. Producers fill only the fields they know about; + the rest carry the dataclass defaults (empty strings). Consumers + should treat an empty field as "no signal" rather than as "no role + provided". + * Otherwise (whole ``roles`` list missing, or this slot is ``None``), + a bare ``QuantizerRole()`` with all fields empty is returned. + Field-specific fallback policies belong to the individual + dispatch convenience accessors (e.g. :meth:`_slot_tensor_type`), + not to this primitive — that way a future recipe state that + dispatches on, say, ``module_type`` is free to define its own + fallback policy without impacting tensor-type dispatch. + + The "real role vs bare-default role" distinction is hidden from + dispatch logic here. Recipe states that need to *warn* on missing + roles (as :class:`CustomRecipeState` does) should consult + ``self.roles[idx]`` directly. + """ + if self.roles is not None: + role = self.roles[idx] + if role is not None: + return role + return QuantizerRole() + + def _slot_tensor_type(self, idx: int) -> str: + """Convenience accessor: tensor-type dispatch with positional fallback. + + Resolves to one of :attr:`_KNOWN_TENSOR_TYPES`. Used by recipe states + whose dispatch only depends on the tensor's role within a GEMM + (input / weight / output / grad_output / grad_input), e.g. + Float8BlockScalingRecipeState, NVFP4BlockScalingRecipeState. + + Behavior: + + * If the resolved :meth:`_slot_role` carries a ``tensor_type`` in + :attr:`_KNOWN_TENSOR_TYPES`, return it. + * Otherwise (no role provided, a role with empty / non-canonical + ``tensor_type`` like DPA's ``"qkv"``, or a role that intentionally + only sets ``module_type``/``name``), fall back to the positional + default (forward: ``[input, weight, output, ...]``; + backward: ``[grad_output, grad_input, ...]``) indexed by + ``idx % len(default_tensor_types)``. + + This fallback policy is local to tensor-type dispatch; it does not + affect :meth:`_slot_role` or any other accessor. + """ + role = self._slot_role(idx) + if role.tensor_type in self._KNOWN_TENSOR_TYPES: + return role.tensor_type + # Positional fallback: tensor_type is missing or non-canonical. + default_tensor_types = ( + self._FWD_DEFAULT_TENSOR_TYPES + if self.mode == "forward" + else self._BWD_DEFAULT_TENSOR_TYPES + ) + return default_tensor_types[idx % len(default_tensor_types)] + @staticmethod def create( recipe: Recipe, @@ -1035,6 +1245,7 @@ def create( mode: str, num_quantizers: int = 1, device: Optional[torch.device] = None, + roles: Optional[List[QuantizerRole]] = None, ) -> RecipeState: """Factory method to create the state for a quantization recipe @@ -1048,6 +1259,9 @@ def create( Number of quantizers to create state for. device: torch.device, default = default CUDA device Device for quantized tensors. + roles: list of QuantizerRole, optional + Semantic roles for each quantizer slot. When provided, must + have length ``num_quantizers``. Returns ------- @@ -1076,6 +1290,7 @@ def create( mode=mode, num_quantizers=num_quantizers, device=device, + roles=roles, ) @abc.abstractmethod @@ -1088,6 +1303,43 @@ def make_quantizers(self) -> list: """ + def inherit_state_from(self, other: "RecipeState") -> bool: + """Take over persistent training buffers from a predecessor state. + + Used when a ``RecipeState`` is being replaced (e.g. role-driven + rebuild, post-checkpoint resume) but its mutable buffers must + survive. The default implementation rebinds attributes listed in + :attr:`_persistent_state_buffers` to ``other``'s tensor objects. + Rebinding (rather than copying values) ensures any external + references — most importantly the + :class:`FP8GlobalStateManager` reduction buffers — keep pointing + at storage that is also visible to this state's quantizers, so + amax reductions and quantization stay consistent. + + Subclasses with composed sub-states (e.g. :class:`CustomRecipeState` + owning an inner :class:`DelayedScalingRecipeState`) override this + to recurse / stash for later use during ``make_quantizers``. + + Returns + ------- + bool + ``True`` if any persistent buffer was inherited; ``False`` if + the states are incompatible (different class, mismatched + shapes / dtypes) and a fresh state should be used instead. + """ + if type(self) is not type(other): + return False + if not self._persistent_state_buffers: + return False + for name in self._persistent_state_buffers: + src = getattr(other, name) + dst = getattr(self, name) + if src.shape != dst.shape or src.dtype != dst.dtype: + return False + for name in self._persistent_state_buffers: + setattr(self, name, getattr(other, name)) + return True + class DelayedScalingRecipeState(RecipeState): """State for FP8 quantization with per-tensor delayed scaling. @@ -1105,6 +1357,10 @@ class DelayedScalingRecipeState(RecipeState): scale: torch.Tensor amax_history: torch.Tensor + # Persistent training state inherited across role-driven rebuilds. + # See ``RecipeState.inherit_state_from``. + _persistent_state_buffers = ("scale", "amax_history") + def __init__( self, recipe: DelayedScaling, @@ -1112,10 +1368,13 @@ def __init__( mode: str, num_quantizers: int = 1, device: Optional[torch.device] = None, + roles: Optional[List[QuantizerRole]] = None, ) -> None: + self._validate_roles(roles, num_quantizers) self.recipe = recipe self.mode = mode self.num_quantizers = num_quantizers + self.roles = roles self.dtype = get_fp8_te_dtype(recipe, mode == "forward") # Allocate buffers @@ -1158,10 +1417,13 @@ def __init__( mode: str, num_quantizers: int = 1, device: Optional[torch.device] = None, + roles: Optional[List[QuantizerRole]] = None, ) -> None: + self._validate_roles(roles, num_quantizers) self.recipe = recipe self.mode = mode self.num_quantizers = num_quantizers + self.roles = roles self.dtype = get_fp8_te_dtype(recipe, mode == "forward") # Allocate buffers @@ -1198,10 +1460,13 @@ def __init__( mode: str, num_quantizers: int = 1, device: Optional[torch.device] = None, + roles: Optional[List[QuantizerRole]] = None, ) -> None: + self._validate_roles(roles, num_quantizers) self.recipe = recipe self.mode = mode self.num_quantizers = num_quantizers + self.roles = roles self.dtype = get_fp8_te_dtype(recipe, mode == "forward") # Allocate buffers @@ -1235,10 +1500,13 @@ def __init__( mode: str, num_quantizers: int = 1, device: Optional[torch.device] = None, + roles: Optional[List[QuantizerRole]] = None, ) -> None: + self._validate_roles(roles, num_quantizers) self.recipe = recipe self.mode = mode self.num_quantizers = num_quantizers + self.roles = roles self.qx_dtype = get_fp8_te_dtype(recipe, True) self.qw_dtype = get_fp8_te_dtype(recipe, True) self.qgrad_dtype = get_fp8_te_dtype(recipe, False) @@ -1249,75 +1517,54 @@ def __init__( self.device = device def make_quantizers(self) -> list: + """Build one ``Float8BlockQuantizer`` per slot, dispatched by tensor type. + + Per-slot behavior, resolved via :meth:`RecipeState._slot_tensor_type`: + + * ``"weight"`` uses ``recipe.fp8_quant_fwd_weight`` and + ``recipe.w_block_scaling_dim``. + * ``"input"`` / ``"output"`` (and any unknown forward slot) use + ``recipe.fp8_quant_fwd_inp`` and ``recipe.x_block_scaling_dim``. + * ``"grad_output"`` / ``"grad_input"`` (and any unknown backward slot) + use ``recipe.fp8_quant_bwd_grad`` and ``recipe.grad_block_scaling_dim``. + + When the owning module/op provides a role list via + ``get_quantizer_roles``, the per-slot ``tensor_type`` drives dispatch. + Otherwise (or for boundary slots whose role is ``None``), the + positional fallback ``[input, weight, output, ...]`` / + ``[grad_output, grad_input, ...]`` is used. This matches the legacy + index-based convention, so behavior is unchanged for + modules that haven't adopted roles yet. + """ # TODO(ksivamani); Find better design for this, adding here to avoid circular import. from .tensor.float8_blockwise_tensor import Float8BlockQuantizer - if self.mode == "forward": - # The index convention (coming from base.py set_meta_tensor) - # is somewhat awkward, and doesn't play nicely with QuantizeOp, - # which is not associated with a GEMM. - assert self.num_quantizers % 3 == 0 # x, w, output per gemm - return list( - itertools.chain.from_iterable( - [ - [ - Float8BlockQuantizer( - fp8_dtype=self.qx_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, - block_scaling_dim=self.recipe.x_block_scaling_dim, - ), - Float8BlockQuantizer( - fp8_dtype=self.qw_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_fwd_weight.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_fwd_weight.power_2_scale, - block_scaling_dim=self.recipe.w_block_scaling_dim, - ), - Float8BlockQuantizer( - fp8_dtype=self.qx_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_fwd_inp.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_fwd_inp.power_2_scale, - block_scaling_dim=self.recipe.x_block_scaling_dim, - ), - ] - for _ in range(self.num_quantizers // 3) - ] - ) + def _make(tensor_type: str) -> Float8BlockQuantizer: + if tensor_type == "weight": + qparams = self.recipe.fp8_quant_fwd_weight + fp8_dtype = self.qw_dtype + block_scaling_dim = self.recipe.w_block_scaling_dim + elif tensor_type in ("grad_output", "grad_input"): + qparams = self.recipe.fp8_quant_bwd_grad + fp8_dtype = self.qgrad_dtype + block_scaling_dim = self.recipe.grad_block_scaling_dim + else: + # "input", "output", or any unknown forward type fall back to + # the input config, matching the legacy positional behavior. + qparams = self.recipe.fp8_quant_fwd_inp + fp8_dtype = self.qx_dtype + block_scaling_dim = self.recipe.x_block_scaling_dim + return Float8BlockQuantizer( + fp8_dtype=fp8_dtype, + rowwise=True, + columnwise=True, + amax_epsilon=qparams.amax_epsilon, + force_pow_2_scales=qparams.power_2_scale, + block_scaling_dim=block_scaling_dim, ) - assert self.mode == "backward", f"Unexpected mode {self.mode}" - assert self.num_quantizers % 2 == 0 # grad_output and grad_input per gemm - return list( - itertools.chain.from_iterable( - [ - [ - Float8BlockQuantizer( - fp8_dtype=self.qgrad_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, - block_scaling_dim=self.recipe.grad_block_scaling_dim, - ), - Float8BlockQuantizer( - fp8_dtype=self.qgrad_dtype, - rowwise=True, - columnwise=True, - amax_epsilon=self.recipe.fp8_quant_bwd_grad.amax_epsilon, - force_pow_2_scales=self.recipe.fp8_quant_bwd_grad.power_2_scale, - block_scaling_dim=self.recipe.grad_block_scaling_dim, - ), - ] - for _ in range(self.num_quantizers // 2) - ] - ) - ) + assert self.mode in ("forward", "backward"), f"Unexpected mode {self.mode}" + return [_make(self._slot_tensor_type(idx)) for idx in range(self.num_quantizers)] class NVFP4BlockScalingRecipeState(RecipeState): @@ -1338,10 +1585,13 @@ def __init__( mode: str, num_quantizers: int = 1, device: Optional[torch.device] = None, + roles: Optional[List[QuantizerRole]] = None, ) -> None: + self._validate_roles(roles, num_quantizers) self.recipe = recipe self.mode = mode self.num_quantizers = num_quantizers + self.roles = roles self.dtype = get_fp4_te_dtype(recipe) # Allocate buffers @@ -1349,63 +1599,184 @@ def __init__( device = torch.device("cuda") def make_quantizers(self) -> list: + """Build one ``NVFP4Quantizer`` per slot, dispatched by tensor type. + + Per-slot behavior, resolved via :meth:`RecipeState._slot_tensor_type`: + + * Forward, ``"weight"`` -> ``recipe.fp4_quant_fwd_weight``. + * Forward, ``"input"`` / ``"output"`` (and any unknown forward type) -> + ``recipe.fp4_quant_fwd_inp``. + * Backward, any slot -> ``recipe.fp4_quant_bwd_grad``. + + When the owning module/op provides a role list via + ``get_quantizer_roles``, the per-slot ``tensor_type`` drives dispatch. + Otherwise (or for boundary slots whose role is ``None``), the + positional fallback ``[input, weight, output, ...]`` is used; on this + layout slot ``idx % 3 == 1`` is always weight and the rest fall into + the input config, matching the legacy index-based behavior. + """ from .tensor.nvfp4_tensor import NVFP4Quantizer - # The index convention (coming from base.py set_meta_tensor) - # is somewhat awkward. It assumes forward quantizers are - # ordered [input, weight, output, ...] and backward quantizers - # are ordered [grad_output, grad_input, ...]. This doesn't - # play nicely with fusible ops: Linear op doesn't own output - # or grad input quantizers, Quantize op only owns input and - # grad output quantizers. - - if self.mode == "forward": - - def _make_quantizer(idx: int) -> NVFP4Quantizer: - qparams = ( - self.recipe.fp4_quant_fwd_weight - if idx % 3 == 1 - else self.recipe.fp4_quant_fwd_inp - ) - return NVFP4Quantizer( - fp4_dtype=self.dtype, - rowwise=True, - columnwise=True, - with_rht=qparams.random_hadamard_transform, - with_post_rht_amax=qparams.random_hadamard_transform, - with_2d_quantization=qparams.fp4_2d_quantization, - stochastic_rounding=qparams.stochastic_rounding, - row_scaled_nvfp4=self.recipe.row_scaled_activation and idx % 3 != 1, - ) + def _qparams(tensor_type: str): + if self.mode == "backward": + return self.recipe.fp4_quant_bwd_grad + if tensor_type == "weight": + return self.recipe.fp4_quant_fwd_weight + return self.recipe.fp4_quant_fwd_inp + + def _make(tensor_type: str) -> NVFP4Quantizer: + qparams = _qparams(tensor_type) + return NVFP4Quantizer( + fp4_dtype=self.dtype, + rowwise=True, + columnwise=True, + with_rht=qparams.random_hadamard_transform, + with_post_rht_amax=qparams.random_hadamard_transform, + with_2d_quantization=qparams.fp4_2d_quantization, + stochastic_rounding=qparams.stochastic_rounding, + row_scaled_nvfp4=( + self.mode == "forward" + and tensor_type != "weight" + and self.recipe.row_scaled_activation + ), + ) - return [_make_quantizer(idx) for idx in range(self.num_quantizers)] - - if self.mode == "backward": - return [ - NVFP4Quantizer( - fp4_dtype=self.dtype, - rowwise=True, - columnwise=True, - with_rht=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, - with_post_rht_amax=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, - with_2d_quantization=self.recipe.fp4_quant_bwd_grad.fp4_2d_quantization, - stochastic_rounding=self.recipe.fp4_quant_bwd_grad.stochastic_rounding, - row_scaled_nvfp4=False, + if self.mode not in ("forward", "backward"): + raise RuntimeError(f"Unexpected recipe mode ({self.mode})") + + return [_make(self._slot_tensor_type(idx)) for idx in range(self.num_quantizers)] + + +def _handle_delayed_scaling_requests( + raw: list, + device: torch.device, + mode: str, + *, + existing_ds_state: Optional["DelayedScalingRecipeState"] = None, +) -> Optional["DelayedScalingRecipeState"]: + """Detect DelayedScalingRequest items, allocate shared state, replace with real quantizers. + + All DS requests in the same RecipeState must share identical parameters. + + When ``existing_ds_state`` is provided and compatible (same dtype, + same number of DS slots, same ``amax_history_len``), it is reused + instead of allocating fresh buffers. Reusing preserves accumulated + ``scale`` / ``amax_history`` across role-driven rebuilds — important + for post-checkpoint resume and mid-training factory swaps. The + ``Float8Quantizer`` instances built here will then view into the + SAME tensor objects already registered with + ``FP8GlobalStateManager``'s reduction buffers, keeping reduction + and quantization consistent. + + Returns a ``DelayedScalingRecipeState`` owning the shared buffers, or + ``None`` when no DS requests are present. + """ + ds_items = [(i, r) for i, r in enumerate(raw) if isinstance(r, DelayedScalingRequest)] + if not ds_items: + return None + + r0 = ds_items[0][1] + + # Validate all DS requests share same params + for idx, req in ds_items[1:]: + for field_name in ( + "fp8_format", + "margin", + "amax_history_len", + "amax_compute_algo", + "scaling_factor_compute_algo", + "reduce_amax", + ): + v0 = getattr(r0, field_name) + vi = getattr(req, field_name) + if v0 != vi: + raise ValueError( + "All DelayedScalingRequests in one CustomRecipeState must match. " + f"Slot 0 has {field_name}={v0!r}, slot {idx} has {vi!r}." ) - for _ in range(self.num_quantizers) - ] - raise RuntimeError(f"Unexpected recipe mode ({self.mode})") + # Build a real DelayedScalingRecipeState to own the shared buffers. + inner_recipe = DelayedScaling( + fp8_format=r0.fp8_format, + margin=r0.margin, + amax_history_len=r0.amax_history_len, + amax_compute_algo=r0.amax_compute_algo, + scaling_factor_compute_algo=r0.scaling_factor_compute_algo, + reduce_amax=r0.reduce_amax, + ) + n = len(ds_items) + + # Reuse a compatible existing DSRS so its scale / amax_history (and any + # external references to them) survive the rebuild. + expected_dtype = get_fp8_te_dtype(inner_recipe, mode == "forward") + dsrs = None + if existing_ds_state is not None: + if ( + existing_ds_state.num_quantizers == n + and existing_ds_state.dtype == expected_dtype + and existing_ds_state.amax_history.shape[0] == r0.amax_history_len + ): + dsrs = existing_ds_state + + if dsrs is None: + dsrs = DelayedScalingRecipeState( + inner_recipe, + mode=mode, + num_quantizers=n, + device=device, + ) + + # Splice Float8Quantizer instances (backed by dsrs buffers) into raw list. + quantizers = dsrs.make_quantizers() + for j, (idx, _req) in enumerate(ds_items): + raw[idx] = quantizers[j] + + return dsrs + + +def _has_delayed_scaling_state(fp8_meta: Dict[str, Any]) -> bool: + """Check if fp8_meta has delayed scaling state (built-in or custom).""" + if fp8_meta["recipe"].delayed(): + return True + if fp8_meta["recipe"].custom(): + for key in ("scaling_fwd", "scaling_bwd"): + state = fp8_meta.get(key) + if isinstance(state, CustomRecipeState) and state._has_delayed_scaling: + return True + return False class CustomRecipeState(RecipeState): - """State for CustomRecipe: produce quantizers per tensor.""" + """State for CustomRecipe: produce quantizers per tensor. + + Stateful quantizer support: + - Supports stateful quantizers (e.g. delayed scaling) via ``DelayedScalingRequest``. + - The factory returns request dataclasses for stateful quantizers; TE detects them, + allocates shared buffers, and replaces with real quantizer instances. + - Stateful recipe state is composed via real TE recipe state objects (e.g. + ``DelayedScalingRecipeState``), not reimplemented. + """ recipe: CustomRecipe mode: str num_quantizers: int device: Optional[torch.device] + # -- Composed sub-states for stateful sub-recipes -- + # + # When the qfactory returns request objects (e.g. ``DelayedScalingRequest``) + # for a stateful built-in recipe, ``make_quantizers`` allocates a real + # built-in ``RecipeState`` for those slots and reuses its persistent + # buffers across role-driven rebuilds via ``inherit_state_from``. One + # ``__state`` / ``__state_to_inherit`` pair per stateful recipe. + + # Delayed scaling (``DelayedScalingRequest`` -> ``DelayedScalingRecipeState``): + # ``_ds_state`` owns shared ``scale`` / ``amax_history`` for DS slots in this + # CustomRecipeState; ``_ds_state_to_inherit`` is a transient stash set by + # ``inherit_state_from`` and consumed by the next ``make_quantizers`` call. + _ds_state: Optional[DelayedScalingRecipeState] + _ds_state_to_inherit: Optional[DelayedScalingRecipeState] + def __init__( self, recipe: CustomRecipe, @@ -1413,39 +1784,106 @@ def __init__( mode: str, num_quantizers: int = 1, device: Optional[torch.device] = None, + roles: Optional[List[QuantizerRole]] = None, ) -> None: + self._validate_roles(roles, num_quantizers) self.recipe = recipe self.mode = mode self.num_quantizers = num_quantizers + self.roles = roles if device is None: device = torch.device("cuda") self.device = device + # -- Stateful sub-state slots (initialized empty) -- + # Delayed scaling + self._ds_state = None + self._ds_state_to_inherit = None + if getattr(recipe, "qfactory", None) is None: raise ValueError("CustomRecipe requires `qfactory`.") def make_quantizers(self) -> list: qfactory = self.recipe.qfactory - out = [] - - # TODO(negvet): make_quantizers() should take roles from the operation - # Hardcode linear-specific roles for now - roles: List[str] - if self.mode == "forward": - roles = [ - ("linear_input", "linear_weight", "linear_output")[i % 3] - for i in range(self.num_quantizers) - ] - elif self.mode == "backward": - roles = [ - ("linear_grad_output", "linear_grad_input")[i % 2] - for i in range(self.num_quantizers) - ] - else: - roles = ["unknown"] * self.num_quantizers - for i in range(self.num_quantizers): - # Get quantizer from the user defined factory - quantizer = qfactory(roles[i]) - out.append(quantizer) - return out + roles = self.roles + if roles is None: + warnings.warn( + "CustomRecipeState: no QuantizerRole list provided by the module/op. " + "Falling back to bare QuantizerRole() defaults. " + "Override get_quantizer_roles() to provide meaningful roles.", + stacklevel=2, + ) + roles = [QuantizerRole() for _ in range(self.num_quantizers)] + + # qfactory must return a Quantizer or QuantizerRequest for every slot. + # None is not a valid return value — it would silently disable quantization + # for that tensor, risking hard-to-detect performance regressions. + # TODO(negvet): Introduce an explicit IdentityQuantizer for intentional no-op + # quantization. Until then, None is rejected. + raw = [qfactory(roles[i]) for i in range(self.num_quantizers)] + for i, q in enumerate(raw): + if q is None: + raise ValueError( + f"CustomRecipe qfactory returned None for slot {i} " + f"(role={roles[i]}). Every slot must return a Quantizer " + "instance or a QuantizerRequest." + ) + + # -- Delayed scaling sub-state -- + # If a predecessor stashed a compatible inner DSRS via + # ``inherit_state_from``, reuse it so accumulated scale / amax_history + # survive the rebuild. Consume the stash so a subsequent + # ``make_quantizers`` doesn't reuse it again unintentionally. + existing_ds_state = self._ds_state_to_inherit + self._ds_state_to_inherit = None + self._ds_state = _handle_delayed_scaling_requests( + raw, + self.device, + self.mode, + existing_ds_state=existing_ds_state, + ) + + return raw + + def inherit_state_from(self, other: "RecipeState") -> bool: + """Stash ``other``'s composed sub-states for reuse on next ``make_quantizers``. + + ``CustomRecipeState`` cannot inherit declaratively because its + persistent state lives in composed sub-states (one per stateful + sub-recipe) that are allocated only when ``make_quantizers`` runs. + For each stateful sub-recipe we stash the predecessor's sub-state + and let the next ``make_quantizers`` decide whether the + predecessor's shape is compatible with the new factory output. + """ + if not isinstance(other, CustomRecipeState): + return False + + inherited_any = False + + # -- Delayed scaling sub-state -- + if other._ds_state is not None: + self._ds_state_to_inherit = other._ds_state + inherited_any = True + + return inherited_any + + # -- Delegation to composed DelayedScalingRecipeState -- + + @property + def _has_delayed_scaling(self) -> bool: + return self._ds_state is not None + + @property + def amax_history(self) -> Optional[torch.Tensor]: + """Amax history from the composed delayed-scaling state, if any.""" + return self._ds_state.amax_history if self._ds_state else None + + @property + def scale(self) -> Optional[torch.Tensor]: + """Current scale from the composed delayed-scaling state, if any.""" + return self._ds_state.scale if self._ds_state else None + + @property + def _inner_delayed_scaling_recipe(self) -> Optional[DelayedScaling]: + return self._ds_state.recipe if self._ds_state else None