diff --git a/examples/inference/gpt/gpt_dynamic_inference.py b/examples/inference/gpt/gpt_dynamic_inference.py index 02a257c1b46..9172d137eab 100644 --- a/examples/inference/gpt/gpt_dynamic_inference.py +++ b/examples/inference/gpt/gpt_dynamic_inference.py @@ -508,7 +508,7 @@ def escape_str(s): print( f"{setup_prefix} … " f"throughput: {throughput:.3f} tok/s … ", f"total time: {total_time:.3f}s … " - f"mem {peak_alloc_gb:.1f}/{peak_resvd_gb:.1f} GB … " + f"mem {peak_alloc_gb:.1f} allocated/{peak_resvd_gb:.1f} reserved GB … " f"steps: {engine.context.step_count:d} … " f"capture {capture_str}", ) diff --git a/megatron/core/inference/batch_dimensions_utils.py b/megatron/core/inference/batch_dimensions_utils.py index d8793f01d67..1229d333d0a 100644 --- a/megatron/core/inference/batch_dimensions_utils.py +++ b/megatron/core/inference/batch_dimensions_utils.py @@ -213,57 +213,72 @@ class CUDAGraphBatchDimensionBuilder: """ # Constant for rounding token counts when generating CUDA graph batch dimensions - CUDA_GRAPH_ROUNDER = 8 + CUDA_GRAPH_ROUNDER = 2 @staticmethod def _calculate_cuda_graph_token_counts( - tp_size: int, num_cuda_graphs: int, cuda_graph_max_tokens: int + tp_size: int, + num_cuda_graphs: int, + cuda_graph_max_tokens: int, + sizing_distribution: "CudaGraphSizingDistribution" = None, ) -> List[int]: """ Calculate CUDA graph token counts for a given configuration. - This method computes evenly-spaced token counts from step_size up to - cuda_graph_max_tokens, ensuring proper rounding and TP alignment. + Dispatches on `sizing_distribution`: + - EXPONENTIAL (default): halves from cuda_graph_max_tokens down to tp_size, log-spaced, + creates log2(max_tokens) graphs. + - LINEAR: small graphs [1, 2, 4] + range(8, 256, 8) + range(256, max+1, 16); + explicit-N path uses even 16-stride from 0 to max. Args: tp_size: Tensor parallel size (for alignment) - num_cuda_graphs: Number of CUDA graphs to generate (must be >= 1) + num_cuda_graphs: Number of CUDA graphs to generate (must be >= 1, or -1 to auto-size) cuda_graph_max_tokens: Maximum token count for CUDA graphs (must be > 0) + sizing_distribution: Distribution of cudagraph sizes. Defaults to EXPONENTIAL. Returns: List of token counts in descending order - Example: - >>> _calculate_cuda_graph_token_counts - (tp_size=2, num_cuda_graphs=4, cuda_graph_max_tokens=1000) - [1000, 752, 504, 256] + Example (EXPONENTIAL): + >>> _calculate_cuda_graph_token_counts(tp_size=1, num_cuda_graphs=8, + cuda_graph_max_tokens=128) + [128, 64, 32, 16, 8, 4, 2, 1] """ - if num_cuda_graphs == -1: - # automatically determine the number of CUDA graphs to - # capture based on the `max_requests` value - cuda_graph_token_counts = ( - [1, 2, 4] + list(range(8, 256, 8)) + list(range(256, cuda_graph_max_tokens + 1, 16)) + from megatron.core.inference.config import CudaGraphSizingDistribution + + if sizing_distribution is None: + sizing_distribution = CudaGraphSizingDistribution.EXPONENTIAL + + if sizing_distribution == CudaGraphSizingDistribution.LINEAR: + return CUDAGraphBatchDimensionBuilder._calculate_token_counts_linear( + tp_size, num_cuda_graphs, cuda_graph_max_tokens ) - # Align each entry to TP size - cuda_graph_token_counts = list( - dict.fromkeys( - round_up_to_nearest_multiple(s, tp_size) for s in cuda_graph_token_counts - ) + + # Default path: exponential decay. + if num_cuda_graphs == -1: + # Pick a graph count: we halve from cuda_graph_max_tokens down to 1, so + # log2(max_tokens) halvings are needed. Add a small margin for the two forced endpoints + # (cuda_graph_max_tokens and tp_size) that are unioned into the set after the loop. + # Floor at MIN_GRAPHS so the trim logic always has at least 2 entries to work with. + HEADROOM = 2 + MIN_GRAPHS = 4 + num_halvings = int(math.log2(max(2, cuda_graph_max_tokens))) + auto_n = max(MIN_GRAPHS, num_halvings + HEADROOM) + return CUDAGraphBatchDimensionBuilder._calculate_cuda_graph_token_counts( + tp_size=tp_size, + num_cuda_graphs=auto_n, + cuda_graph_max_tokens=cuda_graph_max_tokens, + sizing_distribution=sizing_distribution, ) - # Clamp to max tokens - cuda_graph_token_counts = [ - s for s in cuda_graph_token_counts if s <= cuda_graph_max_tokens - ] - if not cuda_graph_token_counts or cuda_graph_token_counts[-1] != cuda_graph_max_tokens: - cuda_graph_token_counts.append(cuda_graph_max_tokens) - cuda_graph_token_counts.reverse() - return cuda_graph_token_counts assert num_cuda_graphs >= 1, f"num_cuda_graphs must be >= 1, got {num_cuda_graphs}" assert ( cuda_graph_max_tokens > 0 ), f"cuda_graph_max_tokens must be > 0, got {cuda_graph_max_tokens}" + rounder = CUDAGraphBatchDimensionBuilder.CUDA_GRAPH_ROUNDER + # Cuda graph step size. cuda_graph_step_size = cuda_graph_max_tokens / num_cuda_graphs cuda_graph_step_size = CUDAGraphBatchDimensionBuilder.CUDA_GRAPH_ROUNDER * int( @@ -274,25 +289,87 @@ def _calculate_cuda_graph_token_counts( # Ensure non-zero step size (can happen when max_tokens < num_cuda_graphs). cuda_graph_step_size = max(cuda_graph_step_size, tp_size) - # round down cuda graph max tokens to be multiple of TP size + # Round down cuda graph max tokens to be multiple of TP size cuda_graph_max_tokens = (cuda_graph_max_tokens // tp_size) * tp_size - # Cuda graph token counts. if num_cuda_graphs == 1: - cuda_graph_token_counts = [cuda_graph_max_tokens] - else: - cuda_graph_token_counts = list( - range(cuda_graph_step_size, cuda_graph_max_tokens, cuda_graph_step_size) - ) - if ( - len(cuda_graph_token_counts) == 0 - or cuda_graph_token_counts[-1] != cuda_graph_max_tokens - ): - cuda_graph_token_counts.append(cuda_graph_max_tokens) - cuda_graph_token_counts.reverse() + return [cuda_graph_max_tokens] + + # Exponentially decreasing token counts: halve from max_tokens until below the rounder floor + # or num_cuda_graphs. Dedupe (the rounding/TP-alignment can collide for small values), + # then sort descending. + sizes = set() + val = cuda_graph_max_tokens + for _ in range(num_cuda_graphs): + # Round down to multiple of rounder, then up to multiple of TP size + rounded = max(rounder, (val // rounder) * rounder) + rounded = math.ceil(rounded / tp_size) * tp_size + sizes.add(rounded) + val //= 2 + if val < 1: + break + + # Always include the endpoints: cuda_graph_max_tokens (largest) and tp_size (smallest). + sizes.add(cuda_graph_max_tokens) + sizes.add(tp_size) + + cuda_graph_token_counts = sorted(sizes, reverse=True) + + # Trim from the middle if we exceed num_cuda_graphs requested by the user. + # Since num_cuda_graphs >= 1, this only runs when we have at least 2 elements. + while len(cuda_graph_token_counts) > num_cuda_graphs: + cuda_graph_token_counts.pop(-2) + + assert len(cuda_graph_token_counts) <= num_cuda_graphs + assert cuda_graph_max_tokens in cuda_graph_token_counts return cuda_graph_token_counts + @staticmethod + def _calculate_token_counts_linear( + tp_size: int, num_cuda_graphs: int, cuda_graph_max_tokens: int + ) -> List[int]: + """Linear-stride token count distribution. + + For num_cuda_graphs == -1, returns [1, 2, 4] + range(8, 256, 8) + range(256, max+1, 16) + TP-aligned and deduped. + For positive N, returns evenly-spaced sizes with step ~ max_tokens / N. + """ + rounder = CUDAGraphBatchDimensionBuilder.CUDA_GRAPH_ROUNDER + + if num_cuda_graphs == -1: + sizes = ( + [1, 2, 4] + list(range(8, 256, 8)) + list(range(256, cuda_graph_max_tokens + 1, 16)) + ) + # TP-align and dedupe in order; preserve original ordering for parity. + sizes = list(dict.fromkeys(round_up_to_nearest_multiple(s, tp_size) for s in sizes)) + sizes = [s for s in sizes if s <= cuda_graph_max_tokens] + if not sizes or sizes[-1] != cuda_graph_max_tokens: + sizes.append(cuda_graph_max_tokens) + sizes.reverse() + return sizes + + assert num_cuda_graphs >= 1, f"num_cuda_graphs must be >= 1, got {num_cuda_graphs}" + assert ( + cuda_graph_max_tokens > 0 + ), f"cuda_graph_max_tokens must be > 0, got {cuda_graph_max_tokens}" + + # Even stride: step = round_up_to(max / N, rounder), TP-aligned. + step = cuda_graph_max_tokens / num_cuda_graphs + step = rounder * int(math.ceil(int(step) / rounder)) + step = round_up_to_nearest_multiple(step, tp_size) + step = max(step, tp_size) + cuda_graph_max_tokens = (cuda_graph_max_tokens // tp_size) * tp_size + + if num_cuda_graphs == 1: + return [cuda_graph_max_tokens] + + sizes = list(range(step, cuda_graph_max_tokens, step)) + if not sizes or sizes[-1] != cuda_graph_max_tokens: + sizes.append(cuda_graph_max_tokens) + sizes.reverse() + return sizes + @staticmethod def generate_cuda_graph_batch_dimensions_list( tp_size: int, @@ -304,6 +381,7 @@ def generate_cuda_graph_batch_dimensions_list( max_sequence_length: int, use_cuda_graphs_for_non_decode_steps: bool, num_speculative_tokens: int = 0, + sizing_distribution: "CudaGraphSizingDistribution" = None, ) -> Tuple[List[InferenceBatchDimensions], Optional[List[int]]]: """ Generate CUDA graph batch dimensions. @@ -361,6 +439,12 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int cuda_graph_decode_token_counts = None if num_cuda_graphs is not None: + # Lazy import to avoid a circular dependency with config.py. + from megatron.core.inference.config import CudaGraphSizingDistribution + + if sizing_distribution is None: + sizing_distribution = CudaGraphSizingDistribution.EXPONENTIAL + # Ensure valid num_cuda_graphs. if ( cuda_graph_max_tokens is None @@ -387,6 +471,7 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int tp_size=tp_size, num_cuda_graphs=num_cuda_graphs, cuda_graph_max_tokens=cuda_graph_max_tokens, + sizing_distribution=sizing_distribution, ) ) @@ -399,9 +484,33 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int tp_size=tp_size, num_cuda_graphs=num_cuda_graphs, cuda_graph_max_tokens=cuda_graph_max_tokens_decode, + sizing_distribution=sizing_distribution, ) ) + # Include the smallest decode-only graphs when auto-sizing (num_cuda_graphs == -1). + # Without this, TP alignment and the num_speculative_tokens floor division can drop + # the smallest 1- and 2-request shapes from the captured set. + # + # The minimum valid decode token_count is lcm(spec_unit, tp_size): + # - Ensure divisible by tp_size (required so TP / sequence-parallel never produces a + # single-token graph when tp_size > 1). + # - Ensure a multiple of (spec+1) so it accommodates an integer number of decode + # requests when speculative decoding is enabled. + if num_cuda_graphs == -1: + spec_unit = num_speculative_tokens + 1 + min_decode_tokens = math.lcm(spec_unit, tp_size) + for req_count_multiple in (1, 2): + floor_tokens = min_decode_tokens * req_count_multiple + if ( + floor_tokens <= cuda_graph_max_tokens_decode + and floor_tokens not in cuda_graph_decode_token_counts + ): + cuda_graph_decode_token_counts.append(floor_tokens) + cuda_graph_decode_token_counts = sorted( + set(cuda_graph_decode_token_counts), reverse=True + ) + cuda_graph_batch_dimensions_list = [] if num_cuda_graphs is None: cuda_graph_batch_dimensions_list = [] @@ -419,34 +528,62 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int token_count=token_count, prefill_req_count=0, decode_req_count=decode_req_count ) else: - # Mixed prefill and decode mode + # Mixed prefill and decode mode. + # + # Under EXPONENTIAL distribution (default): generate mixed CGs across a + # geometric P-grid {1, 2, 4, ..., max_requests}. This bounds the relative + # overhead per real batch (~2x P slack worst case) and is the structural fix + # that makes mixed CGs usable for real batches with P != fixed_P. + # + # Under LINEAR distribution: use the legacy fixed P value + # (cuda_graph_mixed_prefill_request_count) — same single-P behavior main has + # today, for apples-to-apples benchmarking against vLLM-style configurations. + if sizing_distribution == CudaGraphSizingDistribution.LINEAR: + p_values = [min(cuda_graph_mixed_prefill_request_count, max_requests)] + # In legacy mode, the prefill-only floor uses the fixed P value to match + # main's behavior exactly. + prefill_only_floor = cuda_graph_mixed_prefill_request_count + else: + p_values = [] + p = 1 + while p < max_requests: + p_values.append(p) + p *= 2 + if not p_values or p_values[-1] != max_requests: + p_values.append(max_requests) + prefill_only_floor = 1 + # Create prefill and mixed dimensions with full token counts for size in cuda_graph_prefill_token_counts: assert size % tp_size == 0 - prefill_req_count = min(cuda_graph_mixed_prefill_request_count, max_requests) - decode_req_count = max( - 0, - min( - (size - prefill_req_count) // (num_speculative_tokens + 1), - max_requests - prefill_req_count, - ), - ) - add_if_valid( - token_count=size, - prefill_req_count=prefill_req_count, - decode_req_count=decode_req_count, - ) + for prefill_req_count in p_values: + decode_req_count = max( + 0, + min( + (size - prefill_req_count) // (num_speculative_tokens + 1), + max_requests - prefill_req_count, + ), + ) + # Skip token_count == 1 with prefill_req == 1: the gather kernel asserts + # on index >= 1 against a 1-element tensor at capture time. Larger + # `(size, size, 0)` shapes (each prefill = 1 token, total batch >= 2) are + # fine because the gather has multiple indices to read. + if size < 2: + continue + add_if_valid( + token_count=size, + prefill_req_count=prefill_req_count, + decode_req_count=decode_req_count, + ) # We need to ensure the prefill requests are shorter than the max sequence length, # considering the one decode token is used for prefill request construction prefill_only_minimal_num = max( - cuda_graph_mixed_prefill_request_count, - math.ceil(size / max(1, max_sequence_length - 1)), + prefill_only_floor, math.ceil(size / max(1, max_sequence_length - 1)) ) - if prefill_only_minimal_num < max_requests: + if prefill_only_minimal_num < max_requests and size >= 2: + prefill_req_count = max(prefill_only_minimal_num, min(max_requests, size)) add_if_valid( - token_count=size, - prefill_req_count=max(prefill_only_minimal_num, min(max_requests, size)), - decode_req_count=0, + token_count=size, prefill_req_count=prefill_req_count, decode_req_count=0 ) # Create decode-only dimensions with optimized token counts diff --git a/megatron/core/inference/config.py b/megatron/core/inference/config.py index e8769f3d6e7..ea4b08e5183 100644 --- a/megatron/core/inference/config.py +++ b/megatron/core/inference/config.py @@ -117,6 +117,22 @@ class KVCacheManagementMode(str, Enum): """Deallocate large tensors and recompute them from scratch during allocation.""" +class CudaGraphSizingDistribution(str, Enum): + """How CUDA graph token-count sizes are spaced when generating the captured graphs. + + EXPONENTIAL (default) — token counts halve from `cuda_graph_max_tokens` down to `tp_size`, + giving a log-spaced distribution. Bounded relative padding (~2x worst case) at every scale and + `log2(max_tokens)` total graphs. + + LINEAR — Include size-1 and size-2 graphs where applicable, linear spacing up until 256, and + sparser linear spacing past 256. e.g. `[1, 2, 4] + range(8, 256, 8) + range(256, max+1, 16)`. + Higher graph density at the top end. + """ + + EXPONENTIAL = "exponential" + LINEAR = "linear" + + @dataclass class InferenceConfig: """ @@ -197,10 +213,20 @@ class InferenceConfig: """ cuda_graph_mixed_prefill_count: Optional[int] = 16 - """ + """ The number of mixed prefill graphs to capture if mixed prefill/decode graphs are enabled. """ + cuda_graph_sizing_distribution: CudaGraphSizingDistribution = ( + CudaGraphSizingDistribution.EXPONENTIAL + ) + """ + How CUDA graph token counts are spaced. EXPONENTIAL (default) halves from + `cuda_graph_max_tokens` down to `tp_size` (log-spaced, ~log2(max_tokens) graphs). + LINEAR uses a range of linear strides (includes small graphs + mid-range linearity + + a bigger step size at the top end). + """ + use_cuda_graphs_for_non_decode_steps: bool = True """ Whether to use CUDA graphs for non-decode steps. diff --git a/megatron/core/inference/contexts/dynamic_context.py b/megatron/core/inference/contexts/dynamic_context.py index 4a0d0cba518..ca7cfeed063 100644 --- a/megatron/core/inference/contexts/dynamic_context.py +++ b/megatron/core/inference/contexts/dynamic_context.py @@ -662,6 +662,7 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC max_sequence_length=self.max_sequence_length, use_cuda_graphs_for_non_decode_steps=self.use_cuda_graphs_for_non_decode_steps, num_speculative_tokens=self.num_speculative_tokens, + sizing_distribution=inference_config.cuda_graph_sizing_distribution, ) ) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 92efff36073..8a43eb0f7ae 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -44,7 +44,7 @@ ) from megatron.core.inference.utils import Counter, InferenceMode, await_process_call from megatron.core.process_groups_config import ProcessGroupCollection -from megatron.core.transformer.cuda_graphs import delete_cuda_graphs +from megatron.core.transformer.cuda_graphs import CudaGraphManager, delete_cuda_graphs from megatron.core.transformer.enums import InferenceCudaGraphScope from megatron.core.transformer.moe.router_replay import RouterReplay, RouterReplayAction from megatron.core.utils import ( @@ -133,6 +133,8 @@ class EngineSuspendedError(Exception): def format_mem_bytes(mem_bytes): """Convert a byte count to a human-readable string in tb, gb, mb, kb, or bytes.""" + if mem_bytes < 0: + return "-" + format_mem_bytes(-mem_bytes) for power, suffix in [(4, "tb"), (3, "gb"), (2, "mb"), (1, "kb"), (0, "bytes")]: suffix_bytes = 1024**power if mem_bytes >= suffix_bytes: @@ -140,6 +142,33 @@ def format_mem_bytes(mem_bytes): return "%d bytes" % mem_bytes +def _cuda_graph_mempool_bytes() -> Tuple[int, int]: + """Return (reserved, allocated) bytes belonging to the global CUDA graph mempool. + + PyTorch's `torch.cuda.memory_stats()` reports process-wide totals that mix in + every other allocation (KV cache, NCCL workspaces, layer scratch). To isolate + growth caused by graph capture, we walk `torch.cuda.memory_snapshot()` and + filter segments by their `segment_pool_id` against the graph pool handle. + Returns (0, 0) if the pool hasn't been created yet. + """ + pool_id = CudaGraphManager.global_mempool + if pool_id is None: + return 0, 0 + reserved = 0 + allocated = 0 + for seg in torch.cuda.memory_snapshot(): + seg_pool_id = ( + seg.get("segment_pool_id") + or seg.get("private_pool_id") + or seg.get("pool_id") + or seg.get("pool") + ) + if seg_pool_id == pool_id: + reserved += seg.get("total_size", 0) + allocated += seg.get("allocated_size", 0) + return reserved, allocated + + @dataclass(kw_only=True) class RequestEntry: """Entry in the engine's `self.requests` dict.""" @@ -347,6 +376,13 @@ def create_cuda_graphs(self, reset_context: bool = True): time_start = time.time() mem_stats_start = torch.cuda.memory_stats() + # Snapshot of process-wide stats for the "total memory used by capture" summary. + start_proc_reserved = mem_stats_start["reserved_bytes.all.current"] + start_proc_alloc = mem_stats_start["allocated_bytes.all.current"] + + # Pool-scoped baselines for the per-iteration deltas. + prev_pool_reserved, prev_pool_alloc = _cuda_graph_mempool_bytes() + logging.info("> dynamic_engine.py: building cuda graphs for ") for graph in context.cuda_graph_batch_dimensions_list: logging.info(graph) @@ -427,27 +463,43 @@ def create_cuda_graphs(self, reset_context: bool = True): context.reset() + # Per-iteration memory accounting, scoped to the CUDA-graph mempool. + # This isolates pool growth from process-wide scratch churn (KV cache, + # NCCL workspaces, etc.) that pollutes `torch.cuda.memory_stats()`. + pool_reserved, pool_alloc = _cuda_graph_mempool_bytes() + logging.info( + " [graph %d/%d] %s | pool reserved=%s (Δiter=%s) " "pool allocated=%s (Δiter=%s)", + tbar_idx + 1, + len(context.cuda_graph_batch_dimensions_list), + cuda_graph_batch_dimension, + format_mem_bytes(pool_reserved), + format_mem_bytes(pool_reserved - prev_pool_reserved), + format_mem_bytes(pool_alloc), + format_mem_bytes(pool_alloc - prev_pool_alloc), + ) + prev_pool_reserved, prev_pool_alloc = pool_reserved, pool_alloc + if mtp_warmup_enabled and mtp_seen_batch_sizes: logging.info("> MTP CUDA graph warmup: %d batch size(s)", len(mtp_seen_batch_sizes)) # Memory usage. time_end = time.time() mem_stats_end = torch.cuda.memory_stats() + final_pool_reserved, final_pool_alloc = _cuda_graph_mempool_bytes() capture_stats = { "time": time_end - time_start, - "allocated_bytes": ( - mem_stats_end["allocated_bytes.all.current"] - - mem_stats_start["allocated_bytes.all.current"] - ), - "reserved_bytes": ( - mem_stats_end["reserved_bytes.all.current"] - - mem_stats_start["reserved_bytes.all.current"] - ), + "allocated_bytes": (mem_stats_end["allocated_bytes.all.current"] - start_proc_alloc), + "reserved_bytes": (mem_stats_end["reserved_bytes.all.current"] - start_proc_reserved), + "pool_reserved_bytes": final_pool_reserved, + "pool_allocated_bytes": final_pool_alloc, } logging.info( - "> built cuda graph(s) in %.2f sec, with total memory usage: " - "allocated %s, reserved %s.", + "> built cuda graph(s) in %.2f sec. " + "Mempool: reserved %s, allocated %s. " + "Process-wide delta: allocated %s, reserved %s.", capture_stats["time"], + format_mem_bytes(capture_stats["pool_reserved_bytes"]), + format_mem_bytes(capture_stats["pool_allocated_bytes"]), format_mem_bytes(capture_stats["allocated_bytes"]), format_mem_bytes(capture_stats["reserved_bytes"]), ) diff --git a/megatron/inference/utils.py b/megatron/inference/utils.py index 3f06eb8a301..60b6d9bb0c0 100644 --- a/megatron/inference/utils.py +++ b/megatron/inference/utils.py @@ -9,6 +9,7 @@ from gpt_builders import gpt_builder from hybrid_builders import hybrid_builder from megatron.core.inference.config import ( + CudaGraphSizingDistribution, InferenceConfig, KVCacheManagementMode, MambaInferenceStateConfig, @@ -356,6 +357,9 @@ def get_inference_config_from_model_and_args(model: MegatronModule, args): unified_memory_level=args.inference_dynamic_batching_unified_memory_level, kv_cache_management_mode=KVCacheManagementMode(args.rl_kv_cache_management_mode), cuda_graph_mixed_prefill_count=args.inference_dynamic_batching_cuda_graph_mixed_prefill_count, # pylint: disable=line-too-long + cuda_graph_sizing_distribution=CudaGraphSizingDistribution( + args.inference_dynamic_batching_cuda_graph_sizing_distribution + ), use_cuda_graphs_for_non_decode_steps=not args.decode_only_cuda_graphs, cuda_graph_all_prefills=args.inference_cuda_graph_all_prefills, static_kv_memory_pointers=args.rl_persist_cuda_graphs, diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index f3c2ded6907..948f8091814 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -1963,6 +1963,14 @@ def _add_inference_args(parser): group.add_argument('--inference-dynamic-batching-cuda-graph-mixed-prefill-count', type=int, default=16, help='Number of mixed prefill requests to capture in a cuda graph.') + group.add_argument('--inference-dynamic-batching-cuda-graph-sizing-distribution', + type=str, default='exponential', + choices=['exponential', 'linear'], + dest='inference_dynamic_batching_cuda_graph_sizing_distribution', + help='Spacing of CUDA graph token counts. "exponential" (default) ' + 'halves from cuda_graph_max_tokens down to tp_size, giving a ' + 'log-spaced distribution with bounded relative padding. ' + '"linear" uses varying linear strides across the range.') group.add_argument('--inference-dynamic-batching-sampling-backend', type=str, default='torch', choices=['torch', 'flashinfer'], diff --git a/tests/unit_tests/inference/contexts/test_dynamic_context.py b/tests/unit_tests/inference/contexts/test_dynamic_context.py index 81a5e02792a..e79df3aaebf 100644 --- a/tests/unit_tests/inference/contexts/test_dynamic_context.py +++ b/tests/unit_tests/inference/contexts/test_dynamic_context.py @@ -1659,10 +1659,16 @@ def test_add_dummy_requests_for_expert_parallel_step_matches_slow_path( num_speculative_tokens=num_speculative_tokens, ) - smallest = min(ctx.cuda_graph_batch_dimensions_list) + # The fast path is decode-only by construction, so pick the smallest decode-only batch_dim. + # With the geometric grid for mixed cudagraphs, the global min may now be a P=1 mixed shape + # when num_speculative_tokens > 0 makes decode-only token_count > 1) + smallest = min( + batchdim + for batchdim in ctx.cuda_graph_batch_dimensions_list + if batchdim.prefill_req_count == 0 + ) N = smallest.decode_req_count T = smallest.token_count # N * (num_speculative_tokens + 1) - assert smallest.prefill_req_count == 0, "smallest graph must be decode-only" # --- slow path (reference) --- ctx.add_dummy_requests_for_cudagraph_capture(smallest) diff --git a/tests/unit_tests/inference/engines/test_dynamic_engine.py b/tests/unit_tests/inference/engines/test_dynamic_engine.py index 503b0fa52ae..2f03c8cb7aa 100644 --- a/tests/unit_tests/inference/engines/test_dynamic_engine.py +++ b/tests/unit_tests/inference/engines/test_dynamic_engine.py @@ -878,23 +878,28 @@ def test_fixed_output_lengths(self, model_provider: str) -> None: def test_cuda_graph_token_counts(self, use_non_decode: bool) -> None: """Test initialization of `cuda_graph_token_counts` in dynamic context.""" + # Exponential-decay graph distribution (halve from max down to tp_size). + # decode-only path: cuda_graph_max_tokens = max_requests * (spec+1) = 80. + # non-decode path: cuda_graph_max_tokens = self.max_tokens (DEFAULT 16384); + # most large prefill sizes are filtered by is_valid because + # token_count > prefill_req_count * (max_sequence_length - 1). decode_only_cases = [ (0, [80]), (1, [80]), - (2, [80, 40]), - (4, [80, 72, 48, 24]), - (8, [80, 64, 48, 32, 16]), - (16, [80, 72, 64, 56, 48, 40, 32, 24, 16, 8]), - (32, [80, 72, 64, 56, 48, 40, 32, 24, 16, 8]), + (2, [80, 1]), + (4, [80, 40, 20, 1]), + (8, [80, 40, 20, 10, 4, 2, 1]), + (16, [80, 40, 20, 10, 4, 2, 1]), + (32, [80, 40, 20, 10, 4, 2, 1]), ] non_decode_cases = [ (0, [80]), (1, [80]), - (2, [80, 40]), - (4, [80, 72, 48, 24]), - (8, [80, 64, 48, 32, 16]), - (16, [1024, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8]), - (32, [1024, 512, 80, 72, 64, 56, 48, 40, 32, 24, 16, 8]), + (2, [80, 1]), + (4, [80, 40, 20, 1]), + (8, [1024, 512, 256, 80, 40, 20, 10, 4, 2, 1]), + (16, [1024, 512, 256, 128, 80, 64, 40, 32, 20, 16, 10, 8, 4, 2, 1]), + (32, [1024, 512, 256, 128, 80, 64, 40, 32, 20, 16, 10, 8, 4, 2, 1]), ] cases = non_decode_cases if use_non_decode else decode_only_cases