diff --git a/autoparallel/graph_passes/graph_pp_runner.py b/autoparallel/graph_passes/graph_pp_runner.py index 12ad9cd9..1640773c 100644 --- a/autoparallel/graph_passes/graph_pp_runner.py +++ b/autoparallel/graph_passes/graph_pp_runner.py @@ -672,6 +672,9 @@ def _post_backward_common( stage_index_to_stage: Dictionary mapping stage indices to GraphPipelineStage objects. is_prev_stage_on_this_rank: True if the previous stage exists on this rank. """ + assert bw_stage._stage_meta.inputs is not None + num_fwd_args = len(bw_stage._stage_meta.inputs) + input_grads = input_grads[:num_fwd_args] bw_stage.bwd_cache[bw_mb_index] = ( tuple(input_grads) if not isinstance(input_grads, tuple) else input_grads )