Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 143 additions & 43 deletions megatron/core/inference/engines/dynamic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
import torch
from torch import Tensor

from megatron.core.inference.batch_dimensions_utils import (
CUDAGraphBatchDimensionBuilder,
InferenceBatchDimensions,
)
from megatron.core.inference.config import KVCacheManagementMode
from megatron.core.inference.contexts.dynamic_context import (
BlockOverflowError,
Expand Down Expand Up @@ -217,6 +221,7 @@ def __init__(self, controller: TextGenerationController, context: DynamicInferen
self.track_paused_request_events = inference_config.track_paused_request_events
self.track_generated_token_events = inference_config.track_generated_token_events
self.enable_chunked_prefill = inference_config.enable_chunked_prefill
self.cuda_graph_all_prefills = inference_config.cuda_graph_all_prefills
self.metrics_writer = inference_config.metrics_writer
self.logging_step_interval = inference_config.logging_step_interval
self.unified_memory_level = inference_config.unified_memory_level
Expand All @@ -226,6 +231,9 @@ def __init__(self, controller: TextGenerationController, context: DynamicInferen
self.cuda_graph_impl = model_config.cuda_graph_impl
self.inference_cuda_graph_scope = model_config.inference_cuda_graph_scope
self.cuda_graph_modules = model_config.cuda_graph_modules
# Throw a cudagraph-admission warning if deferred for > max_sequence_length steps.
# The floor value of 100 avoids warnings in test configs where max_sequence_length < 100.
self._cg_admission_warn_after = max(100, self.context.max_sequence_length)
# Initialize engine.
self.reset()

Expand Down Expand Up @@ -1520,6 +1528,20 @@ def schedule_non_chunked_prefill(self):
self.context.check_availability(req)
)
if request_can_be_added and request_tokens_can_be_added and kv_cache_available:
# CUDA graph-aware admission gating: defer if the resulting batch shape lacks a
# matching captured CG. Non-chunked admit takes the request whole, so the
# candidate token_count is active + remaining_prompt_tokens.
if self._cg_admission_gating_active():
candidate = InferenceBatchDimensions(
token_count=(
self.context.active_token_count + len(req.remaining_prompt_tokens)
),
prefill_req_count=self.context.num_prefill_requests + 1,
decode_req_count=self.context.num_decode_requests,
)
if not self._cg_admission_check(req, candidate):
break

# Add these hashes to pending.
if prefix_caching_enabled:
for block_hash in req.precomputed_block_hashes:
Expand All @@ -1539,6 +1561,88 @@ def schedule_non_chunked_prefill(self):
if prefix_caching_enabled and pending_request_ids:
self.waiting_request_ids.extendleft(reversed(pending_request_ids))

def _cg_admission_gating_active(self) -> bool:
"""Cudagraph-aware admission gating is active when --inference-cuda-graph-all-prefills
is set, the engine has prefill/mixed CGs, and the batch-dim list is populated.

All are required so legacy tests that exercise the scheduler without intending to run on
captured graphs are unaffected. Gating is opt-in via `cuda_graph_all_prefills`.
"""
return (
self.cuda_graph_all_prefills
and self.context.use_cuda_graphs_for_non_decode_steps
and bool(self.context.cuda_graph_batch_dimensions_list)
)

def _find_cg_chunk_size(self, max_chunk_tokens: int) -> Optional[int]:
"""Return the largest chunk size <= max_chunk_tokens where batch matches a captured graph,
or None if no graph covers any chunk in the budget.

Walks the captured-CG list (sorted descending by token_count) and returns the first chunk
that falls within budget and produces an applicable batch_dim under the engine's matching
mode (strict for hybrid models). Callers must explicitly handle the None case by deferring
the admission rather than scheduling eagerly.
"""
active_tok = self.context.active_token_count
active_p = self.context.num_prefill_requests
active_d = self.context.num_decode_requests
strict = self.context.is_hybrid_model

for cg in self.context.cuda_graph_batch_dimensions_list:
chunk = cg.token_count - active_tok
if chunk < 1:
continue
if chunk > max_chunk_tokens:
continue
candidate = InferenceBatchDimensions(
token_count=cg.token_count,
prefill_req_count=active_p + 1,
decode_req_count=active_d,
)
if cg.is_applicable_for_batch_dim(candidate, strict=strict):
return chunk

return None

def _register_cg_wait(self, req) -> None:
"""Track a deferred admission attempt and emit a starvation warning at the threshold.

Decode is bounded by the number of decode steps, so deferring is bounded in practice.
Persistent waits past `_cg_admission_warn_after` consecutive steps signal a problem.
"""
req.cg_wait_iters += 1
if req.cg_wait_iters % self._cg_admission_warn_after == 0:
logging.warning(
"request %d has been deferred by CG-aware admission for %d steps — "
"possible starvation (strict=%s, active P=%d D=%d tok=%d)",
req.request_id,
req.cg_wait_iters,
self.context.is_hybrid_model,
self.context.num_prefill_requests,
self.context.num_decode_requests,
self.context.active_token_count,
)

def _cg_admission_check(self, req, candidate: InferenceBatchDimensions) -> bool:
"""Return True if the candidate batch shape matches a captured cudagraph.

On miss, registers a wait + warning via `_register_cg_wait`. On hit, resets the counter.
Caller is responsible for breaking the scheduler loop on False.
Passes match_ep_token_counts=False so this local admission probe doesn't force a per-attempt
NCCL all-reduce — the step-time matcher does its own EP sync.
"""
matched = CUDAGraphBatchDimensionBuilder.match_graph_config(
real_batch_dim=candidate,
cuda_graph_batch_dimensions_list=self.context.cuda_graph_batch_dimensions_list,
strict=self.context.is_hybrid_model,
match_ep_token_counts=False,
)
if matched is not None:
req.cg_wait_iters = 0
return True
self._register_cg_wait(req)
return False

def schedule_chunked_prefill(self):
"""
This function schedules chunked prefill requests.
Expand Down Expand Up @@ -1593,69 +1697,65 @@ def schedule_chunked_prefill(self):

# Use remaining prompt tokens for scheduling decisions
remaining_len = len(req.remaining_prompt_tokens)
token_fully_can_be_added = (
self.context.active_token_count + remaining_len <= self.context.max_tokens
)
token_partially_can_be_added = self.context.active_token_count < self.context.max_tokens
request_can_be_added, _, kv_cache_available = self.context.check_availability(req)
request_can_be_added = is_continuing_chunked_prefill or request_can_be_added

if request_can_be_added and kv_cache_available:
if token_fully_can_be_added:
# Add these hashes to pending.
if prefix_caching_enabled:
for block_hash in req.precomputed_block_hashes:
if (
block_hash
not in self.context.kv_block_allocator.kv_hash_to_block_id
):
pending_block_hashes.add(block_hash)
if request_can_be_added and kv_cache_available and token_partially_can_be_added:
# How many tokens we can admit this step.
token_budget = self.context.max_tokens - self.context.active_token_count
max_chunk = min(remaining_len, token_budget)

# Skip CG gating for the continuation of an in-flight chunked prefill:
# the request is already mid-flight, deferring it would deadlock progress.
if self._cg_admission_gating_active() and not is_continuing_chunked_prefill:
# Snap chunk size to the largest captured-CG boundary within budget,
# or defer if no shape covers any chunk in the budget.
snapped_chunk = self._find_cg_chunk_size(max_chunk)
if snapped_chunk is None:
self._register_cg_wait(req)
can_schedule = False
break
prefill_chunk_length = snapped_chunk
req.cg_wait_iters = 0
else:
prefill_chunk_length = max_chunk

# Flash-attn guard: if this chunk would leave exactly 1 token for the
# final chunk, reduce by 1 (or defer if we only have 1 token of budget).
# See https://github.com/Dao-AILab/flash-attention/issues/1537
if remaining_len - prefill_chunk_length == 1:
if prefill_chunk_length > 1:
prefill_chunk_length -= 1
else:
can_schedule = False
break

# Add hashes to pending set (prefix-caching bookkeeping).
if prefix_caching_enabled:
for block_hash in req.precomputed_block_hashes:
if block_hash not in self.context.kv_block_allocator.kv_hash_to_block_id:
pending_block_hashes.add(block_hash)

if prefill_chunk_length >= remaining_len:
self.context.chunked_prefill_request_id = -1
self.context.add_request(req)
self._loop.call_soon_threadsafe(
self._loop.create_task, self._notify_cond_for_new_request()
)
req.remaining_prompt_tokens = req.remaining_prompt_tokens.new_empty(0)
req.add_event_add_context()
# Fully scheduled, so we remove from waiting pool
self.waiting_request_ids.popleft()
# Only this case we keep checking the rest of the waiting queue
can_schedule = True
elif token_partially_can_be_added:
# Add these hashes to pending.
if prefix_caching_enabled:
for block_hash in req.precomputed_block_hashes:
if (
block_hash
not in self.context.kv_block_allocator.kv_hash_to_block_id
):
pending_block_hashes.add(block_hash)
prefill_chunk_length = self.context.max_tokens - self.context.active_token_count

# If this chunk would leave exactly 1 token for the final chunk, reduce
# this chunk by 1 or skip scheduling so the final chunk has 2 tokens.
# This avoids the edge case where max_seqlen_q=1 which results in a bug
# with the Flash Attention kernel.
# See https://github.com/Dao-AILab/flash-attention/issues/1537
if remaining_len - prefill_chunk_length == 1:
if prefill_chunk_length > 1:
prefill_chunk_length -= 1
else:
# We only have space for 1 token, but remaining is 2.
# Delay scheduling to avoid leaving exactly 1 token for the final chunk.
can_schedule = False
break

else:
# Partial admit: schedule this chunk and keep the request at the queue head.
self.context.add_request(req, prefill_chunk_length=prefill_chunk_length)
self._loop.call_soon_threadsafe(
self._loop.create_task, self._notify_cond_for_new_request()
)
self.context.chunked_prefill_request_id = req.request_id
req.remaining_prompt_tokens = req.remaining_prompt_tokens[prefill_chunk_length:]
req.finished_chunk_token_count += prefill_chunk_length
# Still have tokens to prefill, so we break and keep the
# chunked prefill request at the head of the waiting queue
# Note that we do not need to continue check the queue, as the tokens are full

# Prepend pending request ids to waiting queue.
if prefix_caching_enabled and pending_request_ids:
Expand Down
3 changes: 3 additions & 0 deletions megatron/core/inference/inference_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,9 @@ class DynamicInferenceRequest(InferenceRequest):
routing_indices: Optional[np.ndarray] = None
finished_chunk_token_count: int = 0
stop_word_ids: Optional[List[List[int]]] = None # Tokenized stop words (populated internally)
# Consecutive steps this request has been deferred by CG-aware admission gating.
# Reset to 0 on successful admission. Used only for starvation logging.
cg_wait_iters: int = 0

# Prefix caching fields
block_size_tokens: Optional[int] = None # Block size for hash computation
Expand Down
Loading
Loading