Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
a09988a
Change from linear to exponentially decay cudagraph sizes
mathemakitten Feb 20, 2026
a1555ba
Merge branch 'main' into helenn-exponential-decay-cudagraph-sizes
mathemakitten Feb 20, 2026
20798af
Maybe include a size-1 graph
mathemakitten Feb 20, 2026
24bc8d6
Merge branch 'helenn-exponential-decay-cudagraph-sizes' of https://gi…
mathemakitten Feb 20, 2026
bbf2dd1
Merge branch 'main' of https://gitlab-master.nvidia.com/ADLR/megatron…
mathemakitten Feb 20, 2026
3c718e9
Update test_cuda_graph_token_counts
mathemakitten Feb 20, 2026
ffae8ef
Merge branch 'main' into helenn-exponential-decay-cudagraph-sizes
mathemakitten Feb 20, 2026
cafe6af
address comments
mathemakitten Feb 23, 2026
618d016
Merge branch 'helenn-exponential-decay-cudagraph-sizes' of https://gi…
mathemakitten Feb 23, 2026
1f3654d
Merge branch 'main' into helenn-exponential-decay-cudagraph-sizes
mathemakitten Feb 23, 2026
96f5278
keshav comments
mathemakitten Feb 23, 2026
d62f2fc
Merge branch 'helenn-exponential-decay-cudagraph-sizes' of https://gi…
mathemakitten Feb 23, 2026
116a785
Merge branch 'main' into helenn-exponential-decay-cudagraph-sizes
mathemakitten Feb 23, 2026
ad6753e
address comments
mathemakitten Mar 10, 2026
9036827
merge
mathemakitten May 18, 2026
6ea8abc
update print in example script to differentiate reserved / allocated
mathemakitten May 18, 2026
c06c327
exponential decay of graph size
mathemakitten May 18, 2026
b508d08
better logging for checking pool reuse etc
mathemakitten May 18, 2026
b5ec39d
fix merge w main
mathemakitten May 18, 2026
48cda16
fix import
mathemakitten May 18, 2026
9488712
minor cleanup
mathemakitten May 18, 2026
661179a
fix failing test
mathemakitten May 18, 2026
b347294
fix test
mathemakitten May 18, 2026
33e65c4
fix tests and recover small graphs for perf
mathemakitten May 18, 2026
8561960
keshav fixes
mathemakitten May 19, 2026
60e7929
merge main to get new tests
mathemakitten May 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/inference/gpt/gpt_dynamic_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ def escape_str(s):
print(
f"{setup_prefix} … " f"throughput: {throughput:.3f} tok/s … ",
f"total time: {total_time:.3f}s … "
f"mem {peak_alloc_gb:.1f}/{peak_resvd_gb:.1f} GB … "
f"mem {peak_alloc_gb:.1f} allocated/{peak_resvd_gb:.1f} reserved GB … "
f"steps: {engine.context.step_count:d} … "
f"capture {capture_str}",
)
Expand Down
257 changes: 197 additions & 60 deletions megatron/core/inference/batch_dimensions_utils.py

Large diffs are not rendered by default.

28 changes: 27 additions & 1 deletion megatron/core/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,22 @@ class KVCacheManagementMode(str, Enum):
"""Deallocate large tensors and recompute them from scratch during allocation."""


class CudaGraphSizingDistribution(str, Enum):
"""How CUDA graph token-count sizes are spaced when generating the captured graphs.

EXPONENTIAL (default) — token counts halve from `cuda_graph_max_tokens` down to `tp_size`,
giving a log-spaced distribution. Bounded relative padding (~2x worst case) at every scale and
`log2(max_tokens)` total graphs.

LINEAR — Include size-1 and size-2 graphs where applicable, linear spacing up until 256, and
sparser linear spacing past 256. e.g. `[1, 2, 4] + range(8, 256, 8) + range(256, max+1, 16)`.
Higher graph density at the top end.
"""

EXPONENTIAL = "exponential"
LINEAR = "linear"


@dataclass
class InferenceConfig:
"""
Expand Down Expand Up @@ -197,10 +213,20 @@ class InferenceConfig:
"""

cuda_graph_mixed_prefill_count: Optional[int] = 16
"""
"""
The number of mixed prefill graphs to capture if mixed prefill/decode graphs are enabled.
"""

cuda_graph_sizing_distribution: CudaGraphSizingDistribution = (
CudaGraphSizingDistribution.EXPONENTIAL
)
"""
How CUDA graph token counts are spaced. EXPONENTIAL (default) halves from
`cuda_graph_max_tokens` down to `tp_size` (log-spaced, ~log2(max_tokens) graphs).
LINEAR uses a range of linear strides (includes small graphs + mid-range linearity +
a bigger step size at the top end).
"""

use_cuda_graphs_for_non_decode_steps: bool = True
"""
Whether to use CUDA graphs for non-decode steps.
Expand Down
1 change: 1 addition & 0 deletions megatron/core/inference/contexts/dynamic_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,7 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC
max_sequence_length=self.max_sequence_length,
use_cuda_graphs_for_non_decode_steps=self.use_cuda_graphs_for_non_decode_steps,
num_speculative_tokens=self.num_speculative_tokens,
sizing_distribution=inference_config.cuda_graph_sizing_distribution,
)
)

Expand Down
74 changes: 63 additions & 11 deletions megatron/core/inference/engines/dynamic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
)
from megatron.core.inference.utils import Counter, InferenceMode, await_process_call
from megatron.core.process_groups_config import ProcessGroupCollection
from megatron.core.transformer.cuda_graphs import delete_cuda_graphs
from megatron.core.transformer.cuda_graphs import CudaGraphManager, delete_cuda_graphs
from megatron.core.transformer.enums import InferenceCudaGraphScope
from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction
from megatron.core.utils import (
Expand Down Expand Up @@ -133,13 +133,42 @@ class EngineSuspendedError(Exception):

def format_mem_bytes(mem_bytes):
"""Convert a byte count to a human-readable string in tb, gb, mb, kb, or bytes."""
if mem_bytes < 0:
return "-" + format_mem_bytes(-mem_bytes)
for power, suffix in [(4, "tb"), (3, "gb"), (2, "mb"), (1, "kb"), (0, "bytes")]:
suffix_bytes = 1024**power
if mem_bytes >= suffix_bytes:
return "%.1f %s" % (mem_bytes / suffix_bytes, suffix)
return "%d bytes" % mem_bytes


def _cuda_graph_mempool_bytes() -> Tuple[int, int]:
"""Return (reserved, allocated) bytes belonging to the global CUDA graph mempool.

PyTorch's `torch.cuda.memory_stats()` reports process-wide totals that mix in
every other allocation (KV cache, NCCL workspaces, layer scratch). To isolate
growth caused by graph capture, we walk `torch.cuda.memory_snapshot()` and
filter segments by their `segment_pool_id` against the graph pool handle.
Returns (0, 0) if the pool hasn't been created yet.
"""
pool_id = CudaGraphManager.global_mempool
if pool_id is None:
return 0, 0
reserved = 0
allocated = 0
for seg in torch.cuda.memory_snapshot():
seg_pool_id = (
seg.get("segment_pool_id")
or seg.get("private_pool_id")
or seg.get("pool_id")
or seg.get("pool")
)
if seg_pool_id == pool_id:
reserved += seg.get("total_size", 0)
allocated += seg.get("allocated_size", 0)
return reserved, allocated


@dataclass(kw_only=True)
class RequestEntry:
"""Entry in the engine's `self.requests` dict."""
Expand Down Expand Up @@ -347,6 +376,13 @@ def create_cuda_graphs(self, reset_context: bool = True):
time_start = time.time()
mem_stats_start = torch.cuda.memory_stats()

# Snapshot of process-wide stats for the "total memory used by capture" summary.
start_proc_reserved = mem_stats_start["reserved_bytes.all.current"]
start_proc_alloc = mem_stats_start["allocated_bytes.all.current"]

# Pool-scoped baselines for the per-iteration deltas.
prev_pool_reserved, prev_pool_alloc = _cuda_graph_mempool_bytes()

logging.info("> dynamic_engine.py: building cuda graphs for ")
for graph in context.cuda_graph_batch_dimensions_list:
logging.info(graph)
Expand Down Expand Up @@ -427,27 +463,43 @@ def create_cuda_graphs(self, reset_context: bool = True):

context.reset()

# Per-iteration memory accounting, scoped to the CUDA-graph mempool.
# This isolates pool growth from process-wide scratch churn (KV cache,
# NCCL workspaces, etc.) that pollutes `torch.cuda.memory_stats()`.
pool_reserved, pool_alloc = _cuda_graph_mempool_bytes()
logging.info(
" [graph %d/%d] %s | pool reserved=%s (Δiter=%s) " "pool allocated=%s (Δiter=%s)",
tbar_idx + 1,
len(context.cuda_graph_batch_dimensions_list),
cuda_graph_batch_dimension,
format_mem_bytes(pool_reserved),
format_mem_bytes(pool_reserved - prev_pool_reserved),
format_mem_bytes(pool_alloc),
format_mem_bytes(pool_alloc - prev_pool_alloc),
)
prev_pool_reserved, prev_pool_alloc = pool_reserved, pool_alloc

if mtp_warmup_enabled and mtp_seen_batch_sizes:
logging.info("> MTP CUDA graph warmup: %d batch size(s)", len(mtp_seen_batch_sizes))

# Memory usage.
time_end = time.time()
mem_stats_end = torch.cuda.memory_stats()
final_pool_reserved, final_pool_alloc = _cuda_graph_mempool_bytes()
capture_stats = {
"time": time_end - time_start,
"allocated_bytes": (
mem_stats_end["allocated_bytes.all.current"]
- mem_stats_start["allocated_bytes.all.current"]
),
"reserved_bytes": (
mem_stats_end["reserved_bytes.all.current"]
- mem_stats_start["reserved_bytes.all.current"]
),
"allocated_bytes": (mem_stats_end["allocated_bytes.all.current"] - start_proc_alloc),
"reserved_bytes": (mem_stats_end["reserved_bytes.all.current"] - start_proc_reserved),
"pool_reserved_bytes": final_pool_reserved,
"pool_allocated_bytes": final_pool_alloc,
}
logging.info(
"> built cuda graph(s) in %.2f sec, with total memory usage: "
"allocated %s, reserved %s.",
"> built cuda graph(s) in %.2f sec. "
"Mempool: reserved %s, allocated %s. "
"Process-wide delta: allocated %s, reserved %s.",
capture_stats["time"],
format_mem_bytes(capture_stats["pool_reserved_bytes"]),
format_mem_bytes(capture_stats["pool_allocated_bytes"]),
format_mem_bytes(capture_stats["allocated_bytes"]),
format_mem_bytes(capture_stats["reserved_bytes"]),
)
Expand Down
4 changes: 4 additions & 0 deletions megatron/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from gpt_builders import gpt_builder
from hybrid_builders import hybrid_builder
from megatron.core.inference.config import (
CudaGraphSizingDistribution,
InferenceConfig,
KVCacheManagementMode,
MambaInferenceStateConfig,
Expand Down Expand Up @@ -356,6 +357,9 @@ def get_inference_config_from_model_and_args(model: MegatronModule, args):
unified_memory_level=args.inference_dynamic_batching_unified_memory_level,
kv_cache_management_mode=KVCacheManagementMode(args.rl_kv_cache_management_mode),
cuda_graph_mixed_prefill_count=args.inference_dynamic_batching_cuda_graph_mixed_prefill_count, # pylint: disable=line-too-long
cuda_graph_sizing_distribution=CudaGraphSizingDistribution(
args.inference_dynamic_batching_cuda_graph_sizing_distribution
),
use_cuda_graphs_for_non_decode_steps=not args.decode_only_cuda_graphs,
cuda_graph_all_prefills=args.inference_cuda_graph_all_prefills,
static_kv_memory_pointers=args.rl_persist_cuda_graphs,
Expand Down
8 changes: 8 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1963,6 +1963,14 @@ def _add_inference_args(parser):
group.add_argument('--inference-dynamic-batching-cuda-graph-mixed-prefill-count',
type=int, default=16,
help='Number of mixed prefill requests to capture in a cuda graph.')
group.add_argument('--inference-dynamic-batching-cuda-graph-sizing-distribution',
type=str, default='exponential',
choices=['exponential', 'linear'],
dest='inference_dynamic_batching_cuda_graph_sizing_distribution',
help='Spacing of CUDA graph token counts. "exponential" (default) '
'halves from cuda_graph_max_tokens down to tp_size, giving a '
'log-spaced distribution with bounded relative padding. '
'"linear" uses varying linear strides across the range.')
group.add_argument('--inference-dynamic-batching-sampling-backend',
type=str, default='torch',
choices=['torch', 'flashinfer'],
Expand Down
10 changes: 8 additions & 2 deletions tests/unit_tests/inference/contexts/test_dynamic_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1659,10 +1659,16 @@ def test_add_dummy_requests_for_expert_parallel_step_matches_slow_path(
num_speculative_tokens=num_speculative_tokens,
)

smallest = min(ctx.cuda_graph_batch_dimensions_list)
# The fast path is decode-only by construction, so pick the smallest decode-only batch_dim.
# With the geometric grid for mixed cudagraphs, the global min may now be a P=1 mixed shape
# when num_speculative_tokens > 0 makes decode-only token_count > 1)
smallest = min(
batchdim
for batchdim in ctx.cuda_graph_batch_dimensions_list
if batchdim.prefill_req_count == 0
)
N = smallest.decode_req_count
T = smallest.token_count # N * (num_speculative_tokens + 1)
assert smallest.prefill_req_count == 0, "smallest graph must be decode-only"

# --- slow path (reference) ---
ctx.add_dummy_requests_for_cudagraph_capture(smallest)
Expand Down
25 changes: 15 additions & 10 deletions tests/unit_tests/inference/engines/test_dynamic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,23 +878,28 @@ def test_fixed_output_lengths(self, model_provider: str) -> None:
def test_cuda_graph_token_counts(self, use_non_decode: bool) -> None:
"""Test initialization of `cuda_graph_token_counts` in dynamic context."""

# Exponential-decay graph distribution (halve from max down to tp_size).
# decode-only path: cuda_graph_max_tokens = max_requests * (spec+1) = 80.
# non-decode path: cuda_graph_max_tokens = self.max_tokens (DEFAULT 16384);
# most large prefill sizes are filtered by is_valid because
# token_count > prefill_req_count * (max_sequence_length - 1).
decode_only_cases = [
(0, [80]),
(1, [80]),
(2, [80, 40]),
(4, [80, 72, 48, 24]),
(8, [80, 64, 48, 32, 16]),
(16, [80, 72, 64, 56, 48, 40, 32, 24, 16, 8]),
(32, [80, 72, 64, 56, 48, 40, 32, 24, 16, 8]),
(2, [80, 1]),
(4, [80, 40, 20, 1]),
(8, [80, 40, 20, 10, 4, 2, 1]),
(16, [80, 40, 20, 10, 4, 2, 1]),
(32, [80, 40, 20, 10, 4, 2, 1]),
]
non_decode_cases = [
(0, [80]),
(1, [80]),
(2, [80, 40]),
(4, [80, 72, 48, 24]),
(8, [80, 64, 48, 32, 16]),
(16, [1024, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8]),
(32, [1024, 512, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8]),
(2, [80, 1]),
(4, [80, 40, 20, 1]),
(8, [1024, 512, 256, 80, 40, 20, 10, 4, 2, 1]),
(16, [1024, 512, 256, 128, 80, 64, 40, 32, 20, 16, 10, 8, 4, 2, 1]),
(32, [1024, 512, 256, 128, 80, 64, 40, 32, 20, 16, 10, 8, 4, 2, 1]),
]
cases = non_decode_cases if use_non_decode else decode_only_cases

Expand Down
Loading