Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions examples/models/gemma4_31b/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,9 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -

# Prefill (T>=2): shim does dequant+cuBLAS (optimal for large M).
max_prefill = min(config.max_seq_len - 1, config.sliding_window * 2)
seq_dim = Dim("seq_len", min=5, max=max_prefill)
print(f"Exporting prefill (T in [2, {max_prefill}])...")
min_prefill = 5
seq_dim = Dim("seq_len", min=min_prefill, max=max_prefill)
print(f"Exporting prefill (T in [{min_prefill}, {max_prefill}])...")
with torch.no_grad():
prefill_ep = export(
model,
Expand Down Expand Up @@ -250,6 +251,8 @@ def _export_cuda(model: Gemma4_31B, config: Gemma4_31BConfig, output_dir: str) -
"get_vocab_size": config.vocab_size,
"get_n_layers": config.num_hidden_layers,
"get_max_prefill_chunk": max_prefill,
"get_min_prefill_chunk": min_prefill,
"get_sliding_window": config.sliding_window,
"use_kv_cache": True,
"use_sdpa_with_kv_cache": False,
"enable_dynamic_shape": True,
Expand Down Expand Up @@ -364,6 +367,7 @@ def _export_mlx(
"get_vocab_size": config.vocab_size,
"get_n_layers": config.num_hidden_layers,
"get_max_prefill_chunk": max_prefill,
"get_sliding_window": config.sliding_window,
"use_kv_cache": True,
"use_sdpa_with_kv_cache": False,
"enable_dynamic_shape": True,
Expand Down
76 changes: 72 additions & 4 deletions examples/models/gemma4_31b/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ DEFINE_double(temperature, 0.8, "Sampling temperature (0 = near-greedy).");
DEFINE_int32(max_new_tokens, 128, "Maximum tokens to generate.");
DEFINE_int32(bos_id, 2, "BOS token id to prepend (Gemma convention: 2).");
DEFINE_int32(eos_id, 1, "EOS token id (Gemma convention: 1).");
DEFINE_int32(
max_prefill_chunk,
0,
"Override the prefill chunk size (<=0 uses metadata). Experiment: chunking "
"above sliding_window is inexact for sliding layers across boundaries.");
DEFINE_bool(
raw_prompt,
false,
Expand Down Expand Up @@ -168,13 +173,55 @@ int main(int argc, char** argv) {
return 1;
}

int64_t max_prefill_chunk = (*metadata_result)[llm::kMaxSeqLen] - 1;
int64_t exported_max_prefill = (*metadata_result)[llm::kMaxSeqLen] - 1;
{
auto get_result = module->get("get_max_prefill_chunk");
if (get_result.ok()) {
max_prefill_chunk = get_result->toScalar().to<int64_t>();
exported_max_prefill = get_result->toScalar().to<int64_t>();
}
}
// Cap prefill chunks at the sliding window: a chunk larger than the window
// overflows the 2*window ring cache across chunk boundaries, truncating
// sliding attention for the first ~(chunk-window) queries of each chunk (the
// global flat-cache layers stay exact). The export sets max_prefill =
// 2*sliding_window, so window = max_prefill/2 (prefer get_sliding_window
// metadata when present).
int64_t sliding_window = exported_max_prefill / 2;
{
auto sw = module->get("get_sliding_window");
if (sw.ok()) {
sliding_window = sw->toScalar().to<int64_t>();
}
}
int64_t max_prefill_chunk = std::min(sliding_window, exported_max_prefill);
if (FLAGS_max_prefill_chunk > 0) {
max_prefill_chunk =
std::min<int64_t>(FLAGS_max_prefill_chunk, exported_max_prefill);
}
// The exported prefill accepts T in [min_prefill, max_prefill]; a final chunk
// shorter than min_prefill (and > 1) is an out-of-range shape. Read the lower
// bound so chunking can avoid it (fallback 1 keeps older exports working: a
// length-1 tail already routes to decode).
int64_t min_prefill = 1;
{
auto r = module->get("get_min_prefill_chunk");
if (r.ok()) {
min_prefill = r->toScalar().to<int64_t>();
}
}
// A --max_prefill_chunk below the exported minimum has no valid prefill shape
// (and a cap of 1 would make the tail adjustment compute chunk_len == 0 and
// loop forever), so reject it rather than feed an out-of-range / zero chunk.
if (FLAGS_max_prefill_chunk > 0 && max_prefill_chunk < min_prefill) {
ET_LOG(
Error,
"--max_prefill_chunk (%d) is below the exported prefill minimum "
"(%" PRId64 "); use a value >= %" PRId64 " or omit it.",
FLAGS_max_prefill_chunk,
min_prefill,
min_prefill);
return 1;
}

auto S = [](int64_t v) -> SizesType { return static_cast<SizesType>(v); };

Expand Down Expand Up @@ -280,6 +327,21 @@ int main(int argc, char** argv) {
printf("Prompt tokens: %" PRId64 "\n", num_prompt_tokens);
stats.num_prompt_tokens = num_prompt_tokens;

// A prompt of 2..min_prefill-1 tokens has no valid prefill shape (the CUDA
// export specializes prefill to T >= min_prefill) and is too long for the
// single-token decode path, so reject it. A 1-token prompt is fine: it goes
// through decode below.
if (num_prompt_tokens > 1 && num_prompt_tokens < min_prefill) {
ET_LOG(
Error,
"Prompt (%" PRId64
" tokens) is below the exported prefill minimum %" PRId64
"; use a longer prompt.",
num_prompt_tokens,
min_prefill);
return 1;
}

stats.inference_start_ms = llm::time_in_ms();

// ---------------------------------------------------------------
Expand All @@ -288,8 +350,14 @@ int main(int argc, char** argv) {
uint64_t cur_token = 0;
int64_t prefill_pos = 0;
while (prefill_pos < num_prompt_tokens) {
int64_t chunk_len =
std::min(num_prompt_tokens - prefill_pos, max_prefill_chunk);
int64_t remaining = num_prompt_tokens - prefill_pos;
int64_t chunk_len = std::min(remaining, max_prefill_chunk);
// Shrink this chunk so the tail it leaves is never in (1, min_prefill):
// such a tail would be an out-of-range prefill shape. A length-1 tail is
// fine (routed to decode below); a >= min_prefill tail is fine too.
if (remaining - chunk_len > 1 && remaining - chunk_len < min_prefill) {
chunk_len = remaining - min_prefill;
}

std::vector<int64_t> token_data(
prompt_tokens.begin() + prefill_pos,
Expand Down
Loading