From 6cdd7115e1c1e6605feaff6ebe111ea0466cd71a Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Mon, 11 May 2026 21:25:17 -0700 Subject: [PATCH] [PyTorch] CPU overhead optimizations for te autocast (#2957) * cpu optimizations for te autocast Signed-off-by: Varun Thumbe * address some review comments Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comments Signed-off-by: Varun Thumbe * clean comments Signed-off-by: Varun Thumbe --------- Signed-off-by: Varun Thumbe Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/common/recipe/__init__.py | 70 +++++++++++--- transformer_engine/pytorch/quantization.py | 97 +++++++++++++------- 2 files changed, 118 insertions(+), 49 deletions(-) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 9599663691..b773a81d1b 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -4,6 +4,7 @@ """This module provides predefined FP8 recipes.""" from __future__ import annotations +import abc import os from enum import Enum from typing import Any, Literal, Optional, Union, Callable, NamedTuple @@ -60,6 +61,16 @@ class MMParams: use_split_accumulator: bool = True + def __post_init__(self) -> None: + object.__setattr__( + self, + "_cached_repr", + f"MMParams(use_split_accumulator={self.use_split_accumulator})", + ) + + def __repr__(self) -> str: + return self._cached_repr + @dataclass(frozen=True) class QParams: @@ -76,21 +87,50 @@ class QParams: stochastic_rounding: bool = False fp4_2d_quantization: bool = False - def __repr__(self) -> str: - return ( + def __post_init__(self) -> None: + object.__setattr__( + self, + "_cached_repr", f"Qparams(\npower_2_scale={self.power_2_scale},\n" f"amax_epsilon={self.amax_epsilon},\n" f"random_hadamard_transform={self.random_hadamard_transform},\n" f"stochastic_rounding={self.stochastic_rounding},\n" - f"fp4_2d_quantization={self.fp4_2d_quantization}\n)" + f"fp4_2d_quantization={self.fp4_2d_quantization}\n)", ) + def __repr__(self) -> str: + return self._cached_repr + class Recipe: """ Base recipe class. """ + # Cached string representation. Lazily populated by ``__repr__`` in + # subclasses and invalidated by ``__setattr__`` whenever any attribute + # changes. This makes repeated ``str(recipe)`` calls much cheaper + _cached_repr: Optional[str] = None + + def __setattr__(self, name: str, value: Any) -> None: + # Invalidate the cached repr on any attribute mutation. + if name != "_cached_repr": + object.__setattr__(self, "_cached_repr", None) + object.__setattr__(self, name, value) + + @abc.abstractmethod + def _make_repr(self) -> str: + """Build the string representation for this recipe. + + Subclasses must override this method. The result is cached by + ``__repr__`` and reused until any attribute is mutated. + """ + + def __repr__(self) -> str: + if self._cached_repr is None: + self._cached_repr = self._make_repr() + return self._cached_repr + @classmethod def nvfp4(cls): """Whether the given recipe is NVFP4 1D block scaling.""" @@ -127,7 +167,7 @@ def custom(cls): return issubclass(cls, CustomRecipe) -@dataclass() +@dataclass(repr=False) class DelayedScaling(Recipe): """ Use the delayed scaling factor strategy. Use scale factor from previous @@ -227,7 +267,7 @@ def __post_init__(self) -> None: self.backward_override is None ), "Delayed scaling only supports backward_override=None." - def __repr__(self) -> str: + def _make_repr(self) -> str: return ( f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " @@ -240,7 +280,7 @@ def __repr__(self) -> str: ) -@dataclass() +@dataclass(repr=False) class Float8CurrentScaling(Recipe): """ Use the per-tensor current scaling factor strategy. @@ -275,7 +315,7 @@ def __post_init__(self) -> None: self.backward_override in _BACKWARD_OVERRIDES ), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'." - def __repr__(self) -> str: + def _make_repr(self) -> str: return ( f"recipe_type={self.__class__.__name__}, " f"format={str(self.fp8_format).split('.')[1]}, " @@ -291,7 +331,7 @@ def __repr__(self) -> str: ) -@dataclass() +@dataclass(repr=False) class MXFP8BlockScaling(Recipe): """ Use the MXFP8 scaling factor strategy. @@ -333,7 +373,7 @@ def __post_init__(self) -> None: self.backward_override in _BACKWARD_OVERRIDES ), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'." - def __repr__(self) -> str: + def _make_repr(self) -> str: return ( f"recipe_type={self.__class__.__name__}, " f"margin={self.margin}, " @@ -342,7 +382,7 @@ def __repr__(self) -> str: ) -@dataclass() +@dataclass(repr=False) class Float8BlockScaling(Recipe): """ Use block-wise scaling for FP8 tensors. @@ -414,7 +454,7 @@ def __post_init__(self) -> None: self.backward_override in _BACKWARD_OVERRIDES ), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'." - def __repr__(self) -> str: + def _make_repr(self) -> str: return ( f"recipe_type={self.__class__.__name__}, " f"format={str(self.fp8_format).split('.')[1]}, " @@ -433,7 +473,7 @@ def __repr__(self) -> str: ) -@dataclass() +@dataclass(repr=False) class NVFP4BlockScaling(Recipe): """ Use the NVFP4 scaling strategy. @@ -531,7 +571,7 @@ def __post_init__(self) -> None: fp4_2d_quantization=False, ) - def __repr__(self) -> str: + def _make_repr(self) -> str: return ( f"recipe_type={self.__class__.__name__}, " f"fp4_format={str(self.fp4_format).split('.')[1]}, " @@ -546,7 +586,7 @@ def __repr__(self) -> str: ) -@dataclass() +@dataclass(repr=False) class CustomRecipe(Recipe): """ Custom recipe that allows users to provide quantizer factories. @@ -608,7 +648,7 @@ def __post_init__(self) -> None: self.backward_override in _BACKWARD_OVERRIDES ), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'." - def __repr__(self) -> str: + def _make_repr(self) -> str: return ( f"recipe_type={self.__class__.__name__}, " f"qfactory={self.qfactory}, " diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 82b8274378..0c40723517 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -686,9 +686,8 @@ def reduce_and_update_fp8_tensors( amax_history, scale, get_fp8_max(recipe, forward), recipe ) - @classmethod + @staticmethod def get_unique_autocast_key( - cls, recipe: Optional[Recipe] = None, group: Optional[dist_group_type] = None, ): @@ -697,7 +696,11 @@ def get_unique_autocast_key( Object identity is sufficient since autocast contexts never outlive a single training session. """ - return str((str(recipe), id(group) if group is not None else None)) + recipe_repr = recipe.__dict__.get("_cached_repr") if recipe is not None else None + if recipe_repr is None: + recipe_repr = str(recipe) + group_id = id(group) if group is not None else None + return f"recipe={recipe_repr},group={group_id}" @classmethod def autocast_enter( @@ -911,14 +914,13 @@ def quantized_model_init( qstate.high_precision_init_val = _high_precision_init_val -@contextmanager def fp8_autocast( enabled: bool = True, calibrating: bool = False, fp8_recipe: Optional[Recipe] = None, fp8_group: Optional[dist_group_type] = None, _graph: bool = False, -) -> None: +) -> "autocast": """ .. warning:: @@ -934,25 +936,16 @@ def fp8_autocast( stacklevel=2, ) - # Call new implementation. - with autocast( + return autocast( enabled=enabled, calibrating=calibrating, recipe=fp8_recipe, amax_reduction_group=fp8_group, _graph=_graph, - ): - yield + ) -@contextmanager -def autocast( - enabled: bool = True, - calibrating: bool = False, - recipe: Optional["Recipe"] = None, - amax_reduction_group: Optional["dist_group_type"] = None, - _graph: bool = False, -) -> None: +class autocast: """ Context manager for quantization schemes like FP8 or FP4. @@ -991,24 +984,60 @@ def autocast( are reduced at the end of each training step. """ - if enabled: - check_recipe_support(recipe) - - # Save current state so we always restore it on exit. - fp8_state = FP8GlobalStateManager.get_autocast_state() - - FP8GlobalStateManager.autocast_enter( - enabled=enabled, - calibrating=calibrating, - fp8_recipe=recipe, - fp8_group=amax_reduction_group, - _graph=_graph, + # Class-based context manager (instead of ``@contextmanager`` from contextlib) + # to avoid overheads. + __slots__ = ( + "_enabled", + "_calibrating", + "_recipe", + "_amax_reduction_group", + "_graph", + "_fp8_state", ) - try: - yield - finally: - FP8GlobalStateManager.set_autocast_state(fp8_state) - FP8GlobalStateManager.autocast_exit(enabled, _graph=_graph) + + def __init__( + self, + enabled: bool = True, + calibrating: bool = False, + recipe: Optional["Recipe"] = None, + amax_reduction_group: Optional["dist_group_type"] = None, + _graph: bool = False, + ) -> None: + self._enabled = enabled + self._calibrating = calibrating + self._recipe = recipe + self._amax_reduction_group = amax_reduction_group + self._graph = _graph + self._fp8_state = None + + def __enter__(self) -> "autocast": + # Disallow nested re-entry of the same instance. + if self._fp8_state is not None: + raise RuntimeError( + "autocast context manager cannot be entered more than once concurrently" + ) + if self._enabled: + check_recipe_support(self._recipe) + # Save current state so we always restore it on exit. + self._fp8_state = FP8GlobalStateManager.get_autocast_state() + FP8GlobalStateManager.autocast_enter( + enabled=self._enabled, + calibrating=self._calibrating, + fp8_recipe=self._recipe, + fp8_group=self._amax_reduction_group, + _graph=self._graph, + ) + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + try: + FP8GlobalStateManager.set_autocast_state(self._fp8_state) + FP8GlobalStateManager.autocast_exit(self._enabled, _graph=self._graph) + finally: + # Clear the saved state so the instance can be entered again + # sequentially (and so a failure inside the restore path does not + # permanently mark the instance as "active"). + self._fp8_state = None def _update_amax_history(amax_history: torch.Tensor) -> torch.Tensor: