Skip to content

generate.py: fix per-slot samplers/logits_processors bookkeeping in filter and extend#1225

Open
mloiterman wants to merge 2 commits into
ml-explore:mainfrom
mloiterman:fix/generation-batch-filter-stale-length
Open

generate.py: fix per-slot samplers/logits_processors bookkeeping in filter and extend#1225
mloiterman wants to merge 2 commits into
ml-explore:mainfrom
mloiterman:fix/generation-batch-filter-stale-length

Conversation

@mloiterman

@mloiterman mloiterman commented Apr 28, 2026

Copy link
Copy Markdown

Summary

GenerationBatch.filter (around line 1392 of mlx_lm/generate.py)
uses if any(self.samplers) and if any(self.logits_processors) to
guard whether it trims those per-sequence lists. When every current
slot is None / [] (the common shape for any in-flight request
that does not attach a custom sampler or logits processor),
any(...) is False and the trim is skipped. The lists keep stale
length while self.uids is correctly trimmed.

When new sequences subsequently arrive via extend, their entries
are appended to the over-long list. _step's per-sequence loop
(for e in range(len(self.uids)): for processor in self.logits_processors[e]:)
then reads stale-index slots — silently bypassing the new sequence's
processor for exactly one generation step before the next filter
corrects the length.

The structurally-symmetric PromptProcessingBatch.filter (around
line 1117) already handles this case with explicit else branches
that reset to [None] * len(keep) / [[]] * len(keep). This PR
mirrors those branches onto GenerationBatch.filter.

Impact (downstream)

Discovered while running schema-constrained generation
(response_format=json_schema with strict: true) in a
production-shape FastAPI server. After any tool-bearing chat
completion (which carries no logits_processors), the very next
request's xgrammar-based grammar processor was silently bypassed —
the model returned plain unconstrained text instead of schema-valid
output. The bug is silent (no exception, no log, no error code) and
affects exactly one request before self-recovering, which made it
ordering-dependent and hard to spot.

Diff

--- a/mlx_lm/generate.py
+++ b/mlx_lm/generate.py
@@ -1391,8 +1391,12 @@ class GenerationBatch:
         self.tokens = [self.tokens[idx] for idx in keep]
         if any(self.samplers):
             self.samplers = [self.samplers[idx] for idx in keep]
+        else:
+            self.samplers = [None] * len(keep)
         if any(self.logits_processors):
             self.logits_processors = [self.logits_processors[idx] for idx in keep]
+        else:
+            self.logits_processors = [[]] * len(keep)
         self.max_tokens = [self.max_tokens[idx] for idx in keep]
         self.state_machines = [self.state_machines[idx] for idx in keep]

The [[]] (logits_processors) and [None] (samplers) defaults are
copied verbatim from PromptProcessingBatch.filter so the two
methods stay symmetric.

Standalone reproducer (no model required)

"""Standalone reproducer for the GenerationBatch.filter stale-length bug.

Uses object.__new__(GenerationBatch) to bypass the model-requiring
__init__ — runs in milliseconds, no model load, no GPU.

Run:
    python upstream_repro.py

Exit code 0 = bug present (current upstream main).
Exit code 1 = bug fixed.
"""
from __future__ import annotations
from mlx_lm.generate import GenerationBatch


def make_bare_batch() -> GenerationBatch:
    b = object.__new__(GenerationBatch)
    b.model = None
    b.uids = []
    b.prompt_cache = []
    b.tokens = []
    b.samplers = []
    b.fallback_sampler = None
    b.logits_processors = []
    b.state_machines = []
    b.max_tokens = []
    b._next_tokens = None
    b._next_logprobs = []
    b._token_context = []
    b._num_tokens = []
    b._matcher_states = []
    return b


def populate_no_processor_request(b: GenerationBatch, uid: int) -> None:
    """A 'tool-style' request: no logits processors, no custom sampler.
    This is the common shape for any request that does not ask for
    grammar / penalty / thinking-budget logits modification."""
    b.uids.append(uid)
    b.prompt_cache.append(None)
    b.tokens.append([1, 2, 3])
    b.samplers.append(None)
    b.logits_processors.append([])
    b.state_machines.append(None)
    b.max_tokens.append(16)
    b._token_context.append(None)
    b._num_tokens.append(0)
    b._matcher_states.append(None)
    b._next_logprobs.append(None)


def main() -> int:
    b = make_bare_batch()
    populate_no_processor_request(b, uid=42)

    assert len(b.uids) == 1
    assert b.logits_processors == [[]]
    assert b.samplers == [None]

    # Simulate the request finishing — keep=[] means drop everything.
    b.filter(keep=[])

    print(f"after filter(keep=[]):")
    print(f"  uids               = {b.uids} (len={len(b.uids)})")
    print(f"  logits_processors  = {b.logits_processors} (len={len(b.logits_processors)})")
    print(f"  samplers           = {b.samplers} (len={len(b.samplers)})")

    bug = (
        len(b.logits_processors) != len(b.uids)
        or len(b.samplers) != len(b.uids)
    )
    if bug:
        print("\nBUG REPRODUCED — list length(s) do NOT match len(uids).")
        return 0
    print("\nFIX VERIFIED — all per-sequence lists have len == len(uids).")
    return 1


if __name__ == "__main__":
    raise SystemExit(main())

Verified output (against mlx-lm 0.31.3 and main at the time of
this PR):

after filter(keep=[]):
  uids               = [] (len=0)
  logits_processors  = [[]] (len=1)
  samplers           = [None] (len=1)

BUG REPRODUCED — list length(s) do NOT match len(uids).

With the diff applied:

after filter(keep=[]):
  uids               = [] (len=0)
  logits_processors  = [] (len=0)
  samplers           = [] (len=0)

FIX VERIFIED — all per-sequence lists have len == len(uids).

Suggested unit test

def test_generation_batch_filter_clears_logits_processors_when_all_empty():
    """GenerationBatch.filter must keep self.logits_processors length
    in lockstep with self.uids even when every per-sequence slot is
    empty. Regression: prior versions guarded the trim with
    `if any(self.logits_processors)` and skipped on all-`[]`, leaving
    a stale-length list that would later mis-index `_step`."""
    from mlx_lm.generate import GenerationBatch

    b = object.__new__(GenerationBatch)
    b.uids = [42]
    b.prompt_cache = []
    b.tokens = [[1, 2, 3]]
    b.samplers = [None]
    b.fallback_sampler = None
    b.logits_processors = [[]]
    b.state_machines = [None]
    b.max_tokens = [16]
    b._next_tokens = None
    b._next_logprobs = [None]
    b._token_context = [None]
    b._num_tokens = [0]
    b._matcher_states = [None]

    b.filter(keep=[])

    assert len(b.uids) == 0
    assert len(b.logits_processors) == 0
    assert len(b.samplers) == 0

Related (different bugs)

The filter-list-length bug fixed here appears unreported.


Update 2026-06-10 — now also carries the PromptProcessingBatch.extend fix from #1230

@nastya236 closed #1230 as a duplicate of this PR. The two changes are
companions in the same file but fix different code pathsfilter
(this PR's original scope) and extend (#1230's scope) — so this PR
now includes both, making the single-merge view correct:

  • Commit 1 (filter, original): else branches so the per-slot
    samplers / logits_processors lists stay in lockstep with uids
    after trimming. Failure mode: silent — the next request's logits
    processor is bypassed for one step (unconstrained output under
    response_format=json_schema).
  • Commit 2 (extend, from fix(generate): avoid None entries in merged logits_processors #1230, credit @BLuchterhand):
    [[]] * N instead of [None] * N for absent per-slot
    logits_processors. Merging an unconfigured batch with a
    processor-equipped batch otherwise produces [None, ..., [fn], ...],
    and _step crashes with TypeError: 'NoneType' object is not iterable. Failure mode: fatal for any heterogeneous batched
    workload (hit in production at batch_size=8, mixed plain +
    grammar-constrained requests).

Both use the same asymmetric sentinel pair, matching the existing
PromptProcessingBatch.filter: None for samplers[e]
(type Optional[Callable], consumed via or self.fallback_sampler)
and [] for logits_processors[e] (type List[Callable], consumed
by iteration).

Commit 2 carries over #1230's regression test
(test_prompt_processing_batch_extend_mixes_logits_processors),
verified red on the unpatched tree and green with the fix; the full
-k batch suite in tests/test_generate.py passes (17/17).

…-sequence lists

`GenerationBatch.filter` (around line 1392) uses `if any(self.samplers)`
and `if any(self.logits_processors)` to guard whether it trims those
per-sequence lists. When every current slot is `None` / `[]` (the common
shape for any in-flight request that does not attach a custom sampler
or logits processor), `any(...)` is `False` and the trim is skipped.
The lists keep stale length while `self.uids` is correctly trimmed.

When new sequences subsequently arrive via `extend`, their entries are
appended to the over-long list. `_step`'s per-sequence loop
(`for e in range(len(self.uids)): for processor in
self.logits_processors[e]:`) then reads stale-index slots — silently
bypassing the new sequence's processor for exactly one generation step
before the next `filter` corrects the length.

The structurally-symmetric `PromptProcessingBatch.filter` (around line
1117) already handles this case with explicit `else` branches that
reset to `[None] * len(keep)` / `[[]] * len(keep)`. This commit
mirrors those branches onto `GenerationBatch.filter`.

Discovered while running schema-constrained generation
(`response_format=json_schema`, `strict: true`) in a production-shape
FastAPI server. After any tool-bearing chat completion (which carries
no `logits_processors`), the very next request's grammar processor
was silently bypassed — model returned plain unconstrained text instead
of schema-valid output. The bug was silent (no exception, no log) and
affected exactly one request before self-recovering, which made it
ordering-dependent and hard to spot.
@mloiterman

Copy link
Copy Markdown
Author

Cross-reference: #1230 lands the symmetric fix in PromptProcessingBatch.extend, where [None] * N for absent per-slot logits_processors produces a list shape [None, ..., [fn], ...] after merging with a processor-equipped batch. That shape later crashes GenerationBatch._step at line 1346 with TypeError: 'NoneType' object is not iterable.

The two PRs touch different code paths (filter here, extend there) but share the same per-slot sentinel argument:

  • samplers[e] — type Optional[Callable]. Consumed at line 1358 as self.samplers[e] or self.fallback_sampler, so None is the correct sentinel.
  • logits_processors[e] — type List[Callable]. Consumed at line 1346 as for processor in self.logits_processors[e], so [] is the correct sentinel (None crashes the iterator).

Both #1225 and #1230 use this asymmetric pair (None vs []) and match the existing PromptProcessingBatch.filter pattern at line 1120. Safe to land independently or together.

@nastya236 nastya236 added the bug Something isn't working label Jun 8, 2026
…t per-slot logits_processors

extend() fills missing per-slot logits_processors with [None] * N. Merging
an unconfigured batch with a processor-equipped batch produces a mixed list
([None, ..., [fn], ...]); GenerationBatch._step then iterates
self.logits_processors[e] under the any() guard and crashes with
TypeError: 'NoneType' object is not iterable.

Per-slot type is List[Callable], so the absent-value sentinel is [] —
matching the existing [[]] * len(keep) in PromptProcessingBatch.filter.
samplers keep the None sentinel (consumed as `self.samplers[e] or
self.fallback_sampler`, type Optional[Callable]).

Fix and regression test carried over from ml-explore#1230 (closed as duplicate of
this PR; the two changes are companions in the same file — filter here,
extend there).

Co-authored-by: BLuchterhand <benlucht8@gmail.com>
@mloiterman mloiterman changed the title generate.py: GenerationBatch.filter — add else branches so logits_processors / samplers length stays in lockstep with uids generate.py: fix per-slot samplers/logits_processors bookkeeping in filter and extend Jun 10, 2026
@mloiterman

Copy link
Copy Markdown
Author

@nastya236 Following your close of #1230 as a duplicate of this PR: the two fixes were actually in different code paths (GenerationBatch.filter here, PromptProcessingBatch.extend there), so this PR didn't yet cover #1230's crash. I've now pushed #1230's extend fix and its regression test onto this branch (commit credit to @BLuchterhand), so one merge resolves both:

Both sentinels match the existing PromptProcessingBatch.filter pattern. The carried-over test is red on the unpatched tree, green with the fix; the -k batch suite in tests/test_generate.py passes 17/17 locally. PR title/description updated to reflect the combined scope.

@richgoodson

Copy link
Copy Markdown

Independent confirmation. I hit the logits_processors half of this downstream in omlx (a continuous-batching server built on BatchGenerator) while live-testing a thinking-budget logits processor, and traced it to the same missing else branches in GenerationBatch.filter.

What I observed, deterministic at temperature 0: a request carrying a thinking-budget processor produced 299 reasoning tokens (budget 300, enforced) when it hit a fresh decode batch, and 580 reasoning tokens (processor never applied) when the previous request in the batch carried no processors. Instrumenting _step caught the broken state directly: uids=(1,) with logits_processors=([], [ThinkingBudgetProcessor]). The row reads the stale empty slot at index 0 and the real processor sits orphaned at index 1.

One correction to the PR description: the bypass is not limited to one generation step. filter only runs when a row is removed, so nothing corrects the list while the misaligned request is the only one in flight. It ran unconstrained for its entire generation (580 tokens) in my repro. The index math also gives two nastier variants when requests overlap. A later row that extends in behind the orphan reads the previous request's processor slot, so one request's grammar constraint or budget can be applied to a different request's tokens. And when any() is True because of the orphan, the next filter reindexes by row indices that are offset by the stale slots, which can drop the orphan and make the loss permanent for the surviving request. I have only measured the single-request bypass; the overlap variants follow from the same off-by-N indexing the diff fixes.

The bug is silent in all cases: no exception, no log line, output that looks plausible. For schema-constrained or budget-constrained serving that is the worst failure shape.

We shipped a downstream wrap of GenerationBatch.filter in the meantime (jundot/omlx#1845) and will drop it once this lands. The fix here matches what we converged on independently: reset to the per-row sentinel whenever the any() guard would skip the reindex, with [] for processors and None for samplers.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants