From a97ffacb37710aef8270147b7c5f197521aee66d Mon Sep 17 00:00:00 2001 From: Digant Desai Date: Wed, 17 Jun 2026 12:02:41 -0700 Subject: [PATCH] Update [ghstack-poisoned] --- examples/models/eagle3/export.py | 8 +- examples/models/eagle3/main.cpp | 140 ++++++++++++++++++++++++------- 2 files changed, 116 insertions(+), 32 deletions(-) diff --git a/examples/models/eagle3/export.py b/examples/models/eagle3/export.py index e26f4d91538..2cee66cbff6 100644 --- a/examples/models/eagle3/export.py +++ b/examples/models/eagle3/export.py @@ -251,6 +251,7 @@ def _partitioner(name: str): "get_n_layers": target_config.num_hidden_layers, "get_max_prefill_chunk": max_prefill, "get_min_prefill_chunk": target_min, + "get_sliding_window": target_config.sliding_window, "get_chain_len": chain_len, "get_draft_vocab_size": draft_vocab_size, "use_kv_cache": True, @@ -308,9 +309,10 @@ def main() -> None: "--max-prefill", type=int, default=512, - help="Max prefill length: AOTI compiles prefill kernels for up to this T " - "and the whole prompt must fit in one prefill (the runner does not chunk). " - "Smaller compiles faster.", + help="Max prefill chunk: AOTI compiles prefill kernels for up to this T. " + "The runner chunks the prompt into <= this many tokens per prefill (a " + "longer prompt is fed as multiple chunks), so this bounds compile time, " + "not prompt length. Smaller compiles faster.", ) p.add_argument( "--chain", type=int, default=4, help="Draft chain length K (verify K+1)." diff --git a/examples/models/eagle3/main.cpp b/examples/models/eagle3/main.cpp index 8bb5997402c..c1b22bc3b4f 100644 --- a/examples/models/eagle3/main.cpp +++ b/examples/models/eagle3/main.cpp @@ -31,9 +31,13 @@ // forward per round (speedup ~= acceptance length tau). // // Features round-trip through the host between method calls (D2H copy + re-feed -// as host tensors). They are small (<= max_prefill x H bf16), so the cost is -// negligible next to the INT4 31B target forward, and it keeps device-tensor -// lifetimes simple. +// as host tensors), which keeps device-tensor lifetimes simple. Chunked prefill +// concatenates per-position features for the whole prompt before draft seeding, +// so the host buffer is prompt_len x H bf16 (~672 MiB at 64K context, H=5376), +// scaling with prompt_len rather than max_prefill. That is negligible next to +// the INT4 31B target forward at today's context lengths; stream draft seeding +// as each prefill chunk completes if it becomes a memory/perf concern at larger +// contexts or hidden sizes. // // Run (after exporting model.pte + aoti_cuda_blob.ptd via export.py, sourcing // the CUDA env, and building the eagle3-cuda preset): @@ -268,6 +272,22 @@ int main(int argc, char** argv) { }; const int64_t chain_len = meta("get_chain_len"); const int64_t max_prefill = meta("get_max_prefill_chunk"); + // Prefill chunks must not exceed the sliding window: a chunk larger than the + // window overflows the 2*window ring cache across chunk boundaries, + // truncating sliding attention for the first ~(chunk-window) queries of each + // chunk (the global flat-cache layers stay exact). Prefer get_sliding_window + // when the export provides it, else fall back to max_prefill/2. + int64_t prefill_chunk = max_prefill / 2; + { + auto sw = module->get("get_sliding_window"); + if (sw.ok()) { + prefill_chunk = sw->toScalar().to(); + } + } + // Also bound by the exported prefill range: get_max_prefill_chunk is the + // largest T the prefill kernels were compiled for, which need not be + // 2*sliding_window (small --max-prefill with a larger window), so cap here. + prefill_chunk = std::min(prefill_chunk, max_prefill); const int64_t min_prefill = meta("get_min_prefill_chunk"); const int64_t max_seq_len = meta("get_max_seq_len"); const int64_t K_req = (FLAGS_chain > 0) ? FLAGS_chain : chain_len; @@ -308,16 +328,16 @@ int main(int argc, char** argv) { prompt.insert(prompt.begin(), static_cast(FLAGS_bos_id)); } const int64_t L = static_cast(prompt.size()); - // The runner does not chunk: the whole prompt must fit one prefill, and its - // length must be within the exported prefill range [min_prefill, - // max_prefill]. - if (L > max_prefill) { + // A single prefill forward caps at max_prefill (the sliding-ring 2*window + // limit), so prompts beyond that are looped in <= max_prefill chunks below; + // the flat global KV cache accumulates across chunks. The prompt only has to + // fit the exported context (its features then seed the speculative loop). + if (L >= max_seq_len) { ET_LOG( Error, - "Prompt (%" PRId64 " tokens) exceeds max_prefill %" PRId64 - "; this runner does not chunk prefill.", + "Prompt (%" PRId64 " tokens) does not fit max_seq_len %" PRId64, L, - max_prefill); + max_seq_len); return 1; } if (L < min_prefill) { @@ -333,8 +353,9 @@ int main(int argc, char** argv) { // The prefill bonus token is always emittable (no KV write past the prompt). // Each speculative round, however, writes a K-token verify window, so it // needs anchor_pos + K <= max_seq_len - 1 (enforced in the loop below). Cap - // the total at the positions available; max_new >= 1 since L <= max_prefill < - // max_seq_len. + // the total at the positions available; max_new >= 1 since L < max_seq_len + // (L may exceed max_prefill -- the prompt is fed as chunks; L >= max_seq_len + // is rejected above). int64_t max_new = std::min(FLAGS_max_new_tokens, max_seq_len - L); printf( "Prompt tokens: %" PRId64 ", chain K=%" PRId64 ", max_new=%" PRId64 "\n", @@ -433,22 +454,48 @@ int main(int argc, char** argv) { stats.model_load_end_ms = llm::time_in_ms(); stats.inference_start_ms = stats.model_load_end_ms; - // --- Prefill: target over the prompt -> bonus token + per-position feature. - // --- - tok_buf = prompt; - pos_buf.resize(L); - for (int64_t i = 0; i < L; i++) { - pos_buf[i] = i; - } - auto pf = module->execute( - "prefill", {EValue(long_tensor(tok_buf)), EValue(pos_tensor(pos_buf))}); - if (pf.error() != Error::Ok) { - ET_LOG(Error, "prefill failed"); - return 1; + // The exported prefill forward accepts T in [min_prefill, max_prefill]; pick + // the next chunk so the running tail never drops below min_prefill (it would + // be an out-of-range shape). All but the last one or two chunks are + // max_prefill. + auto next_chunk = [&](int64_t done) { + int64_t remaining = L - done; + int64_t len = std::min(remaining, prefill_chunk); + if (remaining - len > 0 && remaining - len < min_prefill) { + len = remaining - min_prefill; + } + return len; + }; + + // --- Prefill: target over the prompt (chunked to respect the prefill cap) -> + // bonus token + per-position feature. The flat target KV cache accumulates + // across chunks; the bonus token is the last chunk's output, and the + // per-position features of every chunk are concatenated to seed the draft. + HostFeature feat_prompt; + int64_t anchor = 0; + int64_t prefill_pos = 0; + while (prefill_pos < L) { + int64_t chunk_len = next_chunk(prefill_pos); + tok_buf.assign( + prompt.begin() + prefill_pos, prompt.begin() + prefill_pos + chunk_len); + pos_buf.resize(chunk_len); + for (int64_t i = 0; i < chunk_len; i++) { + pos_buf[i] = prefill_pos + i; + } + auto pf = module->execute( + "prefill", {EValue(long_tensor(tok_buf)), EValue(pos_tensor(pos_buf))}); + if (pf.error() != Error::Ok) { + ET_LOG(Error, "prefill failed at pos %" PRId64, prefill_pos); + return 1; + } + anchor = read_ids(pf->at(0).toTensor())[0]; // bonus token after the prompt + HostFeature chunk_feat = read_feature(pf->at(1).toTensor()); + feat_prompt.H = chunk_feat.H; + feat_prompt.T += chunk_feat.T; + feat_prompt.data.insert( + feat_prompt.data.end(), chunk_feat.data.begin(), chunk_feat.data.end()); + prefill_pos += chunk_len; } - int64_t anchor = - read_ids(pf->at(0).toTensor())[0]; // bonus token at position L - HostFeature feat_prompt = read_feature(pf->at(1).toTensor()); const int64_t H = feat_prompt.H; int64_t anchor_pos = L; @@ -475,10 +522,45 @@ int main(int argc, char** argv) { if (speculate) { // Seed the first chain (shifted): draft slot p pairs feat_prompt[p] with // token_{p+1}; the last slot pairs feat_prompt[L-1] with the bonus and - // predicts position L+1. + // predicts position L+1. Seed in <= max_prefill chunks (draft_decode shares + // the prefill shape range), each contiguous from the previous so the draft + // KV cache fills; the last chunk's last row predicts proposal 0 and carries + // the recurrent feature, then K-1 recurrent steps follow (mirroring chain). std::vector seed_tokens(prompt.begin() + 1, prompt.end()); seed_tokens.push_back(anchor); - proposals = chain(seed_tokens, feat_prompt, 0); + std::vector ids; + HostFeature last_g; + for (int64_t seed_pos = 0; seed_pos < L;) { + int64_t chunk_len = next_chunk(seed_pos); + std::vector chunk_tokens( + seed_tokens.begin() + seed_pos, + seed_tokens.begin() + seed_pos + chunk_len); + draft_decode( + chunk_tokens, + feat_prompt.data.data() + seed_pos * H, + chunk_len, + H, + seed_pos, + ids, + last_g); + seed_pos += chunk_len; + } + proposals.push_back(ids.back()); + int64_t last_pos = L - 1; + for (int64_t k = 1; k < K; k++) { + std::vector step_ids; + HostFeature step_g; + draft_decode( + {proposals.back()}, + last_g.data.data(), + 1, + last_g.H, + last_pos + k, + step_ids, + step_g); + proposals.push_back(step_ids[0]); + last_g = step_g; + } } // Stable buffers for target_verify (fixed length K+1) so the CUDA graph