Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions examples/models/eagle3/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def _partitioner(name: str):
"get_n_layers": target_config.num_hidden_layers,
"get_max_prefill_chunk": max_prefill,
"get_min_prefill_chunk": target_min,
"get_sliding_window": target_config.sliding_window,
"get_chain_len": chain_len,
"get_draft_vocab_size": draft_vocab_size,
"use_kv_cache": True,
Expand Down Expand Up @@ -308,9 +309,10 @@ def main() -> None:
"--max-prefill",
type=int,
default=512,
help="Max prefill length: AOTI compiles prefill kernels for up to this T "
"and the whole prompt must fit in one prefill (the runner does not chunk). "
"Smaller compiles faster.",
help="Max prefill chunk: AOTI compiles prefill kernels for up to this T. "
"The runner chunks the prompt into <= this many tokens per prefill (a "
"longer prompt is fed as multiple chunks), so this bounds compile time, "
"not prompt length. Smaller compiles faster.",
)
p.add_argument(
"--chain", type=int, default=4, help="Draft chain length K (verify K+1)."
Expand Down
140 changes: 111 additions & 29 deletions examples/models/eagle3/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,13 @@
// forward per round (speedup ~= acceptance length tau).
//
// Features round-trip through the host between method calls (D2H copy + re-feed
// as host tensors). They are small (<= max_prefill x H bf16), so the cost is
// negligible next to the INT4 31B target forward, and it keeps device-tensor
// lifetimes simple.
// as host tensors), which keeps device-tensor lifetimes simple. Chunked prefill
// concatenates per-position features for the whole prompt before draft seeding,
// so the host buffer is prompt_len x H bf16 (~672 MiB at 64K context, H=5376),
// scaling with prompt_len rather than max_prefill. That is negligible next to
// the INT4 31B target forward at today's context lengths; stream draft seeding
// as each prefill chunk completes if it becomes a memory/perf concern at larger
// contexts or hidden sizes.
//
// Run (after exporting model.pte + aoti_cuda_blob.ptd via export.py, sourcing
// the CUDA env, and building the eagle3-cuda preset):
Expand Down Expand Up @@ -268,6 +272,22 @@ int main(int argc, char** argv) {
};
const int64_t chain_len = meta("get_chain_len");
const int64_t max_prefill = meta("get_max_prefill_chunk");
// Prefill chunks must not exceed the sliding window: a chunk larger than the
// window overflows the 2*window ring cache across chunk boundaries,
// truncating sliding attention for the first ~(chunk-window) queries of each
// chunk (the global flat-cache layers stay exact). Prefer get_sliding_window
// when the export provides it, else fall back to max_prefill/2.
int64_t prefill_chunk = max_prefill / 2;
{
auto sw = module->get("get_sliding_window");
if (sw.ok()) {
prefill_chunk = sw->toScalar().to<int64_t>();
}
}
// Also bound by the exported prefill range: get_max_prefill_chunk is the
// largest T the prefill kernels were compiled for, which need not be
// 2*sliding_window (small --max-prefill with a larger window), so cap here.
prefill_chunk = std::min(prefill_chunk, max_prefill);
const int64_t min_prefill = meta("get_min_prefill_chunk");
const int64_t max_seq_len = meta("get_max_seq_len");
const int64_t K_req = (FLAGS_chain > 0) ? FLAGS_chain : chain_len;
Expand Down Expand Up @@ -308,16 +328,16 @@ int main(int argc, char** argv) {
prompt.insert(prompt.begin(), static_cast<int64_t>(FLAGS_bos_id));
}
const int64_t L = static_cast<int64_t>(prompt.size());
// The runner does not chunk: the whole prompt must fit one prefill, and its
// length must be within the exported prefill range [min_prefill,
// max_prefill].
if (L > max_prefill) {
// A single prefill forward caps at max_prefill (the sliding-ring 2*window
// limit), so prompts beyond that are looped in <= max_prefill chunks below;
// the flat global KV cache accumulates across chunks. The prompt only has to
// fit the exported context (its features then seed the speculative loop).
if (L >= max_seq_len) {
ET_LOG(
Error,
"Prompt (%" PRId64 " tokens) exceeds max_prefill %" PRId64
"; this runner does not chunk prefill.",
"Prompt (%" PRId64 " tokens) does not fit max_seq_len %" PRId64,
L,
max_prefill);
max_seq_len);
return 1;
}
if (L < min_prefill) {
Expand All @@ -333,8 +353,9 @@ int main(int argc, char** argv) {
// The prefill bonus token is always emittable (no KV write past the prompt).
// Each speculative round, however, writes a K-token verify window, so it
// needs anchor_pos + K <= max_seq_len - 1 (enforced in the loop below). Cap
// the total at the positions available; max_new >= 1 since L <= max_prefill <
// max_seq_len.
// the total at the positions available; max_new >= 1 since L < max_seq_len
// (L may exceed max_prefill -- the prompt is fed as chunks; L >= max_seq_len
// is rejected above).
int64_t max_new = std::min<int64_t>(FLAGS_max_new_tokens, max_seq_len - L);
printf(
"Prompt tokens: %" PRId64 ", chain K=%" PRId64 ", max_new=%" PRId64 "\n",
Expand Down Expand Up @@ -433,22 +454,48 @@ int main(int argc, char** argv) {
stats.model_load_end_ms = llm::time_in_ms();
stats.inference_start_ms = stats.model_load_end_ms;

// --- Prefill: target over the prompt -> bonus token + per-position feature.
// ---
tok_buf = prompt;
pos_buf.resize(L);
for (int64_t i = 0; i < L; i++) {
pos_buf[i] = i;
}
auto pf = module->execute(
"prefill", {EValue(long_tensor(tok_buf)), EValue(pos_tensor(pos_buf))});
if (pf.error() != Error::Ok) {
ET_LOG(Error, "prefill failed");
return 1;
// The exported prefill forward accepts T in [min_prefill, max_prefill]; pick
// the next chunk so the running tail never drops below min_prefill (it would
// be an out-of-range shape). All but the last one or two chunks are
// max_prefill.
auto next_chunk = [&](int64_t done) {
int64_t remaining = L - done;
int64_t len = std::min(remaining, prefill_chunk);
if (remaining - len > 0 && remaining - len < min_prefill) {
len = remaining - min_prefill;
}
return len;
};

// --- Prefill: target over the prompt (chunked to respect the prefill cap) ->
// bonus token + per-position feature. The flat target KV cache accumulates
// across chunks; the bonus token is the last chunk's output, and the
// per-position features of every chunk are concatenated to seed the draft.
HostFeature feat_prompt;
int64_t anchor = 0;
int64_t prefill_pos = 0;
while (prefill_pos < L) {
int64_t chunk_len = next_chunk(prefill_pos);
tok_buf.assign(
prompt.begin() + prefill_pos, prompt.begin() + prefill_pos + chunk_len);
pos_buf.resize(chunk_len);
for (int64_t i = 0; i < chunk_len; i++) {
pos_buf[i] = prefill_pos + i;
}
auto pf = module->execute(
"prefill", {EValue(long_tensor(tok_buf)), EValue(pos_tensor(pos_buf))});
if (pf.error() != Error::Ok) {
ET_LOG(Error, "prefill failed at pos %" PRId64, prefill_pos);
return 1;
}
anchor = read_ids(pf->at(0).toTensor())[0]; // bonus token after the prompt
HostFeature chunk_feat = read_feature(pf->at(1).toTensor());
feat_prompt.H = chunk_feat.H;
feat_prompt.T += chunk_feat.T;
feat_prompt.data.insert(
feat_prompt.data.end(), chunk_feat.data.begin(), chunk_feat.data.end());
prefill_pos += chunk_len;
}
int64_t anchor =
read_ids(pf->at(0).toTensor())[0]; // bonus token at position L
HostFeature feat_prompt = read_feature(pf->at(1).toTensor());
const int64_t H = feat_prompt.H;
int64_t anchor_pos = L;

Expand All @@ -475,10 +522,45 @@ int main(int argc, char** argv) {
if (speculate) {
// Seed the first chain (shifted): draft slot p pairs feat_prompt[p] with
// token_{p+1}; the last slot pairs feat_prompt[L-1] with the bonus and
// predicts position L+1.
// predicts position L+1. Seed in <= max_prefill chunks (draft_decode shares
// the prefill shape range), each contiguous from the previous so the draft
// KV cache fills; the last chunk's last row predicts proposal 0 and carries
// the recurrent feature, then K-1 recurrent steps follow (mirroring chain).
std::vector<int64_t> seed_tokens(prompt.begin() + 1, prompt.end());
seed_tokens.push_back(anchor);
proposals = chain(seed_tokens, feat_prompt, 0);
std::vector<int64_t> ids;
HostFeature last_g;
for (int64_t seed_pos = 0; seed_pos < L;) {
int64_t chunk_len = next_chunk(seed_pos);
std::vector<int64_t> chunk_tokens(
seed_tokens.begin() + seed_pos,
seed_tokens.begin() + seed_pos + chunk_len);
draft_decode(
chunk_tokens,
feat_prompt.data.data() + seed_pos * H,
chunk_len,
H,
seed_pos,
ids,
last_g);
seed_pos += chunk_len;
}
proposals.push_back(ids.back());
int64_t last_pos = L - 1;
for (int64_t k = 1; k < K; k++) {
std::vector<int64_t> step_ids;
HostFeature step_g;
draft_decode(
{proposals.back()},
last_g.data.data(),
1,
last_g.H,
last_pos + k,
step_ids,
step_g);
proposals.push_back(step_ids[0]);
last_g = step_g;
}
}

// Stable buffers for target_verify (fixed length K+1) so the CUDA graph
Expand Down
Loading