Skip to content

Conversation

@ptrendx
Copy link
Member

@ptrendx ptrendx commented Dec 1, 2025

Description

This PR includes a few performance optimizations targeting the CPU overhead. The code, perf numbers etc. are WIP. The code gets kind of ugly though :-(.

For the prepare_forward changes I did not touch attention (@cyanguwa FYI) since it has multiple exit points from the forward and was worried that I would miss something there - it would be great if we could refactor that part first to have a single return statement instead.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@ptrendx
Copy link
Member Author

ptrendx commented Dec 1, 2025

/te-ci pytorch

Comment on lines 644 to 646
def fast_set_attr(self, name: str, value: Any) -> None:
self.__dict__[name] = value
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume we are separating out this function so we can manually avoid overheads from __setattr__ and dict? Doing some benchmarking:

  • dict read: 9 ns
  • dict write: 13 ns
  • dict in: 9 ns
  • dict.get: 14 ns
  • Function call: 9 ns
  • Class attr read: 3 ns
  • Class attr write: 5 ns
  • Class custom getattr: 101 ns
  • Class custom setattr: 134 ns
Benchmarking script

I ran the following on a GB200 node. For the dict times, I subtracted out the overhead from list reads. For the class getattr/setattr times, I subtracted out the overhead from range.

import contextlib
import time

class Timer:
    """Measure time interval."""

    def __init__(self) -> None:
        self._start = None
        self._end = None

    def time(self) -> float:
	"""CPU time interval in seconds."""
        return self._end - self._start

    @contextlib.contextmanager
    def context(self):
        """Context manager to capture time interval."""
	self._start = time.perf_counter()
        yield
        self._end = time.perf_counter()

def main() -> None:

    # Options
    iters = 1024 * 1024

    # Timer
    timer = Timer()

    # Dummy data
    str_list = ["lorem", "ipsum", "dolor", "sit", "amet", "consectetur", "adipiscing", "elit"]
    str_list = [str_list[i % len(str_list)] for i in range(iters)]
    str_dict = {s: len(s) for s in str_list}
    class PlainClass:
        def __init__(self) -> None:
            self.attr = 1
    class CustomGetattrSetattrClass:
        def __init__(self) -> None:
            self.attr = 1
        def __getattribute__(self, name):
            return super().__getattribute__(name)
	def __setattr__(self, name, val):
            super().__setattr__(name, val)

    # Timer overhead
    with timer.context():
        pass
    print(f"Timer overhead: {timer.time() * 1e9 / iters} ns/iter")

    # Range loop
    with timer.context():
        for _ in range(iters):
            pass
    print(f"Range loop: {timer.time() * 1e9 / iters} ns/iter")

    # List loop
    with timer.context():
        for _ in str_list:
            pass
    print(f"List loop: {timer.time() * 1e9 / iters} ns/iter")

    # Empty range+enumerate loop
    with timer.context():
        for i, j in enumerate(range(iters)):
            pass
    print(f"Range+enumerate loop: {timer.time() * 1e9 / iters} ns/iter")

    # Empty range+enumerate loop
    with timer.context():
        for i, s in enumerate(str_list):
            pass
    print(f"List+enumerate loop: {timer.time() * 1e9 / iters} ns/iter")

    # List reads
    with timer.context():
        for i in range(iters):
            str_list[i]
    print(f"List reads: {timer.time() * 1e9 / iters} ns/iter")

    # Dict reads
    with timer.context():
        for i in range(iters):
            str_dict[str_list[i]]
    print(f"Dict reads: {timer.time() * 1e9 / iters} ns/iter")

    # Dict get
    with timer.context():
        for i in range(iters):
            str_dict.get(str_list[i], None)
    print(f"Dict gets: {timer.time() * 1e9 / iters} ns/iter")

    # Dict writes
    with timer.context():
        for i in range(iters):
            str_dict[str_list[i]] = i
    print(f"Dict writes: {timer.time() * 1e9 / iters} ns/iter")

    # Dict membership
    with timer.context():
        for i in range(iters):
            str_list[i] in str_dict
    print(f"Dict membership: {timer.time() * 1e9 / iters} ns/iter")

    # Function call
    def func() -> None:
        pass
    with timer.context():
        for _ in range(iters):
            func()
    print(f"Function call: {timer.time() * 1e9 / iters} ns/iter")

    # Function call
    func = lambda: None
    with timer.context():
        for _ in range(iters):
            func()
    print(f"Lambda call: {timer.time() * 1e9 / iters} ns/iter")

    # Class attr read
    myobj = PlainClass()
    with timer.context():
        for _ in range(iters):
            _ = myobj.attr
    print(f"Class attr read: {timer.time() * 1e9 / iters} ns/iter")

    # Class attr write
    myobj = PlainClass()
    with timer.context():
        for i in range(iters):
            myobj.attr = i
    print(f"Class attr write: {timer.time() * 1e9 / iters} ns/iter")

    # getattr
    myobj = PlainClass()
    with timer.context():
        for _ in range(iters):
            getattr(myobj, "attr", None)
    print(f"getattr: {timer.time() * 1e9 / iters} ns/iter")

    # getattr
    myobj = PlainClass()
    with timer.context():
        for i in range(iters):
            setattr(myobj, "attr", i)
    print(f"setattr: {timer.time() * 1e9 / iters} ns/iter")

    # Class custom getattr
    myobj = CustomGetattrSetattrClass()
    with timer.context():
        for _ in range(iters):
            _ = myobj.attr
    print(f"Class custom getattr: {timer.time() * 1e9 / iters} ns/iter")

    # Class custom setattr
    myobj = CustomGetattrSetattrClass()
    with timer.context():
        for i in range(iters):
            myobj.attr = i
    print(f"Class custom setattr: {timer.time() * 1e9 / iters} ns/iter")

if __name__ == "__main__":
    main()

How much perf difference do you observe from fast_set_attr? I could see how it could save us ~1 us of overhead, but it would be good to make sure before making the code messier.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want to comment too much on the perf results yet since up till now they all come from my machine and not a real cluster, but that anecdotal evidence shows that the time of the small test of just running BF16 Linear layer forward for many iterations after the proposed code changes go from 9.2 to 7.7 s. The fast_set_attr alone brought it to ~8.4s.
I will test it properly and report the timings in the description of the PR.
Now, about introducing the separate function - since ultimately this is the optimization that you came up with at some point, there already was the machinery to not do the expensive Module.set_attr for some parameters. The problem that I see is discoverability - if people do not study that code very cautiously they will not realize that they should not just do self.something = something. Therefore I think we should actually go a more explicit way and in the set_attr of TE module just error out with a message to either use fast_set_attr for the things we are sure are just small values (since the usage of dict directly has some problems BTW since it e.g. bypasses properties and stuff) and use a new function, let's call it just set_attr for anything where we need the full machinery.

Copy link
Collaborator

@timmoon10 timmoon10 Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer not to ban self.something = something. I think readability and safety are more important for non-performance-critical things like initialization and checkpointing. It would be better to make this function an advanced internal implementation with a name like _fast_setattr.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How would we then make sure that this does not resurface in the future?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Went with the explicit setattr calls and having a warning issued when the regular setattr function is used. That way the users can still use the regular setattr call if they want, but for the internal development we make sure during testing that the warning does not trigger. To make the code less ugly we only turn on the warning after the constructor is finished - that way we can still use the nice syntax during construction (where there are the most occurences) since we do not care about the speed there.

@ptrendx ptrendx force-pushed the pr_python_cpu_optimization branch from 5eefe3e to 1c7d896 Compare December 2, 2025 22:45
@ptrendx ptrendx force-pushed the pr_python_cpu_optimization branch 3 times, most recently from 948747b to c4e380f Compare December 16, 2025 21:20
@ptrendx
Copy link
Member Author

ptrendx commented Jan 10, 2026

/te-ci pytorch

@ptrendx ptrendx marked this pull request as ready for review January 10, 2026 00:48
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 10, 2026

Greptile Summary

  • Introduces CPU performance optimizations by adding fast_setattr() to bypass PyTorch's expensive __setattr__ overhead and refactoring context managers to explicit method calls
  • Updates pytest configuration across test scripts to treat RuntimeWarnings as errors, enforcing the new fast attribute assignment pattern
  • Optimizes C++ tensor allocator by moving null checks before mutex acquisition to reduce lock contention

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/base.py Added fast_setattr() optimization and setattr warning system; several locations still use direct assignment after initialization violating the new pattern
transformer_engine/pytorch/module/linear.py Refactored from context manager to manual prepare_forward()/end_forward() calls; creates NVTX range imbalance if exceptions occur
transformer_engine/pytorch/module/grouped_linear.py Same context manager refactoring as linear.py; vulnerable to NVTX stack corruption on exceptions
transformer_engine/pytorch/module/layernorm_linear.py Same context manager refactoring pattern; breaks exception safety guarantees for NVTX profiling
tests/pytorch/pytest.ini Treats RuntimeWarnings as errors to enforce fast_setattr() usage; will cause test failures with existing direct assignments in base.py

Confidence score: 1/5

  • This PR has significant implementation issues including exception safety bugs and inconsistent application of the optimization pattern
  • Score severely reduced due to NVTX range imbalance bugs in multiple modules where exceptions between prepare_forward() and end_forward() corrupt profiling stacks
  • Direct attribute assignments still exist in base.py after initialization which will cause test failures when RuntimeWarnings are treated as errors

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 5 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +1402 to +1475
inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor))

weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()

quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)

(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
linear_fn = _Linear.apply
autograd_ctx = []
else:
linear_fn = _Linear.forward
autograd_ctx = [None]

non_tensor_args = (
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
is_grad_enabled,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.ub_name,
fp8_output,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.save_original_input,
debug,
)
out = linear_fn(
*autograd_ctx,
weight_tensor,
inp,
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
non_tensor_args,
)
(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
linear_fn = _Linear.apply
autograd_ctx = []
else:
linear_fn = _Linear.forward
autograd_ctx = [None]

non_tensor_args = (
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
is_grad_enabled,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.ub_name,
fp8_output,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.save_original_input,
debug,
)
out = linear_fn(
*autograd_ctx,
weight_tensor,
inp,
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
non_tensor_args,
)
self.end_forward()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The refactoring from context manager to explicit prepare_forward() and end_forward() calls introduces an NVTX range stack leak risk. If an exception is raised anywhere between line 1402 (where prepare_forward() calls nvtx_range_push()) and line 1474 (where end_forward() calls nvtx_range_pop()), the NVTX range will never be popped from the stack.

The original context manager approach (with self.prepare_forward()) was safer because it guaranteed cleanup even when exceptions occurred. Without proper exception handling, NVTX ranges can accumulate on the stack, corrupting profiling data.

Consider wrapping the code between prepare_forward() and end_forward() in a try-finally block:

inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor))
try:
    # ... all the processing code ...
    out = linear_fn(...)
finally:
    self.end_forward()

This same issue exists in layernorm_linear.py, layernorm_mlp.py, and grouped_linear.py.

Comment on lines +1519 to +1601
inp = self.prepare_forward(
inp, allow_non_contiguous=False # removed .contiguous from inside the layer
) as inp:
)

# Get concatenated weight and bias tensors
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
# Get concatenated weight and bias tensors
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()

quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)

(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
fwd_fn = _LayerNormLinear.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormLinear.forward
autograd_ctx = [None]
non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_rs_dgrad,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_name,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)
(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
fwd_fn = _LayerNormLinear.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormLinear.forward
autograd_ctx = [None]
non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_rs_dgrad,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_name,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)

self.end_forward()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same NVTX range stack leak vulnerability as in linear.py. If an exception occurs between prepare_forward() (line 1519) and end_forward() (line 1601), the NVTX range pushed in prepare_forward() will never be popped, corrupting the profiling stack.

The original context manager pattern was exception-safe. Consider using try-finally to ensure end_forward() is always called:

inp = self.prepare_forward(inp, allow_non_contiguous=False)
try:
    # ... processing code ...
    out = fwd_fn(...)
finally:
    self.end_forward()

Comment on lines 2057 to 2159
inp = self.prepare_forward(inp, num_gemms=2)

quantizers = (
self._get_quantizers(fp8_output, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, is_grad_enabled)
quantizers = (
self._get_quantizers(fp8_output, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, is_grad_enabled)

# Get quantizers
(
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
) = quantizers

# Get weight tensors
fc1_weight, fc2_weight = self._get_weight_tensors()
fc1_bias = self.fc1_bias if self.use_bias else None
fc2_bias = self.fc2_bias if self.use_bias else None
if not self.fp8:
if isinstance(fc1_weight, Float8Tensor):
fc1_weight = fc1_weight.dequantize()
if isinstance(fc2_weight, Float8Tensor):
fc2_weight = fc2_weight.dequantize()

# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute():
self.bias_gelu_nvfusion = False
# Get quantizers
(
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
) = quantizers

if is_grad_enabled:
fwd_fn = _LayerNormMLP.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormMLP.forward
autograd_ctx = [None]

non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
self.bias_gelu_nvfusion and not self.fp8 and not debug,
self.set_parallel_mode,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.activation,
self.activation_params,
self.normalization,
self.ub_overlap_ag,
self.ub_overlap_rs,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.gemm_gelu_fusion and not debug,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.checkpoint,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
fc1_weight,
fc1_bias,
fc2_weight,
fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)
# Get weight tensors
fc1_weight, fc2_weight = self._get_weight_tensors()
fc1_bias = self.fc1_bias if self.use_bias else None
fc2_bias = self.fc2_bias if self.use_bias else None
if not self.fp8:
if isinstance(fc1_weight, Float8Tensor):
fc1_weight = fc1_weight.dequantize()
if isinstance(fc2_weight, Float8Tensor):
fc2_weight = fc2_weight.dequantize()

# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute():
self.bias_gelu_nvfusion = False

if is_grad_enabled:
fwd_fn = _LayerNormMLP.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormMLP.forward
autograd_ctx = [None]

non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
self.bias_gelu_nvfusion and not self.fp8 and not debug,
self.set_parallel_mode,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.activation,
self.activation_params,
self.normalization,
self.ub_overlap_ag,
self.ub_overlap_rs,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.gemm_gelu_fusion and not debug,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.checkpoint,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same NVTX range stack leak vulnerability as in linear.py and layernorm_linear.py. If an exception occurs between prepare_forward() (line 2057) and end_forward() (line 2159), the NVTX range will leak.

Consider using try-finally to ensure cleanup:

inp = self.prepare_forward(inp, num_gemms=2)
try:
    # ... processing code ...
    out = fwd_fn(...)
finally:
    self.end_forward()

Comment on lines +793 to +852
inp = self.prepare_forward(inp, num_gemms=self.num_gemms)
weight_tensors = self._get_weight_tensors()
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]

quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers()
quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers()

if debug:
if self.no_debug_features_active(list(chain(*quantizers))):
debug = False
quantizers = self._get_quantizers()
if debug:
if self.no_debug_features_active(list(chain(*quantizers))):
debug = False
quantizers = self._get_quantizers()

if isinstance(weight_tensors, QuantizedTensorStorage):
raise RuntimeError("FP8 weights are not supported in debug mode.")
if isinstance(weight_tensors, QuantizedTensorStorage):
raise RuntimeError("FP8 weights are not supported in debug mode.")

(
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
) = quantizers
(
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
) = quantizers

if is_grad_enabled:
linear_fn = _GroupedLinear.apply
autograd_ctx = []
else:
linear_fn = _GroupedLinear.forward
autograd_ctx = [None]

non_tensor_args = (
m_splits,
self.apply_bias,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.sequence_parallel,
self.activation_dtype,
is_grad_enabled,
self,
None, # skip_fp8_weight_update
self.save_original_input,
debug,
)
out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)
if is_grad_enabled:
linear_fn = _GroupedLinear.apply
autograd_ctx = []
else:
linear_fn = _GroupedLinear.forward
autograd_ctx = [None]

non_tensor_args = (
m_splits,
self.apply_bias,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.sequence_parallel,
self.activation_dtype,
is_grad_enabled,
self,
None, # skip_fp8_weight_update
self.save_original_input,
debug,
)
out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same NVTX range stack leak vulnerability. If an exception occurs between prepare_forward() (line 793) and end_forward() (line 847), the NVTX range will leak.

Consider using try-finally to ensure cleanup:

inp = self.prepare_forward(inp, num_gemms=self.num_gemms)
try:
    # ... processing code ...
    out = linear_fn(...)
finally:
    self.end_forward()

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 10, 2026

Greptile Overview

Greptile Summary

This PR attempts to optimize CPU overhead by:

  1. Replacing context managers with manual prepare_forward()/end_forward() calls
  2. Introducing fast_setattr() to bypass PyTorch's __setattr__ overhead
  3. Adding an __setattr__ override that warns when the slow path is used
  4. Optimizing C++ mutex locking by moving null checks before lock acquisition
  5. Configuring pytest to treat RuntimeWarnings as errors

Critical Issues Found

1. NVTX Range Imbalance on Exceptions (HIGH SEVERITY)

The refactoring from context managers to manual prepare_forward()/end_forward() calls breaks exception safety. If an exception occurs between these calls (e.g., shape mismatch, CUDA OOM, assertion failure), nvtx_range_pop() is never called, corrupting the NVTX stack. This affects all forward methods in Linear, LayerNormLinear, LayerNormMLP, and GroupedLinear modules.

2. setattr Doesn't Actually Optimize (HIGH SEVERITY)

The new __setattr__ override still calls super().__setattr__(name, value) after emitting a warning, meaning every attribute assignment still goes through the slow PyTorch path. This defeats the purpose of the optimization.

3. Multiple RuntimeWarning Violations (CRITICAL SEVERITY)

Six locations in base.py use direct attribute assignment after initialization (lines 965, 966, 1558, 1559, 1565, 1581, 1608). Since pytest.ini now treats RuntimeWarnings as errors, all tests will fail.

Positive Aspects

  • C++ mutex optimization is correct and beneficial
  • Attention module correctly uses prepare_forward_ctx context manager
  • All module subclasses properly set _initialized flag
  • Test scripts correctly updated to use pytest.ini

Recommendation

This PR cannot be merged in its current state due to the RuntimeWarning violations that will cause all tests to fail. The NVTX exception safety issue is also critical for production use.

Confidence Score: 0/5

  • This PR is not safe to merge - it will cause all tests to fail due to RuntimeWarning violations
  • Score of 0 reflects critical issues that will break the build: (1) Six direct attribute assignments trigger RuntimeWarnings which pytest.ini treats as errors, causing all tests to fail immediately; (2) NVTX range imbalance on exceptions will corrupt profiling; (3) setattr optimization doesn't actually work as intended
  • transformer_engine/pytorch/module/base.py requires immediate attention - must fix all direct attribute assignments (lines 965, 966, 1558, 1559, 1565, 1581, 1608) and address exception safety for NVTX ranges

Important Files Changed

File Analysis

Filename Score Overview
transformer_engine/pytorch/module/base.py 1/5 Critical bugs: NVTX range imbalance on exceptions, setattr doesn't optimize, multiple RuntimeWarning-triggering assignments
transformer_engine/pytorch/module/linear.py 2/5 Changed prepare_forward context manager to manual calls, adds _initialized flag, affected by base.py NVTX bug
transformer_engine/pytorch/module/layernorm_linear.py 2/5 Changed prepare_forward context manager to manual calls, adds _initialized flag, affected by base.py NVTX bug
transformer_engine/pytorch/module/layernorm_mlp.py 2/5 Changed prepare_forward context manager to manual calls, adds _initialized flag, affected by base.py NVTX bug
transformer_engine/pytorch/module/grouped_linear.py 2/5 Changed prepare_forward context manager to manual calls, adds _initialized flag, affected by base.py NVTX bug
tests/pytorch/pytest.ini 3/5 New file that treats RuntimeWarning as error - exposes issues with setattr implementation in base.py

Sequence Diagram

sequenceDiagram
    participant User
    participant Module as Linear/LayerNorm Module
    participant Base as TransformerEngineBaseModule
    participant NVTX as NVTX Range Stack
    participant FP8 as FP8GlobalStateManager

    User->>Module: forward(inp)
    Module->>Base: prepare_forward(inp)
    Base->>Base: init_fp8_metadata()
    Base->>Base: get_amp_dtype()
    Base->>FP8: Copy FP8 metadata (if recompute)
    Base->>NVTX: nvtx_range_push("Module forward")
    Note over NVTX: Range pushed - needs cleanup!
    Base-->>Module: returns processed inp
    
    alt Exception occurs during forward
        Module->>Module: Compute quantizers/weights
        Module--xModule: Exception raised!
        Note over NVTX: ⚠️ BUG: nvtx_range_pop() never called!
        Note over NVTX: Stack becomes imbalanced
    else Normal execution
        Module->>Module: Compute quantizers/weights
        Module->>Module: Call linear_fn()
        Module->>Base: end_forward()
        Base->>FP8: Restore FP8 metadata (if recompute)
        Base->>NVTX: nvtx_range_pop()
        Note over NVTX: Range properly cleaned up
        Base-->>Module: return
        Module-->>User: return output
    end
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7 files reviewed, 7 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +1086 to +1095
def end_forward(self):
"""
Required to be called at the end of the forward function to properly handle
DelayedScaling metadata handling and the NVTX ranges.
"""
delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed()
if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
nvtx_range_pop()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new non-context-manager approach to prepare_forward() and end_forward() creates a critical bug: if an exception is raised between the prepare_forward() call and the end_forward() call in any forward method, the NVTX range will never be popped, leading to an imbalanced NVTX stack.

The original context manager approach with with self.prepare_forward(...) as inp: guaranteed that cleanup would happen even on exceptions. The new approach breaks this guarantee.

Impact: All forward methods in Linear, LayerNormLinear, LayerNormMLP, and GroupedLinear are affected. Any exception during forward (e.g., shape mismatch, CUDA OOM, assertion failure) will leave the NVTX range stack corrupted.

Recommendation: Either:

  1. Revert to the context manager approach (accepting the performance cost), or
  2. Wrap the forward logic in try/finally to ensure end_forward() is always called:
inp = self.prepare_forward(inp, ...)
try:
    # forward logic here
    out = linear_fn(...)
finally:
    self.end_forward()

Comment on lines 651 to +661
def __setattr__(self, name: str, value: Any) -> None:
if name in TransformerEngineBaseModule._fast_setattr_names:
# torch.nn.Module has a custom __setattr__ that handles
# modules, parameters, and buffers. This is unnecessary
# overhead when setting plain attrs.
self.__dict__[name] = value
else:
# Default case
super().__setattr__(name, value)
if "_initialized" in self.__dict__ and self._initialized:
warnings.warn(
"""The default implementation of torch.nn.Module introduces significant CPU overhead
when setting attributes and is therefore not recommended. Please use the explicit
calls (fast_setattr for setting regular values and module_setattr for setting
parameters, children modules and buffers).""",
RuntimeWarning,
)
super().__setattr__(name, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __setattr__ override doesn't actually optimize anything - it still calls super().__setattr__(name, value) after emitting the warning. This means every attribute assignment still goes through the slow PyTorch Module path, defeating the purpose of the optimization.

The warning is also emitted on EVERY attribute assignment after initialization, which would be extremely noisy and unhelpful for debugging.

Additionally, there are multiple places in the codebase that still use direct attribute assignment after initialization:

  • Line 965: self.tp_group = tp_group in set_tensor_parallel_group()
  • Line 966: self.tp_group_initialized = True
  • Lines 1558, 1559, 1565, 1581, 1608: Various debug-related attribute assignments

Since pytest.ini now treats RuntimeWarnings as errors, all of these will cause test failures.

Recommendation: Either:

  1. Remove the __setattr__ override entirely and rely on developers remembering to use fast_setattr, or
  2. Update all the remaining direct assignments to use fast_setattr or module_setattr

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 10, 2026

Additional Comments (5)

transformer_engine/pytorch/module/base.py
Direct attribute assignments after initialization will trigger the RuntimeWarning defined in __setattr__, which pytest.ini now treats as an error. This will cause test failures.

        self.fast_setattr("tp_group", tp_group)
        self.fast_setattr("tp_group_initialized", True)

transformer_engine/pytorch/module/base.py
Direct attribute assignments after initialization will trigger the RuntimeWarning defined in __setattr__, which pytest.ini now treats as an error. This will cause test failures.

            self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
            self.fast_setattr("debug_enabled_in_this_iteration", debug)

transformer_engine/pytorch/module/base.py
Direct attribute assignment after initialization will trigger the RuntimeWarning defined in __setattr__, which pytest.ini now treats as an error. This will cause test failures.

        self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())

transformer_engine/pytorch/module/base.py
Direct attribute assignment after initialization will trigger the RuntimeWarning defined in __setattr__, which pytest.ini now treats as an error. This will cause test failures.

        self.fast_setattr("next_iter_when_debug_should_be_run", next_iter_when_debug_should_be_run(quantizers))

transformer_engine/pytorch/module/base.py
Direct attribute assignment after initialization will trigger the RuntimeWarning defined in __setattr__, which pytest.ini now treats as an error. This will cause test failures.

            self.fast_setattr("name", f"Layer_{TEDebugState.get_layer_count()}")

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 8 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +1081 to +1085
nvtx_range_push(self.__class__.__name__ + " forward")
if not allow_non_contiguous and not inp.is_contiguous():
inp = inp.contiguous()
return inp
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The prepare_forward() method now uses nvtx_range_push() but the corresponding nvtx_range_pop() is in end_forward(). This creates a critical issue: if an exception occurs in the forward pass between prepare_forward() and end_forward() (e.g., shape mismatch, CUDA OOM, assertion failures), the NVTX range will never be popped, corrupting the NVTX profiling stack.

While the old context manager pattern guaranteed cleanup via __exit__, the new pattern requires manual calls. Consider wrapping the entire forward body in try-finally or document that modules must ensure end_forward() is called even on exceptions.

Comment on lines 651 to +661
def __setattr__(self, name: str, value: Any) -> None:
if name in TransformerEngineBaseModule._fast_setattr_names:
# torch.nn.Module has a custom __setattr__ that handles
# modules, parameters, and buffers. This is unnecessary
# overhead when setting plain attrs.
self.__dict__[name] = value
else:
# Default case
super().__setattr__(name, value)
if "_initialized" in self.__dict__ and self._initialized:
warnings.warn(
"""The default implementation of torch.nn.Module introduces significant CPU overhead
when setting attributes and is therefore not recommended. Please use the explicit
calls (fast_setattr for setting regular values and module_setattr for setting
parameters, children modules and buffers).""",
RuntimeWarning,
)
super().__setattr__(name, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The __setattr__ override emits a RuntimeWarning every time it's called after initialization, but it still calls super().__setattr__() to perform the assignment. This means:

  1. The warning is emitted but the "slow" operation still happens - there's no actual optimization here
  2. With pytest.ini treating RuntimeWarning as error, any accidental direct attribute assignment will cause test failures
  3. The warning message suggests using fast_setattr but doesn't prevent the slow path

If the goal is CPU optimization, this should either:

  • Raise an exception instead of warning (fail-fast), OR
  • Actually use the fast path by default after initialization

The current implementation warns about performance but doesn't fix it.

Comment on lines +6 to +7
filterwarnings=
error::RuntimeWarning
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Treating RuntimeWarning as an error will cause test failures for any code that uses direct attribute assignment after module initialization. Several locations in base.py still use direct assignment (see other comments), which will fail with this configuration.

This is appropriate for enforcing the fast_setattr() pattern, but all existing direct assignments must be fixed first, or this change should be in a separate PR after fixing all assignment sites.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Additional Comments (5)

transformer_engine/pytorch/module/base.py
The set_tensor_parallel_group() method directly assigns to self.tp_group and self.tp_group_initialized using the standard = operator. After initialization (when _initialized=True), this will trigger the RuntimeWarning from the overridden __setattr__ method, which pytest.ini now treats as an error.

This method is called after module initialization in user code, so it will fail. The assignments should use fast_setattr():

        self.fast_setattr("tp_group", tp_group)
        self.fast_setattr("tp_group_initialized", True)

transformer_engine/pytorch/module/base.py
The is_debug_iter() method directly assigns to self.debug_last_iteration and self.debug_enabled_in_this_iteration. Since this is called during forward passes (after _initialized=True), these assignments will trigger RuntimeWarnings that pytest treats as errors.

These should use fast_setattr():

            self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
            self.fast_setattr("debug_enabled_in_this_iteration", debug)

transformer_engine/pytorch/module/base.py
Another instance where self.debug_last_iteration is assigned directly during forward pass execution, which will trigger RuntimeWarning treated as error by pytest.

        self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())

