Skip to content
163 changes: 158 additions & 5 deletions omlx/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
import os
import threading
import time
from collections import defaultdict, deque
from collections import OrderedDict, defaultdict, deque
from collections.abc import Callable
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Optional
from typing import Any, NamedTuple, Optional

import mlx.core as mx
from mlx_lm.generate import (
Expand Down Expand Up @@ -408,6 +408,134 @@ class _PrefillState:
# After _original_step returns, self._next_tokens holds the freshly sampled
# tokens. We eval them synchronously and accept in grammar processors.
# ---------------------------------------------------------------------------
# Authoritative per-uid row state for the generation batch.
#
# mlx-lm keeps ``samplers`` / ``logits_processors`` as positional lists that
# must stay aligned with ``uids``. Heterogeneous continuous batching
# (extend/filter/split across prompt and generation batches) can leave stale
# or offset row slots behind; #1799 made the step crash-safe by normalising
# ``None`` slots, but a misaligned row silently runs the WRONG sampler and
# logits processors (e.g. a grammar/thinking-budget request decoding with no
# constraints at all). The registry below records, at insert time, what each
# uid is supposed to run; the step chokepoint realigns the positional lists
# from it. Bounded so a missing cleanup can never grow it unbounded.
class _RegisteredRow(NamedTuple):
"""What a uid is supposed to run, recorded at request insert."""

sampler: Any
logits_processors: list


_UID_ROW_REGISTRY_MAX = 4096
# Keyed by (id(model), uid): mlx-lm's BatchGenerator numbers uids per
# instance starting at 0, so two engines serving concurrently (or an engine
# reload) produce colliding uid sequences. The model object is the one
# identity both the insert sites and the step chokepoint can see.
_uid_row_registry: "OrderedDict[tuple[int, int], _RegisteredRow]" = OrderedDict()
# Engines run on separate executor threads and share this module-level
# registry; a plain OrderedDict is not safe under concurrent mutation.
_uid_row_registry_lock = threading.Lock()
# Drift corrections are worth one log line each, but a pathological batching
# pattern could correct on every merge; cap the WARNING rate and route the
# rest to DEBUG so the signal survives without flooding the logs.
_UID_ROW_DRIFT_WARNING_INTERVAL_S = 60.0
_uid_row_drift_last_warning = float("-inf")


def _register_uid_rows(model, uids, samplers, lps_rows) -> None:
"""Record the sampler and logits processors each freshly-inserted uid must run.

Each (model, uid) key is inserted exactly once per request, so plain
insertion order is enough for the oldest-first backstop eviction.
"""
with _uid_row_registry_lock:
for uid, sampler, lps in zip(uids, samplers, lps_rows):
_uid_row_registry[(id(model), uid)] = _RegisteredRow(sampler, list(lps or ()))
while len(_uid_row_registry) > _UID_ROW_REGISTRY_MAX:
_uid_row_registry.popitem(last=False)


def _unregister_uid_row(model, uid) -> None:
"""Drop a finished request's row so heavy processors are not pinned
until FIFO eviction; the bounded size stays as the backstop."""
with _uid_row_registry_lock:
_uid_row_registry.pop((id(model), uid), None)


def _unregister_uid_rows_for_model(model) -> None:
"""Drop every registry row for a model (generator reset, recovery, shutdown).

The recovery and reset paths clear the uid maps wholesale instead of
finishing requests one by one; releasing by model covers them, and leaves
nothing behind that a later engine load could match if ``id(model)`` were
recycled.
"""
model_id = id(model)
with _uid_row_registry_lock:
for key in [key for key in _uid_row_registry if key[0] == model_id]:
del _uid_row_registry[key]


def _row_drifted(current_lps, expected_lps) -> bool:
"""True when a slot's processors genuinely differ from the registered row.

Two distinct empty lists are equivalent — the #1799 normalisation mints
fresh ``[]`` objects every step — so only differing content counts. The
caller's identity check is the steady-state fast path; this only runs
past it.
"""
if not current_lps and not expected_lps:
return False
return current_lps != expected_lps


def _log_drift_correction(uids, slot_count) -> None:
"""Log a corrected drift: one WARNING per window, the rest at DEBUG."""
global _uid_row_drift_last_warning
now = time.monotonic()
rate_limited = now - _uid_row_drift_last_warning < _UID_ROW_DRIFT_WARNING_INTERVAL_S
if not rate_limited:
_uid_row_drift_last_warning = now
(logger.debug if rate_limited else logger.warning)(
"Realigned generation-batch row state from the uid registry "
f"(uids={list(uids)}, had {slot_count} processor slots); "
"stale or offset slots would have run the wrong sampler/processors."
)


def _realigned_rows(model, uids, cur_samplers, cur_lps):
"""Rebuild the positional row lists in uid order from the registry.

Registered uids take their recorded row; unregistered uids keep their
current slot (the #1799 fallback), padding when the lists are shorter
than ``uids``. Returns ``(samplers, logits_processors, drift)`` — drift
only drives logging, the rebuilt lists are always installed. In steady
state the slots already are the registry lists, so the identity check
skips any comparison work.
"""
model_id = id(model)
with _uid_row_registry_lock:
rows = [_uid_row_registry.get((model_id, uid)) for uid in uids]

drift = len(cur_lps) != len(uids)
samplers, lps = [], []
for i, row in enumerate(rows):
if row is not None:
if not drift:
if i >= len(cur_samplers):
drift = row.sampler is not None
elif cur_samplers[i] is not row.sampler:
drift = True
if not drift and i < len(cur_lps) and cur_lps[i] is not row.logits_processors:
drift = _row_drifted(cur_lps[i], row.logits_processors)
samplers.append(row.sampler)
lps.append(row.logits_processors)
else:
samplers.append(cur_samplers[i] if i < len(cur_samplers) else None)
lps.append(cur_lps[i] if i < len(cur_lps) else [])
return samplers, lps, drift


_original_generation_batch_step = GenerationBatch._step


Expand Down Expand Up @@ -446,6 +574,20 @@ def _patched_generation_batch_step(self):
procs if procs is not None else [] for procs in self.logits_processors
]

