From 258a238a8b2e4e8d48ffcf0a978a3e79ae663aa9 Mon Sep 17 00:00:00 2001 From: Keshav Santhanam Date: Wed, 22 Apr 2026 21:23:08 -0700 Subject: [PATCH] Add reset function to CudaGraphManager Signed-off-by: Keshav Santhanam --- megatron/core/transformer/cuda_graphs.py | 12 +++++++++--- .../inference/engines/test_dynamic_engine.py | 7 ++----- .../engines/test_mamba_prefix_caching_e2e.py | 4 +--- .../engines/test_prefix_caching_cuda_graphs.py | 8 ++------ tests/unit_tests/rl/test_rl_utils.py | 3 +-- tests/unit_tests/transformer/test_cuda_graphs.py | 12 ++++-------- 6 files changed, 19 insertions(+), 27 deletions(-) diff --git a/megatron/core/transformer/cuda_graphs.py b/megatron/core/transformer/cuda_graphs.py index 8d1972ff175..9ccdf13c72f 100644 --- a/megatron/core/transformer/cuda_graphs.py +++ b/megatron/core/transformer/cuda_graphs.py @@ -348,6 +348,14 @@ class _CudagraphGlobalRecord: """A pool-like data structure to reuse input and output buffers across cudagraph.""" tensor_reuse_pool = TensorReusePool() + @classmethod + def reset(cls): + """Reset all global tracking state. Only necessary for testing.""" + cls.cudagraph_created = False + cls.cudagraph_record = [] + cls.cudagraph_inference_record = [] + cls.mtp_cudagraph_managers.clear() + @classmethod def record_fwd_graph(cls, runner, args, kwargs, out): """Record a fwd graph to 'cudagraph_record""" @@ -533,9 +541,7 @@ def delete_cuda_graphs(): mgr.inference_cudagraphs_lookup_table.clear() # Reset global tracking state - _CudagraphGlobalRecord.cudagraph_created = False - _CudagraphGlobalRecord.cudagraph_record = [] - _CudagraphGlobalRecord.cudagraph_inference_record = [] + _CudagraphGlobalRecord.reset() # TODO: Optional?: Force garbage collection to clean up memory gc.collect() diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index 02a177c413c..e2de1c2c208 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -468,8 +468,7 @@ def _build_test_env(cls, test_config): ) # Reset global cuda graph state. - _CudagraphGlobalRecord.cudagraph_created = False - _CudagraphGlobalRecord.cudagraph_record = [] + _CudagraphGlobalRecord.reset() CudaGraphManager.global_mempool = None # Inference engine. @@ -4500,9 +4499,7 @@ def _create_model(self, model_provider, num_cuda_graphs): def _reset_cuda_graph_state(self, model): """Reset all CUDA graph global and per-module state.""" - _CudagraphGlobalRecord.cudagraph_created = False - _CudagraphGlobalRecord.cudagraph_record = [] - _CudagraphGlobalRecord.cudagraph_inference_record = [] + _CudagraphGlobalRecord.reset() CudaGraphManager.global_mempool = None for module in model.modules(): if isinstance(module, CudaGraphManager): diff --git a/tests/unit_tests/inference/engines/test_mamba_prefix_caching_e2e.py b/tests/unit_tests/inference/engines/test_mamba_prefix_caching_e2e.py index ce21c775b73..0b80b6d157b 100644 --- a/tests/unit_tests/inference/engines/test_mamba_prefix_caching_e2e.py +++ b/tests/unit_tests/inference/engines/test_mamba_prefix_caching_e2e.py @@ -219,9 +219,7 @@ def _build_engine( vocab_size=VOCAB_SIZE, detokenize=lambda tokens: "tokenized_prompt" ), ) - _CudagraphGlobalRecord.cudagraph_created = False - _CudagraphGlobalRecord.cudagraph_record = [] - _CudagraphGlobalRecord.cudagraph_inference_record = [] + _CudagraphGlobalRecord.reset() CudaGraphManager.global_mempool = None for module in model.modules(): if isinstance(module, CudaGraphManager): diff --git a/tests/unit_tests/inference/engines/test_prefix_caching_cuda_graphs.py b/tests/unit_tests/inference/engines/test_prefix_caching_cuda_graphs.py index 52a05f7f80f..e4ee6fcd5d6 100644 --- a/tests/unit_tests/inference/engines/test_prefix_caching_cuda_graphs.py +++ b/tests/unit_tests/inference/engines/test_prefix_caching_cuda_graphs.py @@ -140,9 +140,7 @@ def _create_model(self, model_type, num_cuda_graphs=None): def _reset_cuda_graph_state(self, model): """Reset all CUDA graph global and per-module state.""" - _CudagraphGlobalRecord.cudagraph_created = False - _CudagraphGlobalRecord.cudagraph_record = [] - _CudagraphGlobalRecord.cudagraph_inference_record = [] + _CudagraphGlobalRecord.reset() CudaGraphManager.global_mempool = None for module in model.modules(): if isinstance(module, CudaGraphManager): @@ -360,9 +358,7 @@ def _create_hybrid_model(self, num_cuda_graphs=None): def _reset_cuda_graph_state(self, model): """Reset all CUDA graph global and per-module state.""" - _CudagraphGlobalRecord.cudagraph_created = False - _CudagraphGlobalRecord.cudagraph_record = [] - _CudagraphGlobalRecord.cudagraph_inference_record = [] + _CudagraphGlobalRecord.reset() CudaGraphManager.global_mempool = None for module in model.modules(): if isinstance(module, CudaGraphManager): diff --git a/tests/unit_tests/rl/test_rl_utils.py b/tests/unit_tests/rl/test_rl_utils.py index 6bf6e994ffb..27f5f3c6742 100644 --- a/tests/unit_tests/rl/test_rl_utils.py +++ b/tests/unit_tests/rl/test_rl_utils.py @@ -883,8 +883,7 @@ def test_get_logprobs_cuda_graphs(self, initialize_model_parallel): # Ensure all pending work is complete and graph destruction runs now torch.cuda.synchronize() - _CudagraphGlobalRecord.cudagraph_created = False - _CudagraphGlobalRecord.cudagraph_record = [] + _CudagraphGlobalRecord.reset() CudaGraphManager.global_mempool = None CudaGraphManager.fwd_mempools = None CudaGraphManager.bwd_mempools = None diff --git a/tests/unit_tests/transformer/test_cuda_graphs.py b/tests/unit_tests/transformer/test_cuda_graphs.py index ee4ff7d152d..02fc0727f97 100644 --- a/tests/unit_tests/transformer/test_cuda_graphs.py +++ b/tests/unit_tests/transformer/test_cuda_graphs.py @@ -78,8 +78,7 @@ def setup_method(self, method): def teardown_method(self, method): Utils.destroy_model_parallel() - _CudagraphGlobalRecord.cudagraph_created = False - _CudagraphGlobalRecord.cudagraph_record = [] + _CudagraphGlobalRecord.reset() CudaGraphManager.global_mempool = None @pytest.mark.skipif( @@ -270,8 +269,7 @@ def test_cuda_graph_determine_first_last_layer_logic( # Teardown Utils.destroy_model_parallel() - _CudagraphGlobalRecord.cudagraph_created = False - _CudagraphGlobalRecord.cudagraph_record = [] + _CudagraphGlobalRecord.reset() CudaGraphManager.global_mempool = None CudaGraphManager.fwd_mempools = None CudaGraphManager.bwd_mempools = None @@ -358,8 +356,7 @@ def setup_method(self, method): def teardown_method(self, method): Utils.destroy_model_parallel() - _CudagraphGlobalRecord.cudagraph_created = False - _CudagraphGlobalRecord.cudagraph_record = [] + _CudagraphGlobalRecord.reset() @pytest.mark.skipif( not (HAVE_TE and is_te_min_version("1.5.0")), @@ -500,8 +497,7 @@ def get_mamba_block(hybrid_layer_pattern): def teardown_method(self, method): Utils.destroy_model_parallel() - _CudagraphGlobalRecord.cudagraph_created = False - _CudagraphGlobalRecord.cudagraph_record = [] + _CudagraphGlobalRecord.reset() @pytest.mark.skipif( not (HAVE_TE and is_te_min_version("1.5.0")),