transformer_engine/pytorch/module/base.py
The no_debug_features_active() method assigns to self.next_iter_when_debug_should_be_run, which is called during forward execution. This will trigger the RuntimeWarning.

        self.fast_setattr("next_iter_when_debug_should_be_run", next_iter_when_debug_should_be_run(quantizers))

transformer_engine/pytorch/module/base.py
The _validate_name() method assigns self.name directly during forward execution if no name was provided. This will trigger RuntimeWarning.

            self.fast_setattr("name", f"Layer_{TEDebugState.get_layer_count()}")

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 1040 to 1095
def prepare_forward(
self,
inp: torch.Tensor,
num_gemms: int = 1,
allow_non_contiguous: bool = False,
allow_different_data_and_param_types: bool = False,
) -> Generator[torch.Tensor, None, None]:
"""Checks and prep for FWD.
The context manager is needed because there isn't a way for a module to know
if it's the last FP8 module in the forward autocast. It is useful
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one.
"""
self.allow_different_data_and_param_types = allow_different_data_and_param_types
self.forwarded_at_least_once = True
) -> torch.Tensor:
"""Checks and prepares for FWD execution."""
self.fast_setattr(
"allow_different_data_and_param_types", allow_different_data_and_param_types
)
self.fast_setattr("forwarded_at_least_once", True)

# Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase():
delayed_scaling_recipe = self.fp8_meta["recipe"].delayed()
FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
else:
assert inp.is_cuda, "TransformerEngine needs CUDA."

if self.tp_size > 1:
assert self.tp_group_initialized, "TP group not initialized."

self.set_activation_dtype(inp)
self.init_fp8_metadata(num_gemms=num_gemms)
self._check_weight_tensor_recipe_correspondence()

delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed()
if delayed_scaling_recipe:
if self.sequence_parallel:
assert self.fp8_meta["recipe"].reduce_amax, (
"Amax reduction across tensor parallel group is "
"necessary when using sequence parallelism with FP8."
)

if not FP8GlobalStateManager.fp8_graph_capturing():
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta)

# Activation recomputation is used and this is the first forward phase.
if self.training and is_fp8_activation_recompute_enabled():
FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)

with get_nvtx_range_context(self.__class__.__name__ + " forward"):
if not allow_non_contiguous and not inp.is_contiguous():
inp = inp.contiguous()
yield inp
nvtx_range_push(self.__class__.__name__ + " forward")
if not allow_non_contiguous and not inp.is_contiguous():
inp = inp.contiguous()
return inp

def end_forward(self):
"""
Required to be called at the end of the forward function to properly handle
DelayedScaling metadata handling and the NVTX ranges.
"""
delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed()
if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
nvtx_range_pop()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CRITICAL: NVTX Stack Corruption on Exceptions

The refactoring from a context manager to manual prepare_forward()/end_forward() calls breaks exception safety. The context manager's __exit__ method guaranteed that cleanup code would run even if exceptions occurred. Now, if ANY exception is raised between these calls, nvtx_range_pop() will never be called, corrupting the NVTX profiling stack.

Exception scenarios that will corrupt the stack:

  1. Line 1058: assert inp.is_cuda - AssertionError if input is on CPU
  2. Line 1061: assert self.tp_group_initialized - AssertionError if TP not initialized
  3. Line 1070: assert self.fp8_meta["recipe"].reduce_amax - AssertionError for sequence parallel
  4. Any CUDA errors in the forward computation (OOM, kernel failures, etc.)
  5. Shape mismatches, dtype errors, or other runtime errors

Impact: NVTX profiling will be permanently broken after the first exception, making debugging extremely difficult.

Solution: Either:

  1. Wrap the entire forward body in try-finally to guarantee end_forward() is called, OR
  2. Revert to using the context manager (prepare_forward_ctx), OR
  3. Remove all assertions and exceptions from the code path between these calls (not realistic)

The attention module correctly uses prepare_forward_ctx (the context manager version) which avoids this issue.

Comment on lines +1402 to +1475
inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor))

weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()

quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)

(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
linear_fn = _Linear.apply
autograd_ctx = []
else:
linear_fn = _Linear.forward
autograd_ctx = [None]

non_tensor_args = (
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
is_grad_enabled,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.ub_name,
fp8_output,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.save_original_input,
debug,
)
out = linear_fn(
*autograd_ctx,
weight_tensor,
inp,
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
non_tensor_args,
)
(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
linear_fn = _Linear.apply
autograd_ctx = []
else:
linear_fn = _Linear.forward
autograd_ctx = [None]

non_tensor_args = (
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
is_grad_enabled,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.ub_name,
fp8_output,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.save_original_input,
debug,
)
out = linear_fn(
*autograd_ctx,
weight_tensor,
inp,
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
non_tensor_args,
)
self.end_forward()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CRITICAL: Exception safety issue - NVTX stack corruption risk

Between prepare_forward() (which calls nvtx_range_push) and end_forward() (which calls nvtx_range_pop), there are multiple operations that can raise exceptions:

  1. _get_weight_and_bias_tensors() - could fail with various errors
  2. _get_quantizers() / _get_debug_quantizers() - could raise exceptions
  3. linear_fn() - the actual forward computation can fail with CUDA errors, shape mismatches, OOM, etc.

If ANY exception occurs, end_forward() won't be called, leaving the NVTX stack corrupted.

Solution: Wrap in try-finally or use the context manager version (prepare_forward_ctx) like the attention module does.

Comment on lines +1519 to +1595
inp = self.prepare_forward(
inp, allow_non_contiguous=False # removed .contiguous from inside the layer
) as inp:
)

# Get concatenated weight and bias tensors
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
# Get concatenated weight and bias tensors
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()

quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)

(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
fwd_fn = _LayerNormLinear.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormLinear.forward
autograd_ctx = [None]
non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_rs_dgrad,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_name,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)
(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
fwd_fn = _LayerNormLinear.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormLinear.forward
autograd_ctx = [None]
non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_rs_dgrad,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_name,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CRITICAL: Exception safety issue - NVTX stack corruption risk

Same issue as Linear module. Between prepare_forward() and end_forward(), multiple operations can raise exceptions, leaving the NVTX stack corrupted if they do.

Solution: Wrap in try-finally or use the context manager version (prepare_forward_ctx).

Comment on lines +2057 to +2167
inp = self.prepare_forward(inp, num_gemms=2)

quantizers = (
self._get_quantizers(fp8_output, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, is_grad_enabled)
quantizers = (
self._get_quantizers(fp8_output, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, is_grad_enabled)

# Get quantizers
(
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
) = quantizers

# Get weight tensors
fc1_weight, fc2_weight = self._get_weight_tensors()
fc1_bias = self.fc1_bias if self.use_bias else None
fc2_bias = self.fc2_bias if self.use_bias else None
if not self.fp8:
if isinstance(fc1_weight, Float8Tensor):
fc1_weight = fc1_weight.dequantize()
if isinstance(fc2_weight, Float8Tensor):
fc2_weight = fc2_weight.dequantize()

# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute():
self.bias_gelu_nvfusion = False
# Get quantizers
(
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
) = quantizers

if is_grad_enabled:
fwd_fn = _LayerNormMLP.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormMLP.forward
autograd_ctx = [None]

non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
self.bias_gelu_nvfusion and not self.fp8 and not debug,
self.set_parallel_mode,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.activation,
self.activation_params,
self.normalization,
self.ub_overlap_ag,
self.ub_overlap_rs,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.gemm_gelu_fusion and not debug,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.checkpoint,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
fc1_weight,
fc1_bias,
fc2_weight,
fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)
# Get weight tensors
fc1_weight, fc2_weight = self._get_weight_tensors()
fc1_bias = self.fc1_bias if self.use_bias else None
fc2_bias = self.fc2_bias if self.use_bias else None
if not self.fp8:
if isinstance(fc1_weight, Float8Tensor):
fc1_weight = fc1_weight.dequantize()
if isinstance(fc2_weight, Float8Tensor):
fc2_weight = fc2_weight.dequantize()

# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute():
self.fast_setattr("bias_gelu_nvfusion", False)

if is_grad_enabled:
fwd_fn = _LayerNormMLP.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormMLP.forward
autograd_ctx = [None]

non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
self.bias_gelu_nvfusion and not self.fp8 and not debug,
self.set_parallel_mode,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.activation,
self.activation_params,
self.normalization,
self.ub_overlap_ag,
self.ub_overlap_rs,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.gemm_gelu_fusion and not debug,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.checkpoint,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
fc1_weight,
fc1_bias,
fc2_weight,
fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)

self.end_forward()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CRITICAL: Exception safety issue - NVTX stack corruption risk

Same exception safety issue. The forward computation between prepare_forward() and end_forward() can raise exceptions, corrupting the NVTX stack.

Solution: Wrap in try-finally or use the context manager version (prepare_forward_ctx).

Comment on lines +793 to +853
inp = self.prepare_forward(inp, num_gemms=self.num_gemms)
weight_tensors = self._get_weight_tensors()
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]

quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers()
quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers()

if debug:
if self.no_debug_features_active(list(chain(*quantizers))):
debug = False
quantizers = self._get_quantizers()
if debug:
if self.no_debug_features_active(list(chain(*quantizers))):
debug = False
quantizers = self._get_quantizers()

if isinstance(weight_tensors, QuantizedTensorStorage):
raise RuntimeError("FP8 weights are not supported in debug mode.")
if isinstance(weight_tensors, QuantizedTensorStorage):
raise RuntimeError("FP8 weights are not supported in debug mode.")

(
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
) = quantizers
(
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
) = quantizers

if is_grad_enabled:
linear_fn = _GroupedLinear.apply
autograd_ctx = []
else:
linear_fn = _GroupedLinear.forward
autograd_ctx = [None]

non_tensor_args = (
m_splits,
self.apply_bias,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.sequence_parallel,
self.activation_dtype,
is_grad_enabled,
self,
None, # skip_fp8_weight_update
self.save_original_input,
debug,
)
out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)
if is_grad_enabled:
linear_fn = _GroupedLinear.apply
autograd_ctx = []
else:
linear_fn = _GroupedLinear.forward
autograd_ctx = [None]

non_tensor_args = (
m_splits,
self.apply_bias,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.sequence_parallel,
self.activation_dtype,
is_grad_enabled,
self,
None, # skip_fp8_weight_update
self.save_original_input,
debug,
)
out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)

self.end_forward()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CRITICAL: Exception safety issue - NVTX stack corruption risk

Same exception safety issue as other modules. Any exception between prepare_forward() and end_forward() will corrupt the NVTX stack.

Solution: Wrap in try-finally or use the context manager version (prepare_forward_ctx).

Comment on lines +637 to +660
def fast_setattr(self, name: str, value: Any) -> None:
"""
Fast version of the Module's set attribute function.
Should be used for regular attributes, but not properties nor parameters/buffers.
"""
self.__dict__[name] = value

def module_setattr(self, name: str, value: Any) -> None:
"""
Regular version of the Module's set attribute function.
Should be used only when the fast version cannot be used - for the properties,
parameters and buffers.
"""
super().__setattr__(name, value)

def __setattr__(self, name: str, value: Any) -> None:
if name in TransformerEngineBaseModule._fast_setattr_names:
# torch.nn.Module has a custom __setattr__ that handles
# modules, parameters, and buffers. This is unnecessary
# overhead when setting plain attrs.
self.__dict__[name] = value
else:
# Default case
super().__setattr__(name, value)
if "_initialized" in self.__dict__ and self._initialized:
warnings.warn(
"""The default implementation of torch.nn.Module introduces significant CPU overhead
when setting attributes and is therefore not recommended. Please use the explicit
calls (fast_setattr for setting regular values and module_setattr for setting
parameters, children modules and buffers).""",
RuntimeWarning,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Design Issue: The setattr override doesn't actually optimize anything

The new __setattr__ implementation still calls super().__setattr__(), which means it still incurs the full PyTorch Module overhead. The warning is issued but the slow path is always taken.

To actually optimize, the code should:

  1. Skip the warning during __init__ (check if "_initialized" not in self.__dict__)
  2. After init, require explicit use of fast_setattr() or module_setattr()
  3. Consider making __setattr__ raise an error instead of a warning to enforce the pattern

Currently, this just adds overhead (checking _initialized + issuing warning) on top of the existing slow path.

Comment on lines 1111 to 1113
def train(self, mode: bool = True) -> Self:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", ".*CPU overhead.*")
return super().train(mode)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Workaround Needed: train() method silences the RuntimeWarning

This method filters out the "CPU overhead" warning because PyTorch's Module.train() internally uses __setattr__ to set the training flag on all child modules. This is a workaround for the fact that the base implementation can't be avoided in some cases.

This reveals a potential issue: other PyTorch internal methods that use __setattr__ will also trigger warnings. Consider:

  • to() device transfers
  • half(), float(), bfloat16() dtype conversions
  • register_module(), register_buffer(), register_parameter()

These may all need similar workarounds.

Comment on lines 1040 to 1096
def prepare_forward(
self,
inp: torch.Tensor,
num_gemms: int = 1,
allow_non_contiguous: bool = False,
allow_different_data_and_param_types: bool = False,
) -> Generator[torch.Tensor, None, None]:
"""Checks and prep for FWD.
The context manager is needed because there isn't a way for a module to know
if it's the last FP8 module in the forward autocast. It is useful
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one.
"""
self.allow_different_data_and_param_types = allow_different_data_and_param_types
self.forwarded_at_least_once = True
) -> torch.Tensor:
"""Checks and prepares for FWD execution."""
self.fast_setattr(
"allow_different_data_and_param_types", allow_different_data_and_param_types
)
self.fast_setattr("forwarded_at_least_once", True)

# Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase():
delayed_scaling_recipe = self.fp8_meta["recipe"].delayed()
FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
else:
assert inp.is_cuda, "TransformerEngine needs CUDA."

if self.tp_size > 1:
assert self.tp_group_initialized, "TP group not initialized."

self.set_activation_dtype(inp)
self.init_fp8_metadata(num_gemms=num_gemms)
self._check_weight_tensor_recipe_correspondence()

delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed()
if delayed_scaling_recipe:
if self.sequence_parallel:
assert self.fp8_meta["recipe"].reduce_amax, (
"Amax reduction across tensor parallel group is "
"necessary when using sequence parallelism with FP8."
)

if not FP8GlobalStateManager.fp8_graph_capturing():
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta)

# Activation recomputation is used and this is the first forward phase.
if self.training and is_fp8_activation_recompute_enabled():
FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)

with get_nvtx_range_context(self.__class__.__name__ + " forward"):
if not allow_non_contiguous and not inp.is_contiguous():
inp = inp.contiguous()
yield inp
nvtx_range_push(self.__class__.__name__ + " forward")
if not allow_non_contiguous and not inp.is_contiguous():
inp = inp.contiguous()
return inp

def end_forward(self):
"""
Required to be called at the end of the forward function to properly handle
DelayedScaling metadata handling and the NVTX ranges.
"""
delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed()
if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
nvtx_range_pop()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] NVTX range imbalance on exceptions: The refactoring from context manager to explicit prepare_forward()/end_forward() calls breaks exception safety. If an exception occurs between prepare_forward() (which calls nvtx_range_push) and end_forward() (which calls nvtx_range_pop), the NVTX range stack will be corrupted.

This affects all forward methods in Linear, LayerNormLinear, LayerNormMLP, and GroupedLinear modules. Any exception (shape mismatch, CUDA OOM, assertion failure, etc.) will leave an unclosed NVTX range.

The original context manager pattern guaranteed cleanup via __exit__. The new pattern requires all code paths to call end_forward(), but exceptions bypass this.

Comment on lines 652 to +661
def __setattr__(self, name: str, value: Any) -> None:
if name in TransformerEngineBaseModule._fast_setattr_names:
# torch.nn.Module has a custom __setattr__ that handles
# modules, parameters, and buffers. This is unnecessary
# overhead when setting plain attrs.
self.__dict__[name] = value
else:
# Default case
super().__setattr__(name, value)
if "_initialized" in self.__dict__ and self._initialized:
warnings.warn(
"""The default implementation of torch.nn.Module introduces significant CPU overhead
when setting attributes and is therefore not recommended. Please use the explicit
calls (fast_setattr for setting regular values and module_setattr for setting
parameters, children modules and buffers).""",
RuntimeWarning,
)
super().__setattr__(name, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] setattr warning defeats the optimization: The __setattr__ override emits a RuntimeWarning when _initialized=True, which pytest.ini treats as an error. However, __setattr__ still calls super().__setattr__(), meaning it doesn't actually prevent or optimize anything - it just warns.

If the goal is CPU optimization, the warning should either:

  1. Be accompanied by raising an exception to force using fast_setattr(), or
  2. Implement the fast path directly in __setattr__ when appropriate

Currently, any test that triggers the warning will fail, but production code will just emit warnings with no actual optimization benefit from this implementation.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Additional Comments (4)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignment will trigger RuntimeWarning treated as error

These direct assignments using self.attribute = will trigger the RuntimeWarning from the overridden __setattr__ method (line 652-660), which pytest.ini now treats as an error.

This method is called after module initialization (when _initialized=True), so these assignments will fail tests.

        self.fast_setattr("tp_group", tp_group)
        self.fast_setattr("tp_group_initialized", True)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignment will trigger RuntimeWarning treated as error

These direct assignments will trigger RuntimeWarning since is_debug_iter() is called during forward passes (after _initialized=True).

            self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
            self.fast_setattr("debug_enabled_in_this_iteration", debug)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignment will trigger RuntimeWarning treated as error

This assignment will also trigger the RuntimeWarning.

        self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())

transformer_engine/common/transformer_engine.cpp
Good optimization: Moving the early return check (if (index == 0) return;) before acquiring the mutex lock avoids unnecessary synchronization overhead when freeing null pointers. This is a correct and safe optimization that reduces contention on the allocator mutex.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Additional Comments (6)

transformer_engine/pytorch/module/base.py, line 1096
[P0] NVTX range imbalance on exceptions: The refactoring from context manager to explicit prepare_forward()/end_forward() calls breaks exception safety. If an exception occurs between prepare_forward() (which calls nvtx_range_push) and end_forward() (which calls nvtx_range_pop), the NVTX range stack will be corrupted.

This affects all forward methods in Linear, LayerNormLinear, LayerNormMLP, and GroupedLinear modules. Any exception (shape mismatch, CUDA OOM, assertion failure, etc.) will leave an unclosed NVTX range.

The original context manager pattern guaranteed cleanup via __exit__. The new pattern requires all code paths to call end_forward(), but exceptions bypass this.


transformer_engine/pytorch/module/base.py, line 661
[P1] setattr warning defeats the optimization: The __setattr__ override emits a RuntimeWarning when _initialized=True, which pytest.ini treats as an error. However, __setattr__ still calls super().__setattr__(), meaning it doesn't actually prevent or optimize anything - it just warns.

If the goal is CPU optimization, the warning should either:

  1. Be accompanied by raising an exception to force using fast_setattr(), or
  2. Implement the fast path directly in __setattr__ when appropriate

Currently, any test that triggers the warning will fail, but production code will just emit warnings with no actual optimization benefit from this implementation.


transformer_engine/pytorch/module/base.py, line 968
[P0] set_tensor_parallel_group uses slow setattr after init: This method is called after module initialization (when _initialized=True), directly assigning to self.tp_group and self.tp_group_initialized using =, which triggers the RuntimeWarning that pytest treats as an error.

        self.fast_setattr("tp_group", tp_group)
        self.fast_setattr("tp_group_initialized", True)

transformer_engine/pytorch/module/base.py, line 1578
[P0] is_debug_iter uses slow setattr after init: The is_debug_iter() method directly assigns to self.debug_last_iteration and self.debug_enabled_in_this_iteration at lines 1564-1565 and 1571. Since this is called during forward passes (after _initialized=True), these assignments will trigger RuntimeWarning that pytest treats as errors.

            self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
            self.fast_setattr("debug_enabled_in_this_iteration", debug)

And at line 1571:

        self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())

transformer_engine/pytorch/module/base.py, line 1115
[P3] train() method hides warnings inappropriately: The train() method suppresses CPU overhead warnings with a catch_warnings context. This defeats the purpose of the pytest.ini configuration that treats RuntimeWarning as errors, allowing slow paths to be used silently during .train() / .eval() calls.

If the warnings are genuinely problematic here, the underlying issue should be fixed rather than suppressed.


transformer_engine/pytorch/module/layernorm_mlp.py, line 2093
[P0] Attribute assignment will trigger RuntimeWarning: Direct assignment to self.bias_gelu_nvfusion uses the slow __setattr__ path after initialization, which triggers a RuntimeWarning that pytest treats as an error.

            self.fast_setattr("bias_gelu_nvfusion", False)

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +1082 to +1095
nvtx_range_push(self.__class__.__name__ + " forward")
if not allow_non_contiguous and not inp.is_contiguous():
inp = inp.contiguous()
return inp

def end_forward(self):
"""
Required to be called at the end of the forward function to properly handle
DelayedScaling metadata handling and the NVTX ranges.
"""
delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed()
if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
nvtx_range_pop()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CRITICAL: NVTX range imbalance on exceptions - will corrupt profiling stack

The refactoring from context manager (with get_nvtx_range_context(...)) to explicit nvtx_range_push() / nvtx_range_pop() calls breaks exception safety. If an exception occurs between prepare_forward() (which calls nvtx_range_push) and end_forward() (which calls nvtx_range_pop), the NVTX range stack will be permanently corrupted.

The original context manager pattern guaranteed cleanup via __exit__. The new pattern requires all code paths to call end_forward(), but exceptions bypass this.

This affects all forward methods in Linear, LayerNormLinear, LayerNormMLP, and GroupedLinear modules. Any exception (shape mismatch, CUDA OOM, assertion failure, etc.) will leave an unclosed NVTX range, affecting all subsequent profiling.

Recommended fix: Either:

  1. Restore the context manager pattern, or
  2. Wrap all forward pass code in try/finally blocks to ensure nvtx_range_pop() is called, or
  3. Document that forward passes must be wrapped in exception handlers that call end_forward() on errors

Comment on lines 652 to +661
def __setattr__(self, name: str, value: Any) -> None:
if name in TransformerEngineBaseModule._fast_setattr_names:
# torch.nn.Module has a custom __setattr__ that handles
# modules, parameters, and buffers. This is unnecessary
# overhead when setting plain attrs.
self.__dict__[name] = value
else:
# Default case
super().__setattr__(name, value)
if "_initialized" in self.__dict__ and self._initialized:
warnings.warn(
"""The default implementation of torch.nn.Module introduces significant CPU overhead
when setting attributes and is therefore not recommended. Please use the explicit
calls (fast_setattr for setting regular values and module_setattr for setting
parameters, children modules and buffers).""",
RuntimeWarning,
)
super().__setattr__(name, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WARNING: setattr override defeats the optimization it's meant to enforce

The __setattr__ override emits a RuntimeWarning when _initialized=True, but then still calls super().__setattr__(name, value) which goes through PyTorch's slow path anyway. This means:

  1. You pay the cost of the warning check
  2. You still pay the full cost of PyTorch's __setattr__
  3. The warning is treated as an error by pytest, so any code path using it will fail tests

This doesn't actually optimize anything - it just breaks code that uses standard attribute assignment. The optimization only works when code explicitly calls fast_setattr().

Implications:

  • Any code path (including external libraries or user code) that does module.attribute = value will fail in tests
  • The train() method workaround (lines 1111-1114) shows this is already a problem
  • Properties, descriptors, and any dynamic attribute access will trigger this warning

Recommended approach:

  • Remove the warning from __setattr__ - it creates more problems than it solves
  • Keep fast_setattr() as an opt-in optimization
  • Add comments/documentation encouraging use of fast_setattr() where performance matters

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Additional Comments (2)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignments will trigger RuntimeWarning treated as test failure

These direct assignments using self.attribute = will trigger the RuntimeWarning from the overridden __setattr__ method (lines 652-661), which pytest.ini now treats as an error. This method is called after module initialization (when _initialized=True), so these assignments will cause tests to fail.

        self.fast_setattr("tp_group", tp_group)
        self.fast_setattr("tp_group_initialized", True)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignments will trigger RuntimeWarning treated as test failure

These direct assignments to self.debug_last_iteration and self.debug_enabled_in_this_iteration will trigger RuntimeWarning since is_debug_iter() is called during forward passes (after _initialized=True). The pytest.ini configuration now treats RuntimeWarnings as errors, causing test failures.

            self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
            self.fast_setattr("debug_enabled_in_this_iteration", debug)
        else:
            # If this is the same iteration as previous invocation of the module,
            # we use the debug value from the first invocation in the iteration.
            debug = self.debug_enabled_in_this_iteration

        self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Additional Comments (3)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignments will trigger RuntimeWarning treated as test failure

The set_tensor_parallel_group() method uses direct attribute assignment (self.tp_group = tp_group and self.tp_group_initialized = True). This method is called after module initialization (when _initialized=True), so these assignments will trigger the RuntimeWarning from the overridden __setattr__ method (lines 652-661), which pytest.ini now treats as an error.

        self.fast_setattr("tp_group", tp_group)
        self.fast_setattr("tp_group_initialized", True)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignments will trigger RuntimeWarning treated as test failure

The is_debug_iter() method uses direct attribute assignment for self.debug_last_iteration and self.debug_enabled_in_this_iteration. Since this method is called during forward passes (after _initialized=True), these assignments will trigger RuntimeWarnings that pytest.ini treats as errors.

            self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
            self.fast_setattr("debug_enabled_in_this_iteration", debug)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignment will trigger RuntimeWarning treated as test failure

This direct assignment to self.debug_last_iteration in the is_debug_iter() method will trigger a RuntimeWarning when called during forward passes (after _initialized=True).

        self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 9 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 1399 to 1478
).is_fp8_ubuf():
fp8_grad = True

with self.prepare_forward(
inp,
allow_non_contiguous=isinstance(inp, QuantizedTensor),
) as inp:
inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor))

weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()

quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)

(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
linear_fn = _Linear.apply
autograd_ctx = []
else:
linear_fn = _Linear.forward
autograd_ctx = [None]

non_tensor_args = (
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
is_grad_enabled,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.ub_name,
fp8_output,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.save_original_input,
debug,
)
out = linear_fn(
*autograd_ctx,
weight_tensor,
inp,
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
non_tensor_args,
)
(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
linear_fn = _Linear.apply
autograd_ctx = []
else:
linear_fn = _Linear.forward
autograd_ctx = [None]

non_tensor_args = (
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
is_grad_enabled,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.ub_name,
fp8_output,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.save_original_input,
debug,
)
out = linear_fn(
*autograd_ctx,
weight_tensor,
inp,
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
non_tensor_args,
)
self.end_forward()
if self.gemm_bias_unfused_add:
out = out + cast_if_needed(bias_tensor, self.activation_dtype)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] The refactoring from context manager to manual prepare_forward()/end_forward() calls breaks exception safety. If any exception occurs between these calls (e.g., shape mismatch, CUDA OOM, assertion failure in the autograd function), end_forward() is never called, leaving an unclosed NVTX range via nvtx_range_push() without matching nvtx_range_pop(). This corrupts the NVTX stack for profiling.

The original context manager pattern guaranteed cleanup via __exit__. Consider either wrapping in try-finally or reverting to the context manager pattern (prepare_forward_ctx).

Comment on lines 1516 to 1603
).is_fp8_ubuf():
fp8_grad = True

