|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +"""Chunked-vs-unchunked prefill equivalence test for the MLX qwen3.5 MoE .pte. |
| 8 | +
|
| 9 | +The MLX C++ runner chunks long prompts and carries the recurrent/conv state and |
| 10 | +KV cache across chunk boundaries (qwen35_moe_engine.cpp prefill_tokens). Chunk |
| 11 | +boundaries are easy to get subtly wrong, so this test asserts that feeding a |
| 12 | +prompt as several sequential `forward` calls produces the same final-position |
| 13 | +logits (and same greedy first token) as a single `forward` call. |
| 14 | +
|
| 15 | +It runs against an already-exported tiny MLX .pte (no tokenizer needed: random |
| 16 | +token ids). Point it at the .pte via the QWEN_TINY_PTE env var, e.g.: |
| 17 | +
|
| 18 | + python -m executorch.examples.models.qwen3_5_moe.export \ |
| 19 | + --tiny-test --backend mlx --qlinear 4w --qlinear-group-size 32 \ |
| 20 | + --output-dir /tmp/qwen35_moe_mlx_tiny |
| 21 | + QWEN_TINY_PTE=/tmp/qwen35_moe_mlx_tiny/model.pte \ |
| 22 | + python -m pytest examples/models/qwen3_5_moe/test_chunked_prefill.py -v |
| 23 | +
|
| 24 | +The test skips (rather than fails) when the .pte env var is unset or the MLX |
| 25 | +runtime is unavailable, so it is a no-op on non-MLX machines. |
| 26 | +""" |
| 27 | + |
| 28 | +import os |
| 29 | +import unittest |
| 30 | + |
| 31 | +import torch |
| 32 | + |
| 33 | +PTE_ENV = "QWEN_TINY_PTE" |
| 34 | + |
| 35 | + |
| 36 | +def _load_forward(pte_path): |
| 37 | + """Load a fresh program instance so mutable state starts zeroed.""" |
| 38 | + from executorch.runtime import Runtime, Verification |
| 39 | + |
| 40 | + runtime = Runtime.get() |
| 41 | + program = runtime.load_program(pte_path, verification=Verification.Minimal) |
| 42 | + return program, program.load_method("forward") |
| 43 | + |
| 44 | + |
| 45 | +def _scalar_metadata(program, name, default): |
| 46 | + try: |
| 47 | + result = program.load_method(name).execute([]) |
| 48 | + except Exception: |
| 49 | + return default |
| 50 | + v = result[0] |
| 51 | + return int(v) if isinstance(v, int) else int(v.item()) |
| 52 | + |
| 53 | + |
| 54 | +def _last_logits(outputs): |
| 55 | + # forward returns logits shaped (1, T, vocab); take the final position. |
| 56 | + return outputs[0][0, -1, :] |
| 57 | + |
| 58 | + |
| 59 | +class TestChunkedPrefill(unittest.TestCase): |
| 60 | + def setUp(self): |
| 61 | + self.pte_path = os.environ.get(PTE_ENV) |
| 62 | + if not self.pte_path: |
| 63 | + self.skipTest(f"{PTE_ENV} not set; export a tiny MLX .pte first") |
| 64 | + if not os.path.exists(self.pte_path): |
| 65 | + self.skipTest(f"{PTE_ENV}={self.pte_path} does not exist") |
| 66 | + try: |
| 67 | + import executorch.runtime # noqa: F401 |
| 68 | + except Exception as e: # pragma: no cover - environment dependent |
| 69 | + self.skipTest(f"executorch.runtime unavailable: {e}") |
| 70 | + |
| 71 | + def test_chunked_prefill_matches_unchunked(self): |
| 72 | + # Read shapes from the model's constant methods. |
| 73 | + program, _ = _load_forward(self.pte_path) |
| 74 | + vocab_size = _scalar_metadata(program, "get_vocab_size", 256) |
| 75 | + max_seq_len = _scalar_metadata(program, "get_max_seq_len", 64) |
| 76 | + del program |
| 77 | + |
| 78 | + prompt_len = min(40, max_seq_len - 1) |
| 79 | + chunk = 8 |
| 80 | + self.assertGreater( |
| 81 | + prompt_len, |
| 82 | + chunk, |
| 83 | + "prompt must exceed chunk size to exercise multiple chunks", |
| 84 | + ) |
| 85 | + |
| 86 | + torch.manual_seed(0) |
| 87 | + tokens = torch.randint(0, vocab_size, (1, prompt_len), dtype=torch.long) |
| 88 | + |
| 89 | + # Unchunked: one forward over the whole prompt (fresh program/state). |
| 90 | + _, forward_full = _load_forward(self.pte_path) |
| 91 | + pos_full = torch.arange(prompt_len, dtype=torch.long) |
| 92 | + logits_full = _last_logits(forward_full.execute([tokens, pos_full])) |
| 93 | + |
| 94 | + # Chunked: sequential forwards advancing input_pos, carrying state across |
| 95 | + # boundaries (fresh program/state). |
| 96 | + _, forward_chunk = _load_forward(self.pte_path) |
| 97 | + logits_chunk = None |
| 98 | + for off in range(0, prompt_len, chunk): |
| 99 | + end = min(off + chunk, prompt_len) |
| 100 | + chunk_tokens = tokens[:, off:end] |
| 101 | + chunk_pos = torch.arange(off, end, dtype=torch.long) |
| 102 | + logits_chunk = _last_logits( |
| 103 | + forward_chunk.execute([chunk_tokens, chunk_pos]) |
| 104 | + ) |
| 105 | + |
| 106 | + # Same greedy first token, and logits numerically close. |
| 107 | + self.assertEqual( |
| 108 | + int(torch.argmax(logits_full)), |
| 109 | + int(torch.argmax(logits_chunk)), |
| 110 | + "chunked prefill produced a different first token than unchunked", |
| 111 | + ) |
| 112 | + torch.testing.assert_close( |
| 113 | + logits_chunk.to(torch.float32), |
| 114 | + logits_full.to(torch.float32), |
| 115 | + rtol=1e-2, |
| 116 | + atol=1e-2, |
| 117 | + ) |
| 118 | + |
| 119 | + |
| 120 | +if __name__ == "__main__": |
| 121 | + unittest.main() |
0 commit comments