Skip to content

Feat: allow dynamic DFlash switching on a per-request basis based on context length #876

Open
sparsely-active wants to merge 13 commits intojundot:mainfrom
sparsely-active:dynamic-dflash
Open

Feat: allow dynamic DFlash switching on a per-request basis based on context length #876
sparsely-active wants to merge 13 commits intojundot:mainfrom
sparsely-active:dynamic-dflash

Conversation

@sparsely-active
Copy link
Copy Markdown

Related issue and some discussion: #873

Motivation

DFlash (speculative decoding) can provide a substantial increase in token generation speed but it doesn't fit well with agentic workflows. Above ~4k context the benefits taper off, and at large contexts (32k) the effect is negative.

Current status

The current implementation in oMLX uses DFlash for requests below a preset max_dflash_ctx of 4096 tokens. If the server receives a request with a context above that threshold, it goes int fallback mode, unloads the draft model and disables DFlash for all prompts (regardless of context size) for the remainder of the session. This status is only reset once the model is unloaded and loaded again.

Proposal

Per-request routing: instead of a permanent "fallback" mode, DFlashEngine now supports per-request mode switching; each request to the model is handled by either BatchedEngine or DFlashEngine based on context length.

Shared target model: the target model is loaded once and shared across both DFlash and batched paths. BatchedEngine now accepts optional model/tokenizer params, and its start() skips loading when a model is already provided. The target model is loaded once in memory and shared by both engines.

No eviction: removed the evict_dflash and start_fallback pattern. Target model and draft model stay loaded; switching between modes is just a routing decision.

Per-model max_dflash_ctx: added dflash_max_ctx option to ModelSettings, configurable via the admin UI (default 4096). The DFLASH_MAX_CTX env var still takes priority as a global override (also used internally by bstnxbt/dflash-mlx).

Outcome

