From 38de2ec12ca259963f949de270911a8ac96d8f42 Mon Sep 17 00:00:00 2001 From: Aditya Venkataraman Date: Wed, 22 Apr 2026 03:18:01 +0000 Subject: [PATCH] Trim non-pipeline input grads before caching in bwd_cache The backward graph may produce gradients for inputs beyond the pipeline activations (e.g. labels when loss is fused into the last stage). get_bwd_send_ops zips bwd_cache with grad_send_info using strict=True, and grad_send_info only has entries for pipeline activation inputs, so extra grads cause a ValueError. Mirror the trimming that upstream PipelineStage does at torch/distributed/pipelining/stage.py:997. Authored with Claude. Co-Authored-By: Claude Opus 4.6 (1M context) --- autoparallel/graph_passes/graph_pp_runner.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/autoparallel/graph_passes/graph_pp_runner.py b/autoparallel/graph_passes/graph_pp_runner.py index 12ad9cd9..1640773c 100644 --- a/autoparallel/graph_passes/graph_pp_runner.py +++ b/autoparallel/graph_passes/graph_pp_runner.py @@ -672,6 +672,9 @@ def _post_backward_common( stage_index_to_stage: Dictionary mapping stage indices to GraphPipelineStage objects. is_prev_stage_on_this_rank: True if the previous stage exists on this rank. """ + assert bw_stage._stage_meta.inputs is not None + num_fwd_args = len(bw_stage._stage_meta.inputs) + input_grads = input_grads[:num_fwd_args] bw_stage.bwd_cache[bw_mb_index] = ( tuple(input_grads) if not isinstance(input_grads, tuple) else input_grads )