Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions megatron/core/inference/contexts/kv_block_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from collections import deque
from typing import Callable, Dict, Optional

import numpy as np
import torch
from torch import Tensor

Expand Down Expand Up @@ -73,6 +74,12 @@ def __init__(
(self.total_count,), dtype=torch.int64, device=torch.cuda.current_device()
)

# Per-block routing storage for MoE routing replay.
# Maps block_id -> ndarray [block_size_tokens, num_layers, topk].
# Routing data persists through block release/deregister and is only
# cleared when a block is re-allocated or the allocator is reset.
self.block_routing: Dict[int, np.ndarray] = {}

def __str__(self):
return (
f"using: total {self.get_total_used()}/{self.total_count - 1}"
Expand Down Expand Up @@ -183,6 +190,11 @@ def allocate_memory_blocks(self, num_blocks: int) -> Optional[Tensor]:
if self.prefix_caching_eviction_policy == PrefixCachingEvictionPolicy.LRU:
self.update_timestamps(block_ids)

# Clear stale routing data for newly allocated blocks
if self.block_routing:
for bid in block_ids.tolist():
self.block_routing.pop(bid, None)

return block_ids

def release_memory_blocks(self, blocks: Tensor) -> None:
Expand Down Expand Up @@ -245,6 +257,9 @@ def reset(self) -> None:

self.total_avail = self.total_count - 1

# Clear per-block routing data
self.block_routing.clear()

if self.enable_prefix_caching:
# Reset all block hashes
self.block_hashes.fill_(-1)
Expand Down Expand Up @@ -358,3 +373,139 @@ def evict_lru_blocks(self, num_blocks_needed: int) -> bool:
self._deregister_blocks(blocks_to_evict)

return True

# =========================================================================
# Per-block routing storage methods (for MoE routing replay)
# =========================================================================

def store_routing_per_block(
self, routing_indices_per_request: Optional[Dict[int, np.ndarray]]
) -> None:
"""Distribute per-request routing indices into per-block storage.

Uses the context's token-to-block mapping to scatter each token's
routing data into the appropriate block in the allocator. Matched
(prefix-cached) blocks already have routing from the original request
and are not overwritten here since their tokens are not in the active
token layout.

Args:
routing_indices_per_request: Dict mapping request_id to routing
ndarray [num_tokens, num_layers, topk], or None.
"""
if routing_indices_per_request is None:
return

context = self.context
token_count = context.active_token_count
if token_count == 0:
return

# Get token-to-block mapping for all active tokens
block_ids = context.token_to_block_idx[:token_count]
positions = context.token_to_local_position_within_kv_block[:token_count]

# Reconstruct flat routing in active-request order
active_request_slice = slice(context.paused_request_count, context.total_request_count)
active_request_ids = context.request_ids[active_request_slice].tolist()
routing_parts = [
routing_indices_per_request[rid]
for rid in active_request_ids
if rid in routing_indices_per_request
]
if not routing_parts:
return
flat_routing = np.concatenate(routing_parts, axis=0) # [token_count, num_layers, topk]
assert flat_routing.shape[0] == token_count, (
f"Routing token count {flat_routing.shape[0]} != active token count {token_count}"
)

# Convert GPU tensors to numpy for dict-based storage
block_ids_np = block_ids.cpu().numpy()
positions_np = positions.cpu().numpy()

dummy = self.dummy_block_idx

# Group tokens by block_id using sort for efficient scatter
unique_blocks, inverse, counts = np.unique(
block_ids_np, return_inverse=True, return_counts=True
)
sorted_indices = np.argsort(inverse, kind='stable')
sorted_positions = positions_np[sorted_indices]
sorted_routing = flat_routing[sorted_indices]

offset = 0
for bid, count in zip(unique_blocks, counts):
bid = int(bid)
count = int(count)
if bid == dummy:
offset += count
continue
block_pos = sorted_positions[offset : offset + count]
block_rout = sorted_routing[offset : offset + count]
self.store_block_routing(bid, block_pos, block_rout)
offset += count

def reconstruct_routing_from_blocks(
self, block_ids: list[int], total_routing_tokens: int
) -> Optional[np.ndarray]:
"""Reconstruct routing indices from per-block storage.

Concatenates per-block routing ndarrays in block order, trimming the
last block to exactly ``total_routing_tokens`` entries.

Args:
block_ids: Ordered list of block IDs for the request.
total_routing_tokens: Expected number of routing tokens
(total_tokens - 1, since the last generated token has no
forward-pass routing).

Returns:
ndarray [total_routing_tokens, num_layers, topk] or None if any
block is missing routing data.
"""
block_size = self.context.block_size_tokens
routing_parts = []
tokens_collected = 0

for bid in block_ids:
routing = self.get_block_routing(bid)
if routing is None:
return None # Missing routing data for this block
remaining = total_routing_tokens - tokens_collected
if remaining <= 0:
break
take = min(block_size, remaining)
routing_parts.append(routing[:take])
tokens_collected += take

if not routing_parts or tokens_collected != total_routing_tokens:
return None

return np.concatenate(routing_parts, axis=0)

def store_block_routing(self, block_id: int, positions: np.ndarray, routing: np.ndarray) -> None:
"""Store routing indices for specific token positions in a block.

Args:
block_id: The block ID.
positions: ndarray of token positions within the block (1D, int).
routing: ndarray of routing data [num_positions, num_layers, topk].
"""
if block_id not in self.block_routing:
self.block_routing[block_id] = np.zeros(
(self.context.block_size_tokens, routing.shape[-2], routing.shape[-1]),
dtype=routing.dtype,
)
self.block_routing[block_id][positions] = routing

def get_block_routing(self, block_id: int) -> Optional[np.ndarray]:
"""Get stored routing indices for a block.

Args:
block_id: The block ID.

Returns:
ndarray of shape [block_size_tokens, num_layers, topk] or None.
"""
return self.block_routing.get(block_id)
23 changes: 22 additions & 1 deletion megatron/core/inference/engines/dynamic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,6 +1042,7 @@ def post_process_requests(
routing_indices_per_request: Optional[Dict[int, torch.Tensor]] = None,
pre_fwd_active_token_count: Optional[int] = None,
pre_fwd_step_count: Optional[int] = None,
finished_routing_block_ids: Optional[Dict[int, list[int]]] = None,
) -> Tuple[List[DynamicInferenceRequest], List[DynamicInferenceRequest]]:
"""
Handles post-processing for requests after a step.
Expand All @@ -1058,7 +1059,11 @@ def post_process_requests(
list of (top_n_logprobs, top_n_indices) tuples.
routing_indices_per_request: (Dict[int, Tensor]): MoE routing indices
pre-mapped by request_id. Each value is a tensor of shape
[num_tokens_this_step, num_layers, topk].
[num_tokens_this_step, num_layers, topk]. Unused when per-block
routing is active (routing_indices_per_request will be None).
finished_routing_block_ids: (Dict[int, List[int]]): Block IDs for
finished requests, saved before update_requests released them.
Used for per-block routing reconstruction.

Returns:
A list of active requests and completed requests as `DynamicInferenceRequest` objects
Expand Down Expand Up @@ -1182,6 +1187,20 @@ def post_process_requests(
self._spec_tokens_accepted += actual_accepted

if request_id in finished_request_ids:
# Reconstruct routing from per-block storage before popping.
if (
finished_routing_block_ids
and request_id in finished_routing_block_ids
and len(self.requests[request_id].record.requests) == 1
):
block_ids = finished_routing_block_ids[request_id]
total_tokens = len(request.prompt_tokens) + len(
request.generated_tokens
)
request.routing_indices = self.context.kv_block_allocator.reconstruct_routing_from_blocks(
block_ids, total_tokens - 1
)

# Request finished by normal means (termination_id, max_length, or stop word from previous step)
request.generated_length = len(request.generated_tokens)
request.status = Status.COMPLETED
Expand Down Expand Up @@ -1711,6 +1730,7 @@ async def async_bookkeep(
log_probs = step_result["log_probs"]
top_n_logprobs = step_result.get("top_n_logprobs", None)
routing_indices_per_request = step_result.get("routing_indices_per_request", None)
finished_routing_block_ids = step_result.get("finished_routing_block_ids", None)
cuda_graph_request_count = step_result["cuda_graph_request_count"]

# Add paused events.
Expand All @@ -1731,6 +1751,7 @@ async def async_bookkeep(
routing_indices_per_request,
pre_fwd_active_token_count=context_state.get("active_token_count"),
pre_fwd_step_count=context_state.get("step_count"),
finished_routing_block_ids=finished_routing_block_ids,
)

else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1266,7 +1266,7 @@ def _router_record_bookkeeping(self) -> Optional[Dict[int, Tensor]]:
axis=0,
)

return routing_splits
return dict(zip(active_request_ids, routing_splits))

def _dynamic_step_calculate_log_probs(self, logits: Tensor) -> Optional[Tensor]:
"""Calculate log probs from logits."""
Expand Down Expand Up @@ -1715,6 +1715,17 @@ def _dynamic_step_context_bookkeeping(self) -> Dict[str, Tensor]:
)
finished_request_ids = context.request_ids[finished_idxs]

# Save block IDs for finished requests before update_requests releases them.
# Needed for per-block routing reconstruction in the engine.
finished_routing_block_ids = {}
if context.kv_block_allocator.block_routing and finished_idxs.numel() > 0:
for fidx in finished_idxs.tolist():
req_id = int(context.request_ids[fidx].item())
blocks = context.request_to_kv_block_ids[fidx]
valid = blocks[blocks >= 0].tolist()
if valid:
finished_routing_block_ids[req_id] = valid

# Clone needed: update_requests mutates next_tokens in-place via tensor_swap,
# which would corrupt the reused _sampled_tokens_cuda buffer.
new_sample_copy = self._sampled_tokens_cuda[:active_request_count].clone()
Expand All @@ -1732,6 +1743,7 @@ def _dynamic_step_context_bookkeeping(self) -> Dict[str, Tensor]:
return {
"active_request_ids": active_request_ids,
"finished_request_ids": finished_request_ids,
"finished_routing_block_ids": finished_routing_block_ids,
**(update_result or {}),
}

Expand Down Expand Up @@ -1785,6 +1797,13 @@ async def async_generate_output_tokens_dynamic_batch(
# Collect routing indices per request (must be done before context transitions)
routing_indices_per_request = self._router_record_bookkeeping()

# Store routing per-block for MoE routing replay reconstruction.
# Must be done while token-to-block mappings are still valid (before update_requests).
context = self.inference_wrapped_model.inference_context
context.kv_block_allocator.store_routing_per_block(routing_indices_per_request)
# Per-step routing is no longer needed; reconstruction happens from blocks at completion.
routing_indices_per_request = None

# This is the best place to yield control back to event loop.
# At this point we have enqueued FW pass GPU kernels asynchronously.
# While they are running, we can do other useful CPU work.
Expand Down
Loading