diff --git a/mlx_lm/generate.py b/mlx_lm/generate.py index 3573b2640..c8c1a5356 100644 --- a/mlx_lm/generate.py +++ b/mlx_lm/generate.py @@ -1063,12 +1063,12 @@ def extend(self, batch): if not any(self.samplers): self.samplers = [None] * len(self.uids) if not any(self.logits_processors): - self.logits_processors = [None] * len(self.uids) + self.logits_processors = [[]] * len(self.uids) samplers = batch.samplers if any(batch.samplers) else [None] * len(batch.uids) logits_processors = ( batch.logits_processors if any(batch.logits_processors) - else [None] * len(batch.uids) + else [[]] * len(batch.uids) ) self.uids.extend(batch.uids) @@ -1391,8 +1391,12 @@ def filter(self, keep: List[int]): self.tokens = [self.tokens[idx] for idx in keep] if any(self.samplers): self.samplers = [self.samplers[idx] for idx in keep] + else: + self.samplers = [None] * len(keep) if any(self.logits_processors): self.logits_processors = [self.logits_processors[idx] for idx in keep] + else: + self.logits_processors = [[]] * len(keep) self.max_tokens = [self.max_tokens[idx] for idx in keep] self.state_machines = [self.state_machines[idx] for idx in keep] diff --git a/tests/test_generate.py b/tests/test_generate.py index 4f5bb4c91..129df0cde 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -9,6 +9,7 @@ from mlx_lm.generate import ( BatchGenerator, GenerationResponse, + PromptProcessingBatch, SequenceStateMachine, batch_generate, generate, @@ -402,6 +403,34 @@ def test_batch_generate_with_logits_processors(self): self.assertEqual(responses[uid1].logprobs[1].item(), 0.0) self.assertEqual(responses[uid2].logprobs[2].item(), 0.0) + def test_prompt_processing_batch_extend_mixes_logits_processors(self): + """Test PromptProcessingBatch.extend produces a per-slot list with no None entries when merging an unconfigured batch with a processor-equipped batch.""" + fallback = lambda x: mx.argmax(x, axis=-1) + a = PromptProcessingBatch.empty(self.model, fallback) + a.uids = [0] + a.tokens = [[]] + a.samplers = [] + a.logits_processors = [] + a.max_tokens = [1] + a.state_machines = [SequenceStateMachine()] + a.prompt_cache = [] + + procs = make_logits_processors({0: 2000.0}) + b = PromptProcessingBatch.empty(self.model, fallback) + b.uids = [1] + b.tokens = [[]] + b.samplers = [] + b.logits_processors = [procs] + b.max_tokens = [1] + b.state_machines = [SequenceStateMachine()] + b.prompt_cache = [] + + a.extend(b) + + self.assertEqual(len(a.logits_processors), 2) + for entry in a.logits_processors: + self.assertIsInstance(entry, list) + def test_batch_generate_processor_tokens_match_prompt_on_first_step(self): prompt = self.tokenizer.encode("hello") seen = []