Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion docs/serving/deepseek-v4.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
DP=4 + expert parallel + mega_moe + FP8 KV cache (B200, 4× SM100):

```bash
CUDA_VISIBLE_DEVICES=0,1,2,3 tokenspeed serve deepseek-ai/DeepSeek-V4-Flash \
CUDA_VISIBLE_DEVICES=0,1,2,3 exec ts serve \
--model deepseek-ai/DeepSeek-V4-Flash \
--host localhost --port 30100 \
--dist-init-addr 127.0.0.1:4013 \
--trust-remote-code \
Expand Down Expand Up @@ -50,6 +51,20 @@ also be bumped to 256.)
- `--deepseek-v4-indexer-prefill-max-logits-mb N`: caps the FP4 indexer
prefill logits buffer in MB (default 512).

## MTP speculative decoding

DeepSeek V4 can use the checkpoint's NextN/MTP draft layers through the standard
speculative flags. For `num_steps > 1`, keep the main V4 launch flags and add:

```bash
--speculative-algorithm MTP \
--speculative-num-steps 3
```

When `--speculative-draft-model-path` is omitted for MTP, TokenSpeed uses the
same V4 checkpoint as the draft source and loads the `DeepseekV4ForCausalLMNextN`
architecture.

## Hardware / dependency requirements

- 4× NVIDIA Blackwell SM100 (B200) GPUs.
Expand Down
35 changes: 25 additions & 10 deletions python/tokenspeed/runtime/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
_DEEPSEEK_V4_ARCHITECTURES = frozenset(
{
"DeepseekV4ForCausalLM",
"DeepseekV4ForCausalLMNextN",
}
)
_MLA_ARCHITECTURES = frozenset(
Expand Down Expand Up @@ -85,10 +86,13 @@ def override_model_config(model_config, ext_yaml):


def is_deepseek_v4(config: PretrainedConfig) -> bool:
return (
config.architectures is not None
and config.architectures[0] in _DEEPSEEK_V4_ARCHITECTURES
)
architectures = getattr(config, "architectures", None) or []
return len(architectures) > 0 and architectures[0] in _DEEPSEEK_V4_ARCHITECTURES


def is_deepseek_v4_nextn(config: PretrainedConfig) -> bool:
architectures = getattr(config, "architectures", None) or []
return len(architectures) > 0 and architectures[0] == "DeepseekV4ForCausalLMNextN"


def configure_deepseek_v4_attention(model_config) -> None:
Expand All @@ -111,6 +115,19 @@ def configure_deepseek_v4_attention(model_config) -> None:
model_config.scaling = model_config.scaling * mscale * mscale


def _derive_num_attention_layers(
hf_config: PretrainedConfig,
num_hidden_layers: int,
) -> int:
architectures = getattr(hf_config, "architectures", None) or []
num_attention_layers = num_hidden_layers
if is_deepseek_v4_nextn(hf_config):
num_attention_layers = int(getattr(hf_config, "num_nextn_predict_layers", 1))
if any(arch in _DOUBLE_ATTENTION_LAYER_ARCHITECTURES for arch in architectures):
num_attention_layers = num_hidden_layers * 2
return num_attention_layers


class ModelConfig:
def __init__(
self,
Expand Down Expand Up @@ -249,12 +266,10 @@ def __init__(
self.num_hidden_layers = getattr(self.hf_text_config, "num_hidden_layers", None)
if self.num_hidden_layers is None:
self.num_hidden_layers = self.hf_text_config.num_layers
self.num_attention_layers = self.num_hidden_layers
if any(
arch in _DOUBLE_ATTENTION_LAYER_ARCHITECTURES
for arch in self.hf_config.architectures
):
self.num_attention_layers = self.num_hidden_layers * 2
self.num_attention_layers = _derive_num_attention_layers(
self.hf_config,
self.num_hidden_layers,
)
self.vocab_size = self.hf_text_config.vocab_size

# Verify quantization
Expand Down
20 changes: 13 additions & 7 deletions python/tokenspeed/runtime/engine/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,9 @@ def __init__(
f"(ratio={server_args.mamba_full_memory_ratio})."
)

enable_mixed_prefill_decode = (
server_args.enable_mixed_chunk and server_args.speculative_algorithm is None
)
scheduler_cfg = make_config(
num_device_pages=self.max_total_num_tokens // server_args.block_size,
max_scheduled_tokens=server_args.chunked_prefill_size,
Expand All @@ -293,6 +296,7 @@ def __init__(
mamba_cache_chunk_size=server_args.mamba_cache_chunk_size,
mamba_pool_total_chunks=mamba_pool_total_chunks,
paged_cache_groups=pool_to_paged_cache_groups(token_to_kv_pool),
enable_mixed_prefill_decode=enable_mixed_prefill_decode,
)
logger.info(
"Scheduler config: page_size=%s num_device_pages=%s "
Expand Down Expand Up @@ -785,8 +789,10 @@ def _commit_forward_results(
on_first_token=None,
):
self.request_handler.forward_ct += 1
forward_mode = (
ForwardMode.EXTEND if forward_op.num_extends() > 0 else ForwardMode.DECODE
forward_mode = ForwardMode.from_num_extends(
forward_op.num_extends(),
len(forward_op.request_ids),
has_drafter=self.server_args.speculative_algorithm is not None,
)
self.request_handler._profile_batch_predicate(forward_mode)

Expand Down Expand Up @@ -859,12 +865,12 @@ def _dp_sync_and_check(self, forward_op) -> DpForwardMetadata:
batch_size = len(forward_op.request_ids) if forward_op is not None else 0
if forward_op is None:
forward_mode = ForwardMode.IDLE
elif forward_op.num_extends() > 0:
forward_mode = ForwardMode.EXTEND
elif self.server_args.speculative_algorithm is not None:
forward_mode = ForwardMode.TARGET_VERIFY
else:
forward_mode = ForwardMode.DECODE
forward_mode = ForwardMode.from_num_extends(
forward_op.num_extends(),
batch_size,
has_drafter=self.server_args.speculative_algorithm is not None,
)

self._dp_local_info[0, 0] = num_tokens
self._dp_local_info[0, 1] = batch_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,8 @@ def post_process_forward_op(
forward_op.input_lengths,
forward_op.extend_prefix_lens,
)
is_decode_op = forward_op.num_extends() <= 0
num_extends = forward_op.num_extends()
is_decode_op = num_extends <= 0

request_changes = []
stream_out_rids = []
Expand All @@ -504,6 +505,7 @@ def post_process_forward_op(
if output_logprobs_list is not None
else None
)
is_decode_slot = i >= num_extends
if self.spec_num_tokens is not None and is_decode_op:
pt += self.spec_num_tokens
else:
Expand All @@ -524,7 +526,7 @@ def post_process_forward_op(
if on_first_token is not None and model_output_ids:
on_first_token(forward_op.request_pool_indices[i], model_output_ids[0])

if is_decode_op and self.spec_algorithm is not None:
if is_decode_slot and self.spec_algorithm is not None:
request_state.spec_verify_ct += 1

# With the capturable grammar pipeline the matcher is
Expand Down Expand Up @@ -597,7 +599,7 @@ def post_process_forward_op(
else:
stream_out_rids.append(rid)
stream_out_states.append(request_state)
if is_decode_op:
if is_decode_slot:
request_changes.append(
make_update_reserve_tokens_event(rid, output_length)
)
Expand Down
4 changes: 2 additions & 2 deletions python/tokenspeed/runtime/engine/scheduler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

"""Helper functions for constructing scheduler specs and events."""

import os
from collections.abc import Sequence
from typing import Any, Mapping

Expand All @@ -39,7 +38,6 @@
"WriteBackDoneEvent": Cache.WriteBackDoneEvent,
"PrefetchDoneEvent": Cache.PrefetchDoneEvent,
}
_TRUTHY_ENV_VALUES = {"1", "true", "yes", "on"}


def make_spec(rid: str, tokens: list[int]) -> RequestSpec:
Expand All @@ -66,6 +64,7 @@ def make_config(
mamba_cache_chunk_size: int = 64,
mamba_pool_total_chunks: int = 0,
paged_cache_groups: Sequence["PagedCacheGroupConfig"] | None = None,
enable_mixed_prefill_decode: bool = False,
) -> SchedulerConfig:
cfg = SchedulerConfig()
cfg.num_device_pages = num_device_pages
Expand All @@ -92,6 +91,7 @@ def make_config(
cfg.enable_mamba = enable_mamba
cfg.mamba_cache_chunk_size = mamba_cache_chunk_size
cfg.mamba_pool_total_chunks = mamba_pool_total_chunks
cfg.enable_mixed_prefill_decode = enable_mixed_prefill_decode
if paged_cache_groups:
cfg.paged_cache_groups = list(paged_cache_groups)
return cfg
Expand Down
21 changes: 14 additions & 7 deletions python/tokenspeed/runtime/execution/cuda_graph_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,6 @@ def _capture_one(self, bs: int):
grammar_backend=self.grammar_backend,
)

self._init_capture_metadata(bs)

def run_once():
# Dummy add_batch keeps the grammar queue 1:1 with replays —
# fetch_batch pops once per forward, so warmup + capture
Expand All @@ -377,6 +375,7 @@ def run_once():
self.sampling_backend.prepare_capture(
bs=bs, num_tokens_per_req=self.max_tokens_per_req
)
self._init_capture_metadata(bs)
run_once()

# Clear any per-pool state that warm-up dirtied at pool row 0,
Expand All @@ -392,6 +391,7 @@ def run_once():
self.sampling_backend.prepare_capture(
bs=bs, num_tokens_per_req=self.max_tokens_per_req
)
self._init_capture_metadata(bs)

self.deepep_adapter.capture()

Expand Down Expand Up @@ -546,6 +546,7 @@ def _pad_offsets_to_padded_bs(
def _init_replay_metadata(
self,
padded_bs: int,
actual_bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
req_to_page: torch.Tensor,
Expand All @@ -562,7 +563,7 @@ def _init_replay_metadata(
"uses_paged_cache_groups",
False,
):
actual_bs = next(
table_bs = next(
(
int(table.shape[0])
for table in paged_cache_block_tables.values()
Expand All @@ -572,7 +573,7 @@ def _init_replay_metadata(
)
paged_cache_block_tables = self._pad_block_tables_to_padded_bs(
paged_cache_block_tables,
actual_bs=actual_bs,
actual_bs=table_bs,
padded_bs=padded_bs,
)
kwargs["paged_cache_block_tables"] = paged_cache_block_tables
Expand All @@ -585,6 +586,8 @@ def _init_replay_metadata(
kwargs["paged_cache_block_table_base_offsets"] = (
paged_cache_block_table_base_offsets
)
if getattr(self.attn_backend, "uses_padded_decode_token_mask", False):
kwargs["actual_bs"] = actual_bs
self.attn_backend.init_forward_metadata_replay_cuda_graph(
padded_bs,
req_pool_indices,
Expand All @@ -594,14 +597,16 @@ def _init_replay_metadata(
**kwargs,
)
if self.draft_attn_backend is not None:
# DRAFT_EXTEND covers step 0 + N-1 decode steps (drafter syncs per step).
draft_attn_kwargs = {}
if getattr(self.draft_attn_backend, "uses_padded_decode_token_mask", False):
draft_attn_kwargs["actual_bs"] = actual_bs
self.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
padded_bs,
req_pool_indices,
seq_lens,
req_to_page=self.drafter.req_to_page,
forward_mode=ForwardMode.DRAFT_EXTEND,
**kwargs,
**draft_attn_kwargs,
)

@nvtx_range("attn_meta_prep", color="orange")
Expand All @@ -625,7 +630,7 @@ def _init_forward_metadata(
**kwargs,
)
if self.draft_attn_backend is not None:
if forward_mode.is_extend():
if forward_mode == ForwardMode.EXTEND or forward_mode.is_mixed():
# Initial prefill: draft step 0 uses EXTEND (regular prefill)
# kernel with the caller's prefix kwargs. Step 0 and the
# subsequent decode steps have structurally different
Expand Down Expand Up @@ -785,6 +790,7 @@ def __call__(
)
self._init_replay_metadata(
padded_bs,
bs,
req_pool_indices,
seq_lens,
req_to_page=req_to_page,
Expand Down Expand Up @@ -831,6 +837,7 @@ def __call__(
extend_prefix_lens_cpu=extend_prefix_lens_cpu,
extend_seq_lens=extend_seq_lens,
extend_seq_lens_cpu=extend_seq_lens_cpu,
num_extends=ctx.num_extends,
positions=positions,
out_cache_loc=out_cache_loc,
global_num_tokens=ctx.global_num_tokens,
Expand Down
1 change: 1 addition & 0 deletions python/tokenspeed/runtime/execution/drafter/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,7 @@ def _run_multi_step_decode(
)

out_cache_loc = cache_locs[:, i - 1].contiguous()
ctx.attn_backend.advance_draft_forward_metadata()

with nvtx_range("draft_forward", color="red"):
logits_output = self.draft_model_runner.forward(
Expand Down
16 changes: 16 additions & 0 deletions python/tokenspeed/runtime/execution/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def is_extend(self):
def is_decode(self):
return self == ForwardMode.DECODE

def is_mixed(self):
return self == ForwardMode.MIXED

def is_idle(self):
return self == ForwardMode.IDLE

Expand All @@ -67,6 +70,19 @@ def is_draft_extend(self):
def is_decode_or_idle(self):
return self == ForwardMode.DECODE or self == ForwardMode.IDLE

@staticmethod
def from_num_extends(
num_extends: int,
batch_size: int,
*,
has_drafter: bool = False,
) -> "ForwardMode":
if batch_size <= 0:
return ForwardMode.IDLE
if num_extends > 0:
return ForwardMode.MIXED if num_extends < batch_size else ForwardMode.EXTEND
return ForwardMode.TARGET_VERIFY if has_drafter else ForwardMode.DECODE


class CaptureHiddenMode(IntEnum):
NULL = auto()
Expand Down
Loading
Loading