Skip to content

perf: env-gated prefill allocator clear + chunked-loop tightening (PR-B of 3)#163

Open
st-adam wants to merge 1 commit into
jjang-ai:mainfrom
st-adam:perf/prefill-allocator-gate
Open

perf: env-gated prefill allocator clear + chunked-loop tightening (PR-B of 3)#163
st-adam wants to merge 1 commit into
jjang-ai:mainfrom
st-adam:perf/prefill-allocator-gate

Conversation

@st-adam
Copy link
Copy Markdown

@st-adam st-adam commented May 12, 2026

Summary

Tier-2 perf cleanups on the long-context chunked-prefill path. All behaviour-preserving; the headline change (skipping per-chunk mx.clear_cache()) is opt-in so existing users see zero change.

Base: 9cfbeb24 (current main, 1.5.32 series). Independent of PR-A #162 (low-risk hot-path cleanups).

This is PR-B of 3 in the perf-audit series. PR-C (sampler buffer reuse + async pipelining) follows.

Changes in MLLMBatchGenerator._process_prompts chunked-text branch

# finding fix
7 Per-chunk mx.clear_cache() at line 2722 flushes the Metal allocator free-list between every chunk of a long-context prefill. Redundant given the v1.5.31 98 % working-set guard. Env-gated. New --prefill-keep-alloc CLI flag (or VMLX_PREFILL_KEEP_ALLOC=1) skips the call. Default behaviour unchanged.
8 Per-chunk next((b for b in ssm_boundaries if processed < b <= processed + chunk_size and b not in captured), None) is an O(B) generator scan. For a 32K prompt with prefill_step_size=1024 and 4 boundaries that's 32 generator allocations + 128 conditional checks per request. Sorted-list + pointer index; pointer advances monotonically as the loop progresses. O(1) per chunk.
9 getattr(request, '_original_token_ids', None) or input_ids[0].tolist() evaluated on every boundary hit; the fallback triggers a Metal→CPU sync. Computed once before the loop when _original_token_ids is unset. _maybe_capture_clean_ssm_boundary accepts a Python list, so the contract is unchanged.
10 mx.eval([c.state for c in cache if hasattr(c, 'state')]) per chunk — hasattr reflection over the full cache layer list every iteration. Precompute _state_layers = [c for c in cache if hasattr(c, "state")] once before the loop; reuse inside.

Out of scope (kept unconditional, justified inline)

  • The image-prefill mx.clear_cache() at the one-shot VLM entry (line 2849 in upstream) is a deliberate pre-guard against the Metal command-buffer OOM on large pixel inputs (server-kill scenario). Not redundant with the 98 % guard for that codepath.
  • Error-recovery mx.clear_cache() calls in the OOM-recovery and SSM re-derive branches are low-frequency and off the hot path. Left untouched.

Tests

  • tests/test_perf_prefill_loop.py — 8 new unit tests:
    • test_chunk_loop_uses_sorted_boundary_pointer — boundary pointer + sorted list present
    • test_chunk_loop_precomputes_state_layers_state_layers reused per chunk
    • test_chunk_loop_hoists_all_tokens_tolist_hoisted_all_tokens present
    • test_chunk_loop_env_gates_clear_cacheif not _prefill_keep_alloc: mx.clear_cache() gate present
    • test_cli_flag_propagates_to_env--prefill-keep-alloc sets VMLX_PREFILL_KEEP_ALLOC=1
    • test_prefill_keep_alloc_env_off_by_default — default behaviour unchanged
    • test_prefill_keep_alloc_env_recognised — truthy spellings work (1, true, yes, on, case-insensitive)
    • test_boundary_pointer_advances_past_captured — synthetic pointer-advance check
  • Adjacent regression sweep: tests/test_memory_limits.py 8/8, tests/test_ssm_companion_cache.py 14/14, tests/test_batching.py 54/56 (same 2 pre-existing pytest-asyncio infra failures, unaffected).

Expected gain

With --prefill-keep-alloc:

  • 5-12 % TTFT improvement on long prompts (>8K tokens), depending on chunk count. The allocator-clear elision is the bulk; boundary precompute + tolist hoist + state-layer precompute are constant-factor.

Without the flag:

  • Same behaviour as today. Boundary pointer + state-layer precompute + tolist hoist still apply (these are unconditional optimizations — they reduce Python overhead without changing the Metal command stream).

Test plan

  • pytest tests/test_perf_prefill_loop.py -v — 8/8 pass
  • vmlx serve --help shows --prefill-keep-alloc
  • CLI flag → env var propagation verified in unit test
  • Live model: TTFT at 8K / 16K / 32K with VMLX_PREFILL_KEEP_ALLOC on vs off; verify no OOM
  • bench/live_speed_probe.py before/after on the standard long-context prompt

🤖 Generated with Claude Code

…-B of 3)

Pure-implementation perf wins on the long-context chunked-prefill path.
All behaviour-preserving; the headline change (skipping per-chunk
`mx.clear_cache()`) is opt-in.

Changes in `MLLMBatchGenerator._process_prompts` chunked-text branch:

1. **Per-chunk `mx.clear_cache()` is now env-gated.** Default behaviour
   unchanged. Opt-in via `--prefill-keep-alloc` (or
   `VMLX_PREFILL_KEEP_ALLOC=1`) keeps the Metal allocator free-list warm
   across prefill chunks. The 1.5.31 working-set guard (98 %) catches
   genuine OOMs, so the per-chunk pool-flush is redundant on a
   well-provisioned box. Expected: 5-12 % TTFT improvement on 16K+
   prompts when enabled.

2. **Boundary lookup goes from O(B) per chunk to O(1).** Replaced the
   `next((b for b in ssm_boundaries if processed < b <= processed +
   chunk_size and b not in captured), None)` generator scan with a
   sorted-list + pointer index that advances monotonically. For a 32K
   prompt and ~4 boundaries that's 32 generator allocations and 128
   conditional checks dropped per request.

3. **`input_ids[0].tolist()` is hoisted out of the loop.** Each chunk
   boundary previously triggered a Metal→CPU sync to materialise the
   prompt token list for `_maybe_capture_clean_ssm_boundary`. Now
   computed once before the loop when `_original_token_ids` is not
   already set.

4. **`_state_layers` precomputed once.** The per-chunk
   `mx.eval([c.state for c in cache if hasattr(c, 'state')])` walked
   the full cache layer list each iteration. The hasattr-filtered list
   is now built once before the loop.

CLI:
- New `--prefill-keep-alloc` flag mirrors the env var.

Tests:
- `tests/test_perf_prefill_loop.py` — 8 unit tests covering each
  invariant (boundary pointer present, state-layers precomputed,
  tolist hoisted, env-gate logic, CLI propagation, truthy spellings).
- Adjacent regression: `tests/test_memory_limits.py` 8/8,
  `tests/test_ssm_companion_cache.py` 14/14, `tests/test_batching.py`
  54/56 (same 2 pre-existing pytest-asyncio infra failures).

Out of scope (kept unconditional, justified):
- The image-prefill `mx.clear_cache()` at the one-shot VLM entry: that
  call is a deliberate pre-guard against the Metal command-buffer OOM
  on large pixel inputs (server-kill scenario). Not redundant with the
  98 % guard for that codepath.
- Error-recovery `mx.clear_cache()` calls in the OOM-recovery and SSM
  re-derive paths: low frequency, off the hot path.

Stacks on PR-A (perf/safe-wins, jjang-ai#162). Independent of PR-C (sampler
reuse + async pipelining).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant