diff --git a/omlx/scheduler.py b/omlx/scheduler.py index 84aff86a9..74bb5f2b2 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -473,6 +473,43 @@ def _patched_generation_batch_step(self): GenerationBatch._step = _patched_generation_batch_step +# --------------------------------------------------------------------------- +# Monkey-patch GenerationBatch.filter to keep logits_processors aligned with +# uids. mlx-lm's filter only reindexes the processor list when at least one +# row has an active processor: +# +# if any(self.logits_processors): +# self.logits_processors = [self.logits_processors[idx] for idx in keep] +# +# There is no else branch (unlike the prompt-batch class, which resets to +# ``[[]] * len(keep)``), so when every slot is empty — the normal state after +# serving requests without per-request processors — the stale list survives +# while uids/tokens shrink. A later extend() then appends the next request's +# processors BEHIND its own row index: the row reads a leftover empty slot and +# the real processor (thinking budget, grammar constraint) is silently never +# applied. Which requests are affected depends on insertion/removal order, +# and alignment self-heals once the broken request finishes, so the symptom +# is an intermittently ignored thinking_budget or grammar. See #934/#1747 +# for the sibling None-slot collapse handled in _patched_generation_batch_step. +_original_generation_batch_filter = GenerationBatch.filter + + +def _patched_generation_batch_filter(self, keep): + lps = self.logits_processors + lps_inert = not lps or not any(lps) + if lps is None: + # ``any(None)`` inside the original filter raises TypeError. + self.logits_processors = [] + _original_generation_batch_filter(self, keep) + if lps_inert: + # Original filter skipped the reindex; reset to one empty slot per + # surviving row so extend() appends at the correct indices. + self.logits_processors = [[] for _ in keep] + + +GenerationBatch.filter = _patched_generation_batch_filter + + # Monkey-patch TurboQuantKVCache.merge so _merge_caches() works try: from mlx_vlm.turboquant import TurboQuantKVCache as _TQCache diff --git a/tests/test_scheduler_logits_processors.py b/tests/test_scheduler_logits_processors.py index 36a1b8583..d5b1e707f 100644 --- a/tests/test_scheduler_logits_processors.py +++ b/tests/test_scheduler_logits_processors.py @@ -132,6 +132,152 @@ def test_scheduler_source_normalises_per_row_slots(self): ) +def _bare_generation_batch(uid, logits_processors): + """Build a GenerationBatch via __new__ with plain-list state. + + ``filter()`` and ``extend()`` never touch the model, so a bare instance + is enough to exercise the real mlx-lm bookkeeping without loading + weights. Mirrors the ``__class__.__new__`` idiom of + ``_patched_ppb_split`` in omlx/scheduler.py. + """ + from mlx_lm.generate import GenerationBatch + + batch = GenerationBatch.__new__(GenerationBatch) + batch.uids = [uid] + batch.prompt_cache = [] + batch.tokens = [[1, 2, 3]] + batch.samplers = [lambda x: x] + batch.fallback_sampler = lambda x: x + batch.logits_processors = logits_processors + batch.state_machines = [object()] + batch.max_tokens = [4] + batch._current_tokens = None + batch._current_logprobs = [] + batch._next_tokens = None + batch._next_logprobs = [object()] + batch._token_context = [object()] + batch._num_tokens = [0] + batch._matcher_states = [object()] + return batch + + +class TestFilterStaleProcessorAlignment: + """Pin the GenerationBatch.filter alignment patch. + + mlx-lm's ``GenerationBatch.filter`` reindexes ``logits_processors`` only + when ``any(self.logits_processors)`` is True; there is no else branch + (the prompt-batch class has one: ``[[]] * len(keep)``). After a request + with no per-request processors finishes — every slot ``[]``, the shape + omlx inserts — removal shrinks ``uids`` but leaves the stale processor + list behind. The next request's row then ``extend()``s in BEHIND its own + index: row 0 reads the leftover empty slot and its real processor + (thinking budget, grammar constraint) is silently never applied. The + misalignment self-heals when the affected request finishes (the orphan + makes ``any()`` True again), so the symptom is an intermittently ignored + thinking_budget / grammar that depends on request order. + + ``_patched_generation_batch_filter`` resets the list to one empty slot + per surviving row whenever the original guard would have skipped the + reindex. + """ + + def test_filter_resets_stale_list_when_all_slots_inert(self): + """filter(keep=[]) on an all-empty-slot batch must empty the list. + + Fails before the fix: ``logits_processors`` stays ``[[]]`` while + ``uids`` becomes ``[]``. Passes after: both are empty. + """ + import omlx.scheduler # noqa: F401 (installs the filter patch) + + batch = _bare_generation_batch(uid=0, logits_processors=[[]]) + batch.filter([]) + + assert batch.uids == [] + assert batch.logits_processors == [] + + def test_processor_lands_on_its_own_row_after_remove_then_extend(self): + """End-to-end shape of the live reproduction (#1825 follow-up). + + Request A (no processors) finishes and is removed; request B (with a + thinking-budget-style processor) joins via extend(). B's processor + must sit at B's row index. Fails before the fix with + ``logits_processors == [[], [processor]]`` against ``uids == [1]`` — + row 0 reads the stale empty slot and the processor is never called. + """ + import omlx.scheduler # noqa: F401 (installs the filter patch) + + def budget_processor(tokens, logits): + return logits + + survivor = _bare_generation_batch(uid=0, logits_processors=[[]]) + survivor.filter([]) # request A removed; batch now empty + + joiner = _bare_generation_batch( + uid=1, logits_processors=[[budget_processor]] + ) + survivor.extend(joiner) # request B joins the long-lived batch + + assert survivor.uids == [1] + assert len(survivor.logits_processors) == len(survivor.uids) + assert survivor.logits_processors[0] == [budget_processor] + + def test_filter_preserves_active_processor_reindex(self): + """When any slot is active the original reindex path runs; the patch + must not clobber its (correct) result.""" + import omlx.scheduler # noqa: F401 (installs the filter patch) + + def grammar_processor(tokens, logits): + return logits + + batch = _bare_generation_batch(uid=0, logits_processors=None) + batch.uids = [0, 1] + batch.tokens = [[1], [2]] + batch.samplers = [lambda x: x, lambda x: x] + batch.logits_processors = [[], [grammar_processor]] + batch.state_machines = [object(), object()] + batch.max_tokens = [4, 4] + batch._next_logprobs = [object(), object()] + batch._token_context = [object(), object()] + batch._num_tokens = [0, 0] + batch._matcher_states = [object(), object()] + import mlx.core as mx + + batch._next_tokens = mx.array([1, 2]) + batch.filter([1]) + + assert batch.uids == [1] + assert batch.logits_processors == [[grammar_processor]] + + def test_filter_normalises_none_list(self): + """A None logits_processors list must not crash the original filter + (``any(None)`` raises TypeError) and must come out aligned.""" + import omlx.scheduler # noqa: F401 (installs the filter patch) + + batch = _bare_generation_batch(uid=0, logits_processors=None) + batch.filter([]) + + assert batch.logits_processors == [] + + def test_scheduler_source_installs_filter_patch(self): + """Source-level guard against silent removal of the patch + installation. Cheap; runs without a model in CI.""" + from pathlib import Path + + scheduler_src = ( + Path(__file__).resolve().parents[1] / "omlx" / "scheduler.py" + ).read_text() + assert ( + "GenerationBatch.filter = _patched_generation_batch_filter" + in scheduler_src + ), ( + "scheduler.py must install _patched_generation_batch_filter on " + "GenerationBatch.filter: mlx-lm's filter leaves a stale " + "logits_processors list behind when every slot is empty, which " + "silently drops the next request's processors after a " + "remove-then-extend." + ) + + class TestCorruptionPatternRecovery: """Pin the recovery contract: 'not iterable' is a known corruption."""