Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
96 commits
Select commit Hold shift + click to select a range
bf1b92f
Speculative decoding simple implementation
Feb 2, 2026
194e0e4
Speculative decoding vectorized implementation
Feb 5, 2026
e911a17
Added comments and cleaned up code
Feb 6, 2026
6aab08f
Bug fix
Feb 6, 2026
fceb983
Bug fix
Feb 6, 2026
6aeaced
Bug fix
Feb 6, 2026
f5d52db
Rebase to main
shanmugamr1992 Feb 22, 2026
537a1fe
Merge remote-tracking branch 'upstream/main' into spec_mamba
santhnm2 Feb 23, 2026
b2718b8
WIP MTP for mamba
santhnm2 Feb 23, 2026
d2ac237
Merge remote-tracking branch 'upstream/main' into spec_mamba
santhnm2 Feb 24, 2026
043e7c1
Merge remote-tracking branch 'upstream/main' into spec_mamba
santhnm2 Feb 24, 2026
397fd54
WIP debugging
santhnm2 Feb 24, 2026
af60402
Add SGLang kernels
santhnm2 Feb 24, 2026
7795737
More debugging
santhnm2 Feb 24, 2026
ae03747
Working causal_conv1d_update triton kernel
santhnm2 Feb 25, 2026
5915cc2
Mamba almost working
santhnm2 Feb 25, 2026
a011c73
Fix non-consecutive acceptance bug
santhnm2 Feb 25, 2026
06b08d7
More progress
santhnm2 Feb 26, 2026
6b8835a
Working with cuda graphs
santhnm2 Feb 27, 2026
9057442
Fix cuda graphs and chunked prefill
santhnm2 Feb 27, 2026
cfc6282
Merge with main
santhnm2 Feb 28, 2026
4d5fe5d
Add speculative decode unit tests
santhnm2 Mar 2, 2026
8e3710f
Minor fix
santhnm2 Mar 2, 2026
867a137
Minimize diff
santhnm2 Mar 2, 2026
9727533
Formatting
santhnm2 Mar 2, 2026
9917e35
Linting
santhnm2 Mar 2, 2026
eff0fa1
Linting / copyright
santhnm2 Mar 2, 2026
905c7e3
Merge branch 'main' into spec_mamba
santhnm2 Mar 2, 2026
0d05f8b
Linting
santhnm2 Mar 2, 2026
6b86d00
Merge with main
santhnm2 Mar 3, 2026
19921fc
Merge remote-tracking branch 'upstream/main' into spec_mamba
santhnm2 Mar 4, 2026
42cb956
Merge with main
santhnm2 Mar 5, 2026
5947e3a
Bug fixes
santhnm2 Mar 5, 2026
789f6e8
More fixes
santhnm2 Mar 5, 2026
56b84f5
Minor fixes
santhnm2 Mar 5, 2026
1d028e2
Merge remote-tracking branch 'upstream/main' into spec_mamba
santhnm2 Mar 5, 2026
2962941
Merge remote-tracking branch 'upstream/main' into spec_mamba
santhnm2 Mar 5, 2026
1a852c6
Add softplus
santhnm2 Mar 5, 2026
3549711
Remove dead code
santhnm2 Mar 5, 2026
7faee83
Fix flaky test
santhnm2 Mar 5, 2026
fc806ef
More flaky test fixes
santhnm2 Mar 5, 2026
fadbc0c
Address claude's comments
santhnm2 Mar 5, 2026
5f3c141
Chunked prefill fix
santhnm2 Mar 6, 2026
a51d979
Move cache_seqlens_decode into mamba_metadata.py
santhnm2 Mar 6, 2026
307fad5
AAdd triton kernels (possibly revert)
santhnm2 Mar 6, 2026
e0b0a8c
Fix bug
santhnm2 Mar 6, 2026
f4112ea
Undo formatting changes
santhnm2 Mar 6, 2026
83a9726
Merge remote-tracking branch 'upstream/main' into spec_mamba
santhnm2 Mar 6, 2026
33e1910
Cleanup
santhnm2 Mar 6, 2026
9b4dadf
Add unit tests
santhnm2 Mar 6, 2026
1277af4
Merge remote-tracking branch 'upstream' into spec_mamba
santhnm2 Mar 6, 2026
e59d6e9
Test cache_seqlens in mamba_metadata.py
santhnm2 Mar 6, 2026
213e5d7
Add spec decode + prefix caching unit tests
santhnm2 Mar 6, 2026
d26450b
Fix speculative decode engine test
santhnm2 Mar 6, 2026
d942c0f
Enable prefix caching in the config
santhnm2 Mar 6, 2026
06aa6d2
Merge with main
santhnm2 Mar 7, 2026
c616203
Merge remote-tracking branch 'upstream/main' into spec_mamba
santhnm2 Mar 9, 2026
8b3fd10
Fix tests
santhnm2 Mar 9, 2026
e6f61d6
Linting
santhnm2 Mar 9, 2026
5e02618
Address review comments
santhnm2 Mar 9, 2026
ff8721d
Update clones
santhnm2 Mar 9, 2026
83f526c
Linting
santhnm2 Mar 9, 2026
1194663
Update clones
santhnm2 Mar 9, 2026
72b1f68
Delete extraneous tokens after stop sequence
santhnm2 Mar 9, 2026
22e8db3
Add engine test for deleting speculative tokens after stop token
santhnm2 Mar 9, 2026
7ce9546
Address review comments
santhnm2 Mar 9, 2026
89e55c0
Remove restriction on materialize_only_last_token_logits
santhnm2 Mar 9, 2026
a878759
Revert circular buffer logic for conv
santhnm2 Mar 9, 2026
47195a1
Linting
santhnm2 Mar 9, 2026
6d7da58
Remove references to cache_seqlens
santhnm2 Mar 9, 2026
b65dbbe
Linting and misc review comments
santhnm2 Mar 9, 2026
712824f
Linting
santhnm2 Mar 9, 2026
61e16e5
Revert materialize_only_last_token_logits changes
santhnm2 Mar 9, 2026
456853c
Formatting
santhnm2 Mar 9, 2026
c3e697f
Minor fix
santhnm2 Mar 9, 2026
ab28eb7
Merge remote-tracking branch 'upstream/main' into spec_mamba
santhnm2 Mar 9, 2026
955f404
Remove outdated assertion on test
santhnm2 Mar 9, 2026
1a0584d
Nits
santhnm2 Mar 9, 2026
00a7dcc
Fix event tracking for speculative tokens
santhnm2 Mar 9, 2026
8204991
Merge remote-tracking branch 'upstream/main' into spec_mamba
santhnm2 Mar 9, 2026
9ca68d0
Update text_generation_controller tests
santhnm2 Mar 9, 2026
ed7667c
Fix text generation controller tests
santhnm2 Mar 9, 2026
b92f79c
Fix new_speculative_tokens + eviction, add tests
santhnm2 Mar 9, 2026
9eb639a
Log speculative token acceptance rates
santhnm2 Mar 9, 2026
675aa01
Don't overcount spec proposed tokens for prefill requests
santhnm2 Mar 9, 2026
30d3b62
Merge remote-tracking branch 'upstream/main' into spec_mamba
santhnm2 Mar 9, 2026
4f119d2
Linting
santhnm2 Mar 9, 2026
fe1372c
Linting
santhnm2 Mar 10, 2026
3be6e4d
Fixing logprobs, stop words adn track_generated_token_events
Mar 10, 2026
e3e5ca2
Fix dynamic_engine unit tests
santhnm2 Mar 10, 2026
f656155
Merge remote-tracking branch 'upstream/main' into spec_mamba
santhnm2 Mar 10, 2026
0ecfb4e
Fix speculative decoding
santhnm2 Mar 10, 2026
a267a9c
Restore dynamic_engine unit test changes
santhnm2 Mar 10, 2026
9922180
Bug fix
santhnm2 Mar 10, 2026
42126a8
Merge santhnm2/spec_mamba into spec_mamba
Mar 10, 2026
4097bf1
Minimize diff
santhnm2 Mar 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 45 additions & 19 deletions megatron/core/inference/batch_dimensions_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def is_applicable_for_batch_dim(
>= real_batch_dim.prefill_req_count + real_batch_dim.decode_req_count
)