with self.prepare_forward(
inp = self.prepare_forward(
inp, allow_non_contiguous=False # removed .contiguous from inside the layer
) as inp:
)

# Get concatenated weight and bias tensors
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
# Get concatenated weight and bias tensors
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()

quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)

(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
fwd_fn = _LayerNormLinear.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormLinear.forward
autograd_ctx = [None]
non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_rs_dgrad,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_name,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)
(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
fwd_fn = _LayerNormLinear.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormLinear.forward
autograd_ctx = [None]
non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_rs_dgrad,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_name,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)

self.end_forward()

if self.return_layernorm_output:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] Same NVTX range imbalance issue as Linear. If an exception occurs between prepare_forward() (which calls nvtx_range_push) and end_forward() (which calls nvtx_range_pop), the NVTX stack will be corrupted. The context manager automatically handled this cleanup.

Comment on lines 2054 to +2167
if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf():
fp8_output = True

with self.prepare_forward(inp, num_gemms=2) as inp:
inp = self.prepare_forward(inp, num_gemms=2)

quantizers = (
self._get_quantizers(fp8_output, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, is_grad_enabled)
quantizers = (
self._get_quantizers(fp8_output, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, is_grad_enabled)

# Get quantizers
(
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
) = quantizers

# Get weight tensors
fc1_weight, fc2_weight = self._get_weight_tensors()
fc1_bias = self.fc1_bias if self.use_bias else None
fc2_bias = self.fc2_bias if self.use_bias else None
if not self.fp8:
if isinstance(fc1_weight, Float8Tensor):
fc1_weight = fc1_weight.dequantize()
if isinstance(fc2_weight, Float8Tensor):
fc2_weight = fc2_weight.dequantize()

# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute():
self.bias_gelu_nvfusion = False
# Get quantizers
(
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
) = quantizers

if is_grad_enabled:
fwd_fn = _LayerNormMLP.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormMLP.forward
autograd_ctx = [None]

non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
self.bias_gelu_nvfusion and not self.fp8 and not debug,
self.set_parallel_mode,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.activation,
self.activation_params,
self.normalization,
self.ub_overlap_ag,
self.ub_overlap_rs,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.gemm_gelu_fusion and not debug,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.checkpoint,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
fc1_weight,
fc1_bias,
fc2_weight,
fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)
# Get weight tensors
fc1_weight, fc2_weight = self._get_weight_tensors()
fc1_bias = self.fc1_bias if self.use_bias else None
fc2_bias = self.fc2_bias if self.use_bias else None
if not self.fp8:
if isinstance(fc1_weight, Float8Tensor):
fc1_weight = fc1_weight.dequantize()
if isinstance(fc2_weight, Float8Tensor):
fc2_weight = fc2_weight.dequantize()

# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute():
self.fast_setattr("bias_gelu_nvfusion", False)

if is_grad_enabled:
fwd_fn = _LayerNormMLP.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormMLP.forward
autograd_ctx = [None]

non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
self.bias_gelu_nvfusion and not self.fp8 and not debug,
self.set_parallel_mode,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.activation,
self.activation_params,
self.normalization,
self.ub_overlap_ag,
self.ub_overlap_rs,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.gemm_gelu_fusion and not debug,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.checkpoint,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
fc1_weight,
fc1_bias,
fc2_weight,
fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)

self.end_forward()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] Same NVTX range imbalance issue. Exception between prepare_forward() and end_forward() will leave unclosed NVTX range, corrupting profiling stack.

Comment on lines 790 to +852

is_grad_enabled = torch.is_grad_enabled()

with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp:
weight_tensors = self._get_weight_tensors()
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]
inp = self.prepare_forward(inp, num_gemms=self.num_gemms)
weight_tensors = self._get_weight_tensors()
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]

quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers()
quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers()

if debug:
if self.no_debug_features_active(list(chain(*quantizers))):
debug = False
quantizers = self._get_quantizers()
if debug:
if self.no_debug_features_active(list(chain(*quantizers))):
debug = False
quantizers = self._get_quantizers()

if isinstance(weight_tensors, QuantizedTensorStorage):
raise RuntimeError("FP8 weights are not supported in debug mode.")
if isinstance(weight_tensors, QuantizedTensorStorage):
raise RuntimeError("FP8 weights are not supported in debug mode.")

(
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
) = quantizers
(
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
) = quantizers

if is_grad_enabled:
linear_fn = _GroupedLinear.apply
autograd_ctx = []
else:
linear_fn = _GroupedLinear.forward
autograd_ctx = [None]

non_tensor_args = (
m_splits,
self.apply_bias,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.sequence_parallel,
self.activation_dtype,
is_grad_enabled,
self,
None, # skip_fp8_weight_update
self.save_original_input,
debug,
)
out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)
if is_grad_enabled:
linear_fn = _GroupedLinear.apply
autograd_ctx = []
else:
linear_fn = _GroupedLinear.forward
autograd_ctx = [None]

non_tensor_args = (
m_splits,
self.apply_bias,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.sequence_parallel,
self.activation_dtype,
is_grad_enabled,
self,
None, # skip_fp8_weight_update
self.save_original_input,
debug,
)
out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] Same NVTX range imbalance issue. Exception between prepare_forward() and end_forward() will leave unclosed NVTX range.

Comment on lines 1111 to 1113
def train(self, mode: bool = True) -> Self:
with warnings.catch_warnings():
warnings.filterwarnings("ignore", ".*CPU overhead.*")
return super().train(mode)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P2] The train() method suppresses the CPU overhead warning by filtering it out. However, this means PyTorch's __setattr__ is still being called for all the child module state changes during train()/eval() mode switching. This defeats the optimization goal since the expensive __setattr__ path is still taken - it just doesn't warn about it.

The proper fix would be to override train() to use fast_setattr for the training attribute on all child modules, rather than suppressing the warning.

Comment on lines 652 to +661
def __setattr__(self, name: str, value: Any) -> None:
if name in TransformerEngineBaseModule._fast_setattr_names:
# torch.nn.Module has a custom __setattr__ that handles
# modules, parameters, and buffers. This is unnecessary
# overhead when setting plain attrs.
self.__dict__[name] = value
else:
# Default case
super().__setattr__(name, value)
if "_initialized" in self.__dict__ and self._initialized:
warnings.warn(
"""The default implementation of torch.nn.Module introduces significant CPU overhead
when setting attributes and is therefore not recommended. Please use the explicit
calls (fast_setattr for setting regular values and module_setattr for setting
parameters, children modules and buffers).""",
RuntimeWarning,
)
super().__setattr__(name, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P2] The __setattr__ override still calls super().__setattr__(name, value) after emitting the warning, meaning the expensive PyTorch __setattr__ is always executed. This doesn't optimize anything - it only adds overhead by checking _initialized and potentially emitting a warning before doing the expensive operation anyway.

If the goal is CPU optimization, the __setattr__ should use fast_setattr for appropriate attributes instead of just warning. The current implementation only serves as a detection mechanism, not an optimization.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Additional Comments (3)

transformer_engine/pytorch/module/base.py
[P0] These direct attribute assignments will trigger the RuntimeWarning defined in __setattr__ (lines 652-661), which pytest.ini now treats as an error. The set_tensor_parallel_group() method is called after module initialization (when _initialized=True), causing tests to fail.

        self.fast_setattr("tp_group", tp_group)
        self.fast_setattr("tp_group_initialized", True)

transformer_engine/pytorch/module/base.py
[P0] These direct attribute assignments will trigger RuntimeWarning since is_debug_iter() is called during forward passes (after _initialized=True). With pytest.ini treating RuntimeWarnings as errors, this will cause test failures.

            self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
            self.fast_setattr("debug_enabled_in_this_iteration", debug)

transformer_engine/pytorch/module/base.py
[P0] This direct assignment will trigger the RuntimeWarning. Called during forward passes when _initialized=True.

        self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7 files reviewed, 7 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +1402 to +1475
inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor))

weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()

quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)

(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
linear_fn = _Linear.apply
autograd_ctx = []
else:
linear_fn = _Linear.forward
autograd_ctx = [None]

non_tensor_args = (
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
is_grad_enabled,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.ub_name,
fp8_output,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.save_original_input,
debug,
)
out = linear_fn(
*autograd_ctx,
weight_tensor,
inp,
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
non_tensor_args,
)
(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
linear_fn = _Linear.apply
autograd_ctx = []
else:
linear_fn = _Linear.forward
autograd_ctx = [None]

non_tensor_args = (
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
is_grad_enabled,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.ub_name,
fp8_output,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.save_original_input,
debug,
)
out = linear_fn(
*autograd_ctx,
weight_tensor,
inp,
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
non_tensor_args,
)
self.end_forward()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] NVTX range imbalance on exceptions. The refactoring from context manager to manual prepare_forward()/end_forward() breaks exception safety. If any exception occurs between line 1402 (prepare_forward() which calls nvtx_range_push) and line 1474 (end_forward() which calls nvtx_range_pop), the NVTX range stack will be corrupted. Possible exceptions include shape mismatches, CUDA OOM, assertion failures (e.g., line 1061 in base.py: assert self.tp_group_initialized), or errors in quantizer initialization. The original context manager guaranteed cleanup via __exit__, but the new pattern has no exception handling.

Comment on lines +1519 to +1595
inp = self.prepare_forward(
inp, allow_non_contiguous=False # removed .contiguous from inside the layer
) as inp:
)

# Get concatenated weight and bias tensors
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
# Get concatenated weight and bias tensors
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()

quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)

(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
fwd_fn = _LayerNormLinear.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormLinear.forward
autograd_ctx = [None]
non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_rs_dgrad,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_name,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)
(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
fwd_fn = _LayerNormLinear.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormLinear.forward
autograd_ctx = [None]
non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_rs_dgrad,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_name,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] NVTX range imbalance on exceptions. Same issue as Linear module - if any exception occurs between prepare_forward() (line 1519, calls nvtx_range_push) and end_forward() (line 1595, calls nvtx_range_pop), the NVTX stack is corrupted. Possible exceptions include shape mismatches, CUDA errors, or assertion failures.

Comment on lines +793 to +852
inp = self.prepare_forward(inp, num_gemms=self.num_gemms)
weight_tensors = self._get_weight_tensors()
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]

quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers()
quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers()

if debug:
if self.no_debug_features_active(list(chain(*quantizers))):
debug = False
quantizers = self._get_quantizers()
if debug:
if self.no_debug_features_active(list(chain(*quantizers))):
debug = False
quantizers = self._get_quantizers()

if isinstance(weight_tensors, QuantizedTensorStorage):
raise RuntimeError("FP8 weights are not supported in debug mode.")
if isinstance(weight_tensors, QuantizedTensorStorage):
raise RuntimeError("FP8 weights are not supported in debug mode.")

(
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
) = quantizers
(
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
) = quantizers

if is_grad_enabled:
linear_fn = _GroupedLinear.apply
autograd_ctx = []
else:
linear_fn = _GroupedLinear.forward
autograd_ctx = [None]

non_tensor_args = (
m_splits,
self.apply_bias,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.sequence_parallel,
self.activation_dtype,
is_grad_enabled,
self,
None, # skip_fp8_weight_update
self.save_original_input,
debug,
)
out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)
if is_grad_enabled:
linear_fn = _GroupedLinear.apply
autograd_ctx = []
else:
linear_fn = _GroupedLinear.forward
autograd_ctx = [None]

non_tensor_args = (
m_splits,
self.apply_bias,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.sequence_parallel,
self.activation_dtype,
is_grad_enabled,
self,
None, # skip_fp8_weight_update
self.save_original_input,
debug,
)
out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] NVTX range imbalance on exceptions. Same issue as other modules - if any exception occurs between prepare_forward() (line 793, calls nvtx_range_push) and end_forward() (line 847, calls nvtx_range_pop), the NVTX stack is corrupted.

Comment on lines +2057 to +2167
inp = self.prepare_forward(inp, num_gemms=2)

quantizers = (
self._get_quantizers(fp8_output, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, is_grad_enabled)
quantizers = (
self._get_quantizers(fp8_output, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, is_grad_enabled)

# Get quantizers
(
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
) = quantizers

# Get weight tensors
fc1_weight, fc2_weight = self._get_weight_tensors()
fc1_bias = self.fc1_bias if self.use_bias else None
fc2_bias = self.fc2_bias if self.use_bias else None
if not self.fp8:
if isinstance(fc1_weight, Float8Tensor):
fc1_weight = fc1_weight.dequantize()
if isinstance(fc2_weight, Float8Tensor):
fc2_weight = fc2_weight.dequantize()

# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute():
self.bias_gelu_nvfusion = False
# Get quantizers
(
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
) = quantizers

if is_grad_enabled:
fwd_fn = _LayerNormMLP.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormMLP.forward
autograd_ctx = [None]

non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
self.bias_gelu_nvfusion and not self.fp8 and not debug,
self.set_parallel_mode,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.activation,
self.activation_params,
self.normalization,
self.ub_overlap_ag,
self.ub_overlap_rs,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.gemm_gelu_fusion and not debug,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.checkpoint,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
fc1_weight,
fc1_bias,
fc2_weight,
fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)
# Get weight tensors
fc1_weight, fc2_weight = self._get_weight_tensors()
fc1_bias = self.fc1_bias if self.use_bias else None
fc2_bias = self.fc2_bias if self.use_bias else None
if not self.fp8:
if isinstance(fc1_weight, Float8Tensor):
fc1_weight = fc1_weight.dequantize()
if isinstance(fc2_weight, Float8Tensor):
fc2_weight = fc2_weight.dequantize()

# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute():
self.fast_setattr("bias_gelu_nvfusion", False)

if is_grad_enabled:
fwd_fn = _LayerNormMLP.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormMLP.forward
autograd_ctx = [None]

non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
self.bias_gelu_nvfusion and not self.fp8 and not debug,
self.set_parallel_mode,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.activation,
self.activation_params,
self.normalization,
self.ub_overlap_ag,
self.ub_overlap_rs,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.gemm_gelu_fusion and not debug,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.checkpoint,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
fc1_weight,
fc1_bias,
fc2_weight,
fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)

