From 42ee17bafaf19c97f214bace0374046df32e9954 Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Wed, 17 Jun 2026 12:02:30 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- backends/cuda/triton/kernels/sdpa_midm.py | 24 ++++++---- examples/models/eagle3/export.py | 57 +++++++++++++++-------- examples/models/eagle3/main.cpp | 8 +++- examples/models/eagle3/target.py | 5 +- examples/models/gemma4_31b/model.py | 5 +- 5 files changed, 65 insertions(+), 34 deletions(-) diff --git a/backends/cuda/triton/kernels/sdpa_midm.py b/backends/cuda/triton/kernels/sdpa_midm.py index 66de8fbf731..177774f9a52 100644 --- a/backends/cuda/triton/kernels/sdpa_midm.py +++ b/backends/cuda/triton/kernels/sdpa_midm.py @@ -44,8 +44,9 @@ # path is appropriate (enough rows to amortize a tiled kernel). MIDM_MAX_M = 8 -# Number of key-range partitions for split-K. The verify method exports a static -# M / B / H / D, so the partial buffers and grid are static-shaped; only the +# Number of key-range partitions for split-K. B / H / D are static for the +# exported verify method; M is the dynamic verify length (bounded by MIDM_MAX_M, +# BLOCK_M covers it), so the grid (NUM_SPLITS x B*H) is static-shaped; the # per-split chunk size (derived from the dynamic valid_len) is a runtime scalar. # 32 splits x (B*H) heads gives ~1K CTAs at the gemma4 global shape -- ample # occupancy on an A100 while keeping the fp32 partials small. @@ -89,9 +90,9 @@ def _sdpa_midm_splitk_kernel( valid_len, chunk_size, scale, + M, H: tl.constexpr, HKV: tl.constexpr, - M: tl.constexpr, D: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, @@ -200,8 +201,8 @@ def _sdpa_midm_reduce_kernel( soh, som, sod, + M, NUM_SPLITS: tl.constexpr, - M: tl.constexpr, D: tl.constexpr, BLOCK_M: tl.constexpr, ): @@ -267,13 +268,17 @@ def _sdpa_midm_op( ``valid_len`` (max valid position + 1) bounds the key range; it is split into NUM_SPLITS chunks of ``chunk_size`` keys computed in parallel, then reduced. - M / B / H / D are static for the exported verify method, so only chunk_size is - a runtime (backed-SymInt) scalar -- the grid and partial buffers are static. + B / H / D are static for the exported verify method; M is the dynamic verify + length (bounded by MIDM_MAX_M). chunk_size (from the dynamic valid_len) is a + runtime (backed-SymInt) scalar; the grid (NUM_SPLITS x B*H) is static. """ B, H, M, D = q.shape HKV = k.shape[1] out = torch.empty_like(q) - BLOCK_M = max(16, triton.next_power_of_2(M)) + # M <= MIDM_MAX_M (8) => next_pow2(M) <= 8 => max(16, .) is always 16. Hardcode + # so M can be a runtime (dynamic verify) dim -- next_power_of_2 can't take a + # SymInt, and M is a kernel runtime arg used only for the offs_m < M masks. + BLOCK_M = 16 # gemma4 global layers use D=512; a wide key tile + pipelining overflow SMEM # there, so shrink both. Small D can afford more. BLOCK_N, num_stages = (32, 1) if D >= 512 else (64, 2) @@ -381,8 +386,9 @@ def midm_sdpa( ) -> torch.Tensor: """Dispatch: the mid-M op for a small query window when enabled; otherwise the standard F.sdpa the model already uses (which the replacement pass swaps - for triton::sdpa). M is static per exported method, so the branch resolves at - trace time. ``valid_len`` is the shared per-forward key bound.""" + for triton::sdpa). M (q.shape[2]) is the dynamic verify length; its exported + range [2, MIDM_MAX_M] satisfies this guard, so the branch resolves at export. + ``valid_len`` is the shared per-forward key bound.""" M = q.shape[2] if enable and 2 <= M <= MIDM_MAX_M: return sdpa_midm(q, k, v, input_pos, scale, valid_len=valid_len) diff --git a/examples/models/eagle3/export.py b/examples/models/eagle3/export.py index e171cc4b505..e26f4d91538 100644 --- a/examples/models/eagle3/export.py +++ b/examples/models/eagle3/export.py @@ -9,7 +9,8 @@ Three methods are lowered together so they share mutable state: - "prefill": target prompt prefill (T in [get_min_prefill_chunk, get_max_prefill_chunk]) -> next token + fused feature. - - "target_verify": target forward over the candidate chain (static T=chain+1) + - "target_verify": target forward over the candidate chain (dynamic T in + [2, MATVEC_MAX_M] = K+1; --chain selects K at runtime) -> per-position greedy ids + fused feature. - "draft_decode": draft proposal over its KV cache (T>=1; seed with T>1, step with T=1) -> proposed target ids + recurrent feature. @@ -34,9 +35,12 @@ supported. Scope (this is a fixed-shape ExecuTorch artifact, not a generic EAGLE runtime): -chain length, the chain_len+1 verify window, the prefill/draft dynamic ranges, -the CUDA backend, and the small-M INT4 dispatch policy are all baked at export — -vary the target, chain length, or backend by re-exporting. The caller is +the target, the prefill/draft/verify dynamic ranges, the CUDA backend, and the +small-M INT4 dispatch policy are all baked at export — vary the target or backend +by re-exporting. Chain length K is NOT baked: target_verify is dynamic over +T in [2, MATVEC_MAX_M], so one .pte serves any K in [1, MATVEC_MAX_M - 1] +(get_chain_len is only the default) and the runner selects K with --chain. The +caller is responsible for pairing a target, draft, and tokenizer that were trained together: only target/draft hidden size is checked here; tokenizer identity, target vocab size, the d2t/t2d mapping, the tap-layer convention, and the draft's @@ -56,8 +60,9 @@ from executorch.examples.models.eagle3.speculator import Eagle3Speculator from executorch.examples.models.eagle3.target import TARGETS -# Route the static chain_len+1 verify forward to the small-M INT4 GEMM. Must be -# <= the shim's GEMM_MAX_M (8 in int4_plain_mm.cuh) and >= the largest chain+1. +# Route the verify forward to the small-M INT4 GEMM. target_verify is dynamic +# over T in [2, _MATVEC_MAX_M] (chain_len+1 is only the export example), and the +# whole range must be <= the shim's GEMM_MAX_M (8 in int4_plain_mm.cuh). # Set locally on int4_dispatch (not the global default) so other models' exports # keep MATVEC_MAX_M=4 and their dynamic prefill ranges are unaffected. _MATVEC_MAX_M = 8 @@ -138,8 +143,9 @@ def _lap(msg: str) -> None: hidden = spec.draft.config.hidden_size draft_vocab_size = spec.draft.config.draft_vocab_size # Verify re-feeds the last confirmed token (its logits are the folded bonus) - # plus the K proposals: a fixed chain_len+1 window in one target forward. With - # chain_len+1 <= MATVEC_MAX_M the verify forward stays on the small-M GEMM + # plus the K proposals: a chain_len+1 window -- only the export example. + # target_verify is lowered dynamic over T in [2, MATVEC_MAX_M], and with the + # whole range <= MATVEC_MAX_M the verify forward stays on the small-M GEMM # rather than the dequant path. verify_len = chain_len + 1 # prefill's dynamic length must take a single INT4 dispatch branch over its @@ -164,10 +170,18 @@ def _lap(msg: str) -> None: ) _lap("export prefill") - print(f"Exporting target_verify (T = {verify_len})...") + # Dynamic chain length: verify window T = K+1 dynamic in [2, MATVEC_MAX_M] + # so K is a runtime parameter (one .pte serves K in [1, MATVEC_MAX_M-1], the + # runner picks it with --chain). max == MATVEC_MAX_M so M never straddles the + # INT4 dispatch threshold -> resolves to the small-M GEMM over the whole + # range. min=2 is the K=1 window; the target's min_forward_len was a + # conservative export note -- the gemma4 mask traces correctly down to T=2. + verify_max = int4_dispatch.MATVEC_MAX_M + verify_dim = Dim("verify_len", min=2, max=verify_max) + print(f"Exporting target_verify (T in [2, {verify_max}], example {verify_len})...") # The mid-M SDPA key bound is the dynamic length of kv_window: valid KV - # positions = anchor_pos + chain + 1, in [verify_len, max_seq_len]. - kv_dim = Dim("kv_len", min=verify_len, max=target_config.max_seq_len) + # positions = anchor_pos + K + 1, in [2, max_seq_len]. + kv_dim = Dim("kv_len", min=2, max=target_config.max_seq_len) with torch.no_grad(): verify_ep = export( _TargetVerify(spec), @@ -176,7 +190,7 @@ def _lap(msg: str) -> None: torch.arange(verify_len, dtype=torch.long), torch.zeros((8 * verify_len,), dtype=torch.int32), ), - dynamic_shapes=({}, {}, {0: kv_dim}), + dynamic_shapes=({1: verify_dim}, {0: verify_dim}, {0: kv_dim}), strict=True, ) _lap("export target_verify") @@ -359,28 +373,31 @@ def main() -> None: f"--max-prefill (got {args.max_prefill}) or --max-seq-len (got " f"{args.max_seq_len})" ) - # target_verify is a single static forward of chain+1 tokens: it must fit the - # small-M GEMM (chain+1 <= _MATVEC_MAX_M) and the target's per-forward bounds - # [min_forward_len, max_forward]. + # target_verify is exported dynamic over T in [2, _MATVEC_MAX_M] (see + # verify_dim), so --chain only sets the default/example K baked as + # get_chain_len; one .pte serves any K in [1, _MATVEC_MAX_M - 1]. The example + # K+1 must still fit the small-M GEMM (<= _MATVEC_MAX_M), the dynamic lower + # bound (K >= 1 => window >= 2), and the target's per-forward max. + # min_forward_len is a conservative prefill note and does NOT bound verify. verify_len = args.chain + 1 if verify_len > _MATVEC_MAX_M: p.error( f"--chain {args.chain} (verify window {verify_len}) exceeds the " f"INT4 small-M GEMM limit {_MATVEC_MAX_M}" ) - if verify_len < spec_t.min_forward_len: + if verify_len < 2: p.error( f"--chain {args.chain} (verify window {verify_len}) is below the " - f"target's minimum forward length {spec_t.min_forward_len}" + f"minimum verify window of 2 (need --chain >= 1)" ) if verify_len > min(args.max_seq_len - 1, max_forward): p.error( f"--chain {args.chain} (verify window {verify_len}) exceeds the " f"target's per-forward limit {min(args.max_seq_len - 1, max_forward)}" ) - # Route the static chain_len+1 verify forward to the small-M INT4 GEMM by - # raising the dispatch threshold for this export only; restore it so the - # process-global default (4) is unchanged for any later use. + # Route the verify forward (dynamic T in [2, _MATVEC_MAX_M]) to the small-M + # INT4 GEMM by raising the dispatch threshold for this export only; restore + # it so the process-global default (4) is unchanged for any later use. import executorch.backends.cuda.int4_dispatch as int4_dispatch saved_threshold = int4_dispatch.MATVEC_MAX_M diff --git a/examples/models/eagle3/main.cpp b/examples/models/eagle3/main.cpp index 758a8bbf6b1..8bb5997402c 100644 --- a/examples/models/eagle3/main.cpp +++ b/examples/models/eagle3/main.cpp @@ -109,6 +109,11 @@ DEFINE_bool( "current export feeds target_verify a kv_window whose length changes every " "round, so capture is unsafe (stale-shape replay). Only enable for an " "export whose target_verify inputs all have stable shapes."); +DEFINE_int32( + chain, + -1, + "Override chain length K at runtime (<=0 uses the .pte's get_chain_len). " + "Requires a dynamic-T verify export; clamped to [1, 7] (verify M=K+1<=8)."); // Chat template + stop tokens default to Gemma 4 IT; override for other models. DEFINE_string( chat_prefix, @@ -265,7 +270,8 @@ int main(int argc, char** argv) { const int64_t max_prefill = meta("get_max_prefill_chunk"); const int64_t min_prefill = meta("get_min_prefill_chunk"); const int64_t max_seq_len = meta("get_max_seq_len"); - const int64_t K = chain_len; + const int64_t K_req = (FLAGS_chain > 0) ? FLAGS_chain : chain_len; + const int64_t K = (K_req < 1) ? 1 : (K_req > 7 ? 7 : K_req); // EOS: tokenizer/metadata ids, the configured eos, any --stop_ids, and the // encoded --stop_token delimiter (all default to the Gemma 4 IT conventions). diff --git a/examples/models/eagle3/target.py b/examples/models/eagle3/target.py index 59dc73aac92..556748e756d 100644 --- a/examples/models/eagle3/target.py +++ b/examples/models/eagle3/target.py @@ -73,9 +73,10 @@ class TargetSpec: # config -> max tokens accepted in one target forward (e.g. a sliding ring # buffer caps it at 2*window; a flat-cache model uses max_seq_len-1). max_forward_len: Callable[[Any], int] - # Minimum tokens in ANY single target forward the export accepts (some + # Minimum tokens the export specializes for a target forward (some # attention-mask implementations specialize a lower bound under - # torch.export). Applies to both prefill and the static target_verify window. + # torch.export). Bounds prefill only; target_verify is exported dynamic over + # T in [2, MATVEC_MAX_M] and is not constrained by it. min_forward_len: int diff --git a/examples/models/gemma4_31b/model.py b/examples/models/gemma4_31b/model.py index 80555f01d26..79b526938c6 100644 --- a/examples/models/gemma4_31b/model.py +++ b/examples/models/gemma4_31b/model.py @@ -380,8 +380,9 @@ def forward( # layers; falls back to F.sdpa otherwise (M==1 decode, large-M prefill, # sliding layers, or when disabled). Imported lazily and only when # enabled so a CPU / non-mid-M import of the model never pulls in triton - # or the CUDA backend. M is static per exported method, so the mid-M - # branch resolves at trace time. + # or the CUDA backend. M (the verify window) is the dynamic verify length + # bounded to [2, MIDM_MAX_M] by the export, so the mid-M branch resolves + # at trace time. if self.use_midm_sdpa: from executorch.backends.cuda.triton.kernels.sdpa_midm import midm_sdpa