def is_valid(self, max_requests: int, max_sequence_length: int) -> bool:
def is_valid(
self, max_requests: int, max_sequence_length: int, num_speculative_tokens: int
) -> bool:
"""
Checks if the batch dimension is valid based on resource constraints.

Expand All @@ -92,11 +94,17 @@ def is_valid(self, max_requests: int, max_sequence_length: int) -> bool:
return False

# Check if token count is sufficient for requests
if self.token_count < self.prefill_req_count + self.decode_req_count:
if self.token_count < self.prefill_req_count + self.decode_req_count * (
num_speculative_tokens + 1
):
return False

# Check if the prefill requests are shorter than the max sequence length
if self.token_count > self.prefill_req_count * max_sequence_length + self.decode_req_count:
if (
self.token_count
> self.prefill_req_count * max_sequence_length
+ self.decode_req_count * (num_speculative_tokens + 1)
):
return False

return True
Expand Down Expand Up @@ -308,6 +316,7 @@ def generate_cuda_graph_batch_dimensions_list(
max_tokens: int,
max_sequence_length: int,
use_cuda_graphs_for_non_decode_steps: bool,
num_speculative_tokens: int = 0,
) -> Tuple[List[InferenceBatchDimensions], Optional[List[int]]]:
"""
Generate CUDA graph batch dimensions.
Expand Down Expand Up @@ -344,6 +353,7 @@ def generate_cuda_graph_batch_dimensions_list(
max_tokens: Maximum total tokens
max_sequence_length: Maximum sequence length
use_cuda_graphs_for_non_decode_steps: Whether to use CUDA graphs for non-decode steps
num_speculative_tokens: Number of speculative tokens

Returns:
Tuple containing:
Expand All @@ -355,7 +365,7 @@ def generate_cuda_graph_batch_dimensions_list(
def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int) -> None:
"""Helper to create and append batch dimension to list only if it's valid."""
batch_dim = InferenceBatchDimensions(token_count, prefill_req_count, decode_req_count)
if batch_dim.is_valid(max_requests, max_sequence_length):
if batch_dim.is_valid(max_requests, max_sequence_length, num_speculative_tokens):
cuda_graph_batch_dimensions_list.append(batch_dim)

# Cuda graph token-counts
Expand All @@ -372,9 +382,10 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int
):
cuda_graph_max_tokens = max_tokens

assert cuda_graph_max_tokens == max_requests, (
f"cuda_graph_max_tokens ({cuda_graph_max_tokens}) must equal max_requests "
f"({max_requests}). This is required for correctly syncing EP ranks: "
assert cuda_graph_max_tokens == max_requests * (num_speculative_tokens + 1), (
f"cuda_graph_max_tokens ({cuda_graph_max_tokens}) must equal max_requests *"
f"(num_speculative_tokens + 1) ({max_requests * (num_speculative_tokens + 1)}). "
"This is required for correctly syncing EP ranks: "
f"prefill and decode graph pools must have the same token count granularity."
)

Expand All @@ -395,8 +406,9 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int
)

# Calculate separate token counts for decode-only graphs.
# Decode graphs can be more conservative since each request uses exactly 1 token.
cuda_graph_max_tokens_decode = min(cuda_graph_max_tokens, max_requests)
cuda_graph_max_tokens_decode = min(
cuda_graph_max_tokens, max_requests * (num_speculative_tokens + 1)
)
cuda_graph_decode_token_counts = (
CUDAGraphBatchDimensionBuilder._calculate_cuda_graph_token_counts(
tp_size=tp_size,
Expand All @@ -415,20 +427,29 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int
): # decode only
# Use decode-specific token counts for decode-only graphs
for size in cuda_graph_decode_token_counts:
decode_req_count = min(size // (num_speculative_tokens + 1), max_requests)
token_count = decode_req_count * (num_speculative_tokens + 1)
token_count = token_count // tp_size * tp_size
add_if_valid(
token_count=min(size, max_requests),
prefill_req_count=0,
decode_req_count=min(size, max_requests),
token_count=token_count, prefill_req_count=0, decode_req_count=decode_req_count
)
else:
# Mixed prefill and decode mode
# Create prefill and mixed dimensions with full token counts
for size in cuda_graph_prefill_token_counts:
assert size % tp_size == 0
prefill_req_count = min(cuda_graph_mixed_prefill_request_count, max_requests)
decode_req_count = max(
0,
min(
(size - prefill_req_count) // (num_speculative_tokens + 1),
max_requests - prefill_req_count,
),
)
add_if_valid(
token_count=size,
prefill_req_count=min(cuda_graph_mixed_prefill_request_count, max_requests),
decode_req_count=min(size, max_requests)
- min(cuda_graph_mixed_prefill_request_count, max_requests),
prefill_req_count=prefill_req_count,
decode_req_count=decode_req_count,
)
# We need to ensure the prefill requests are shorter than the max sequence length,
# considering the one decode token is used for prefill request construction
Expand All @@ -445,16 +466,21 @@ def add_if_valid(token_count: int, prefill_req_count: int, decode_req_count: int

# Create decode-only dimensions with optimized token counts
for size in cuda_graph_decode_token_counts:
decode_req_count = min(size // (num_speculative_tokens + 1), max_requests)
token_count = decode_req_count * (num_speculative_tokens + 1)
token_count = token_count // tp_size * tp_size
add_if_valid(
token_count=min(size, max_requests),
prefill_req_count=0,
decode_req_count=min(size, max_requests),
token_count=token_count, prefill_req_count=0, decode_req_count=decode_req_count
)

# Remove duplicates and sort by prefill token count
cuda_graph_batch_dimensions_list = list(set(cuda_graph_batch_dimensions_list))
cuda_graph_batch_dimensions_list.sort(
key=lambda x: ((x.token_count - x.decode_req_count), x.decode_req_count), reverse=True
key=lambda x: (
(x.token_count - x.decode_req_count * (num_speculative_tokens + 1)),
x.decode_req_count,
),
reverse=True,
)

# Collect actual token counts from batch dimensions, then unique and sort
Expand Down
3 changes: 3 additions & 0 deletions megatron/core/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,9 @@ class InferenceConfig:
enable_chunked_prefill: bool = False
"""Whether to enable chunked prefill."""

num_speculative_tokens: int = 0
"""The number of speculative tokens to generate for decode steps."""

enable_prefix_caching: bool = False
"""Whether to enable prefix caching for KV cache block sharing."""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,12 +259,17 @@ def update(
self.cu_seqlens = self._cu_seqlens_buffer[: padded_prefill_count + 1]

if padded_decode_count > 0 and padded_prefill_count > 0:
self._device_decode_prefill_buffer[0] = real_decode_count
self._device_decode_prefill_buffer[0] = cu_seqlens[real_decode_count]
# This describes the number of items in the prefill tensor relative to the
# decode tensor. If chunked prefill is present, it is included in the
# "prefill" part of the main split.
self._device_decode_prefill_buffer[1] = regular_prefill_count + (
1 if has_chunked_prefill_req else 0
self._device_decode_prefill_buffer[1] = (
cu_seqlens[
real_decode_count
+ regular_prefill_count
+ (1 if has_chunked_prefill_req else 0)
]
- cu_seqlens[real_decode_count]
)
self.device_decode_prefill = self._device_decode_prefill_buffer

Expand Down
12 changes: 10 additions & 2 deletions megatron/core/inference/contexts/attention_context/mha_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def update(
request_to_kv_block_ids: torch.Tensor,
batch_dimensions: InferenceBatchDimensions,
padded_batch_dimensions: InferenceBatchDimensions,
num_speculative_tokens: int = 0,
):
"""
Args:
Expand All @@ -49,6 +50,7 @@ def update(
request_to_kv_block_ids: (>real_batch_size, max_kv_blocks)
batch_dimensions: Configuration object containing real batch settings
padded_batch_dimensions: Configuration object containing padded batch settings
num_speculative_tokens: Number of speculative tokens
"""
# Extract values from configs
real_batch_size = batch_dimensions.req_count
Expand Down Expand Up @@ -99,7 +101,7 @@ def update(
)

if padded_batch_dimensions.prefill_req_count == 0:
self._max_seqlen_q = 1
self._max_seqlen_q = num_speculative_tokens + 1
else:
# Make sure we will launch the prefill kernel for prefill graphs
self._max_seqlen_q = max(2, padded_batch_dimensions.token_count)
Expand Down Expand Up @@ -150,6 +152,7 @@ def update(
request_to_kv_block_ids: torch.Tensor,
batch_dimensions: InferenceBatchDimensions,
padded_batch_dimensions: InferenceBatchDimensions,
num_speculative_tokens: int = 0,
):
"""
Args:
Expand All @@ -158,13 +161,15 @@ def update(
request_to_kv_block_ids: (>real_batch_size, max_kv_blocks)
batch_dimensions: Configuration object containing real batch settings
padded_batch_dimensions: Configuration object containing padded batch settings
num_speculative_tokens: Number of speculative tokens
"""
super().update(
request_query_lengths,
request_kv_length_offsets,
request_to_kv_block_ids,
batch_dimensions,
padded_batch_dimensions,
num_speculative_tokens,
)

def reset(self):
Expand All @@ -183,6 +188,7 @@ def update(
request_to_kv_block_ids: torch.Tensor,
batch_dimensions: InferenceBatchDimensions,
padded_batch_dimensions: InferenceBatchDimensions,
num_speculative_tokens: int = 0,
):
"""
Args:
Expand All @@ -191,17 +197,19 @@ def update(
request_to_kv_block_ids: (>real_batch_size, max_kv_blocks)
batch_dimensions: Configuration object containing real batch settings
padded_batch_dimensions: Configuration object containing padded batch settings
num_speculative_tokens: Number of speculative tokens
"""
super().update(
request_query_lengths,
request_kv_length_offsets,
request_to_kv_block_ids,
batch_dimensions,
padded_batch_dimensions,
num_speculative_tokens,
)
if len(self.state_data["query_lengths"]) > 0:
self.state_data["max_seqlen_q"] = torch.max(self.state_data["query_lengths"]).item()
self.state_data["max_seqlen_k"] = torch.max(self.state_data["kv_seq_lengths"]).item()
else:
self.state_data["max_seqlen_q"] = 1
self.state_data["max_seqlen_q"] = num_speculative_tokens + 1
self.state_data["max_seqlen_k"] = 1
40 changes: 30 additions & 10 deletions megatron/core/inference/contexts/dynamic_block_allocator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,19 +85,39 @@ def get_total_used(self):

def get_active_used(self):
"""Compute number of active blocks used."""
return (
self.context.request_kv_block_counts[
self.context.paused_request_count : self.context.total_request_count
]
.sum()
.item()
)
if not self.enable_prefix_caching:
return (
self.context.request_kv_block_counts[
self.context.paused_request_count : self.context.total_request_count
]
.sum()
.item()
)

active_start = self.context.paused_request_count
active_end = self.context.total_request_count
if active_end > active_start:
active_rows = self.context.request_to_kv_block_ids[active_start:active_end]
valid_ids = active_rows[active_rows >= 0]
if valid_ids.numel() > 0:
return int(torch.unique(valid_ids).numel())
return 0

def get_paused_used(self):
"""Compute number of paused blocks used."""
return (
self.context.request_kv_block_counts[: self.context.paused_request_count].sum().item()
)
if not self.enable_prefix_caching:
return (
self.context.request_kv_block_counts[: self.context.paused_request_count]
.sum()
.item()
)

if self.context.paused_request_count > 0:
paused_rows = self.context.request_to_kv_block_ids[: self.context.paused_request_count]
valid_ids = paused_rows[paused_rows >= 0]
if valid_ids.numel() > 0:
return int(torch.unique(valid_ids).numel())
return 0

def get_active_avail(self):
"""Compute number of active blocks available."""
Expand Down
Loading