perf: env-gated prefill allocator clear + chunked-loop tightening (PR-B of 3)#163
Open
st-adam wants to merge 1 commit into
Open
perf: env-gated prefill allocator clear + chunked-loop tightening (PR-B of 3)#163st-adam wants to merge 1 commit into
st-adam wants to merge 1 commit into
Conversation
…-B of 3) Pure-implementation perf wins on the long-context chunked-prefill path. All behaviour-preserving; the headline change (skipping per-chunk `mx.clear_cache()`) is opt-in. Changes in `MLLMBatchGenerator._process_prompts` chunked-text branch: 1. **Per-chunk `mx.clear_cache()` is now env-gated.** Default behaviour unchanged. Opt-in via `--prefill-keep-alloc` (or `VMLX_PREFILL_KEEP_ALLOC=1`) keeps the Metal allocator free-list warm across prefill chunks. The 1.5.31 working-set guard (98 %) catches genuine OOMs, so the per-chunk pool-flush is redundant on a well-provisioned box. Expected: 5-12 % TTFT improvement on 16K+ prompts when enabled. 2. **Boundary lookup goes from O(B) per chunk to O(1).** Replaced the `next((b for b in ssm_boundaries if processed < b <= processed + chunk_size and b not in captured), None)` generator scan with a sorted-list + pointer index that advances monotonically. For a 32K prompt and ~4 boundaries that's 32 generator allocations and 128 conditional checks dropped per request. 3. **`input_ids[0].tolist()` is hoisted out of the loop.** Each chunk boundary previously triggered a Metal→CPU sync to materialise the prompt token list for `_maybe_capture_clean_ssm_boundary`. Now computed once before the loop when `_original_token_ids` is not already set. 4. **`_state_layers` precomputed once.** The per-chunk `mx.eval([c.state for c in cache if hasattr(c, 'state')])` walked the full cache layer list each iteration. The hasattr-filtered list is now built once before the loop. CLI: - New `--prefill-keep-alloc` flag mirrors the env var. Tests: - `tests/test_perf_prefill_loop.py` — 8 unit tests covering each invariant (boundary pointer present, state-layers precomputed, tolist hoisted, env-gate logic, CLI propagation, truthy spellings). - Adjacent regression: `tests/test_memory_limits.py` 8/8, `tests/test_ssm_companion_cache.py` 14/14, `tests/test_batching.py` 54/56 (same 2 pre-existing pytest-asyncio infra failures). Out of scope (kept unconditional, justified): - The image-prefill `mx.clear_cache()` at the one-shot VLM entry: that call is a deliberate pre-guard against the Metal command-buffer OOM on large pixel inputs (server-kill scenario). Not redundant with the 98 % guard for that codepath. - Error-recovery `mx.clear_cache()` calls in the OOM-recovery and SSM re-derive paths: low frequency, off the hot path. Stacks on PR-A (perf/safe-wins, jjang-ai#162). Independent of PR-C (sampler reuse + async pipelining). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
4 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Tier-2 perf cleanups on the long-context chunked-prefill path. All behaviour-preserving; the headline change (skipping per-chunk
mx.clear_cache()) is opt-in so existing users see zero change.Base:
9cfbeb24(currentmain, 1.5.32 series). Independent of PR-A #162 (low-risk hot-path cleanups).This is PR-B of 3 in the perf-audit series. PR-C (sampler buffer reuse + async pipelining) follows.
Changes in
MLLMBatchGenerator._process_promptschunked-text branchmx.clear_cache()at line 2722 flushes the Metal allocator free-list between every chunk of a long-context prefill. Redundant given the v1.5.31 98 % working-set guard.--prefill-keep-allocCLI flag (orVMLX_PREFILL_KEEP_ALLOC=1) skips the call. Default behaviour unchanged.next((b for b in ssm_boundaries if processed < b <= processed + chunk_size and b not in captured), None)is an O(B) generator scan. For a 32K prompt withprefill_step_size=1024and 4 boundaries that's 32 generator allocations + 128 conditional checks per request.getattr(request, '_original_token_ids', None) or input_ids[0].tolist()evaluated on every boundary hit; the fallback triggers a Metal→CPU sync._original_token_idsis unset._maybe_capture_clean_ssm_boundaryaccepts a Python list, so the contract is unchanged.mx.eval([c.state for c in cache if hasattr(c, 'state')])per chunk —hasattrreflection over the full cache layer list every iteration._state_layers = [c for c in cache if hasattr(c, "state")]once before the loop; reuse inside.Out of scope (kept unconditional, justified inline)
mx.clear_cache()at the one-shot VLM entry (line 2849 in upstream) is a deliberate pre-guard against the Metal command-buffer OOM on large pixel inputs (server-kill scenario). Not redundant with the 98 % guard for that codepath.mx.clear_cache()calls in the OOM-recovery and SSM re-derive branches are low-frequency and off the hot path. Left untouched.Tests
tests/test_perf_prefill_loop.py— 8 new unit tests:test_chunk_loop_uses_sorted_boundary_pointer— boundary pointer + sorted list presenttest_chunk_loop_precomputes_state_layers—_state_layersreused per chunktest_chunk_loop_hoists_all_tokens_tolist—_hoisted_all_tokenspresenttest_chunk_loop_env_gates_clear_cache—if not _prefill_keep_alloc: mx.clear_cache()gate presenttest_cli_flag_propagates_to_env—--prefill-keep-allocsetsVMLX_PREFILL_KEEP_ALLOC=1test_prefill_keep_alloc_env_off_by_default— default behaviour unchangedtest_prefill_keep_alloc_env_recognised— truthy spellings work (1,true,yes,on, case-insensitive)test_boundary_pointer_advances_past_captured— synthetic pointer-advance checktests/test_memory_limits.py8/8,tests/test_ssm_companion_cache.py14/14,tests/test_batching.py54/56 (same 2 pre-existing pytest-asyncio infra failures, unaffected).Expected gain
With
--prefill-keep-alloc:Without the flag:
Test plan
pytest tests/test_perf_prefill_loop.py -v— 8/8 passvmlx serve --helpshows--prefill-keep-allocVMLX_PREFILL_KEEP_ALLOCon vs off; verify no OOMbench/live_speed_probe.pybefore/after on the standard long-context prompt🤖 Generated with Claude Code