self.end_forward()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P1] NVTX range imbalance on exceptions. Same issue as other modules - if any exception occurs between prepare_forward() (line 2057, calls nvtx_range_push) and end_forward() (line 2167, calls nvtx_range_pop), the NVTX stack is corrupted.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Additional Comments (3)

transformer_engine/pytorch/module/base.py
[P0] Direct attribute assignments will trigger RuntimeWarning treated as test failure. This method is called post-initialization (when _initialized=True), so these assignments trigger the __setattr__ override (line 653) that emits RuntimeWarning. With pytest.ini now treating RuntimeWarnings as errors, all tests calling this method will fail.

        self.fast_setattr("tp_group", tp_group)
        self.fast_setattr("tp_group_initialized", True)

transformer_engine/pytorch/module/base.py
[P0] Direct attribute assignments will trigger RuntimeWarning treated as test failure. The is_debug_iter() method is called during forward passes (post-initialization), so these assignments trigger the __setattr__ override that emits RuntimeWarning, causing pytest to fail.

            self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
            self.fast_setattr("debug_enabled_in_this_iteration", debug)

transformer_engine/pytorch/module/base.py
[P0] Direct attribute assignment will trigger RuntimeWarning treated as test failure. This line is reached during forward passes when _initialized=True, triggering the __setattr__ override.

        self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7 files reviewed, 7 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 652 to +660
def __setattr__(self, name: str, value: Any) -> None:
if name in TransformerEngineBaseModule._fast_setattr_names:
# torch.nn.Module has a custom __setattr__ that handles
# modules, parameters, and buffers. This is unnecessary
# overhead when setting plain attrs.
self.__dict__[name] = value
else:
# Default case
super().__setattr__(name, value)
if "_initialized" in self.__dict__ and self._initialized:
warnings.warn(
"""The default implementation of torch.nn.Module introduces significant CPU overhead
when setting attributes and is therefore not recommended. Please use the explicit
calls (fast_setattr for setting regular values and module_setattr for setting
parameters, children modules and buffers).""",
RuntimeWarning,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CRITICAL: setattr override doesn't actually optimize anything

This implementation defeats the intended optimization. The goal was to bypass PyTorch's expensive __setattr__ by using direct __dict__ access, but this override:

  1. Still calls super().__setattr__(name, value) which has the exact same overhead as before
  2. Only adds a warning without actually optimizing the assignment
  3. Makes every attribute assignment after initialization trigger both the warning AND the slow path

The intended optimization pattern should be:

  • Use fast_setattr() for regular attributes (bypasses this method entirely via __dict__)
  • Use module_setattr() for parameters/buffers/properties (calls super)
  • Don't override __setattr__ at all, or only override it during initialization to set _initialized

The current implementation provides no performance benefit and breaks existing code.

Comment on lines +1082 to +1096
nvtx_range_push(self.__class__.__name__ + " forward")
if not allow_non_contiguous and not inp.is_contiguous():
inp = inp.contiguous()
return inp

def end_forward(self):
"""
Required to be called at the end of the forward function to properly handle
DelayedScaling metadata handling and the NVTX ranges.
"""
delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed()
if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
nvtx_range_pop()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CRITICAL: Exception safety broken - NVTX ranges will be corrupted on exceptions

The refactoring from a context manager to manual prepare_forward()/end_forward() calls breaks exception safety. If any exception occurs between these calls (e.g., shape mismatch, CUDA OOM, assertion failure, type error), nvtx_range_pop() will never be called, corrupting the NVTX profiling stack.

This affects all forward methods in Linear, LayerNormLinear, LayerNormMLP, and GroupedLinear modules.

The original context manager pattern guaranteed cleanup via __exit__. The new pattern requires wrapping the code between prepare_forward() and end_forward() in a try-finally block, or keeping the context manager pattern (the new prepare_forward_ctx preserves the context manager but isn't being used).

Consider either:

  1. Using prepare_forward_ctx() instead in all forward methods
  2. Adding try-finally blocks around all forward implementations
  3. Reverting the context manager refactoring if the performance benefit doesn't justify the complexity

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Additional Comments (5)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignments will cause test failures

These direct assignments using self.attribute = value will trigger the RuntimeWarning from the overridden __setattr__ method (lines 652-660), which pytest.ini now treats as an error.

This method is called after module initialization (when _initialized=True), so these assignments will cause all tests calling set_tensor_parallel_group() to fail.

        self.fast_setattr("tp_group", tp_group)
        self.fast_setattr("tp_group_initialized", True)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignments will cause test failures

These direct assignments will trigger RuntimeWarning since is_debug_iter() is called during forward passes (after _initialized=True). With pytest.ini treating RuntimeWarnings as errors, this will cause tests to fail.

            self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
            self.fast_setattr("debug_enabled_in_this_iteration", debug)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignment will cause test failures

This assignment occurs during forward passes (post-initialization) and will trigger RuntimeWarning, causing pytest to fail.

        self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignment will cause test failures

This assignment occurs during forward passes and will trigger the RuntimeWarning.

        self.fast_setattr("next_iter_when_debug_should_be_run", next_iter_when_debug_should_be_run(quantizers))

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignment will cause test failures

This assignment can occur post-initialization and will trigger RuntimeWarning.

            self.fast_setattr("name", f"Layer_{TEDebugState.get_layer_count()}")

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 8 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 1040 to 1096
def prepare_forward(
self,
inp: torch.Tensor,
num_gemms: int = 1,
allow_non_contiguous: bool = False,
allow_different_data_and_param_types: bool = False,
) -> Generator[torch.Tensor, None, None]:
"""Checks and prep for FWD.
The context manager is needed because there isn't a way for a module to know
if it's the last FP8 module in the forward autocast. It is useful
to setup the forward aggregated amax reduction for every module
just in case. The autocast exit will pick up the most recent one.
"""
self.allow_different_data_and_param_types = allow_different_data_and_param_types
self.forwarded_at_least_once = True
) -> torch.Tensor:
"""Checks and prepares for FWD execution."""
self.fast_setattr(
"allow_different_data_and_param_types", allow_different_data_and_param_types
)
self.fast_setattr("forwarded_at_least_once", True)

# Activation recomputation is used and this is the second forward phase.
if self.fp8 and in_fp8_activation_recompute_phase():
delayed_scaling_recipe = self.fp8_meta["recipe"].delayed()
FP8GlobalStateManager.get_old_fp8_meta_tensors_for_recompute(self.fp8_meta)
else:
assert inp.is_cuda, "TransformerEngine needs CUDA."

if self.tp_size > 1:
assert self.tp_group_initialized, "TP group not initialized."

self.set_activation_dtype(inp)
self.init_fp8_metadata(num_gemms=num_gemms)
self._check_weight_tensor_recipe_correspondence()

delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed()
if delayed_scaling_recipe:
if self.sequence_parallel:
assert self.fp8_meta["recipe"].reduce_amax, (
"Amax reduction across tensor parallel group is "
"necessary when using sequence parallelism with FP8."
)

if not FP8GlobalStateManager.fp8_graph_capturing():
FP8GlobalStateManager.add_fp8_tensors_to_global_buffer(self.fp8_meta)

# Activation recomputation is used and this is the first forward phase.
if self.training and is_fp8_activation_recompute_enabled():
FP8GlobalStateManager.copy_forward_fp8_meta_tensors_for_recompute(self.fp8_meta)

with get_nvtx_range_context(self.__class__.__name__ + " forward"):
if not allow_non_contiguous and not inp.is_contiguous():
inp = inp.contiguous()
yield inp
nvtx_range_push(self.__class__.__name__ + " forward")
if not allow_non_contiguous and not inp.is_contiguous():
inp = inp.contiguous()
return inp

def end_forward(self):
"""
Required to be called at the end of the forward function to properly handle
DelayedScaling metadata handling and the NVTX ranges.
"""
delayed_scaling_recipe = self.fp8 and self.fp8_meta["recipe"].delayed()
if delayed_scaling_recipe and self.fp8 and in_fp8_activation_recompute_phase():
FP8GlobalStateManager.restore_fp8_meta_tensors(self.fp8_meta)
nvtx_range_pop()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CRITICAL: NVTX range imbalance on exceptions - breaks profiling

The refactoring from a context manager to manual prepare_forward()/end_forward() calls breaks exception safety. The original context manager pattern guaranteed that nvtx_range_pop() would be called via __exit__ even if an exception occurred.

With the new pattern:

  • prepare_forward() calls nvtx_range_push() on line 1082
  • end_forward() calls nvtx_range_pop() on line 1095
  • If ANY exception occurs between these calls, nvtx_range_pop() is never executed

Impact: This affects Linear, LayerNormLinear, LayerNormMLP, and GroupedLinear modules. Any exception (shape mismatch, CUDA OOM, assertion failure, invalid tensor, etc.) will leave an unclosed NVTX range, corrupting the profiling stack.

Why this happened: The PR author mentions they avoided touching attention because it has multiple return statements and they were worried about missing something. The same issue exists here - the forward methods have many potential exception points.

Recommendation: Either:

  1. Keep using context manager (prepare_forward_ctx) like the attention module does, OR
  2. Wrap the forward logic in try/finally to ensure end_forward() is always called, OR
  3. Revert this optimization for modules with complex forward methods

Comment on lines 652 to +661
def __setattr__(self, name: str, value: Any) -> None:
if name in TransformerEngineBaseModule._fast_setattr_names:
# torch.nn.Module has a custom __setattr__ that handles
# modules, parameters, and buffers. This is unnecessary
# overhead when setting plain attrs.
self.__dict__[name] = value
else:
# Default case
super().__setattr__(name, value)
if "_initialized" in self.__dict__ and self._initialized:
warnings.warn(
"""The default implementation of torch.nn.Module introduces significant CPU overhead
when setting attributes and is therefore not recommended. Please use the explicit
calls (fast_setattr for setting regular values and module_setattr for setting
parameters, children modules and buffers).""",
RuntimeWarning,
)
super().__setattr__(name, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CRITICAL: setattr override defeats the optimization it's trying to enable

This __setattr__ override is intended to warn developers when they use the slow path, but it actually creates the performance overhead that the PR aims to eliminate:

  1. Every attribute assignment (even with fast_setattr) triggers __setattr__
  2. __setattr__ checks "_initialized" in self.__dict__ - this is a dictionary lookup
  3. Then it checks self._initialized - another dictionary lookup
  4. Then it calls warnings.warn() - string formatting, stack inspection, filtering
  5. Finally calls super().__setattr__() - the slow PyTorch path

The problem: Even when developers correctly use fast_setattr(), the __setattr__ method is STILL invoked because fast_setattr just does self.__dict__[name] = value, which doesn't trigger __setattr__. But any code that still uses direct assignment (like the bugs found in this PR) will pay a huge penalty.

Why this is worse:

  • During __init__, every self.x = y now does 2 dict lookups + warning machinery + slow setattr
  • The warning itself is expensive (stack inspection, string formatting)
  • With pytest.ini treating warnings as errors, this will abort tests

The train() method workaround (lines 1111-1114) shows this was already causing problems - it filters out the warning, but the overhead remains.

Recommendation: Remove the __setattr__ override entirely or only enable it in debug/development mode with an environment variable. Use static analysis tools (linters, mypy plugins) to catch incorrect usage instead of runtime checks.

Comment on lines +1402 to +1475
inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor))

weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()

quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)

