From 423301be61922ea6db8deee47c06d06b3cb667c5 Mon Sep 17 00:00:00 2001 From: BLuchterhand Date: Wed, 29 Apr 2026 17:29:44 -0400 Subject: [PATCH] fix(generate): avoid None entries in merged logits_processors PromptProcessingBatch.extend filled missing per-slot logits_processors with [None] when either side lacked configured processors. Merging an unconfigured batch with a processor-equipped batch then produced a list shaped like [None, ..., [fn], ...]. GenerationBatch._step at line 1346 iterates self.logits_processors[e] under the any() guard at line 1337, which raises TypeError on the None slots. Fill with [[]] instead. Matches the existing pattern at line 1120 (filter() restoring [[]] * len(keep)) and the per-slot type List[Callable]. Reproduce: construct two PromptProcessingBatch instances, one without processors and one with, then call extend; the merged self.logits_processors contains None entries. New unit test covers this shape directly. --- mlx_lm/generate.py | 4 ++-- tests/test_generate.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 3573b2640..5cafc1aa5 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -1063,12 +1063,12 @@ def extend(self, batch): if not any(self.samplers): self.samplers = [None] * len(self.uids) if not any(self.logits_processors): - self.logits_processors = [None] * len(self.uids) + self.logits_processors = [[]] * len(self.uids) samplers = batch.samplers if any(batch.samplers) else [None] * len(batch.uids) logits_processors = ( batch.logits_processors if any(batch.logits_processors) - else [None] * len(batch.uids) + else [[]] * len(batch.uids) ) self.uids.extend(batch.uids) diff --git a/tests/test_generate.py b/tests/test_generate.py index 4f5bb4c91..129df0cde 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -9,6 +9,7 @@ from mlx_lm.generate import ( BatchGenerator, GenerationResponse, + PromptProcessingBatch, SequenceStateMachine, batch_generate, generate, @@ -402,6 +403,34 @@ def test_batch_generate_with_logits_processors(self): self.assertEqual(responses[uid1].logprobs[1].item(), 0.0) self.assertEqual(responses[uid2].logprobs[2].item(), 0.0) + def test_prompt_processing_batch_extend_mixes_logits_processors(self): + """Test PromptProcessingBatch.extend produces a per-slot list with no None entries when merging an unconfigured batch with a processor-equipped batch.""" + fallback = lambda x: mx.argmax(x, axis=-1) + a = PromptProcessingBatch.empty(self.model, fallback) + a.uids = [0] + a.tokens = [[]] + a.samplers = [] + a.logits_processors = [] + a.max_tokens = [1] + a.state_machines = [SequenceStateMachine()] + a.prompt_cache = [] + + procs = make_logits_processors({0: 2000.0}) + b = PromptProcessingBatch.empty(self.model, fallback) + b.uids = [1] + b.tokens = [[]] + b.samplers = [] + b.logits_processors = [procs] + b.max_tokens = [1] + b.state_machines = [SequenceStateMachine()] + b.prompt_cache = [] + + a.extend(b) + + self.assertEqual(len(a.logits_processors), 2) + for entry in a.logits_processors: + self.assertIsInstance(entry, list) + def test_batch_generate_processor_tokens_match_prompt_on_first_step(self): prompt = self.tokenizer.encode("hello") seen = []