fix(scheduler): keep logits processors aligned after batch row removal#1845
Conversation
mlx-lm's GenerationBatch.filter only reindexes logits_processors when at least one row slot is non-empty. After a request with no per-request processors finishes (every slot is [], the shape omlx inserts), removal shrinks uids but leaves the stale processor list behind. The next request then extends in behind its own row index: the row reads a leftover empty slot and its real processor (thinking budget, grammar constraint) is silently never applied. Alignment self-heals when the affected request finishes, so the symptom is an intermittently ignored thinking_budget or grammar constraint that depends on request order. Wrap GenerationBatch.filter alongside the existing _step chokepoint patch (jundot#934/jundot#1747): when the original guard would have skipped the reindex, reset the list to one empty slot per surviving row. Found while live-testing jundot#1844. thinking_budget was enforced on a fresh server but ignored when the previous request in the decode batch had no processors. Reproduced on /v1/chat/completions as well, so the bug is independent of the completions forwarding change.
|
Upstream status: this is a known mlx-lm bug with an open fix, ml-explore/mlx-lm#1225 (open since April, no review yet; the maintainer closed the sibling #1230 as a duplicate and its extend fix was folded into #1225 two days ago). I commented there with our repro and measurements. When the mlx-lm pin moves past that fix, _patched_generation_batch_filter becomes a no-op and can be dropped together with the source-level pin test; the behavior tests in tests/test_scheduler_logits_processors.py pin the aligned outcome rather than the patch mechanics, so they keep passing on a fixed mlx-lm. |
|
Thanks for the detailed repro and upstream context. I confirmed the current mlx-lm pin still has the GenerationBatch.filter stale-list behavior. I also checked the VLM path. mlx-vlm's own GenerationBatch.filter already reindexes/clears logits_processors, and this oMLX patch wraps mlx-lm's GenerationBatch used by the normal oMLX decode path, so I don't see a VLM conflict. I ran the logits-processor, scheduler, and MTP-adjacent tests locally, and this looks good to me. I'm going to merge it. |
Symptom
Per-request logits processors (thinking budget, grammar constraints) are silently dropped depending on request order. A request with
thinking_budget: 300reasons to its natural length as if the parameter were absent, but only when the previous request in the decode batch carried no processors. The same request on a fresh server enforces the budget exactly. Grammar-constrained requests can lose their constraint the same way.Cause
mlx-lm's
GenerationBatch.filteronly reindexeslogits_processorswhen at least one row slot is non-empty:There is no else branch (the prompt-batch class has one:
[[]] * len(keep)). After a request with no per-request processors finishes, every slot is[](the shape omlx inserts),any()is False, and removal shrinksuidswhile the stale processor list survives. The next request's row thenextend()s in behind its own index. I instrumented the patched_stepand caught the broken state directly:uids=(1,)with processor slots([], [ThinkingBudgetProcessor]). Row 0 reads the leftover empty slot; the real processor sits orphaned at index 1 and is never applied.The misalignment self-heals when the affected request finishes (the orphan makes
any()True again, so the next filter reindexes correctly). That makes the symptom intermittent and order-dependent, which is why it has been easy to miss.samplerssit behind the sameany()guard, but omlx always passes a real sampler object per row, so that list cannot go inert. I left it alone.Fix
Wrap
GenerationBatch.filternext to the existing_stepchokepoint patch (#934/#1747): snapshot whether the list was inert before the original filter runs, and if the original guard skipped the reindex, reset the list to one empty slot per surviving row. Also normalizes a None list, which would crash the original filter'sany().Testing
tests/test_scheduler_logits_processors.pyexercise the real mlx-lmfilter/extendbookkeeping on bareGenerationBatchinstances (no model load): the stale-list reset, the remove-then-extend alignment that reproduces the live failure, the active-processor reindex path (must not be clobbered), the None normalization, and a source-level pin on the patch installation. The behavior tests fail without the fix exactly as described (logits_processors == [[]]againstuids == []).thinking_budget: 300chat request produced 580 reasoning tokens before the fix (budget ignored) and 299 after. On a fresh server both builds produce 299, which is the order-dependence in one line.Context
Found while live-testing #1844 (thinking_budget on /v1/completions, #1825). The drop reproduces on /v1/chat/completions on current main, so it is independent of that PR. The underlying bug is in mlx-lm; if you would rather fix it upstream I can write that up, but the monkey-patch follows the established #934/#1747 pattern and protects omlx on the current pin either way.