Skip to content

[MLX] Reduce physical footprint memory in RingBufferKVCache for chunked prefill#20341

Merged
metascroy merged 14 commits into
mainfrom
reduce-mem-sliding-kv
Jun 17, 2026
Merged

[MLX] Reduce physical footprint memory in RingBufferKVCache for chunked prefill#20341
metascroy merged 14 commits into
mainfrom
reduce-mem-sliding-kv

Conversation

@metascroy

Copy link
Copy Markdown
Contributor

When doing chunked prefill, the RingBufferKVCache does not need 2x window size, but instead window_size + max_write_length - 1 (prefill chunk size). This PR exposes that knob and wires it to gemma4 31b MLX export, which uses chunk_size 256, smaller than gemma4's window size (1024).

Reduces phys_footprint on a 4K export by around −0.68 GiB (from 13.84 GiB to 13.16 GiB).

@pytorch-bot

pytorch-bot Bot commented Jun 17, 2026

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20341

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 16 Pending, 2 Unrelated Failures, 2 Unclassified Failures

As of commit d91d110 with merge base 0eb8247 (image):

NEW FAILURES - The following jobs have failed:

UNCLASSIFIED FAILURES - DrCI could not classify the following jobs because the workflow did not run on the merge base. The failures may be pre-existing on trunk or introduced by this PR:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 17, 2026
@metascroy

Copy link
Copy Markdown
Contributor Author

@claude review this code

@github-actions

Copy link
Copy Markdown

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@claude

claude Bot commented Jun 17, 2026

Copy link
Copy Markdown

Claude finished @metascroy's task in 6m 39s —— View job


Code Review

I reviewed the diff against origin/main. This is a clean, well-reasoned change. The core idea — that a sliding-window ring buffer needs window_size + max_write_len - 1 slots rather than 2 × window_size — is correct, and the implementation wires it through cleanly.

Tasks

  • Read PR diff and surrounding code
  • Verify the buffer-sizing math
  • Check downstream consumers of buffer_size / mask logic
  • Review test coverage
  • Post review

Correctness ✅

The buffer-size derivation checks out. During a single multi-token write of seq_len tokens at start_pos:

  • The earliest query (start_pos) needs positions back to start_pos − w + 1.
  • The latest query (start_pos + seq_len − 1) needs the freshly written tail.

The union spans (start_pos + seq_len − 1) − (start_pos − w + 1) + 1 = seq_len + w − 1 distinct positions that must coexist in the ring, so buffer_size ≥ window_size + max_write_len − 1 is exactly right. The create_sliding_window_mask reconstruction (backends/mlx/llm/cache.py:338-366) is already written generically in terms of buffer_size, so it stays correct for any b — no change needed there. Good.

The assert max_write_len <= window_size guard (cache.py:276-278) is the right invariant: if a write exceeded the window, slots within a single write could collide.

Minor: docstring vs. actual default

The default is now 2 * window_size − 1, not 2 * window_size. Both cache.py:235 ("a ~2× buffer") and mlx_source_transformations.py:184 ("defaults to the full window (the original 2× buffer)") describe the default as the original 2× buffer, but it's actually one slot smaller than before. The new value is mathematically correct (and a tiny bit leaner), so this is purely a doc-accuracy nit — consider saying "≈2× (2·window_size − 1)" so the −1 isn't surprising to a future reader comparing against the old 2 * max_context_length.

Suggestion: test coverage for the new path

RingBufferKVCacheTest (backends/mlx/test/test_ops.py:2732) only exercises the default buffer size — RingBufferKVCacheModel doesn't plumb max_write_len through. The reduced-buffer case is the one most likely to regress (a wrapping write into the tighter buffer + mask reconstruction), and it's currently untested. Worth adding a config where max_write_len < window_size and seq_step triggers a ring wrap, asserting numerics match a reference. Add a max_write_len test config →

Observation: generic HF path not wired

backends/mlx/llm/source_transformation.py:206 constructs RingBufferKVCache without max_write_len, so the generic HF sliding-window export still gets the full 2× buffer. That's consistent with the PR scope (gemma4 only), but the same memory win is available there if the export-time max prefill length is known — possible follow-up.

Style

max_write_len: int | None = None and the threading through both functions match the surrounding code well. Comments are appropriately scoped to the non-obvious invariant. No issues.


Nothing blocking — the math is sound and the change does what it claims. The only real ask is the missing test for the reduced-buffer path; the docstring wording is a nit.
· branch reduce-mem-sliding-kv

@mergennachin mergennachin requested a review from Gasoonjia June 17, 2026 19:07
@mergennachin

Copy link
Copy Markdown
Contributor

Can't similar optimization apply to CUDA too?

cc @Gasoonjia

@metascroy metascroy force-pushed the reduce-mem-sliding-kv branch from 94a5797 to d91d110 Compare June 17, 2026 22:32
@Gasoonjia

Copy link
Copy Markdown
Contributor

@mergennachin I think it can but due to different ring kv cache impl we may need to have differetn approch. Not a blocker though.

@metascroy metascroy merged commit 66feb0a into main Jun 17, 2026
238 of 254 checks passed
@metascroy metascroy deleted the reduce-mem-sliding-kv branch June 17, 2026 23:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/mlx CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants