Skip to content

feat(pflash): importance-scored sparse prefill scaffold (#136)#161

Open
st-adam wants to merge 1 commit into
jjang-ai:mainfrom
st-adam:feat/pflash-sparse-prefill
Open

feat(pflash): importance-scored sparse prefill scaffold (#136)#161
st-adam wants to merge 1 commit into
jjang-ai:mainfrom
st-adam:feat/pflash-sparse-prefill

Conversation

@st-adam
Copy link
Copy Markdown

@st-adam st-adam commented May 12, 2026

Summary

Lands the algorithm pieces + integration scaffold for PFlash importance-scored sparse prefill (issue #136) so the BSA-equivalent Metal kernel work has a stable landing point. Default OFF; no behaviour change when unset.

Per @jjang-ai's note on #136 ("Would appreciate PR's in most recent stable versions"), this PR is based on current main (1.5.32 series, base 9cfbeb24).

What's shipped

Algorithm (pure MLX, no custom kernel):

  • vmlx_engine/utils/pflash.py
    • pflash_score_blocks() — per-block entropy of pooled drafter logits
    • pflash_select_top_k() — top-K with head/tail pinning
    • pflash_block_ranges() — coalesce adjacent kept blocks
    • plan_sparse_prefill() — end-to-end planner returning a PFlashPlan
  • vmlx_engine/utils/pflash_drafter.py
    • load_pflash_drafter() — co-resident drafter load via mlx_lm.load
    • drafter_score_blocks() — single drafter forward + per-block logit pool

Integration:

  • CLI flags: --enable-pflash, --pflash-drafter, --pflash-block-size, --pflash-keep-ratio, --pflash-min-seq-len
  • Env equivalents: VMLX_ENABLE_PFLASH, VMLX_PFLASH_DRAFTER, VMLX_PFLASH_BLOCK_SIZE, VMLX_PFLASH_KEEP_RATIO, VMLX_PFLASH_MIN_SEQ_LEN, VMLX_PFLASH_HEAD_TOKENS, VMLX_PFLASH_TAIL_TOKENS
  • MLLMBatchGenerator._maybe_build_pflash_plan() — invoked from the chunked-prefill entry; builds a plan and emits a one-line [pflash] plan: K/N blocks kept (X% coverage, R ranges) telemetry log
  • /health.pflash — activations, blocks_total, blocks_kept, skipped_below_min_seq_len, failures

Tests:

  • tests/test_pflash.py — 16 algorithm-only unit tests, no model weights required:
    • Config validation (keep_ratio range, block_size > 0)
    • Score blocks: entropy ordering on uniform vs one-hot
    • Top-K selection: basic, head/tail pin, keep-all
    • Block ranges: coalesce, partial trailing block, empty mask
    • plan_sparse_prefill: end-to-end, empty prompt
    • Activation gating: drafter-required, min_seq_len, env config
  • tests/benchmark/test_pflash_ttft.py — opt-in TTFT benchmark harness (target × drafter × ctx-size matrix; requires real model weights; not run in CI)

Current scope limitation (intentional)

The chunked-prefill path still runs the dense forward over the full prompt. The plan's keep_ranges are logged but not yet routed to a sparse-mask target forward. The reason is that the actually-sparse target prefill requires either:

  1. A BSA-equivalent Metal kernel (large undertaking — FA-2 derived attention kernel for Metal's threadgroup model), or
  2. Reusing mx.fast.scaled_dot_product_attention with a precomputed block-sparse mask plus cache-layout adjustments for the skipped blocks.

Path (1) is the right long-term answer per the CUDA reference. This PR lands the framework so either path slots in without redesigning the integration surface.

Reference performance target (from #136)

CUDA reference (RTX 3090): 128K cold TTFT 24.8 s vs llama.cpp 257 s = 10.4× with keep_ratio=0.05, NIAH single-needle retrieved at every measured context.

Metal expectation (this scaffold, post-kernel): ~5–8× on long-context cold-prefix workloads.

Quality gate (pre-merge for the sparse-kernel follow-up)

Must pass NIAH single-needle at every context size with keep_ratio ≥ 0.05 plus a multi-document reasoning eval (e.g. LongBench HotpotQA). Validation harness scaffolded in tests/benchmark/; full quality suite to land with the sparse forward.

Test plan

  • pytest tests/test_pflash.py -v — 16/16 pass
  • pytest tests/test_memory_limits.py -v — 8/8 pass (no regression)
  • vmlx serve --help shows the 5 new --pflash-* flags
  • All touched modules (mllm_batch_generator, cli, server, utils.pflash, utils.pflash_drafter) import clean
  • Live model: cold-prefix prompt at 8K+ tokens with --enable-pflash --pflash-drafter mlx-community/Qwen3-0.6B-bf16 produces [pflash] plan: log entry and /health.pflash.activations > 0
  • Live model: same prompt without flags is byte-identical (default OFF semantics)

Follow-up PRs

  1. BSA-equivalent Metal kernel for sparse target forward
  2. Actually-sparse target prefill (route keep_ranges to kernel)
  3. NIAH single-needle quality gate
  4. Drafter chunked scoring for prompts that exceed drafter context

Closes #136 (scaffold; follow-ups linked above).

🤖 Generated with Claude Code

…ai#136)

PFlash uses a small drafter to score per-block importance over long
prompts; the target model can then prefill only the spans that matter.
This PR lands the algorithm + integration scaffold so subsequent work on
the BSA-equivalent Metal kernel has a stable landing point.

Shipped in this PR:
- `vmlx_engine/utils/pflash.py` — block scoring (entropy of pooled
  drafter logits), top-K selection with head/tail pinning, range
  coalescing, and `plan_sparse_prefill()` planner.
- `vmlx_engine/utils/pflash_drafter.py` — co-resident drafter loader and
  `drafter_score_blocks()` (single drafter forward + per-block pool).
- CLI flags: `--enable-pflash`, `--pflash-drafter`, `--pflash-block-size`,
  `--pflash-keep-ratio`, `--pflash-min-seq-len`. Env equivalents:
  `VMLX_ENABLE_PFLASH`, `VMLX_PFLASH_*`.
- `MLLMBatchGenerator._maybe_build_pflash_plan()` — invoked from the
  chunked-prefill entry. Currently informational: it builds a plan,
  emits a one-line `[pflash]` telemetry log, and falls through to the
  dense prefill path. The sparse-mask kernel is a follow-up.
- `/health.pflash` — activations / blocks_kept / blocks_total /
  failures counters.
- `tests/test_pflash.py` — 16 algorithm-only unit tests covering
  validation, scoring, selection, range coalescing, planner end-to-end,
  activation gating, env config.
- `tests/benchmark/test_pflash_ttft.py` — opt-in TTFT benchmark harness
  (target × drafter × ctx-size matrix; not run in CI).

Default state: OFF. With `--enable-pflash` but no drafter loaded, the
engine logs a warning and stays on dense prefill — no behaviour change.

Follow-ups (separate PRs):
- BSA-equivalent Metal kernel for the target's sparse forward.
- Actually-sparse target prefill (currently the kept_ranges are logged
  but the existing dense chunked loop still runs over the full prompt).
- NIAH single-needle quality gate at keep_ratio ≥ 0.05.
- Drafter chunked scoring for prompts that exceed drafter context.

Closes jjang-ai#136 (scaffold only; tracking follow-ups linked above).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Feature: native PFlash-style speculative prefill (importance-scored sparse prefill) for long-context cold TTFT

1 participant