Skip to content

Commit b808e26

Browse files
committed
up
1 parent 5af5b19 commit b808e26

2 files changed

Lines changed: 130 additions & 5 deletions

File tree

examples/models/qwen3_5_moe/qwen35_moe_engine.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@ namespace {
4747
// and calls `forward` for both phases.
4848
constexpr const char* kPrefillMethod = "forward";
4949
constexpr const char* kDecodeMethod = "forward";
50-
// Prefill is chunked on MLX to cap peak memory and the compiled prefill shape.
51-
constexpr int64_t kPrefillChunkSize = 1024;
5250
#else
5351
// CUDA/Metal exports emit two separate methods.
5452
constexpr const char* kPrefillMethod = "prefill";
@@ -268,9 +266,15 @@ class Qwen35MoESession : public LLMSession {
268266
// pass. Only the final chunk's sampled token is kept; the recurrence/KV
269267
// state from earlier chunks persists via pos_ advancement.
270268
#ifdef EXECUTORCH_BUILD_MLX
271-
// Chunk size = compiled max prefill chunk from model metadata, falling back
272-
// to the default if the model didn't export it. Clamp to >= 1.
273-
int64_t chunk_size = kPrefillChunkSize;
269+
// Chunk size: default to the compiled max (kMaxSeqLen - 1), overridden by
270+
// the exported get_max_prefill_chunk constant when present (mirrors
271+
// gemma4_31b). Falls back to T (single pass) if no metadata is available at
272+
// all.
273+
int64_t chunk_size = T;
274+
if (auto it = metadata_.find(kMaxSeqLen);
275+
it != metadata_.end() && it->second > 1) {
276+
chunk_size = it->second - 1;
277+
}
274278
if (auto it = metadata_.find(kMaxPrefillChunk);
275279
it != metadata_.end() && it->second > 0) {
276280
chunk_size = it->second;
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""Chunked-vs-unchunked prefill equivalence test for the MLX qwen3.5 MoE .pte.
8+
9+
The MLX C++ runner chunks long prompts and carries the recurrent/conv state and
10+
KV cache across chunk boundaries (qwen35_moe_engine.cpp prefill_tokens). Chunk
11+
boundaries are easy to get subtly wrong, so this test asserts that feeding a
12+
prompt as several sequential `forward` calls produces the same final-position
13+
logits (and same greedy first token) as a single `forward` call.
14+
15+
It runs against an already-exported tiny MLX .pte (no tokenizer needed: random
16+
token ids). Point it at the .pte via the QWEN_TINY_PTE env var, e.g.:
17+
18+
python -m executorch.examples.models.qwen3_5_moe.export \
19+
--tiny-test --backend mlx --qlinear 4w --qlinear-group-size 32 \
20+
--output-dir /tmp/qwen35_moe_mlx_tiny
21+
QWEN_TINY_PTE=/tmp/qwen35_moe_mlx_tiny/model.pte \
22+
python -m pytest examples/models/qwen3_5_moe/test_chunked_prefill.py -v
23+
24+
The test skips (rather than fails) when the .pte env var is unset or the MLX
25+
runtime is unavailable, so it is a no-op on non-MLX machines.
26+
"""
27+
28+
import os
29+
import unittest
30+
31+
import torch
32+
33+
PTE_ENV = "QWEN_TINY_PTE"
34+
35+
36+
def _load_forward(pte_path):
37+
"""Load a fresh program instance so mutable state starts zeroed."""
38+
from executorch.runtime import Runtime, Verification
39+
40+
runtime = Runtime.get()
41+
program = runtime.load_program(pte_path, verification=Verification.Minimal)
42+
return program, program.load_method("forward")
43+
44+
45+
def _scalar_metadata(program, name, default):
46+
try:
47+
result = program.load_method(name).execute([])
48+
except Exception:
49+
return default
50+
v = result[0]
51+
return int(v) if isinstance(v, int) else int(v.item())
52+
53+
54+
def _last_logits(outputs):
55+
# forward returns logits shaped (1, T, vocab); take the final position.
56+
return outputs[0][0, -1, :]
57+
58+
59+
class TestChunkedPrefill(unittest.TestCase):
60+
def setUp(self):
61+
self.pte_path = os.environ.get(PTE_ENV)
62+
if not self.pte_path:
63+
self.skipTest(f"{PTE_ENV} not set; export a tiny MLX .pte first")
64+
if not os.path.exists(self.pte_path):
65+
self.skipTest(f"{PTE_ENV}={self.pte_path} does not exist")
66+
try:
67+
import executorch.runtime # noqa: F401
68+
except Exception as e: # pragma: no cover - environment dependent
69+
self.skipTest(f"executorch.runtime unavailable: {e}")
70+
71+
def test_chunked_prefill_matches_unchunked(self):
72+
# Read shapes from the model's constant methods.
73+
program, _ = _load_forward(self.pte_path)
74+
vocab_size = _scalar_metadata(program, "get_vocab_size", 256)
75+
max_seq_len = _scalar_metadata(program, "get_max_seq_len", 64)
76+
del program
77+
78+
prompt_len = min(40, max_seq_len - 1)
79+
chunk = 8
80+
self.assertGreater(
81+
prompt_len,
82+
chunk,
83+
"prompt must exceed chunk size to exercise multiple chunks",
84+
)
85+
86+
torch.manual_seed(0)
87+
tokens = torch.randint(0, vocab_size, (1, prompt_len), dtype=torch.long)
88+
89+
# Unchunked: one forward over the whole prompt (fresh program/state).
90+
_, forward_full = _load_forward(self.pte_path)
91+
pos_full = torch.arange(prompt_len, dtype=torch.long)
92+
logits_full = _last_logits(forward_full.execute([tokens, pos_full]))
93+
94+
# Chunked: sequential forwards advancing input_pos, carrying state across
95+
# boundaries (fresh program/state).
96+
_, forward_chunk = _load_forward(self.pte_path)
97+
logits_chunk = None
98+
for off in range(0, prompt_len, chunk):
99+
end = min(off + chunk, prompt_len)
100+
chunk_tokens = tokens[:, off:end]
101+
chunk_pos = torch.arange(off, end, dtype=torch.long)
102+
logits_chunk = _last_logits(
103+
forward_chunk.execute([chunk_tokens, chunk_pos])
104+
)
105+
106+
# Same greedy first token, and logits numerically close.
107+
self.assertEqual(
108+
int(torch.argmax(logits_full)),
109+
int(torch.argmax(logits_chunk)),
110+
"chunked prefill produced a different first token than unchunked",
111+
)
112+
torch.testing.assert_close(
113+
logits_chunk.to(torch.float32),
114+
logits_full.to(torch.float32),
115+
rtol=1e-2,
116+
atol=1e-2,
117+
)
118+
119+
120+
if __name__ == "__main__":
121+
unittest.main()

0 commit comments

Comments
 (0)