Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions backends/cuda/triton/kernels/sdpa_midm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@
# path is appropriate (enough rows to amortize a tiled kernel).
MIDM_MAX_M = 8

# Number of key-range partitions for split-K. The verify method exports a static
# M / B / H / D, so the partial buffers and grid are static-shaped; only the
# Number of key-range partitions for split-K. B / H / D are static for the
# exported verify method; M is the dynamic verify length (bounded by MIDM_MAX_M,
# BLOCK_M covers it), so the grid (NUM_SPLITS x B*H) is static-shaped; the
# per-split chunk size (derived from the dynamic valid_len) is a runtime scalar.
# 32 splits x (B*H) heads gives ~1K CTAs at the gemma4 global shape -- ample
# occupancy on an A100 while keeping the fp32 partials small.
Expand Down Expand Up @@ -89,9 +90,9 @@ def _sdpa_midm_splitk_kernel(
valid_len,
chunk_size,
scale,
M,
H: tl.constexpr,
HKV: tl.constexpr,
M: tl.constexpr,
D: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
Expand Down Expand Up @@ -200,8 +201,8 @@ def _sdpa_midm_reduce_kernel(
soh,
som,
sod,
M,
NUM_SPLITS: tl.constexpr,
M: tl.constexpr,
D: tl.constexpr,
BLOCK_M: tl.constexpr,
):
Expand Down Expand Up @@ -267,13 +268,17 @@ def _sdpa_midm_op(

``valid_len`` (max valid position + 1) bounds the key range; it is split into
NUM_SPLITS chunks of ``chunk_size`` keys computed in parallel, then reduced.
M / B / H / D are static for the exported verify method, so only chunk_size is
a runtime (backed-SymInt) scalar -- the grid and partial buffers are static.
B / H / D are static for the exported verify method; M is the dynamic verify
length (bounded by MIDM_MAX_M). chunk_size (from the dynamic valid_len) is a
runtime (backed-SymInt) scalar; the grid (NUM_SPLITS x B*H) is static.
"""
B, H, M, D = q.shape
HKV = k.shape[1]
out = torch.empty_like(q)
BLOCK_M = max(16, triton.next_power_of_2(M))
# M <= MIDM_MAX_M (8) => next_pow2(M) <= 8 => max(16, .) is always 16. Hardcode
# so M can be a runtime (dynamic verify) dim -- next_power_of_2 can't take a
# SymInt, and M is a kernel runtime arg used only for the offs_m < M masks.
BLOCK_M = 16
# gemma4 global layers use D=512; a wide key tile + pipelining overflow SMEM
# there, so shrink both. Small D can afford more.
BLOCK_N, num_stages = (32, 1) if D >= 512 else (64, 2)
Expand Down Expand Up @@ -381,8 +386,9 @@ def midm_sdpa(
) -> torch.Tensor:
"""Dispatch: the mid-M op for a small query window when enabled; otherwise
the standard F.sdpa the model already uses (which the replacement pass swaps
for triton::sdpa). M is static per exported method, so the branch resolves at
trace time. ``valid_len`` is the shared per-forward key bound."""
for triton::sdpa). M (q.shape[2]) is the dynamic verify length; its exported
range [2, MIDM_MAX_M] satisfies this guard, so the branch resolves at export.
``valid_len`` is the shared per-forward key bound."""
M = q.shape[2]
if enable and 2 <= M <= MIDM_MAX_M:
return sdpa_midm(q, k, v, input_pos, scale, valid_len=valid_len)
Expand Down
57 changes: 37 additions & 20 deletions examples/models/eagle3/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
Three methods are lowered together so they share mutable state:
- "prefill": target prompt prefill (T in [get_min_prefill_chunk,
get_max_prefill_chunk]) -> next token + fused feature.
- "target_verify": target forward over the candidate chain (static T=chain+1)
- "target_verify": target forward over the candidate chain (dynamic T in
[2, MATVEC_MAX_M] = K+1; --chain selects K at runtime)
-> per-position greedy ids + fused feature.
- "draft_decode": draft proposal over its KV cache (T>=1; seed with T>1, step
with T=1) -> proposed target ids + recurrent feature.
Expand All @@ -34,9 +35,12 @@
supported.

Scope (this is a fixed-shape ExecuTorch artifact, not a generic EAGLE runtime):
chain length, the chain_len+1 verify window, the prefill/draft dynamic ranges,
the CUDA backend, and the small-M INT4 dispatch policy are all baked at export —
vary the target, chain length, or backend by re-exporting. The caller is
the target, the prefill/draft/verify dynamic ranges, the CUDA backend, and the
small-M INT4 dispatch policy are all baked at export — vary the target or backend
by re-exporting. Chain length K is NOT baked: target_verify is dynamic over
T in [2, MATVEC_MAX_M], so one .pte serves any K in [1, MATVEC_MAX_M - 1]
(get_chain_len is only the default) and the runner selects K with --chain. The
caller is
responsible for pairing a target, draft, and tokenizer that were trained
together: only target/draft hidden size is checked here; tokenizer identity,
target vocab size, the d2t/t2d mapping, the tap-layer convention, and the draft's
Expand All @@ -56,8 +60,9 @@
from executorch.examples.models.eagle3.speculator import Eagle3Speculator
from executorch.examples.models.eagle3.target import TARGETS

# Route the static chain_len+1 verify forward to the small-M INT4 GEMM. Must be
# <= the shim's GEMM_MAX_M (8 in int4_plain_mm.cuh) and >= the largest chain+1.
# Route the verify forward to the small-M INT4 GEMM. target_verify is dynamic
# over T in [2, _MATVEC_MAX_M] (chain_len+1 is only the export example), and the
# whole range must be <= the shim's GEMM_MAX_M (8 in int4_plain_mm.cuh).
# Set locally on int4_dispatch (not the global default) so other models' exports
# keep MATVEC_MAX_M=4 and their dynamic prefill ranges are unaffected.
_MATVEC_MAX_M = 8
Expand Down Expand Up @@ -138,8 +143,9 @@ def _lap(msg: str) -> None:
hidden = spec.draft.config.hidden_size
draft_vocab_size = spec.draft.config.draft_vocab_size
# Verify re-feeds the last confirmed token (its logits are the folded bonus)
# plus the K proposals: a fixed chain_len+1 window in one target forward. With
# chain_len+1 <= MATVEC_MAX_M the verify forward stays on the small-M GEMM
# plus the K proposals: a chain_len+1 window -- only the export example.
# target_verify is lowered dynamic over T in [2, MATVEC_MAX_M], and with the
# whole range <= MATVEC_MAX_M the verify forward stays on the small-M GEMM
# rather than the dequant path.
verify_len = chain_len + 1
# prefill's dynamic length must take a single INT4 dispatch branch over its
Expand All @@ -164,10 +170,18 @@ def _lap(msg: str) -> None:
)
_lap("export prefill")

print(f"Exporting target_verify (T = {verify_len})...")
# Dynamic chain length: verify window T = K+1 dynamic in [2, MATVEC_MAX_M]
# so K is a runtime parameter (one .pte serves K in [1, MATVEC_MAX_M-1], the
# runner picks it with --chain). max == MATVEC_MAX_M so M never straddles the
# INT4 dispatch threshold -> resolves to the small-M GEMM over the whole
# range. min=2 is the K=1 window; the target's min_forward_len was a
# conservative export note -- the gemma4 mask traces correctly down to T=2.
verify_max = int4_dispatch.MATVEC_MAX_M
verify_dim = Dim("verify_len", min=2, max=verify_max)
print(f"Exporting target_verify (T in [2, {verify_max}], example {verify_len})...")
# The mid-M SDPA key bound is the dynamic length of kv_window: valid KV
# positions = anchor_pos + chain + 1, in [verify_len, max_seq_len].
kv_dim = Dim("kv_len", min=verify_len, max=target_config.max_seq_len)
# positions = anchor_pos + K + 1, in [2, max_seq_len].
kv_dim = Dim("kv_len", min=2, max=target_config.max_seq_len)
with torch.no_grad():
verify_ep = export(
_TargetVerify(spec),
Expand All @@ -176,7 +190,7 @@ def _lap(msg: str) -> None:
torch.arange(verify_len, dtype=torch.long),
torch.zeros((8 * verify_len,), dtype=torch.int32),
),
dynamic_shapes=({}, {}, {0: kv_dim}),
dynamic_shapes=({1: verify_dim}, {0: verify_dim}, {0: kv_dim}),
strict=True,
)
_lap("export target_verify")
Expand Down Expand Up @@ -359,28 +373,31 @@ def main() -> None:
f"--max-prefill (got {args.max_prefill}) or --max-seq-len (got "
f"{args.max_seq_len})"
)
# target_verify is a single static forward of chain+1 tokens: it must fit the
# small-M GEMM (chain+1 <= _MATVEC_MAX_M) and the target's per-forward bounds
# [min_forward_len, max_forward].
# target_verify is exported dynamic over T in [2, _MATVEC_MAX_M] (see
# verify_dim), so --chain only sets the default/example K baked as
# get_chain_len; one .pte serves any K in [1, _MATVEC_MAX_M - 1]. The example
# K+1 must still fit the small-M GEMM (<= _MATVEC_MAX_M), the dynamic lower
# bound (K >= 1 => window >= 2), and the target's per-forward max.
# min_forward_len is a conservative prefill note and does NOT bound verify.
verify_len = args.chain + 1
if verify_len > _MATVEC_MAX_M:
p.error(
f"--chain {args.chain} (verify window {verify_len}) exceeds the "
f"INT4 small-M GEMM limit {_MATVEC_MAX_M}"
)
if verify_len < spec_t.min_forward_len:
if verify_len < 2:
p.error(
f"--chain {args.chain} (verify window {verify_len}) is below the "
f"target's minimum forward length {spec_t.min_forward_len}"
f"minimum verify window of 2 (need --chain >= 1)"
)
if verify_len > min(args.max_seq_len - 1, max_forward):
p.error(
f"--chain {args.chain} (verify window {verify_len}) exceeds the "
f"target's per-forward limit {min(args.max_seq_len - 1, max_forward)}"
)
# Route the static chain_len+1 verify forward to the small-M INT4 GEMM by
# raising the dispatch threshold for this export only; restore it so the
# process-global default (4) is unchanged for any later use.
# Route the verify forward (dynamic T in [2, _MATVEC_MAX_M]) to the small-M
# INT4 GEMM by raising the dispatch threshold for this export only; restore
# it so the process-global default (4) is unchanged for any later use.
import executorch.backends.cuda.int4_dispatch as int4_dispatch

