generate.py: fix per-slot samplers/logits_processors bookkeeping in filter and extend#1225
Conversation
…-sequence lists `GenerationBatch.filter` (around line 1392) uses `if any(self.samplers)` and `if any(self.logits_processors)` to guard whether it trims those per-sequence lists. When every current slot is `None` / `[]` (the common shape for any in-flight request that does not attach a custom sampler or logits processor), `any(...)` is `False` and the trim is skipped. The lists keep stale length while `self.uids` is correctly trimmed. When new sequences subsequently arrive via `extend`, their entries are appended to the over-long list. `_step`'s per-sequence loop (`for e in range(len(self.uids)): for processor in self.logits_processors[e]:`) then reads stale-index slots — silently bypassing the new sequence's processor for exactly one generation step before the next `filter` corrects the length. The structurally-symmetric `PromptProcessingBatch.filter` (around line 1117) already handles this case with explicit `else` branches that reset to `[None] * len(keep)` / `[[]] * len(keep)`. This commit mirrors those branches onto `GenerationBatch.filter`. Discovered while running schema-constrained generation (`response_format=json_schema`, `strict: true`) in a production-shape FastAPI server. After any tool-bearing chat completion (which carries no `logits_processors`), the very next request's grammar processor was silently bypassed — model returned plain unconstrained text instead of schema-valid output. The bug was silent (no exception, no log) and affected exactly one request before self-recovering, which made it ordering-dependent and hard to spot.
|
Cross-reference: #1230 lands the symmetric fix in The two PRs touch different code paths (
Both #1225 and #1230 use this asymmetric pair ( |
…t per-slot logits_processors extend() fills missing per-slot logits_processors with [None] * N. Merging an unconfigured batch with a processor-equipped batch produces a mixed list ([None, ..., [fn], ...]); GenerationBatch._step then iterates self.logits_processors[e] under the any() guard and crashes with TypeError: 'NoneType' object is not iterable. Per-slot type is List[Callable], so the absent-value sentinel is [] — matching the existing [[]] * len(keep) in PromptProcessingBatch.filter. samplers keep the None sentinel (consumed as `self.samplers[e] or self.fallback_sampler`, type Optional[Callable]). Fix and regression test carried over from ml-explore#1230 (closed as duplicate of this PR; the two changes are companions in the same file — filter here, extend there). Co-authored-by: BLuchterhand <benlucht8@gmail.com>
|
@nastya236 Following your close of #1230 as a duplicate of this PR: the two fixes were actually in different code paths (
Both sentinels match the existing |
|
Independent confirmation. I hit the logits_processors half of this downstream in omlx (a continuous-batching server built on BatchGenerator) while live-testing a thinking-budget logits processor, and traced it to the same missing else branches in GenerationBatch.filter. What I observed, deterministic at temperature 0: a request carrying a thinking-budget processor produced 299 reasoning tokens (budget 300, enforced) when it hit a fresh decode batch, and 580 reasoning tokens (processor never applied) when the previous request in the batch carried no processors. Instrumenting _step caught the broken state directly: uids=(1,) with logits_processors=([], [ThinkingBudgetProcessor]). The row reads the stale empty slot at index 0 and the real processor sits orphaned at index 1. One correction to the PR description: the bypass is not limited to one generation step. filter only runs when a row is removed, so nothing corrects the list while the misaligned request is the only one in flight. It ran unconstrained for its entire generation (580 tokens) in my repro. The index math also gives two nastier variants when requests overlap. A later row that extends in behind the orphan reads the previous request's processor slot, so one request's grammar constraint or budget can be applied to a different request's tokens. And when any() is True because of the orphan, the next filter reindexes by row indices that are offset by the stale slots, which can drop the orphan and make the loss permanent for the surviving request. I have only measured the single-request bypass; the overlap variants follow from the same off-by-N indexing the diff fixes. The bug is silent in all cases: no exception, no log line, output that looks plausible. For schema-constrained or budget-constrained serving that is the worst failure shape. We shipped a downstream wrap of GenerationBatch.filter in the meantime (jundot/omlx#1845) and will drop it once this lands. The fix here matches what we converged on independently: reset to the per-row sentinel whenever the any() guard would skip the reindex, with [] for processors and None for samplers. |
Summary
GenerationBatch.filter(around line 1392 ofmlx_lm/generate.py)uses
if any(self.samplers)andif any(self.logits_processors)toguard whether it trims those per-sequence lists. When every current
slot is
None/[](the common shape for any in-flight requestthat does not attach a custom sampler or logits processor),
any(...)isFalseand the trim is skipped. The lists keep stalelength while
self.uidsis correctly trimmed.When new sequences subsequently arrive via
extend, their entriesare appended to the over-long list.
_step's per-sequence loop(
for e in range(len(self.uids)): for processor in self.logits_processors[e]:)then reads stale-index slots — silently bypassing the new sequence's
processor for exactly one generation step before the next
filtercorrects the length.
The structurally-symmetric
PromptProcessingBatch.filter(aroundline 1117) already handles this case with explicit
elsebranchesthat reset to
[None] * len(keep)/[[]] * len(keep). This PRmirrors those branches onto
GenerationBatch.filter.Impact (downstream)
Discovered while running schema-constrained generation
(
response_format=json_schemawithstrict: true) in aproduction-shape FastAPI server. After any tool-bearing chat
completion (which carries no
logits_processors), the very nextrequest's
xgrammar-based grammar processor was silently bypassed —the model returned plain unconstrained text instead of schema-valid
output. The bug is silent (no exception, no log, no error code) and
affects exactly one request before self-recovering, which made it
ordering-dependent and hard to spot.
Diff
The
[[]](logits_processors) and[None](samplers) defaults arecopied verbatim from
PromptProcessingBatch.filterso the twomethods stay symmetric.
Standalone reproducer (no model required)
Verified output (against
mlx-lm 0.31.3andmainat the time ofthis PR):
With the diff applied:
Suggested unit test
Related (different bugs)
mtp_generate_step: logits processors see stale prev_tokens on draft calls— different mechanism (token history notupdated in the MTP path), distinct from the list-length bug fixed
here.
Stateful logits processors see stale tokens due to lazy evaluation in stream context— also a processor-stalenessstory, but in a different code path (lazy eval ordering in streaming
generation).
The
filter-list-length bug fixed here appears unreported.Update 2026-06-10 — now also carries the
PromptProcessingBatch.extendfix from #1230@nastya236 closed #1230 as a duplicate of this PR. The two changes are
companions in the same file but fix different code paths —
filter(this PR's original scope) and
extend(#1230's scope) — so this PRnow includes both, making the single-merge view correct:
filter, original): else branches so the per-slotsamplers/logits_processorslists stay in lockstep withuidsafter trimming. Failure mode: silent — the next request's logits
processor is bypassed for one step (unconstrained output under
response_format=json_schema).extend, from fix(generate): avoid None entries in merged logits_processors #1230, credit @BLuchterhand):[[]] * Ninstead of[None] * Nfor absent per-slotlogits_processors. Merging an unconfigured batch with aprocessor-equipped batch otherwise produces
[None, ..., [fn], ...],and
_stepcrashes withTypeError: 'NoneType' object is not iterable. Failure mode: fatal for any heterogeneous batchedworkload (hit in production at batch_size=8, mixed plain +
grammar-constrained requests).
Both use the same asymmetric sentinel pair, matching the existing
PromptProcessingBatch.filter:Noneforsamplers[e](type
Optional[Callable], consumed viaor self.fallback_sampler)and
[]forlogits_processors[e](typeList[Callable], consumedby iteration).
Commit 2 carries over #1230's regression test
(
test_prompt_processing_batch_extend_mixes_logits_processors),verified red on the unpatched tree and green with the fix; the full
-k batchsuite intests/test_generate.pypasses (17/17).