Skip to content

Add Prompt Lookup Decoding (ngram-simple) and Rolling-Hash Speculative Memory (ngram-mod)#1297

Open
mayank2130 wants to merge 6 commits into
ml-explore:mainfrom
mayank2130:pld-ngram-simple
Open

Add Prompt Lookup Decoding (ngram-simple) and Rolling-Hash Speculative Memory (ngram-mod)#1297
mayank2130 wants to merge 6 commits into
ml-explore:mainfrom
mayank2130:pld-ngram-simple

Conversation

@mayank2130

@mayank2130 mayank2130 commented May 22, 2026

Copy link
Copy Markdown

Closes #851

Summary

Adds Prompt Lookup Decoding (PLD) and rolling-hash speculative decoding to mlx_lm via a generalized DraftStrategy abstraction.

Instead of generating speculative drafts with a smaller neural model, the new strategies reuse previously observed token trajectories:

  • ngram-simple performs exact prompt-history lookup
  • ngram-mod implements a rolling-hash associative memory ported from llama.cpp PR #19164

Both strategies preserve output correctness because speculative tokens are only accepted if verified by the target model under the same sampling configuration.

This PR adds:

  • DraftStrategy interface for pluggable speculative drafters
  • ModelDraftStrategy for existing neural drafting
  • NgramSimpleStrategy for prompt lookup decoding
  • NgramModStrategy + NgramModTable for rolling-hash speculative memory
  • optional adaptive repetition gating
  • CLI, Python, and server integration
  • process-shared speculative memory for cross-request reuse

Usage

from mlx_lm import load, stream_generate
from mlx_lm.sample_utils import make_sampler

model, tokenizer = load("Qwen/Qwen3-8B-MLX-4bit")

prompt = tokenizer.apply_chat_template(
    [{"role": "user", "content": "Write a Python function `add(a, b)`."}],
    add_generation_prompt=True,
    enable_thinking=False,
)

for response in stream_generate(
    model,
    tokenizer,
    prompt,
    max_tokens=256,
    sampler=make_sampler(temp=0.0),
    draft_type="ngram-simple",      # or "ngram-mod"
    num_draft_tokens=4,
    ngram_size=3,                   # use 16 for ngram-mod
    disable_adaptive_gate=True,
):
    print(response.text, end="", flush=True)

For ngram-mod, reuse a table across related generations to preserve learned n-gram memory:

from mlx_lm import load, stream_generate
from mlx_lm.generate import NgramModTable
from mlx_lm.sample_utils import make_sampler

model, tokenizer = load("Qwen/Qwen3-8B-MLX-4bit")
table = NgramModTable(n=16)

for prompt_text in prompts:
    prompt = tokenizer.apply_chat_template(
        [{"role": "user", "content": prompt_text}],
        add_generation_prompt=True,
        enable_thinking=False,
    )

    for response in stream_generate(
        model,
        tokenizer,
        prompt,
        max_tokens=256,
        sampler=make_sampler(temp=0.0),
        draft_type="ngram-mod",
        num_draft_tokens=6,
        ngram_size=16,
        ngram_mod_table=table,
        disable_adaptive_gate=True,
    ):
        print(response.text, end="", flush=True)

CLI: multi-turn ngram-simple

printf '%s\n%s\n%s\nq\n' \
'Write a Python function summarize_orders(orders) where each order has id, customer, total, and status. Return only the code.' \
'Now add a currency="$" parameter and use it when formatting money values. Keep the cancelled-order behavior unchanged. Return the full updated function only.' \
'Update summarize_orders so it skips orders whose status is cancelled. Keep the same structure and return the full updated function only.' \
| python -m mlx_lm chat \
  --model Qwen/Qwen3-8B-MLX-4bit \
  --max-tokens 500 \
  --temp 0 \
  --chat-template-config '{"enable_thinking": false}' \
  --draft-type ngram-simple \
  --num-draft-tokens 4 \
  --ngram-size 3 \
  --disable-adaptive-gate

CLI: multi-turn ngram-mod

printf '%s\n%s\n%s\nq\n' \
'Write a Python function summarize_orders(orders) where each order has id, customer, total, and status. Return only the code.' \
'Now add a currency="$" parameter and use it when formatting money values. Keep the cancelled-order behavior unchanged. Return the full updated function only.' \
'Update summarize_orders so it skips orders whose status is cancelled. Keep the same structure and return the full updated function only.' \
| python -m mlx_lm chat \
  --model Qwen/Qwen3-8B-MLX-4bit \
  --max-tokens 500 \
  --temp 0 \
  --chat-template-config '{"enable_thinking": false}' \
  --draft-type ngram-mod \
  --num-draft-tokens 6 \
  --ngram-size 16 \
  --disable-adaptive-gate

The chat command keeps the conversation history and prompt cache alive across turns, so T2/T3 can reuse the generated structure from T1.

Server

Per-request JSON overrides: draft_type, ngram_size disable_adaptive_gate

Architecture

Speculative drafting is abstracted behind:

class DraftStrategy(Protocol):
    def propose(self, y, n_max, ctx) -> mx.array: ...
    def rewind(self, n: int) -> None: ...
    def observe(self, tokens) -> None: ...
    def accept(self, n_accepted: int, n_drafted: int) -> None: ...

NgramSimpleStrategy scans backward for matching n-grams and proposes the following continuation tokens directly from prior history.

  • Best suited for: short iterative edits, local repetition, single-user coding flows