# Realign per-row samplers and logits processors with ``uids`` from the
# per-uid registry; stale or offset slots left by batch extend/filter/
# split would otherwise run another request's — or no — rows. See #1823.
new_samplers, new_lps, drift = _realigned_rows(
getattr(self, "model", None),
self.uids,
getattr(self, "samplers", None) or [],
self.logits_processors,
)
if drift:
_log_drift_correction(self.uids, len(self.logits_processors))
self.logits_processors = new_lps
self.samplers = new_samplers

result = _original_generation_batch_step(self)

# self._next_tokens contains the just-sampled tokens (async eval pending).
Expand Down Expand Up @@ -1742,6 +1884,7 @@ def _drain_pending_async_removes(self) -> bool:
e,
)
# Cleanup uid maps now that the slot is reclaimable.
_unregister_uid_row(self.model, uid)
if uid in self.uid_to_request_id:
del self.uid_to_request_id[uid]
if request_id in self.request_id_to_uid:
Expand Down Expand Up @@ -3479,8 +3622,8 @@ def _insert_prefilled_request(
logits_processors=[per_row_lps],
state_machines=[state.sm],
)

if uids:
_register_uid_rows(self.model, uids, [state.sampler], [per_row_lps])
uid = uids[0]
self.request_id_to_uid[request.request_id] = uid
self.uid_to_request_id[uid] = request.request_id
Expand Down Expand Up @@ -5833,6 +5976,7 @@ def _do_abort_request(self, request_id: str) -> bool:
close = getattr(mtp_state.generator, "close", None)
if callable(close):
close()
_unregister_uid_row(self.model, uid)
del self.uid_to_request_id[uid]
del self.request_id_to_uid[request.request_id]

Expand Down Expand Up @@ -5998,7 +6142,9 @@ def fail_all_requests(self) -> list[str]:
if uid is not None:
self.uid_to_request_id.pop(uid, None)
self._generation_overflow_recovery_ids.difference_update(failed_ids)
# Reset batch generator only (cache is not corrupted)
# Reset batch generator only (cache is not corrupted). Every row dies
# with it; survivors re-register at re-insert.
_unregister_uid_rows_for_model(self.model)
self.batch_generator = None
self._current_sampler_params = None
# Reclaim fragmented Metal buffers after generation failure.
Expand Down Expand Up @@ -7031,8 +7177,8 @@ def _sparse_progress(processed: int, total: int) -> None:
logits_processors=[per_row_lps],
state_machines=[sm],
)

if uids:
_register_uid_rows(self.model, uids, [sampler], [per_row_lps])
uid = uids[0]
self.request_id_to_uid[request.request_id] = uid
self.uid_to_request_id[uid] = request.request_id
Expand Down Expand Up @@ -7568,6 +7714,7 @@ def _cleanup_finished(self, finished_ids: set[str]) -> None:
self._remove_uid_from_active_batch(uid)
if hasattr(self.model, "unregister_rope_delta"):
self.model.unregister_rope_delta(uid)
_unregister_uid_row(self.model, uid)
if uid in self.uid_to_request_id:
del self.uid_to_request_id[uid]
del self.request_id_to_uid[request_id]
Expand Down Expand Up @@ -7685,6 +7832,7 @@ def _recover_from_cache_error(self) -> None:
self._cache_rate_tracker.clear()

# Clear UID mappings
_unregister_uid_rows_for_model(self.model)
self.request_id_to_uid.clear()
self.uid_to_request_id.clear()

Expand Down Expand Up @@ -7714,6 +7862,7 @@ def _recover_from_generation_overflow_error(self) -> None:
if hasattr(self.model, "clear_pending_embeddings"):
self.model.clear_pending_embeddings()

_unregister_uid_rows_for_model(self.model)
self.request_id_to_uid.clear()
self.uid_to_request_id.clear()
self._deferred_clear_at = None
Expand Down Expand Up @@ -8255,6 +8404,7 @@ def reset(self) -> None:
self.running.clear()
self.requests.clear()
self.finished_req_ids.clear()
_unregister_uid_rows_for_model(self.model)
self.request_id_to_uid.clear()
self.uid_to_request_id.clear()
self._generation_overflow_recovery_ids.clear()
Expand Down Expand Up @@ -8369,6 +8519,9 @@ def shutdown(self) -> None:
if self.paged_ssd_cache_manager is not None:
self.paged_ssd_cache_manager.close()
self.paged_ssd_cache_manager = None
# Release whatever the per-path unregisters did not reach, so nothing
# survives this engine in the module-level row registry.
_unregister_uid_rows_for_model(self.model)
logger.info("Scheduler shutdown completed")

def adjust_store_cache_cap(self, pressure_level: str) -> None:
Expand Down
Loading