Feat: allow dynamic DFlash switching on a per-request basis based on context length #876
Feat: allow dynamic DFlash switching on a per-request basis based on context length #876sparsely-active wants to merge 13 commits intojundot:mainfrom
Conversation
|
Hi.
Literally my use case, so I love the idea. My two cents: At least z-lab/Qwen3.5-27B-DFlash explicitly mentions this:
Which is one of the main reasons of the noticeable degradation on large context in my understanding. The moment people see the If I'm correct with the training tokens reasoning then I'd say we need to add some sane <input
type="number"
x-model.number="modelSettings.dflash_max_ctx"
placeholder="4096"
min="1024"
step="1024"
...
/>No DFlash:
|
|
Yes it gets worse at higher contexts but people might want to adjust it from 4k to something like 6-8k depending on their setup and hardware, was my thinking. |
|
I've been doing some tests on an M1 Max 64GB and the results on this incredibly modest hardware are not good! I think the real bottleneck for me is memory bandwidth (409 GB/s), and introducing DFlash actually makes the overall E2E throughput significantly worse, at every context range. I think part of the issue is that DFlash can't benefit from the oMLX KV cache mechanics and batched processing, so TTFT time goes up, but also I think there's just more memory bandwidth being used overall, so I'm hitting that wall quicker? Maybe it's a different story on beefier machines? |
|
I have M2 Max 64 GB and my benchmarks are above. Don't forget to re-download a z-lab model, for example the z-lab/Qwen3.6-35B-A3B-DFlash that I use has actually finished its training and then re-uploaded just 1 day ago. Regarding your question – I'd definitely use the changes in this PR, it fits my flow nicely. |
|
Hmmm okay, after a bit more exploration it looks like the degraded throughput I was seeing was due to some hidden complexity that I initially overlooked with regards to how the model is loaded with DFlash; when I think the approach I've proposed here can work, it just needs some adjustment to avoid throughput degradation for non-DFlash responses (over the |
- Add 'mode' parameter (default 'dflash') to generate() and stream_generate() - When mode='batched' or prompt exceeds max_dflash_ctx (now 8192), delegate to fallback engine - Target model loaded once and shared across both modes (no reload overhead) - Updated docstring to reflect new per-request switching behavior
…aths - BatchedEngine accepts optional model/tokenizer params (shared from DFlashEngine) - start() skips loading when model already provided - DFlashEngine passes shared target model to fallback engine (no reload)
…er-request switching - Remove _in_fallback_mode attribute (no longer needed with per-request mode) - Rename _evict_dflash_and_start_fallback -> _init_fallback_engine - No model eviction: target model + draft model stay loaded - generate/stream_generate route based on mode/context, no permanent state - Remove in_fallback_mode from get_stats() - stop() cleans up fallback engine, unloads shared models
- Add dflash_max_ctx field to ModelSettings dataclass (Optional[int]) - DFlashEngine reads from model_settings.dflash_max_ctx instead of env var only - Env var DFLASH_MAX_CTX still takes priority for global override - UI: add Max Prompt Tokens input in DFlash section (default 4096)
- add _clear_env() helper to clear DFLASH_MAX_CTX before tests that verify model_settings-based resolution. - test model settings and dflash routing
… models Apply omlx post-load transforms (GatedDeltaNet advance, IndexCache, etc). dflash-mlx installs its own hooks on linear_attn/self_attn but doesn't apply omlx's GatedDeltaNet advance patch.
When a long-context request triggered fallback to BatchedEngine, _uninstall_dflash_hooks permanently disabled dflash-mlx hooks on the shared model. Subsequent short-context requests would silently fall through to baseline performance because _dflash_split_sdpa_enabled was set to False and never restored. Replace the one-shot uninstall with toggle functions: - _enable_dflash_hooks(model): sets _dflash_split_sdpa_enabled=True on all self_attn layers, re-wraps _ExactSmallProjPad on linear_attn layers if they were unwrapped. - _disable_dflash_hooks(model): sets _dflash_split_sdpa_enabled=False, unwraps _ExactSmallProjPad back to raw linear layers. Call _enable before the DFlash speculative path (generate, stream_generate) and _disable before creating the fallback engine. A threading lock prevents concurrent flag toggling from async request handlers and the MLX executor thread. The cost of toggling should be negligable.
…rved at the same time
1e82b76 to
bf59f4f
Compare
BenchmarksHardware: M1 Max 64GB Baseline, DFlash completely disabledDFlash enabled, max context: ~4kDFlash enabled, max context: ~16kSample exceptence rates (~93%) |
Notes on dflash-mlx hooksHooks are installed once when The hooks are class-level, meaning they affect all instances of the attention class, including any fallback engine that shares the same model object. This changes the behaviour of the target model. ApproachThe goal is to load the primary/target model only once in memory, so when switching between DFlash and non-DFlash requests (based on context size) we flip these hooks on and off like a lightswitch. Two of them are basically just boolean flags, the overhead should be negligible. There's a simple lock with reference-count added to ensure we don't handle both a DFlash and non-DFlash request at the same time on the same model, to avoid any weird scenarios or race conditions. |
|
I ran your PR against oMLX 0.3.6 stable on an M5 Max / 128GB unified memory machine ctx sweep, median of 3 trials per size (gen tok/s)Two bench modes: DFlash-path rows (2K/4K) are untouched by the PR — same rate as stock. Good. Fallback-path rows (8K+) show a consistent ~6-10% regression vs stock in the PR The post-probe row is the whole point of the PR on small-prompts-after-large- The bug: VLM fallback served from a DFlash-loaded (mlx-lm) target
Different Python class, different forward method, different internal structure. The LLM fallback case ( Suggested fixIn if self._fallback_engine_type == "vlm":
from .vlm import VLMBatchedEngine
self._fallback_engine = VLMBatchedEngine(
model_name=self._model_name,
scheduler_config=self._scheduler_config,
model_settings=self._model_settings,
# VLM expects a mlx_vlm-loaded model; sharing DFlash's mlx-lm target
# forces a slower path. Let VLMBatchedEngine load its own.
)
else:
# LLM fallback — BatchedEngine natively consumes mlx-lm models,
# sharing is type-consistent and saves ~29GB of GPU memory.
from .batched import BatchedEngine
self._fallback_engine = BatchedEngine(
model_name=self._model_name,
scheduler_config=self._scheduler_config,
model_settings=self._model_settings,
model=shared_model,
tokenizer=shared_tokenizer,
)Tradeoff: for VLM-classified models, the fix reverts to stock's behaviour of
Side finding: tiered-KV determinismPre-existing issue on stock 0.3.6: with DFlash enabled, greedy (temp=0) output Two remaining observations for context:
SummaryWith the three-line fix, this PR is a strict improvement over stock 0.3.6 on |
|
Nice catch! Thanks for taking a look. I was testing with the model type explicitly set to LLM, and didn't test the VLM path. So we have solid DFlash support for the LLM path via Edit: confirmed, |
|
Interesting — but I think there's a gap between "mlx-vlm has no DFlash support" and "DFlash + VLM-classified models doesn't work." Our bench on That speedup is real because mlx-vlm's missing DFlash support only matters when:
So I'd argue the type-guard fix still stands on its own merits today:
When mlx-vlm ships DFlash support, the Separately — on the "weird issues" warning: for anyone hitting this thread later with a VLM-classified text model and DFlash enabled, the stock behaviour is fine for text requests and speculative-decodes correctly. The weirdness surfaces if you start sending images to a DFlash-enabled VLM model, which is a different code path entirely. Happy to open the fixup PR whenever you're ready. |
|
Aha, the new |
|
@sparsely-active hi! Look what I've found #866 (comment) DFlash draft model author claims speculative 100k tokens on his side. |
Related issue and some discussion: #873
Motivation
DFlash (speculative decoding) can provide a substantial increase in token generation speed but it doesn't fit well with agentic workflows. Above ~4k context the benefits taper off, and at large contexts (32k) the effect is negative.
Current status
The current implementation in oMLX uses DFlash for requests below a preset
max_dflash_ctxof 4096 tokens. If the server receives a request with a context above that threshold, it goes int fallback mode, unloads the draft model and disables DFlash for all prompts (regardless of context size) for the remainder of the session. This status is only reset once the model is unloaded and loaded again.Proposal
Per-request routing: instead of a permanent "fallback" mode, DFlashEngine now supports per-request mode switching; each request to the model is handled by either BatchedEngine or DFlashEngine based on context length.
Shared target model: the target model is loaded once and shared across both DFlash and batched paths. BatchedEngine now accepts optional model/tokenizer params, and its
start()skips loading when a model is already provided. The target model is loaded once in memory and shared by both engines.No eviction: removed the evict_dflash and start_fallback pattern. Target model and draft model stay loaded; switching between modes is just a routing decision.
Per-model max_dflash_ctx: added dflash_max_ctx option to ModelSettings, configurable via the admin UI (default 4096). The DFLASH_MAX_CTX env var still takes priority as a global override (also used internally by
bstnxbt/dflash-mlx).Outcome
A model configured with DFlash can receive both large-context and small-context prompts over the course of a session and dynamically use batched mode (with it's own KV cache) or DFlash mode (separate KV cache) with a configurable threshold.
For agentic workflows (depending on your agent) this means you can keep a long-running, large-context session going and also fire off smaller ad-hoc prompts in between that will benefit from DFlash, potentially through automatic delegation to subagents (for small tasks, with the main context stripped out).