NgramModStrategy ports llama.cpp's rolling-hash speculative memory.
Architecture mirrors llama.cpp's split between:

  • process-global speculative memory
  • per-request runtime state

The shared table stores:
hash(ngram) -> next_token
allowing speculative reuse across requests handled by the same running server process.

Implementation behavior intentionally matches llama.cpp:

  • fixed-size lossy hash table
  • silent overwrite collision policy
  • verifier-corrected speculative drafts
  • adaptive reset on repeated low acceptance

Adaptive Gate

An optional adaptive gate computes a 3-gram repetition score over the prompt. If repetition falls below:
NGRAM_GATE_THRESHOLD = 0.02
speculation is skipped automatically.

This is particularly important for ngram-mod, whose cold-start behavior can regress below baseline throughput on low-repetition prompts.

Benchmarks

All benchmarks used:

  • mlx-community/Llama-3.2-3B-Instruct-4bit
  • Apple Silicon

LONG MULTI-TURN EDITING (~280 TOK/TURN) — OVERALL

config tok/s acc% speedup
baseline 54.09 1.00×
ngram-simple nd=4 91.57 62.7% 1.69×
ngram-simple nd=6 89.01 67.8% 1.65×
ngram-mod nd=6 84.69 59.4% 1.57×
ngram-mod nd=8 82.69 61.7% 1.53×

ngram-mod nd=6 per-turn behavior

turn prompt baseline tok/s ngram-simple nd=6 ngram-mod nd=6
T1 write EmailValidator class 57.1 62.4 (1.09×) 54.4 (0.95×)
T2 add is_disposable method 57.9 118.3 (2.04×) 123.9 (2.14×)
T3 reject whitespace in domain 55.0 118.4 (2.15×) 122.9 (2.24×)
T4 log inside validate 49.9 118.0 (2.37×) 135.8 (2.72×)

mayank2130 and others added 2 commits May 19, 2026 15:35
@mayank2130 mayank2130 changed the title Add Prompt Lookup Decoding (ngram-simple) via DraftStrategy abstraction Add Prompt Lookup Decoding (ngram-simple & ngram-mod) via DraftStrategy abstraction May 22, 2026
@mayank2130 mayank2130 changed the title Add Prompt Lookup Decoding (ngram-simple & ngram-mod) via DraftStrategy abstraction Add Prompt Lookup Decoding (ngram-simple) and Rolling-Hash Speculative Memory (ngram-mod) May 22, 2026
@mayank2130

Copy link
Copy Markdown
Author

hey @angeloskath can this PLD/n-gram decoding be reviewed.

If you're not the one to reachout for mlx-lm PRs could you point me to someone else. Thanks.

@ashalliants

Copy link
Copy Markdown

Tested this branch on Apple Silicon (M5, 32 GB, macOS 25.2.0, Python 3.13) against Qwen3.5/3.6-family models and hit three issues worth flagging — two bugs and one UX trap that together made the feature look like it was working when it wasn't.

1. CLI flags are parsed but never forwarded to generate()

In main() the new arguments (--draft-type, --ngram-size, --disable-adaptive-gate) are added to the parser, but the generate(...) call only forwards num_draft_tokens:

response = generate(
    ...
    draft_model=draft_model,
    num_draft_tokens=args.num_draft_tokens,
)   # args.draft_type, args.ngram_size, args.disable_adaptive_gate never passed

So mlx_lm generate --draft-type ngram-simple silently runs plain decoding. I benchmarked all three draft types via the CLI and got byte-identical speeds (~54.8 tok/s on Qwen3.6-35B-A3B-4bit) before noticing the flags were no-ops. Calling stream_generate(draft_type=...) directly works as documented.

2. Hybrid-attention models fail mid-generation with an untrimmable-cache error

Models with gated-delta / linear-attention layers (ArraysCache) raise once speculation actually engages:

ValueError: Speculative decoding requires a trimmable prompt cache (got {'ArraysCache'}).

This covers the entire Qwen3.5/3.6 family (qwen3_5, qwen3_5_moe) — I reproduced on both Qwen3.6-35B-A3B-4bit (MoE) and a Qwen3.6-27B dense build. Given these are popular local models, it might be worth checking cache trimmability up front when the strategy is constructed (in the draft_type resolution block) and raising/warning immediately, rather than failing mid-stream after prefill.

3. The adaptive gate masks issue 2 — incompatible models appear to "work"

Because the 3-gram repetition gate silently disables speculation when the prompt scores below NGRAM_GATE_THRESHOLD (0.02), short or non-repetitive prompts run as plain decode with no indication. On the hybrid-cache models above, a quick smoke test "succeeds" (it never speculates), and the ValueError only surfaces later with a repetitive prompt. A one-line logging.info when the gate suppresses speculation (and/or a field on GenerationResponse) would make both the gate and issue 2 much easier to diagnose.

Happy to provide full repro scripts/timings if useful — everything above was reproduced at commit pld-ngram-simple HEAD with greedy sampling.

@ashalliants

Copy link
Copy Markdown

Repro scripts and prompts for the findings above, as offered: https://gist.github.com/ashalliants/91819d410f6822e406a314740c8b7d0ebench_api.py reproduces both the ArraysCache failure and the gate-masked false-positive; the README maps each file to each claim.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

n-gram hashing for speculative decoding

2 participants