saved_threshold = int4_dispatch.MATVEC_MAX_M
Expand Down
8 changes: 7 additions & 1 deletion examples/models/eagle3/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ DEFINE_bool(
"current export feeds target_verify a kv_window whose length changes every "
"round, so capture is unsafe (stale-shape replay). Only enable for an "
"export whose target_verify inputs all have stable shapes.");
DEFINE_int32(
chain,
-1,
"Override chain length K at runtime (<=0 uses the .pte's get_chain_len). "
"Requires a dynamic-T verify export; clamped to [1, 7] (verify M=K+1<=8).");
// Chat template + stop tokens default to Gemma 4 IT; override for other models.
DEFINE_string(
chat_prefix,
Expand Down Expand Up @@ -265,7 +270,8 @@ int main(int argc, char** argv) {
const int64_t max_prefill = meta("get_max_prefill_chunk");
const int64_t min_prefill = meta("get_min_prefill_chunk");
const int64_t max_seq_len = meta("get_max_seq_len");
const int64_t K = chain_len;
const int64_t K_req = (FLAGS_chain > 0) ? FLAGS_chain : chain_len;
const int64_t K = (K_req < 1) ? 1 : (K_req > 7 ? 7 : K_req);

// EOS: tokenizer/metadata ids, the configured eos, any --stop_ids, and the
// encoded --stop_token delimiter (all default to the Gemma 4 IT conventions).
Expand Down
5 changes: 3 additions & 2 deletions examples/models/eagle3/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,10 @@ class TargetSpec:
# config -> max tokens accepted in one target forward (e.g. a sliding ring
# buffer caps it at 2*window; a flat-cache model uses max_seq_len-1).
max_forward_len: Callable[[Any], int]
# Minimum tokens in ANY single target forward the export accepts (some
# Minimum tokens the export specializes for a target forward (some
# attention-mask implementations specialize a lower bound under
# torch.export). Applies to both prefill and the static target_verify window.
# torch.export). Bounds prefill only; target_verify is exported dynamic over
# T in [2, MATVEC_MAX_M] and is not constrained by it.
min_forward_len: int


Expand Down
5 changes: 3 additions & 2 deletions examples/models/gemma4_31b/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,9 @@ def forward(
# layers; falls back to F.sdpa otherwise (M==1 decode, large-M prefill,
# sliding layers, or when disabled). Imported lazily and only when
# enabled so a CPU / non-mid-M import of the model never pulls in triton
# or the CUDA backend. M is static per exported method, so the mid-M
# branch resolves at trace time.
# or the CUDA backend. M (the verify window) is the dynamic verify length
# bounded to [2, MIDM_MAX_M] by the export, so the mid-M branch resolves
# at trace time.
if self.use_midm_sdpa:
from executorch.backends.cuda.triton.kernels.sdpa_midm import midm_sdpa

Expand Down
Loading