Skip to content

feat(pld): hybrid partial-accept replay for SSM models (#134)#149

Open
st-adam wants to merge 1 commit into
jjang-ai:mainfrom
st-adam:pld-ssm-replay
Open

feat(pld): hybrid partial-accept replay for SSM models (#134)#149
st-adam wants to merge 1 commit into
jjang-ai:mainfrom
st-adam:pld-ssm-replay

Conversation

@st-adam
Copy link
Copy Markdown

@st-adam st-adam commented May 7, 2026

Summary

  • Fixes the token loss on hybrid SSM/ATT models when 0 < num_accept < K in PLD partial-reject path
  • Adds Scheduler._replay_ssm_forward() to restore caches to N, replay accepted tokens, advance to N+K', emit K'+1 tokens instead of 1
  • Default ON; opt-out: VMLX_DISABLE_PLD_REPLAY=1
  • New /health field pld_ssm_replay.{enabled,attempts,emitted,failures}
  • 6 unit tests in tests/test_pld_ssm_replay.py

Problem

PR #26 PLD on hybrid models (48 GatedDeltaNet + 16 full-attention): with K=2, a partial accept (num_accept=1) still emits only 1 correction token because SSM state cannot be trimmed — both caches must rewind to N.

Solution

After partial rejection, restore to N, replay drafts[:num_accept] forward through the full model. Both caches reach N+num_accept. Emit drafts[:num_accept] + [bonus_token] — same as the full-accept path, minus the extra K-K' tokens.

Expected gain

+5-10% on top of PR #26's +4-7% on hybrid models. Full PLD target for hybrid moves from +4-7% toward the +15-25% cited in #134.

Test plan

  • pytest tests/test_pld_ssm_replay.py -v — 6 unit tests pass
  • pytest tests/test_ssm_companion_cache.py -v — existing tests unaffected
  • Live model: VMLX_DISABLE_PLD_REPLAY=1 vs unset — byte-equal at T=0, higher tok/s unset

Fixes #134

🤖 Generated with Claude Code

…i#134)

On hybrid SSM/ATT models with 0 < num_accept < K, the PLD path previously
discarded accepted drafts and emitted only a correction token. This PR adds
_replay_ssm_forward() which restores caches to N, replays the accepted tokens
through the model, and emits num_accept+1 tokens instead of 1.

- New Scheduler._replay_ssm_forward() staticmethod (scheduler.py)
- Modify case (b): try replay first, fall back to correction-only on failure
- Add _pld_replay_{enabled,attempts,emitted,failures} counters
- Add pld_ssm_replay telemetry to /health endpoint (server.py)
- Document the fix in notes/prompt-lookup-decoding.md
- 6 unit tests in tests/test_pld_ssm_replay.py
- New partial_accept_stress benchmark in tests/benchmark/test_pld_acceptance.py

Expected gain: +5-10% on top of PR jjang-ai#26's +4-7% on hybrid models.
Disable: VMLX_DISABLE_PLD_REPLAY=1

Closes jjang-ai#134

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@st-adam
Copy link
Copy Markdown
Author

st-adam commented May 12, 2026

Rebased onto current main (1.5.32 series, base commit 9cfbeb24) per @jjang-ai's note on #134 ("PRs in most recent stable versions"). No code changes — clean rebase over the 50 intervening commits.

Verified post-rebase:

  • tests/test_pld_ssm_replay.py — 6/6 passing
  • vmlx_engine.scheduler imports cleanly; Scheduler._replay_ssm_forward present
  • Upstream's memory_limits refactor (v1.5.31) integrated cleanly into our scheduler.py touch points — no shadowing of get_metal_ws_guard_threshold / get_effective_metal_working_set_bytes

PR is CLEAN / MERGEABLE against the 1.5.32 base.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Perf: PLD verify-cost on hybrid SSM models — proposal for SSM checkpoint/replay (extends #26)

2 participants