A model configured with DFlash can receive both large-context and small-context prompts over the course of a session and dynamically use batched mode (with it's own KV cache) or DFlash mode (separate KV cache) with a configurable threshold.

For agentic workflows (depending on your agent) this means you can keep a long-running, large-context session going and also fire off smaller ad-hoc prompts in between that will benefit from DFlash, potentially through automatic delegation to subagents (for small tasks, with the main context stripped out).

@deepsweet
Copy link
Copy Markdown

deepsweet commented Apr 21, 2026

Hi.

small-context prompts over the course of a session

Literally my use case, so I love the idea.

My two cents:

At least z-lab/Qwen3.5-27B-DFlash explicitly mentions this:

It was trained with a context length of 4096 tokens.

Which is one of the main reasons of the noticeable degradation on large context in my understanding.

The moment people see the dflash_max_ctx configurable in UI they'll definitely want to increase it, but the experience won't be as straightforward as expected.

If I'm correct with the training tokens reasoning then I'd say we need to add some sane max to that input:

<input
  type="number"
  x-model.number="modelSettings.dflash_max_ctx"
  placeholder="4096"
  min="1024"
  step="1024"
  ...
/>

No DFlash:

Benchmark Model: Qwen3.6-35B-A3B-MLX-oQ5-FP16
================================================================================

Single Request Results
--------------------------------------------------------------------------------
Test                TTFT(ms)    TPOT(ms)        pp TPS        tg TPS      E2E(s)    Throughput    Peak Mem
pp1024/tg128          1208.2       13.32   847.6 tok/s    75.7 tok/s       2.899   397.3 tok/s    23.86 GB
pp4096/tg128          4249.8       14.62   963.8 tok/s    69.0 tok/s       6.106   691.8 tok/s    24.64 GB
pp8192/tg128          9851.2       15.50   831.6 tok/s    65.0 tok/s      11.820   703.9 tok/s    24.99 GB
pp16384/tg128        20957.0       17.87   781.8 tok/s    56.4 tok/s      23.227   710.9 tok/s    25.61 GB
pp32768/tg128        49014.8       23.50   668.5 tok/s    42.9 tok/s      51.999   632.6 tok/s    26.95 GB

DFLASH_MAX_CTX=32768:

Benchmark Model: Qwen3.6-35B-A3B-MLX-oQ5-FP16
================================================================================

Single Request Results
--------------------------------------------------------------------------------
Test                TTFT(ms)    TPOT(ms)        pp TPS        tg TPS      E2E(s)    Throughput    Peak Mem
pp1024/tg128          1318.2        6.84   776.8 tok/s   147.3 tok/s       2.187   526.8 tok/s    25.87 GB
pp4096/tg128          5964.4        7.26   686.7 tok/s   138.9 tok/s       6.886   613.4 tok/s    27.01 GB
pp8192/tg128         14474.0        8.12   566.0 tok/s   124.2 tok/s      15.505   536.6 tok/s    27.52 GB
pp16384/tg128        37778.8        9.34   433.7 tok/s   107.9 tok/s      38.965   423.8 tok/s    28.15 GB
pp32768/tg128        50907.5       22.04   643.7 tok/s    45.7 tok/s      53.706   612.5 tok/s    26.95 GB

@sparsely-active
Copy link
Copy Markdown
Author

Yes it gets worse at higher contexts but people might want to adjust it from 4k to something like 6-8k depending on their setup and hardware, was my thinking.

@sparsely-active
Copy link
Copy Markdown
Author

sparsely-active commented Apr 21, 2026

I've been doing some tests on an M1 Max 64GB and the results on this incredibly modest hardware are not good! I think the real bottleneck for me is memory bandwidth (409 GB/s), and introducing DFlash actually makes the overall E2E throughput significantly worse, at every context range. I think part of the issue is that DFlash can't benefit from the oMLX KV cache mechanics and batched processing, so TTFT time goes up, but also I think there's just more memory bandwidth being used overall, so I'm hitting that wall quicker? Maybe it's a different story on beefier machines?

@deepsweet
Copy link
Copy Markdown

I have M2 Max 64 GB and my benchmarks are above. Don't forget to re-download a z-lab model, for example the z-lab/Qwen3.6-35B-A3B-DFlash that I use has actually finished its training and then re-uploaded just 1 day ago.

Regarding your question – I'd definitely use the changes in this PR, it fits my flow nicely.

@sparsely-active
Copy link
Copy Markdown
Author

sparsely-active commented Apr 22, 2026

Hmmm okay, after a bit more exploration it looks like the degraded throughput I was seeing was due to some hidden complexity that I initially overlooked with regards to how the model is loaded with DFlash; when dflash-mlx is loaded it patches the model at the class level with three types of hooks on model layers when load_target_bundle() is called. These hooks wrap some of the interface related to things like attention and projection layers to enable speculative decoding, but they also affect any code path that re-uses the same model, including the "fallback engine" in the case of this PR.

I think the approach I've proposed here can work, it just needs some adjustment to avoid throughput degradation for non-DFlash responses (over the max_dflash_ctx threshold).

- Add 'mode' parameter (default 'dflash') to generate() and stream_generate()
- When mode='batched' or prompt exceeds max_dflash_ctx (now 8192), delegate to fallback engine
- Target model loaded once and shared across both modes (no reload overhead)
- Updated docstring to reflect new per-request switching behavior
…aths

- BatchedEngine accepts optional model/tokenizer params (shared from DFlashEngine)
- start() skips loading when model already provided
- DFlashEngine passes shared target model to fallback engine (no reload)
…er-request switching

- Remove _in_fallback_mode attribute (no longer needed with per-request mode)
- Rename _evict_dflash_and_start_fallback -> _init_fallback_engine
- No model eviction: target model + draft model stay loaded
- generate/stream_generate route based on mode/context, no permanent state
- Remove in_fallback_mode from get_stats()
- stop() cleans up fallback engine, unloads shared models
- Add dflash_max_ctx field to ModelSettings dataclass (Optional[int])
- DFlashEngine reads from model_settings.dflash_max_ctx instead of env var only
- Env var DFLASH_MAX_CTX still takes priority for global override
- UI: add Max Prompt Tokens input in DFlash section (default 4096)
- add _clear_env() helper to clear DFLASH_MAX_CTX before tests that
verify model_settings-based resolution.
- test model settings and dflash routing
… models

Apply omlx post-load transforms (GatedDeltaNet advance, IndexCache, etc).
dflash-mlx installs its own hooks on linear_attn/self_attn but doesn't apply omlx's GatedDeltaNet advance patch.
When a long-context request triggered fallback to BatchedEngine,  _uninstall_dflash_hooks permanently disabled dflash-mlx hooks on the shared model. Subsequent short-context requests would silently fall through to baseline performance because _dflash_split_sdpa_enabled was set to False and never restored.

Replace the one-shot uninstall with toggle functions:
- _enable_dflash_hooks(model): sets _dflash_split_sdpa_enabled=True on all self_attn layers, re-wraps _ExactSmallProjPad on linear_attn layers if they were unwrapped.
- _disable_dflash_hooks(model): sets _dflash_split_sdpa_enabled=False, unwraps _ExactSmallProjPad back to raw linear layers.

Call _enable before the DFlash speculative path (generate, stream_generate) and _disable before creating the fallback engine.

A threading lock prevents concurrent flag toggling from async request handlers and the MLX executor thread.

The cost of toggling should be negligable.
@sparsely-active
Copy link
Copy Markdown
Author

Benchmarks

Hardware: M1 Max 64GB
Model: Qwen3.6-35B-A3B-MLX-oQ5-FP16
Drafter: Qwen3.6-35B-A3B-Draft

Baseline, DFlash completely disabled

--------------------------------------------------------------------------------
Test                TTFT(ms)    TPOT(ms)        pp TPS        tg TPS      E2E(s)    Throughput    Peak Mem
pp1024/tg128          1663.4       17.57   615.6 tok/s    57.4 tok/s       3.895   295.8 tok/s    23.86 GB
pp4096/tg128          4525.1       18.36   905.2 tok/s    54.9 tok/s       6.857   616.0 tok/s    24.64 GB
pp8192/tg128          8893.7       19.20   921.1 tok/s    52.5 tok/s      11.333   734.2 tok/s    24.99 GB
pp16384/tg128        18583.6       21.32   881.6 tok/s    47.3 tok/s      21.291   775.5 tok/s    25.61 GB
pp32768/tg128        42245.3       24.78   775.7 tok/s    40.7 tok/s      45.393   724.7 tok/s    26.95 GB

DFlash enabled, max context: ~4k

--------------------------------------------------------------------------------
Test                TTFT(ms)    TPOT(ms)        pp TPS        tg TPS      E2E(s)    Throughput    Peak Mem
pp1024/tg128          1972.5        7.02   519.1 tok/s   143.6 tok/s       2.864   402.3 tok/s    25.87 GB
pp4096/tg128          6811.6        7.71   601.3 tok/s   130.8 tok/s       7.790   542.2 tok/s    27.01 GB
pp8192/tg128          9939.5       19.79   824.2 tok/s    50.9 tok/s      12.452   668.1 tok/s    25.87 GB
pp16384/tg128        18728.7       21.75   874.8 tok/s    46.3 tok/s      21.490   768.3 tok/s    26.49 GB
pp32768/tg128        42380.0       26.08   773.2 tok/s    38.6 tok/s      45.692   720.0 tok/s    27.83 GB

DFlash enabled, max context: ~16k

--------------------------------------------------------------------------------
Test                TTFT(ms)    TPOT(ms)        pp TPS        tg TPS      E2E(s)    Throughput    Peak Mem
pp1024/tg128          1879.5        6.96   544.8 tok/s   144.8 tok/s       2.764   416.9 tok/s    25.87 GB
pp4096/tg128          7136.2        7.65   574.0 tok/s   131.8 tok/s       8.107   521.0 tok/s    27.01 GB
pp8192/tg128         14744.9        8.59   555.6 tok/s   117.3 tok/s      15.836   525.4 tok/s    27.52 GB
pp16384/tg128        35989.4       10.80   455.2 tok/s    93.3 tok/s      37.361   442.0 tok/s    28.15 GB
pp32768/tg128        44010.7       25.12   744.5 tok/s    40.1 tok/s      47.201   696.9 tok/s    27.83 GB

Sample exceptence rates (~93%)

2026-04-22 14:28:15,664 - omlx.engine.dflash - INFO - [-] - DFlash generation complete: 128 tokens, 46.4 tok/s, acceptance=93.0%, cycles=9
2026-04-22 14:28:23,771 - omlx.engine.dflash - INFO - [-] - DFlash generation complete: 128 tokens, 15.8 tok/s, acceptance=93.0%, cycles=9
2026-04-22 14:28:39,606 - omlx.engine.dflash - INFO - [-] - DFlash generation complete: 128 tokens, 8.1 tok/s, acceptance=93.0%, cycles=9
2026-04-22 14:29:16,967 - omlx.engine.dflash - INFO - [-] - DFlash generation complete: 128 tokens, 3.4 tok/s, acceptance=92.2%, cycles=10

@sparsely-active
Copy link
Copy Markdown
Author

Notes on dflash-mlx hooks

Hooks are installed once when load_target_bundle() is called in dflash-mlx's runtime. They persist on the model class for the lifetime of the process.

load_target_bundle()
    │
    ├─ _install_split_full_attention_hook(self_attn)
    ├─ _install_speculative_linear_cache_hook(linear_attn)
    └─ _install_exact_small_proj_hooks(linear_attn)

The hooks are class-level, meaning they affect all instances of the attention class, including any fallback engine that shares the same model object. This changes the behaviour of the target model.

Approach

The goal is to load the primary/target model only once in memory, so when switching between DFlash and non-DFlash requests (based on context size) we flip these hooks on and off like a lightswitch. Two of them are basically just boolean flags, the overhead should be negligible.

There's a simple lock with reference-count added to ensure we don't handle both a DFlash and non-DFlash request at the same time on the same model, to avoid any weird scenarios or race conditions.

@davidpeden3
Copy link
Copy Markdown

I ran your PR against oMLX 0.3.6 stable on an M5 Max / 128GB unified memory machine
using mlx-community/Qwen3.5-27B-8bit with the z-lab/Qwen3.5-27B-DFlash draft.
Good news and one real bug.

ctx sweep, median of 3 trials per size (gen tok/s)

Two bench modes: clean restarts oMLX before every size; session runs the
ascending sweep in one process lifetime and then re-fires a 2K prompt as a
post-routing probe.

                                          stock 0.3.6   pr876 (this PR)   pr876 + fix
size=2048  actual=1854  (DFlash)              26.4           26.4            26.4
size=4096  actual=3762  (DFlash)              23.1           23.2            23.2
size=8192  actual=7578  (fallback)            16.5           14.8            16.4
size=16384 actual=15210 (fallback)            16.2           14.8            16.1
size=32768 actual=30474 (fallback)            15.4           14.5            15.4

session post-probe 2K after 32K, same process:
                                               16.8  ❌       24.6  ✅         24.6  ✅

DFlash-path rows (2K/4K) are untouched by the PR — same rate as stock. Good.

Fallback-path rows (8K+) show a consistent ~6-10% regression vs stock in the PR
as-submitted. Closed completely by a small fix (see below).

The post-probe row is the whole point of the PR on small-prompts-after-large-
prompts workloads: stock is stuck in fallback mode for the rest of the process;
the PR keeps DFlash available.

The bug: VLM fallback served from a DFlash-loaded (mlx-lm) target

Qwen3.5-27B-8bit has a vision_config block, so oMLX/model_discovery.py
classifies it as VLM and sets fallback_engine_type="vlm". Both stock and your
PR end up using VLMBatchedEngine as the fallback.

  • Stock's eviction path tears down DFlash, then VLMBatchedEngine.start()
    loads via mlx_vlm.utils.load → returns a VLM-shaped mlx_vlm.models.qwen3_5.Model.
  • Your PR skips eviction and passes model=self._target_model to
    VLMBatchedEngine. But self._target_model came from
    dflash_mlx.runtime.load_target_bundle → which calls mlx_lm.utils.load
    returns a text-only mlx_lm.models.qwen3_5.Model.

Different Python class, different forward method, different internal structure.
VLMBatchedEngine still works but via a slower path — 6-10% gen tok/s penalty on
this model.

The LLM fallback case (fallback_engine_type == "batched") is fine — BatchedEngine
uses mlx-lm natively, so sharing DFlash's mlx-lm target with it is type-consistent.

Suggested fix

In omlx/engine/dflash.py::_init_fallback_engine, only share when the fallback
type can consume the dflash-loaded model:

if self._fallback_engine_type == "vlm":
    from .vlm import VLMBatchedEngine
    self._fallback_engine = VLMBatchedEngine(
        model_name=self._model_name,
        scheduler_config=self._scheduler_config,
        model_settings=self._model_settings,
        # VLM expects a mlx_vlm-loaded model; sharing DFlash's mlx-lm target
        # forces a slower path. Let VLMBatchedEngine load its own.
    )
else:
    # LLM fallback — BatchedEngine natively consumes mlx-lm models,
    # sharing is type-consistent and saves ~29GB of GPU memory.
    from .batched import BatchedEngine
    self._fallback_engine = BatchedEngine(
        model_name=self._model_name,
        scheduler_config=self._scheduler_config,
        model_settings=self._model_settings,
        model=shared_model,
        tokenizer=shared_tokenizer,
    )

Tradeoff: for VLM-classified models, the fix reverts to stock's behaviour of
loading the target twice. On a memory-constrained machine that may be visible;
on 128GB unified memory it's fine. A follow-up would be a registry keyed by
(model_name, loader_kind) so the model is only loaded once per process even
when DFlash and VLM both need a reference.

pr876 + fix in the table above is the three-line change applied locally.

Side finding: tiered-KV determinism

Pre-existing issue on stock 0.3.6: with DFlash enabled, greedy (temp=0) output
hash differs cold vs warm-after-restart at size=4000 (rerunning the same prompt
after a brew services restart omlx produces a different content SHA). Re-ran
the same scenario against your PR patched at sizes 4000/8000/16000 — content
determinism passes at all sizes
. Looks like the per-request routing change
sidesteps whatever was corrupting the post-restart content in stock's eviction
path. Worth mentioning in case someone else hits it and wonders.

Two remaining observations for context:

  1. Post-restart prefill is still slow at all sizes (4K: 629 tok/s post-restart
    vs 740 warm-immediate; 8K: 1855 vs 4937; 16K: 2678 vs 5039). Content is
    deterministic but SSD-cache reuse looks imperfect. Likely unrelated to this
    PR — same signal on stock.

  2. At 8K/16K (routed to VLM), reasoning_content diverges between the first
    request of a process and all subsequent requests (stable after the first).
    Content still matches so user-visible output is identical. Likely a VLM-
    engine temp=0 non-determinism, orthogonal to this PR; happy to file it
    separately if that's useful.

Summary

With the three-line fix, this PR is a strict improvement over stock 0.3.6 on
Qwen3.5-27B: identical perf per-size, ~46% throughput recovery on small prompts
after a large one in the same session. Happy to open a PR on top of yours if
that's easier than you folding the fix in.

@sparsely-active
Copy link
Copy Markdown
Author

sparsely-active commented Apr 22, 2026

Nice catch! Thanks for taking a look. I was testing with the model type explicitly set to LLM, and didn't test the VLM path. So we have solid DFlash support for the LLM path via bstnxbt/dflash-mlx, but on the VLM side it looks like DFlash support was only merged a few days ago(?) in Blaizzy/mlx-vlm, but it doesn't seem like it's been released yet? I could be reading this wrong but it seems like we can't actually support the DFlash + VLM combo right now..?

Edit: confirmed, mlx-vlm currently has no DFlash support, it's expected to land in the next release. So the upshot is (with or without this PR) if you try to use DFlash with a VLM or with a model loaded by mlx-vlm (for whatever reason); DFlash won't work and you might experience weird issues.

@davidpeden3
Copy link
Copy Markdown

Interesting — but I think there's a gap between "mlx-vlm has no DFlash support" and "DFlash + VLM-classified models doesn't work." Our bench on Qwen3.5-27B-8bit (classified VLM in oMLX because its config.json has vision_config) shows DFlash delivering a real ~60% gen tok/s speedup on text-only prompts in stock 0.3.6:

size=2048 actual=1854  DFlash:  26.4 tok/s
size=8192 actual=7578  fallback: 16.5 tok/s   (post-eviction BatchedEngine)

That speedup is real because DFlashEngine.start() loads the target via dflash_mlx.runtime.load_target_bundlemlx_lm.utils.load, so DFlash runs on the mlx-lm backend regardless of the model's classification. mlx-vlm never enters the DFlash-active request path for text-only prompts.

mlx-vlm's missing DFlash support only matters when:

  1. The request actually routes through mlx-vlm's forward path (image inputs, or a large prompt routed to VLMBatchedEngine as the fallback).
  2. Or the two get mixed — which is exactly what your PR as-submitted was doing, and exactly what the fix addresses.

So I'd argue the type-guard fix still stands on its own merits today:

  • Stock: DFlash speedup on text inference of VLM-classified models ✅
  • PR unpatched: DFlash speedup preserved, fallback path regresses 6-10% because mlx-vlm engine serves from a mlx-lm model ❌
  • PR patched: DFlash speedup preserved, fallback path at stock parity ✅

When mlx-vlm ships DFlash support, the fallback_engine_type == "vlm" branch can be revisited — at that point sharing might become correct, but it depends on whether mlx-vlm's DFlash expects the dflash-loaded target or does its own load. Worth a second pass then; doesn't need to block this PR.

Separately — on the "weird issues" warning: for anyone hitting this thread later with a VLM-classified text model and DFlash enabled, the stock behaviour is fine for text requests and speculative-decodes correctly. The weirdness surfaces if you start sending images to a DFlash-enabled VLM model, which is a different code path entirely.

Happy to open the fixup PR whenever you're ready.

@sparsely-active
Copy link
Copy Markdown
Author

Aha, the new mlx-vlm has been released and the update for it has been merged to main. I'll have to rebase and re-assess.

@deepsweet
Copy link
Copy Markdown

@sparsely-active hi! Look what I've found #866 (comment)

DFlash draft model author claims speculative 100k tokens on his side.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants