Skip to content

Feature request: prompt-lookup (n-gram) drafting as an alternative/hybrid draft source #35

@ashalliants

Description

@ashalliants

Summary

Feature request: support prompt-lookup / n-gram drafting as an alternative (or supplementary) draft source alongside the DFlash block-diffusion drafter.

Motivation

For edit-heavy / agentic coding workloads (the model re-emits file contents with small changes, multi-turn sessions where output overlaps prior context), a large fraction of generated tokens are verbatim copies of tokens already in context. Prompt-lookup decoding (PLD) drafts those spans for free with simple n-gram matching — no draft model forward passes at all — and typically reaches very high acceptance inside copy regions (vLLM ships this as the [ngram] speculative method; llama.cpp has an equivalent lookup-decoding mode).

Why dflash-mlx is uniquely positioned to do this

I tried to get ngram drafting working on Qwen3.6 via the upstream mlx-lm speculative path (ml-explore/mlx-lm#1297) and it is architecturally blocked there: the generic verifier requires trimmable caches, and the Qwen3.5/3.6 family's gated-delta linear-attention layers use an untrimmable ArraysCache:

ValueError: Speculative decoding requires a trimmable prompt cache (got {'ArraysCache'}).

dflash-mlx already solved the hard part — verification with partial-acceptance rollback over these hybrid-cache models. Swapping (or augmenting) the draft source should be much smaller work than what already exists: an ngram lookup produces a candidate block, and the existing verify path consumes it unchanged.

Evidence there's headroom

Measured on M5 / 32 GB, Qwen3.6-35B-A3B-4bit + z-lab/Qwen3.6-35B-A3B-DFlash, greedy, ~3K-token generations on a code-edit prompt (re-emit a ~100-line file with a small refactor):

  • DFlash: 67.0 tok/s @ 74.9% acceptance (code-edit), 64.2 tok/s @ 70.9% (fresh code)
  • plain decode: ~54.8 tok/s

74.9% acceptance on a workload where long spans are exact copies suggests the diffusion drafter leaves accuracy on the table precisely where ngram lookup is near-perfect. A hybrid policy could be simple and conservative, e.g.:

  1. try n-gram match against context (cheap, exact);
  2. if a match of length ≥ k exists, emit the lookup continuation as the draft block (zero draft-model cost);
  3. otherwise fall back to the DFlash drafter as today.

That keeps DFlash's behaviour on novel text and upgrades copy regions to near-1.0 acceptance with lower draft cost. It would also benefit models that have no trained DFlash drafter yet, since ngram drafting needs no sidecar model at all.

Possible knobs

  • --draft-source {dflash,ngram,hybrid} (default dflash, current behaviour)
  • min match length / max draft length for the ngram path (vLLM's prompt_lookup_min experience suggests too-short matches can hurt structured/tool-call output, so a conservative default like ≥ 4–8 matters)

Happy to share benchmark scripts/prompts or run comparisons on Apple Silicon (M5, 32 GB) if that's useful.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions