From 07337e58e54a13de351ddf7aae71dfdb8c8fd56a Mon Sep 17 00:00:00 2001 From: "Michael G. Loiterman" Date: Tue, 28 Apr 2026 12:16:04 -0400 Subject: [PATCH 1/2] =?UTF-8?q?generate.py:=20GenerationBatch.filter=20?= =?UTF-8?q?=E2=80=94=20add=20else=20branches=20for=20empty=20per-sequence?= =?UTF-8?q?=20lists?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `GenerationBatch.filter` (around line 1392) uses `if any(self.samplers)` and `if any(self.logits_processors)` to guard whether it trims those per-sequence lists. When every current slot is `None` / `[]` (the common shape for any in-flight request that does not attach a custom sampler or logits processor), `any(...)` is `False` and the trim is skipped. The lists keep stale length while `self.uids` is correctly trimmed. When new sequences subsequently arrive via `extend`, their entries are appended to the over-long list. `_step`'s per-sequence loop (`for e in range(len(self.uids)): for processor in self.logits_processors[e]:`) then reads stale-index slots — silently bypassing the new sequence's processor for exactly one generation step before the next `filter` corrects the length. The structurally-symmetric `PromptProcessingBatch.filter` (around line 1117) already handles this case with explicit `else` branches that reset to `[None] * len(keep)` / `[[]] * len(keep)`. This commit mirrors those branches onto `GenerationBatch.filter`. Discovered while running schema-constrained generation (`response_format=json_schema`, `strict: true`) in a production-shape FastAPI server. After any tool-bearing chat completion (which carries no `logits_processors`), the very next request's grammar processor was silently bypassed — model returned plain unconstrained text instead of schema-valid output. The bug was silent (no exception, no log) and affected exactly one request before self-recovering, which made it ordering-dependent and hard to spot. --- mlx_lm/generate.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 3573b2640..e30686f5f 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -1391,8 +1391,12 @@ def filter(self, keep: List[int]): self.tokens = [self.tokens[idx] for idx in keep] if any(self.samplers): self.samplers = [self.samplers[idx] for idx in keep] + else: + self.samplers = [None] * len(keep) if any(self.logits_processors): self.logits_processors = [self.logits_processors[idx] for idx in keep] + else: + self.logits_processors = [[]] * len(keep) self.max_tokens = [self.max_tokens[idx] for idx in keep] self.state_machines = [self.state_machines[idx] for idx in keep] From 60f88f648a3cd412a369f8e082e9b11073bda0c1 Mon Sep 17 00:00:00 2001 From: "Michael G. Loiterman" Date: Wed, 10 Jun 2026 15:52:30 -0400 Subject: [PATCH 2/2] =?UTF-8?q?generate.py:=20PromptProcessingBatch.extend?= =?UTF-8?q?=20=E2=80=94=20use=20[]=20sentinel=20for=20absent=20per-slot=20?= =?UTF-8?q?logits=5Fprocessors?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit extend() fills missing per-slot logits_processors with [None] * N. Merging an unconfigured batch with a processor-equipped batch produces a mixed list ([None, ..., [fn], ...]); GenerationBatch._step then iterates self.logits_processors[e] under the any() guard and crashes with TypeError: 'NoneType' object is not iterable. Per-slot type is List[Callable], so the absent-value sentinel is [] — matching the existing [[]] * len(keep) in PromptProcessingBatch.filter. samplers keep the None sentinel (consumed as `self.samplers[e] or self.fallback_sampler`, type Optional[Callable]). Fix and regression test carried over from #1230 (closed as duplicate of this PR; the two changes are companions in the same file — filter here, extend there). Co-authored-by: BLuchterhand --- mlx_lm/generate.py | 4 ++-- tests/test_generate.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index e30686f5f..c8c1a5356 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -1063,12 +1063,12 @@ def extend(self, batch): if not any(self.samplers): self.samplers = [None] * len(self.uids) if not any(self.logits_processors): - self.logits_processors = [None] * len(self.uids) + self.logits_processors = [[]] * len(self.uids) samplers = batch.samplers if any(batch.samplers) else [None] * len(batch.uids) logits_processors = ( batch.logits_processors if any(batch.logits_processors) - else [None] * len(batch.uids) + else [[]] * len(batch.uids) ) self.uids.extend(batch.uids) diff --git a/tests/test_generate.py b/tests/test_generate.py index 4f5bb4c91..129df0cde 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -9,6 +9,7 @@ from mlx_lm.generate import ( BatchGenerator, GenerationResponse, + PromptProcessingBatch, SequenceStateMachine, batch_generate, generate, @@ -402,6 +403,34 @@ def test_batch_generate_with_logits_processors(self): self.assertEqual(responses[uid1].logprobs[1].item(), 0.0) self.assertEqual(responses[uid2].logprobs[2].item(), 0.0) + def test_prompt_processing_batch_extend_mixes_logits_processors(self): + """Test PromptProcessingBatch.extend produces a per-slot list with no None entries when merging an unconfigured batch with a processor-equipped batch.""" + fallback = lambda x: mx.argmax(x, axis=-1) + a = PromptProcessingBatch.empty(self.model, fallback) + a.uids = [0] + a.tokens = [[]] + a.samplers = [] + a.logits_processors = [] + a.max_tokens = [1] + a.state_machines = [SequenceStateMachine()] + a.prompt_cache = [] + + procs = make_logits_processors({0: 2000.0}) + b = PromptProcessingBatch.empty(self.model, fallback) + b.uids = [1] + b.tokens = [[]] + b.samplers = [] + b.logits_processors = [procs] + b.max_tokens = [1] + b.state_machines = [SequenceStateMachine()] + b.prompt_cache = [] + + a.extend(b) + + self.assertEqual(len(a.logits_processors), 2) + for entry in a.logits_processors: + self.assertIsInstance(entry, list) + def test_batch_generate_processor_tokens_match_prompt_on_first_step(self): prompt = self.tokenizer.encode("hello") seen = []