Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions docs/experimental/dflash_mlx_hook_details.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# DFlash-MLX Hook Details

Date: 2026-04-22

Notes on how dflash-mlx inserts hooks into a model when it loads.

dflash-mlx installs three types of hooks at the **class level** on model layers when `load_target_bundle()` is called. These hooks wrap attention and projection layers to enable speculative decoding, but they also affect any code path that uses the same model — including fallback engines.

## Hook Types

### 1. Split Full Attention Hook (`_install_split_full_attention_hook`)

Wraps `self_attn.__call__` with a custom split attention path. When `_dflash_split_sdpa_enabled=True`, the hook runs its own manual SDPA computation (separate q/k/v projection, rope, attention) instead of the original optimized path.

**Impact on fallback engine:** Adds Python wrapper overhead and runs a less-optimized attention path even during prefill. Causes ~25% TG TPS regression and ~30-40% slower prefill throughput on the fallback path.

**Mechanism:**
```python
# In dflash_mlx.runtime.py line 652
cls.__call__ = split_call # wraps original __call__

# split_call checks _dflash_split_sdpa_enabled flag
# When True: runs custom attention path (manual q/k/v proj, rope, SDPA)
# When False: calls original_call (unhooked path)
```

### 2. Speculative Linear Cache Hook (`_install_speculative_linear_cache_hook`)

Wraps `linear_attn.__call__` with speculative logic. Checks if cache is `RecurrentRollbackCache` and armed; otherwise calls original.

**Impact on fallback engine:** Minimal — the cache won't be `RecurrentRollbackCache` in BatchedEngine, so it already calls `original_call`. But the hook wrapper still adds a Python call frame.

**Mechanism:**
```python
# In dflash_mlx.runtime.py line 381-400
def speculative_call(self, x, cache=None):
if isinstance(cache, RecurrentRollbackCache) and cache.is_armed:
# Run speculative path with tape replay
return _speculative_linear_attn(self, x, cache)
return original_call(x, cache) # fallback to original
```

### 3. Exact Small Proj Hooks (`_install_exact_small_proj_hooks`)

Wraps `in_proj_b` and `in_proj_a` with `_ExactSmallProjPad` class. Pads short sequences to a minimum length (`pad_m=16`) before projection, ensuring the draft model's small-sequence assumptions hold.

**Impact on fallback engine:** Changes weight shape behavior — the wrapped layer adds padding logic that isn't needed for normal inference.

**Mechanism:**
```python
# In dflash_mlx.runtime.py line 274-303
class _ExactSmallProjPad(nn.Module):
def __init__(self, linear: nn.Module, *, pad_m: int = 16):
self.linear = linear # stores original wrapped layer
self.pad_m = pad_m

def __call__(self, x: mx.array) -> mx.array:
if x.ndim == 3 and x.shape[1] < self.pad_m:
# Pad short sequences before projection
pad = mx.zeros((batch_size, self.pad_m - seq_len, hidden_dim))
out = self.linear(mx.concatenate([x, pad], axis=1))
return out[:, :seq_len, :]
return self.linear(x) # no padding needed
```

## Hook Lifecycle

Hooks are installed once when `load_target_bundle()` is called in dflash-mlx's runtime. They persist on the model class for the lifetime of the process.

```
load_target_bundle()
├─ _install_split_full_attention_hook(self_attn)
├─ _install_speculative_linear_cache_hook(linear_attn)
└─ _install_exact_small_proj_hooks(linear_attn)
```