(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
linear_fn = _Linear.apply
autograd_ctx = []
else:
linear_fn = _Linear.forward
autograd_ctx = [None]

non_tensor_args = (
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
is_grad_enabled,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.ub_name,
fp8_output,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.save_original_input,
debug,
)
out = linear_fn(
*autograd_ctx,
weight_tensor,
inp,
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
non_tensor_args,
)
(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
linear_fn = _Linear.apply
autograd_ctx = []
else:
linear_fn = _Linear.forward
autograd_ctx = [None]

non_tensor_args = (
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
is_grad_enabled,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.ub_name,
fp8_output,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.save_original_input,
debug,
)
out = linear_fn(
*autograd_ctx,
weight_tensor,
inp,
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
non_tensor_args,
)
self.end_forward()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CRITICAL: NVTX range will not be closed if exception occurs

The manual prepare_forward()/end_forward() pattern breaks exception safety. If ANY exception occurs between line 1402 (prepare_forward() which calls nvtx_range_push) and line 1474 (end_forward() which calls nvtx_range_pop), the NVTX profiling stack will be corrupted.

Potential exception points in this code:

  • Line 1404: _get_weight_and_bias_tensors() could fail
  • Lines 1406-1414: quantizer construction could fail
  • Line 1467: linear_fn() could raise exceptions (shape mismatches, CUDA OOM, invalid tensors, assertion failures, etc.)

Example failure scenario:

inp = self.prepare_forward(...)  # Pushes NVTX range
# ... setup code ...
out = linear_fn(...)  # Raises RuntimeError: shape mismatch
# end_forward() is NEVER called, NVTX range never popped
# All subsequent NVTX ranges are corrupted

Fix: Wrap in try/finally:

inp = self.prepare_forward(...)
try:
    # ... all forward logic ...
finally:
    self.end_forward()

Or use the context manager pattern that attention module correctly uses (prepare_forward_ctx).

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Additional Comments (5)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignments will trigger RuntimeWarning treated as test failure

These direct assignments using self.tp_group = tp_group and self.tp_group_initialized = True will trigger the RuntimeWarning from the overridden __setattr__ method (lines 652-661), which pytest.ini now treats as an error.

The set_tensor_parallel_group() method is called after module initialization (when _initialized=True), so these assignments will cause all tests that call this method to fail.

        self.fast_setattr("tp_group", tp_group)
        self.fast_setattr("tp_group_initialized", True)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignments will trigger RuntimeWarning treated as test failure

These direct assignments to self.debug_last_iteration and self.debug_enabled_in_this_iteration will trigger the RuntimeWarning from the overridden __setattr__ method (lines 652-661), which pytest.ini now treats as an error.

The is_debug_iter() method is called during forward passes (after _initialized=True), so these assignments will cause tests to fail whenever debug mode is used.

            self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
            self.fast_setattr("debug_enabled_in_this_iteration", debug)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignment will trigger RuntimeWarning treated as test failure

This direct assignment to self.debug_last_iteration will trigger the RuntimeWarning from the overridden __setattr__ method (lines 652-661), which pytest.ini now treats as an error.

This line is in the else branch of is_debug_iter() and is called during forward passes (after _initialized=True).

        self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignment will trigger RuntimeWarning treated as test failure

This direct assignment to self.next_iter_when_debug_should_be_run will trigger the RuntimeWarning from the overridden __setattr__ method (lines 652-661), which pytest.ini now treats as an error.

The no_debug_features_active() method is called during forward passes (after _initialized=True).

        self.fast_setattr("next_iter_when_debug_should_be_run", next_iter_when_debug_should_be_run(quantizers))

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignment will trigger RuntimeWarning treated as test failure

This direct assignment to self.name will trigger the RuntimeWarning from the overridden __setattr__ method (lines 652-661), which pytest.ini now treats as an error.

The _validate_name() method is called during forward passes (after _initialized=True), specifically when debug is enabled and no name has been assigned to the module.

            self.fast_setattr("name", f"Layer_{TEDebugState.get_layer_count()}")

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 5 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Additional Comments (5)

transformer_engine/pytorch/module/base.py
[P0] Direct attribute assignments will trigger RuntimeWarning treated as test failure. The new __setattr__ override (line 653) emits RuntimeWarning when _initialized=True, and pytest.ini now treats RuntimeWarnings as errors. Since set_tensor_parallel_group() is called after module initialization in user code, these assignments will cause all tests using this method to fail.

        self.fast_setattr("tp_group", tp_group)
        self.fast_setattr("tp_group_initialized", True)

transformer_engine/pytorch/module/base.py
[P0] Direct attribute assignments will trigger RuntimeWarning treated as test failure. The is_debug_iter() method is called during forward passes (after _initialized=True), so these assignments will trigger the __setattr__ override that emits RuntimeWarning. With pytest.ini treating RuntimeWarnings as errors, this will cause test failures.

            self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
            self.fast_setattr("debug_enabled_in_this_iteration", debug)

transformer_engine/pytorch/module/base.py
[P0] Direct attribute assignment will trigger RuntimeWarning treated as test failure. This assignment occurs during forward passes after initialization.

        self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())

transformer_engine/pytorch/module/base.py
[P0] Direct attribute assignment will trigger RuntimeWarning treated as test failure. The no_debug_features_active() method is called during forward passes after initialization.

        self.fast_setattr("next_iter_when_debug_should_be_run", next_iter_when_debug_should_be_run(quantizers))

transformer_engine/pytorch/module/base.py
[P0] Direct attribute assignment will trigger RuntimeWarning treated as test failure. The _validate_name() method is called during forward passes when debug is enabled.

            self.fast_setattr("name", f"Layer_{TEDebugState.get_layer_count()}")

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10 files reviewed, 10 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 652 to +661
def __setattr__(self, name: str, value: Any) -> None:
if name in TransformerEngineBaseModule._fast_setattr_names:
# torch.nn.Module has a custom __setattr__ that handles
# modules, parameters, and buffers. This is unnecessary
# overhead when setting plain attrs.
self.__dict__[name] = value
else:
# Default case
super().__setattr__(name, value)
if "_initialized" in self.__dict__ and self._initialized:
warnings.warn(
"""The default implementation of torch.nn.Module introduces significant CPU overhead
when setting attributes and is therefore not recommended. Please use the explicit
calls (fast_setattr for setting regular values and module_setattr for setting
parameters, children modules and buffers).""",
RuntimeWarning,
)
super().__setattr__(name, value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MAJOR ISSUE: The __setattr__ warning defeats the optimization purpose

The __setattr__ override emits a RuntimeWarning whenever it's called after initialization. However, emitting warnings has significant CPU overhead - it involves stack trace inspection, formatting, filtering, etc. This defeats the entire purpose of the optimization!

The warning will fire every time the slow path is used, which is precisely when you're trying to avoid overhead. This makes the "optimization" actually slower than before in cases where the slow path is accidentally used.

Consider one of these alternatives:

  1. Make __setattr__ raise an exception immediately (fail-fast during development)
  2. Remove the warning entirely and rely on code review to catch issues
  3. Add a debug mode that can be disabled in production

The current implementation is problematic because:

  • Warning overhead during runtime defeats the optimization
  • pytest treats it as an error anyway, so it's not really a "warning"
  • It will cause legitimate code paths to fail (see direct assignment issues in this file)

Comment on lines +1402 to +1475
inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor))

weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()

quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)

(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
linear_fn = _Linear.apply
autograd_ctx = []
else:
linear_fn = _Linear.forward
autograd_ctx = [None]

non_tensor_args = (
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
is_grad_enabled,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.ub_name,
fp8_output,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.save_original_input,
debug,
)
out = linear_fn(
*autograd_ctx,
weight_tensor,
inp,
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
non_tensor_args,
)
(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
linear_fn = _Linear.apply
autograd_ctx = []
else:
linear_fn = _Linear.forward
autograd_ctx = [None]

non_tensor_args = (
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
is_grad_enabled,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.ub_name,
fp8_output,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.save_original_input,
debug,
)
out = linear_fn(
*autograd_ctx,
weight_tensor,
inp,
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
non_tensor_args,
)
self.end_forward()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CRITICAL: NVTX range imbalance on exceptions will corrupt profiling

The refactoring from context manager to manual prepare_forward()/end_forward() calls breaks exception safety. prepare_forward() calls nvtx_range_push() (line 1082 in base.py), and end_forward() calls nvtx_range_pop() (line 1095).

If any exception occurs between lines 1402-1473 (before end_forward() is called on line 1474), the NVTX range stack will be corrupted because nvtx_range_pop() is never called. This includes:

  • Shape mismatches in tensors
  • CUDA out of memory errors
  • Assertion failures in quantizers
  • Type errors in _Linear.apply/forward
  • Any exception from user code

The original context manager pattern guaranteed cleanup via __exit__. The new pattern requires wrapping in try-finally:

Suggested change
inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor))
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers
if is_grad_enabled:
linear_fn = _Linear.apply
autograd_ctx = []
else:
linear_fn = _Linear.forward
autograd_ctx = [None]
non_tensor_args = (
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
is_grad_enabled,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.ub_name,
fp8_output,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.save_original_input,
debug,
)
out = linear_fn(
*autograd_ctx,
weight_tensor,
inp,
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
non_tensor_args,
)
(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers
if is_grad_enabled:
linear_fn = _Linear.apply
autograd_ctx = []
else:
linear_fn = _Linear.forward
autograd_ctx = [None]
non_tensor_args = (
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
is_grad_enabled,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.ub_name,
fp8_output,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.save_original_input,
debug,
)
out = linear_fn(
*autograd_ctx,
weight_tensor,
inp,
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
non_tensor_args,
)
self.end_forward()
inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor))
try:
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers
if is_grad_enabled:
linear_fn = _Linear.apply
autograd_ctx = []
else:
linear_fn = _Linear.forward
autograd_ctx = [None]
non_tensor_args = (
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
is_grad_enabled,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.ub_name,
fp8_output,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.save_original_input,
debug,
)
out = linear_fn(
*autograd_ctx,
weight_tensor,
inp,
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
non_tensor_args,
)
finally:
self.end_forward()

This same issue affects LayerNormLinear, LayerNormMLP, and GroupedLinear.

Comment on lines +1519 to +1600
inp = self.prepare_forward(
inp, allow_non_contiguous=False # removed .contiguous from inside the layer
) as inp:
)

# Get concatenated weight and bias tensors
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
# Get concatenated weight and bias tensors
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()

quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)

(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
fwd_fn = _LayerNormLinear.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormLinear.forward
autograd_ctx = [None]
non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_rs_dgrad,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_name,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)
(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
fwd_fn = _LayerNormLinear.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormLinear.forward
autograd_ctx = [None]
non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_rs_dgrad,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_name,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CRITICAL: NVTX range imbalance on exceptions

Same exception safety issue as in Linear module. If any exception occurs between prepare_forward() and end_forward(), the NVTX range stack will be corrupted. This code needs try-finally wrapping:

Suggested change
inp = self.prepare_forward(
inp, allow_non_contiguous=False # removed .contiguous from inside the layer
) as inp:
)
# Get concatenated weight and bias tensors
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
# Get concatenated weight and bias tensors
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers
if is_grad_enabled:
fwd_fn = _LayerNormLinear.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormLinear.forward
autograd_ctx = [None]
non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_rs_dgrad,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_name,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)
(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers
if is_grad_enabled:
fwd_fn = _LayerNormLinear.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormLinear.forward
autograd_ctx = [None]
non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_rs_dgrad,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_name,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)
inp = self.prepare_forward(
inp, allow_non_contiguous=False # removed .contiguous from inside the layer
)
try:
# Get concatenated weight and bias tensors
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
# ... rest of the forward code ...
finally:
self.end_forward()

Comment on lines +2057 to +2160
inp = self.prepare_forward(inp, num_gemms=2)

quantizers = (
self._get_quantizers(fp8_output, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, is_grad_enabled)
quantizers = (
self._get_quantizers(fp8_output, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, is_grad_enabled)

# Get quantizers
(
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
) = quantizers

# Get weight tensors
fc1_weight, fc2_weight = self._get_weight_tensors()
fc1_bias = self.fc1_bias if self.use_bias else None
fc2_bias = self.fc2_bias if self.use_bias else None
if not self.fp8:
if isinstance(fc1_weight, Float8Tensor):
fc1_weight = fc1_weight.dequantize()
if isinstance(fc2_weight, Float8Tensor):
fc2_weight = fc2_weight.dequantize()

# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute():
self.bias_gelu_nvfusion = False
# Get quantizers
(
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
) = quantizers

if is_grad_enabled:
fwd_fn = _LayerNormMLP.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormMLP.forward
autograd_ctx = [None]

non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
self.bias_gelu_nvfusion and not self.fp8 and not debug,
self.set_parallel_mode,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.activation,
self.activation_params,
self.normalization,
self.ub_overlap_ag,
self.ub_overlap_rs,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.gemm_gelu_fusion and not debug,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.checkpoint,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
fc1_weight,
fc1_bias,
fc2_weight,
fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)
# Get weight tensors
fc1_weight, fc2_weight = self._get_weight_tensors()
fc1_bias = self.fc1_bias if self.use_bias else None
fc2_bias = self.fc2_bias if self.use_bias else None
if not self.fp8:
if isinstance(fc1_weight, Float8Tensor):
fc1_weight = fc1_weight.dequantize()
if isinstance(fc2_weight, Float8Tensor):
fc2_weight = fc2_weight.dequantize()

# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute():
self.fast_setattr("bias_gelu_nvfusion", False)

if is_grad_enabled:
fwd_fn = _LayerNormMLP.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormMLP.forward
autograd_ctx = [None]

non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
self.bias_gelu_nvfusion and not self.fp8 and not debug,
self.set_parallel_mode,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.activation,
self.activation_params,
self.normalization,
self.ub_overlap_ag,
self.ub_overlap_rs,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.gemm_gelu_fusion and not debug,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.checkpoint,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
fc1_weight,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CRITICAL: NVTX range imbalance on exceptions

Same exception safety issue as in Linear module. The code between prepare_forward() and end_forward() needs try-finally wrapping to ensure NVTX ranges are properly closed even when exceptions occur.

Comment on lines +793 to +852
inp = self.prepare_forward(inp, num_gemms=self.num_gemms)
weight_tensors = self._get_weight_tensors()
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]

quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers()
quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers()

if debug:
if self.no_debug_features_active(list(chain(*quantizers))):
debug = False
quantizers = self._get_quantizers()
if debug:
if self.no_debug_features_active(list(chain(*quantizers))):
debug = False
quantizers = self._get_quantizers()

if isinstance(weight_tensors, QuantizedTensorStorage):
raise RuntimeError("FP8 weights are not supported in debug mode.")
if isinstance(weight_tensors, QuantizedTensorStorage):
raise RuntimeError("FP8 weights are not supported in debug mode.")

(
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
) = quantizers
(
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
) = quantizers

if is_grad_enabled:
linear_fn = _GroupedLinear.apply
autograd_ctx = []
else:
linear_fn = _GroupedLinear.forward
autograd_ctx = [None]

non_tensor_args = (
m_splits,
self.apply_bias,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.sequence_parallel,
self.activation_dtype,
is_grad_enabled,
self,
None, # skip_fp8_weight_update
self.save_original_input,
debug,
)
out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)
if is_grad_enabled:
linear_fn = _GroupedLinear.apply
autograd_ctx = []
else:
linear_fn = _GroupedLinear.forward
autograd_ctx = [None]

non_tensor_args = (
m_splits,
self.apply_bias,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.sequence_parallel,
self.activation_dtype,
is_grad_enabled,
self,
None, # skip_fp8_weight_update
self.save_original_input,
debug,
)
out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CRITICAL: NVTX range imbalance on exceptions

Same exception safety issue as in Linear module. The code between prepare_forward() and end_forward() needs try-finally wrapping to ensure NVTX ranges are properly closed even when exceptions occur.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Additional Comments (5)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignments will cause test failures

These direct assignments using self.tp_group = tp_group and self.tp_group_initialized = True will trigger the RuntimeWarning from the overridden __setattr__ method (lines 652-661), which pytest.ini now treats as an error.

The set_tensor_parallel_group() method is called after module initialization (when _initialized=True), so these assignments will cause all tests that call this method to fail.

        self.fast_setattr("tp_group", tp_group)
        self.fast_setattr("tp_group_initialized", True)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignments will cause test failures

These direct assignments to self.debug_last_iteration and self.debug_enabled_in_this_iteration will trigger the RuntimeWarning since is_debug_iter() is called during forward passes (after _initialized=True). With pytest.ini treating RuntimeWarnings as errors, this will cause tests to fail.

            self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
            self.fast_setattr("debug_enabled_in_this_iteration", debug)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignment will cause test failures

This direct assignment to self.debug_last_iteration will trigger RuntimeWarning (treated as error by pytest).

        self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignment will cause test failures

This direct assignment to self.next_iter_when_debug_should_be_run will trigger RuntimeWarning (treated as error by pytest). This method is called during forward passes after initialization.

        self.fast_setattr("next_iter_when_debug_should_be_run", next_iter_when_debug_should_be_run(quantizers))

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignment will cause test failures

This direct assignment to self.name will trigger RuntimeWarning (treated as error by pytest). This code path executes during forward passes when debug is enabled.

            self.fast_setattr("name", f"Layer_{TEDebugState.get_layer_count()}")

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7 files reviewed, 7 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +1402 to +1475
inp = self.prepare_forward(inp, allow_non_contiguous=isinstance(inp, QuantizedTensor))

weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()

quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)

(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
linear_fn = _Linear.apply
autograd_ctx = []
else:
linear_fn = _Linear.forward
autograd_ctx = [None]

non_tensor_args = (
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
is_grad_enabled,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.ub_name,
fp8_output,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.save_original_input,
debug,
)
out = linear_fn(
*autograd_ctx,
weight_tensor,
inp,
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
non_tensor_args,
)
(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
linear_fn = _Linear.apply
autograd_ctx = []
else:
linear_fn = _Linear.forward
autograd_ctx = [None]

non_tensor_args = (
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
is_grad_enabled,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.ub_name,
fp8_output,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.save_original_input,
debug,
)
out = linear_fn(
*autograd_ctx,
weight_tensor,
inp,
bias_tensor if (self.apply_bias and not self.gemm_bias_unfused_add) else None,
non_tensor_args,
)
self.end_forward()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CRITICAL: NVTX range imbalance on exceptions

The refactoring from a context manager to manual prepare_forward()/end_forward() calls breaks exception safety. If an exception occurs between prepare_forward() (line 1402, which calls nvtx_range_push) and end_forward() (line 1474, which calls nvtx_range_pop), the NVTX range stack will be corrupted.

Exceptions can occur in many places:

  • _get_weight_and_bias_tensors() (line 1404)
  • _get_quantizers() or _get_debug_quantizers() (lines 1406-1414)
  • linear_fn() (line 1467) - shape mismatches, CUDA OOM, assertions, etc.

The original context manager pattern guaranteed cleanup via __exit__. The new pattern requires all code paths to call end_forward(), but exceptions bypass this, leaving an unclosed NVTX range that corrupts the stack for all subsequent operations.

Consider using a try-finally block or reverting to the context manager pattern for this module.

Comment on lines +1519 to +1596
inp = self.prepare_forward(
inp, allow_non_contiguous=False # removed .contiguous from inside the layer
) as inp:
)

# Get concatenated weight and bias tensors
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()
# Get concatenated weight and bias tensors
weight_tensor, bias_tensor = self._get_weight_and_bias_tensors()

quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
quantizers = (
self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, fp8_grad, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, fp8_grad, is_grad_enabled)

(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
fwd_fn = _LayerNormLinear.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormLinear.forward
autograd_ctx = [None]
non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_rs_dgrad,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_name,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
bias_tensor if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)
(
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
) = quantizers

if is_grad_enabled:
fwd_fn = _LayerNormLinear.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormLinear.forward
autograd_ctx = [None]
non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
input_quantizer,
weight_quantizer,
output_quantizer,
grad_input_quantizer,
grad_weight_quantizer,
grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.normalization,
self.ub_overlap_ag_fprop,
self.ub_overlap_rs_fprop,
self.ub_overlap_ag_dgrad,
self.ub_overlap_rs_dgrad,
self.ub_bulk_wgrad,
self.ub_bulk_dgrad,
self.ub_name,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
weight_tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CRITICAL: NVTX range imbalance on exceptions

Same issue as in Linear module: the refactoring from a context manager to manual prepare_forward()/end_forward() calls breaks exception safety. If an exception occurs between prepare_forward() (line 1519) and end_forward() (line 1596), the NVTX range stack will be corrupted.

Exceptions can occur in:

  • _get_weight_and_bias_tensors() (line 1522)
  • _get_quantizers() or _get_debug_quantizers() (lines 1524-1532)
  • fwd_fn() (line 1586) - shape mismatches, CUDA OOM, assertions, etc.

Consider using a try-finally block or reverting to the context manager pattern.

Comment on lines +2057 to +2157
inp = self.prepare_forward(inp, num_gemms=2)

quantizers = (
self._get_quantizers(fp8_output, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, is_grad_enabled)
quantizers = (
self._get_quantizers(fp8_output, is_grad_enabled)
if not debug
else self._get_debug_quantizers(fp8_output, is_grad_enabled)
)
if debug:
if self.no_debug_features_active(quantizers):
debug = False
quantizers = self._get_quantizers(fp8_output, is_grad_enabled)

# Get quantizers
(
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
) = quantizers

# Get weight tensors
fc1_weight, fc2_weight = self._get_weight_tensors()
fc1_bias = self.fc1_bias if self.use_bias else None
fc2_bias = self.fc2_bias if self.use_bias else None
if not self.fp8:
if isinstance(fc1_weight, Float8Tensor):
fc1_weight = fc1_weight.dequantize()
if isinstance(fc2_weight, Float8Tensor):
fc2_weight = fc2_weight.dequantize()

# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute():
self.bias_gelu_nvfusion = False
# Get quantizers
(
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
) = quantizers

if is_grad_enabled:
fwd_fn = _LayerNormMLP.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormMLP.forward
autograd_ctx = [None]

non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
self.bias_gelu_nvfusion and not self.fp8 and not debug,
self.set_parallel_mode,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.activation,
self.activation_params,
self.normalization,
self.ub_overlap_ag,
self.ub_overlap_rs,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.gemm_gelu_fusion and not debug,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.checkpoint,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
self.layer_norm_weight,
self.layer_norm_bias,
fc1_weight,
fc1_bias,
fc2_weight,
fc2_bias if self.apply_bias and not self.gemm_bias_unfused_add else None,
non_tensor_args,
)
# Get weight tensors
fc1_weight, fc2_weight = self._get_weight_tensors()
fc1_bias = self.fc1_bias if self.use_bias else None
fc2_bias = self.fc2_bias if self.use_bias else None
if not self.fp8:
if isinstance(fc1_weight, Float8Tensor):
fc1_weight = fc1_weight.dequantize()
if isinstance(fc2_weight, Float8Tensor):
fc2_weight = fc2_weight.dequantize()

# Disable bias_gelu_nvfusion for determinism checkpointing in non-reentrant mode
if self.bias_gelu_nvfusion and not use_reentrant_activation_recompute():
self.fast_setattr("bias_gelu_nvfusion", False)

if is_grad_enabled:
fwd_fn = _LayerNormMLP.apply
autograd_ctx = []
else:
fwd_fn = _LayerNormMLP.forward
autograd_ctx = [None]

non_tensor_args = (
self.eps,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
self.fuse_wgrad_accumulation,
fc1_input_quantizer,
fc1_weight_quantizer,
fc1_output_quantizer,
fc1_grad_input_quantizer,
fc1_grad_weight_quantizer,
fc1_grad_output_quantizer,
fc2_input_quantizer,
fc2_weight_quantizer,
fc2_output_quantizer,
fc2_grad_input_quantizer,
fc2_grad_weight_quantizer,
fc2_grad_output_quantizer,
is_cpu_offload_enabled(),
self.tp_group,
self.tp_size,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.return_layernorm_output,
self.return_layernorm_output_gathered,
self.bias_gelu_nvfusion and not self.fp8 and not debug,
self.set_parallel_mode,
is_grad_enabled,
self.fwd_ln_sm_margin if is_grad_enabled else self.inf_ln_sm_margin,
self.bwd_ln_sm_margin,
self.zero_centered_gamma,
self.activation,
self.activation_params,
self.normalization,
self.ub_overlap_ag,
self.ub_overlap_rs,
self.ub_overlap_rs_dgrad,
self.ub_bulk_dgrad,
self.ub_bulk_wgrad,
self.gemm_gelu_fusion and not debug,
self.fsdp_group,
self,
skip_fp8_weight_update,
self.symmetric_ar_type,
self.checkpoint,
debug,
)
out = fwd_fn(
*autograd_ctx,
inp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CRITICAL: NVTX range imbalance on exceptions

Same issue as in Linear module: the refactoring from a context manager to manual prepare_forward()/end_forward() calls breaks exception safety. If an exception occurs between prepare_forward() (line 2057) and end_forward() (line 2157), the NVTX range stack will be corrupted.

Exceptions can occur in:

  • _get_quantizers() or _get_debug_quantizers() (lines 2059-2067)
  • _get_weight_tensors() (line 2075)
  • fwd_fn() (line 2145) - shape mismatches, CUDA OOM, assertions, etc.

Consider using a try-finally block or reverting to the context manager pattern.

Comment on lines +793 to +852
inp = self.prepare_forward(inp, num_gemms=self.num_gemms)
weight_tensors = self._get_weight_tensors()
bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)]

quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers()
quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers()

if debug:
if self.no_debug_features_active(list(chain(*quantizers))):
debug = False
quantizers = self._get_quantizers()
if debug:
if self.no_debug_features_active(list(chain(*quantizers))):
debug = False
quantizers = self._get_quantizers()

if isinstance(weight_tensors, QuantizedTensorStorage):
raise RuntimeError("FP8 weights are not supported in debug mode.")
if isinstance(weight_tensors, QuantizedTensorStorage):
raise RuntimeError("FP8 weights are not supported in debug mode.")

(
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
) = quantizers
(
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
) = quantizers

if is_grad_enabled:
linear_fn = _GroupedLinear.apply
autograd_ctx = []
else:
linear_fn = _GroupedLinear.forward
autograd_ctx = [None]

non_tensor_args = (
m_splits,
self.apply_bias,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.sequence_parallel,
self.activation_dtype,
is_grad_enabled,
self,
None, # skip_fp8_weight_update
self.save_original_input,
debug,
)
out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)
if is_grad_enabled:
linear_fn = _GroupedLinear.apply
autograd_ctx = []
else:
linear_fn = _GroupedLinear.forward
autograd_ctx = [None]

non_tensor_args = (
m_splits,
self.apply_bias,
is_first_microbatch,
self.fp8,
self.fp8_calibration,
self.wgrad_store,
input_quantizers,
weight_quantizers,
output_quantizers,
grad_input_quantizers,
grad_weight_quantizers,
grad_output_quantizers,
self.fuse_wgrad_accumulation,
is_cpu_offload_enabled(),
self.sequence_parallel,
self.activation_dtype,
is_grad_enabled,
self,
None, # skip_fp8_weight_update
self.save_original_input,
debug,
)
out = linear_fn(*autograd_ctx, inp, non_tensor_args, *weight_tensors, *bias_tensors)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CRITICAL: NVTX range imbalance on exceptions

Same issue as in Linear module: the refactoring from a context manager to manual prepare_forward()/end_forward() calls breaks exception safety. If an exception occurs between prepare_forward() (line 793) and end_forward() (line 847), the NVTX range stack will be corrupted.

Exceptions can occur in:

  • _get_weight_tensors() (line 794)
  • _get_quantizers() or _get_debug_quantizers() (lines 797-805)
  • linear_fn() (line 846) - shape mismatches, CUDA OOM, assertions, etc.

Consider using a try-finally block or reverting to the context manager pattern.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 12, 2026

Additional Comments (3)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignments will cause test failures

These direct assignments using self.tp_group = tp_group and self.tp_group_initialized = True will trigger the RuntimeWarning from the overridden __setattr__ method (lines 652-661), which pytest.ini now treats as an error.

The set_tensor_parallel_group() method is called after module initialization (when _initialized=True), so these assignments will cause all tests that call this method to fail.

        self.fast_setattr("tp_group", tp_group)
        self.fast_setattr("tp_group_initialized", True)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignments will cause test failures

These direct assignments to self.debug_last_iteration and self.debug_enabled_in_this_iteration will trigger the RuntimeWarning from the overridden __setattr__ method (lines 652-661), which pytest.ini now treats as an error.

The is_debug_iter() method is called during forward passes (after _initialized=True), so these assignments will cause tests to fail when debug mode is enabled.

            self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())
            self.fast_setattr("debug_enabled_in_this_iteration", debug)

transformer_engine/pytorch/module/base.py
CRITICAL: Direct attribute assignment will cause test failures

This direct assignment to self.debug_last_iteration will trigger the RuntimeWarning from the overridden __setattr__ method (lines 652-661), which pytest.ini now treats as an error.

        self.fast_setattr("debug_last_iteration", TEDebugState.get_iteration())

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx ptrendx force-pushed the pr_python_cpu_optimization branch from f96abbd to 2fb6ee3 Compare January 13, 2026 20:05
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/pytorch/module/base.py, line 1570 (link)

    syntax: CRITICAL: RuntimeWarning will cause test failures

    This direct attribute assignment will also trigger the RuntimeWarning.

17 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants