Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions omlx/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,43 @@ def _patched_generation_batch_step(self):
GenerationBatch._step = _patched_generation_batch_step


# ---------------------------------------------------------------------------
# Monkey-patch GenerationBatch.filter to keep logits_processors aligned with
# uids. mlx-lm's filter only reindexes the processor list when at least one
# row has an active processor:
#
# if any(self.logits_processors):
# self.logits_processors = [self.logits_processors[idx] for idx in keep]
#
# There is no else branch (unlike the prompt-batch class, which resets to
# ``[[]] * len(keep)``), so when every slot is empty — the normal state after
# serving requests without per-request processors — the stale list survives
# while uids/tokens shrink. A later extend() then appends the next request's
# processors BEHIND its own row index: the row reads a leftover empty slot and
# the real processor (thinking budget, grammar constraint) is silently never
# applied. Which requests are affected depends on insertion/removal order,
# and alignment self-heals once the broken request finishes, so the symptom
# is an intermittently ignored thinking_budget or grammar. See #934/#1747
# for the sibling None-slot collapse handled in _patched_generation_batch_step.
_original_generation_batch_filter = GenerationBatch.filter


def _patched_generation_batch_filter(self, keep):
lps = self.logits_processors
lps_inert = not lps or not any(lps)
if lps is None:
# ``any(None)`` inside the original filter raises TypeError.
self.logits_processors = []
_original_generation_batch_filter(self, keep)
if lps_inert:
# Original filter skipped the reindex; reset to one empty slot per
# surviving row so extend() appends at the correct indices.
self.logits_processors = [[] for _ in keep]


GenerationBatch.filter = _patched_generation_batch_filter


# Monkey-patch TurboQuantKVCache.merge so _merge_caches() works
try:
from mlx_vlm.turboquant import TurboQuantKVCache as _TQCache
Expand Down
146 changes: 146 additions & 0 deletions tests/test_scheduler_logits_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,152 @@ def test_scheduler_source_normalises_per_row_slots(self):
)


def _bare_generation_batch(uid, logits_processors):
"""Build a GenerationBatch via __new__ with plain-list state.

``filter()`` and ``extend()`` never touch the model, so a bare instance
is enough to exercise the real mlx-lm bookkeeping without loading
weights. Mirrors the ``__class__.__new__`` idiom of
``_patched_ppb_split`` in omlx/scheduler.py.
"""
from mlx_lm.generate import GenerationBatch

batch = GenerationBatch.__new__(GenerationBatch)
batch.uids = [uid]
batch.prompt_cache = []
batch.tokens = [[1, 2, 3]]
batch.samplers = [lambda x: x]
batch.fallback_sampler = lambda x: x
batch.logits_processors = logits_processors
batch.state_machines = [object()]
batch.max_tokens = [4]
batch._current_tokens = None
batch._current_logprobs = []
batch._next_tokens = None
batch._next_logprobs = [object()]
batch._token_context = [object()]
batch._num_tokens = [0]
batch._matcher_states = [object()]
return batch


class TestFilterStaleProcessorAlignment:
"""Pin the GenerationBatch.filter alignment patch.

mlx-lm's ``GenerationBatch.filter`` reindexes ``logits_processors`` only
when ``any(self.logits_processors)`` is True; there is no else branch
(the prompt-batch class has one: ``[[]] * len(keep)``). After a request
with no per-request processors finishes — every slot ``[]``, the shape
omlx inserts — removal shrinks ``uids`` but leaves the stale processor
list behind. The next request's row then ``extend()``s in BEHIND its own
index: row 0 reads the leftover empty slot and its real processor
(thinking budget, grammar constraint) is silently never applied. The
misalignment self-heals when the affected request finishes (the orphan
makes ``any()`` True again), so the symptom is an intermittently ignored
thinking_budget / grammar that depends on request order.

``_patched_generation_batch_filter`` resets the list to one empty slot
per surviving row whenever the original guard would have skipped the
reindex.
"""

def test_filter_resets_stale_list_when_all_slots_inert(self):
"""filter(keep=[]) on an all-empty-slot batch must empty the list.

Fails before the fix: ``logits_processors`` stays ``[[]]`` while
``uids`` becomes ``[]``. Passes after: both are empty.
"""
import omlx.scheduler # noqa: F401 (installs the filter patch)

batch = _bare_generation_batch(uid=0, logits_processors=[[]])
batch.filter([])

assert batch.uids == []
assert batch.logits_processors == []

def test_processor_lands_on_its_own_row_after_remove_then_extend(self):
"""End-to-end shape of the live reproduction (#1825 follow-up).

Request A (no processors) finishes and is removed; request B (with a
thinking-budget-style processor) joins via extend(). B's processor
must sit at B's row index. Fails before the fix with
``logits_processors == [[], [processor]]`` against ``uids == [1]`` —
row 0 reads the stale empty slot and the processor is never called.
"""
import omlx.scheduler # noqa: F401 (installs the filter patch)

def budget_processor(tokens, logits):
return logits

survivor = _bare_generation_batch(uid=0, logits_processors=[[]])
survivor.filter([]) # request A removed; batch now empty

joiner = _bare_generation_batch(
uid=1, logits_processors=[[budget_processor]]
)
survivor.extend(joiner) # request B joins the long-lived batch

assert survivor.uids == [1]
assert len(survivor.logits_processors) == len(survivor.uids)
assert survivor.logits_processors[0] == [budget_processor]

def test_filter_preserves_active_processor_reindex(self):
"""When any slot is active the original reindex path runs; the patch
must not clobber its (correct) result."""
import omlx.scheduler # noqa: F401 (installs the filter patch)

def grammar_processor(tokens, logits):
return logits

batch = _bare_generation_batch(uid=0, logits_processors=None)
batch.uids = [0, 1]
batch.tokens = [[1], [2]]
batch.samplers = [lambda x: x, lambda x: x]
batch.logits_processors = [[], [grammar_processor]]
batch.state_machines = [object(), object()]
batch.max_tokens = [4, 4]
batch._next_logprobs = [object(), object()]
batch._token_context = [object(), object()]
batch._num_tokens = [0, 0]
batch._matcher_states = [object(), object()]
import mlx.core as mx

batch._next_tokens = mx.array([1, 2])
batch.filter([1])

assert batch.uids == [1]
assert batch.logits_processors == [[grammar_processor]]

def test_filter_normalises_none_list(self):
"""A None logits_processors list must not crash the original filter
(``any(None)`` raises TypeError) and must come out aligned."""
import omlx.scheduler # noqa: F401 (installs the filter patch)

batch = _bare_generation_batch(uid=0, logits_processors=None)
batch.filter([])

assert batch.logits_processors == []

def test_scheduler_source_installs_filter_patch(self):
"""Source-level guard against silent removal of the patch
installation. Cheap; runs without a model in CI."""
from pathlib import Path

scheduler_src = (
Path(__file__).resolve().parents[1] / "omlx" / "scheduler.py"
).read_text()
assert (
"GenerationBatch.filter = _patched_generation_batch_filter"
in scheduler_src
), (
"scheduler.py must install _patched_generation_batch_filter on "
"GenerationBatch.filter: mlx-lm's filter leaves a stale "
"logits_processors list behind when every slot is empty, which "
"silently drops the next request's processors after a "
"remove-then-extend."
)


class TestCorruptionPatternRecovery:
"""Pin the recovery contract: 'not iterable' is a known corruption."""

Expand Down