The hooks are **class-level**, meaning they affect all instances of the attention class — including any fallback engine that shares the same model object.
4 changes: 4 additions & 0 deletions omlx/admin/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class ModelSettingsRequest(BaseModel):
dflash_enabled: Optional[bool] = None
dflash_draft_model: Optional[str] = None
dflash_draft_quant_bits: Optional[int] = None
dflash_max_ctx: Optional[int] = None
reasoning_parser: Optional[str] = None
is_pinned: Optional[bool] = None
is_default: Optional[bool] = None
Expand Down Expand Up @@ -1428,6 +1429,7 @@ async def list_models(is_admin: bool = Depends(require_admin)):
"dflash_enabled": settings.dflash_enabled,
"dflash_draft_model": settings.dflash_draft_model,
"dflash_draft_quant_bits": settings.dflash_draft_quant_bits,
"dflash_max_ctx": settings.dflash_max_ctx,
"is_pinned": settings.is_pinned,
"is_default": settings.is_default,
"display_name": settings.display_name,
Expand Down Expand Up @@ -1655,6 +1657,8 @@ async def update_model_settings(
current_settings.dflash_draft_model = request.dflash_draft_model or None
if "dflash_draft_quant_bits" in sent:
current_settings.dflash_draft_quant_bits = request.dflash_draft_quant_bits or None
if "dflash_max_ctx" in sent:
current_settings.dflash_max_ctx = request.dflash_max_ctx or None

if "reasoning_parser" in sent:
current_settings.reasoning_parser = request.reasoning_parser or None
Expand Down
5 changes: 5 additions & 0 deletions omlx/admin/static/js/dashboard.js
Original file line number Diff line number Diff line change
Expand Up @@ -1532,6 +1532,7 @@
dflash_enabled: settings.dflash_enabled || false,
dflash_draft_model: settings.dflash_draft_model || '',
dflash_draft_quant_bits: settings.dflash_draft_quant_bits ? String(settings.dflash_draft_quant_bits) : '',
dflash_max_ctx: settings.dflash_max_ctx ? String(settings.dflash_max_ctx) : '',
ctKwargEntries,
};
this.showModelSettingsModal = true;
Expand Down Expand Up @@ -1612,6 +1613,9 @@
dflash_draft_quant_bits: this.modelSettings.dflash_enabled && this.modelSettings.dflash_draft_quant_bits
? parseInt(this.modelSettings.dflash_draft_quant_bits)
: null,
dflash_max_ctx: this.modelSettings.dflash_enabled && this.modelSettings.dflash_max_ctx
? parseInt(this.modelSettings.dflash_max_ctx)
: null,
};
})()),
});
Expand Down Expand Up @@ -1674,6 +1678,7 @@
this.modelSettings.dflash_enabled = false;
this.modelSettings.dflash_draft_model = null;
this.modelSettings.dflash_draft_quant_bits = null;
this.modelSettings.dflash_max_ctx = null;
} else if (response.status === 404) {
alert(window.t('js.error.no_config_defaults'));
} else if (response.status === 401) {
Expand Down
6 changes: 6 additions & 0 deletions omlx/admin/templates/dashboard/_modal_model_settings.html
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,12 @@ <h4 class="text-xs font-bold uppercase tracking-widest text-neutral-400 mb-3">{{
<option value="8">8-bit</option>
</select>
</div>
<div>
<label class="block text-xs font-bold uppercase tracking-wider text-neutral-500 mb-2">Max Prompt Tokens</label>
<input type="number" x-model.number="modelSettings.dflash_max_ctx" placeholder="4096" min="1024"
class="w-full px-4 py-2.5 border border-neutral-200 rounded-xl text-sm focus:ring-2 focus:ring-neutral-900 focus:border-transparent transition-all">
<p class="text-xs text-neutral-400 mt-1">Switch to batched mode when prompt exceeds this (default 4096)</p>
</div>
</div>
</div>
</div>
Expand Down
53 changes: 29 additions & 24 deletions omlx/engine/batched.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def __init__(
stream_interval: int = 1,
enable_thinking: bool | None = None,
model_settings: Any | None = None,
model: Any | None = None,
tokenizer: Any | None = None,
):
"""
Initialize the batched engine.
Expand All @@ -56,6 +58,8 @@ def __init__(
stream_interval: Tokens to batch before streaming (1=every token)
enable_thinking: Enable thinking mode for reasoning models (passed to chat_template_kwargs)
model_settings: Optional per-model settings for post-load transforms
model: Optional already-loaded MLX model (shared from another engine)
tokenizer: Optional already-loaded tokenizer (shared from another engine)
"""
self._model_name = model_name
self._trust_remote_code = trust_remote_code
Expand All @@ -64,8 +68,8 @@ def __init__(
self._enable_thinking = enable_thinking
self._model_settings = model_settings

self._model = None
self._tokenizer = None
self._model = model
self._tokenizer = tokenizer
self._engine = None
self._loaded = False
self._grammar_compiler = None
Expand Down Expand Up @@ -193,33 +197,34 @@ async def start(self) -> None:
from ..engine_core import AsyncEngineCore, EngineConfig
from ..scheduler import SchedulerConfig

# Build tokenizer config with model-specific fixes
tokenizer_config = get_tokenizer_config(
self._model_name,
trust_remote_code=self._trust_remote_code,
)

# Load model on the global MLX executor to avoid blocking the event loop
# while ensuring no concurrent Metal operations. See issue #85.
from ..engine_core import get_mlx_executor

def _load_model_sync():
return load(
# If model/tokenizer already provided (shared from another engine), skip loading
if self._model is None:
# Build tokenizer config with model-specific fixes
tokenizer_config = get_tokenizer_config(
self._model_name,
tokenizer_config=tokenizer_config,
trust_remote_code=self._trust_remote_code,
)

loop = asyncio.get_running_loop()
self._model, self._tokenizer = await loop.run_in_executor(
get_mlx_executor(), _load_model_sync
)
# Load model on the global MLX executor to avoid blocking the event loop
# while ensuring no concurrent Metal operations. See issue #85.
from ..engine_core import get_mlx_executor

# Apply post-load transforms (e.g., IndexCache for DSA models)
from ..utils.model_loading import apply_post_load_transforms
def _load_model_sync():
return load(
self._model_name,
tokenizer_config=tokenizer_config,
)

self._model = apply_post_load_transforms(
self._model, self._model_settings
)
loop = asyncio.get_running_loop()
self._model, self._tokenizer = await loop.run_in_executor(
get_mlx_executor(), _load_model_sync
)

from ..utils.model_loading import apply_post_load_transforms

self._model = apply_post_load_transforms(
self._model, self._model_settings
)

# TurboQuant KV cache: patch attention and set kv_bits on scheduler
if self._model_settings is not None:
Expand Down
Loading