feat(pflash): importance-scored sparse prefill scaffold (#136)#161
Open
st-adam wants to merge 1 commit into
Open
feat(pflash): importance-scored sparse prefill scaffold (#136)#161st-adam wants to merge 1 commit into
st-adam wants to merge 1 commit into
Conversation
…ai#136) PFlash uses a small drafter to score per-block importance over long prompts; the target model can then prefill only the spans that matter. This PR lands the algorithm + integration scaffold so subsequent work on the BSA-equivalent Metal kernel has a stable landing point. Shipped in this PR: - `vmlx_engine/utils/pflash.py` — block scoring (entropy of pooled drafter logits), top-K selection with head/tail pinning, range coalescing, and `plan_sparse_prefill()` planner. - `vmlx_engine/utils/pflash_drafter.py` — co-resident drafter loader and `drafter_score_blocks()` (single drafter forward + per-block pool). - CLI flags: `--enable-pflash`, `--pflash-drafter`, `--pflash-block-size`, `--pflash-keep-ratio`, `--pflash-min-seq-len`. Env equivalents: `VMLX_ENABLE_PFLASH`, `VMLX_PFLASH_*`. - `MLLMBatchGenerator._maybe_build_pflash_plan()` — invoked from the chunked-prefill entry. Currently informational: it builds a plan, emits a one-line `[pflash]` telemetry log, and falls through to the dense prefill path. The sparse-mask kernel is a follow-up. - `/health.pflash` — activations / blocks_kept / blocks_total / failures counters. - `tests/test_pflash.py` — 16 algorithm-only unit tests covering validation, scoring, selection, range coalescing, planner end-to-end, activation gating, env config. - `tests/benchmark/test_pflash_ttft.py` — opt-in TTFT benchmark harness (target × drafter × ctx-size matrix; not run in CI). Default state: OFF. With `--enable-pflash` but no drafter loaded, the engine logs a warning and stays on dense prefill — no behaviour change. Follow-ups (separate PRs): - BSA-equivalent Metal kernel for the target's sparse forward. - Actually-sparse target prefill (currently the kept_ranges are logged but the existing dense chunked loop still runs over the full prompt). - NIAH single-needle quality gate at keep_ratio ≥ 0.05. - Drafter chunked scoring for prompts that exceed drafter context. Closes jjang-ai#136 (scaffold only; tracking follow-ups linked above). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Lands the algorithm pieces + integration scaffold for PFlash importance-scored sparse prefill (issue #136) so the BSA-equivalent Metal kernel work has a stable landing point. Default OFF; no behaviour change when unset.
Per @jjang-ai's note on #136 ("Would appreciate PR's in most recent stable versions"), this PR is based on current
main(1.5.32 series, base9cfbeb24).What's shipped
Algorithm (pure MLX, no custom kernel):
vmlx_engine/utils/pflash.pypflash_score_blocks()— per-block entropy of pooled drafter logitspflash_select_top_k()— top-K with head/tail pinningpflash_block_ranges()— coalesce adjacent kept blocksplan_sparse_prefill()— end-to-end planner returning aPFlashPlanvmlx_engine/utils/pflash_drafter.pyload_pflash_drafter()— co-resident drafter load viamlx_lm.loaddrafter_score_blocks()— single drafter forward + per-block logit poolIntegration:
--enable-pflash,--pflash-drafter,--pflash-block-size,--pflash-keep-ratio,--pflash-min-seq-lenVMLX_ENABLE_PFLASH,VMLX_PFLASH_DRAFTER,VMLX_PFLASH_BLOCK_SIZE,VMLX_PFLASH_KEEP_RATIO,VMLX_PFLASH_MIN_SEQ_LEN,VMLX_PFLASH_HEAD_TOKENS,VMLX_PFLASH_TAIL_TOKENSMLLMBatchGenerator._maybe_build_pflash_plan()— invoked from the chunked-prefill entry; builds a plan and emits a one-line[pflash] plan: K/N blocks kept (X% coverage, R ranges)telemetry log/health.pflash— activations, blocks_total, blocks_kept, skipped_below_min_seq_len, failuresTests:
tests/test_pflash.py— 16 algorithm-only unit tests, no model weights required:plan_sparse_prefill: end-to-end, empty prompttests/benchmark/test_pflash_ttft.py— opt-in TTFT benchmark harness (target × drafter × ctx-size matrix; requires real model weights; not run in CI)Current scope limitation (intentional)
The chunked-prefill path still runs the dense forward over the full prompt. The plan's
keep_rangesare logged but not yet routed to a sparse-mask target forward. The reason is that the actually-sparse target prefill requires either:mx.fast.scaled_dot_product_attentionwith a precomputed block-sparse mask plus cache-layout adjustments for the skipped blocks.Path (1) is the right long-term answer per the CUDA reference. This PR lands the framework so either path slots in without redesigning the integration surface.
Reference performance target (from #136)
CUDA reference (RTX 3090): 128K cold TTFT 24.8 s vs llama.cpp 257 s = 10.4× with
keep_ratio=0.05, NIAH single-needle retrieved at every measured context.Metal expectation (this scaffold, post-kernel): ~5–8× on long-context cold-prefix workloads.
Quality gate (pre-merge for the sparse-kernel follow-up)
Must pass NIAH single-needle at every context size with
keep_ratio ≥ 0.05plus a multi-document reasoning eval (e.g. LongBench HotpotQA). Validation harness scaffolded intests/benchmark/; full quality suite to land with the sparse forward.Test plan
pytest tests/test_pflash.py -v— 16/16 passpytest tests/test_memory_limits.py -v— 8/8 pass (no regression)vmlx serve --helpshows the 5 new--pflash-*flagsmllm_batch_generator,cli,server,utils.pflash,utils.pflash_drafter) import clean--enable-pflash --pflash-drafter mlx-community/Qwen3-0.6B-bf16produces[pflash] plan:log entry and/health.pflash.activations > 0Follow-up PRs
keep_rangesto kernel)Closes #136 (scaffold; follow-ups linked above).
🤖 Generated with Claude Code