From 2fdf41c3ac30fedc2cad03e0cf71312e51691eb3 Mon Sep 17 00:00:00 2001 From: efortin Date: Thu, 11 Jun 2026 17:33:51 +0200 Subject: [PATCH 1/8] fix(scheduler): realign per-row samplers and logits processors by uid #1799 made the generation step crash-safe by normalising None row slots, but the positional drift behind it was still there: a stale or offset slot left in samplers/logits_processors by batch extend/filter/split shifts every row after it, so a request can decode with another request's - or no - sampler and processors. Under concurrent mixed load this silently disables grammar constraints (json_schema responses come back as prose) and thinking budgets (unbounded reasoning), with no error anywhere. Record at insert time what each uid must run, and realign the positional lists from that registry at the step chokepoint - the same place as the #1799 normalisation. A rate-limited warning fires whenever a drift is actually corrected, so the silent corruption becomes observable. The registry is a bounded OrderedDict so a missed cleanup can never grow it unbounded. Repro (concurrent plain + constrained request, 0.3s apart): 10/10 thinking_budget violations and 5/5 grammar violations on main; 0/15 with this fix. Solo and sequential behavior unchanged, byte-identical at temperature 0. --- omlx/scheduler.py | 61 +++++++++++- tests/test_scheduler_logits_processors.py | 109 ++++++++++++++++++++++ 2 files changed, 169 insertions(+), 1 deletion(-) diff --git a/omlx/scheduler.py b/omlx/scheduler.py index fddf60806..803771da6 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -19,7 +19,7 @@ import os import threading import time -from collections import defaultdict, deque +from collections import OrderedDict, defaultdict, deque from collections.abc import Callable from contextlib import contextmanager from dataclasses import dataclass, field @@ -408,6 +408,30 @@ class _PrefillState: # After _original_step returns, self._next_tokens holds the freshly sampled # tokens. We eval them synchronously and accept in grammar processors. # --------------------------------------------------------------------------- +# Authoritative per-uid row state for the generation batch. +# +# mlx-lm keeps ``samplers`` / ``logits_processors`` as positional lists that +# must stay aligned with ``uids``. Heterogeneous continuous batching +# (extend/filter/split across prompt and generation batches) can leave stale +# or offset row slots behind; #1799 made the step crash-safe by normalising +# ``None`` slots, but a misaligned row silently runs the WRONG sampler and +# logits processors (e.g. a grammar/thinking-budget request decoding with no +# constraints at all). The registry below records, at insert time, what each +# uid is supposed to run; the step chokepoint realigns the positional lists +# from it. Bounded so a missing cleanup can never grow it unbounded. +_UID_ROW_REGISTRY_MAX = 4096 +_uid_row_registry: "OrderedDict[int, tuple[Any, list]]" = OrderedDict() + + +def _register_uid_rows(uids, samplers, lps_rows) -> None: + """Record the sampler and logits processors each freshly-inserted uid must run.""" + for uid, sampler, lps in zip(uids, samplers, lps_rows): + _uid_row_registry[uid] = (sampler, list(lps) if lps else []) + _uid_row_registry.move_to_end(uid) + while len(_uid_row_registry) > _UID_ROW_REGISTRY_MAX: + _uid_row_registry.popitem(last=False) + + _original_generation_batch_step = GenerationBatch._step @@ -446,6 +470,37 @@ def _patched_generation_batch_step(self): procs if procs is not None else [] for procs in self.logits_processors ] + # Realign per-row samplers and logits processors with ``uids`` from the + # per-uid registry. Positional drift (stale/offset slots left by batch + # extend/filter/split) otherwise makes a row run another request's — or + # no — sampler and processors, which silently disables grammar + # constraints and thinking budgets under concurrent mixed load. + cur_lps = self.logits_processors + cur_samplers = getattr(self, "samplers", None) or [] + new_lps = [] + new_samplers = [] + drift = len(cur_lps) != len(self.uids) + for i, uid in enumerate(self.uids): + entry = _uid_row_registry.get(uid) + if entry is not None: + sampler, lps = entry + if not drift and i < len(cur_lps) and cur_lps[i] is not lps and (cur_lps[i] or lps): + drift = drift or (cur_lps[i] != lps) + new_samplers.append(sampler) + new_lps.append(lps) + else: + new_samplers.append(cur_samplers[i] if i < len(cur_samplers) else None) + new_lps.append(cur_lps[i] if i < len(cur_lps) else []) + if drift: + logger.warning( + "Realigned generation-batch row state from the uid registry " + f"(uids={list(self.uids)}, had {len(cur_lps)} processor slots); " + "stale or offset slots would have run the wrong sampler/processors." + ) + self.logits_processors = new_lps + self.samplers = new_samplers + + result = _original_generation_batch_step(self) # self._next_tokens contains the just-sampled tokens (async eval pending). @@ -3479,6 +3534,8 @@ def _insert_prefilled_request( logits_processors=[per_row_lps], state_machines=[state.sm], ) + if uids: + _register_uid_rows(uids, [state.sampler], [per_row_lps]) if uids: uid = uids[0] @@ -7031,6 +7088,8 @@ def _sparse_progress(processed: int, total: int) -> None: logits_processors=[per_row_lps], state_machines=[sm], ) + if uids: + _register_uid_rows(uids, [sampler], [per_row_lps]) if uids: uid = uids[0] diff --git a/tests/test_scheduler_logits_processors.py b/tests/test_scheduler_logits_processors.py index d5b1e707f..b612db1c2 100644 --- a/tests/test_scheduler_logits_processors.py +++ b/tests/test_scheduler_logits_processors.py @@ -411,3 +411,112 @@ def identity_processor(token_context, logits): bg.next_generated() bg.close() + + +class TestRowRealignment: + """Pin the uid-registry realignment (#1823). + + Stale or offset row slots left by batch extend/filter/split shift every + row after them, so a request silently runs another request's — or no — + sampler and logits processors. The #1799 normalisation makes the step + crash-safe but cannot restore alignment; the chokepoint must realign + the positional lists from the per-uid registry.""" + + def test_patched_step_realigns_offset_rows_from_registry(self, monkeypatch): + """The #1823 probe scenario: three processor slots for two uids. + + A stale leading slot (left by a finished request) offsets every row: + the constrained request's processors sit in a slot nothing reads, + and its row runs an empty one. Red before the registry realignment + (the wrapped step sees the offset rows: uid 2 runs no processors); + green after (uid 2's row runs its own sampler and processors). + """ + from collections import OrderedDict + + import omlx.scheduler as scheduler + + monkeypatch.setattr(scheduler, "_uid_row_registry", OrderedDict()) + + captured = {} + + def fake_original_step(self): + captured["logits_processors"] = list(self.logits_processors) + captured["samplers"] = list(self.samplers) + return "stepped" + + monkeypatch.setattr( + scheduler, "_original_generation_batch_step", fake_original_step + ) + + def budget_processor(token_context, logits): + return logits + + def grammar_processor(token_context, logits): + return logits + + sampler_uid2 = object() + + # What the insert sites record: uid 1 is a plain request, uid 2 is + # the constrained one (grammar + thinking budget). + scheduler._register_uid_rows([1], [None], [[]]) + scheduler._register_uid_rows( + [2], [sampler_uid2], [[budget_processor, grammar_processor]] + ) + + class FakeModel: + pass + + class FakeBatch: + model = FakeModel() + uids = [1, 2] + # Stale leading slot from a finished request: 3 slots, 2 uids. + logits_processors = [[], [], [budget_processor, grammar_processor]] + samplers = [None, None, sampler_uid2] + _next_tokens = None + + batch = FakeBatch() + result = scheduler._patched_generation_batch_step(batch) + + assert result == "stepped" + # uid 2's row must run ITS processors and sampler, not the offset ones. + assert captured["logits_processors"][1] == [ + budget_processor, + grammar_processor, + ] + assert captured["samplers"][1] is sampler_uid2 + # Alignment restored: exactly one slot per uid. + assert len(batch.logits_processors) == len(batch.uids) + assert len(batch.samplers) == len(batch.uids) + + def test_registry_is_bounded(self): + """A missed cleanup must never grow the registry unbounded.""" + from collections import OrderedDict + + import omlx.scheduler as scheduler + + registry = OrderedDict() + original = scheduler._uid_row_registry + scheduler._uid_row_registry = registry + try: + for uid in range(scheduler._UID_ROW_REGISTRY_MAX + 100): + scheduler._register_uid_rows([uid], [None], [[]]) + assert len(registry) == scheduler._UID_ROW_REGISTRY_MAX + # Oldest entries evicted first. + assert 0 not in registry + assert scheduler._UID_ROW_REGISTRY_MAX + 99 in registry + finally: + scheduler._uid_row_registry = original + + def test_scheduler_source_registers_rows_at_insert(self): + """Source-level guard: both insert sites must record what each uid + is supposed to run, or the chokepoint has nothing to realign from.""" + from pathlib import Path + + scheduler_src = ( + Path(__file__).resolve().parents[1] / "omlx" / "scheduler.py" + ).read_text() + assert scheduler_src.count("_register_uid_rows(uids") >= 2, ( + "every batch_generator.insert call site must register the " + "per-uid sampler and logits processors; the step chokepoint " + "realigns rows from that registry. See #1823." + ) From e6f53e4f71a866565c29bd815d34489aaea1aaf1 Mon Sep 17 00:00:00 2001 From: efortin Date: Thu, 11 Jun 2026 19:04:24 +0200 Subject: [PATCH 2/8] test(scheduler): pin registry fallback and short-slot padding --- tests/test_scheduler_logits_processors.py | 44 +++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/tests/test_scheduler_logits_processors.py b/tests/test_scheduler_logits_processors.py index b612db1c2..0ce3c7ed9 100644 --- a/tests/test_scheduler_logits_processors.py +++ b/tests/test_scheduler_logits_processors.py @@ -520,3 +520,47 @@ def test_scheduler_source_registers_rows_at_insert(self): "per-uid sampler and logits processors; the step chokepoint " "realigns rows from that registry. See #1823." ) + + def test_unregistered_uid_keeps_current_row_and_short_slots_pad(self, monkeypatch): + """Realignment must not invent state: a uid missing from the registry + keeps its current row, and missing trailing slots pad to empty + instead of raising.""" + from collections import OrderedDict + + import omlx.scheduler as scheduler + + monkeypatch.setattr(scheduler, "_uid_row_registry", OrderedDict()) + + captured = {} + + def fake_original_step(self): + captured["logits_processors"] = list(self.logits_processors) + captured["samplers"] = list(self.samplers) + return "stepped" + + monkeypatch.setattr( + scheduler, "_original_generation_batch_step", fake_original_step + ) + + def legacy_processor(token_context, legacy_logits): + return legacy_logits + + class FakeModel: + pass + + class FakeBatch: + model = FakeModel() + uids = [7, 8] + # uid 7 is not registered but carries a live row: keep it. + # uid 8 has no slot at all (shorter list): pad to []. + logits_processors = [[legacy_processor]] + samplers = [None] + _next_tokens = None + + batch = FakeBatch() + result = scheduler._patched_generation_batch_step(batch) + + assert result == "stepped" + assert captured["logits_processors"][0] == [legacy_processor] + assert captured["logits_processors"][1] == [] + assert len(batch.samplers) == len(batch.uids) From 191a5cb373b1546e0851a13873bebe1bdfd64968 Mon Sep 17 00:00:00 2001 From: efortin Date: Thu, 11 Jun 2026 19:39:38 +0200 Subject: [PATCH 3/8] fix(scheduler): key the row registry by model and release rows on completion mlx-lm's BatchGenerator numbers uids per instance starting at zero, so two engines serving concurrently (or an engine reload) produce colliding uid sequences; a registry keyed by uid alone could install one model's sampler and processors on another model's row. Key entries by (id(model), uid) instead - the model object is the one identity both the insert sites and the step chokepoint can see. Also guard the registry with a lock (engines run on separate executor threads and share the module-level dict) and release a request's row at the existing uid-map cleanup site, so heavy grammar processors are not pinned until LRU eviction; the bounded size stays as the backstop. New tests: same uid on two models must not cross-contaminate, and completion unregister drops the row (double-unregister is a no-op). --- omlx/scheduler.py | 40 +++++++--- tests/test_scheduler_logits_processors.py | 96 ++++++++++++++++++++--- 2 files changed, 116 insertions(+), 20 deletions(-) diff --git a/omlx/scheduler.py b/omlx/scheduler.py index 803771da6..6f3423519 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -420,16 +420,32 @@ class _PrefillState: # uid is supposed to run; the step chokepoint realigns the positional lists # from it. Bounded so a missing cleanup can never grow it unbounded. _UID_ROW_REGISTRY_MAX = 4096 -_uid_row_registry: "OrderedDict[int, tuple[Any, list]]" = OrderedDict() +# Keyed by (id(model), uid): mlx-lm's BatchGenerator numbers uids per +# instance starting at 0, so two engines serving concurrently (or an engine +# reload) produce colliding uid sequences. The model object is the one +# identity both the insert sites and the step chokepoint can see. +_uid_row_registry: "OrderedDict[tuple[int, int], tuple[Any, list]]" = OrderedDict() +# Engines run on separate executor threads and share this module-level +# registry; a plain OrderedDict is not safe under concurrent mutation. +_uid_row_registry_lock = threading.Lock() -def _register_uid_rows(uids, samplers, lps_rows) -> None: +def _register_uid_rows(model, uids, samplers, lps_rows) -> None: """Record the sampler and logits processors each freshly-inserted uid must run.""" - for uid, sampler, lps in zip(uids, samplers, lps_rows): - _uid_row_registry[uid] = (sampler, list(lps) if lps else []) - _uid_row_registry.move_to_end(uid) - while len(_uid_row_registry) > _UID_ROW_REGISTRY_MAX: - _uid_row_registry.popitem(last=False) + with _uid_row_registry_lock: + for uid, sampler, lps in zip(uids, samplers, lps_rows): + key = (id(model), uid) + _uid_row_registry[key] = (sampler, list(lps) if lps else []) + _uid_row_registry.move_to_end(key) + while len(_uid_row_registry) > _UID_ROW_REGISTRY_MAX: + _uid_row_registry.popitem(last=False) + + +def _unregister_uid_row(model, uid) -> None: + """Drop a finished request's row so heavy processors are not pinned + until LRU eviction; the bounded size stays as the backstop.""" + with _uid_row_registry_lock: + _uid_row_registry.pop((id(model), uid), None) _original_generation_batch_step = GenerationBatch._step @@ -480,8 +496,11 @@ def _patched_generation_batch_step(self): new_lps = [] new_samplers = [] drift = len(cur_lps) != len(self.uids) + model_id = id(getattr(self, "model", None)) + with _uid_row_registry_lock: + registry_rows = [_uid_row_registry.get((model_id, uid)) for uid in self.uids] for i, uid in enumerate(self.uids): - entry = _uid_row_registry.get(uid) + entry = registry_rows[i] if entry is not None: sampler, lps = entry if not drift and i < len(cur_lps) and cur_lps[i] is not lps and (cur_lps[i] or lps): @@ -1797,6 +1816,7 @@ def _drain_pending_async_removes(self) -> bool: e, ) # Cleanup uid maps now that the slot is reclaimable. + _unregister_uid_row(self.model, uid) if uid in self.uid_to_request_id: del self.uid_to_request_id[uid] if request_id in self.request_id_to_uid: @@ -3535,7 +3555,7 @@ def _insert_prefilled_request( state_machines=[state.sm], ) if uids: - _register_uid_rows(uids, [state.sampler], [per_row_lps]) + _register_uid_rows(self.model, uids, [state.sampler], [per_row_lps]) if uids: uid = uids[0] @@ -7089,7 +7109,7 @@ def _sparse_progress(processed: int, total: int) -> None: state_machines=[sm], ) if uids: - _register_uid_rows(uids, [sampler], [per_row_lps]) + _register_uid_rows(self.model, uids, [sampler], [per_row_lps]) if uids: uid = uids[0] diff --git a/tests/test_scheduler_logits_processors.py b/tests/test_scheduler_logits_processors.py index 0ce3c7ed9..1d024a136 100644 --- a/tests/test_scheduler_logits_processors.py +++ b/tests/test_scheduler_logits_processors.py @@ -456,13 +456,6 @@ def grammar_processor(token_context, logits): sampler_uid2 = object() - # What the insert sites record: uid 1 is a plain request, uid 2 is - # the constrained one (grammar + thinking budget). - scheduler._register_uid_rows([1], [None], [[]]) - scheduler._register_uid_rows( - [2], [sampler_uid2], [[budget_processor, grammar_processor]] - ) - class FakeModel: pass @@ -474,6 +467,13 @@ class FakeBatch: samplers = [None, None, sampler_uid2] _next_tokens = None + # What the insert sites record: uid 1 is a plain request, uid 2 is + # the constrained one (grammar + thinking budget). + scheduler._register_uid_rows(FakeBatch.model, [1], [None], [[]]) + scheduler._register_uid_rows( + FakeBatch.model, [2], [sampler_uid2], [[budget_processor, grammar_processor]] + ) + batch = FakeBatch() result = scheduler._patched_generation_batch_step(batch) @@ -497,13 +497,14 @@ def test_registry_is_bounded(self): registry = OrderedDict() original = scheduler._uid_row_registry scheduler._uid_row_registry = registry + model = object() try: for uid in range(scheduler._UID_ROW_REGISTRY_MAX + 100): - scheduler._register_uid_rows([uid], [None], [[]]) + scheduler._register_uid_rows(model, [uid], [None], [[]]) assert len(registry) == scheduler._UID_ROW_REGISTRY_MAX # Oldest entries evicted first. - assert 0 not in registry - assert scheduler._UID_ROW_REGISTRY_MAX + 99 in registry + assert (id(model), 0) not in registry + assert (id(model), scheduler._UID_ROW_REGISTRY_MAX + 99) in registry finally: scheduler._uid_row_registry = original @@ -564,3 +565,78 @@ class FakeBatch: assert captured["logits_processors"][0] == [legacy_processor] assert captured["logits_processors"][1] == [] assert len(batch.samplers) == len(batch.uids) + + + def test_same_uid_on_two_models_does_not_cross_contaminate(self, monkeypatch): + """mlx-lm numbers uids per BatchGenerator instance, so two engines + serving concurrently produce colliding uid values. The registry must + key by model so engine A's realignment never installs engine B's + sampler and processors.""" + from collections import OrderedDict + + import omlx.scheduler as scheduler + + monkeypatch.setattr(scheduler, "_uid_row_registry", OrderedDict()) + + captured = {} + + def fake_original_step(self): + captured[id(self.model)] = list(self.logits_processors) + return "stepped" + + monkeypatch.setattr( + scheduler, "_original_generation_batch_step", fake_original_step + ) + + def qwen_processor(token_context, logits): + return logits + + def gemma_processor(token_context, logits): + return logits + + class FakeModel: + pass + + model_a, model_b = FakeModel(), FakeModel() + # SAME uid value on both engines, different processors. + scheduler._register_uid_rows(model_a, [7], [None], [[qwen_processor]]) + scheduler._register_uid_rows(model_b, [7], [None], [[gemma_processor]]) + + def make_batch(model): + class FakeBatch: + pass + + b = FakeBatch() + b.model = model + b.uids = [7] + b.logits_processors = [[]] + b.samplers = [None] + b._next_tokens = None + return b + + scheduler._patched_generation_batch_step(make_batch(model_a)) + scheduler._patched_generation_batch_step(make_batch(model_b)) + + assert captured[id(model_a)][0] == [qwen_processor] + assert captured[id(model_b)][0] == [gemma_processor] + + def test_unregister_drops_the_row(self): + """Completion cleanup must release the row so heavy processors are + not pinned until LRU eviction.""" + from collections import OrderedDict + + import omlx.scheduler as scheduler + + registry = OrderedDict() + original = scheduler._uid_row_registry + scheduler._uid_row_registry = registry + model = object() + try: + scheduler._register_uid_rows(model, [3], [None], [[object()]]) + assert (id(model), 3) in registry + scheduler._unregister_uid_row(model, 3) + assert (id(model), 3) not in registry + # Unregistering twice (or an unknown uid) is a no-op. + scheduler._unregister_uid_row(model, 3) + finally: + scheduler._uid_row_registry = original From bba0c1236970be56555b87d24f153c004c488d47 Mon Sep 17 00:00:00 2001 From: efortin Date: Thu, 11 Jun 2026 19:47:36 +0200 Subject: [PATCH 4/8] test(scheduler): update the insert-site guard for the model-keyed registry signature --- tests/test_scheduler_logits_processors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_scheduler_logits_processors.py b/tests/test_scheduler_logits_processors.py index 1d024a136..0b81141d5 100644 --- a/tests/test_scheduler_logits_processors.py +++ b/tests/test_scheduler_logits_processors.py @@ -516,7 +516,7 @@ def test_scheduler_source_registers_rows_at_insert(self): scheduler_src = ( Path(__file__).resolve().parents[1] / "omlx" / "scheduler.py" ).read_text() - assert scheduler_src.count("_register_uid_rows(uids") >= 2, ( + assert scheduler_src.count("_register_uid_rows(self.model, uids") >= 2, ( "every batch_generator.insert call site must register the " "per-uid sampler and logits processors; the step chokepoint " "realigns rows from that registry. See #1823." From 4448f37d2e52a3ef4f8887005936d3a1418ca0f1 Mon Sep 17 00:00:00 2001 From: efortin Date: Thu, 11 Jun 2026 20:02:42 +0200 Subject: [PATCH 5/8] refactor(scheduler): drop the redundant move_to_end in the row registry Each (model, uid) key is inserted exactly once per request, so the OrderedDict already appends it last; plain insertion order is all the oldest-first backstop eviction needs. One dict operation per row instead of two, and the lps copy reads as list(lps or ()). --- omlx/scheduler.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/omlx/scheduler.py b/omlx/scheduler.py index 6f3423519..2ddbd2373 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -431,12 +431,14 @@ class _PrefillState: def _register_uid_rows(model, uids, samplers, lps_rows) -> None: - """Record the sampler and logits processors each freshly-inserted uid must run.""" + """Record the sampler and logits processors each freshly-inserted uid must run. + + Each (model, uid) key is inserted exactly once per request, so plain + insertion order is enough for the oldest-first backstop eviction. + """ with _uid_row_registry_lock: for uid, sampler, lps in zip(uids, samplers, lps_rows): - key = (id(model), uid) - _uid_row_registry[key] = (sampler, list(lps) if lps else []) - _uid_row_registry.move_to_end(key) + _uid_row_registry[(id(model), uid)] = (sampler, list(lps or ())) while len(_uid_row_registry) > _UID_ROW_REGISTRY_MAX: _uid_row_registry.popitem(last=False) From b019488c8ddd1c2cf2ec8a8ea4c78fce69e0c12f Mon Sep 17 00:00:00 2001 From: efortin Date: Thu, 11 Jun 2026 21:59:52 +0200 Subject: [PATCH 6/8] fix(scheduler): release registry rows on abort, recovery, reset, and shutdown MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The row registry was only released in the deferred async-remove path, so a request finishing through the synchronous fallback, an aborted request, or a generator reset (failure recovery, reset(), shutdown()) kept its sampler and logits processors pinned until the FIFO backstop — and entries surviving an engine unload were exactly the residue a recycled id(model) could later match. Release the row in every path that retires a uid, and release by model wherever the uid maps are cleared wholesale. Rate-limit the drift warning to one WARNING per 60s window (DEBUG otherwise) so a pathological merge pattern cannot flood the logs. Tests: structural AST checks that every retirement path releases its rows, a model-scoped clear unit test, the pre-fix offset behavior pinned through the empty-registry fallback, and the warning rate limit. --- omlx/scheduler.py | 40 ++++- tests/test_scheduler_logits_processors.py | 174 +++++++++++++++++++++- 2 files changed, 210 insertions(+), 4 deletions(-) diff --git a/omlx/scheduler.py b/omlx/scheduler.py index 2ddbd2373..793be8274 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -428,6 +428,11 @@ class _PrefillState: # Engines run on separate executor threads and share this module-level # registry; a plain OrderedDict is not safe under concurrent mutation. _uid_row_registry_lock = threading.Lock() +# Drift corrections are worth one log line each, but a pathological batching +# pattern could correct on every merge; cap the WARNING rate and route the +# rest to DEBUG so the signal survives without flooding the logs. +_UID_ROW_DRIFT_WARNING_INTERVAL_S = 60.0 +_uid_row_drift_last_warning = float("-inf") def _register_uid_rows(model, uids, samplers, lps_rows) -> None: @@ -445,11 +450,25 @@ def _register_uid_rows(model, uids, samplers, lps_rows) -> None: def _unregister_uid_row(model, uid) -> None: """Drop a finished request's row so heavy processors are not pinned - until LRU eviction; the bounded size stays as the backstop.""" + until FIFO eviction; the bounded size stays as the backstop.""" with _uid_row_registry_lock: _uid_row_registry.pop((id(model), uid), None) +def _unregister_uid_rows_for_model(model) -> None: + """Drop every registry row for a model (generator reset, recovery, shutdown). + + The recovery and reset paths clear the uid maps wholesale instead of + finishing requests one by one; releasing by model covers them, and leaves + nothing behind that a later engine load could match if ``id(model)`` were + recycled. + """ + model_id = id(model) + with _uid_row_registry_lock: + for key in [key for key in _uid_row_registry if key[0] == model_id]: + del _uid_row_registry[key] + + _original_generation_batch_step = GenerationBatch._step @@ -513,7 +532,12 @@ def _patched_generation_batch_step(self): new_samplers.append(cur_samplers[i] if i < len(cur_samplers) else None) new_lps.append(cur_lps[i] if i < len(cur_lps) else []) if drift: - logger.warning( + global _uid_row_drift_last_warning + now = time.monotonic() + rate_limited = now - _uid_row_drift_last_warning < _UID_ROW_DRIFT_WARNING_INTERVAL_S + if not rate_limited: + _uid_row_drift_last_warning = now + (logger.debug if rate_limited else logger.warning)( "Realigned generation-batch row state from the uid registry " f"(uids={list(self.uids)}, had {len(cur_lps)} processor slots); " "stale or offset slots would have run the wrong sampler/processors." @@ -5912,6 +5936,7 @@ def _do_abort_request(self, request_id: str) -> bool: close = getattr(mtp_state.generator, "close", None) if callable(close): close() + _unregister_uid_row(self.model, uid) del self.uid_to_request_id[uid] del self.request_id_to_uid[request.request_id] @@ -6077,7 +6102,9 @@ def fail_all_requests(self) -> list[str]: if uid is not None: self.uid_to_request_id.pop(uid, None) self._generation_overflow_recovery_ids.difference_update(failed_ids) - # Reset batch generator only (cache is not corrupted) + # Reset batch generator only (cache is not corrupted). Every row dies + # with it; survivors re-register at re-insert. + _unregister_uid_rows_for_model(self.model) self.batch_generator = None self._current_sampler_params = None # Reclaim fragmented Metal buffers after generation failure. @@ -7649,6 +7676,7 @@ def _cleanup_finished(self, finished_ids: set[str]) -> None: self._remove_uid_from_active_batch(uid) if hasattr(self.model, "unregister_rope_delta"): self.model.unregister_rope_delta(uid) + _unregister_uid_row(self.model, uid) if uid in self.uid_to_request_id: del self.uid_to_request_id[uid] del self.request_id_to_uid[request_id] @@ -7766,6 +7794,7 @@ def _recover_from_cache_error(self) -> None: self._cache_rate_tracker.clear() # Clear UID mappings + _unregister_uid_rows_for_model(self.model) self.request_id_to_uid.clear() self.uid_to_request_id.clear() @@ -7795,6 +7824,7 @@ def _recover_from_generation_overflow_error(self) -> None: if hasattr(self.model, "clear_pending_embeddings"): self.model.clear_pending_embeddings() + _unregister_uid_rows_for_model(self.model) self.request_id_to_uid.clear() self.uid_to_request_id.clear() self._deferred_clear_at = None @@ -8336,6 +8366,7 @@ def reset(self) -> None: self.running.clear() self.requests.clear() self.finished_req_ids.clear() + _unregister_uid_rows_for_model(self.model) self.request_id_to_uid.clear() self.uid_to_request_id.clear() self._generation_overflow_recovery_ids.clear() @@ -8450,6 +8481,9 @@ def shutdown(self) -> None: if self.paged_ssd_cache_manager is not None: self.paged_ssd_cache_manager.close() self.paged_ssd_cache_manager = None + # Release whatever the per-path unregisters did not reach, so nothing + # survives this engine in the module-level row registry. + _unregister_uid_rows_for_model(self.model) logger.info("Scheduler shutdown completed") def adjust_store_cache_cap(self, pressure_level: str) -> None: diff --git a/tests/test_scheduler_logits_processors.py b/tests/test_scheduler_logits_processors.py index 0b81141d5..7b19e5c62 100644 --- a/tests/test_scheduler_logits_processors.py +++ b/tests/test_scheduler_logits_processors.py @@ -622,7 +622,7 @@ class FakeBatch: def test_unregister_drops_the_row(self): """Completion cleanup must release the row so heavy processors are - not pinned until LRU eviction.""" + not pinned until FIFO eviction.""" from collections import OrderedDict import omlx.scheduler as scheduler @@ -640,3 +640,175 @@ def test_unregister_drops_the_row(self): scheduler._unregister_uid_row(model, 3) finally: scheduler._uid_row_registry = original + + def test_model_scoped_clear_drops_only_that_model(self): + """Reset/recovery/shutdown release by model: every row of the reset + engine goes, every row of the other engine stays.""" + from collections import OrderedDict + + import omlx.scheduler as scheduler + + registry = OrderedDict() + original = scheduler._uid_row_registry + scheduler._uid_row_registry = registry + model_a, model_b = object(), object() + try: + scheduler._register_uid_rows(model_a, [0, 1], [None, None], [[], []]) + scheduler._register_uid_rows(model_b, [0], [None], [[]]) + scheduler._unregister_uid_rows_for_model(model_a) + assert (id(model_a), 0) not in registry + assert (id(model_a), 1) not in registry + assert (id(model_b), 0) in registry + # Clearing an unknown model is a no-op. + scheduler._unregister_uid_rows_for_model(object()) + assert (id(model_b), 0) in registry + finally: + scheduler._uid_row_registry = original + + def test_offset_rows_pass_through_without_registry(self, monkeypatch): + """The pre-fix behavior, pinned through the fallback path: with + nothing registered, the chokepoint cannot restore alignment, so the + #1823 probe shape (three slots for two uids) reaches the step with + uid 2 running no processors. This is the exact silent failure the + registry realignment corrects in + ``test_patched_step_realigns_offset_rows_from_registry``.""" + from collections import OrderedDict + + import omlx.scheduler as scheduler + + monkeypatch.setattr(scheduler, "_uid_row_registry", OrderedDict()) + + captured = {} + + def fake_original_step(self): + captured["logits_processors"] = list(self.logits_processors) + return "stepped" + + monkeypatch.setattr( + scheduler, "_original_generation_batch_step", fake_original_step + ) + + def grammar_processor(token_context, logits): + return logits + + class FakeModel: + pass + + class FakeBatch: + model = FakeModel() + uids = [1, 2] + # Stale leading slot: uid 2's processors sit in slot 2, which the + # two-uid loop never reads. + logits_processors = [[], [], [grammar_processor]] + samplers = [None, None, object()] + _next_tokens = None + + scheduler._patched_generation_batch_step(FakeBatch()) + + # Without registry rows the constrained request silently decodes + # unconstrained — the pre-#1824 behavior. + assert captured["logits_processors"][1] == [] + + def test_drift_warning_is_rate_limited(self, monkeypatch, caplog): + """One drift correction per window logs at WARNING; the rest go to + DEBUG so a pathological merge pattern cannot flood the logs.""" + import logging + from collections import OrderedDict + + import omlx.scheduler as scheduler + + monkeypatch.setattr(scheduler, "_uid_row_registry", OrderedDict()) + monkeypatch.setattr(scheduler, "_uid_row_drift_last_warning", float("-inf")) + monkeypatch.setattr( + scheduler, "_original_generation_batch_step", lambda self: "stepped" + ) + + def make_misaligned_batch(): + class FakeModel: + pass + + class FakeBatch: + pass + + batch = FakeBatch() + batch.model = FakeModel() + batch.uids = [1] + # One stale slot too many: drift on every call. + batch.logits_processors = [[], []] + batch.samplers = [None, None] + batch._next_tokens = None + return batch + + with caplog.at_level(logging.DEBUG, logger=scheduler.logger.name): + scheduler._patched_generation_batch_step(make_misaligned_batch()) + scheduler._patched_generation_batch_step(make_misaligned_batch()) + + realign_levels = [ + record.levelno + for record in caplog.records + if "Realigned generation-batch row state" in record.getMessage() + ] + assert realign_levels == [logging.WARNING, logging.DEBUG] + + +class TestRegistryCleanupPaths: + """Every path that retires a uid — or the whole generator — must release + its registry rows. A finished, aborted, or failed request that stays + registered pins its (possibly heavy, stateful) processors until the FIFO + backstop, and entries surviving a generator reset or engine unload are + exactly the residue an ``id(model)`` recycle could later match. + + Structural AST checks: cheaper than spinning up a Scheduler per path, + and immune to formatting churn (unlike substring counting).""" + + PER_UID_RELEASE_PATHS = [ + "_drain_pending_async_removes", + "_do_abort_request", + "_cleanup_finished", + ] + MODEL_WIDE_RELEASE_PATHS = [ + "fail_all_requests", + "_recover_from_cache_error", + "_recover_from_generation_overflow_error", + "reset", + "shutdown", + ] + + @staticmethod + def _called_names(func_name: str) -> set: + import ast + from pathlib import Path + + source = ( + Path(__file__).resolve().parents[1] / "omlx" / "scheduler.py" + ).read_text() + tree = ast.parse(source) + for node in ast.walk(tree): + if ( + isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) + and node.name == func_name + ): + return { + call.func.id + if isinstance(call.func, ast.Name) + else getattr(call.func, "attr", None) + for call in ast.walk(node) + if isinstance(call, ast.Call) + } + raise AssertionError(f"{func_name} not found in scheduler.py") + + @pytest.mark.parametrize("func_name", PER_UID_RELEASE_PATHS) + def test_per_uid_paths_release_the_row(self, func_name): + assert "_unregister_uid_row" in self._called_names(func_name), ( + f"{func_name} retires a uid from the batch but does not release " + "its registry row; the processors stay pinned until the FIFO " + "backstop. See #1823." + ) + + @pytest.mark.parametrize("func_name", MODEL_WIDE_RELEASE_PATHS) + def test_model_wide_paths_release_every_row(self, func_name): + assert "_unregister_uid_rows_for_model" in self._called_names(func_name), ( + f"{func_name} clears the uid maps (or retires the generator) " + "wholesale but leaves the registry rows behind; release by model " + "so nothing survives a reset, recovery, or shutdown. See #1823." + ) From 5f59d6abbac30e05c5ffdcd0a861cf448ec8578a Mon Sep 17 00:00:00 2001 From: efortin Date: Fri, 12 Jun 2026 08:09:38 +0200 Subject: [PATCH 7/8] refactor(scheduler): extract row realignment helpers Pull the chokepoint realignment into a pure _realigned_rows function, name the drift comparison rules (_row_drifted: identity fast path stays inline at the caller, two fresh empty lists are equivalent, content comparison last), move the rate-limited log into _log_drift_correction, and type the registry entries as a NamedTuple. No behavior change; the rebuilt lists are still installed unconditionally and drift only drives logging. Two direct unit tests on the pure rebuild. --- omlx/scheduler.py | 110 +++++++++++++++------- tests/test_scheduler_logits_processors.py | 48 ++++++++++ 2 files changed, 122 insertions(+), 36 deletions(-) diff --git a/omlx/scheduler.py b/omlx/scheduler.py index 793be8274..55c43cad1 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -25,7 +25,7 @@ from dataclasses import dataclass, field from enum import Enum from pathlib import Path -from typing import Any, Optional +from typing import Any, NamedTuple, Optional import mlx.core as mx from mlx_lm.generate import ( @@ -419,12 +419,19 @@ class _PrefillState: # constraints at all). The registry below records, at insert time, what each # uid is supposed to run; the step chokepoint realigns the positional lists # from it. Bounded so a missing cleanup can never grow it unbounded. +class _RegisteredRow(NamedTuple): + """What a uid is supposed to run, recorded at request insert.""" + + sampler: Any + logits_processors: list + + _UID_ROW_REGISTRY_MAX = 4096 # Keyed by (id(model), uid): mlx-lm's BatchGenerator numbers uids per # instance starting at 0, so two engines serving concurrently (or an engine # reload) produce colliding uid sequences. The model object is the one # identity both the insert sites and the step chokepoint can see. -_uid_row_registry: "OrderedDict[tuple[int, int], tuple[Any, list]]" = OrderedDict() +_uid_row_registry: "OrderedDict[tuple[int, int], _RegisteredRow]" = OrderedDict() # Engines run on separate executor threads and share this module-level # registry; a plain OrderedDict is not safe under concurrent mutation. _uid_row_registry_lock = threading.Lock() @@ -443,7 +450,7 @@ def _register_uid_rows(model, uids, samplers, lps_rows) -> None: """ with _uid_row_registry_lock: for uid, sampler, lps in zip(uids, samplers, lps_rows): - _uid_row_registry[(id(model), uid)] = (sampler, list(lps or ())) + _uid_row_registry[(id(model), uid)] = _RegisteredRow(sampler, list(lps or ())) while len(_uid_row_registry) > _UID_ROW_REGISTRY_MAX: _uid_row_registry.popitem(last=False) @@ -469,6 +476,61 @@ def _unregister_uid_rows_for_model(model) -> None: del _uid_row_registry[key] +def _row_drifted(current_lps, expected_lps) -> bool: + """True when a slot's processors genuinely differ from the registered row. + + Two distinct empty lists are equivalent — the #1799 normalisation mints + fresh ``[]`` objects every step — so only differing content counts. The + caller's identity check is the steady-state fast path; this only runs + past it. + """ + if not current_lps and not expected_lps: + return False + return current_lps != expected_lps + + +def _log_drift_correction(uids, slot_count) -> None: + """Log a corrected drift: one WARNING per window, the rest at DEBUG.""" + global _uid_row_drift_last_warning + now = time.monotonic() + rate_limited = now - _uid_row_drift_last_warning < _UID_ROW_DRIFT_WARNING_INTERVAL_S + if not rate_limited: + _uid_row_drift_last_warning = now + (logger.debug if rate_limited else logger.warning)( + "Realigned generation-batch row state from the uid registry " + f"(uids={list(uids)}, had {slot_count} processor slots); " + "stale or offset slots would have run the wrong sampler/processors." + ) + + +def _realigned_rows(model, uids, cur_samplers, cur_lps): + """Rebuild the positional row lists in uid order from the registry. + + Registered uids take their recorded row; unregistered uids keep their + current slot (the #1799 fallback), padding when the lists are shorter + than ``uids``. Returns ``(samplers, logits_processors, drift)`` — drift + only drives logging, the rebuilt lists are always installed. In steady + state the slots already are the registry lists, so the identity check + skips any comparison work. + """ + model_id = id(model) + with _uid_row_registry_lock: + rows = [_uid_row_registry.get((model_id, uid)) for uid in uids] + + drift = len(cur_lps) != len(uids) + samplers, lps = [], [] + for i, row in enumerate(rows): + if row is not None: + if not drift and i < len(cur_lps) and cur_lps[i] is not row.logits_processors: + drift = _row_drifted(cur_lps[i], row.logits_processors) + samplers.append(row.sampler) + lps.append(row.logits_processors) + else: + samplers.append(cur_samplers[i] if i < len(cur_samplers) else None) + lps.append(cur_lps[i] if i < len(cur_lps) else []) + return samplers, lps, drift + + _original_generation_batch_step = GenerationBatch._step @@ -508,40 +570,16 @@ def _patched_generation_batch_step(self): ] # Realign per-row samplers and logits processors with ``uids`` from the - # per-uid registry. Positional drift (stale/offset slots left by batch - # extend/filter/split) otherwise makes a row run another request's — or - # no — sampler and processors, which silently disables grammar - # constraints and thinking budgets under concurrent mixed load. - cur_lps = self.logits_processors - cur_samplers = getattr(self, "samplers", None) or [] - new_lps = [] - new_samplers = [] - drift = len(cur_lps) != len(self.uids) - model_id = id(getattr(self, "model", None)) - with _uid_row_registry_lock: - registry_rows = [_uid_row_registry.get((model_id, uid)) for uid in self.uids] - for i, uid in enumerate(self.uids): - entry = registry_rows[i] - if entry is not None: - sampler, lps = entry - if not drift and i < len(cur_lps) and cur_lps[i] is not lps and (cur_lps[i] or lps): - drift = drift or (cur_lps[i] != lps) - new_samplers.append(sampler) - new_lps.append(lps) - else: - new_samplers.append(cur_samplers[i] if i < len(cur_samplers) else None) - new_lps.append(cur_lps[i] if i < len(cur_lps) else []) + # per-uid registry; stale or offset slots left by batch extend/filter/ + # split would otherwise run another request's — or no — rows. See #1823. + new_samplers, new_lps, drift = _realigned_rows( + getattr(self, "model", None), + self.uids, + getattr(self, "samplers", None) or [], + self.logits_processors, + ) if drift: - global _uid_row_drift_last_warning - now = time.monotonic() - rate_limited = now - _uid_row_drift_last_warning < _UID_ROW_DRIFT_WARNING_INTERVAL_S - if not rate_limited: - _uid_row_drift_last_warning = now - (logger.debug if rate_limited else logger.warning)( - "Realigned generation-batch row state from the uid registry " - f"(uids={list(self.uids)}, had {len(cur_lps)} processor slots); " - "stale or offset slots would have run the wrong sampler/processors." - ) + _log_drift_correction(self.uids, len(self.logits_processors)) self.logits_processors = new_lps self.samplers = new_samplers diff --git a/tests/test_scheduler_logits_processors.py b/tests/test_scheduler_logits_processors.py index 7b19e5c62..7d64c2f64 100644 --- a/tests/test_scheduler_logits_processors.py +++ b/tests/test_scheduler_logits_processors.py @@ -641,6 +641,54 @@ def test_unregister_drops_the_row(self): finally: scheduler._uid_row_registry = original + def test_realigned_rows_rebuilds_in_uid_order(self): + """Direct unit coverage of the pure rebuild: offset slots are + replaced by the registered rows and the drift flag is set.""" + from collections import OrderedDict + + import omlx.scheduler as scheduler + + registry = OrderedDict() + original = scheduler._uid_row_registry + scheduler._uid_row_registry = registry + model = object() + proc = object() + sampler = object() + try: + scheduler._register_uid_rows(model, [1], [None], [[]]) + scheduler._register_uid_rows(model, [2], [sampler], [[proc]]) + # The #1823 probe shape: a stale leading slot, 3 slots for 2 uids. + samplers, lps, drift = scheduler._realigned_rows( + model, [1, 2], [None, None, sampler], [[], [], [proc]] + ) + assert drift is True + assert samplers == [None, sampler] + assert lps == [[], [proc]] + finally: + scheduler._uid_row_registry = original + + def test_realigned_rows_steady_state_reports_no_drift(self): + """Feeding the rebuilt lists back in (the post-realignment state) + must report no drift: the identity fast path short-circuits.""" + from collections import OrderedDict + + import omlx.scheduler as scheduler + + registry = OrderedDict() + original = scheduler._uid_row_registry + scheduler._uid_row_registry = registry + model = object() + proc = object() + try: + scheduler._register_uid_rows(model, [1], [None], [[proc]]) + samplers, lps, drift = scheduler._realigned_rows(model, [1], [], []) + assert drift is True # short slots on the first pass + samplers, lps, drift = scheduler._realigned_rows(model, [1], samplers, lps) + assert drift is False + assert lps == [[proc]] + finally: + scheduler._uid_row_registry = original + def test_model_scoped_clear_drops_only_that_model(self): """Reset/recovery/shutdown release by model: every row of the reset engine goes, every row of the other engine stays.""" From 7a35d4fc25dd4cfde270e8d367c7591668b4f56a Mon Sep 17 00:00:00 2001 From: efortin Date: Sun, 14 Jun 2026 07:43:01 +0200 Subject: [PATCH 8/8] fix(scheduler): log sampler-only row realignment --- omlx/scheduler.py | 10 +++++----- tests/test_scheduler_logits_processors.py | 23 +++++++++++++++++++++++ 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/omlx/scheduler.py b/omlx/scheduler.py index 55c43cad1..8a4961e39 100644 --- a/omlx/scheduler.py +++ b/omlx/scheduler.py @@ -521,6 +521,11 @@ def _realigned_rows(model, uids, cur_samplers, cur_lps): samplers, lps = [], [] for i, row in enumerate(rows): if row is not None: + if not drift: + if i >= len(cur_samplers): + drift = row.sampler is not None + elif cur_samplers[i] is not row.sampler: + drift = True if not drift and i < len(cur_lps) and cur_lps[i] is not row.logits_processors: drift = _row_drifted(cur_lps[i], row.logits_processors) samplers.append(row.sampler) @@ -583,7 +588,6 @@ def _patched_generation_batch_step(self): self.logits_processors = new_lps self.samplers = new_samplers - result = _original_generation_batch_step(self) # self._next_tokens contains the just-sampled tokens (async eval pending). @@ -3620,8 +3624,6 @@ def _insert_prefilled_request( ) if uids: _register_uid_rows(self.model, uids, [state.sampler], [per_row_lps]) - - if uids: uid = uids[0] self.request_id_to_uid[request.request_id] = uid self.uid_to_request_id[uid] = request.request_id @@ -7177,8 +7179,6 @@ def _sparse_progress(processed: int, total: int) -> None: ) if uids: _register_uid_rows(self.model, uids, [sampler], [per_row_lps]) - - if uids: uid = uids[0] self.request_id_to_uid[request.request_id] = uid self.uid_to_request_id[uid] = request.request_id diff --git a/tests/test_scheduler_logits_processors.py b/tests/test_scheduler_logits_processors.py index 7d64c2f64..719b2946d 100644 --- a/tests/test_scheduler_logits_processors.py +++ b/tests/test_scheduler_logits_processors.py @@ -689,6 +689,29 @@ def test_realigned_rows_steady_state_reports_no_drift(self): finally: scheduler._uid_row_registry = original + def test_realigned_rows_reports_sampler_only_drift(self): + """A corrected sampler-only mismatch is still row-state drift.""" + from collections import OrderedDict + + import omlx.scheduler as scheduler + + registry = OrderedDict() + original = scheduler._uid_row_registry + scheduler._uid_row_registry = registry + model = object() + expected_sampler = object() + wrong_sampler = object() + try: + scheduler._register_uid_rows(model, [1], [expected_sampler], [[]]) + samplers, lps, drift = scheduler._realigned_rows( + model, [1], [wrong_sampler], [[]] + ) + assert drift is True + assert samplers == [expected_sampler] + assert lps == [[]] + finally: + scheduler._uid_row_registry = original + def test_model_scoped_clear_drops_only_that_model(self): """Reset/recovery/shutdown release by model: every row of the reset engine goes, every row of the other engine stays."""