diff --git a/docs/mamba-dsv4-refactor.md b/docs/mamba-dsv4-refactor.md new file mode 100644 index 000000000..6b98e3ad0 --- /dev/null +++ b/docs/mamba-dsv4-refactor.md @@ -0,0 +1,626 @@ +# tokenspeed-scheduler Hybrid Cache Refactor Plan + +> Finalized design note for converging Mamba/GDN-style state cache and +> DeepSeek-V4 grouped cache in `tokenspeed-scheduler`. +> +> Target families: +> - Standard MHA / MLA token-page cache. +> - DeepSeek-V4 CSA / HCA compressed history plus SWA / tail state. +> - Qwen 3.5 / GDN / Mamba-like recurrent state plus conv state. +> - Future hybrid variants with page-like and state-like cache families. +> +> Baseline: tokenspeed-scheduler `main`, plus the in-flight +> `feat/v4-prefix-cache` bridge as a prototype to preserve and generalize. + +--- + +## 1. Goal + +Refactor `tokenspeed-scheduler/csrc/resource/` so the scheduler critical path is +model-neutral. `Scheduler` should call one `CacheCoordinator` facade; per-model +differences belong in registered `CacheResourceSpec` records and internal +`CacheFamilyPolicy` implementations. + +The scheduler should optimize for the **longest recoverable prefix**, not the +longest equal-token prefix. A token prefix is usable only if every required +cache family can either attach an exact reusable slice, restore an aligned +checkpoint, or provide an explicit replay plan that rebuilds the missing state +before normal execution resumes. + +Design principle: **converge the scheduler-facing substrate, not the cache +semantics**. Cache families may share integer resource ids, RAII ownership, +per-request facades, TreeNode attachments, commit-boundary plumbing, and +coordinator control flow. Allocator geometry, recoverability, split behavior, +and publication policy remain family-specific. + +## 2. Non-Goals + +- Do not force all cache types into one KV page model. +- Do not collapse all allocators into one semantics-blind allocator. +- Do not make `TreeNode` aware of model names such as Mamba or DeepSeek-V4. + Model-specific names live in spec ids. +- Do not publish auxiliary / window / tail state to the global prefix index by + default. They stay request-local unless an explicit recovery contract exists. +- Do not change forward-op tensor ABI in PRs 1-7. Existing Mamba fields and + flattened per-group block tables remain compatibility views until the worker + ABI is deliberately migrated. +- Do not introduce per-token or per-layer virtual dispatch on the scheduler hot + path. A bounded startup-built loop over active families is acceptable. + +## 3. Model Semantics To Preserve + +DeepSeek-V4 and Mamba/GDN layouts share a State/Classical split, but the +shareable object differs. + +| Aspect | DeepSeek-V4 CSA/HCA/SWA | Qwen 3.5 / GDN / Mamba-like | +|---|---|---| +| Classical history | CSA/HCA compressed KV blocks. CSA has main compressed KV plus indexer KV; HCA has main compressed KV. | Standard token KV pages only on full-attention layers. | +| Sequence state | SWA window plus uncompressed compression tail. | Recurrent state plus conv ring buffer. | +| Reuse cap | Last complete compression block / LCM boundary. Incomplete tails are replay work. | Last aligned checkpoint (`mamba_cache_chunk_size` / FLA chunk multiple), plus full-attention page hits. | +| Shareable unit | Canonical compressed-history bundle; optional SWA state only with an explicit policy. | Atomic recurrent+conv checkpoint; conv and recurrent state must not be split. | +| Request-local unit | SWA/tail state unless checkpointed. | In-flight working state. | + +Recoverability mapping: + +- `Recoverability::Exact`: standard KV pages, V4 compressed history, Qwen + full-attention pages. +- `Recoverability::AlignedCheckpoint`: Mamba/GDN state checkpoints and optional + V4 SWA periodic checkpoints. +- `Recoverability::WindowRepairable`: V4 SWA zero-cache style rebuild. +- `Recoverability::RequestLocalOnly`: scratch state that never enters the global + prefix index. + +## 4. Current TokenSpeed State + +Current `main` has three scheduler-visible cache concepts: + +- Standard KV: `Scheduler` owns `KVPrefixCache`; normal KV pages are inserted + into the radix tree. +- Mamba/GDN: `HybridPrefixCache` wraps `KVPrefixCache` for Mamba-aware matching. + `MambaChunkAllocator` allocates integer state slots; `LocalMambaAllocator` + keeps one working slot and one checkpoint slot per request. `InsertMamba` + attaches a checkpoint slot to a block-aligned `TreeNode`. +- DeepSeek-V4 paged groups: `PagedCacheGroupTable` is per `(request, group)` and + owns compact `OwnedPages`, `raw_token_cursor_`, and `base_logical_page_`. + Scheduler helpers allocate, compact, and flatten these request-local tables + into `paged_cache_block_tables` and `paged_cache_block_table_base_offsets`. + +The important current-state boundary: on `main`, V4 paged-cache groups are not +tree-attached prefix-cache entries. They are a parallel request-local address +space keyed by `group_id`. + +Current scheduler-facing structure on `main`: + +```text +Scheduler + | + +-- KVPrefixCache + | | + | +-- RadixTree + | | + | +-- TreeNode + | | + | +-- DeviceResource / HostResource (standard KV pages) + | `-- MambaSlot (Mamba checkpoint) + | + +-- Mamba side + | | + | +-- MambaChunkAllocator (slot allocator) + | `-- LocalMambaAllocator per request + | | + | +-- working slot + | `-- checkpoint slot + | + `-- DeepSeek-V4 paged-cache-group side + | + +-- PagedCacheGroupAllocator[group_id] (page allocator) + `-- PagedCacheGroupTable[request_id][group_id] + | + +-- OwnedPages (request-owned only) + +-- raw_token_cursor_ + `-- base_logical_page_ +``` + +Current critical path on `main`: + +```text +Scheduler::newForwardOperation + | + +-- Match + | | + | +-- KVPrefixCache::Match + | `-- HybridPrefixCache::augmentMatch (Mamba checkpoint) + | + +-- Admission + | | + | +-- KVPrefixCache::EnsureCapacityByEvict + | +-- HybridPrefixCache::EnsureMambaCapacityByEvict + | `-- Scheduler::checkPagedCacheGroupAdmission (V4 request-local) + | + +-- Apply FSM event + | + +-- Cache state mutation + | | + | +-- Scheduler::acquirePagedCachePagesForRequest (ReleaseSkipped + Acquire) + | `-- Scheduler::populatePagedCachePagesForOp (forward metadata) + | + `-- Finish / retract cleanup + | + +-- KV node refs / cache-op tracking + +-- Mamba slot release + `-- PagedCacheGroupTable::ReleaseAll +``` + +The in-flight `feat/v4-prefix-cache` work should be treated as a Phase-0 bridge, +not the final abstraction. Preserve its useful semantics: opt-in switch, +History/State completeness split, state-only prune, split guards, +borrowed-prefix import ordering, and RAII ownership of pages moved from +request-local tables into tree-node attachments. + +## 5. Target Architecture + +`Scheduler` sees only five coordinator operations: + +| Operation | Scheduler-facing meaning | +|---|---| +| `MatchPrefix` | Produce a `RecoveryPlan`: exact hits, checkpoint hits, and replay ranges. | +| `Admit` | Decide whether the next step can allocate, must evict, or must reject. | +| `StepCommit` | Publish complete pages, compressed bundles, or eligible checkpoints. | +| `FinishRequest` | Release request-local state while leaving shared attachments managed by refcount/LRU. | +| `Statistics` | Report family-level hit, replay, allocation, and eviction counters. | + +`CacheCoordinator` owns: + +- `FamilyRegistry`: startup-registered `CacheResourceSpec` records. +- `RadixTree`: token prefix tree with family slots on each node. +- Existing allocators: `PageAllocator`, `PagedCacheGroupAllocator`, + `MambaChunkAllocator`, and their request-local users stay resource-specific + in this refactor. +- The mutation protocol for request-local cache state: admission, detach, + release, tree attach, finish cleanup, and atomic cohort transitions. +- Startup-built active-family policy arrays for match, admit, commit, evict, + and finish phases. + +Expected simplification: + +- `scheduler.cpp` no longer branches over KV vs Mamba vs V4 paged groups. +- `MatchResult::mamba_*` and existing paged-cache metadata become compatibility + views derived from `RecoveryPlan`. +- `Scheduler` keeps a `RequestCacheContext` per request. Flatten reads + snapshots from that context; mutating operations pass the context by reference + to `CacheCoordinator`. +- There is no separate indirection registry. `RequestCacheContext` itself is + the request-local API surface; ownership-changing methods still live on + `CacheCoordinator`. +- `PagedCacheSnapshot` becomes a bridge representation of generic + `TreeAttachment` slots, with V4 roles encoded by spec id. + +The wrong abstraction is a universal allocator. The right shared substrate is: +per-request state shape, explicit commit boundaries, radix-node attachments, +recovery/admission plans, and one coordinator-owned control-flow order. + +Target scheduler-facing structure: + +```text +Scheduler + | + +-- RequestCacheContext[request_id] + | | + | +-- page-array state entries + | `-- checkpoint-pair state entries + | | + | +-- read-only snapshot helpers + | `-- no tree attach / detach authority + | + +-- Forward metadata flattening step + | | + | `-- asks RequestCacheContext for read-only snapshots + | no detach, release, or tree attach + | + `-- CacheCoordinator + | + +-- FamilyRegistry + | | + | `-- CacheResourceSpec[FamilyId] + | | + | +-- family: TokenPage / CompressedPage / + | | SlidingWindowState / CompressionTailState / + | | RecurrentState / ConvState + | `-- attachment_kind: TokenPageAttachment / + | CompressedBlockAttachment / + | SlidingWindowStateAttachment / + | StateSlotAttachment / + | NoneForRequestLocal + | + +-- RadixTree + | | + | `-- TreeNode + | | + | `-- slots[FamilyId] -> TreeAttachment + | (shape follows spec.attachment_kind) + | | + | +-- TokenPageAttachment + | +-- CompressedBlockAttachment + | +-- SlidingWindowStateAttachment + | `-- StateSlotAttachment + | + +-- Existing allocators + | | + | +-- PageAllocator / LocalKVAllocator + | +-- PagedCacheGroupAllocator / PagedCacheGroupTable + | `-- MambaChunkAllocator / LocalMambaAllocator + | + `-- CacheFamilyPolicy[FamilyId] + | + +-- Match / ComputeSlice + +-- ComputeDemand + +-- OnCommit + +-- OnEvict + `-- OnFinish +``` + +Target critical path: + +```text +Scheduler::newForwardOperation + | + +-- RecoveryPlan plan = CacheCoordinator::MatchPrefix(request_key) + | | + | +-- walk the token radix tree once + | +-- validate each active family against the terminal node + | | | + | | +-- exact page / compressed-bundle hit + | | +-- aligned recurrent+conv checkpoint hit + | | `-- window/tail replay plan + | | + | `-- return slices + replay ranges + execution resume point + | + +-- AdmissionVerdict verdict = CacheCoordinator::Admit(ctx, next_step) + | | + | +-- compute new demand by family + | +-- subtract borrowed_prefix_units + | +-- credit sliding releasable_units + | `-- evict/prune through family policies if needed + | + +-- Apply FSM event + | + +-- CacheCoordinator::StepCommit(ctx, plan, step_tokens) + | | + | +-- attach complete canonical bundles at commit boundaries + | +-- publish eligible checkpoints + | `-- keep RequestLocalOnly state local + | + +-- ForwardMetadata::Flatten(ctx, plan) + | | + | +-- old KV block tables + | +-- old mamba_pool_indices / cow_src / branching seqlen + | `-- old V4 paged_cache_block_tables / base_offsets + | + `-- CacheCoordinator::FinishRequest(ctx) + | + `-- release request-local state; shared TreeNode attachments remain +``` + +## 6. Core Data Model + +Only the stable scheduler-facing model is shown here. Implementation helpers and +test adapters should stay local to the PR that introduces them. + +Data-structure relationship: + +```text +┌──────────────────────────────────────────────────────────────┐ +│ scheduler.cpp / operations::forward │ +│ owns ctx; calls coordinator(ctx) │ +└──────────────┬──────────────────────────┬────────────────────┘ + │ owns │ calls with ctx& + ▼ ▼ +┌──────────────────────────────┐ ┌──────────────────────────────┐ +│ RequestCacheContext │ │ CacheCoordinator │ +│ - PageArrayState entries │◄─┤ - mutates ctx via APIs │ +│ - CheckpointPair entries │ │ - FamilyRegistry │ +│ - read-only state views │ │ - active family policies │ +│ - no allocator ownership │ │ - produces RecoveryPlan │ +│ - no tree access │ └──────────────┬───────────────┘ +└──────────────┬───────────────┘ │ owns + │ request-local lifetime ▼ + │ ┌────────────────────────────────────────┐ + │ │ RadixTree / TreeNode slots │ + │ │ slots[FamilyId] -> TreeAttachment │ + │ │ │ + │ │ - TokenPageAttachment │ + │ │ - CompressedBlockAttachment │ + │ │ - SlidingWindowStateAttachment │ + │ │ - StateSlotAttachment │ + │ └───────────────────┬────────────────────┘ + │ │ shared lifetime + └──────────────────┬───────────────────┘ + ▼ + ┌────────────────────────────────────────┐ + │ Existing allocator layer │ + │ PageAllocator │ + │ PagedCacheGroupAllocator │ + │ MambaChunkAllocator │ + │ request-local adapters │ + └────────────────────────────────────────┘ +``` + +There is no standalone request-local entity beside `CacheCoordinator`. +`PageArrayState` and `CheckpointPair` are entries inside +`RequestCacheContext`. The coordinator can mutate them only because the +scheduler passes the context by reference to coordinator methods. The tree side +stores reusable `TreeAttachment` shapes; the request side stores current +per-request state. + +`RequestCacheContext` is split out because its lifecycle is different from both +the tree and the allocators: + +- Request-local state is created at admission, changes every scheduling step, + and is fully cleaned up at request finish or abort. +- Tree attachments are shared artifacts. They are created only at commit + boundaries and outlive the request until refcount/LRU eviction releases them. +- Allocators are global resource owners. They provide and reclaim pages or + slots, but they should not encode request progress, prefix-match decisions, + or tree publication state. + +Keeping these lifecycles separate lets the scheduler carry one per-request +object through the forward path while preserving coordinator ownership of +mutating operations. + +The code model should stay small. It needs five stable concepts: + +| Concept | Purpose | Essential fields | +|---|---|---| +| `CacheResourceSpec` | Static declaration of one cache family role. | `id`, `family_index`, `family`, `attachment_kind`, `recoverability`, `publication`, `split_policy`, `rows_per_page`, `entry_stride_tokens`, optional `sliding_window_tokens`, `checkpoint_chunk_tokens`, `layer_indices`, optional `state_cohort_id`, `required_for_recovery`. | +| `RequestCacheContext` | Scheduler-held per-request cache facade. | request id, family-indexed page-array/checkpoint-pair entries, read-only state views; no tree attach/detach authority. | +| `FamilySlice` | One family contribution inside a match result. | `family`, `hit_node`, `recoverable_end_tokens`, `replay_from_tokens`, `replay_to_tokens`, `replay_mode`, borrowed page/slot ids, optional page-table base. | +| `RecoveryPlan` | Scheduler-facing answer to prefix matching. | `kv_terminal`, `matched_prefix_tokens`, `reusable_classical_end_tokens`, `recoverable_prefix_end_tokens`, `execution_resume_tokens`, `replay_ranges`, `slices`. | +| `ResourceDemand` / `AdmissionVerdict` | Admission accounting by family and by atomic state cohort. | family id, new units needed, releasable units, borrowed prefix units, optional `state_cohort_id`, verdict kind. | + +The small enum set behind those structs is: + +| Enum | Values | +|---|---| +| `CacheFamily` | `TokenPage`, `CompressedPage`, `SlidingWindowState`, `CompressionTailState`, `RecurrentState`, `ConvState`. | +| `AttachmentKind` | `TokenPageAttachment`, `CompressedBlockAttachment`, `SlidingWindowStateAttachment`, `StateSlotAttachment`, `NoneForRequestLocal`. | +| `Recoverability` | `Exact`, `AlignedCheckpoint`, `WindowRepairable`, `RequestLocalOnly`. | +| `PublicationPolicy` | `CanonicalPrefixIndex`, `AuxiliaryLocalOnly`, `RequestLocalOnly`. | +| `SplitPolicy` | `SplitPrefix`, `KeepOnSuffix`, `DropOnSplit`, `NoSplitAllowed`. | +| `ReplayMode` | `None`, `PrefillRecompute`, `StateForwardPropagate`, `WindowRebuild`. | + +`CacheFamily` and `AttachmentKind` should not drift. `CacheFamily` names the +model semantic role; `AttachmentKind` names the reusable shape stored on a +`TreeNode`. The mapping is declared once in `CacheResourceSpec`: + +| Cache family | Tree attachment shape | +|---|---| +| `TokenPage` | `TokenPageAttachment` | +| `CompressedPage` | `CompressedBlockAttachment`; CSA/HCA specs for the same raw span may be committed as one atomic bundle. | +| `SlidingWindowState` | `NoneForRequestLocal` by default; `SlidingWindowStateAttachment` if explicit SWA sharing is enabled. Full SWA Caching and Periodic Checkpointing are strategies on that one attachment kind, not separate tree state. | +| `CompressionTailState` | `NoneForRequestLocal` by default; incomplete tails are replay work, not shared attachments. | +| `RecurrentState` | `StateSlotAttachment` as part of the recurrent+conv checkpoint cohort. | +| `ConvState` | `StateSlotAttachment` as part of the same checkpoint cohort; it must not own a separate tree lifetime from `RecurrentState`. | + +Spec-id naming keeps model roles out of scheduler branches: + +- `kv.device`, `kv.host` -> `TokenPage` +- `v4.csa.main.compressed_kv`, `v4.csa.indexer.compressed_kv` -> + `CompressedPage` +- `v4.hca.main.compressed_kv` -> `CompressedPage` +- `v4.swa.window_state` -> `SlidingWindowState` +- `v4.compression_tail.state` -> `CompressionTailState` +- `qwen35.gdn.recurrent_state` -> `RecurrentState` +- `qwen35.gdn.conv_state` -> `ConvState` + +Layout comparison: + +```text +1. Standard KV / MLA token pages + + logical tokens: [0 ........ 63][64 ....... 127][128 ...... 191] + page table: page A page B page C + tree attachment: exact reusable token-page slices + growth: linear with context length + + +2. DeepSeek-V4 compressed history + sequence state + + raw tokens: [---------- LCM span ----------][tail not complete] + | | | + v v v + CSA compressed KV HCA compressed KV SWA/tail state + + indexer KV main history request-local + + tree attachment: complete compressed-history bundle only + replay boundary: incomplete compression tail / SWA window policy + growth: compressed history grows by completed spans; + sequence state stays fixed-size per request + + +3. Mamba/GDN recurrent + conv checkpoint + + raw tokens: [0 ........ checkpoint N][N .... current tail] + | | + v v + state slot checkpoint working slot + conv_state + recurrent runtime updates every step + + tree attachment: atomic state slot checkpoint at aligned boundary + replay boundary: nearest checkpoint <= token prefix + growth: state size is independent of sequence length; + checkpoints are sparse reusable snapshots +``` + +### 6.1 Allocators And Request Facades + +Allocator classes should remain resource-specific in this refactor. Do not +introduce a generic allocator layer, and do not rename existing allocators into +pool-like abstractions. + +Keep the current allocator ownership model: + +- `PageAllocator` remains the allocator for standard KV pages. +- `PagedCacheGroupAllocator` remains the allocator for V4 paged-cache groups. +- `MambaChunkAllocator` remains the allocator for Mamba/GDN state slots. +- `OwnedPages` and `OwnedStateSlots` keep their concrete RAII ownership names. +- Low-level free-list or counter helpers may be factored later, but they are + implementation details, not part of this design's public model. + +The shared scheduler-facing layer is `RequestCacheContext`, above the +allocators: + +- Page-array state: request-local state whose visible shape is borrowed prefix + pages, owned suffix pages, raw-token cursor, base logical page, and optional + sliding-window release. This covers standard KV suffix state and V4 + `PagedCacheGroupTable` behavior. +- Checkpoint-pair state: request-local state whose visible shape is one working + slot plus one checkpoint slot. This covers `LocalMambaAllocator` behavior. + +Those names describe context entry shapes, not new allocator base classes. At +runtime, `Scheduler` stores a `RequestCacheContext` per request and passes it by +reference to `CacheCoordinator`. `RequestCacheContext` may expose read-only +state views, but detach, release, attach, admission, and cohort mutation must go +through `CacheCoordinator`. + +Do not add an intermediate indirection layer in this rollout. The context +object is passed directly to coordinator methods, and the coordinator mutates +the page-array or checkpoint-pair entries through explicit function calls on +that context. + +The lifetime rule is still explicit: + +- Page-like family: existing allocator -> request-local page-array entry in + `RequestCacheContext` -> Tree attachment at commit boundary -> eviction + returns pages to the same allocator. +- Slot-like family: existing allocator -> request-local checkpoint-pair entry in + `RequestCacheContext` -> `StateSlotAttachment` at aligned checkpoint boundary + -> eviction returns the slot to the same allocator. + +`RecurrentState` and `ConvState` are separate families but one atomic checkpoint +cohort. A request must not publish, transfer, admit, or evict one without the +other, so the coordinator treats them as one checkpoint-pair recovery unit even +though the underlying allocator remains `MambaChunkAllocator`. + +### 6.2 Tree Attachments + +Tree-node family slots are the reusable side of the design. + +- `TokenPageAttachment`: normal KV page ownership. +- `CompressedBlockAttachment`: canonical V4 compressed-history bundle. It may + own multiple spec ids for the same raw-token span, because V4 history is + reusable only when the required CSA/HCA groups are complete together. +- `SlidingWindowStateAttachment`: optional SWA storage when SWA is published to + the tree. Full SWA Caching and Periodic Checkpointing are two strategies for + the same logical attachment; an implementation should not keep both for the + same SWA span. +- `StateSlotAttachment`: Mamba/GDN checkpoint slot. It is leaf-like and uses + `SplitPolicy::KeepOnSuffix`. + +The Phase-0 `PagedCacheSnapshot` bridge maps to `CompressedBlockAttachment` plus +optional sliding-window-state attachments. Preserve its no-split guard until +generic attachment slots are ready. + +### 6.3 Recovery And Admission + +The important fields in `RecoveryPlan` are conceptual, not syntactic: + +```text +matched_prefix_tokens + raw token equality from the radix walk + +reusable_classical_end_tokens + furthest canonical page/bundle boundary that can be attached directly + +recoverable_prefix_end_tokens + min boundary across all required families after checkpoint/replay rules + +execution_resume_tokens + first token the worker must execute after materializing borrowed state + +replay_ranges + explicit family-local work needed to rebuild missing sequence state + +slices + per-family attachments, borrowed page/slot ids, base offsets, and replay mode +``` + +Admission consumes the same family view: + +```text +RecoveryPlan + | + v +ResourceDemand per family + new_units_needed + borrowed_prefix_units + releasable_units + state_cohort_id + | + v +AdmissionVerdict + Ok / NeedEvict / Reject +``` + +Demands with the same `state_cohort_id` are admitted as one unit. V4 +SlidingWindowState + CompressionTailState share one sequence-state block per +request; Mamba/GDN RecurrentState + ConvState share one checkpoint slot. + +## 7. Critical-Path Discipline + +The critical path is the target pipeline in Section 5. This section constrains +where indirection is allowed. + +```text +HOT: per request scheduled, per token committed + - fixed loop over startup-built ActiveFamilies + - at most one family hook per Match / Admit / Commit phase + - no string lookup, dynamic family discovery, or per-layer scan + - no scheduler-level model branch + +WARM: per insert decision, per eviction batch + - finalize recovery slices + - drive family-aware eviction + - update LRU/refcount/lock accounting + +COLD: radix split, HiCache, PD transfer, remote import + - validate layout signatures + - serialize immutable transfer descriptors + - use split policy before attachment-specific split logic +``` + +Derived `CacheFamilyPolicy` implementations may simplify family-specific code, +but they must not create a second family-specific control flow outside +`CacheCoordinator`. + +## 8. Invariants + +| # | Invariant | +|---|---| +| I1 | Auxiliary-family lock implies its canonical carrier lock on the same node: `aux_lock_ref <= carrier_lock_ref`. | +| I2 | A family may attach to the global index only at its commit boundary. | +| I3 | `PublicationPolicy::RequestLocalOnly` families never call `TreeNode::Attach`. | +| I4 | `Recoverability::RequestLocalOnly` families return no global hit. | +| I5 | Auxiliary attachments are evicted before their carrier attachment is freed. | +| I6 | Borrowed prefix ids are imported before sliding release or fresh allocation. | +| I7 | `NoSplitAllowed` and Phase-0 `PagedCacheSnapshot` nodes are not radix-split. | + +## 9. Cold-Path Layout And Transfer + +HiCache, prefill/decode disaggregation, remote cache import, and future +cross-instance transfer paths must validate layout compatibility before treating +a cache hit as materializable. This is cold-path metadata; it must not add +string lookup, serializer probing, or topology checks to the scheduler hot path. + +Every exported transfer descriptor should carry: + +- `CacheResourceId` / `FamilyId` and raw logical token span. +- `Recoverability`, `PublicationPolicy`, and checkpoint / replay mode. +- Page or slot geometry: `rows_per_page`, `entry_stride_tokens`, + `checkpoint_chunk_tokens`, and optional `sliding_window_tokens`. +- Tensor layout signature: dtype, layer membership, TP shard layout, PP stage + ownership, and serializer version. +- Storage references: page ids, state-slot blobs, or external object ids. + +Descriptors are immutable compatibility records. Local runtime state such as +`ref_count`, `lock_ref`, LRU position, pin state, and in-flight ownership stays +local to the receiving cache manager. diff --git a/tokenspeed-scheduler/CMakeLists.txt b/tokenspeed-scheduler/CMakeLists.txt index b0770e477..37347ee10 100644 --- a/tokenspeed-scheduler/CMakeLists.txt +++ b/tokenspeed-scheduler/CMakeLists.txt @@ -39,9 +39,11 @@ add_library(tokenspeed_scheduler_core STATIC csrc/resource/radix_tree/radix_tree.cpp csrc/resource/radix_tree/tree_node.cpp csrc/resource/kv_prefix_cache/kv_prefix_cache.cpp - csrc/resource/kv_prefix_cache/cache_coordinator.cpp + csrc/resource/hybrid_prefix_cache/family_registry.cpp csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache.cpp + csrc/resource/hybrid_prefix_cache/mamba_family_ops.cpp csrc/resource/hybrid_prefix_cache/mamba_eviction_manager.cpp + csrc/resource/hybrid_prefix_cache/paged_cache_family_ops.cpp csrc/resource/allocator/mamba_chunk_allocator.cpp csrc/resource/allocator/mamba_host_allocator.cpp csrc/resource/allocator/local_mamba_allocator.cpp @@ -50,7 +52,6 @@ add_library(tokenspeed_scheduler_core STATIC csrc/fsm/cache_events.cpp csrc/fsm/pd_events.cpp - csrc/scheduler/operations/cache.cpp csrc/scheduler/operations/forward.cpp csrc/scheduler/kv_cache_events.cpp @@ -112,12 +113,15 @@ if(TOKENSPEED_SCHEDULER_BUILD_TESTS) tests/cpp/test_batch_scheduling.cpp tests/cpp/test_prefetch.cpp tests/cpp/test_owned_pages.cpp + tests/cpp/test_hybrid_cache_registry.cpp tests/cpp/test_paged_cache_prefix_match.cpp tests/cpp/test_paged_cache_attach_loop.cpp tests/cpp/test_paged_cache_eviction.cpp tests/cpp/test_paged_cache_family_split.cpp tests/cpp/test_paged_cache_prefix_hit_commit.cpp tests/cpp/test_retract_abort_pages.cpp + tests/cpp/test_request_cache_context.cpp + tests/cpp/test_scheduler_memory_diagnostics.cpp tests/cpp/test_req_pool_allocator.cpp tests/cpp/test_mamba_slot.cpp tests/cpp/test_mamba_eviction.cpp diff --git a/tokenspeed-scheduler/csrc/fsm/forward_events.cpp b/tokenspeed-scheduler/csrc/fsm/forward_events.cpp index 83dd0354d..9e04e114f 100644 --- a/tokenspeed-scheduler/csrc/fsm/forward_events.cpp +++ b/tokenspeed-scheduler/csrc/fsm/forward_events.cpp @@ -33,10 +33,7 @@ #include "fsm/forward_events.h" #include "fsm/forward_states.h" #include "resource/allocator/kv_allocator.h" -#include "resource/allocator/owned_pages.h" #include "resource/allocator/req_pool_allocator.h" -#include "resource/radix_tree/node_range.h" -#include "resource/kv_prefix_cache/kv_prefix_cache.h" #include "resource/radix_tree/tree_node.h" #include "resource/types.h" #include "scheduler/operations/cache.h" @@ -58,24 +55,6 @@ std::vector BuildWriteBackPairs(const std::vector MambaNodesForTransferPairs(const std::vector& candidates, - const std::vector& transfers) { - std::unordered_set src_slots; - for (const auto& transfer : transfers) { - if (transfer.kind == tokenspeed::CacheKind::kMamba) { - src_slots.insert(transfer.src); - } - } - std::vector nodes; - nodes.reserve(src_slots.size()); - for (tokenspeed::TreeNode* node : candidates) { - if (node != nullptr && node->HasMamba() && src_slots.find(node->MambaSlotIndex()) != src_slots.end()) { - nodes.push_back(node); - } - } - return nodes; -} - void DemoteWrittenBackDevice(tokenspeed::KVPrefixCache* kv_prefix_cache, tokenspeed::HybridPrefixCache* hybrid_prefix_cache, tokenspeed::TreeNode* device_node) { if (kv_prefix_cache == nullptr || device_node == nullptr) return; @@ -86,94 +65,57 @@ void DemoteWrittenBackDevice(tokenspeed::KVPrefixCache* kv_prefix_cache, }); } -bool ShouldPublishMambaCheckpoint(tokenspeed::HybridPrefixCache* hybrid_cache, std::int32_t chunk_begin, - std::int32_t chunk_size, std::int32_t page_size) { - if (hybrid_cache == nullptr || chunk_size <= 0 || page_size <= 0) return false; - const std::int32_t final_len = chunk_begin + chunk_size; - const std::int32_t last_inserted_len = (final_len / page_size) * page_size; - if (last_inserted_len <= chunk_begin) return false; - if (last_inserted_len == final_len) return true; - - const std::int32_t track_len = last_inserted_len - chunk_begin; - return hybrid_cache->AlignMambaCacheSeqlen(track_len) == track_len; -} - } // namespace namespace tokenspeed::fsm { -void InsertHybridCache(HybridPrefixCache* hybrid_cache, - const std::vector>& full_paged_tokens, - std::unique_ptr& device_node_ref, LocalKVAllocator* local_kv_allocator, - LocalMambaAllocator* local_mamba_allocator, std::int32_t chunk_begin, std::int32_t chunk_size, - std::int32_t page_size) { - if (hybrid_cache == nullptr) return; - - std::vector prefix_pages = DevicePagesFromRoot(device_node_ref->Node()); - std::int32_t new_page_count = - static_cast(full_paged_tokens.size()) - static_cast(prefix_pages.size()); - if (new_page_count <= 0) { - if (local_mamba_allocator != nullptr && local_mamba_allocator->HasCheckpoint()) { - local_mamba_allocator->DetachCheckpoint(); - } - return; - } - - OwnedPages pages_to_insert = local_kv_allocator->TakeFirst(new_page_count); - auto insert_result = hybrid_cache->GetKVPrefixCache().Insert(full_paged_tokens, prefix_pages, - std::move(pages_to_insert)); - - if (local_mamba_allocator != nullptr && local_mamba_allocator->HasCheckpoint()) { - if (ShouldPublishMambaCheckpoint(hybrid_cache, chunk_begin, chunk_size, page_size)) { - hybrid_cache->InsertMamba(insert_result.last_node, local_mamba_allocator->DetachCheckpoint()); - } else { - local_mamba_allocator->DetachCheckpoint(); - } - } - device_node_ref = std::make_unique(insert_result.last_node); -} - // Submitted -> PrefillDone / Prefilling std::variant SchedulePrefillFirstChunkEvent::operator()(Submitted&& state) { // Lock node std::unique_ptr host_node_ref{nullptr}; std::unique_ptr device_node_ref{nullptr}; + std::int32_t max_matched_pages = + disable_l2_cache_ ? match_result_.device.DepthInPage() + : std::max(match_result_.device.DepthInPage(), match_result_.host.DepthInPage()); + std::int32_t window_begin = max_matched_pages * state.GetPageSize(); + const std::int32_t checkpoint_raw_position = window_begin + tokens_this_round_; if (!disable_l2_cache_ && (match_result_.host.DepthInPage() > match_result_.device.DepthInPage())) { host_node_ref = std::make_unique(match_result_.host.last_node); - kv_prefix_cache_->AllocateResourceOfType( - match_result_.NodesWithout()); + StepCommitRequest materialization_request{ + .materialize_prefix = + PrefixMaterializationRequest{ + .compat_match = &match_result_, + .require_all_pages = false, + }, + }; + (void)hybrid_prefix_cache_.StepCommit(std::move(materialization_request)); device_node_ref = std::make_unique(match_result_.host.last_node); } else { device_node_ref = std::make_unique(match_result_.device.last_node); } - // Allocate KV pages for tokens not covered by the prefix cache - auto local_kv_allocator = std::make_unique(device_allocator_, tokens_this_round_); - // Reserve token slots for draft multi-step decode - local_kv_allocator->Acquire(decode_input_tokens_); + auto step_result = hybrid_prefix_cache_.StepCommit({ + .request_local_kv = + RequestLocalKVStateRequest{ + .create_allocator = true, + .initial_tokens = tokens_this_round_, + .acquire_tokens = decode_input_tokens_, + }, + .request_local_mamba = + RequestLocalMambaStateRequest{ + .create_allocator = true, + .checkpoint_raw_position = checkpoint_raw_position, + }, + }); + auto local_kv_allocator = std::move(step_result.local_kv_allocator); // Allocate req_pool_idx when first-time scheduled auto req_pool_index = std::make_unique(req_pool_allocator_->Allocate()); - // Mamba: allocate working + checkpoint slots if mamba is enabled - std::unique_ptr local_mamba_allocator; - if (mamba_allocator_ != nullptr) { - local_mamba_allocator = std::make_unique(mamba_allocator_); - if (!local_mamba_allocator->AllocateWorking()) { - local_mamba_allocator.reset(); - } else { - if (!local_mamba_allocator->AllocateCheckpoint()) { - throw std::logic_error("SchedulePrefillFirstChunkEvent: failed to allocate Mamba checkpoint slot"); - } - } - } + auto local_mamba_allocator = std::move(step_result.local_mamba_allocator); TokenContainer* token_container = state.GetTokenContainer(); - std::int32_t max_matched_pages = - disable_l2_cache_ ? match_result_.device.DepthInPage() - : std::max(match_result_.device.DepthInPage(), match_result_.host.DepthInPage()); - std::int32_t window_begin = max_matched_pages * state.GetPageSize(); TokenContainer::Window window{.begin = window_begin, .size = tokens_this_round_}; bool is_last_chunk = (window.begin + window.size) == token_container->PrefillSize(); @@ -213,17 +155,26 @@ std::variant SchedulePrefillEvent::operator()(Prefillin if (end_of_window_pages < static_cast(paged_tokens.size())) { paged_tokens.resize(end_of_window_pages); } - InsertHybridCache(hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(), - local_mamba_allocator.get(), state.window.begin, state.window.size, state.GetPageSize()); - // Allocate KV pages for the new chunk - local_kv_allocator->Acquire(tokens_this_round_); - - // Allocate fresh mamba checkpoint for this chunk. - if (hybrid_prefix_cache_ != nullptr && local_mamba_allocator != nullptr) { - if (!local_mamba_allocator->AllocateCheckpoint()) { - throw std::logic_error("SchedulePrefillEvent: failed to allocate Mamba checkpoint slot"); - } - } + StepCommitRequest publication_request{ + .publish_device_prefix = + DevicePrefixPublicationRequest{ + .full_paged_tokens = &paged_tokens, + .device_node_ref = &device_node_ref, + .local_kv_allocator = local_kv_allocator.get(), + .local_mamba_allocator = local_mamba_allocator.get(), + }, + .request_local_kv = + RequestLocalKVStateRequest{ + .allocator = local_kv_allocator.get(), + .acquire_tokens = tokens_this_round_, + }, + .request_local_mamba = + RequestLocalMambaStateRequest{ + .refresh_checkpoint_allocator = local_mamba_allocator.get(), + .checkpoint_raw_position = state.window.begin + state.window.size + tokens_this_round_, + }, + }; + (void)hybrid_prefix_cache_.StepCommit(std::move(publication_request)); TokenContainer::Window window{.begin = state.window.begin + state.window.size, .size = tokens_this_round_}; @@ -263,17 +214,26 @@ Decoding ScheduleDecodeEvent::operator()(PrefillDone&& state) { if (end_of_window_pages < static_cast(paged_tokens.size())) { paged_tokens.resize(end_of_window_pages); } - InsertHybridCache(hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(), - local_mamba_allocator.get(), state.window.begin, state.window.size, state.GetPageSize()); - // Allocate fresh checkpoint for decode-phase mamba state tracking - if (hybrid_prefix_cache_ != nullptr && local_mamba_allocator != nullptr) { - if (!local_mamba_allocator->AllocateCheckpoint()) { - throw std::logic_error("ScheduleDecodeEvent: failed to allocate Mamba checkpoint slot"); - } - } - - std::int32_t reserve = state.GetReserveNumTokensInNextScheduleEvent(); - local_kv_allocator->Acquire(reserve); + StepCommitRequest publication_request{ + .publish_device_prefix = + DevicePrefixPublicationRequest{ + .full_paged_tokens = &paged_tokens, + .device_node_ref = &device_node_ref, + .local_kv_allocator = local_kv_allocator.get(), + .local_mamba_allocator = local_mamba_allocator.get(), + }, + .request_local_kv = + RequestLocalKVStateRequest{ + .allocator = local_kv_allocator.get(), + .acquire_tokens = state.GetReserveNumTokensInNextScheduleEvent(), + }, + .request_local_mamba = + RequestLocalMambaStateRequest{ + .refresh_checkpoint_allocator = local_mamba_allocator.get(), + .checkpoint_raw_position = state.GetTokenContainer()->Size() + decode_input_tokens_, + }, + }; + (void)hybrid_prefix_cache_.StepCommit(std::move(publication_request)); return Decoding{state.GetTokenContainer(), state.GetPageSize(), std::move(host_node_ref), std::move(device_node_ref), @@ -289,7 +249,13 @@ Decoding ScheduleDecodeEvent::operator()(Decoding&& state) { auto host_node_ref = std::move(state).TakeHostNodeRef(); std::int32_t reserve = state.GetReserveNumTokensInNextScheduleEvent(); - local_kv_allocator->Acquire(reserve); + (void)hybrid_prefix_cache_.StepCommit({ + .request_local_kv = + RequestLocalKVStateRequest{ + .allocator = local_kv_allocator.get(), + .acquire_tokens = reserve, + }, + }); return Decoding{state.GetTokenContainer(), state.GetPageSize(), std::move(host_node_ref), std::move(device_node_ref), @@ -304,8 +270,14 @@ Decoding ScheduleDecodeFromRetractedEvent::operator()(Retracted&& state) { std::unique_ptr device_node_ref{nullptr}; if (match_result_.host.DepthInPage() > match_result_.device.DepthInPage()) { host_node_ref = std::make_unique(match_result_.host.last_node); - if (!kv_prefix_cache_->AllocateResourceOfType( - match_result_.NodesWithout())) { + StepCommitRequest materialization_request{ + .materialize_prefix = + PrefixMaterializationRequest{ + .compat_match = &match_result_, + .require_all_pages = true, + }, + }; + if (!hybrid_prefix_cache_.StepCommit(std::move(materialization_request)).ok) { // Device allocation failed (race between capacity check and actual alloc). throw std::logic_error( "ScheduleDecodeFromRetractedEvent: failed to allocate device pages for host cache recovery"); @@ -320,19 +292,20 @@ Decoding ScheduleDecodeFromRetractedEvent::operator()(Retracted&& state) { auto local_kv_allocator = std::move(state).TakeKVAllocator(); auto old_mamba_allocator = std::move(state).TakeMambaAllocator(); old_mamba_allocator.reset(); - std::unique_ptr local_mamba_allocator; - if (mamba_allocator_ != nullptr) { - local_mamba_allocator = std::make_unique(mamba_allocator_); - if (!local_mamba_allocator->AllocateWorking()) { - throw std::logic_error("ScheduleDecodeFromRetractedEvent: failed to allocate Mamba recovery working slot"); - } - if (!local_mamba_allocator->AllocateCheckpoint()) { - throw std::logic_error( - "ScheduleDecodeFromRetractedEvent: failed to allocate Mamba recovery checkpoint slot"); - } - } auto req_pool_index = std::make_unique(req_pool_allocator_->Allocate()); - local_kv_allocator->Acquire(decode_input_tokens_); + auto step_result = hybrid_prefix_cache_.StepCommit({ + .request_local_kv = + RequestLocalKVStateRequest{ + .allocator = local_kv_allocator.get(), + .acquire_tokens = decode_input_tokens_, + }, + .request_local_mamba = + RequestLocalMambaStateRequest{ + .create_allocator = true, + .require_allocator = true, + }, + }); + auto local_mamba_allocator = std::move(step_result.local_mamba_allocator); return Decoding{token_container, page_size, std::move(host_node_ref), @@ -348,59 +321,46 @@ Decoding ScheduleDecodeFromRetractedEvent::operator()(Retracted&& state) { template std::variant FinishEvent::apply(ForwardStateT&& state) { auto full_paged_tokens = state.GetFullPagedTokens(true); - std::vector prefix_pages = DevicePagesFromRoot(state.GetDeviceNode()); - std::int32_t alloc_count = - static_cast(full_paged_tokens.size()) - static_cast(prefix_pages.size()); + const TreeNode* current_device_node = state.GetDeviceNode(); auto local_mamba_allocator = std::move(state).TakeLocalMambaAllocator(); auto local_allocator = std::move(state).TakeLocalKVAllocator(); - if (alloc_count > 0) { - OwnedPages alloc_pages = local_allocator->TakeFirst(alloc_count); - - kv_prefix_cache_->Insert(full_paged_tokens, prefix_pages, std::move(alloc_pages), - page_hashes_); - - // Mamba: insert the latest checkpoint snapshot at the terminal node. - if (hybrid_prefix_cache_ != nullptr && local_mamba_allocator != nullptr && - (local_mamba_allocator->HasCheckpoint() || local_mamba_allocator->HasWorking())) { - MatchResult post_match = kv_prefix_cache_->Match(full_paged_tokens); - TreeNode* terminal = post_match.device.last_node; - if (terminal != nullptr && !terminal->HasMamba()) { - if (local_mamba_allocator->HasCheckpoint()) { - hybrid_prefix_cache_->InsertMamba(terminal, local_mamba_allocator->DetachCheckpoint()); - } else { - hybrid_prefix_cache_->InsertMamba(terminal, local_mamba_allocator->DetachWorking()); - } - } - } - } + StepCommitRequest publication_request{ + .publish_finished_request = + FinishedRequestPublicationRequest{ + .full_paged_tokens = &full_paged_tokens, + .current_device_node = current_device_node, + .local_kv_allocator = local_allocator.get(), + .local_mamba_allocator = local_mamba_allocator.get(), + .page_hashes = &page_hashes_, + }, + }; + MatchResult match = hybrid_prefix_cache_.StepCommit(std::move(publication_request)).match_result; // local_mamba_allocator dropped here — destructor frees remaining slots - MatchResult match = kv_prefix_cache_->Match(full_paged_tokens); if (!disable_l2_cache_ && (match.device.DepthInPage() > match.host.DepthInPage())) { std::vector write_diff = match.NodesWithout(); - std::int32_t host_pages_num = 0; - for (TreeNode* node : write_diff) { - host_pages_num += node->Device().NumPages(); - } std::unique_ptr temp_lock = std::make_unique(match.host.last_node); - if (!kv_prefix_cache_->EnsureCapacityByEvict(host_pages_num)) { + StepCommitRequest materialization_request{ + .materialize_host_writeback = + HostWritebackMaterializationRequest{ + .write_diff = &write_diff, + .ensure_capacity_before_allocate = true, + }, + }; + StepCommitResult materialization_result = hybrid_prefix_cache_.StepCommit(std::move(materialization_request)); + if (!materialization_result.ok) { return Finished{}; } - kv_prefix_cache_->AllocateResourceOfType(write_diff); std::unique_ptr device_node_ref = std::make_unique(match.device.last_node); std::unique_ptr host_node_ref = std::make_unique(match.device.last_node); auto pages_to_transfer = BuildWriteBackPairs(write_diff); - std::vector mamba_writeback_nodes; - if (hybrid_prefix_cache_ != nullptr) { - auto mamba_pairs = hybrid_prefix_cache_->PrepareMambaHostWriteBack(write_diff); - mamba_writeback_nodes = MambaNodesForTransferPairs(write_diff, mamba_pairs); - pages_to_transfer.insert(pages_to_transfer.end(), std::make_move_iterator(mamba_pairs.begin()), - std::make_move_iterator(mamba_pairs.end())); - } + pages_to_transfer.insert(pages_to_transfer.end(), + std::make_move_iterator(materialization_result.cache_transfer_pairs.begin()), + std::make_move_iterator(materialization_result.cache_transfer_pairs.end())); return Draining{std::move(pages_to_transfer), std::move(device_node_ref), std::move(host_node_ref), - std::move(mamba_writeback_nodes)}; + std::move(materialization_result.mamba_writeback_nodes)}; } return Finished{}; } @@ -518,16 +478,22 @@ Retracting ScheduleRetractEvent::applyRetract(ForwardStateT&& state) { if (match_result_.device.DepthInPage() > match_result_.host.DepthInPage()) { std::vector write_diff = match_result_.NodesWithout(); device_node_ref = std::make_unique(match_result_.device.last_node); - if (!kv_prefix_cache_->AllocateResourceOfType(write_diff)) { + StepCommitRequest materialization_request{ + .materialize_host_writeback = + HostWritebackMaterializationRequest{ + .write_diff = &write_diff, + .ensure_capacity_before_allocate = false, + }, + }; + StepCommitResult materialization_result = hybrid_prefix_cache_.StepCommit(std::move(materialization_request)); + if (!materialization_result.ok) { throw std::logic_error("ScheduleRetractEvent: failed to allocate host pages for device cache writeback"); } pages_to_transfer = BuildWriteBackPairs(write_diff); - if (hybrid_prefix_cache_ != nullptr) { - auto mamba_pairs = hybrid_prefix_cache_->PrepareMambaHostWriteBack(write_diff); - mamba_writeback_nodes = MambaNodesForTransferPairs(write_diff, mamba_pairs); - pages_to_transfer.insert(pages_to_transfer.end(), std::make_move_iterator(mamba_pairs.begin()), - std::make_move_iterator(mamba_pairs.end())); - } + pages_to_transfer.insert(pages_to_transfer.end(), + std::make_move_iterator(materialization_result.cache_transfer_pairs.begin()), + std::make_move_iterator(materialization_result.cache_transfer_pairs.end())); + mamba_writeback_nodes = std::move(materialization_result.mamba_writeback_nodes); host_node_ref = std::make_unique(match_result_.device.last_node); } else { host_node_ref = std::make_unique(match_result_.device.last_node); @@ -538,23 +504,13 @@ Retracting ScheduleRetractEvent::applyRetract(ForwardStateT&& state) { auto local_allocator = std::move(state).TakeLocalKVAllocator(); auto local_mamba_allocator = std::move(state).TakeLocalMambaAllocator(); - // Mamba: save the latest checkpoint/working state into the prefix cache - // before the request is retracted, so it can be recovered on loadback. - if (hybrid_prefix_cache_ != nullptr && local_mamba_allocator != nullptr && - (local_mamba_allocator->HasCheckpoint() || local_mamba_allocator->HasWorking())) { - TreeNode* terminal = match_result_.device.last_node; - if (terminal != nullptr && !terminal->HasMamba()) { - if (local_mamba_allocator->HasCheckpoint()) { - hybrid_prefix_cache_->InsertMamba(terminal, local_mamba_allocator->DetachCheckpoint()); - } else { - hybrid_prefix_cache_->InsertMamba(terminal, local_mamba_allocator->DetachWorking()); - } - } - // Once retracted, the recoverable Mamba state is tree-owned and - // therefore evictable by HybridPrefixCache. Do not keep request-local - // slots alive in Retracting/Retracted. - local_mamba_allocator.reset(); - } + (void)hybrid_prefix_cache_.StepCommit({ + .publish_tree_owned_request_state = + TreeOwnedRequestStatePublicationRequest{ + .terminal = match_result_.device.last_node, + .local_mamba_allocator_owner = &local_mamba_allocator, + }, + }); return Retracting{token_container, page_size, diff --git a/tokenspeed-scheduler/csrc/fsm/forward_events.h b/tokenspeed-scheduler/csrc/fsm/forward_events.h index 0f42b86b6..11d50753f 100644 --- a/tokenspeed-scheduler/csrc/fsm/forward_events.h +++ b/tokenspeed-scheduler/csrc/fsm/forward_events.h @@ -36,13 +36,9 @@ #include "fsm/forward_states.h" #include "resource/types.h" #include "resource/hybrid_prefix_cache/hybrid_prefix_cache.h" -#include "resource/allocator/mamba_chunk_allocator.h" -#include "resource/allocator/local_mamba_allocator.h" #include "utils.h" namespace tokenspeed { -class PageAllocator; -class KVPrefixCache; class ReqPoolAllocator; class TreeNode; } // namespace tokenspeed @@ -52,33 +48,22 @@ namespace tokenspeed::fsm { struct PrefetchDone; struct Prefetching; -void InsertHybridCache(HybridPrefixCache* hybrid_prefix_cache, - const std::vector>& full_paged_tokens, - std::unique_ptr& device_node_ref, LocalKVAllocator* local_kv_allocator, - LocalMambaAllocator* local_mamba_allocator, std::int32_t chunk_begin, std::int32_t chunk_size, - std::int32_t page_size); - struct SchedulePrefillFirstChunkEvent : InvalidTransitionHandler { using InvalidTransitionHandler::operator(); SchedulePrefillFirstChunkEvent(std::int32_t tokens_this_round, std::int32_t decode_input_tokens, - PageAllocator* device_allocator, ReqPoolAllocator* req_pool_allocator, - MatchResult match_result, Role role, KVPrefixCache* kv_prefix_cache, + ReqPoolAllocator* req_pool_allocator, MatchResult match_result, Role role, bool disable_l2_cache, std::vector loadback_diff, - HybridPrefixCache* hybrid_prefix_cache = nullptr, - MambaChunkAllocator* mamba_allocator = nullptr, - std::vector mamba_loadback_nodes = {}) + std::vector cache_transfer_pairs, + HybridPrefixCache& hybrid_prefix_cache) : tokens_this_round_(tokens_this_round), decode_input_tokens_(decode_input_tokens), - device_allocator_(device_allocator), req_pool_allocator_(req_pool_allocator), match_result_(match_result), role_{role}, disable_l2_cache_{disable_l2_cache}, loadback_diff_(std::move(loadback_diff)), - mamba_loadback_nodes_(std::move(mamba_loadback_nodes)), - kv_prefix_cache_(kv_prefix_cache), - hybrid_prefix_cache_(hybrid_prefix_cache), - mamba_allocator_(mamba_allocator) {} + cache_transfer_pairs_(std::move(cache_transfer_pairs)), + hybrid_prefix_cache_(hybrid_prefix_cache) {} // Returns PrefillDone (single-chunk or last chunk) or Prefilling (more chunks remain). std::variant operator()(Submitted&& state); @@ -86,27 +71,24 @@ struct SchedulePrefillFirstChunkEvent : InvalidTransitionHandler& GetLoadbackDiff() const { return loadback_diff_; } - const std::vector& GetMambaLoadbackNodes() const { return mamba_loadback_nodes_; } + const std::vector& GetCacheTransferPairs() const { return cache_transfer_pairs_; } private: std::int32_t tokens_this_round_{}; std::int32_t decode_input_tokens_{}; - PageAllocator* device_allocator_{}; ReqPoolAllocator* req_pool_allocator_{}; const MatchResult match_result_{}; const Role role_; bool disable_l2_cache_{}; std::vector loadback_diff_; - std::vector mamba_loadback_nodes_; - KVPrefixCache* kv_prefix_cache_; - HybridPrefixCache* hybrid_prefix_cache_{}; - MambaChunkAllocator* mamba_allocator_{}; + std::vector cache_transfer_pairs_; + HybridPrefixCache& hybrid_prefix_cache_; }; struct SchedulePrefillEvent : InvalidTransitionHandler { using InvalidTransitionHandler::operator(); SchedulePrefillEvent(std::int32_t tokens_this_round, std::int32_t reserve_num_tokens_in_next_schedule_event, - HybridPrefixCache* hybrid_prefix_cache = nullptr) + HybridPrefixCache& hybrid_prefix_cache) : tokens_this_round_(tokens_this_round), reserve_num_tokens_in_next_schedule_event_(reserve_num_tokens_in_next_schedule_event), hybrid_prefix_cache_(hybrid_prefix_cache) {} @@ -117,13 +99,13 @@ struct SchedulePrefillEvent : InvalidTransitionHandler { private: std::int32_t tokens_this_round_{}; std::int32_t reserve_num_tokens_in_next_schedule_event_{}; - HybridPrefixCache* hybrid_prefix_cache_{}; + HybridPrefixCache& hybrid_prefix_cache_; }; struct ScheduleDecodeEvent : InvalidTransitionHandler { using InvalidTransitionHandler::operator(); - ScheduleDecodeEvent(std::int32_t decode_input_tokens, HybridPrefixCache* hybrid_prefix_cache = nullptr) + ScheduleDecodeEvent(std::int32_t decode_input_tokens, HybridPrefixCache& hybrid_prefix_cache) : decode_input_tokens_(decode_input_tokens), hybrid_prefix_cache_(hybrid_prefix_cache) {} Decoding operator()(PrefillDone&& state); @@ -131,53 +113,45 @@ struct ScheduleDecodeEvent : InvalidTransitionHandler { private: std::int32_t decode_input_tokens_; - HybridPrefixCache* hybrid_prefix_cache_{}; + HybridPrefixCache& hybrid_prefix_cache_; }; struct ScheduleDecodeFromRetractedEvent : InvalidTransitionHandler { using InvalidTransitionHandler::operator(); // Constructor for Retracted → Decoding recovery (LoadBack from host). - ScheduleDecodeFromRetractedEvent(std::int32_t decode_input_tokens, PageAllocator* device_allocator, - ReqPoolAllocator* req_pool_allocator, KVPrefixCache* kv_prefix_cache, + ScheduleDecodeFromRetractedEvent(std::int32_t decode_input_tokens, ReqPoolAllocator* req_pool_allocator, MatchResult match_result, std::vector loadback_diff, - MambaChunkAllocator* mamba_allocator = nullptr, - std::vector mamba_loadback_nodes = {}) + std::vector cache_transfer_pairs, + HybridPrefixCache& hybrid_prefix_cache) : decode_input_tokens_(decode_input_tokens), - device_allocator_(device_allocator), req_pool_allocator_(req_pool_allocator), - kv_prefix_cache_(kv_prefix_cache), match_result_(std::move(match_result)), loadback_diff_(std::move(loadback_diff)), - mamba_loadback_nodes_(std::move(mamba_loadback_nodes)), - mamba_allocator_(mamba_allocator) {} + cache_transfer_pairs_(std::move(cache_transfer_pairs)), + hybrid_prefix_cache_(hybrid_prefix_cache) {} Decoding operator()(Retracted&& state); const MatchResult& GetMatchResult() const { return match_result_; } const std::vector& GetLoadbackDiff() const { return loadback_diff_; } - const std::vector& GetMambaLoadbackNodes() const { return mamba_loadback_nodes_; } + const std::vector& GetCacheTransferPairs() const { return cache_transfer_pairs_; } private: std::int32_t decode_input_tokens_{}; - PageAllocator* device_allocator_{}; ReqPoolAllocator* req_pool_allocator_{}; - KVPrefixCache* kv_prefix_cache_{}; MatchResult match_result_{}; std::vector loadback_diff_; - std::vector mamba_loadback_nodes_; - MambaChunkAllocator* mamba_allocator_{}; + std::vector cache_transfer_pairs_; + HybridPrefixCache& hybrid_prefix_cache_; }; struct FinishEvent : InvalidTransitionHandler { using InvalidTransitionHandler::operator(); - explicit FinishEvent(KVPrefixCache* kv_prefix_cache, PageAllocator* host_allocator, - std::vector page_hashes = {}, bool disable_l2_cache = false, - HybridPrefixCache* hybrid_prefix_cache = nullptr) - : kv_prefix_cache_(kv_prefix_cache), - host_allocator_(host_allocator), - page_hashes_(std::move(page_hashes)), + explicit FinishEvent(std::vector page_hashes, bool disable_l2_cache, + HybridPrefixCache& hybrid_prefix_cache) + : page_hashes_(std::move(page_hashes)), disable_l2_cache_(disable_l2_cache), hybrid_prefix_cache_(hybrid_prefix_cache) {} @@ -192,11 +166,9 @@ struct FinishEvent : InvalidTransitionHandler { Finished operator()(Finished&& state) { return std::move(state); } private: - KVPrefixCache* kv_prefix_cache_{}; std::vector page_hashes_; - PageAllocator* host_allocator_; bool disable_l2_cache_; - HybridPrefixCache* hybrid_prefix_cache_{}; + HybridPrefixCache& hybrid_prefix_cache_; template std::variant apply(ForwardStateT&& state); @@ -221,12 +193,8 @@ struct AbortEvent : InvalidTransitionHandler { struct ScheduleRetractEvent : InvalidTransitionHandler { using InvalidTransitionHandler::operator(); - ScheduleRetractEvent(KVPrefixCache* kv_prefix_cache, PageAllocator* host_allocator, MatchResult match_result, - HybridPrefixCache* hybrid_prefix_cache = nullptr) - : kv_prefix_cache_(kv_prefix_cache), - host_allocator_(host_allocator), - match_result_(match_result), - hybrid_prefix_cache_(hybrid_prefix_cache) {} + ScheduleRetractEvent(MatchResult match_result, HybridPrefixCache& hybrid_prefix_cache) + : match_result_(match_result), hybrid_prefix_cache_(hybrid_prefix_cache) {} Retracting operator()(Decoding&& state); Retracting operator()(PrefillDone&& state); @@ -237,10 +205,8 @@ struct ScheduleRetractEvent : InvalidTransitionHandler { template Retracting applyRetract(ForwardStateT&& state); - KVPrefixCache* kv_prefix_cache_{}; - PageAllocator* host_allocator_{}; const MatchResult match_result_{}; - HybridPrefixCache* hybrid_prefix_cache_{}; + HybridPrefixCache& hybrid_prefix_cache_; }; // Draining → WritingBack: WriteBack op has been generated this round; transfer diff --git a/tokenspeed-scheduler/csrc/fsm/forward_states.h b/tokenspeed-scheduler/csrc/fsm/forward_states.h index c37a41ba1..53a79a41c 100644 --- a/tokenspeed-scheduler/csrc/fsm/forward_states.h +++ b/tokenspeed-scheduler/csrc/fsm/forward_states.h @@ -115,6 +115,7 @@ struct BaseState { } const TreeNode* GetDeviceNode() const { return device_node_ref_->Node(); } + TreeNode* GetMutableDeviceNode() const { return device_node_ref_->Node(); } std::int32_t TailPageAvailableTokens() const { return local_kv_allocator_->TailPageAvailableTokens(); } diff --git a/tokenspeed-scheduler/csrc/resource/allocator/local_mamba_allocator.cpp b/tokenspeed-scheduler/csrc/resource/allocator/local_mamba_allocator.cpp index 75d7532bb..a318ef01e 100644 --- a/tokenspeed-scheduler/csrc/resource/allocator/local_mamba_allocator.cpp +++ b/tokenspeed-scheduler/csrc/resource/allocator/local_mamba_allocator.cpp @@ -40,10 +40,11 @@ std::int32_t LocalMambaAllocator::WorkingIndex() const { return working_ ? working_->Index() : -1; } -bool LocalMambaAllocator::AllocateCheckpoint() { +bool LocalMambaAllocator::AllocateCheckpoint(std::int32_t raw_position) { auto slot = allocator_->Allocate(); if (!slot.has_value()) return false; checkpoint_ = std::make_unique(std::move(*slot)); + checkpoint_position_ = raw_position; return true; } @@ -52,6 +53,7 @@ std::int32_t LocalMambaAllocator::CheckpointIndex() const { } std::unique_ptr LocalMambaAllocator::DetachCheckpoint() { + checkpoint_position_ = -1; return std::move(checkpoint_); } diff --git a/tokenspeed-scheduler/csrc/resource/allocator/local_mamba_allocator.h b/tokenspeed-scheduler/csrc/resource/allocator/local_mamba_allocator.h index a541edee3..ed6a15eb2 100644 --- a/tokenspeed-scheduler/csrc/resource/allocator/local_mamba_allocator.h +++ b/tokenspeed-scheduler/csrc/resource/allocator/local_mamba_allocator.h @@ -44,8 +44,9 @@ class LocalMambaAllocator { std::int32_t WorkingIndex() const; bool HasWorking() const { return working_ != nullptr; } - bool AllocateCheckpoint(); + bool AllocateCheckpoint(std::int32_t raw_position = -1); std::int32_t CheckpointIndex() const; + std::int32_t CheckpointPosition() const { return checkpoint_position_; } bool HasCheckpoint() const { return checkpoint_ != nullptr; } std::unique_ptr DetachCheckpoint(); std::unique_ptr DetachWorking(); @@ -54,6 +55,7 @@ class LocalMambaAllocator { MambaChunkAllocator* allocator_; std::unique_ptr working_{}; std::unique_ptr checkpoint_{}; + std::int32_t checkpoint_position_{-1}; }; } // namespace tokenspeed diff --git a/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/family_registry.cpp b/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/family_registry.cpp new file mode 100644 index 000000000..75a635fdf --- /dev/null +++ b/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/family_registry.cpp @@ -0,0 +1,199 @@ +// Copyright (c) 2026 LightSeek Foundation +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "resource/hybrid_prefix_cache/family_registry.h" + +#include "resource/hybrid_prefix_cache/hybrid_prefix_cache.h" + +#include +#include +#include +#include +#include + +namespace tokenspeed { + +void FamilyRegistry::Clear() { + specs.clear(); + id_to_index_.clear(); + active_match_family_indices.clear(); + active_admit_family_indices.clear(); + active_commit_family_indices.clear(); + active_evict_family_indices.clear(); + active_finish_family_indices.clear(); + active_stats_family_indices.clear(); + active_compatibility_family_indices.clear(); +} + +const CacheResourceSpec* FamilyRegistry::FindById(const std::string& id) const { + auto it = id_to_index_.find(id); + if (it == id_to_index_.end()) return nullptr; + return &specs.at(static_cast(it->second)); +} + +const CacheResourceSpec& FamilyRegistry::At(std::int32_t family_index) const { + if (family_index < 0 || family_index >= static_cast(specs.size())) { + throw std::out_of_range("FamilyRegistry::At: family_index out of range"); + } + return specs.at(static_cast(family_index)); +} + +std::int32_t FamilyRegistry::Register(CacheResourceSpec spec, bool active_match, bool active_admit, bool active_commit, + bool active_evict, bool active_finish, bool active_stats, + bool active_compatibility) { + if (spec.id.empty()) { + throw std::invalid_argument("FamilyRegistry::Register: id must be non-empty"); + } + if (id_to_index_.find(spec.id) != id_to_index_.end()) { + throw std::invalid_argument("FamilyRegistry::Register: duplicate cache family id: " + spec.id); + } + + const std::int32_t index = static_cast(specs.size()); + spec.family_index = index; + id_to_index_.emplace(spec.id, index); + specs.push_back(std::move(spec)); + if (active_match) active_match_family_indices.push_back(index); + if (active_admit) active_admit_family_indices.push_back(index); + if (active_commit) active_commit_family_indices.push_back(index); + if (active_evict) active_evict_family_indices.push_back(index); + if (active_finish) active_finish_family_indices.push_back(index); + if (active_stats) active_stats_family_indices.push_back(index); + if (active_compatibility) active_compatibility_family_indices.push_back(index); + return index; +} + +namespace hybrid_prefix_cache::detail { + +FamilyRegistryBuildResult BuildFamilyRegistry(const FamilyRegistryBuildInput& input) { + FamilyRegistryBuildResult result{}; + + result.registry.Register( + CacheResourceSpec{ + .id = "kv.token_page", + .family = CacheFamily::TokenPage, + .attachment_kind = TreeAttachmentKind::ReusableTree, + .recoverability = Recoverability::Exact, + .publication = PublicationKind::CanonicalPrefixIndex, + .split_policy = SplitPolicy::CarrierKV, + .rows_per_page = input.kv_page_size, + .entry_stride_tokens = 1, + .required_for_recovery = true, + }, + /*active_match=*/true, /*active_admit=*/true, /*active_commit=*/true, + /*active_evict=*/true, /*active_finish=*/true, /*active_stats=*/true, + /*active_compatibility=*/true); + + if (input.has_mamba_adjunct) { + result.registry.Register( + CacheResourceSpec{ + .id = "mamba.checkpoint", + .family = CacheFamily::RecurrentState, + .attachment_kind = TreeAttachmentKind::ReusableTree, + .recoverability = Recoverability::AlignedCheckpoint, + .publication = PublicationKind::AuxiliaryLocalOnly, + .split_policy = SplitPolicy::CheckpointBoundary, + .checkpoint_chunk_tokens = input.mamba_cache_chunk_size, + .state_cohort_id = "mamba.checkpoint", + .required_for_recovery = true, + }, + /*active_match=*/true, /*active_admit=*/true, /*active_commit=*/true, + /*active_evict=*/true, /*active_finish=*/true, /*active_stats=*/true, + /*active_compatibility=*/true); + } + + for (const PagedCacheFamilyRegistryInput& group : input.paged_cache_groups) { + const auto& cfg = group.config; + CacheFamily family = CacheFamily::CompressedPage; + if (cfg.family == PagedCacheGroupFamily::State) { + family = cfg.retention == PagedCacheGroupConfig::Retention::SlidingWindow + ? CacheFamily::SlidingWindowState + : CacheFamily::CompressionTailState; + } + + if (group.required) { + if (cfg.family == PagedCacheGroupFamily::History) { + result.paged_cache_history_groups.push_back(group.group_id); + } else { + result.paged_cache_state_groups.push_back(group.group_id); + } + } + + CacheResourceSpec spec{ + .id = group.group_id, + .family = family, + .attachment_kind = + group.required ? TreeAttachmentKind::ReusableTree : TreeAttachmentKind::NoneForRequestLocal, + .recoverability = group.required ? Recoverability::Exact : Recoverability::RequestLocalOnly, + .publication = + group.required && cfg.family == PagedCacheGroupFamily::History + ? PublicationKind::CanonicalPrefixIndex + : (group.required ? PublicationKind::AuxiliaryLocalOnly : PublicationKind::RequestLocalOnly), + .split_policy = group.required ? SplitPolicy::SnapshotBoundary : SplitPolicy::RequestLocalOnly, + .rows_per_page = cfg.rows_per_page, + .entry_stride_tokens = cfg.entry_stride_tokens, + .sliding_window_tokens = cfg.sliding_window_tokens, + .state_cohort_id = group.required ? "paged.required" : std::string{}, + .required_for_recovery = group.required, + }; + + result.registry.Register(std::move(spec), /*active_match=*/group.required, /*active_admit=*/true, + /*active_commit=*/group.required, /*active_evict=*/group.required, + /*active_finish=*/true, /*active_stats=*/true, + /*active_compatibility=*/true); + } + + result.paged_cache_history_group_set = std::unordered_set(result.paged_cache_history_groups.begin(), + result.paged_cache_history_groups.end()); + result.paged_cache_state_group_set = + std::unordered_set(result.paged_cache_state_groups.begin(), result.paged_cache_state_groups.end()); + return result; +} + +} // namespace hybrid_prefix_cache::detail + +void HybridPrefixCache::RebuildFamilyRegistry() { + const std::unordered_set required_group_set(paged_cache_required_groups_.begin(), + paged_cache_required_groups_.end()); + std::vector paged_groups; + paged_groups.reserve(paged_cache_allocators_.size()); + for (const auto& [gid, allocator] : paged_cache_allocators_) { + if (allocator == nullptr) continue; + paged_groups.push_back({ + .group_id = gid, + .config = allocator->Config(), + .required = required_group_set.find(gid) != required_group_set.end(), + }); + } + + auto result = hybrid_prefix_cache::detail::BuildFamilyRegistry({ + .kv_page_size = kv_prefix_cache_.PageSize(), + .has_mamba_adjunct = HasMambaAdjunct(), + .mamba_cache_chunk_size = mamba_cache_chunk_size_, + .paged_cache_groups = std::span{paged_groups}, + }); + + family_registry_ = std::move(result.registry); + paged_cache_history_groups_ = std::move(result.paged_cache_history_groups); + paged_cache_state_groups_ = std::move(result.paged_cache_state_groups); + paged_cache_history_group_set_ = std::move(result.paged_cache_history_group_set); + paged_cache_state_group_set_ = std::move(result.paged_cache_state_group_set); +} + +} // namespace tokenspeed diff --git a/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/family_registry.h b/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/family_registry.h new file mode 100644 index 000000000..e9a81513e --- /dev/null +++ b/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/family_registry.h @@ -0,0 +1,57 @@ +// Copyright (c) 2026 LightSeek Foundation +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#pragma once + +#include +#include +#include +#include +#include + +#include "resource/allocator/paged_cache_group.h" +#include "resource/hybrid_prefix_cache/hybrid_prefix_cache_types.h" + +namespace tokenspeed::hybrid_prefix_cache::detail { + +struct PagedCacheFamilyRegistryInput { + std::string group_id; + PagedCacheGroupConfig config; + bool required{false}; +}; + +struct FamilyRegistryBuildInput { + std::int32_t kv_page_size{0}; + bool has_mamba_adjunct{false}; + std::int32_t mamba_cache_chunk_size{0}; + std::span paged_cache_groups{}; +}; + +struct FamilyRegistryBuildResult { + FamilyRegistry registry{}; + std::vector paged_cache_history_groups{}; + std::vector paged_cache_state_groups{}; + std::unordered_set paged_cache_history_group_set{}; + std::unordered_set paged_cache_state_group_set{}; +}; + +FamilyRegistryBuildResult BuildFamilyRegistry(const FamilyRegistryBuildInput& input); + +} // namespace tokenspeed::hybrid_prefix_cache::detail diff --git a/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache.cpp b/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache.cpp index 361067454..19a110b13 100644 --- a/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache.cpp +++ b/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache.cpp @@ -19,11 +19,15 @@ // SOFTWARE. #include "resource/hybrid_prefix_cache/hybrid_prefix_cache.h" +#include "resource/allocator/kv_allocator.h" +#include "resource/allocator/local_mamba_allocator.h" #include "resource/allocator/mamba_chunk_allocator.h" #include "resource/allocator/mamba_host_allocator.h" +#include "resource/allocator/owned_pages.h" +#include "resource/allocator/page_allocator.h" #include "resource/allocator/paged_cache_group.h" -#include "resource/radix_tree/paged_cache_snapshot.h" #include "resource/radix_tree/node_range.h" +#include "resource/radix_tree/paged_cache_snapshot.h" #include "resource/radix_tree/radix_tree.h" #include "resource/radix_tree/tree_node.h" #include "scheduler/operations/forward.h" @@ -34,256 +38,357 @@ #include #include #include -#include #include #include namespace tokenspeed { -HybridPrefixCache::HybridPrefixCache(KVPrefixCache& kv_prefix_cache, MambaChunkAllocator* mamba_allocator, - std::int32_t mamba_cache_chunk_size, MambaHostAllocator* mamba_host_allocator) +HybridPrefixCache::HybridPrefixCache(KVPrefixCache& kv_prefix_cache, PageAllocator& device_allocator, + MambaChunkAllocator* mamba_allocator, std::int32_t mamba_cache_chunk_size, + MambaHostAllocator* mamba_host_allocator) : kv_prefix_cache_{kv_prefix_cache}, + device_allocator_{device_allocator}, mamba_allocator_{mamba_allocator}, mamba_host_allocator_{mamba_host_allocator}, mamba_eviction_manager_{mamba_allocator}, - mamba_cache_chunk_size_{mamba_cache_chunk_size} {} - -MatchResult HybridPrefixCache::Match(const token_vec_t& token_ids, MatchIntent intent) { - auto match = kv_prefix_cache_.Match(token_ids, intent); - augmentMatch(match); - augmentMatchPagedCache(match); - return match; + mamba_cache_chunk_size_{mamba_cache_chunk_size} { + kv_prefix_cache_.GetDeviceManager().SetEvictionCallback([this](TreeNode* node) { OnKVEvict(node); }); + kv_prefix_cache_.GetHostManager().SetEvictionCallback([this](TreeNode* node) { OnKVHostEvict(node); }); + RebuildFamilyRegistry(); } -MatchResult HybridPrefixCache::Match(const std::vector>& token_pages, - MatchIntent intent) { - auto match = kv_prefix_cache_.Match(token_pages, intent); - augmentMatch(match); - augmentMatchPagedCache(match); - return match; +HybridPrefixCache::~HybridPrefixCache() { + SetKvEventSink({}); + kv_prefix_cache_.GetDeviceManager().SetEvictionCallback({}); + kv_prefix_cache_.GetHostManager().SetEvictionCallback({}); } -void HybridPrefixCache::augmentMatch(MatchResult& match) const { - if (mamba_allocator_ == nullptr) return; - TreeNode* root = match.device.last_node; - while (root != nullptr && !root->IsRoot()) root = root->Parent(); - if (root == nullptr) return; - - // Backward-compatible path: before Mamba L2 is enabled, only device Mamba is - // a valid hybrid prefix source and both match tiers are truncated together. - if (mamba_host_allocator_ == nullptr) { - TreeNode* kv_terminal = match.device.last_node; - if (kv_terminal == nullptr || kv_terminal->IsRoot()) return; - - TreeNode* mamba_node = FindLastMambaNode(kv_terminal); - if (mamba_node == nullptr) { - const std::int32_t kv_depth = match.device.DepthInPage(); - const std::int32_t aligned_seqlen = AlignMambaCacheSeqlen(kv_depth * match.device.page_size); - if (aligned_seqlen > 0) { - match.mamba_branching_seqlen = aligned_seqlen; - } - match.device.last_node = root; - match.host.last_node = root; - return; - } - - std::int32_t page_size = match.device.page_size; - std::int32_t kv_depth = match.device.DepthInPage(); - std::int32_t mamba_depth = mamba_node->DepthInPage(page_size); - match.mamba_cow_src_index = mamba_node->MambaSlotIndex(); - if (kv_depth > mamba_depth) { - const std::int32_t aligned_seqlen = AlignMambaCacheSeqlen(kv_depth * page_size); - if (aligned_seqlen > mamba_depth * page_size) { - match.mamba_branching_seqlen = aligned_seqlen; +RecoveryPlan HybridPrefixCache::MatchPrefix(const token_vec_t& token_ids, MatchIntent intent) { + DemoteIdleMambaDeviceCopiesPresentOnHost(); + MatchResult raw_match = kv_prefix_cache_.Match(token_ids, intent); + RecoveryPlan plan{}; + const auto depth_tokens = [](const TreeNode* node) -> std::int32_t { + return node == nullptr ? 0 : static_cast(node->DepthInTokens()); + }; + plan.raw_token_match_end_tokens = + std::max(depth_tokens(raw_match.device.last_node), depth_tokens(raw_match.host.last_node)); + plan.compat_match = raw_match; + augmentMatch(plan.compat_match); + augmentMatchPagedCache(plan.compat_match); + if (intent == MatchIntent::StateRecovery) { + const DecodeFromRetractedRecovery recovery = PrepareDecodeFromRetractedRecovery(plan.compat_match); + plan.recovery_state_available = recovery.ok; + plan.protected_recovery_node = recovery.protected_source_node; + } + plan.recoverable_prefix_end_tokens = + std::max(depth_tokens(plan.compat_match.device.last_node), depth_tokens(plan.compat_match.host.last_node)); + if (plan.compat_match.paged_cache.prefix_len_tokens > 0) { + plan.recoverable_prefix_end_tokens = + std::min(plan.recoverable_prefix_end_tokens, plan.compat_match.paged_cache.prefix_len_tokens); + } + plan.execution_resume_tokens = plan.recoverable_prefix_end_tokens; + BuildRecoveryPlanSlices(plan); + return plan; +} + +RecoveryPlan HybridPrefixCache::MatchPrefix(const std::vector>& token_pages, + MatchIntent intent) { + DemoteIdleMambaDeviceCopiesPresentOnHost(); + MatchResult raw_match = kv_prefix_cache_.Match(token_pages, intent); + RecoveryPlan plan{}; + const auto depth_tokens = [](const TreeNode* node) -> std::int32_t { + return node == nullptr ? 0 : static_cast(node->DepthInTokens()); + }; + plan.raw_token_match_end_tokens = + std::max(depth_tokens(raw_match.device.last_node), depth_tokens(raw_match.host.last_node)); + plan.compat_match = raw_match; + augmentMatch(plan.compat_match); + augmentMatchPagedCache(plan.compat_match); + if (intent == MatchIntent::StateRecovery) { + const DecodeFromRetractedRecovery recovery = PrepareDecodeFromRetractedRecovery(plan.compat_match); + plan.recovery_state_available = recovery.ok; + plan.protected_recovery_node = recovery.protected_source_node; + } + plan.recoverable_prefix_end_tokens = + std::max(depth_tokens(plan.compat_match.device.last_node), depth_tokens(plan.compat_match.host.last_node)); + if (plan.compat_match.paged_cache.prefix_len_tokens > 0) { + plan.recoverable_prefix_end_tokens = + std::min(plan.recoverable_prefix_end_tokens, plan.compat_match.paged_cache.prefix_len_tokens); + } + plan.execution_resume_tokens = plan.recoverable_prefix_end_tokens; + BuildRecoveryPlanSlices(plan); + return plan; +} + +void HybridPrefixCache::BuildRecoveryPlanSlices(RecoveryPlan& plan) const { + plan.slices.clear(); + plan.slices.reserve(family_registry_.active_match_family_indices.size()); + for (std::int32_t family_index : family_registry_.active_match_family_indices) { + const CacheResourceSpec& spec = family_registry_.At(family_index); + FamilySlice slice{ + .family_index = family_index, + .family_id = spec.id, + .family = spec.family, + .recoverable_end_tokens = plan.recoverable_prefix_end_tokens, + .required_for_recovery = spec.required_for_recovery, + }; + switch (spec.family) { + case CacheFamily::TokenPage: + slice.hit_node = plan.compat_match.device.last_node; + break; + case CacheFamily::RecurrentState: + case CacheFamily::ConvState: + slice.hit_node = plan.compat_match.device.last_node; + if (plan.compat_match.mamba_cow_src_index >= 0) { + slice.borrowed_ids.push_back(plan.compat_match.mamba_cow_src_index); + } + break; + case CacheFamily::CompressedPage: + case CacheFamily::SlidingWindowState: + case CacheFamily::CompressionTailState: { + slice.hit_node = plan.compat_match.paged_cache.last_node; + slice.recoverable_end_tokens = plan.compat_match.paged_cache.prefix_len_tokens; + auto page_it = plan.compat_match.paged_cache.per_group_page_ids.find(spec.id); + if (page_it != plan.compat_match.paged_cache.per_group_page_ids.end()) { + slice.borrowed_ids = page_it->second; + } + auto base_it = plan.compat_match.paged_cache.per_group_base_logical_page.find(spec.id); + if (base_it != plan.compat_match.paged_cache.per_group_base_logical_page.end()) { + slice.base_logical_page = base_it->second; + } + break; } } - match.device.last_node = mamba_node; - match.host.last_node = mamba_node; - return; - } - - const std::int32_t page_size = match.device.page_size; - const std::int32_t kv_depth = std::max(match.device.DepthInPage(), match.host.DepthInPage()); - - TreeNode* device_mamba_node = FindLastMambaNode(match.device.last_node); - TreeNode* host_mamba_node = FindLastMambaHostNode(match.host.last_node); - const std::int32_t device_mamba_depth = - device_mamba_node == nullptr ? 0 : device_mamba_node->DepthInPage(page_size); - const std::int32_t host_mamba_depth = host_mamba_node == nullptr ? 0 : host_mamba_node->DepthInPage(page_size); - const bool prefer_host_mamba = host_mamba_depth > device_mamba_depth; - std::int32_t mamba_depth = 0; - - if (device_mamba_node != nullptr) { - match.device.last_node = device_mamba_node; - if (!prefer_host_mamba) { - match.mamba_cow_src_index = device_mamba_node->MambaSlotIndex(); - } - mamba_depth = std::max(mamba_depth, device_mamba_depth); - } else { - match.device.last_node = root; - } - - if (host_mamba_node != nullptr) { - match.host.last_node = host_mamba_node; - match.mamba_host_src_index = host_mamba_node->MambaHostSlotIndex(); - if (prefer_host_mamba) { - match.mamba_cow_src_index = -1; - } - mamba_depth = std::max(mamba_depth, host_mamba_depth); - } else { - match.host.last_node = root; - } - - if (kv_depth > mamba_depth) { - const std::int32_t aligned_seqlen = AlignMambaCacheSeqlen(kv_depth * page_size); - if (aligned_seqlen > mamba_depth * page_size) { - match.mamba_branching_seqlen = aligned_seqlen; - } + plan.slices.push_back(std::move(slice)); } } -std::int32_t HybridPrefixCache::AlignMambaCacheSeqlen(std::int32_t seqlen) const { - if (mamba_cache_chunk_size_ <= 0) return seqlen; - return (seqlen / mamba_cache_chunk_size_) * mamba_cache_chunk_size_; -} - -TreeNode* HybridPrefixCache::FindLastMambaNode(TreeNode* from) const { - for (TreeNode* cur = from; cur != nullptr && !cur->IsRoot(); cur = cur->Parent()) { - if (cur->HasMamba()) return cur; - } - return nullptr; +HybridPrefixCache::RawHostStorageHashSeed HybridPrefixCache::LookupRawHostStorageHashSeed( + const std::vector>& token_pages) { + MatchResult match = kv_prefix_cache_.Match(token_pages); + const std::int32_t host_matched_pages = match.host.DepthInPage(); + const auto& page_hashes = match.host.last_node->PageHashes(); + return RawHostStorageHashSeed{ + .host_matched_pages = host_matched_pages, + .prior_hash_seed = page_hashes.empty() ? std::string{} : page_hashes.back(), + }; } -TreeNode* HybridPrefixCache::FindLastMambaHostNode(TreeNode* from) const { - for (TreeNode* cur = from; cur != nullptr && !cur->IsRoot(); cur = cur->Parent()) { - if (cur->HasMambaOnHost()) return cur; - } - return nullptr; +cache_op_id HybridPrefixCache::AllocateCacheOpId() { + return kv_prefix_cache_.AllocateCacheOpId(); } -bool HybridPrefixCache::EnsureMambaHostCapacityByEvict(std::int32_t num_slots, TreeNode* protected_node) { - if (mamba_host_allocator_ == nullptr) return num_slots <= 0; - if (mamba_host_allocator_->AvailableSlots() >= num_slots) return true; - - std::vector candidates; - candidates.reserve(mamba_host_nodes_.size()); - for (TreeNode* node : mamba_host_nodes_) { - if (node == nullptr || node == protected_node || !node->HasMambaOnHost()) continue; - if (node->OnHost() && GetResource(node).RefCount() > 0) continue; - candidates.push_back(node); +void HybridPrefixCache::SetKvEventSink(KvEventSink sink) { + if (!sink) { + if (has_facade_kv_event_sink_) { + kv_prefix_cache_.SetKvEventSink({}); + has_facade_kv_event_sink_ = false; + } + return; } - std::sort(candidates.begin(), candidates.end(), - [](const TreeNode* lhs, const TreeNode* rhs) { return lhs->Time() < rhs->Time(); }); - for (TreeNode* node : candidates) { - if (mamba_host_allocator_->AvailableSlots() >= num_slots) break; - node->DetachMambaHost(); - mamba_host_nodes_.erase(node); - } - if (mamba_host_allocator_->AvailableSlots() < num_slots) { - spdlog::warn("[HybridPrefixCache] mamba host capacity exhausted required={} after_evict_available={}", - num_slots, mamba_host_allocator_->AvailableSlots()); - } - return mamba_host_allocator_->AvailableSlots() >= num_slots; + kv_prefix_cache_.SetKvEventSink(std::move(sink)); + has_facade_kv_event_sink_ = true; } -std::vector HybridPrefixCache::PrepareMambaHostWriteBack(const std::vector& nodes) { - std::vector transfers; - if (mamba_allocator_ == nullptr || mamba_host_allocator_ == nullptr) return transfers; - - std::int32_t needed = 0; - for (TreeNode* node : nodes) { - if (node != nullptr && node->HasMamba() && !node->HasMambaOnHost() && - pending_mamba_host_writebacks_.find(node) == pending_mamba_host_writebacks_.end()) { - needed++; +HybridPrefixCache::RequestLocalKVResult HybridPrefixCache::PrepareRequestLocalKV( + const RequestLocalKVRequest& request) const { + RequestLocalKVResult result{}; + auto require_allocator = [&]() -> LocalKVAllocator& { + if (request.local_kv_allocator == nullptr) { + throw std::invalid_argument("HybridPrefixCache::PrepareRequestLocalKV requires local_kv_allocator"); } - } - if (!EnsureMambaHostCapacityByEvict(needed)) return transfers; + return *request.local_kv_allocator; + }; - for (TreeNode* node : nodes) { - if (node == nullptr || !node->HasMamba() || node->HasMambaOnHost()) continue; - if (pending_mamba_host_writebacks_.find(node) != pending_mamba_host_writebacks_.end()) continue; - auto slot = mamba_host_allocator_->Allocate(); - if (!slot.has_value()) break; - const std::int32_t device_idx = node->MambaSlotIndex(); - const std::int32_t host_idx = slot->Index(); - pending_mamba_host_writebacks_.emplace(node, std::make_unique(std::move(*slot))); - transfers.push_back(TransferPair{CacheKind::kMamba, device_idx, host_idx}); + switch (request.kind) { + case RequestLocalKVKind::kPrefillFirstChunk: + result.local_kv_allocator = + std::make_unique(&device_allocator_, request.tokens_this_round); + result.local_kv_allocator->Acquire(request.decode_input_tokens); + return result; + case RequestLocalKVKind::kPrefillChunk: + require_allocator().Acquire(request.tokens_this_round); + return result; + case RequestLocalKVKind::kDecodeReserve: + require_allocator().Acquire(request.reserve_tokens); + return result; + case RequestLocalKVKind::kDecodeFromRetractedReserve: + require_allocator().Acquire(request.decode_input_tokens); + return result; } - return transfers; + return result; } -std::vector HybridPrefixCache::PrepareMambaDeviceLoadBack(const std::vector& nodes) { - std::vector transfers; - if (mamba_allocator_ == nullptr || mamba_host_allocator_ == nullptr) return transfers; +HybridPrefixCache::CachePublicationResult HybridPrefixCache::Publish(CachePublicationRequest request) { + CachePublicationResult result{}; + auto require_tokens = [&]() -> const std::vector>& { + if (request.full_paged_tokens == nullptr) { + throw std::invalid_argument("HybridPrefixCache::Publish requires full_paged_tokens"); + } + return *request.full_paged_tokens; + }; + auto require_device_ref = [&]() -> std::unique_ptr& { + if (request.device_node_ref == nullptr || *request.device_node_ref == nullptr) { + throw std::invalid_argument("HybridPrefixCache::Publish requires device_node_ref"); + } + return *request.device_node_ref; + }; + auto require_current_device_node = [&]() -> const TreeNode* { + if (request.current_device_node == nullptr) { + throw std::invalid_argument("HybridPrefixCache::Publish requires current_device_node"); + } + return request.current_device_node; + }; + auto require_local_kv_allocator = [&]() -> LocalKVAllocator& { + if (request.local_kv_allocator == nullptr) { + throw std::invalid_argument("HybridPrefixCache::Publish requires local_kv_allocator"); + } + return *request.local_kv_allocator; + }; + auto require_page_hashes = [&]() -> const std::vector& { + if (request.page_hashes == nullptr) { + throw std::invalid_argument("HybridPrefixCache::Publish requires page_hashes"); + } + return *request.page_hashes; + }; - for (TreeNode* node : nodes) { - if (node == nullptr || !node->HasMambaOnHost() || node->HasMamba()) continue; - auto slot = mamba_allocator_->Allocate(); - if (!slot.has_value()) break; - const std::int32_t host_idx = node->MambaHostSlotIndex(); - const std::int32_t device_idx = slot->Index(); - node->AttachMamba(std::make_unique(std::move(*slot))); - mamba_eviction_manager_.TrackNode(node); - transfers.push_back(TransferPair{CacheKind::kMamba, host_idx, device_idx}); - } - return transfers; -} + switch (request.kind) { + case CachePublicationKind::kForwardChunk: { + const auto& full_paged_tokens = require_tokens(); + auto& device_node_ref = require_device_ref(); + LocalKVAllocator& local_kv_allocator = require_local_kv_allocator(); + std::vector prefix_pages = DevicePagesFromRoot(device_node_ref->Node()); + const std::int32_t new_page_count = + static_cast(full_paged_tokens.size()) - static_cast(prefix_pages.size()); + if (new_page_count <= 0) return result; + + OwnedPages pages_to_insert = local_kv_allocator.TakeFirst(new_page_count); + auto insert_result = kv_prefix_cache_.Insert(full_paged_tokens, prefix_pages, + std::move(pages_to_insert)); + + if (request.local_mamba_allocator != nullptr && request.local_mamba_allocator->HasCheckpoint()) { + const std::int32_t checkpoint_position = request.local_mamba_allocator->CheckpointPosition(); + if (checkpoint_position >= 0 && + checkpoint_position != static_cast(insert_result.last_node->DepthInTokens())) { + device_node_ref = std::make_unique(insert_result.last_node); + result.device_insert_page_count = new_page_count; + return result; + } + InsertMamba(insert_result.last_node, request.local_mamba_allocator->DetachCheckpoint()); + } + device_node_ref = std::make_unique(insert_result.last_node); + result.device_insert_page_count = new_page_count; + return result; + } + case CachePublicationKind::kFinishChunk: { + const auto& full_paged_tokens = require_tokens(); + const TreeNode* current_device_node = require_current_device_node(); + LocalKVAllocator& local_kv_allocator = require_local_kv_allocator(); + const std::vector& page_hashes = require_page_hashes(); + std::vector prefix_pages = DevicePagesFromRoot(current_device_node); + const std::int32_t alloc_count = + static_cast(full_paged_tokens.size()) - static_cast(prefix_pages.size()); + + if (alloc_count > 0) { + OwnedPages alloc_pages = local_kv_allocator.TakeFirst(alloc_count); + kv_prefix_cache_.Insert(full_paged_tokens, prefix_pages, std::move(alloc_pages), + page_hashes); + PublishFinishMambaState(full_paged_tokens, request.local_mamba_allocator); + } -bool HybridPrefixCache::EnsureMambaCapacityByEvict(std::int32_t num_slots, TreeNode* protected_node) { - if (mamba_allocator_ == nullptr) return num_slots <= 0; - return mamba_eviction_manager_.EnsureCapacity(num_slots, protected_node); -} + result.device_insert_page_count = std::max(0, alloc_count); + result.match_result = kv_prefix_cache_.Match(full_paged_tokens); + return result; + } + case CachePublicationKind::kRetractDeviceInsertPageCount: { + const auto& full_paged_tokens = require_tokens(); + const TreeNode* current_device_node = require_current_device_node(); + std::vector prefix_pages = DevicePagesFromRoot(current_device_node); + const std::int32_t full_page_count = static_cast(full_paged_tokens.size()); + const std::int32_t prefix_page_count = static_cast(prefix_pages.size()); + if (full_page_count < prefix_page_count) { + throw std::logic_error( + "HybridPrefixCache::Publish retract plan: current device prefix exceeds " + "available full token pages"); + } + result.device_insert_page_count = full_page_count - prefix_page_count; + return result; + } + case CachePublicationKind::kRetractChunk: { + const auto& full_paged_tokens = require_tokens(); + const TreeNode* current_device_node = require_current_device_node(); + std::vector prefix_pages = DevicePagesFromRoot(current_device_node); + const std::int32_t alloc_count = + static_cast(full_paged_tokens.size()) - static_cast(prefix_pages.size()); + if (alloc_count < 0) { + throw std::logic_error( + "HybridPrefixCache::Publish retract chunk: current device prefix exceeds " + "available full token pages"); + } + if (request.pages_to_insert.Size() != alloc_count) { + throw std::logic_error("HybridPrefixCache::Publish retract chunk: request-local page count mismatch"); + } -void HybridPrefixCache::InsertMamba(TreeNode* terminal_node, std::unique_ptr slot) { - if (terminal_node == nullptr || slot == nullptr) return; - if (mamba_allocator_ == nullptr) { - throw std::logic_error("HybridPrefixCache::InsertMamba: mamba adjunct not enabled"); - } - const std::int32_t page_size = kv_prefix_cache_.PageSize(); - if (page_size <= 0 || terminal_node->DepthInTokens() % static_cast(page_size) != 0) { - throw std::logic_error("HybridPrefixCache::InsertMamba: terminal node is not block-aligned"); + kv_prefix_cache_.Insert(full_paged_tokens, prefix_pages, + std::move(request.pages_to_insert)); + result.device_insert_page_count = alloc_count; + result.match_result = kv_prefix_cache_.Match(full_paged_tokens, MatchIntent::StateRecovery); + return result; + } } - terminal_node->AttachMamba(std::move(slot)); - mamba_eviction_manager_.TrackNode(terminal_node); + return result; } -bool HybridPrefixCache::AttachPagedCacheSnapshotToNode(TreeNode* node, std::unique_ptr snapshot) { - if (node == nullptr || snapshot == nullptr) return false; - // Compute completeness from what is present. The policy-driven "snapshot - // must be full" invariant is enforced upstream by CommitChunk, which only - // attaches full snapshots; direct callers (tests, future restore paths) - // may attach history-only or state-only snapshots without policy gating. - snapshot->complete_families.clear(); - bool history_complete = !paged_cache_history_groups_.empty(); - for (const auto& gid : paged_cache_history_groups_) { - if (snapshot->groups.find(gid) == snapshot->groups.end()) { - history_complete = false; - break; +HybridPrefixCache::CacheMaterializationResult HybridPrefixCache::Materialize( + const CacheMaterializationRequest& request) { + CacheMaterializationResult result{}; + auto require_match = [&]() -> const MatchResult& { + if (request.match_result == nullptr) { + throw std::invalid_argument("HybridPrefixCache::Materialize requires match_result"); } - } - if (history_complete) { - snapshot->complete_families.insert(PagedCacheGroupFamily::History); - } - bool state_complete = !paged_cache_state_groups_.empty(); - for (const auto& gid : paged_cache_state_groups_) { - if (snapshot->groups.find(gid) == snapshot->groups.end()) { - state_complete = false; - break; + return *request.match_result; + }; + auto require_write_diff = [&]() -> const std::vector& { + if (request.write_diff == nullptr) { + throw std::invalid_argument("HybridPrefixCache::Materialize requires write_diff"); } - } - if (state_complete) { - snapshot->complete_families.insert(PagedCacheGroupFamily::State); - } - node->AttachPagedCacheSnapshot(std::move(snapshot)); - paged_cache_snapshot_nodes_.insert(node); - return true; -} + return *request.write_diff; + }; -std::unique_ptr HybridPrefixCache::DetachPagedCacheSnapshotFromNode(TreeNode* node) { - if (node == nullptr) return nullptr; - paged_cache_snapshot_nodes_.erase(node); - return node->DetachPagedCacheSnapshot(); + switch (request.kind) { + case CacheMaterializationKind::kPrefillHostPrefixOnDevice: { + const MatchResult& match_result = require_match(); + (void)kv_prefix_cache_.AllocateResourceOfType( + match_result.NodesWithout()); + return result; + } + case CacheMaterializationKind::kDecodeRecoveryHostPrefixOnDevice: { + const MatchResult& match_result = require_match(); + result.ok = kv_prefix_cache_.AllocateResourceOfType( + match_result.NodesWithout()); + return result; + } + case CacheMaterializationKind::kFinishWritebackHostPages: { + const std::vector& write_diff = require_write_diff(); + std::int32_t host_pages_num = 0; + for (TreeNode* node : write_diff) { + host_pages_num += node->Device().NumPages(); + } + if (!kv_prefix_cache_.EnsureCapacityByEvict(host_pages_num)) { + result.ok = false; + return result; + } + (void)kv_prefix_cache_.AllocateResourceOfType(write_diff); + return result; + } + case CacheMaterializationKind::kRetractWritebackHostPages: { + const std::vector& write_diff = require_write_diff(); + result.ok = kv_prefix_cache_.AllocateResourceOfType(write_diff); + return result; + } + } + return result; } void HybridPrefixCache::OnKVEvict(TreeNode* node) { @@ -386,803 +491,460 @@ void HybridPrefixCache::OnKVDeviceDemote(TreeNode* node) { } } -std::int32_t HybridPrefixCache::AvailableSlots() const { - if (mamba_allocator_ == nullptr) return 0; - return mamba_allocator_->AvailableSlots(); +std::size_t HybridPrefixCache::AvailableDevicePages() const { + return Stats().available_device_pages; } -void HybridPrefixCache::RegisterPagedCacheGroup(std::unique_ptr allocator) { - if (allocator == nullptr) { - throw std::invalid_argument("HybridPrefixCache::RegisterPagedCacheGroup: null allocator"); - } - std::string gid = allocator->Config().group_id; - if (paged_cache_allocators_.find(gid) != paged_cache_allocators_.end()) { - throw std::invalid_argument("HybridPrefixCache::RegisterPagedCacheGroup: duplicate group_id: " + gid); - } - paged_cache_allocators_.emplace(std::move(gid), std::move(allocator)); -} - -void HybridPrefixCache::EnablePagedCacheAdjunct(std::vector required_groups, - std::unordered_map sliding_window_per_group, - StateRestorePolicy policy) { - if (required_groups.empty()) { - throw std::invalid_argument("HybridPrefixCache::EnablePagedCacheAdjunct: required_groups must be non-empty"); - } - std::vector history_gids; - std::vector state_gids; - std::vector required_sliding_gids; - history_gids.reserve(required_groups.size()); - state_gids.reserve(required_groups.size()); - required_sliding_gids.reserve(required_groups.size()); - - // Partition required groups by family; collect sliding-group entries for - // post-validation against `sliding_window_per_group`. - for (const auto& gid : required_groups) { - auto it = paged_cache_allocators_.find(gid); - if (it == paged_cache_allocators_.end() || it->second == nullptr) { - throw std::invalid_argument("HybridPrefixCache::EnablePagedCacheAdjunct: required group '" + gid + - "' missing from registered allocators"); - } - const auto& cfg = it->second->Config(); - const std::int32_t raw_per_page = cfg.RawTokensPerPage(); - if (raw_per_page <= 0) { - throw std::invalid_argument("HybridPrefixCache::EnablePagedCacheAdjunct: required group '" + gid + - "' has non-positive RawTokensPerPage"); - } - if (cfg.family == PagedCacheGroupFamily::History) { - history_gids.push_back(gid); - } else { - state_gids.push_back(gid); - } - if (cfg.retention == PagedCacheGroupConfig::Retention::SlidingWindow) { - auto win_it = sliding_window_per_group.find(gid); - if (win_it == sliding_window_per_group.end() || win_it->second <= 0) { - throw std::invalid_argument("HybridPrefixCache::EnablePagedCacheAdjunct: sliding group '" + gid + - "' missing positive sliding_window entry"); - } - required_sliding_gids.push_back(gid); - } - } - if (history_gids.empty()) { - throw std::invalid_argument( - "HybridPrefixCache::EnablePagedCacheAdjunct: at least one History-family group required"); - } - if (sliding_window_per_group.size() != required_sliding_gids.size()) { - throw std::invalid_argument( - "HybridPrefixCache::EnablePagedCacheAdjunct: sliding_window_per_group keys must exactly " - "match the set of required groups whose retention is SlidingWindow"); - } - - // History alignment = LCM(raw_per_page) across History-family groups. - std::int32_t history_alignment = 1; - for (const auto& gid : history_gids) { - const auto& cfg = paged_cache_allocators_.find(gid)->second->Config(); - history_alignment = std::lcm(history_alignment, cfg.RawTokensPerPage()); - } - // Phase 1: state groups must align with the history alignment (so trailing - // segments are themselves page-aligned). Phase 2 will relax this via replay. - if (policy == StateRestorePolicy::kSnapshotRequired) { - for (const auto& gid : state_gids) { - const auto& cfg = paged_cache_allocators_.find(gid)->second->Config(); - const std::int32_t raw_per_page = cfg.RawTokensPerPage(); - if (history_alignment % raw_per_page != 0) { - throw std::invalid_argument("HybridPrefixCache::EnablePagedCacheAdjunct: state group '" + gid + - "' RawTokensPerPage=" + std::to_string(raw_per_page) + - " does not divide history_alignment=" + std::to_string(history_alignment)); - } - } - } - - paged_cache_history_alignment_tokens_ = history_alignment; - paged_cache_required_groups_ = std::move(required_groups); - paged_cache_sliding_window_per_group_ = std::move(sliding_window_per_group); - paged_cache_history_groups_ = std::move(history_gids); - paged_cache_state_groups_ = std::move(state_gids); - paged_cache_history_group_set_ = - std::unordered_set(paged_cache_history_groups_.begin(), paged_cache_history_groups_.end()); - paged_cache_state_group_set_ = - std::unordered_set(paged_cache_state_groups_.begin(), paged_cache_state_groups_.end()); - paged_cache_state_policy_ = policy; -} - -namespace { - -// Ancestor path (excluding root), reversed so element 0 is closest to root. -std::vector CollectAncestorPathRootToLeaf(TreeNode* from) { - std::vector path; - for (TreeNode* n = from; n != nullptr && !n->IsRoot(); n = n->Parent()) { - path.push_back(n); - } - std::reverse(path.begin(), path.end()); - return path; -} - -} // namespace - -void HybridPrefixCache::augmentMatchPagedCache(MatchResult& match) const { - if (!HasPagedCacheAdjunct()) return; - if (match.device.last_node == nullptr) return; - - const std::int32_t align = paged_cache_history_alignment_tokens_; - - auto cap_to_root = [&]() { - TreeNode* root = match.device.last_node; - while (root != nullptr && !root->IsRoot()) root = root->Parent(); - match.device.last_node = root; - if (match.host.last_node != nullptr) { - TreeNode* h = match.host.last_node; - while (h != nullptr && !h->IsRoot()) h = h->Parent(); - match.host.last_node = h; - } - }; - - std::vector path = CollectAncestorPathRootToLeaf(match.device.last_node); - - // Phase A: history chain. Walk root→leaf, advance only on contiguous - // History-family completeness at every k*align boundary. - TreeNode* deepest_history = nullptr; - std::vector history_chain; - std::int32_t expected_depth = align; - for (TreeNode* n : path) { - const std::int32_t d = static_cast(n->DepthInTokens()); - if (d < expected_depth) continue; - if (d > expected_depth) break; - const auto* snap = n->GetPagedCacheSnapshot(); - if (snap == nullptr) break; - if (!snap->IsCompleteFor(PagedCacheGroupFamily::History)) break; - deepest_history = n; - history_chain.push_back(n); - expected_depth += align; - } - if (deepest_history == nullptr) { - cap_to_root(); - return; - } - - // Phase B: state window. `segments_needed` is the worst-case trailing - // coverage across state groups (so every state group is satisfied at the - // chosen depth). Walk back through history_chain, pick the deepest D' - // whose trailing `segments_needed` history_chain entries all have State - // complete. - std::int32_t worst_window = 0; - for (const auto& gid : paged_cache_state_groups_) { - auto it = paged_cache_sliding_window_per_group_.find(gid); - if (it != paged_cache_sliding_window_per_group_.end()) { - worst_window = std::max(worst_window, it->second); - } - } - const std::int32_t segments_needed = worst_window > 0 ? (worst_window + align - 1) / align : 1; - - TreeNode* usable_node = nullptr; - if (paged_cache_state_groups_.empty()) { - usable_node = deepest_history; - } else { - for (std::int32_t end_idx = static_cast(history_chain.size()) - 1; end_idx >= 0; --end_idx) { - const std::int32_t start_idx = std::max(0, end_idx - segments_needed + 1); - bool ok = true; - for (std::int32_t i = start_idx; i <= end_idx; ++i) { - const auto* snap = history_chain[i]->GetPagedCacheSnapshot(); - if (snap == nullptr || !snap->IsCompleteFor(PagedCacheGroupFamily::State)) { - ok = false; - break; - } - } - if (ok) { - usable_node = history_chain[end_idx]; - break; - } - } - } - if (usable_node == nullptr) { - cap_to_root(); - return; - } - - const std::int32_t usable = static_cast(usable_node->DepthInTokens()); - // Trim history_chain to ancestors up to and including usable_node. - while (!history_chain.empty() && static_cast(history_chain.back()->DepthInTokens()) > usable) { - history_chain.pop_back(); - } - - // Phase C: per-group page-id assembly. History groups take the full chain; - // State groups share a trailing-window slice computed once. - match.paged_cache.last_node = usable_node; - match.paged_cache.prefix_len_tokens = usable; - match.paged_cache.per_group_page_ids.clear(); - match.paged_cache.per_group_base_logical_page.clear(); - - auto assemble = [&](const std::string& gid, std::span chain, bool is_sliding) { - std::vector page_ids; - std::int32_t base_logical_page = 0; - if (!chain.empty()) { - const PagedCacheSnapshot* earliest_snap = chain.front()->GetPagedCacheSnapshot(); - if (earliest_snap != nullptr && is_sliding) { - auto git = earliest_snap->groups.find(gid); - if (git != earliest_snap->groups.end()) { - base_logical_page = git->second.base_logical_page; - } - } - for (TreeNode* anc : chain) { - const PagedCacheSnapshot* snap = anc->GetPagedCacheSnapshot(); - if (snap == nullptr) continue; - auto git = snap->groups.find(gid); - if (git == snap->groups.end()) continue; - const auto& seg_ids = git->second.pages.Ids(); - page_ids.insert(page_ids.end(), seg_ids.begin(), seg_ids.end()); - } - } - match.paged_cache.per_group_page_ids[gid] = std::move(page_ids); - match.paged_cache.per_group_base_logical_page[gid] = base_logical_page; - }; - - const std::span history_span{history_chain}; - for (const auto& gid : paged_cache_history_groups_) { - const bool is_sliding = - paged_cache_sliding_window_per_group_.find(gid) != paged_cache_sliding_window_per_group_.end(); - assemble(gid, history_span, is_sliding); - } - if (!paged_cache_state_groups_.empty()) { - const std::size_t take = std::min(history_chain.size(), static_cast(segments_needed)); - const std::span state_span = history_span.last(take); - for (const auto& gid : paged_cache_state_groups_) { - const bool is_sliding = - paged_cache_sliding_window_per_group_.find(gid) != paged_cache_sliding_window_per_group_.end(); - assemble(gid, state_span, is_sliding); - } - } - - // Cap device/host match nodes to the paged-cache usable depth. - match.device.last_node = usable_node; - if (match.host.last_node != nullptr && static_cast(match.host.last_node->DepthInTokens()) > usable) { - TreeNode* h = match.host.last_node; - while (h != nullptr && !h->IsRoot() && static_cast(h->DepthInTokens()) > usable) { - h = h->Parent(); - } - match.host.last_node = h; - } - - match.paged_cache.restore_kind = MatchResult::PagedCache::RestoreKind::kSnapshotComplete; - match.paged_cache.replay_start_tokens = 0; -} - -std::vector HybridPrefixCache::PagedCacheGroupIds() const { - std::vector ids; - ids.reserve(paged_cache_allocators_.size()); - for (const auto& [gid, _] : paged_cache_allocators_) { - ids.push_back(gid); - } - return ids; -} - -std::int32_t HybridPrefixCache::PagedCacheGroupTotalPages(const std::string& group_id) const { - auto it = paged_cache_allocators_.find(group_id); - if (it == paged_cache_allocators_.end()) { - throw std::out_of_range("HybridPrefixCache::PagedCacheGroupTotalPages: group_id not configured"); - } - return it->second->TotalPages(); -} - -std::int32_t HybridPrefixCache::PagedCacheGroupAvailablePages(const std::string& group_id) const { - auto it = paged_cache_allocators_.find(group_id); - if (it == paged_cache_allocators_.end()) { - throw std::out_of_range("HybridPrefixCache::PagedCacheGroupAvailablePages: group_id not configured"); - } - return it->second->AvailablePages(); -} - -std::int64_t HybridPrefixCache::PagedCacheGroupFailedAllocCount(const std::string& group_id) const { - auto it = paged_cache_allocators_.find(group_id); - if (it == paged_cache_allocators_.end()) { - throw std::out_of_range("HybridPrefixCache::PagedCacheGroupFailedAllocCount: group_id not configured"); - } - return it->second->FailedAllocCount(); +HybridPrefixCache::DeviceMemoryDiagnosticsSnapshot HybridPrefixCache::CollectDeviceMemoryDiagnostics() const { + auto snapshot = Stats({.include_device_memory_diagnostics = true}).device_memory_diagnostics; + _assert(snapshot.has_value(), "HybridPrefixCache::CollectDeviceMemoryDiagnostics: missing stats snapshot"); + return std::move(*snapshot); } std::vector HybridPrefixCache::GetRequestPagedCachePageIds(const std::string& request_id, const std::string& group_id) const { - if (paged_cache_allocators_.find(group_id) == paged_cache_allocators_.end()) { - throw std::out_of_range("HybridPrefixCache::GetRequestPagedCachePageIds: group_id not configured"); - } - auto req_it = request_paged_cache_tables_.find(request_id); - if (req_it == request_paged_cache_tables_.end()) { - return {}; - } - auto group_it = req_it->second.find(group_id); - if (group_it == req_it->second.end()) { - return {}; - } - return group_it->second.PageIds(); + auto snapshot = Stats({.request_id = request_id, .paged_cache_group_ids = {group_id}}); + return snapshot.request_paged_cache_page_ids.at(group_id); } std::int32_t HybridPrefixCache::GetRequestPagedCacheBaseLogicalPage(const std::string& request_id, const std::string& group_id) const { - if (paged_cache_allocators_.find(group_id) == paged_cache_allocators_.end()) { - throw std::out_of_range("HybridPrefixCache::GetRequestPagedCacheBaseLogicalPage: group_id not configured"); - } - auto req_it = request_paged_cache_tables_.find(request_id); - if (req_it == request_paged_cache_tables_.end()) { - return 0; - } - auto group_it = req_it->second.find(group_id); - if (group_it == req_it->second.end()) { - return 0; - } - return group_it->second.BaseLogicalPage(); + auto snapshot = Stats({.request_id = request_id, .paged_cache_group_ids = {group_id}}); + return snapshot.request_paged_cache_base_logical_page.at(group_id); } -std::map HybridPrefixCache::InitialSimulatedFree() const { - std::map out; - for (const auto& [gid, allocator] : paged_cache_allocators_) { - out[gid] = allocator->AvailablePages(); - } - return out; -} +CacheStatsSnapshot HybridPrefixCache::Stats(const StatsRequest& request) const { + CacheStatsSnapshot snapshot{ + .available_device_pages = static_cast(device_allocator_.AvailablePages()), + }; -void HybridPrefixCache::AcquireForRequest(const std::string& request_id, std::int32_t first_raw_position_of_op, - std::int32_t target_raw_tokens_exclusive, - const MatchResult::PagedCache& paged_cache_hit) { - if (paged_cache_allocators_.empty()) return; - auto& tables = request_paged_cache_tables_[request_id]; - const bool has_hit = (paged_cache_hit.last_node != nullptr) && (paged_cache_hit.prefix_len_tokens > 0); - for (const auto& [group_id, allocator] : paged_cache_allocators_) { - auto it = tables.find(group_id); - const bool fresh_table = (it == tables.end()); - if (fresh_table) { - it = tables.emplace(group_id, PagedCacheGroupTable(allocator.get())).first; - // Import borrowed-prefix BEFORE ReleaseSkipped/Acquire on a fresh table. - if (has_hit) { - auto pid_it = paged_cache_hit.per_group_page_ids.find(group_id); - if (pid_it != paged_cache_hit.per_group_page_ids.end() && !pid_it->second.empty()) { - std::int32_t base_logical_page = 0; - auto base_it = paged_cache_hit.per_group_base_logical_page.find(group_id); - if (base_it != paged_cache_hit.per_group_base_logical_page.end()) { - base_logical_page = base_it->second; - } - std::vector page_ids_copy = pid_it->second; - it->second.ImportPrefixBorrowed(std::move(page_ids_copy), base_logical_page, - paged_cache_hit.prefix_len_tokens); + snapshot.paged_cache_group_ids.reserve(paged_cache_allocators_.size()); + for (const auto& [gid, _] : paged_cache_allocators_) { + snapshot.paged_cache_group_ids.push_back(gid); + } + + std::vector requested_groups = request.paged_cache_group_ids; + if (requested_groups.empty()) { + requested_groups = snapshot.paged_cache_group_ids; + } + for (const auto& gid : requested_groups) { + auto alloc_it = paged_cache_allocators_.find(gid); + if (alloc_it == paged_cache_allocators_.end() || alloc_it->second == nullptr) { + throw std::out_of_range("HybridPrefixCache::Stats: group_id not configured"); + } + snapshot.paged_cache_total_pages[gid] = alloc_it->second->TotalPages(); + snapshot.paged_cache_available_pages[gid] = alloc_it->second->AvailablePages(); + snapshot.paged_cache_failed_alloc_count[gid] = alloc_it->second->FailedAllocCount(); + + if (request.request_id.has_value()) { + std::vector pages; + std::int32_t base_logical_page = 0; + auto req_it = request_paged_cache_tables_.find(*request.request_id); + if (req_it != request_paged_cache_tables_.end()) { + auto group_it = req_it->second.find(gid); + if (group_it != req_it->second.end()) { + pages = group_it->second.PageIds(); + base_logical_page = group_it->second.BaseLogicalPage(); } } + snapshot.request_paged_cache_page_ids[gid] = std::move(pages); + snapshot.request_paged_cache_base_logical_page[gid] = base_logical_page; } - const auto& cfg = allocator->Config(); - if (cfg.retention == PagedCacheGroupConfig::Retention::SlidingWindow && cfg.sliding_window_tokens.has_value()) { - const std::int32_t lower = std::max(0, first_raw_position_of_op - *cfg.sliding_window_tokens + 1); - it->second.ReleaseSkipped(lower); - } - it->second.Acquire(target_raw_tokens_exclusive); } -} -void HybridPrefixCache::ReleaseRequest(const std::string& request_id) { - auto it = request_paged_cache_tables_.find(request_id); - if (it != request_paged_cache_tables_.end()) { - for (auto& [_, table] : it->second) { - table.ReleaseAll(); - } - request_paged_cache_tables_.erase(it); + if (request.include_device_memory_diagnostics) { + snapshot.device_memory_diagnostics = CacheDeviceMemoryDiagnosticsSnapshot{ + .tree_device_pages = kv_prefix_cache_.CollectAllPages(), + .free_device_pages = device_allocator_.AvailablePages(), + .total_device_pages = device_allocator_.TotalPages() - 1, + }; } - DemoteIdleMambaDeviceCopiesPresentOnHost(); -} -void HybridPrefixCache::PopulateOp(ForwardOperationBase& op_base) const { - if (paged_cache_allocators_.empty()) return; - auto req_it = request_paged_cache_tables_.find(op_base.request_id); - for (const auto& [gid, allocator] : paged_cache_allocators_) { - std::vector pages; - std::int32_t base_offset = 0; - if (req_it != request_paged_cache_tables_.end()) { - auto table_it = req_it->second.find(gid); - if (table_it != req_it->second.end()) { - pages = table_it->second.PageIds(); - base_offset = table_it->second.BaseLogicalPage(); - } - } - op_base.paged_cache_pages[gid] = std::move(pages); - if (allocator->Config().retention == PagedCacheGroupConfig::Retention::SlidingWindow) { - op_base.paged_cache_page_base_offsets[gid] = base_offset; - } - } + return snapshot; } -HybridPrefixCache::PagedCacheGroupAdmission HybridPrefixCache::checkPagedCacheGroupAdmission( - const std::string& request_id, std::int32_t first_raw_position_of_op, std::int32_t target_raw_tokens_exclusive, - const std::map& simulated_free, const MatchResult::PagedCache& paged_cache_hit, - const PagedCacheAdmissionContext& context) const { - PagedCacheGroupAdmission result; - if (paged_cache_allocators_.empty() || target_raw_tokens_exclusive < 0) { - return result; - } - - auto req_it = - context.fresh_table_view ? request_paged_cache_tables_.end() : request_paged_cache_tables_.find(request_id); - const bool has_hit = (paged_cache_hit.last_node != nullptr) && (paged_cache_hit.prefix_len_tokens > 0); - for (const auto& [gid, allocator] : paged_cache_allocators_) { - const auto& cfg = allocator->Config(); - const std::int32_t raw_per_page = cfg.RawTokensPerPage(); - if (cfg.entry_stride_tokens <= 0 || cfg.rows_per_page <= 0 || raw_per_page <= 0) { - continue; +void HybridPrefixCache::PrepareForwardOp(ForwardOperationBase& op_base, const CacheOpPrepareRequest& request) { + auto require_match = [&]() -> const MatchResult& { + if (request.match_result == nullptr) { + throw std::invalid_argument( + "HybridPrefixCache::PrepareForwardOp requires match_result for this prepare kind"); } + return *request.match_result; + }; - const std::int32_t entries = CeilDivPositive(target_raw_tokens_exclusive, cfg.entry_stride_tokens); - const std::int32_t required = (entries + cfg.rows_per_page - 1) / cfg.rows_per_page; - - std::int32_t current_size = 0; - std::int32_t current_active = 0; - std::int32_t borrowed_in_table = 0; - std::int32_t owned_in_table = 0; - std::int32_t already_released = 0; - bool table_exists = false; - if (req_it != request_paged_cache_tables_.end()) { - auto t_it = req_it->second.find(gid); - if (t_it != req_it->second.end()) { - table_exists = true; - current_size = t_it->second.Size(); - current_active = t_it->second.ActivePagesCount(); - borrowed_in_table = t_it->second.BorrowedPagesCount(); - owned_in_table = t_it->second.OwnedPagesCount(); - already_released = t_it->second.ReleasedPagesCount(); - } - } + PopulateMambaRequestLocalCompatibilityFields(op_base, request.local_mamba_allocator); - std::int32_t borrowed_count = 0; - std::int32_t borrowed_base = 0; - if (has_hit && !table_exists) { - auto pid_it = paged_cache_hit.per_group_page_ids.find(gid); - if (pid_it != paged_cache_hit.per_group_page_ids.end()) { - borrowed_count = static_cast(pid_it->second.size()); - } - auto base_it = paged_cache_hit.per_group_base_logical_page.find(gid); - if (base_it != paged_cache_hit.per_group_base_logical_page.end()) { - borrowed_base = base_it->second; - } + switch (request.kind) { + case CacheOpPrepareKind::kPrefillFirstChunk: { + const MatchResult& match_result = require_match(); + PopulateMambaMatchCompatibilityFields(op_base, match_result); + CommitChunk(op_base.request_id, request.terminal); + acquireAndPopulateOp(op_base, request.first_raw_position_of_op, request.target_raw_tokens_exclusive, + match_result.paged_cache); + return; } - - std::int32_t releasable_total = 0; - std::int32_t releasable_owned = 0; - if (cfg.retention == PagedCacheGroupConfig::Retention::SlidingWindow && cfg.sliding_window_tokens.has_value()) { - const std::int32_t lower = std::max(0, first_raw_position_of_op - *cfg.sliding_window_tokens + 1); - const std::int32_t target_releases = lower / raw_per_page; - const std::int32_t logical_released_base = table_exists ? already_released : borrowed_base; - releasable_total = std::max(0, target_releases - logical_released_base); - releasable_total = std::min(releasable_total, current_active + borrowed_count); - - // Borrowed pages drop the index only (no pool credit); only the - // owned-prefix slice contributes to releasable_owned. - const std::int32_t borrowed_present_total = table_exists ? borrowed_in_table : borrowed_count; - releasable_owned = releasable_total - std::min(releasable_total, borrowed_present_total); - if (table_exists) { - releasable_owned = std::min(releasable_owned, owned_in_table); + case CacheOpPrepareKind::kPrefillChunk: + CommitChunk(op_base.request_id, request.terminal); + acquireAndPopulateOp(op_base, request.first_raw_position_of_op, request.target_raw_tokens_exclusive, + request.paged_cache_hit); + return; + case CacheOpPrepareKind::kDecode: + if (request.commit_prior_chunk) { + CommitChunk(op_base.request_id, request.terminal); } - } - - const std::int32_t absolute_have = - table_exists ? (already_released + current_size) : (borrowed_base + borrowed_count); - const std::int32_t new_pages = std::max(0, required - absolute_have); - std::int32_t free = allocator->AvailablePages(); - auto sf_it = simulated_free.find(gid); - if (sf_it != simulated_free.end()) { - free = sf_it->second; - } - auto credit_it = context.owned_release_credit.find(gid); - if (credit_it != context.owned_release_credit.end()) { - free += credit_it->second; - } - - result.releasable_owned_pages[gid] = releasable_owned; - result.new_pages_needed[gid] = new_pages; - if (free + releasable_owned < new_pages) { - result.ok = false; + acquireAndPopulateOp(op_base, request.first_raw_position_of_op, request.target_raw_tokens_exclusive, {}); + return; + case CacheOpPrepareKind::kDecodeFromRetracted: { + const MatchResult& match_result = require_match(); + PopulateMambaRecoveryCompatibilityFields(op_base, match_result); + ReleaseRequest(op_base.request_id); + acquireAndPopulateOp(op_base, 0, request.target_raw_tokens_exclusive, match_result.paged_cache); + return; } } - return result; } -void HybridPrefixCache::applyPagedCacheGroupAdmissionDebit(std::map& simulated_free, - const PagedCacheGroupAdmission& admission) { - for (const auto& [gid, releasable_owned] : admission.releasable_owned_pages) { - simulated_free[gid] += releasable_owned; - } - for (const auto& [gid, new_pages] : admission.new_pages_needed) { - simulated_free[gid] -= new_pages; +AdmissionVerdict HybridPrefixCache::Admit(const AdmissionRequest& request, + std::map& simulated_free) { + const MatchResult* match = request.compat_match; + if (match == nullptr && request.recovery_plan != nullptr) { + match = &request.recovery_plan->compat_match; } -} -HybridPrefixCache::AdmissionFailureKind HybridPrefixCache::ClassifyAdmissionFailure( - const PagedCacheGroupAdmission& admission) const { - if (admission.ok) return AdmissionFailureKind::kNone; - bool history_starved = false; - bool state_starved = false; - for (const auto& [gid, needed] : admission.new_pages_needed) { - if (needed <= 0) continue; - if (paged_cache_history_group_set_.find(gid) != paged_cache_history_group_set_.end()) { - history_starved = true; - } - if (paged_cache_state_group_set_.find(gid) != paged_cache_state_group_set_.end()) { - state_starved = true; - } + CacheAdmissionKind kind = CacheAdmissionKind::kDecodeChunk; + if (request.protect_host_match_node || request.host_pages_needed > 0) { + kind = CacheAdmissionKind::kRetract; + } else if (request.fresh_request_table_view) { + kind = CacheAdmissionKind::kDecodeFromRetracted; + } else if (request.compute_branching_checkpoint) { + kind = CacheAdmissionKind::kPrefillFirstChunk; + } else if (request.auxiliary_tree_slots_needed > 0) { + kind = CacheAdmissionKind::kPrefillChunk; } - if (history_starved && state_starved) return AdmissionFailureKind::kBothStarved; - if (history_starved) return AdmissionFailureKind::kHistoryStarved; - if (state_starved) return AdmissionFailureKind::kStateStarved; - return AdmissionFailureKind::kNone; -} -void HybridPrefixCache::refreshPagedCacheSimulatedFree(std::map& simulated_free) const { - for (const auto& [gid, allocator] : paged_cache_allocators_) { - simulated_free[gid] = allocator->AvailablePages(); - } + CacheAdmissionRequest compat_request{ + .kind = kind, + .request_id = request.request_id, + .device_pages_needed = request.device_pages_needed, + .host_pages_needed = request.host_pages_needed, + .tokens_this_round = request.tokens_this_round, + .first_raw_position_of_op = request.first_raw_position_of_op, + .target_raw_tokens_exclusive = request.target_raw_tokens_exclusive, + .match_result = match, + .mamba_recovery_node = request.protected_recovery_node, + .refresh_mamba_checkpoint = request.refresh_mamba_checkpoint, + }; + CacheAdmissionResult compat_result = Admit(compat_request, simulated_free); + return AdmissionVerdict{ + .admitted = compat_result.admitted, + .mamba_branching_seqlen = compat_result.mamba_branching_seqlen, + .mamba_cow_src_index = compat_result.mamba_cow_src_index, + .cache_transfer_pairs = std::move(compat_result.cache_transfer_pairs), + .demands = request.demands, + }; } -bool HybridPrefixCache::admitPagedCacheChunk(const std::string& request_id, std::int32_t first_raw_position_of_op, - std::int32_t target_raw_tokens_exclusive, - std::map& simulated_free, - const MatchResult::PagedCache& paged_cache_hit, - const PagedCacheAdmissionContext& context) { - PagedCacheGroupAdmission admission = checkPagedCacheGroupAdmission( - request_id, first_raw_position_of_op, target_raw_tokens_exclusive, simulated_free, paged_cache_hit, context); - const std::size_t prune_budget = paged_cache_snapshot_nodes_.size(); - for (std::size_t pruned = 0; !admission.ok && pruned < prune_budget; ++pruned) { - AdmissionFailureKind kind = ClassifyAdmissionFailure(admission); - if (kind == AdmissionFailureKind::kNone) break; - if (!tryPrunePagedCacheSnapshot(kind)) break; - refreshPagedCacheSimulatedFree(simulated_free); - admission = checkPagedCacheGroupAdmission(request_id, first_raw_position_of_op, target_raw_tokens_exclusive, - simulated_free, paged_cache_hit, context); - } - if (!admission.ok) return false; - for (const auto& [gid, credit] : context.owned_release_credit) { - simulated_free[gid] += credit; - } - applyPagedCacheGroupAdmissionDebit(simulated_free, admission); - return true; -} +StepCommitResult HybridPrefixCache::StepCommit(StepCommitRequest request) { + StepCommitResult result{}; + auto match_or_plan = [](const RecoveryPlan* recovery_plan, const MatchResult* compat_match) -> const MatchResult* { + if (compat_match != nullptr) return compat_match; + if (recovery_plan != nullptr) return &recovery_plan->compat_match; + return nullptr; + }; -bool HybridPrefixCache::DetachStateSnapshotFromNode(TreeNode* node) { - if (node == nullptr) return false; - PagedCacheSnapshot* snap = node->GetPagedCacheSnapshotMut(); - if (snap == nullptr) return false; - bool removed_any = false; - for (const auto& gid : paged_cache_state_groups_) { - auto it = snap->groups.find(gid); - if (it != snap->groups.end()) { - snap->groups.erase(it); - removed_any = true; + if (request.materialize_prefix.has_value()) { + const auto& materialize = *request.materialize_prefix; + const MatchResult* match = match_or_plan(materialize.recovery_plan, materialize.compat_match); + const CacheMaterializationKind kind = materialize.require_all_pages + ? CacheMaterializationKind::kDecodeRecoveryHostPrefixOnDevice + : CacheMaterializationKind::kPrefillHostPrefixOnDevice; + result.ok = Materialize({ + .kind = kind, + .match_result = match, + }) + .ok; + if (!result.ok) return result; + } + + if (request.publish_device_prefix.has_value()) { + const auto& publish = *request.publish_device_prefix; + (void)Publish({ + .kind = CachePublicationKind::kForwardChunk, + .full_paged_tokens = publish.full_paged_tokens, + .device_node_ref = publish.device_node_ref, + .local_kv_allocator = publish.local_kv_allocator, + .local_mamba_allocator = publish.local_mamba_allocator, + }); + } + + if (request.publish_finished_request.has_value()) { + const auto& publish = *request.publish_finished_request; + auto publish_result = Publish({ + .kind = CachePublicationKind::kFinishChunk, + .full_paged_tokens = publish.full_paged_tokens, + .current_device_node = publish.current_device_node, + .local_kv_allocator = publish.local_kv_allocator, + .local_mamba_allocator = publish.local_mamba_allocator, + .page_hashes = publish.page_hashes, + }); + result.match_result = std::move(publish_result.match_result); + result.device_insert_page_count = publish_result.device_insert_page_count; + } + + if (request.plan_device_prefix_insertion.has_value()) { + const auto& plan = *request.plan_device_prefix_insertion; + auto publish_result = Publish({ + .kind = CachePublicationKind::kRetractDeviceInsertPageCount, + .full_paged_tokens = plan.full_paged_tokens, + .current_device_node = plan.current_device_node, + }); + result.device_insert_page_count = publish_result.device_insert_page_count; + } + + if (request.publish_device_prefix_insertion.has_value()) { + auto& publish = *request.publish_device_prefix_insertion; + auto publish_result = Publish({ + .kind = CachePublicationKind::kRetractChunk, + .full_paged_tokens = publish.full_paged_tokens, + .current_device_node = publish.current_device_node, + .pages_to_insert = std::move(publish.pages_to_insert), + }); + result.match_result = std::move(publish_result.match_result); + result.device_insert_page_count = publish_result.device_insert_page_count; + } + + if (request.materialize_host_writeback.has_value()) { + const auto& materialize = *request.materialize_host_writeback; + const CacheMaterializationKind kind = materialize.ensure_capacity_before_allocate + ? CacheMaterializationKind::kFinishWritebackHostPages + : CacheMaterializationKind::kRetractWritebackHostPages; + result.ok = Materialize({ + .kind = kind, + .write_diff = materialize.write_diff, + }) + .ok; + if (!result.ok) return result; + if (materialize.write_diff != nullptr) { + result.cache_transfer_pairs = PrepareMambaHostWriteBack(*materialize.write_diff); + for (const TransferPair& transfer : result.cache_transfer_pairs) { + if (transfer.kind != CacheKind::kMamba) continue; + for (TreeNode* node : *materialize.write_diff) { + if (node != nullptr && node->HasMamba() && node->MambaSlotIndex() == transfer.src) { + result.mamba_writeback_nodes.push_back(node); + break; + } + } + } } } - if (!removed_any) return false; - snap->complete_families.erase(PagedCacheGroupFamily::State); - // If nothing remains, fall through to full detach to keep invariants tidy. - if (snap->groups.empty()) { - DetachPagedCacheSnapshotFromNode(node); + + if (request.publish_tree_owned_request_state.has_value()) { + const auto& publish = *request.publish_tree_owned_request_state; + if (publish.local_mamba_allocator_owner == nullptr) { + throw std::invalid_argument( + "HybridPrefixCache::StepCommit publish_tree_owned_request_state requires " + "local_mamba_allocator_owner"); + } + PublishRetractMambaState(publish.terminal, *publish.local_mamba_allocator_owner); } - return true; -} -bool HybridPrefixCache::tryPrunePagedCacheSnapshot(AdmissionFailureKind kind) { - if (!HasPagedCacheAdjunct()) return false; - if (kind == AdmissionFailureKind::kNone) return false; + const bool strict_mamba_create = request.request_local_mamba.has_value() && + request.request_local_mamba->create_allocator && + request.request_local_mamba->require_allocator; - auto is_pinned = [](TreeNode* node) { - for (TreeNode* cur = node; cur != nullptr && !cur->IsRoot(); cur = cur->Parent()) { - if (!cur->OnDevice()) continue; - if (cur->Device().RefCount() > 0) return true; + auto apply_request_local_mamba = [&]() { + if (!request.request_local_mamba.has_value()) return; + const auto& mamba = *request.request_local_mamba; + if (mamba.create_allocator) { + auto mamba_result = PrepareRequestLocalMamba({ + .kind = mamba.require_allocator ? RequestLocalMambaKind::kDecodeFromRetracted + : RequestLocalMambaKind::kPrefillFirstChunk, + .checkpoint_raw_position = mamba.checkpoint_raw_position, + }); + result.local_mamba_allocator = std::move(mamba_result.local_mamba_allocator); } - return false; - }; - - // Sort once and share between branches: oldest first, then deepest within - // same Time(). Both try_state_only and try_full walk this same order. - std::vector candidates; - candidates.reserve(paged_cache_snapshot_nodes_.size()); - for (TreeNode* node : paged_cache_snapshot_nodes_) { - if (node == nullptr) continue; - if (!node->HasPagedCacheSnapshot()) continue; - candidates.push_back(node); - } - std::sort(candidates.begin(), candidates.end(), [](TreeNode* a, TreeNode* b) { - if (a->Time() != b->Time()) return a->Time() < b->Time(); - return a->DepthInTokens() > b->DepthInTokens(); - }); - - auto try_state_only = [&]() { - for (TreeNode* node : candidates) { - if (is_pinned(node)) continue; - const auto* snap = node->GetPagedCacheSnapshot(); - if (snap == nullptr) continue; - if (!snap->IsCompleteFor(PagedCacheGroupFamily::State)) continue; - if (DetachStateSnapshotFromNode(node)) return true; + if (mamba.refresh_checkpoint_allocator != nullptr) { + (void)PrepareRequestLocalMamba({ + .kind = RequestLocalMambaKind::kNextCheckpoint, + .local_mamba_allocator = mamba.refresh_checkpoint_allocator, + .checkpoint_raw_position = mamba.checkpoint_raw_position, + }); } - return false; }; - auto try_full = [&]() { - TreeNode* victim = nullptr; - for (TreeNode* node : candidates) { - if (is_pinned(node)) continue; - victim = node; - break; - } - if (victim == nullptr) return false; - const std::size_t victim_depth = victim->DepthInTokens(); - auto primary = DetachPagedCacheSnapshotFromNode(victim); - (void)primary; - std::vector descendants; - for (TreeNode* node : paged_cache_snapshot_nodes_) { - if (node == nullptr || node == victim) continue; - if (!node->HasPagedCacheSnapshot()) continue; - if (node->DepthInTokens() <= victim_depth) continue; - for (TreeNode* cur = node->Parent(); cur != nullptr && !cur->IsRoot(); cur = cur->Parent()) { - if (cur == victim) { - descendants.push_back(node); - break; - } - } + auto apply_request_local_kv = [&]() { + if (!request.request_local_kv.has_value()) return; + const auto& kv = *request.request_local_kv; + if (kv.create_allocator) { + auto kv_result = PrepareRequestLocalKV({ + .kind = RequestLocalKVKind::kPrefillFirstChunk, + .tokens_this_round = kv.initial_tokens, + .decode_input_tokens = kv.acquire_tokens, + }); + result.local_kv_allocator = std::move(kv_result.local_kv_allocator); + return; } - for (TreeNode* d : descendants) { - if (is_pinned(d)) continue; - auto cascaded = DetachPagedCacheSnapshotFromNode(d); - (void)cascaded; + if (kv.allocator != nullptr || kv.acquire_tokens > 0) { + (void)PrepareRequestLocalKV({ + .kind = RequestLocalKVKind::kPrefillChunk, + .local_kv_allocator = kv.allocator, + .tokens_this_round = kv.acquire_tokens, + }); } - return true; }; - // kBothStarved: state-only cannot solve history shortage; go straight to - // full. The outer admit loop will re-classify if state still needs more. - switch (kind) { - case AdmissionFailureKind::kStateStarved: - return try_state_only(); - case AdmissionFailureKind::kHistoryStarved: - case AdmissionFailureKind::kBothStarved: - return try_full(); - case AdmissionFailureKind::kNone: - return false; + if (strict_mamba_create) { + apply_request_local_mamba(); + apply_request_local_kv(); + } else { + apply_request_local_kv(); + apply_request_local_mamba(); + } + + if (request.worker_metadata.has_value()) { + const auto& worker = *request.worker_metadata; + if (worker.op_base == nullptr) { + throw std::invalid_argument("HybridPrefixCache::StepCommit PrepareWorkerOp requires op_base"); + } + CacheOpPrepareKind kind = CacheOpPrepareKind::kDecode; + if (worker.populate_recovery_metadata || worker.release_request_state_before_acquire) { + kind = CacheOpPrepareKind::kDecodeFromRetracted; + } else if (worker.populate_prefix_reuse_metadata) { + kind = CacheOpPrepareKind::kPrefillFirstChunk; + } else if (worker.import_paged_cache_hit) { + kind = CacheOpPrepareKind::kPrefillChunk; + } + const MatchResult* match = match_or_plan(worker.recovery_plan, worker.compat_match); + PrepareForwardOp(*worker.op_base, + { + .kind = kind, + .terminal = worker.terminal, + .first_raw_position_of_op = worker.first_raw_position_of_op, + .target_raw_tokens_exclusive = worker.target_raw_tokens_exclusive, + .match_result = match, + .local_mamba_allocator = worker.local_mamba_allocator_view, + .paged_cache_hit = match == nullptr ? worker.paged_cache_hit : match->paged_cache, + .commit_prior_chunk = worker.commit_tree_prefix_before_acquire, + }); } - return false; -} -bool HybridPrefixCache::AdmitChunk(const std::string& request_id, std::int32_t first_raw_position_of_op, - std::int32_t target_raw_tokens_exclusive, - std::map& simulated_free, - const MatchResult::PagedCache& paged_cache_hit) { - return admitPagedCacheChunk(request_id, first_raw_position_of_op, target_raw_tokens_exclusive, simulated_free, - paged_cache_hit, {}); + return result; } -bool HybridPrefixCache::AdmitChunkFromRetracted(const std::string& request_id, std::int32_t target_raw_tokens_exclusive, - std::map& simulated_free, - const MatchResult::PagedCache& paged_cache_hit) { - PagedCacheAdmissionContext context{.fresh_table_view = true}; - auto req_it = request_paged_cache_tables_.find(request_id); - if (req_it != request_paged_cache_tables_.end()) { - for (const auto& [gid, table] : req_it->second) { - context.owned_release_credit[gid] = table.OwnedPagesCount(); +HybridPrefixCache::CacheAdmissionResult HybridPrefixCache::Admit(const CacheAdmissionRequest& request, + std::map& simulated_free) { + CacheAdmissionResult result{}; + auto require_match = [&]() -> const MatchResult& { + if (request.match_result == nullptr) { + throw std::invalid_argument("HybridPrefixCache::Admit requires match_result for this admission kind"); } - } - return admitPagedCacheChunk(request_id, 0, target_raw_tokens_exclusive, simulated_free, paged_cache_hit, context); -} - -void HybridPrefixCache::CommitChunk(const std::string& request_id, TreeNode* terminal) { - if (!HasPagedCacheAdjunct()) return; - if (terminal == nullptr) return; - - auto tables_it = request_paged_cache_tables_.find(request_id); - if (tables_it == request_paged_cache_tables_.end()) return; - auto& tables = tables_it->second; - - const std::int32_t lcm = paged_cache_history_alignment_tokens_; - if (lcm <= 0) return; - const auto& required_groups = paged_cache_required_groups_; - if (required_groups.empty()) return; - - auto canonical_it = tables.find(required_groups.front()); - if (canonical_it == tables.end()) return; - std::int32_t last_committed = canonical_it->second.CommittedPrefixLenTokens(); - - const std::int32_t chunk_depth = static_cast(terminal->DepthInTokens()); - if (chunk_depth <= 0) return; - - while (last_committed + lcm <= chunk_depth) { - const std::int32_t target = last_committed + lcm; - - TreeNode* attach_node = kv_prefix_cache_.GetRadixTree().SplitAt(terminal, target); - if (attach_node == nullptr) break; + return *request.match_result; + }; + auto mamba_device_loadback_nodes = [this](const MatchResult& match_result, TreeNode* preferred_source = nullptr) { + std::vector nodes; + if (mamba_host_allocator_ == nullptr || match_result.mamba_host_src_index < 0 || + match_result.mamba_cow_src_index >= 0) { + return nodes; + } + TreeNode* host_mamba_node = + preferred_source != nullptr ? preferred_source : FindLastMambaHostNode(match_result.host.last_node); + if (host_mamba_node != nullptr && host_mamba_node->HasMambaOnHost() && !host_mamba_node->HasMamba()) { + nodes.push_back(host_mamba_node); + } + return nodes; + }; - if (attach_node->HasPagedCacheSnapshot()) { - bool covered = true; - for (const auto& gid : required_groups) { - auto t_it = tables.find(gid); - if (t_it == tables.end()) { - covered = false; - break; + switch (request.kind) { + case CacheAdmissionKind::kPrefillFirstChunk: { + const MatchResult& match_result = require_match(); + std::unique_ptr temp_lock = std::make_unique(match_result.device.last_node); + if (!kv_prefix_cache_.EnsureCapacityByEvict(request.device_pages_needed)) { + return result; + } + std::vector loadback_nodes = mamba_device_loadback_nodes(match_result); + std::optional mamba_branching_seqlen; + if (HasMambaAdjunct()) { + if (match_result.mamba_branching_seqlen == -1) { + const std::int32_t aligned = AlignMambaCacheSeqlen(request.tokens_this_round); + if (aligned > 0) { + mamba_branching_seqlen = aligned; + } } - if (t_it->second.CommittedPrefixLenTokens() < target) { - covered = false; - break; + const std::int32_t slots_needed = 2 + static_cast(loadback_nodes.size()); + if (!EnsureMambaCapacityByEvict(slots_needed)) { + return result; } } - if (!covered) { - spdlog::warn( - "[HybridPrefixCache] CommitChunk: target depth {} already has a paged-cache " - "snapshot but request {} has uncommitted owned pages in [{}, {}); leaving " - "existing snapshot intact", - target, request_id, last_committed, target); - break; + result.admitted = AdmitChunk(request.request_id, request.first_raw_position_of_op, + request.target_raw_tokens_exclusive, simulated_free, match_result.paged_cache); + if (result.admitted) { + result.mamba_branching_seqlen = mamba_branching_seqlen; + result.cache_transfer_pairs = PrepareMambaDeviceLoadBack(loadback_nodes); + if (!loadback_nodes.empty() && loadback_nodes.front()->HasMamba()) { + result.mamba_cow_src_index = loadback_nodes.front()->MambaSlotIndex(); + } } - last_committed = target; - continue; + return result; } - - bool preflight_ok = true; - for (const auto& gid : required_groups) { - auto t_it = tables.find(gid); - if (t_it == tables.end()) { - preflight_ok = false; - break; + case CacheAdmissionKind::kPrefillChunk: + if (!kv_prefix_cache_.EnsureCapacityByEvict(request.device_pages_needed)) { + return result; } - const auto& table = t_it->second; - const std::int32_t raw_per_page = table.RawTokensPerPage(); - if (raw_per_page <= 0) { - preflight_ok = false; - break; + if (HasMambaAdjunct() && !EnsureMambaCapacityByEvict(1)) { + return result; } - if (table.CommittedPrefixLenTokens() % raw_per_page != 0) { - preflight_ok = false; - break; + result.admitted = AdmitChunk(request.request_id, request.first_raw_position_of_op, + request.target_raw_tokens_exclusive, simulated_free); + return result; + case CacheAdmissionKind::kDecodeChunk: + if (!kv_prefix_cache_.EnsureCapacityByEvict(request.device_pages_needed)) { + return result; } - if (target % raw_per_page != 0) { - preflight_ok = false; - break; + if (request.refresh_mamba_checkpoint && HasMambaAdjunct() && !EnsureMambaCapacityByEvict(1)) { + return result; } - if (target <= table.CommittedPrefixLenTokens()) { - preflight_ok = false; - break; + result.admitted = AdmitChunk(request.request_id, request.first_raw_position_of_op, + request.target_raw_tokens_exclusive, simulated_free); + return result; + case CacheAdmissionKind::kDecodeFromRetracted: { + const MatchResult& match_result = require_match(); + std::unique_ptr temp_lock = std::make_unique(match_result.device.last_node); + if (!kv_prefix_cache_.EnsureCapacityByEvict(request.device_pages_needed)) { + return result; } - if (target > table.RawTokenCursor()) { - preflight_ok = false; - break; + std::vector loadback_nodes = + mamba_device_loadback_nodes(match_result, request.mamba_recovery_node); + if (HasMambaAdjunct()) { + if (request.mamba_recovery_node == nullptr) { + return result; + } + // Recovery COWs the tree-owned Mamba state into fresh + // request-local working/checkpoint slots. Protect the source + // node only for this allocation; retracted Mamba states are + // otherwise normal evictable tree-owned cache entries. + const std::int32_t slots_needed = 2 + static_cast(loadback_nodes.size()); + if (!EnsureMambaCapacityByEvict(slots_needed, request.mamba_recovery_node)) { + return result; + } } - } - if (!preflight_ok) { - spdlog::warn( - "[HybridPrefixCache] CommitChunk: preflight failed for request {} at target " - "depth {}; leaving prior commits intact", - request_id, target); - break; - } - - auto snapshot = std::make_unique(); - snapshot->prefix_len_tokens = target; - for (const auto& gid : required_groups) { - auto& table = tables.find(gid)->second; - auto group_alloc_it = paged_cache_allocators_.find(gid); - const auto& cfg = group_alloc_it->second->Config(); - auto result = cfg.family == PagedCacheGroupFamily::History ? table.CommitHistoryToSnapshot(target) - : table.CheckpointStateToSnapshot(target); - PagedCacheGroupSnapshot group_snap{}; - group_snap.pages = std::move(result.pages); - group_snap.base_logical_page = result.segment_base_logical_page; - group_snap.raw_token_cursor = table.RawTokenCursor(); - group_snap.sliding = table.IsSliding(); - snapshot->groups.emplace(gid, std::move(group_snap)); - } - - bool snapshot_complete = true; - for (const auto& gid : required_groups) { - if (snapshot->groups.find(gid) == snapshot->groups.end()) { - snapshot_complete = false; - break; + result.admitted = AdmitChunkFromRetracted(request.request_id, request.target_raw_tokens_exclusive, + simulated_free, match_result.paged_cache); + if (result.admitted) { + result.cache_transfer_pairs = PrepareMambaDeviceLoadBack(loadback_nodes); + if (!loadback_nodes.empty() && loadback_nodes.front()->HasMamba()) { + result.mamba_cow_src_index = loadback_nodes.front()->MambaSlotIndex(); + } } + return result; + } + case CacheAdmissionKind::kRetract: { + const MatchResult& match_result = require_match(); + std::unique_ptr temp_lock = std::make_unique(match_result.host.last_node); + result.admitted = kv_prefix_cache_.EnsureCapacityByEvict(request.host_pages_needed); + return result; } - _assert(snapshot_complete, - "HybridPrefixCache::CommitChunk: built snapshot missing a required group after " - "preflight+commit; invariant violated"); - const bool attached = AttachPagedCacheSnapshotToNode(attach_node, std::move(snapshot)); - _assert(attached, - "HybridPrefixCache::CommitChunk: attach rejected a non-null snapshot on a non-null " - "node; invariant violated"); - - last_committed = target; } + return result; } } // namespace tokenspeed diff --git a/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache.h b/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache.h index a519427ce..bdc196821 100644 --- a/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache.h +++ b/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache.h @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -31,7 +32,9 @@ #include #include +#include "resource/allocator/owned_pages.h" #include "resource/allocator/paged_cache_group.h" +#include "resource/hybrid_prefix_cache/hybrid_prefix_cache_types.h" #include "resource/hybrid_prefix_cache/mamba_eviction_manager.h" #include "resource/radix_tree/mamba_slot.h" #include "scheduler/operations/cache.h" @@ -42,35 +45,49 @@ namespace tokenspeed { class MambaChunkAllocator; class MambaHostAllocator; +class LocalKVAllocator; +class LocalMambaAllocator; +class PageAllocator; class ForwardOperationBase; class HybridPrefixCache { public: // `mamba_allocator` may be null; paged-cache adjunct is enabled separately. - HybridPrefixCache(KVPrefixCache& prefix_cache, MambaChunkAllocator* allocator, std::int32_t mamba_cache_chunk_size, - MambaHostAllocator* mamba_host_allocator = nullptr); + HybridPrefixCache(KVPrefixCache& prefix_cache, PageAllocator& device_allocator, MambaChunkAllocator* allocator, + std::int32_t mamba_cache_chunk_size, MambaHostAllocator* mamba_host_allocator = nullptr); + ~HybridPrefixCache(); + + RecoveryPlan MatchPrefix(const token_vec_t& token_ids, MatchIntent intent = MatchIntent::PrefixReuse); + RecoveryPlan MatchPrefix(const std::vector>& token_pages, + MatchIntent intent = MatchIntent::PrefixReuse); + [[nodiscard]] AdmissionVerdict Admit(const AdmissionRequest& request, + std::map& simulated_free); + StepCommitResult StepCommit(StepCommitRequest request); + [[nodiscard]] CacheStatsSnapshot Stats(const StatsRequest& request = {}) const; + const FamilyRegistry& Registry() const { return family_registry_; } + + struct RawHostStorageHashSeed { + std::int32_t host_matched_pages{0}; + std::string prior_hash_seed{}; + }; - MatchResult Match(const token_vec_t& token_ids, MatchIntent intent = MatchIntent::PrefixReuse); - MatchResult Match(const std::vector>& token_pages, - MatchIntent intent = MatchIntent::PrefixReuse); + // Cold-path storage rolling-hash seed lookup. This intentionally returns + // only raw host KV match depth plus the terminal host page-hash seed; it + // does not apply Mamba/paged-cache recovery augmentation. + RawHostStorageHashSeed LookupRawHostStorageHashSeed(const std::vector>& token_pages); - bool EnsureMambaCapacityByEvict(std::int32_t num_slots, TreeNode* protected_node = nullptr); - void InsertMamba(TreeNode* terminal_node, std::unique_ptr slot); - std::int32_t AlignMambaCacheSeqlen(std::int32_t seqlen) const; - TreeNode* FindLastMambaNode(TreeNode* from) const; - TreeNode* FindLastMambaHostNode(TreeNode* from) const; - bool EnsureMambaHostCapacityByEvict(std::int32_t num_slots, TreeNode* protected_node = nullptr); - std::vector PrepareMambaHostWriteBack(const std::vector& nodes); - std::vector PrepareMambaDeviceLoadBack(const std::vector& nodes); - void OnKVHostEvict(TreeNode* node); - void OnKVDeviceDemote(TreeNode* node); - void OnMambaHostWriteBackDone(TreeNode* last_node); - void OnMambaHostWriteBackDone(const std::vector& nodes); - void DemoteIdleMambaDeviceCopiesPresentOnHost(); + cache_op_id AllocateCacheOpId(); + void SetKvEventSink(KvEventSink sink); // Takes ownership. Duplicate group_id throws std::invalid_argument. void RegisterPagedCacheGroup(std::unique_ptr allocator); + // Startup-only scheduler configuration facade. Copies and validates group + // configs, registers concrete paged-cache group allocators, and optionally + // enables prefix-cache adjunct state for the required groups. + void ConfigurePagedCacheAdjunct(std::span group_configs, + std::optional> required_groups); + // History alignment is the LCM of RawTokensPerPage() over the History-family // groups; state groups only need the trailing window. Sliding groups must // have a window entry; full-history groups must not. @@ -78,17 +95,19 @@ class HybridPrefixCache { std::unordered_map sliding_window_per_group, StateRestorePolicy policy = StateRestorePolicy::kSnapshotRequired); - bool HasMambaAdjunct() const { return mamba_allocator_ != nullptr; } - bool HasPagedCacheAdjunct() const { return paged_cache_history_alignment_tokens_ > 0; } - std::int32_t PagedCacheHistoryAlignmentTokens() const { return paged_cache_history_alignment_tokens_; } - const std::vector& PagedCacheRequiredGroups() const { return paged_cache_required_groups_; } - // Group introspection: throws std::out_of_range on unknown group_id. std::vector PagedCacheGroupIds() const; std::int32_t PagedCacheGroupTotalPages(const std::string& group_id) const; std::int32_t PagedCacheGroupAvailablePages(const std::string& group_id) const; std::int64_t PagedCacheGroupFailedAllocCount(const std::string& group_id) const; + // Standard-KV device-page introspection. These are narrow read-only + // facades over the wrapped KV tree and concrete device allocator; they do + // not aggregate adjunct family capacity or mutate cache state. + using DeviceMemoryDiagnosticsSnapshot = CacheDeviceMemoryDiagnosticsSnapshot; + std::size_t AvailableDevicePages() const; + DeviceMemoryDiagnosticsSnapshot CollectDeviceMemoryDiagnostics() const; + // Per-request introspection: unknown group_id throws; unknown request_id returns empty. std::vector GetRequestPagedCachePageIds(const std::string& request_id, const std::string& group_id) const; @@ -100,51 +119,184 @@ class HybridPrefixCache { // Initial per-group simulated_free budget mirroring live allocator state. std::map InitialSimulatedFree() const; - // Ensure tables exist and cover [first_raw_position_of_op, target_raw_tokens_exclusive). - // Borrowed prefix is imported BEFORE any fresh allocation on a fresh table. - void AcquireForRequest(const std::string& request_id, std::int32_t first_raw_position_of_op, - std::int32_t target_raw_tokens_exclusive, - const MatchResult::PagedCache& paged_cache_hit = {}); + // Finish scheduler lifecycle for a request: release request-local + // paged-cache tables/state only. Shared TreeNode attachments remain owned by + // prefix-cache refcount/LRU/eviction paths. + void FinishRequest(const std::string& request_id); - // Owned pages return to the pool via OwnedPages RAII; borrowed ids are dropped. - void ReleaseRequest(const std::string& request_id); + // Callback from KV prefix-cache eviction. + void OnKVEvict(TreeNode* node); + void OnKVDeviceDemote(TreeNode* node); + void OnMambaHostWriteBackDone(TreeNode* last_node); + void OnMambaHostWriteBackDone(const std::vector& nodes); + void DemoteIdleMambaDeviceCopiesPresentOnHost(); - // Fill op.paged_cache_pages / op.paged_cache_page_base_offsets from the tables. - void PopulateOp(ForwardOperationBase& op_base) const; +private: + friend class HybridPrefixCacheTestPeer; - // Run admission against `simulated_free`; prunes evictable snapshots on - // group-pool pressure, then applies the debit on success. - bool AdmitChunk(const std::string& request_id, std::int32_t first_raw_position_of_op, - std::int32_t target_raw_tokens_exclusive, std::map& simulated_free, - const MatchResult::PagedCache& paged_cache_hit = {}); + struct DecodeFromRetractedRecovery { + bool ok{true}; + TreeNode* protected_source_node{nullptr}; + }; - // Retract-decode variant: admission uses a fresh-table view and credits - // pages owned by the stale table before it is released. - bool AdmitChunkFromRetracted(const std::string& request_id, std::int32_t target_raw_tokens_exclusive, - std::map& simulated_free, - const MatchResult::PagedCache& paged_cache_hit); + enum class CacheAdmissionKind { + kPrefillFirstChunk, + kPrefillChunk, + kDecodeChunk, + kDecodeFromRetracted, + kRetract, + }; - // Commit newly-written full LCM segments into TreeNode PagedCacheSnapshots. - void CommitChunk(const std::string& request_id, TreeNode* terminal); + struct CacheAdmissionRequest { + CacheAdmissionKind kind{CacheAdmissionKind::kDecodeChunk}; + std::string request_id{}; + std::int32_t device_pages_needed{0}; + std::int32_t host_pages_needed{0}; + std::int32_t tokens_this_round{0}; + std::int32_t first_raw_position_of_op{0}; + std::int32_t target_raw_tokens_exclusive{0}; + const MatchResult* match_result{nullptr}; + TreeNode* mamba_recovery_node{nullptr}; + bool refresh_mamba_checkpoint{false}; + }; - // Attach a snapshot to `node`, computing `complete_families` from which - // required-per-family group ids are present and registering the node in - // `paged_cache_snapshot_nodes_`. Returns false when either argument is - // null (defensive no-op). Accepts partial snapshots; the per-policy - // "snapshot must be full" invariant is enforced upstream by CommitChunk. - bool AttachPagedCacheSnapshotToNode(TreeNode* node, std::unique_ptr snapshot); + struct CacheAdmissionResult { + bool admitted{false}; + std::optional mamba_branching_seqlen{}; + std::optional mamba_cow_src_index{}; + std::vector cache_transfer_pairs{}; + }; - // Drops `node` from the membership set, then detaches and returns the snapshot. - std::unique_ptr DetachPagedCacheSnapshotFromNode(TreeNode* node); + enum class RequestLocalKVKind { + kPrefillFirstChunk, + kPrefillChunk, + kDecodeReserve, + kDecodeFromRetractedReserve, + }; - // Callback from KV prefix-cache eviction. - void OnKVEvict(TreeNode* node); + struct RequestLocalKVRequest { + RequestLocalKVKind kind{RequestLocalKVKind::kPrefillChunk}; + LocalKVAllocator* local_kv_allocator{nullptr}; + std::int32_t tokens_this_round{0}; + std::int32_t decode_input_tokens{0}; + std::int32_t reserve_tokens{0}; + }; + + struct RequestLocalKVResult { + std::unique_ptr local_kv_allocator{}; + }; - std::int32_t AvailableSlots() const; - KVPrefixCache& GetKVPrefixCache() { return kv_prefix_cache_; } + enum class RequestLocalMambaKind { + kPrefillFirstChunk, + kDecodeFromRetracted, + kNextCheckpoint, + }; -private: - friend class HybridPrefixCacheTestPeer; + struct RequestLocalMambaRequest { + RequestLocalMambaKind kind{RequestLocalMambaKind::kNextCheckpoint}; + LocalMambaAllocator* local_mamba_allocator{nullptr}; + std::optional checkpoint_raw_position{}; + }; + + struct RequestLocalMambaResult { + std::unique_ptr local_mamba_allocator{}; + }; + + enum class CachePublicationKind { + kForwardChunk, + kFinishChunk, + kRetractDeviceInsertPageCount, + kRetractChunk, + }; + + struct CachePublicationRequest { + CachePublicationKind kind{CachePublicationKind::kForwardChunk}; + const std::vector>* full_paged_tokens{nullptr}; + std::unique_ptr* device_node_ref{nullptr}; + const TreeNode* current_device_node{nullptr}; + LocalKVAllocator* local_kv_allocator{nullptr}; + LocalMambaAllocator* local_mamba_allocator{nullptr}; + const std::vector* page_hashes{nullptr}; + OwnedPages pages_to_insert{}; + }; + + struct CachePublicationResult { + MatchResult match_result{}; + std::int32_t device_insert_page_count{0}; + }; + + enum class CacheMaterializationKind { + kPrefillHostPrefixOnDevice, + kDecodeRecoveryHostPrefixOnDevice, + kFinishWritebackHostPages, + kRetractWritebackHostPages, + }; + + struct CacheMaterializationRequest { + CacheMaterializationKind kind{CacheMaterializationKind::kPrefillHostPrefixOnDevice}; + const MatchResult* match_result{nullptr}; + const std::vector* write_diff{nullptr}; + }; + + struct CacheMaterializationResult { + bool ok{true}; + }; + + enum class CacheOpPrepareKind { + kPrefillFirstChunk, + kPrefillChunk, + kDecode, + kDecodeFromRetracted, + }; + + struct CacheOpPrepareRequest { + CacheOpPrepareKind kind{CacheOpPrepareKind::kDecode}; + TreeNode* terminal{nullptr}; + std::int32_t first_raw_position_of_op{0}; + std::int32_t target_raw_tokens_exclusive{0}; + const MatchResult* match_result{nullptr}; + const LocalMambaAllocator* local_mamba_allocator{nullptr}; + MatchResult::PagedCache paged_cache_hit{}; + bool commit_prior_chunk{false}; + }; + + [[nodiscard]] CacheAdmissionResult Admit(const CacheAdmissionRequest& request, + std::map& simulated_free); + DecodeFromRetractedRecovery PrepareDecodeFromRetractedRecovery(MatchResult& match_result) const; + [[nodiscard]] RequestLocalKVResult PrepareRequestLocalKV(const RequestLocalKVRequest& request) const; + [[nodiscard]] RequestLocalMambaResult PrepareRequestLocalMamba(const RequestLocalMambaRequest& request) const; + CachePublicationResult Publish(CachePublicationRequest request); + [[nodiscard]] CacheMaterializationResult Materialize(const CacheMaterializationRequest& request); + void PublishRetractMambaState(TreeNode* terminal, std::unique_ptr& local_mamba_allocator); + void PrepareForwardOp(ForwardOperationBase& op_base, const CacheOpPrepareRequest& request); + + bool HasMambaAdjunct() const { return mamba_allocator_ != nullptr; } + bool HasPagedCacheAdjunct() const { return paged_cache_history_alignment_tokens_ > 0; } + void PopulateMambaMatchCompatibilityFields(ForwardOperationBase& op_base, const MatchResult& match_result) const; + void PopulateMambaRecoveryCompatibilityFields(ForwardOperationBase& op_base, const MatchResult& match_result) const; + void PopulateMambaRequestLocalCompatibilityFields(ForwardOperationBase& op_base, + const LocalMambaAllocator* local_mamba_allocator) const; + bool EnsureMambaCapacityByEvict(std::int32_t num_slots, TreeNode* protected_node = nullptr); + bool EnsureMambaHostCapacityByEvict(std::int32_t num_slots, TreeNode* protected_node = nullptr); + std::int32_t AlignMambaCacheSeqlen(std::int32_t seqlen) const; + void InsertMamba(TreeNode* terminal_node, std::unique_ptr slot); + TreeNode* FindLastMambaNode(TreeNode* from) const; + TreeNode* FindLastMambaHostNode(TreeNode* from) const; + std::vector PrepareMambaHostWriteBack(const std::vector& nodes); + std::vector PrepareMambaDeviceLoadBack(const std::vector& nodes); + void OnKVHostEvict(TreeNode* node); + // Publish request-local Mamba state for finish after the caller has inserted + // new terminal KV pages. The caller owns the "new KV pages were inserted" + // gate so finish publication remains coupled to successful KV insertion. + void PublishFinishMambaState(const std::vector>& full_paged_tokens, + LocalMambaAllocator* local_mamba_allocator); + + // Fill op.paged_cache_pages / op.paged_cache_page_base_offsets from the tables. + void PopulateOp(ForwardOperationBase& op_base) const; + std::unique_ptr allocateRequestLocalMambaState( + std::optional checkpoint_raw_position = {}) const; + void RebuildFamilyRegistry(); + void BuildRecoveryPlanSlices(RecoveryPlan& plan) const; // Per-family classification of admission failure; drives state-only vs // full prune strategy. @@ -154,6 +306,7 @@ class HybridPrefixCache { bool ok{true}; std::map releasable_owned_pages{}; std::map new_pages_needed{}; + std::map shortfall_pages{}; }; struct PagedCacheAdmissionContext { @@ -168,8 +321,43 @@ class HybridPrefixCache { // remains and the node stays registered. Returns true iff state groups removed. bool DetachStateSnapshotFromNode(TreeNode* node); + // Ensure tables exist and cover [first_raw_position_of_op, target_raw_tokens_exclusive). + // Borrowed prefix is imported BEFORE any fresh allocation on a fresh table. + void AcquireForRequest(const std::string& request_id, std::int32_t first_raw_position_of_op, + std::int32_t target_raw_tokens_exclusive, + const MatchResult::PagedCache& paged_cache_hit = {}); + + // Owned pages return to the pool via OwnedPages RAII; borrowed ids are dropped. + void ReleaseRequest(const std::string& request_id); + + // Run paged-cache admission against `simulated_free`; prunes evictable + // snapshots on group-pool pressure, then applies the debit on success. + bool AdmitChunk(const std::string& request_id, std::int32_t first_raw_position_of_op, + std::int32_t target_raw_tokens_exclusive, std::map& simulated_free, + const MatchResult::PagedCache& paged_cache_hit = {}); + + // Retract-decode variant: admission uses a fresh-table view and credits + // pages owned by the stale table before it is released. + bool AdmitChunkFromRetracted(const std::string& request_id, std::int32_t target_raw_tokens_exclusive, + std::map& simulated_free, + const MatchResult::PagedCache& paged_cache_hit); + + // Commit newly-written full LCM segments into TreeNode PagedCacheSnapshots. + void CommitChunk(const std::string& request_id, TreeNode* terminal); + + // Attach a snapshot to `node`, computing `complete_families` from which + // required-per-family group ids are present and registering the node in + // `paged_cache_snapshot_nodes_`. Returns false when either argument is + // null (defensive no-op). Accepts partial snapshots; the per-policy + // "snapshot must be full" invariant is enforced upstream by CommitChunk. + bool AttachPagedCacheSnapshotToNode(TreeNode* node, std::unique_ptr snapshot); + + // Drops `node` from the membership set, then detaches and returns the snapshot. + std::unique_ptr DetachPagedCacheSnapshotFromNode(TreeNode* node); + void augmentMatch(MatchResult& match) const; void augmentMatchPagedCache(MatchResult& match) const; + bool publishRequestMambaState(TreeNode* terminal, LocalMambaAllocator* local_mamba_allocator); // Detach oldest evictable snapshot to free pool pages. State-only path is // used only on kStateStarved; history/both go to full cascade. @@ -180,6 +368,8 @@ class HybridPrefixCache { std::map& simulated_free, const MatchResult::PagedCache& paged_cache_hit, const PagedCacheAdmissionContext& context); + void acquireAndPopulateOp(ForwardOperationBase& op_base, std::int32_t first_raw_position_of_op, + std::int32_t target_raw_tokens_exclusive, const MatchResult::PagedCache& paged_cache_hit); // Build admission record without mutating any table. PagedCacheGroupAdmission checkPagedCacheGroupAdmission(const std::string& request_id, @@ -195,6 +385,7 @@ class HybridPrefixCache { void refreshPagedCacheSimulatedFree(std::map& simulated_free) const; KVPrefixCache& kv_prefix_cache_; + PageAllocator& device_allocator_; MambaChunkAllocator* mamba_allocator_; MambaHostAllocator* mamba_host_allocator_; MambaEvictionManager mamba_eviction_manager_; @@ -202,6 +393,8 @@ class HybridPrefixCache { std::unordered_set mamba_host_nodes_; std::unordered_map> pending_mamba_host_writebacks_; std::unordered_set mamba_host_writeback_done_nodes_; + bool has_facade_kv_event_sink_{false}; + FamilyRegistry family_registry_; // `paged_cache_history_alignment_tokens_ == 0` means adjunct disabled; tables still work. std::map> paged_cache_allocators_; diff --git a/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache_types.h b/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache_types.h new file mode 100644 index 000000000..bf29a8cc0 --- /dev/null +++ b/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache_types.h @@ -0,0 +1,293 @@ +// Copyright (c) 2026 LightSeek Foundation +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "resource/allocator/owned_pages.h" +#include "resource/types.h" +#include "scheduler/operations/cache.h" + +namespace tokenspeed { + +class ForwardOperationBase; +class LocalKVAllocator; +class LocalMambaAllocator; +class TreeNode; + +enum class CacheFamily { + TokenPage, + CompressedPage, + SlidingWindowState, + CompressionTailState, + RecurrentState, + ConvState, +}; + +enum class TreeAttachmentKind { + ReusableTree, + NoneForRequestLocal, +}; + +enum class Recoverability { + Exact, + AlignedCheckpoint, + WindowRepairable, + RequestLocalOnly, +}; + +enum class PublicationKind { + CanonicalPrefixIndex, + AuxiliaryLocalOnly, + RequestLocalOnly, +}; + +enum class SplitPolicy { + CarrierKV, + CheckpointBoundary, + SnapshotBoundary, + RequestLocalOnly, +}; + +struct CacheResourceSpec { + std::string id; + std::int32_t family_index{-1}; + CacheFamily family{CacheFamily::TokenPage}; + TreeAttachmentKind attachment_kind{TreeAttachmentKind::NoneForRequestLocal}; + Recoverability recoverability{Recoverability::RequestLocalOnly}; + PublicationKind publication{PublicationKind::RequestLocalOnly}; + SplitPolicy split_policy{SplitPolicy::RequestLocalOnly}; + std::int32_t rows_per_page{0}; + std::int32_t entry_stride_tokens{0}; + std::int32_t checkpoint_chunk_tokens{0}; + std::optional sliding_window_tokens{}; + std::string state_cohort_id{}; + bool required_for_recovery{false}; +}; + +struct FamilyRegistry { + std::vector specs; + std::vector active_match_family_indices; + std::vector active_admit_family_indices; + std::vector active_commit_family_indices; + std::vector active_evict_family_indices; + std::vector active_finish_family_indices; + std::vector active_stats_family_indices; + std::vector active_compatibility_family_indices; + + void Clear(); + const CacheResourceSpec* FindById(const std::string& id) const; + const CacheResourceSpec& At(std::int32_t family_index) const; + std::int32_t Register(CacheResourceSpec spec, bool active_match, bool active_admit, bool active_commit, + bool active_evict, bool active_finish, bool active_stats, bool active_compatibility); + +private: + std::unordered_map id_to_index_; +}; + +struct FamilySlice { + std::int32_t family_index{-1}; + std::string family_id{}; + CacheFamily family{CacheFamily::TokenPage}; + TreeNode* hit_node{nullptr}; + std::int32_t recoverable_end_tokens{0}; + std::int32_t replay_from_tokens{0}; + std::int32_t replay_to_tokens{0}; + std::vector borrowed_ids{}; + std::int32_t base_logical_page{0}; + bool required_for_recovery{false}; +}; + +struct RecoveryPlan { + std::int32_t raw_token_match_end_tokens{0}; + std::int32_t recoverable_prefix_end_tokens{0}; + std::int32_t execution_resume_tokens{0}; + bool recovery_state_available{true}; + TreeNode* protected_recovery_node{nullptr}; + std::vector slices{}; + MatchResult compat_match{}; +}; + +struct ResourceDemand { + std::int32_t family_index{-1}; + std::string family_id{}; + std::int32_t new_units_needed{0}; + std::int32_t releasable_units{0}; + std::int32_t borrowed_prefix_units{0}; + std::string state_cohort_id{}; +}; + +struct AdmissionRequest { + std::string request_id{}; + std::int32_t device_pages_needed{0}; + std::int32_t host_pages_needed{0}; + std::int32_t tokens_this_round{0}; + std::int32_t first_raw_position_of_op{0}; + std::int32_t target_raw_tokens_exclusive{0}; + const RecoveryPlan* recovery_plan{nullptr}; + const MatchResult* compat_match{nullptr}; + TreeNode* protected_recovery_node{nullptr}; + std::int32_t auxiliary_tree_slots_needed{0}; + bool protect_host_match_node{false}; + bool fresh_request_table_view{false}; + bool compute_branching_checkpoint{false}; + bool refresh_mamba_checkpoint{false}; + std::vector demands{}; +}; + +struct AdmissionVerdict { + bool admitted{false}; + std::optional mamba_branching_seqlen{}; + std::optional mamba_cow_src_index{}; + std::vector cache_transfer_pairs{}; + std::vector demands{}; +}; + +struct PrefixMaterializationRequest { + const RecoveryPlan* recovery_plan{nullptr}; + const MatchResult* compat_match{nullptr}; + bool require_all_pages{false}; +}; + +struct RequestLocalKVStateRequest { + bool create_allocator{false}; + LocalKVAllocator* allocator{nullptr}; + std::int32_t initial_tokens{0}; + std::int32_t acquire_tokens{0}; +}; + +struct RequestLocalMambaStateRequest { + bool create_allocator{false}; + bool require_allocator{false}; + LocalMambaAllocator* refresh_checkpoint_allocator{nullptr}; + std::optional checkpoint_raw_position{}; +}; + +struct DevicePrefixPublicationRequest { + const std::vector>* full_paged_tokens{nullptr}; + std::unique_ptr* device_node_ref{nullptr}; + LocalKVAllocator* local_kv_allocator{nullptr}; + LocalMambaAllocator* local_mamba_allocator{nullptr}; +}; + +struct FinishedRequestPublicationRequest { + const std::vector>* full_paged_tokens{nullptr}; + const TreeNode* current_device_node{nullptr}; + LocalKVAllocator* local_kv_allocator{nullptr}; + LocalMambaAllocator* local_mamba_allocator{nullptr}; + const std::vector* page_hashes{nullptr}; +}; + +struct DevicePrefixInsertionPlanRequest { + const std::vector>* full_paged_tokens{nullptr}; + const TreeNode* current_device_node{nullptr}; +}; + +struct DevicePrefixInsertionRequest { + const std::vector>* full_paged_tokens{nullptr}; + const TreeNode* current_device_node{nullptr}; + OwnedPages pages_to_insert{}; +}; + +struct HostWritebackMaterializationRequest { + const std::vector* write_diff{nullptr}; + bool ensure_capacity_before_allocate{false}; +}; + +struct TreeOwnedRequestStatePublicationRequest { + TreeNode* terminal{nullptr}; + std::unique_ptr* local_mamba_allocator_owner{nullptr}; +}; + +struct WorkerCompatibilityCommitRequest { + ForwardOperationBase* op_base{nullptr}; + TreeNode* terminal{nullptr}; + const RecoveryPlan* recovery_plan{nullptr}; + const MatchResult* compat_match{nullptr}; + const LocalMambaAllocator* local_mamba_allocator_view{nullptr}; + MatchResult::PagedCache paged_cache_hit{}; + std::int32_t first_raw_position_of_op{0}; + std::int32_t target_raw_tokens_exclusive{0}; + bool commit_tree_prefix_before_acquire{false}; + bool import_paged_cache_hit{false}; + bool populate_prefix_reuse_metadata{false}; + bool populate_recovery_metadata{false}; + bool release_request_state_before_acquire{false}; +}; + +struct StepCommitRequest { + std::optional materialize_prefix{}; + std::optional publish_device_prefix{}; + std::optional publish_finished_request{}; + std::optional plan_device_prefix_insertion{}; + std::optional publish_device_prefix_insertion{}; + std::optional materialize_host_writeback{}; + std::optional publish_tree_owned_request_state{}; + std::optional request_local_kv{}; + std::optional request_local_mamba{}; + std::optional worker_metadata{}; +}; + +struct StepCommitResult { + bool ok{true}; + MatchResult match_result{}; + std::int32_t device_insert_page_count{0}; + std::unique_ptr local_kv_allocator{}; + std::unique_ptr local_mamba_allocator{}; + std::vector cache_transfer_pairs{}; + std::vector mamba_writeback_nodes{}; +}; + +struct CacheDeviceMemoryDiagnosticsSnapshot { + std::unordered_map tree_device_pages{}; + std::int32_t free_device_pages{0}; + // Usable device pages; page id 0 remains reserved/invalid. + std::int32_t total_device_pages{0}; +}; + +struct StatsRequest { + std::optional request_id{}; + std::vector paged_cache_group_ids{}; + bool include_device_memory_diagnostics{false}; +}; + +struct CacheStatsSnapshot { + std::size_t available_device_pages{0}; + std::vector paged_cache_group_ids{}; + std::map paged_cache_total_pages{}; + std::map paged_cache_available_pages{}; + std::map paged_cache_failed_alloc_count{}; + std::map> request_paged_cache_page_ids{}; + std::map request_paged_cache_base_logical_page{}; + std::optional device_memory_diagnostics{}; +}; + +} // namespace tokenspeed diff --git a/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/mamba_family_ops.cpp b/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/mamba_family_ops.cpp new file mode 100644 index 000000000..33a7d500f --- /dev/null +++ b/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/mamba_family_ops.cpp @@ -0,0 +1,367 @@ +// Copyright (c) 2026 LightSeek Foundation +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "resource/hybrid_prefix_cache/mamba_family_ops.h" + +#include "resource/allocator/local_mamba_allocator.h" +#include "resource/allocator/mamba_chunk_allocator.h" +#include "resource/allocator/mamba_host_allocator.h" +#include "resource/hybrid_prefix_cache/hybrid_prefix_cache.h" +#include "resource/radix_tree/mamba_slot.h" +#include "resource/radix_tree/node_range.h" +#include "resource/radix_tree/tree_node.h" +#include "scheduler/operations/forward.h" + +#include + +#include +#include +#include + +namespace tokenspeed { +namespace hybrid_prefix_cache::detail { + +std::int32_t AlignMambaCacheSeqlen(std::int32_t seqlen, std::int32_t chunk_size) { + if (chunk_size <= 0) return seqlen; + return (seqlen / chunk_size) * chunk_size; +} + +TreeNode* FindLastMambaNode(TreeNode* from) { + for (TreeNode* cur = from; cur != nullptr && !cur->IsRoot(); cur = cur->Parent()) { + if (cur->HasMamba()) return cur; + } + return nullptr; +} + +TreeNode* FindLastMambaHostNode(TreeNode* from) { + for (TreeNode* cur = from; cur != nullptr && !cur->IsRoot(); cur = cur->Parent()) { + if (cur->HasMambaOnHost()) return cur; + } + return nullptr; +} + +} // namespace hybrid_prefix_cache::detail + +HybridPrefixCache::DecodeFromRetractedRecovery HybridPrefixCache::PrepareDecodeFromRetractedRecovery( + MatchResult& match_result) const { + if (!HasMambaAdjunct()) return {}; + + TreeNode* mamba_recovery_node = FindLastMambaNode(match_result.host.last_node); + if (mamba_recovery_node != nullptr) { + match_result.mamba_cow_src_index = mamba_recovery_node->MambaSlotIndex(); + return {.protected_source_node = mamba_recovery_node}; + } + + TreeNode* host_mamba_recovery_node = FindLastMambaHostNode(match_result.host.last_node); + if (host_mamba_recovery_node == nullptr) { + return {.ok = false}; + } + match_result.mamba_host_src_index = host_mamba_recovery_node->MambaHostSlotIndex(); + match_result.mamba_cow_src_index = -1; + return {.protected_source_node = host_mamba_recovery_node}; +} + +void HybridPrefixCache::PopulateMambaMatchCompatibilityFields(ForwardOperationBase& op_base, + const MatchResult& match_result) const { + if (!HasMambaAdjunct()) return; + op_base.mamba_cow_src_idx = match_result.mamba_cow_src_index; + op_base.mamba_branching_seqlen = match_result.mamba_branching_seqlen; +} + +void HybridPrefixCache::PopulateMambaRecoveryCompatibilityFields(ForwardOperationBase& op_base, + const MatchResult& match_result) const { + if (!HasMambaAdjunct()) return; + op_base.mamba_cow_src_idx = match_result.mamba_cow_src_index; +} + +void HybridPrefixCache::PopulateMambaRequestLocalCompatibilityFields( + ForwardOperationBase& op_base, const LocalMambaAllocator* local_mamba_allocator) const { + if (!HasMambaAdjunct() || local_mamba_allocator == nullptr) return; + if (local_mamba_allocator->HasWorking()) { + op_base.mamba_working_idx = local_mamba_allocator->WorkingIndex(); + } + if (local_mamba_allocator->HasCheckpoint()) { + op_base.mamba_checkpoint_dst_idx = local_mamba_allocator->CheckpointIndex(); + } +} + +HybridPrefixCache::RequestLocalMambaResult HybridPrefixCache::PrepareRequestLocalMamba( + const RequestLocalMambaRequest& request) const { + RequestLocalMambaResult result{}; + const auto should_materialize_checkpoint = [&]() { + if (!request.checkpoint_raw_position.has_value()) return true; + const std::int32_t position = *request.checkpoint_raw_position; + return position > 0 && AlignMambaCacheSeqlen(position) == position; + }; + switch (request.kind) { + case RequestLocalMambaKind::kPrefillFirstChunk: + result.local_mamba_allocator = allocateRequestLocalMambaState(request.checkpoint_raw_position); + return result; + case RequestLocalMambaKind::kDecodeFromRetracted: + result.local_mamba_allocator = allocateRequestLocalMambaState(); + if (HasMambaAdjunct() && result.local_mamba_allocator == nullptr) { + throw std::logic_error("ScheduleDecodeFromRetractedEvent: failed to allocate Mamba recovery slots"); + } + return result; + case RequestLocalMambaKind::kNextCheckpoint: + if (!HasMambaAdjunct() || request.local_mamba_allocator == nullptr) return result; + if (!should_materialize_checkpoint()) { + (void)request.local_mamba_allocator->DetachCheckpoint(); + return result; + } + (void)request.local_mamba_allocator->AllocateCheckpoint(request.checkpoint_raw_position.value_or(-1)); + return result; + } + return result; +} + +std::unique_ptr HybridPrefixCache::allocateRequestLocalMambaState( + std::optional checkpoint_raw_position) const { + if (!HasMambaAdjunct()) return nullptr; + + auto local_mamba_allocator = std::make_unique(mamba_allocator_); + if (!local_mamba_allocator->AllocateWorking() || + !local_mamba_allocator->AllocateCheckpoint(checkpoint_raw_position.value_or(-1))) { + return nullptr; + } + return local_mamba_allocator; +} + +void HybridPrefixCache::PublishFinishMambaState(const std::vector>& full_paged_tokens, + LocalMambaAllocator* local_mamba_allocator) { + if (!HasMambaAdjunct() || local_mamba_allocator == nullptr || !local_mamba_allocator->HasCheckpoint()) { + return; + } + MatchResult post_match = kv_prefix_cache_.Match(full_paged_tokens); + TreeNode* terminal = post_match.device.last_node; + if (terminal == nullptr || terminal->HasMamba()) return; + const std::int32_t checkpoint_position = local_mamba_allocator->CheckpointPosition(); + if (checkpoint_position >= 0 && checkpoint_position != static_cast(terminal->DepthInTokens())) { + return; + } + InsertMamba(terminal, local_mamba_allocator->DetachCheckpoint()); +} + +void HybridPrefixCache::PublishRetractMambaState(TreeNode* terminal, + std::unique_ptr& local_mamba_allocator) { + if (local_mamba_allocator == nullptr) return; + + const bool had_request_local_mamba = local_mamba_allocator->HasCheckpoint() || local_mamba_allocator->HasWorking(); + if (!had_request_local_mamba) return; + + if (HasMambaAdjunct()) { + publishRequestMambaState(terminal, local_mamba_allocator.get()); + } + + // Once retracted, any recoverable Mamba state is tree-owned and therefore + // evictable by HybridPrefixCache. Do not keep request-local slots alive in + // Retracting/Retracted. + local_mamba_allocator.reset(); +} + +bool HybridPrefixCache::publishRequestMambaState(TreeNode* terminal, LocalMambaAllocator* local_mamba_allocator) { + if (!HasMambaAdjunct() || terminal == nullptr || terminal->HasMamba() || local_mamba_allocator == nullptr) { + return false; + } + if (local_mamba_allocator->HasCheckpoint()) { + InsertMamba(terminal, local_mamba_allocator->DetachCheckpoint()); + return true; + } + if (local_mamba_allocator->HasWorking()) { + InsertMamba(terminal, local_mamba_allocator->DetachWorking()); + return true; + } + return false; +} + +void HybridPrefixCache::augmentMatch(MatchResult& match) const { + if (mamba_allocator_ == nullptr) return; + TreeNode* root = match.device.last_node; + while (root != nullptr && !root->IsRoot()) root = root->Parent(); + if (root == nullptr) return; + + if (mamba_host_allocator_ == nullptr) { + const std::int32_t page_size = match.device.page_size > 0 ? match.device.page_size : match.host.page_size; + const std::int32_t kv_depth = std::max(match.device.DepthInPage(), match.host.DepthInPage()); + TreeNode* device_mamba_node = FindLastMambaNode(match.device.last_node); + TreeNode* host_mamba_node = device_mamba_node == nullptr ? FindLastMambaNode(match.host.last_node) : nullptr; + TreeNode* mamba_node = device_mamba_node != nullptr ? device_mamba_node : host_mamba_node; + if (mamba_node == nullptr) { + const std::int32_t aligned_seqlen = AlignMambaCacheSeqlen(kv_depth * page_size); + if (aligned_seqlen > 0) { + match.mamba_branching_seqlen = aligned_seqlen; + } + match.device.last_node = root; + match.host.last_node = root; + return; + } + + const std::int32_t mamba_depth = mamba_node->DepthInPage(page_size); + match.mamba_cow_src_index = mamba_node->MambaSlotIndex(); + if (kv_depth > mamba_depth) { + const std::int32_t aligned_seqlen = AlignMambaCacheSeqlen(kv_depth * page_size); + if (aligned_seqlen > mamba_depth * page_size) { + match.mamba_branching_seqlen = aligned_seqlen; + } + } + match.device.last_node = device_mamba_node != nullptr ? mamba_node : root; + match.host.last_node = mamba_node; + return; + } + + const std::int32_t page_size = match.device.page_size > 0 ? match.device.page_size : match.host.page_size; + const std::int32_t kv_depth = std::max(match.device.DepthInPage(), match.host.DepthInPage()); + TreeNode* device_mamba_node = FindLastMambaNode(match.device.last_node); + TreeNode* host_mamba_node = FindLastMambaHostNode(match.host.last_node); + const std::int32_t device_mamba_depth = + device_mamba_node == nullptr ? 0 : device_mamba_node->DepthInPage(page_size); + const std::int32_t host_mamba_depth = host_mamba_node == nullptr ? 0 : host_mamba_node->DepthInPage(page_size); + const bool prefer_host_mamba = host_mamba_depth > device_mamba_depth; + std::int32_t mamba_depth = 0; + + if (device_mamba_node != nullptr) { + match.device.last_node = device_mamba_node; + if (!prefer_host_mamba) { + match.mamba_cow_src_index = device_mamba_node->MambaSlotIndex(); + } + mamba_depth = std::max(mamba_depth, device_mamba_depth); + } else { + match.device.last_node = root; + } + + if (host_mamba_node != nullptr) { + match.host.last_node = host_mamba_node; + match.mamba_host_src_index = host_mamba_node->MambaHostSlotIndex(); + if (prefer_host_mamba) { + match.mamba_cow_src_index = -1; + } + mamba_depth = std::max(mamba_depth, host_mamba_depth); + } else { + match.host.last_node = root; + } + + if (kv_depth > mamba_depth) { + const std::int32_t aligned_seqlen = AlignMambaCacheSeqlen(kv_depth * page_size); + if (aligned_seqlen > mamba_depth * page_size) { + match.mamba_branching_seqlen = aligned_seqlen; + } + } +} + +std::int32_t HybridPrefixCache::AlignMambaCacheSeqlen(std::int32_t seqlen) const { + return hybrid_prefix_cache::detail::AlignMambaCacheSeqlen(seqlen, mamba_cache_chunk_size_); +} + +TreeNode* HybridPrefixCache::FindLastMambaNode(TreeNode* from) const { + return hybrid_prefix_cache::detail::FindLastMambaNode(from); +} + +TreeNode* HybridPrefixCache::FindLastMambaHostNode(TreeNode* from) const { + return hybrid_prefix_cache::detail::FindLastMambaHostNode(from); +} + +bool HybridPrefixCache::EnsureMambaCapacityByEvict(std::int32_t num_slots, TreeNode* protected_node) { + if (mamba_allocator_ == nullptr) return num_slots <= 0; + return mamba_eviction_manager_.EnsureCapacity(num_slots, protected_node); +} + +bool HybridPrefixCache::EnsureMambaHostCapacityByEvict(std::int32_t num_slots, TreeNode* protected_node) { + if (mamba_host_allocator_ == nullptr) return num_slots <= 0; + if (mamba_host_allocator_->AvailableSlots() >= num_slots) return true; + + std::vector candidates; + candidates.reserve(mamba_host_nodes_.size()); + for (TreeNode* node : mamba_host_nodes_) { + if (node == nullptr || node == protected_node || !node->HasMambaOnHost()) continue; + if (node->OnHost() && GetResource(node).RefCount() > 0) continue; + candidates.push_back(node); + } + std::sort(candidates.begin(), candidates.end(), + [](const TreeNode* lhs, const TreeNode* rhs) { return lhs->Time() < rhs->Time(); }); + + for (TreeNode* node : candidates) { + if (mamba_host_allocator_->AvailableSlots() >= num_slots) break; + node->DetachMambaHost(); + mamba_host_nodes_.erase(node); + } + if (mamba_host_allocator_->AvailableSlots() < num_slots) { + spdlog::warn("[HybridPrefixCache] mamba host capacity exhausted required={} after_evict_available={}", + num_slots, mamba_host_allocator_->AvailableSlots()); + } + return mamba_host_allocator_->AvailableSlots() >= num_slots; +} + +std::vector HybridPrefixCache::PrepareMambaHostWriteBack(const std::vector& nodes) { + std::vector transfers; + if (mamba_allocator_ == nullptr || mamba_host_allocator_ == nullptr) return transfers; + + std::int32_t needed = 0; + for (TreeNode* node : nodes) { + if (node != nullptr && node->HasMamba() && !node->HasMambaOnHost() && + pending_mamba_host_writebacks_.find(node) == pending_mamba_host_writebacks_.end()) { + ++needed; + } + } + if (!EnsureMambaHostCapacityByEvict(needed)) return transfers; + + for (TreeNode* node : nodes) { + if (node == nullptr || !node->HasMamba() || node->HasMambaOnHost()) continue; + if (pending_mamba_host_writebacks_.find(node) != pending_mamba_host_writebacks_.end()) continue; + auto slot = mamba_host_allocator_->Allocate(); + if (!slot.has_value()) break; + const std::int32_t device_idx = node->MambaSlotIndex(); + const std::int32_t host_idx = slot->Index(); + pending_mamba_host_writebacks_.emplace(node, std::make_unique(std::move(*slot))); + transfers.push_back(TransferPair{CacheKind::kMamba, device_idx, host_idx}); + } + return transfers; +} + +std::vector HybridPrefixCache::PrepareMambaDeviceLoadBack(const std::vector& nodes) { + std::vector transfers; + if (mamba_allocator_ == nullptr || mamba_host_allocator_ == nullptr) return transfers; + + for (TreeNode* node : nodes) { + if (node == nullptr || !node->HasMambaOnHost() || node->HasMamba()) continue; + auto slot = mamba_allocator_->Allocate(); + if (!slot.has_value()) break; + const std::int32_t host_idx = node->MambaHostSlotIndex(); + const std::int32_t device_idx = slot->Index(); + node->AttachMamba(std::make_unique(std::move(*slot))); + mamba_eviction_manager_.TrackNode(node); + transfers.push_back(TransferPair{CacheKind::kMamba, host_idx, device_idx}); + } + return transfers; +} + +void HybridPrefixCache::InsertMamba(TreeNode* terminal_node, std::unique_ptr slot) { + if (terminal_node == nullptr || slot == nullptr) return; + if (mamba_allocator_ == nullptr) { + throw std::logic_error("HybridPrefixCache::InsertMamba: mamba adjunct not enabled"); + } + const std::int32_t page_size = kv_prefix_cache_.PageSize(); + if (page_size <= 0 || terminal_node->DepthInTokens() % static_cast(page_size) != 0) { + throw std::logic_error("HybridPrefixCache::InsertMamba: terminal node is not block-aligned"); + } + terminal_node->AttachMamba(std::move(slot)); + mamba_eviction_manager_.TrackNode(terminal_node); +} + +} // namespace tokenspeed diff --git a/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/mamba_family_ops.h b/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/mamba_family_ops.h new file mode 100644 index 000000000..360e9c6d8 --- /dev/null +++ b/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/mamba_family_ops.h @@ -0,0 +1,36 @@ +// Copyright (c) 2026 LightSeek Foundation +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#pragma once + +#include + +namespace tokenspeed { + +class TreeNode; + +namespace hybrid_prefix_cache::detail { + +std::int32_t AlignMambaCacheSeqlen(std::int32_t seqlen, std::int32_t chunk_size); +TreeNode* FindLastMambaNode(TreeNode* from); +TreeNode* FindLastMambaHostNode(TreeNode* from); + +} // namespace hybrid_prefix_cache::detail +} // namespace tokenspeed diff --git a/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/paged_cache_family_ops.cpp b/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/paged_cache_family_ops.cpp new file mode 100644 index 000000000..1f8e85dee --- /dev/null +++ b/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/paged_cache_family_ops.cpp @@ -0,0 +1,874 @@ +// Copyright (c) 2026 LightSeek Foundation +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "resource/hybrid_prefix_cache/paged_cache_family_ops.h" + +#include "resource/allocator/paged_cache_group.h" +#include "resource/hybrid_prefix_cache/hybrid_prefix_cache.h" +#include "resource/radix_tree/paged_cache_snapshot.h" +#include "resource/radix_tree/tree_node.h" +#include "scheduler/operations/forward.h" +#include "utils.h" + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tokenspeed { +namespace hybrid_prefix_cache::detail { + +std::vector CollectAncestorPathRootToLeaf(TreeNode* from) { + std::vector path; + for (TreeNode* n = from; n != nullptr && !n->IsRoot(); n = n->Parent()) { + path.push_back(n); + } + std::reverse(path.begin(), path.end()); + return path; +} + +} // namespace hybrid_prefix_cache::detail + +bool HybridPrefixCache::AttachPagedCacheSnapshotToNode(TreeNode* node, std::unique_ptr snapshot) { + if (node == nullptr || snapshot == nullptr) return false; + // Compute completeness from what is present. The policy-driven "snapshot + // must be full" invariant is enforced upstream by CommitChunk, which only + // attaches full snapshots; direct callers (tests, future restore paths) + // may attach history-only or state-only snapshots without policy gating. + snapshot->complete_families.clear(); + bool history_complete = !paged_cache_history_groups_.empty(); + for (const auto& gid : paged_cache_history_groups_) { + if (snapshot->groups.find(gid) == snapshot->groups.end()) { + history_complete = false; + break; + } + } + if (history_complete) { + snapshot->complete_families.insert(PagedCacheGroupFamily::History); + } + bool state_complete = !paged_cache_state_groups_.empty(); + for (const auto& gid : paged_cache_state_groups_) { + if (snapshot->groups.find(gid) == snapshot->groups.end()) { + state_complete = false; + break; + } + } + if (state_complete) { + snapshot->complete_families.insert(PagedCacheGroupFamily::State); + } + node->AttachPagedCacheSnapshot(std::move(snapshot)); + paged_cache_snapshot_nodes_.insert(node); + return true; +} + +std::unique_ptr HybridPrefixCache::DetachPagedCacheSnapshotFromNode(TreeNode* node) { + if (node == nullptr) return nullptr; + paged_cache_snapshot_nodes_.erase(node); + return node->DetachPagedCacheSnapshot(); +} + +void HybridPrefixCache::RegisterPagedCacheGroup(std::unique_ptr allocator) { + if (allocator == nullptr) { + throw std::invalid_argument("HybridPrefixCache::RegisterPagedCacheGroup: null allocator"); + } + std::string gid = allocator->Config().group_id; + if (paged_cache_allocators_.find(gid) != paged_cache_allocators_.end()) { + throw std::invalid_argument("HybridPrefixCache::RegisterPagedCacheGroup: duplicate group_id: " + gid); + } + paged_cache_allocators_.emplace(std::move(gid), std::move(allocator)); + RebuildFamilyRegistry(); +} + +void HybridPrefixCache::ConfigurePagedCacheAdjunct(std::span group_configs, + std::optional> required_groups) { + for (const auto& cfg : group_configs) { + PagedCacheGroupConfig copy = cfg; + copy.Validate(); + RegisterPagedCacheGroup(std::make_unique(std::move(copy))); + } + + if (!required_groups.has_value()) { + return; + } + if (required_groups->empty()) { + throw std::invalid_argument("HybridPrefixCache::ConfigurePagedCacheAdjunct: required_groups must be non-empty"); + } + + std::vector required; + required.reserve(required_groups->size()); + std::unordered_map sliding_window_per_group; + for (const auto& gid : *required_groups) { + auto it = paged_cache_allocators_.find(gid); + if (it == paged_cache_allocators_.end() || it->second == nullptr) { + throw std::invalid_argument("HybridPrefixCache::ConfigurePagedCacheAdjunct: required group '" + gid + + "' missing from registered paged-cache groups"); + } + const auto& cfg = it->second->Config(); + if (cfg.retention == PagedCacheGroupConfig::Retention::SlidingWindow) { + if (!cfg.sliding_window_tokens.has_value() || *cfg.sliding_window_tokens <= 0) { + throw std::invalid_argument("HybridPrefixCache::ConfigurePagedCacheAdjunct: sliding group '" + gid + + "' must declare positive sliding_window_tokens"); + } + sliding_window_per_group.emplace(gid, *cfg.sliding_window_tokens); + } + required.push_back(gid); + } + + EnablePagedCacheAdjunct(std::move(required), std::move(sliding_window_per_group)); +} + +void HybridPrefixCache::EnablePagedCacheAdjunct(std::vector required_groups, + std::unordered_map sliding_window_per_group, + StateRestorePolicy policy) { + if (required_groups.empty()) { + throw std::invalid_argument("HybridPrefixCache::EnablePagedCacheAdjunct: required_groups must be non-empty"); + } + std::vector history_gids; + std::vector state_gids; + std::vector required_sliding_gids; + history_gids.reserve(required_groups.size()); + state_gids.reserve(required_groups.size()); + required_sliding_gids.reserve(required_groups.size()); + + // Partition required groups by family; collect sliding-group entries for + // post-validation against `sliding_window_per_group`. + for (const auto& gid : required_groups) { + auto it = paged_cache_allocators_.find(gid); + if (it == paged_cache_allocators_.end() || it->second == nullptr) { + throw std::invalid_argument("HybridPrefixCache::EnablePagedCacheAdjunct: required group '" + gid + + "' missing from registered allocators"); + } + const auto& cfg = it->second->Config(); + const std::int32_t raw_per_page = cfg.RawTokensPerPage(); + if (raw_per_page <= 0) { + throw std::invalid_argument("HybridPrefixCache::EnablePagedCacheAdjunct: required group '" + gid + + "' has non-positive RawTokensPerPage"); + } + if (cfg.family == PagedCacheGroupFamily::History) { + history_gids.push_back(gid); + } else { + state_gids.push_back(gid); + } + if (cfg.retention == PagedCacheGroupConfig::Retention::SlidingWindow) { + auto win_it = sliding_window_per_group.find(gid); + if (win_it == sliding_window_per_group.end() || win_it->second <= 0) { + throw std::invalid_argument("HybridPrefixCache::EnablePagedCacheAdjunct: sliding group '" + gid + + "' missing positive sliding_window entry"); + } + required_sliding_gids.push_back(gid); + } + } + if (history_gids.empty()) { + throw std::invalid_argument( + "HybridPrefixCache::EnablePagedCacheAdjunct: at least one History-family group required"); + } + if (sliding_window_per_group.size() != required_sliding_gids.size()) { + throw std::invalid_argument( + "HybridPrefixCache::EnablePagedCacheAdjunct: sliding_window_per_group keys must exactly " + "match the set of required groups whose retention is SlidingWindow"); + } + + // History alignment = LCM(raw_per_page) across History-family groups. + std::int32_t history_alignment = 1; + for (const auto& gid : history_gids) { + const auto& cfg = paged_cache_allocators_.find(gid)->second->Config(); + history_alignment = std::lcm(history_alignment, cfg.RawTokensPerPage()); + } + // Phase 1: state groups must align with the history alignment (so trailing + // segments are themselves page-aligned). Phase 2 will relax this via replay. + if (policy == StateRestorePolicy::kSnapshotRequired) { + for (const auto& gid : state_gids) { + const auto& cfg = paged_cache_allocators_.find(gid)->second->Config(); + const std::int32_t raw_per_page = cfg.RawTokensPerPage(); + if (history_alignment % raw_per_page != 0) { + throw std::invalid_argument("HybridPrefixCache::EnablePagedCacheAdjunct: state group '" + gid + + "' RawTokensPerPage=" + std::to_string(raw_per_page) + + " does not divide history_alignment=" + std::to_string(history_alignment)); + } + } + } + + paged_cache_history_alignment_tokens_ = history_alignment; + paged_cache_required_groups_ = std::move(required_groups); + paged_cache_sliding_window_per_group_ = std::move(sliding_window_per_group); + paged_cache_state_policy_ = policy; + RebuildFamilyRegistry(); +} + +void HybridPrefixCache::augmentMatchPagedCache(MatchResult& match) const { + if (!HasPagedCacheAdjunct()) return; + if (match.device.last_node == nullptr) return; + + const std::int32_t align = paged_cache_history_alignment_tokens_; + + auto cap_to_root = [&]() { + TreeNode* root = match.device.last_node; + while (root != nullptr && !root->IsRoot()) root = root->Parent(); + match.device.last_node = root; + if (match.host.last_node != nullptr) { + TreeNode* h = match.host.last_node; + while (h != nullptr && !h->IsRoot()) h = h->Parent(); + match.host.last_node = h; + } + }; + + std::vector path = hybrid_prefix_cache::detail::CollectAncestorPathRootToLeaf(match.device.last_node); + + // Phase A: history chain. Walk root->leaf, advance only on contiguous + // History-family completeness at every k*align boundary. + TreeNode* deepest_history = nullptr; + std::vector history_chain; + std::int32_t expected_depth = align; + for (TreeNode* n : path) { + const std::int32_t d = static_cast(n->DepthInTokens()); + if (d < expected_depth) continue; + if (d > expected_depth) break; + const auto* snap = n->GetPagedCacheSnapshot(); + if (snap == nullptr) break; + if (!snap->IsCompleteFor(PagedCacheGroupFamily::History)) break; + deepest_history = n; + history_chain.push_back(n); + expected_depth += align; + } + if (deepest_history == nullptr) { + cap_to_root(); + return; + } + + // Phase B: state window. `segments_needed` is the worst-case trailing + // coverage across state groups (so every state group is satisfied at the + // chosen depth). Walk back through history_chain, pick the deepest D' + // whose trailing `segments_needed` history_chain entries all have State + // complete. + std::int32_t worst_window = 0; + for (const auto& gid : paged_cache_state_groups_) { + auto it = paged_cache_sliding_window_per_group_.find(gid); + if (it != paged_cache_sliding_window_per_group_.end()) { + worst_window = std::max(worst_window, it->second); + } + } + const std::int32_t segments_needed = worst_window > 0 ? (worst_window + align - 1) / align : 1; + + TreeNode* usable_node = nullptr; + if (paged_cache_state_groups_.empty()) { + usable_node = deepest_history; + } else { + for (std::int32_t end_idx = static_cast(history_chain.size()) - 1; end_idx >= 0; --end_idx) { + const std::int32_t start_idx = std::max(0, end_idx - segments_needed + 1); + bool ok = true; + for (std::int32_t i = start_idx; i <= end_idx; ++i) { + const auto* snap = history_chain[i]->GetPagedCacheSnapshot(); + if (snap == nullptr || !snap->IsCompleteFor(PagedCacheGroupFamily::State)) { + ok = false; + break; + } + } + if (ok) { + usable_node = history_chain[end_idx]; + break; + } + } + } + if (usable_node == nullptr) { + cap_to_root(); + return; + } + + const std::int32_t usable = static_cast(usable_node->DepthInTokens()); + // Trim history_chain to ancestors up to and including usable_node. + while (!history_chain.empty() && static_cast(history_chain.back()->DepthInTokens()) > usable) { + history_chain.pop_back(); + } + + // Phase C: per-group page-id assembly. History groups take the full chain; + // State groups share a trailing-window slice computed once. + match.paged_cache.last_node = usable_node; + match.paged_cache.prefix_len_tokens = usable; + match.paged_cache.per_group_page_ids.clear(); + match.paged_cache.per_group_base_logical_page.clear(); + + auto assemble = [&](const std::string& gid, std::span chain, bool is_sliding) { + std::vector page_ids; + std::int32_t base_logical_page = 0; + if (!chain.empty()) { + const PagedCacheSnapshot* earliest_snap = chain.front()->GetPagedCacheSnapshot(); + if (earliest_snap != nullptr && is_sliding) { + auto git = earliest_snap->groups.find(gid); + if (git != earliest_snap->groups.end()) { + base_logical_page = git->second.base_logical_page; + } + } + for (TreeNode* anc : chain) { + const PagedCacheSnapshot* snap = anc->GetPagedCacheSnapshot(); + if (snap == nullptr) continue; + auto git = snap->groups.find(gid); + if (git == snap->groups.end()) continue; + const auto& seg_ids = git->second.pages.Ids(); + page_ids.insert(page_ids.end(), seg_ids.begin(), seg_ids.end()); + } + } + match.paged_cache.per_group_page_ids[gid] = std::move(page_ids); + match.paged_cache.per_group_base_logical_page[gid] = base_logical_page; + }; + + const auto sliding_group_end = paged_cache_sliding_window_per_group_.end(); + const auto is_sliding_group = [&](const std::string& gid) { + return paged_cache_sliding_window_per_group_.find(gid) != sliding_group_end; + }; + + const std::span history_span{history_chain}; + for (const auto& gid : paged_cache_history_groups_) { + assemble(gid, history_span, is_sliding_group(gid)); + } + if (!paged_cache_state_groups_.empty()) { + const std::size_t take = std::min(history_chain.size(), static_cast(segments_needed)); + const std::span state_span = history_span.last(take); + for (const auto& gid : paged_cache_state_groups_) { + assemble(gid, state_span, is_sliding_group(gid)); + } + } + + // Cap device/host match nodes to the paged-cache usable depth. + match.device.last_node = usable_node; + if (match.host.last_node != nullptr && static_cast(match.host.last_node->DepthInTokens()) > usable) { + TreeNode* h = match.host.last_node; + while (h != nullptr && !h->IsRoot() && static_cast(h->DepthInTokens()) > usable) { + h = h->Parent(); + } + match.host.last_node = h; + } + + match.paged_cache.restore_kind = MatchResult::PagedCache::RestoreKind::kSnapshotComplete; + match.paged_cache.replay_start_tokens = 0; +} + +std::vector HybridPrefixCache::PagedCacheGroupIds() const { + return Stats().paged_cache_group_ids; +} + +std::int32_t HybridPrefixCache::PagedCacheGroupTotalPages(const std::string& group_id) const { + return Stats({.paged_cache_group_ids = {group_id}}).paged_cache_total_pages.at(group_id); +} + +std::int32_t HybridPrefixCache::PagedCacheGroupAvailablePages(const std::string& group_id) const { + return Stats({.paged_cache_group_ids = {group_id}}).paged_cache_available_pages.at(group_id); +} + +std::int64_t HybridPrefixCache::PagedCacheGroupFailedAllocCount(const std::string& group_id) const { + return Stats({.paged_cache_group_ids = {group_id}}).paged_cache_failed_alloc_count.at(group_id); +} + +std::map HybridPrefixCache::InitialSimulatedFree() const { + std::map out; + for (const auto& [gid, allocator] : paged_cache_allocators_) { + out[gid] = allocator->AvailablePages(); + } + return out; +} + +void HybridPrefixCache::AcquireForRequest(const std::string& request_id, std::int32_t first_raw_position_of_op, + std::int32_t target_raw_tokens_exclusive, + const MatchResult::PagedCache& paged_cache_hit) { + if (paged_cache_allocators_.empty()) return; + auto& tables = request_paged_cache_tables_[request_id]; + const bool has_hit = (paged_cache_hit.last_node != nullptr) && (paged_cache_hit.prefix_len_tokens > 0); + for (const auto& [group_id, allocator] : paged_cache_allocators_) { + auto it = tables.find(group_id); + const bool fresh_table = (it == tables.end()); + if (fresh_table) { + it = tables.emplace(group_id, PagedCacheGroupTable(allocator.get())).first; + // Import borrowed-prefix BEFORE ReleaseSkipped/Acquire on a fresh table. + if (has_hit) { + auto pid_it = paged_cache_hit.per_group_page_ids.find(group_id); + if (pid_it != paged_cache_hit.per_group_page_ids.end() && !pid_it->second.empty()) { + std::int32_t base_logical_page = 0; + auto base_it = paged_cache_hit.per_group_base_logical_page.find(group_id); + if (base_it != paged_cache_hit.per_group_base_logical_page.end()) { + base_logical_page = base_it->second; + } + std::vector page_ids_copy = pid_it->second; + it->second.ImportPrefixBorrowed(std::move(page_ids_copy), base_logical_page, + paged_cache_hit.prefix_len_tokens); + } + } + } + const auto& cfg = allocator->Config(); + if (cfg.retention == PagedCacheGroupConfig::Retention::SlidingWindow && cfg.sliding_window_tokens.has_value()) { + const std::int32_t lower = std::max(0, first_raw_position_of_op - *cfg.sliding_window_tokens + 1); + it->second.ReleaseSkipped(lower); + } + it->second.Acquire(target_raw_tokens_exclusive); + } +} + +void HybridPrefixCache::FinishRequest(const std::string& request_id) { + ReleaseRequest(request_id); +} + +void HybridPrefixCache::ReleaseRequest(const std::string& request_id) { + auto it = request_paged_cache_tables_.find(request_id); + if (it == request_paged_cache_tables_.end()) return; + for (auto& [_, table] : it->second) { + table.ReleaseAll(); + } + request_paged_cache_tables_.erase(it); +} + +void HybridPrefixCache::PopulateOp(ForwardOperationBase& op_base) const { + if (paged_cache_allocators_.empty()) return; + auto req_it = request_paged_cache_tables_.find(op_base.request_id); + for (const auto& [gid, allocator] : paged_cache_allocators_) { + std::vector pages; + std::int32_t base_offset = 0; + if (req_it != request_paged_cache_tables_.end()) { + auto table_it = req_it->second.find(gid); + if (table_it != req_it->second.end()) { + pages = table_it->second.PageIds(); + base_offset = table_it->second.BaseLogicalPage(); + } + } + op_base.paged_cache_pages[gid] = std::move(pages); + if (allocator->Config().retention == PagedCacheGroupConfig::Retention::SlidingWindow) { + op_base.paged_cache_page_base_offsets[gid] = base_offset; + } + } +} + +void HybridPrefixCache::acquireAndPopulateOp(ForwardOperationBase& op_base, std::int32_t first_raw_position_of_op, + std::int32_t target_raw_tokens_exclusive, + const MatchResult::PagedCache& paged_cache_hit) { + AcquireForRequest(op_base.request_id, first_raw_position_of_op, target_raw_tokens_exclusive, paged_cache_hit); + PopulateOp(op_base); +} + +HybridPrefixCache::PagedCacheGroupAdmission HybridPrefixCache::checkPagedCacheGroupAdmission( + const std::string& request_id, std::int32_t first_raw_position_of_op, std::int32_t target_raw_tokens_exclusive, + const std::map& simulated_free, const MatchResult::PagedCache& paged_cache_hit, + const PagedCacheAdmissionContext& context) const { + PagedCacheGroupAdmission result; + if (paged_cache_allocators_.empty() || target_raw_tokens_exclusive < 0) { + return result; + } + + auto req_it = + context.fresh_table_view ? request_paged_cache_tables_.end() : request_paged_cache_tables_.find(request_id); + const bool has_hit = (paged_cache_hit.last_node != nullptr) && (paged_cache_hit.prefix_len_tokens > 0); + for (const auto& [gid, allocator] : paged_cache_allocators_) { + const auto& cfg = allocator->Config(); + const std::int32_t raw_per_page = cfg.RawTokensPerPage(); + if (cfg.entry_stride_tokens <= 0 || cfg.rows_per_page <= 0 || raw_per_page <= 0) { + continue; + } + + const std::int32_t entries = CeilDivPositive(target_raw_tokens_exclusive, cfg.entry_stride_tokens); + const std::int32_t required = (entries + cfg.rows_per_page - 1) / cfg.rows_per_page; + + std::int32_t current_size = 0; + std::int32_t current_active = 0; + std::int32_t borrowed_in_table = 0; + std::int32_t owned_in_table = 0; + std::int32_t already_released = 0; + bool table_exists = false; + if (req_it != request_paged_cache_tables_.end()) { + auto t_it = req_it->second.find(gid); + if (t_it != req_it->second.end()) { + table_exists = true; + current_size = t_it->second.Size(); + current_active = t_it->second.ActivePagesCount(); + borrowed_in_table = t_it->second.BorrowedPagesCount(); + owned_in_table = t_it->second.OwnedPagesCount(); + already_released = t_it->second.ReleasedPagesCount(); + } + } + + std::int32_t borrowed_count = 0; + std::int32_t borrowed_base = 0; + if (has_hit && !table_exists) { + auto pid_it = paged_cache_hit.per_group_page_ids.find(gid); + if (pid_it != paged_cache_hit.per_group_page_ids.end()) { + borrowed_count = static_cast(pid_it->second.size()); + } + auto base_it = paged_cache_hit.per_group_base_logical_page.find(gid); + if (base_it != paged_cache_hit.per_group_base_logical_page.end()) { + borrowed_base = base_it->second; + } + } + + std::int32_t releasable_total = 0; + std::int32_t releasable_owned = 0; + if (cfg.retention == PagedCacheGroupConfig::Retention::SlidingWindow && cfg.sliding_window_tokens.has_value()) { + const std::int32_t lower = std::max(0, first_raw_position_of_op - *cfg.sliding_window_tokens + 1); + const std::int32_t target_releases = lower / raw_per_page; + const std::int32_t logical_released_base = table_exists ? already_released : borrowed_base; + releasable_total = std::max(0, target_releases - logical_released_base); + releasable_total = std::min(releasable_total, current_active + borrowed_count); + + // Borrowed pages drop the index only (no pool credit); only the + // owned-prefix slice contributes to releasable_owned. + const std::int32_t borrowed_present_total = table_exists ? borrowed_in_table : borrowed_count; + releasable_owned = releasable_total - std::min(releasable_total, borrowed_present_total); + if (table_exists) { + releasable_owned = std::min(releasable_owned, owned_in_table); + } + } + + const std::int32_t absolute_have = + table_exists ? (already_released + current_size) : (borrowed_base + borrowed_count); + const std::int32_t new_pages = std::max(0, required - absolute_have); + std::int32_t free = allocator->AvailablePages(); + auto sf_it = simulated_free.find(gid); + if (sf_it != simulated_free.end()) { + free = sf_it->second; + } + auto credit_it = context.owned_release_credit.find(gid); + if (credit_it != context.owned_release_credit.end()) { + free += credit_it->second; + } + + const std::int32_t shortfall = std::max(0, new_pages - free - releasable_owned); + result.releasable_owned_pages[gid] = releasable_owned; + result.new_pages_needed[gid] = new_pages; + result.shortfall_pages[gid] = shortfall; + if (shortfall > 0) { + result.ok = false; + } + } + return result; +} + +void HybridPrefixCache::applyPagedCacheGroupAdmissionDebit(std::map& simulated_free, + const PagedCacheGroupAdmission& admission) { + for (const auto& [gid, releasable_owned] : admission.releasable_owned_pages) { + simulated_free[gid] += releasable_owned; + } + for (const auto& [gid, new_pages] : admission.new_pages_needed) { + simulated_free[gid] -= new_pages; + } +} + +HybridPrefixCache::AdmissionFailureKind HybridPrefixCache::ClassifyAdmissionFailure( + const PagedCacheGroupAdmission& admission) const { + if (admission.ok) return AdmissionFailureKind::kNone; + bool history_starved = false; + bool state_starved = false; + for (const auto& [gid, shortfall] : admission.shortfall_pages) { + if (shortfall <= 0) continue; + if (paged_cache_history_group_set_.find(gid) != paged_cache_history_group_set_.end()) { + history_starved = true; + } + if (paged_cache_state_group_set_.find(gid) != paged_cache_state_group_set_.end()) { + state_starved = true; + } + } + if (history_starved && state_starved) return AdmissionFailureKind::kBothStarved; + if (history_starved) return AdmissionFailureKind::kHistoryStarved; + if (state_starved) return AdmissionFailureKind::kStateStarved; + return AdmissionFailureKind::kNone; +} + +void HybridPrefixCache::refreshPagedCacheSimulatedFree(std::map& simulated_free) const { + for (const auto& [gid, allocator] : paged_cache_allocators_) { + simulated_free[gid] = allocator->AvailablePages(); + } +} + +bool HybridPrefixCache::admitPagedCacheChunk(const std::string& request_id, std::int32_t first_raw_position_of_op, + std::int32_t target_raw_tokens_exclusive, + std::map& simulated_free, + const MatchResult::PagedCache& paged_cache_hit, + const PagedCacheAdmissionContext& context) { + PagedCacheGroupAdmission admission = checkPagedCacheGroupAdmission( + request_id, first_raw_position_of_op, target_raw_tokens_exclusive, simulated_free, paged_cache_hit, context); + const std::size_t prune_budget = paged_cache_snapshot_nodes_.size(); + for (std::size_t pruned = 0; !admission.ok && pruned < prune_budget; ++pruned) { + AdmissionFailureKind kind = ClassifyAdmissionFailure(admission); + if (kind == AdmissionFailureKind::kNone) break; + if (!tryPrunePagedCacheSnapshot(kind)) break; + refreshPagedCacheSimulatedFree(simulated_free); + admission = checkPagedCacheGroupAdmission(request_id, first_raw_position_of_op, target_raw_tokens_exclusive, + simulated_free, paged_cache_hit, context); + } + if (!admission.ok) return false; + for (const auto& [gid, credit] : context.owned_release_credit) { + simulated_free[gid] += credit; + } + applyPagedCacheGroupAdmissionDebit(simulated_free, admission); + return true; +} + +bool HybridPrefixCache::DetachStateSnapshotFromNode(TreeNode* node) { + if (node == nullptr) return false; + PagedCacheSnapshot* snap = node->GetPagedCacheSnapshotMut(); + if (snap == nullptr) return false; + bool removed_any = false; + for (const auto& gid : paged_cache_state_groups_) { + auto it = snap->groups.find(gid); + if (it != snap->groups.end()) { + snap->groups.erase(it); + removed_any = true; + } + } + if (!removed_any) return false; + snap->complete_families.erase(PagedCacheGroupFamily::State); + // If nothing remains, fall through to full detach to keep invariants tidy. + if (snap->groups.empty()) { + DetachPagedCacheSnapshotFromNode(node); + } + return true; +} + +bool HybridPrefixCache::tryPrunePagedCacheSnapshot(AdmissionFailureKind kind) { + if (!HasPagedCacheAdjunct()) return false; + if (kind == AdmissionFailureKind::kNone) return false; + + auto is_pinned = [](TreeNode* node) { + for (TreeNode* cur = node; cur != nullptr && !cur->IsRoot(); cur = cur->Parent()) { + if (!cur->OnDevice()) continue; + if (cur->Device().RefCount() > 0) return true; + } + return false; + }; + + // Sort once and share between branches: oldest first, then deepest within + // same Time(). Both try_state_only and try_full walk this same order. + std::vector candidates; + candidates.reserve(paged_cache_snapshot_nodes_.size()); + for (TreeNode* node : paged_cache_snapshot_nodes_) { + if (node == nullptr) continue; + if (!node->HasPagedCacheSnapshot()) continue; + candidates.push_back(node); + } + std::sort(candidates.begin(), candidates.end(), [](TreeNode* a, TreeNode* b) { + if (a->Time() != b->Time()) return a->Time() < b->Time(); + return a->DepthInTokens() > b->DepthInTokens(); + }); + + auto try_state_only = [&]() { + for (TreeNode* node : candidates) { + if (is_pinned(node)) continue; + const auto* snap = node->GetPagedCacheSnapshot(); + if (snap == nullptr) continue; + if (!snap->IsCompleteFor(PagedCacheGroupFamily::State)) continue; + if (DetachStateSnapshotFromNode(node)) return true; + } + return false; + }; + + auto try_full = [&]() { + TreeNode* victim = nullptr; + for (TreeNode* node : candidates) { + if (is_pinned(node)) continue; + victim = node; + break; + } + if (victim == nullptr) return false; + const std::size_t victim_depth = victim->DepthInTokens(); + auto primary = DetachPagedCacheSnapshotFromNode(victim); + (void)primary; + std::vector descendants; + for (TreeNode* node : paged_cache_snapshot_nodes_) { + if (node == nullptr || node == victim) continue; + if (!node->HasPagedCacheSnapshot()) continue; + if (node->DepthInTokens() <= victim_depth) continue; + for (TreeNode* cur = node->Parent(); cur != nullptr && !cur->IsRoot(); cur = cur->Parent()) { + if (cur == victim) { + descendants.push_back(node); + break; + } + } + } + for (TreeNode* d : descendants) { + if (is_pinned(d)) continue; + auto cascaded = DetachPagedCacheSnapshotFromNode(d); + (void)cascaded; + } + return true; + }; + + // kBothStarved: state-only cannot solve history shortage; go straight to + // full. The outer admit loop will re-classify if state still needs more. + switch (kind) { + case AdmissionFailureKind::kStateStarved: + return try_state_only(); + case AdmissionFailureKind::kHistoryStarved: + case AdmissionFailureKind::kBothStarved: + return try_full(); + case AdmissionFailureKind::kNone: + return false; + } + return false; +} + +bool HybridPrefixCache::AdmitChunk(const std::string& request_id, std::int32_t first_raw_position_of_op, + std::int32_t target_raw_tokens_exclusive, + std::map& simulated_free, + const MatchResult::PagedCache& paged_cache_hit) { + return admitPagedCacheChunk(request_id, first_raw_position_of_op, target_raw_tokens_exclusive, simulated_free, + paged_cache_hit, {}); +} + +bool HybridPrefixCache::AdmitChunkFromRetracted(const std::string& request_id, std::int32_t target_raw_tokens_exclusive, + std::map& simulated_free, + const MatchResult::PagedCache& paged_cache_hit) { + PagedCacheAdmissionContext context{.fresh_table_view = true}; + auto req_it = request_paged_cache_tables_.find(request_id); + if (req_it != request_paged_cache_tables_.end()) { + for (const auto& [gid, table] : req_it->second) { + context.owned_release_credit[gid] = table.OwnedPagesCount(); + } + } + return admitPagedCacheChunk(request_id, 0, target_raw_tokens_exclusive, simulated_free, paged_cache_hit, context); +} + +void HybridPrefixCache::CommitChunk(const std::string& request_id, TreeNode* terminal) { + if (!HasPagedCacheAdjunct()) return; + if (terminal == nullptr) return; + + auto tables_it = request_paged_cache_tables_.find(request_id); + if (tables_it == request_paged_cache_tables_.end()) return; + auto& tables = tables_it->second; + + const std::int32_t lcm = paged_cache_history_alignment_tokens_; + if (lcm <= 0) return; + const auto& required_groups = paged_cache_required_groups_; + if (required_groups.empty()) return; + + auto canonical_it = tables.find(required_groups.front()); + if (canonical_it == tables.end()) return; + std::int32_t last_committed = canonical_it->second.CommittedPrefixLenTokens(); + + const std::int32_t chunk_depth = static_cast(terminal->DepthInTokens()); + if (chunk_depth <= 0) return; + + while (last_committed + lcm <= chunk_depth) { + const std::int32_t target = last_committed + lcm; + + TreeNode* attach_node = kv_prefix_cache_.GetRadixTree().SplitAt(terminal, target); + if (attach_node == nullptr) break; + + if (attach_node->HasPagedCacheSnapshot()) { + bool covered = true; + for (const auto& gid : required_groups) { + auto t_it = tables.find(gid); + if (t_it == tables.end()) { + covered = false; + break; + } + if (t_it->second.CommittedPrefixLenTokens() < target) { + covered = false; + break; + } + } + if (!covered) { + spdlog::warn( + "[HybridPrefixCache] CommitChunk: target depth {} already has a paged-cache " + "snapshot but request {} has uncommitted owned pages in [{}, {}); leaving " + "existing snapshot intact", + target, request_id, last_committed, target); + break; + } + last_committed = target; + continue; + } + + bool preflight_ok = true; + for (const auto& gid : required_groups) { + auto t_it = tables.find(gid); + if (t_it == tables.end()) { + preflight_ok = false; + break; + } + const auto& table = t_it->second; + const std::int32_t raw_per_page = table.RawTokensPerPage(); + if (raw_per_page <= 0) { + preflight_ok = false; + break; + } + if (table.CommittedPrefixLenTokens() % raw_per_page != 0) { + preflight_ok = false; + break; + } + if (target % raw_per_page != 0) { + preflight_ok = false; + break; + } + if (target <= table.CommittedPrefixLenTokens()) { + preflight_ok = false; + break; + } + if (target > table.RawTokenCursor()) { + preflight_ok = false; + break; + } + } + if (!preflight_ok) { + spdlog::warn( + "[HybridPrefixCache] CommitChunk: preflight failed for request {} at target " + "depth {}; leaving prior commits intact", + request_id, target); + break; + } + + auto snapshot = std::make_unique(); + snapshot->prefix_len_tokens = target; + for (const auto& gid : required_groups) { + auto& table = tables.find(gid)->second; + auto group_alloc_it = paged_cache_allocators_.find(gid); + const auto& cfg = group_alloc_it->second->Config(); + auto result = cfg.family == PagedCacheGroupFamily::History ? table.CommitHistoryToSnapshot(target) + : table.CheckpointStateToSnapshot(target); + PagedCacheGroupSnapshot group_snap{}; + group_snap.pages = std::move(result.pages); + group_snap.base_logical_page = result.segment_base_logical_page; + group_snap.raw_token_cursor = table.RawTokenCursor(); + group_snap.sliding = table.IsSliding(); + snapshot->groups.emplace(gid, std::move(group_snap)); + } + + bool snapshot_complete = true; + for (const auto& gid : required_groups) { + if (snapshot->groups.find(gid) == snapshot->groups.end()) { + snapshot_complete = false; + break; + } + } + _assert(snapshot_complete, + "HybridPrefixCache::CommitChunk: built snapshot missing a required group after " + "preflight+commit; invariant violated"); + const bool attached = AttachPagedCacheSnapshotToNode(attach_node, std::move(snapshot)); + _assert(attached, + "HybridPrefixCache::CommitChunk: attach rejected a non-null snapshot on a non-null " + "node; invariant violated"); + + last_committed = target; + } +} + +} // namespace tokenspeed diff --git a/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/paged_cache_family_ops.h b/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/paged_cache_family_ops.h new file mode 100644 index 000000000..f1cf49015 --- /dev/null +++ b/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/paged_cache_family_ops.h @@ -0,0 +1,34 @@ +// Copyright (c) 2026 LightSeek Foundation +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#pragma once + +#include + +namespace tokenspeed { + +class TreeNode; + +namespace hybrid_prefix_cache::detail { + +std::vector CollectAncestorPathRootToLeaf(TreeNode* from); + +} // namespace hybrid_prefix_cache::detail +} // namespace tokenspeed diff --git a/tokenspeed-scheduler/csrc/resource/kv_prefix_cache/cache_coordinator.cpp b/tokenspeed-scheduler/csrc/resource/kv_prefix_cache/cache_coordinator.cpp deleted file mode 100644 index ec9af752f..000000000 --- a/tokenspeed-scheduler/csrc/resource/kv_prefix_cache/cache_coordinator.cpp +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright (c) 2026 LightSeek Foundation -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -#include "resource/kv_prefix_cache/cache_coordinator.h" - -#include -#include -#include -#include -#include -#include -#include - -#include "resource/radix_tree/node_range.h" -#include "resource/radix_tree/tree_node.h" -#include "scheduler/outside_events/cache.h" - -namespace tokenspeed { - -CacheOpSpec::CacheOpSpec() = default; -CacheOpSpec::~CacheOpSpec() = default; -CacheOpSpec::CacheOpSpec(CacheOpSpec&&) noexcept = default; -CacheOpSpec& CacheOpSpec::operator=(CacheOpSpec&&) noexcept = default; - -std::vector CollectNodesByOpId(TreeNode* last_node, cache_op_id op_id) { - return Collect(LeafToRoot(last_node) | std::views::filter([op_id](TreeNode* n) { - auto node_op_id = n->CacheOpId(); - return node_op_id.has_value() && *node_op_id == op_id; - })); -} - -std::optional CacheCoordinator::takeOpSpec(cache_op_id op_id) { - auto iter = pending_ops_.find(op_id); - if (iter == pending_ops_.end()) { - return std::nullopt; - } - CacheOpSpec op = std::move(iter->second); - pending_ops_.erase(iter); - if (op.last_node == nullptr) { - return std::nullopt; - } - op.nodes = CollectNodesByOpId(op.last_node, op_id); - return op; -} - -void CacheCoordinator::HandleEvent(const cache::WriteBackDone& event) { - auto spec = takeOpSpec(event.op_id); - if (!spec) return; - - auto access_time = std::chrono::steady_clock::now(); - for (TreeNode* current : spec->nodes) { - current->Touch(access_time); - } - if (enable_l3_storage_ && !spec->nodes.empty()) { - EnqueueTransfer(spec->last_node); - } -} - -void CacheCoordinator::EnqueueTransfer(TreeNode* last_node) { - if (last_node != nullptr) { - waiting_last_nodes_.push_back(last_node); - } -} - -std::vector CacheCoordinator::DrainTransferQueue() { - std::vector nodes = std::move(waiting_last_nodes_); - waiting_last_nodes_.clear(); - return nodes; -} - -} // namespace tokenspeed diff --git a/tokenspeed-scheduler/csrc/resource/types.h b/tokenspeed-scheduler/csrc/resource/types.h index 4d53e5c0b..d557d6a5c 100644 --- a/tokenspeed-scheduler/csrc/resource/types.h +++ b/tokenspeed-scheduler/csrc/resource/types.h @@ -108,13 +108,11 @@ struct WalkResult { struct CacheOpSpec { std::string request_id; - TreeNode* last_node{nullptr}; - std::vector nodes; - CacheOpSpec(); - ~CacheOpSpec(); - CacheOpSpec(CacheOpSpec&&) noexcept; - CacheOpSpec& operator=(CacheOpSpec&&) noexcept; + CacheOpSpec() = default; + ~CacheOpSpec() = default; + CacheOpSpec(CacheOpSpec&&) noexcept = default; + CacheOpSpec& operator=(CacheOpSpec&&) noexcept = default; CacheOpSpec(const CacheOpSpec&) = delete; CacheOpSpec& operator=(const CacheOpSpec&) = delete; }; diff --git a/tokenspeed-scheduler/csrc/resource/kv_prefix_cache/cache_coordinator.h b/tokenspeed-scheduler/csrc/scheduler/device_memory_diagnostics.h similarity index 64% rename from tokenspeed-scheduler/csrc/resource/kv_prefix_cache/cache_coordinator.h rename to tokenspeed-scheduler/csrc/scheduler/device_memory_diagnostics.h index 8b8b0b8ea..f10034e83 100644 --- a/tokenspeed-scheduler/csrc/resource/kv_prefix_cache/cache_coordinator.h +++ b/tokenspeed-scheduler/csrc/scheduler/device_memory_diagnostics.h @@ -20,34 +20,24 @@ #pragma once -#include -#include +#include +#include #include -#include "resource/radix_tree/tree_node.h" -#include "resource/types.h" +#include "resource/hybrid_prefix_cache/hybrid_prefix_cache.h" namespace tokenspeed { -namespace cache { -struct WriteBackDone; -} - -class CacheCoordinator { -public: - explicit CacheCoordinator(bool enable_l3_storage = false) : enable_l3_storage_(enable_l3_storage) {} - - void HandleEvent(const cache::WriteBackDone& event); - - void EnqueueTransfer(TreeNode* last_node); - std::vector DrainTransferQueue(); - -private: - std::optional takeOpSpec(cache_op_id op_id); - - bool enable_l3_storage_; - std::unordered_map pending_ops_; - std::vector waiting_last_nodes_; +struct RequestLocalKVPagesSnapshot { + std::string request_id; + std::string state_name; + std::vector pages; }; +// Validates the debug-only device-page accounting snapshot used by +// Scheduler::check_device_mem(). The function is side-effect free except for +// the same diagnostic logging as the historical inline checks. +bool ValidateDeviceMemoryDiagnostics(const std::vector& request_pages, + const HybridPrefixCache::DeviceMemoryDiagnosticsSnapshot& device_snapshot); + } // namespace tokenspeed diff --git a/tokenspeed-scheduler/csrc/scheduler/operations/cache.cpp b/tokenspeed-scheduler/csrc/scheduler/operations/cache.cpp deleted file mode 100644 index 60be05c1d..000000000 --- a/tokenspeed-scheduler/csrc/scheduler/operations/cache.cpp +++ /dev/null @@ -1,81 +0,0 @@ -// Copyright (c) 2026 LightSeek Foundation -// -// Permission is hereby granted, free of charge, to any person obtaining a copy -// of this software and associated documentation files (the "Software"), to deal -// in the Software without restriction, including without limitation the rights -// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -// copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in -// all copies or substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -// SOFTWARE. - -#include "scheduler/operations/cache.h" - -#include -#include -#include -#include -#include -#include - -#include "fsm/cache_events.h" -#include "fsm/forward_states.h" -#include "resource/kv_prefix_cache/kv_prefix_cache.h" -#include "resource/types.h" -#include "scheduler/request.h" -#include "scheduler/request_spec.h" -#include "scheduler/scheduler.h" -#include "scheduler/types.h" - -namespace tokenspeed { - -std::optional Scheduler::schedulePrefetch(Request* request, const MatchResult& match) { - const auto& storage = request->GetStorageInfo(); - if (config_.disable_prefix_cache || !config_.enable_l3_storage || !request->Is() || - storage.hit_pages <= config_.prefetch_threshold) { - return {}; - } - - const std::int32_t num_pages_to_fetch = storage.hit_pages; - if (!kv_prefix_cache_.EnsureCapacityByEvict(num_pages_to_fetch)) { - return {}; - } - - std::vector hashes(storage.rolling_hashes.begin(), - storage.rolling_hashes.begin() + num_pages_to_fetch); - - return fsm::SchedulePrefetchEvent{num_pages_to_fetch, std::move(hashes), &host_allocator_, match.host.last_node}; -} - -PrefetchOperation Scheduler::applyEventAndGenerateOp(Request* request, fsm::SchedulePrefetchEvent event) { - // Save rolling hashes BEFORE Apply (event will be moved into the state transition). - auto rolling_hashes = event.TakeRollingPageHashes(); - - // Apply event: Submitted → Prefetching (host pages allocated inside the state transition). - request->Apply(event); - - // After Apply, request is in Prefetching state; read back the allocated host pages. - cache_op_id op_id = kv_prefix_cache_.AllocateCacheOpId(); - - CacheOpSpec spec; - spec.request_id = request->Id(); - cache_op_tracker_[op_id] = std::move(spec); - - PrefetchOperation prefetch_op; - prefetch_op.op_id = op_id; - prefetch_op.dst_pages = request->GetHostPageIds(); - prefetch_op.request_id = request->Id(); - prefetch_op.rolling_page_hashes = std::move(rolling_hashes); - return prefetch_op; -} - -} // namespace tokenspeed diff --git a/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp b/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp index a8ce8f900..1fa02bf11 100644 --- a/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp +++ b/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp @@ -49,6 +49,7 @@ #include "scheduler/operations/cache.h" #include "scheduler/operations/forward.h" #include "scheduler/request.h" +#include "scheduler/request_cache_context.h" #include "scheduler/request_spec.h" #include "scheduler/scheduler.h" #include "scheduler/types.h" @@ -56,37 +57,15 @@ namespace tokenspeed { -namespace { - -std::int32_t CountMambaDeviceLoadBackSlots(const std::vector& nodes) { - std::int32_t slots = 0; - for (TreeNode* node : nodes) { - if (node != nullptr && node->HasMambaOnHost() && !node->HasMamba()) { - ++slots; - } - } - return slots; -} - -void AddUniqueNode(std::vector& nodes, TreeNode* node) { - if (node == nullptr) return; - if (std::find(nodes.begin(), nodes.end(), node) == nodes.end()) { - nodes.push_back(node); - } -} - -} // namespace - std::optional Scheduler::schedulePrefillFirstChunk( Request* request, std::int32_t remaining, std::int32_t decode_input_tokens, bool disable_l2_cache, std::map& simulated_free) { if (req_pool_allocator_.AvailableSlots() == 0) return {}; - MatchResult match_result = hybrid_prefix_cache_ ? hybrid_prefix_cache_->Match(request->GetFullPagedTokens(true)) - : kv_prefix_cache_.Match(request->GetFullPagedTokens(true)); + RecoveryPlan recovery_plan = hybrid_prefix_cache_.MatchPrefix(request->GetFullPagedTokens(true)); + MatchResult match_result = recovery_plan.compat_match; std::int32_t loadback_tokens = 0; std::int32_t unscheduled = 0; std::vector loadback_diff; - std::vector mamba_loadback_nodes; const std::int32_t device_matched = match_result.device.DepthInPage(); const std::int32_t host_matched = match_result.host.DepthInPage(); @@ -101,70 +80,36 @@ std::optional Scheduler::schedulePrefillFir } std::int32_t tokens_this_round = std::min(remaining, unscheduled); - if (hybrid_prefix_cache_ && hybrid_prefix_cache_->HasMambaAdjunct() && match_result.mamba_branching_seqlen == -1) { - const std::int32_t aligned = hybrid_prefix_cache_->AlignMambaCacheSeqlen(tokens_this_round); - if (aligned > 0) { - match_result.mamba_branching_seqlen = aligned; - } - } - std::int32_t num_tokens = loadback_tokens + tokens_this_round + decode_input_tokens; std::int32_t device_pages_needed = (num_tokens + config_.page_size - 1) / config_.page_size; - std::unique_ptr temp_lock = std::make_unique(match_result.device.last_node); - - // Evict unlocked prefix-cache nodes before allocating request-local pages. - if (!(kv_prefix_cache_.EnsureCapacityByEvict(device_pages_needed))) { - return {}; - } - - if (hybrid_prefix_cache_ && hybrid_prefix_cache_->HasMambaAdjunct() && match_result.mamba_host_src_index >= 0 && - match_result.mamba_cow_src_index < 0) { - TreeNode* host_mamba_node = hybrid_prefix_cache_->FindLastMambaHostNode(match_result.host.last_node); - if (host_mamba_node != nullptr && host_mamba_node->HasMambaOnHost() && !host_mamba_node->HasMamba()) { - AddUniqueNode(mamba_loadback_nodes, host_mamba_node); - } - } - const bool needs_mamba_loadback = !mamba_loadback_nodes.empty(); - const std::int32_t mamba_loadback_slots_needed = - needs_mamba_loadback ? CountMambaDeviceLoadBackSlots(mamba_loadback_nodes) : 0; - const std::int32_t mamba_slots_needed = 2 + mamba_loadback_slots_needed; - if (hybrid_prefix_cache_ && hybrid_prefix_cache_->HasMambaAdjunct() && - !hybrid_prefix_cache_->EnsureMambaCapacityByEvict(mamba_slots_needed)) { - return {}; - } - const std::int32_t first_pos = request->PrefillSize() - unscheduled; const std::int32_t target = first_pos + tokens_this_round; - if (hybrid_prefix_cache_ && - !hybrid_prefix_cache_->AdmitChunk(request->Id(), first_pos, target, simulated_free, match_result.paged_cache)) { + AdmissionRequest admission_request{ + .request_id = request->Id(), + .device_pages_needed = device_pages_needed, + .tokens_this_round = tokens_this_round, + .first_raw_position_of_op = first_pos, + .target_raw_tokens_exclusive = target, + .recovery_plan = &recovery_plan, + .auxiliary_tree_slots_needed = 2, + .compute_branching_checkpoint = true, + }; + AdmissionVerdict admission = hybrid_prefix_cache_.Admit(admission_request, simulated_free); + if (!admission.admitted) { return {}; } - if (needs_mamba_loadback) { - hybrid_prefix_cache_->PrepareMambaDeviceLoadBack(mamba_loadback_nodes); - TreeNode* mamba_node = hybrid_prefix_cache_->FindLastMambaNode(match_result.host.last_node); - if (mamba_node != nullptr) { - match_result.mamba_cow_src_index = mamba_node->MambaSlotIndex(); - } + if (admission.mamba_branching_seqlen.has_value()) { + match_result.mamba_branching_seqlen = *admission.mamba_branching_seqlen; } - if (hybrid_prefix_cache_ && hybrid_prefix_cache_->HasMambaAdjunct() && mamba_allocator_ && - mamba_allocator_->AvailableSlots() < 1) { - return {}; + if (admission.mamba_cow_src_index.has_value()) { + match_result.mamba_cow_src_index = *admission.mamba_cow_src_index; } return fsm::SchedulePrefillFirstChunkEvent{ - tokens_this_round, - decode_input_tokens, - &device_allocator_, - &req_pool_allocator_, - match_result, - config_.role, - &kv_prefix_cache_, - disable_l2_cache, - std::move(loadback_diff), - hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr, - mamba_allocator_ ? &*mamba_allocator_ : nullptr, - std::move(mamba_loadback_nodes), + tokens_this_round, decode_input_tokens, &req_pool_allocator_, match_result, + config_.role, disable_l2_cache, std::move(loadback_diff), std::move(admission.cache_transfer_pairs), + hybrid_prefix_cache_, }; } @@ -176,23 +121,21 @@ std::optional Scheduler::schedulePrefill( std::int32_t pages_needed = (tokens_this_round + config_.page_size - 1) / config_.page_size; - if (!kv_prefix_cache_.EnsureCapacityByEvict(pages_needed)) { - return {}; - } - - if (hybrid_prefix_cache_ && hybrid_prefix_cache_->HasMambaAdjunct() && - !hybrid_prefix_cache_->EnsureMambaCapacityByEvict(1)) { - return {}; - } - const std::int32_t first_pos = request->PrefillSize() - unscheduled; const std::int32_t target = first_pos + tokens_this_round; - if (hybrid_prefix_cache_ && !hybrid_prefix_cache_->AdmitChunk(request->Id(), first_pos, target, simulated_free)) { + AdmissionRequest admission_request{ + .request_id = request->Id(), + .device_pages_needed = pages_needed, + .first_raw_position_of_op = first_pos, + .target_raw_tokens_exclusive = target, + .auxiliary_tree_slots_needed = 1, + }; + if (!hybrid_prefix_cache_.Admit(admission_request, simulated_free).admitted) { return {}; } return fsm::SchedulePrefillEvent{tokens_this_round, reserve_num_tokens_in_next_schedule_event, - hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr}; + hybrid_prefix_cache_}; } std::optional Scheduler::scheduleDecode(Request* request, @@ -201,57 +144,37 @@ std::optional Scheduler::scheduleDecode(Request* reque std::int32_t extra_tokens = std::max(0, request->GetReserveNumTokensInNextScheduleEvent() - tail_available); std::int32_t pages_needed = (extra_tokens + config_.page_size - 1) / config_.page_size; - if (!kv_prefix_cache_.EnsureCapacityByEvict(pages_needed)) { - return {}; - } - - if (hybrid_prefix_cache_ && hybrid_prefix_cache_->HasMambaAdjunct() && mamba_allocator_ && - request->Is() && request->GetLocalMambaAllocator() != nullptr && - !hybrid_prefix_cache_->EnsureMambaCapacityByEvict(1)) { - return {}; - } - const std::int32_t first_pos = request->TokenSize(); const std::int32_t target = first_pos + config_.decode_input_tokens; - if (hybrid_prefix_cache_ && !hybrid_prefix_cache_->AdmitChunk(request->Id(), first_pos, target, simulated_free)) { + AdmissionRequest admission_request{ + .request_id = request->Id(), + .device_pages_needed = pages_needed, + .first_raw_position_of_op = first_pos, + .target_raw_tokens_exclusive = target, + .refresh_mamba_checkpoint = request->Is(), + }; + if (!hybrid_prefix_cache_.Admit(admission_request, simulated_free).admitted) { return {}; } - return fsm::ScheduleDecodeEvent{config_.decode_input_tokens, - hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr}; + return fsm::ScheduleDecodeEvent{config_.decode_input_tokens, hybrid_prefix_cache_}; } std::optional Scheduler::scheduleDecodeFromRetracted( Request* request, std::map& simulated_free) { if (req_pool_allocator_.AvailableSlots() == 0) return {}; - MatchResult match_result = - hybrid_prefix_cache_ - ? hybrid_prefix_cache_->Match(request->GetFullPagedTokens(true), MatchIntent::StateRecovery) - : kv_prefix_cache_.Match(request->GetFullPagedTokens(true), MatchIntent::StateRecovery); + RecoveryPlan recovery_plan = + hybrid_prefix_cache_.MatchPrefix(request->GetFullPagedTokens(true), MatchIntent::StateRecovery); + MatchResult match_result = recovery_plan.compat_match; std::vector loadback_diff = match_result.NodesWithout(); - std::vector mamba_loadback_nodes; - TreeNode* mamba_recovery_node = nullptr; - bool needs_mamba_loadback = false; - if (hybrid_prefix_cache_ && mamba_allocator_) { - mamba_recovery_node = hybrid_prefix_cache_->FindLastMambaNode(match_result.host.last_node); - if (mamba_recovery_node == nullptr) { - mamba_recovery_node = hybrid_prefix_cache_->FindLastMambaHostNode(match_result.host.last_node); - needs_mamba_loadback = mamba_recovery_node != nullptr; - if (needs_mamba_loadback && !mamba_recovery_node->HasMamba()) { - AddUniqueNode(mamba_loadback_nodes, mamba_recovery_node); - } - } - if (mamba_recovery_node == nullptr) { - spdlog::warn("[Scheduler] Retracted request {} lost tree-owned Mamba state, aborting request", - request->Id()); - request->Apply(fsm::AbortEvent{}); - return {}; - } - if (!needs_mamba_loadback) { - match_result.mamba_cow_src_index = mamba_recovery_node->MambaSlotIndex(); - } + if (!recovery_plan.recovery_state_available) { + spdlog::warn("[Scheduler] Retracted request {} lost required cache recovery state, aborting request", + request->Id()); + request->Apply(fsm::AbortEvent{}); + return {}; } + TreeNode* mamba_recovery_node = recovery_plan.protected_recovery_node; const std::int32_t device_matched2 = match_result.device.DepthInPage(); const std::int32_t host_matched2 = match_result.host.DepthInPage(); @@ -264,49 +187,36 @@ std::optional Scheduler::scheduleDecodeFr } std::int32_t device_pages_needed = (num_tokens + config_.page_size - 1) / config_.page_size; - std::unique_ptr temp_lock = std::make_unique(match_result.device.last_node); - if (!kv_prefix_cache_.EnsureCapacityByEvict(device_pages_needed)) { - return {}; - } - if (hybrid_prefix_cache_ && mamba_allocator_) { - // Recovery COWs the tree-owned Mamba state into fresh request-local - // working/checkpoint slots. Protect the source node only for this - // allocation; retracted Mamba states are otherwise normal evictable - // tree-owned cache entries. - const std::int32_t mamba_slots_needed = 2 + CountMambaDeviceLoadBackSlots(mamba_loadback_nodes); - if (!hybrid_prefix_cache_->EnsureMambaCapacityByEvict(mamba_slots_needed, mamba_recovery_node)) { - return {}; - } - } - const std::int32_t target = request->TokenSize(); - if (hybrid_prefix_cache_ && !hybrid_prefix_cache_->AdmitChunkFromRetracted(request->Id(), target, simulated_free, - match_result.paged_cache)) { + AdmissionRequest admission_request{ + .request_id = request->Id(), + .device_pages_needed = device_pages_needed, + .target_raw_tokens_exclusive = target, + .compat_match = &match_result, + .protected_recovery_node = mamba_recovery_node, + .auxiliary_tree_slots_needed = 2, + .fresh_request_table_view = true, + }; + AdmissionVerdict admission = hybrid_prefix_cache_.Admit(admission_request, simulated_free); + if (!admission.admitted) { return {}; } - if (needs_mamba_loadback) { - hybrid_prefix_cache_->PrepareMambaDeviceLoadBack(mamba_loadback_nodes); - if (mamba_recovery_node->HasMamba()) { - match_result.mamba_cow_src_index = mamba_recovery_node->MambaSlotIndex(); - } + if (admission.mamba_cow_src_index.has_value()) { + match_result.mamba_cow_src_index = *admission.mamba_cow_src_index; } - return fsm::ScheduleDecodeFromRetractedEvent{ - config_.decode_input_tokens, - &device_allocator_, - &req_pool_allocator_, - &kv_prefix_cache_, - std::move(match_result), - loadback_diff, - mamba_allocator_ ? &*mamba_allocator_ : nullptr, - std::move(mamba_loadback_nodes), - }; + return fsm::ScheduleDecodeFromRetractedEvent{config_.decode_input_tokens, + &req_pool_allocator_, + std::move(match_result), + std::move(loadback_diff), + std::move(admission.cache_transfer_pairs), + hybrid_prefix_cache_}; } std::optional Scheduler::scheduleRetract(Request* request) { auto full_paged_tokens = request->GetFullPagedTokens(true); - std::vector prefix_pages = DevicePagesFromRoot(request->GetDeviceNode()); - std::int32_t total_available = static_cast(request->GetOccupiedPages().size()); + RequestCacheContext cache_context(*request); + std::int32_t total_available = cache_context.OccupiedPageCountSnapshot(); // Overlap scheduling: ExtendResult may grow the token container before the // next Acquire runs. Clamp to the pages we actually have. @@ -314,16 +224,27 @@ std::optional Scheduler::scheduleRetract(Request* req full_paged_tokens.resize(total_available); } - std::int32_t alloc_count = - static_cast(full_paged_tokens.size()) - static_cast(prefix_pages.size()); - - OwnedPages alloc_pages = request->TakeFirstPages(alloc_count); - - kv_prefix_cache_.Insert(full_paged_tokens, prefix_pages, std::move(alloc_pages)); - - MatchResult match_result = kv_prefix_cache_.Match(full_paged_tokens, MatchIntent::StateRecovery); + RequestCacheMutation cache_mutation(*request); + TreeNode* terminal_device_node = cache_mutation.MutableTerminalDeviceNode(); + StepCommitRequest insert_count_request{ + .plan_device_prefix_insertion = + DevicePrefixInsertionPlanRequest{ + .full_paged_tokens = &full_paged_tokens, + .current_device_node = terminal_device_node, + }, + }; + const std::int32_t alloc_count = + hybrid_prefix_cache_.StepCommit(std::move(insert_count_request)).device_insert_page_count; + StepCommitRequest publication_request{ + .publish_device_prefix_insertion = + DevicePrefixInsertionRequest{ + .full_paged_tokens = &full_paged_tokens, + .current_device_node = terminal_device_node, + .pages_to_insert = cache_mutation.TakeFirstLocalKVPages(alloc_count), + }, + }; + MatchResult match_result = hybrid_prefix_cache_.StepCommit(std::move(publication_request)).match_result; - std::unique_ptr temp_lock = std::make_unique(match_result.host.last_node); const std::int32_t device_matched3 = match_result.device.DepthInPage(); const std::int32_t host_matched3 = match_result.host.DepthInPage(); std::int32_t host_pages_needed = 0; @@ -331,14 +252,19 @@ std::optional Scheduler::scheduleRetract(Request* req host_pages_needed = device_matched3 - host_matched3; } - if (!kv_prefix_cache_.EnsureCapacityByEvict(host_pages_needed)) { + AdmissionRequest admission_request{ + .host_pages_needed = host_pages_needed, + .compat_match = &match_result, + .protect_host_match_node = true, + }; + std::map simulated_free; + if (!hybrid_prefix_cache_.Admit(admission_request, simulated_free).admitted) { return {}; } - return fsm::ScheduleRetractEvent{&kv_prefix_cache_, &host_allocator_, match_result, - hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr}; + return fsm::ScheduleRetractEvent{match_result, hybrid_prefix_cache_}; } -LoadBackOperation GenerateLoadBackOp(const std::vector& diff, const std::vector& mamba_nodes, +LoadBackOperation GenerateLoadBackOp(const std::vector& diff, std::vector extra_transfers, cache_op_id op_id) { std::vector transfers; @@ -349,11 +275,8 @@ LoadBackOperation GenerateLoadBackOp(const std::vector& diff, const s transfers.push_back(TransferPair{CacheKind::kKV, host_pages[i], device_pages[i]}); } } - for (TreeNode* node : mamba_nodes) { - if (node != nullptr && node->HasMambaOnHost() && node->HasMamba()) { - transfers.push_back(TransferPair{CacheKind::kMamba, node->MambaHostSlotIndex(), node->MambaSlotIndex()}); - } - } + transfers.insert(transfers.end(), std::make_move_iterator(extra_transfers.begin()), + std::make_move_iterator(extra_transfers.end())); return LoadBackOperation{op_id, std::move(transfers)}; } @@ -365,12 +288,11 @@ std::optional Scheduler::applyEventAndGenerateOp(Request* re const auto& pages_to_transfer = request->GetPagesToTransfer(); if (pages_to_transfer.empty()) { // No copy needed; advance Retracting to Retracted without an op_id. - request->Apply( - fsm::WriteBackDoneEvent{&kv_prefix_cache_, hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr}); + request->Apply(fsm::WriteBackDoneEvent{&kv_prefix_cache_, &hybrid_prefix_cache_}); return std::nullopt; } // Register op_id so WriteBackDone can route back. - cache_op_id op_id = kv_prefix_cache_.AllocateCacheOpId(); + cache_op_id op_id = hybrid_prefix_cache_.AllocateCacheOpId(); CacheOpSpec spec; spec.request_id = request->Id(); cache_op_tracker_[op_id] = std::move(spec); @@ -395,15 +317,17 @@ std::optional Scheduler::newRetractOperation(Request* retrac template requires(std::same_as || std::same_as) static PrefillOperation applyPrefillEvent(Request* request, Event event) { - std::int32_t begin = static_cast(request->GetOccupiedPages().size()); + RequestCacheContext pre_apply_cache(*request); + std::int32_t begin = pre_apply_cache.OccupiedPageCountSnapshot(); request->Apply(event); - std::vector all_pages = request->GetOccupiedPages(); + RequestCacheContext cache_context(*request); + std::vector all_pages = cache_context.OccupiedPagesSnapshot(); std::int32_t sz = static_cast(all_pages.size()) - begin; auto info = request->GetPrefillInfo(); auto op = PrefillOperation{{ .request_id = request->Id(), - .request_pool_index = request->GetReqPoolIndex(), + .request_pool_index = cache_context.RequestPoolIndex(), .input_length = info.extend_len, .occupied_pages = std::move(all_pages), .begin = begin, @@ -414,45 +338,49 @@ static PrefillOperation applyPrefillEvent(Request* request, Event event) { op.shifted_input_ids = std::move(info.shifted_input_ids); op.extend_prefix_len = info.already_scheduled_len; - auto* mamba = request->GetLocalMambaAllocator(); - if (mamba != nullptr && mamba->HasWorking()) { - op.mamba_working_idx = mamba->WorkingIndex(); - if (mamba->HasCheckpoint()) { - op.mamba_checkpoint_dst_idx = mamba->CheckpointIndex(); - } - } - return op; } PrefillOperation Scheduler::applyEventAndGenerateOp(Request* request, fsm::SchedulePrefillFirstChunkEvent event) { auto match = event.GetMatchResult(); auto op = applyPrefillEvent(request, std::move(event)); - // Mamba fields only when adjunct is active. - if (hybrid_prefix_cache_ && hybrid_prefix_cache_->HasMambaAdjunct()) { - op.mamba_cow_src_idx = match.mamba_cow_src_index; - op.mamba_branching_seqlen = match.mamba_branching_seqlen; - } - // Order: attach, acquire, populate. Attach before acquire so prior-chunk - // tail pages commit into snapshots before Acquire's ReleaseSkipped frees them. - if (hybrid_prefix_cache_) { - hybrid_prefix_cache_->CommitChunk(op.request_id, const_cast(request->GetDeviceNode())); - hybrid_prefix_cache_->AcquireForRequest(op.request_id, op.extend_prefix_len, - op.extend_prefix_len + op.input_length, match.paged_cache); - hybrid_prefix_cache_->PopulateOp(op); - } + RequestCacheContext cache_context(*request); + RequestCacheMutation cache_mutation(*request); + StepCommitRequest prepare_request{ + .worker_metadata = + WorkerCompatibilityCommitRequest{ + .op_base = &op, + .terminal = cache_mutation.MutableTerminalDeviceNode(), + .compat_match = &match, + .local_mamba_allocator_view = cache_context.LocalMambaAllocatorView(), + .first_raw_position_of_op = op.extend_prefix_len, + .target_raw_tokens_exclusive = op.extend_prefix_len + op.input_length, + .commit_tree_prefix_before_acquire = true, + .import_paged_cache_hit = true, + .populate_prefix_reuse_metadata = true, + }, + }; + (void)hybrid_prefix_cache_.StepCommit(std::move(prepare_request)); return op; } PrefillOperation Scheduler::applyEventAndGenerateOp(Request* request, fsm::SchedulePrefillEvent event) { auto op = applyPrefillEvent(request, std::move(event)); - // Order: attach, acquire, populate (see SchedulePrefillFirstChunkEvent). - if (hybrid_prefix_cache_) { - hybrid_prefix_cache_->CommitChunk(op.request_id, const_cast(request->GetDeviceNode())); - hybrid_prefix_cache_->AcquireForRequest(op.request_id, op.extend_prefix_len, - op.extend_prefix_len + op.input_length); - hybrid_prefix_cache_->PopulateOp(op); - } + RequestCacheContext cache_context(*request); + RequestCacheMutation cache_mutation(*request); + StepCommitRequest prepare_request{ + .worker_metadata = + WorkerCompatibilityCommitRequest{ + .op_base = &op, + .terminal = cache_mutation.MutableTerminalDeviceNode(), + .local_mamba_allocator_view = cache_context.LocalMambaAllocatorView(), + .first_raw_position_of_op = op.extend_prefix_len, + .target_raw_tokens_exclusive = op.extend_prefix_len + op.input_length, + .commit_tree_prefix_before_acquire = true, + .import_paged_cache_hit = true, + }, + }; + (void)hybrid_prefix_cache_.StepCommit(std::move(prepare_request)); return op; } @@ -460,14 +388,16 @@ template requires(std::same_as || std::same_as) static DecodeOperation applyDecodeEvent(Request* request, Event event, std::int32_t decode_input_tokens) { - std::int32_t begin = static_cast(request->GetOccupiedPages().size()); + RequestCacheContext pre_apply_cache(*request); + std::int32_t begin = pre_apply_cache.OccupiedPageCountSnapshot(); request->Apply(std::move(event)); - std::vector all_pages = request->GetOccupiedPages(); + RequestCacheContext cache_context(*request); + std::vector all_pages = cache_context.OccupiedPagesSnapshot(); std::int32_t sz = static_cast(all_pages.size()) - begin; auto op = DecodeOperation{{ .request_id = request->Id(), - .request_pool_index = request->GetReqPoolIndex(), + .request_pool_index = cache_context.RequestPoolIndex(), .input_length = decode_input_tokens, .occupied_pages = std::move(all_pages), .begin = begin, @@ -475,14 +405,6 @@ static DecodeOperation applyDecodeEvent(Request* request, Event event, std::int3 .prefill_length = request->PrefillSize(), }}; - auto* mamba = request->GetLocalMambaAllocator(); - if (mamba != nullptr && mamba->HasWorking()) { - op.mamba_working_idx = mamba->WorkingIndex(); - if (mamba->HasCheckpoint()) { - op.mamba_checkpoint_dst_idx = mamba->CheckpointIndex(); - } - } - return op; } @@ -496,31 +418,37 @@ DecodeOperation Scheduler::applyEventAndGenerateOp(Request* request, fsm::Schedu if (need_bootstrap_token) { op.decode_input_id = bootstrap_token; } - // Order: attach, acquire, populate. - if (hybrid_prefix_cache_) { - if (came_from_prefill_done) { - hybrid_prefix_cache_->CommitChunk(op.request_id, const_cast(request->GetDeviceNode())); - } - hybrid_prefix_cache_->AcquireForRequest(op.request_id, first_pos, first_pos + op.input_length); - hybrid_prefix_cache_->PopulateOp(op); - } + RequestCacheMutation cache_mutation(*request); + RequestCacheContext cache_context(*request); + StepCommitRequest prepare_request{ + .worker_metadata = + WorkerCompatibilityCommitRequest{ + .op_base = &op, + .terminal = cache_mutation.MutableTerminalDeviceNode(), + .local_mamba_allocator_view = cache_context.LocalMambaAllocatorView(), + .first_raw_position_of_op = first_pos, + .target_raw_tokens_exclusive = first_pos + op.input_length, + .commit_tree_prefix_before_acquire = came_from_prefill_done, + }, + }; + (void)hybrid_prefix_cache_.StepCommit(std::move(prepare_request)); return op; } DecodeOperation Scheduler::applyEventAndGenerateOp(Request* request, fsm::ScheduleDecodeFromRetractedEvent event) { - const std::int32_t mamba_cow_src_index = event.GetMatchResult().mamba_cow_src_index; - auto paged_cache_hit = event.GetMatchResult().paged_cache; + auto match = event.GetMatchResult(); request->Apply(std::move(event)); if (!request->Is()) { throw std::logic_error( "Scheduler::applyEventAndGenerateOp: expected state=Decoding after loadback recovery; got state=" + request->StateName()); } - std::vector all_pages = request->GetOccupiedPages(); + RequestCacheContext cache_context(*request); + std::vector all_pages = cache_context.OccupiedPagesSnapshot(); std::int32_t sz = static_cast(all_pages.size()); DecodeOperation op{{ .request_id = request->Id(), - .request_pool_index = request->GetReqPoolIndex(), + .request_pool_index = cache_context.RequestPoolIndex(), .input_length = config_.decode_input_tokens, .occupied_pages = std::move(all_pages), .begin = 0, @@ -528,21 +456,19 @@ DecodeOperation Scheduler::applyEventAndGenerateOp(Request* request, fsm::Schedu }}; op.decode_input_id = request->GetLastToken(); op.hist_token_len = request->TokenSize() - 1; - op.mamba_cow_src_idx = mamba_cow_src_index; - auto* mamba = request->GetLocalMambaAllocator(); - if (mamba != nullptr && mamba->HasWorking()) { - op.mamba_working_idx = mamba->WorkingIndex(); - if (mamba->HasCheckpoint()) { - op.mamba_checkpoint_dst_idx = mamba->CheckpointIndex(); - } - } - - if (hybrid_prefix_cache_) { - hybrid_prefix_cache_->ReleaseRequest(op.request_id); - hybrid_prefix_cache_->AcquireForRequest(op.request_id, 0, request->TokenSize(), paged_cache_hit); - hybrid_prefix_cache_->PopulateOp(op); - } + StepCommitRequest prepare_request{ + .worker_metadata = + WorkerCompatibilityCommitRequest{ + .op_base = &op, + .compat_match = &match, + .local_mamba_allocator_view = cache_context.LocalMambaAllocatorView(), + .target_raw_tokens_exclusive = request->TokenSize(), + .populate_recovery_metadata = true, + .release_request_state_before_acquire = true, + }, + }; + (void)hybrid_prefix_cache_.StepCommit(std::move(prepare_request)); return op; } @@ -574,8 +500,7 @@ Scheduler::newForwardOperation(std::vector candidates) { ops.push_back(std::move(op)); }; std::vector loadback_ops; - auto simulated_free = - hybrid_prefix_cache_ ? hybrid_prefix_cache_->InitialSimulatedFree() : std::map{}; + auto simulated_free = hybrid_prefix_cache_.InitialSimulatedFree(); for (Request* request : candidates) { if (token_budget <= 0 || config_.max_batch_size == ops.size()) break; @@ -591,12 +516,12 @@ Scheduler::newForwardOperation(std::vector candidates) { if (auto ev = schedulePrefillFirstChunk(request, token_budget, decode_input_tokens, config_.disable_l2_cache, simulated_free)) { std::vector loadback_diff = ev->GetLoadbackDiff(); - std::vector mamba_loadback_nodes = ev->GetMambaLoadbackNodes(); + std::vector cache_transfer_pairs = ev->GetCacheTransferPairs(); push_op(applyEventAndGenerateOp(request, std::move(*ev)), true); // will be empty when disable_l2_cache - if (!loadback_diff.empty() || !mamba_loadback_nodes.empty()) { - cache_op_id op_id = kv_prefix_cache_.AllocateCacheOpId(); - loadback_ops.push_back(GenerateLoadBackOp(loadback_diff, mamba_loadback_nodes, op_id)); + if (!loadback_diff.empty() || !cache_transfer_pairs.empty()) { + cache_op_id op_id = hybrid_prefix_cache_.AllocateCacheOpId(); + loadback_ops.push_back(GenerateLoadBackOp(loadback_diff, std::move(cache_transfer_pairs), op_id)); } } } else if (request->Is() || (request->Is() && config_.role != Role::kP)) { @@ -613,11 +538,11 @@ Scheduler::newForwardOperation(std::vector candidates) { if (auto ev = scheduleDecodeFromRetracted(request, simulated_free)) { std::vector loadback_diff = ev->GetLoadbackDiff(); - std::vector mamba_loadback_nodes = ev->GetMambaLoadbackNodes(); + std::vector cache_transfer_pairs = ev->GetCacheTransferPairs(); push_op(applyEventAndGenerateOp(request, std::move(*ev))); - if (!loadback_diff.empty() || !mamba_loadback_nodes.empty()) { - cache_op_id op_id = kv_prefix_cache_.AllocateCacheOpId(); - loadback_ops.push_back(GenerateLoadBackOp(loadback_diff, mamba_loadback_nodes, op_id)); + if (!loadback_diff.empty() || !cache_transfer_pairs.empty()) { + cache_op_id op_id = hybrid_prefix_cache_.AllocateCacheOpId(); + loadback_ops.push_back(GenerateLoadBackOp(loadback_diff, std::move(cache_transfer_pairs), op_id)); } } } diff --git a/tokenspeed-scheduler/csrc/scheduler/outside_event_handler.cpp b/tokenspeed-scheduler/csrc/scheduler/outside_event_handler.cpp index 2279df74a..d2ec9467e 100644 --- a/tokenspeed-scheduler/csrc/scheduler/outside_event_handler.cpp +++ b/tokenspeed-scheduler/csrc/scheduler/outside_event_handler.cpp @@ -26,6 +26,7 @@ #include "scheduler/page_hasher.h" #include "scheduler/scheduler.h" +#include "fsm/cache_events.h" #include "fsm/forward_events.h" #include "fsm/pd_events.h" @@ -92,8 +93,7 @@ void Scheduler::handleEvent(const pd::FailedEvent& event) {} void Scheduler::handleEvent(const pd::SucceededEvent& event) { std::vector page_hashes; requests_.at(event.request_id) - ->Apply(fsm::FinishEvent{&kv_prefix_cache_, &host_allocator_, std::move(page_hashes), config_.disable_l2_cache, - hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr}); + ->Apply(fsm::FinishEvent{std::move(page_hashes), config_.disable_l2_cache, hybrid_prefix_cache_}); } void Scheduler::handleEvent(const pd::RemotePrefillDoneEvent& event) { @@ -114,8 +114,7 @@ void Scheduler::handleEvent(const forward::Finish& event) { page_hashes = ComputePagedHashes(token_pages, ""); } } - req->Apply(fsm::FinishEvent{&kv_prefix_cache_, &host_allocator_, std::move(page_hashes), - config_.disable_l2_cache, hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr}); + req->Apply(fsm::FinishEvent{std::move(page_hashes), config_.disable_l2_cache, hybrid_prefix_cache_}); } } @@ -149,13 +148,9 @@ void Scheduler::handleEvent(const cache::WriteBackDone& event) { auto spec = std::move(it->second); cache_op_tracker_.erase(it); - auto now = std::chrono::steady_clock::now(); - for (TreeNode* n : spec.nodes) n->Touch(now); - if (!spec.request_id.empty()) { if (auto* req = find_request(spec.request_id)) { - req->Apply( - fsm::WriteBackDoneEvent{&kv_prefix_cache_, hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr}); + req->Apply(fsm::WriteBackDoneEvent{&kv_prefix_cache_, &hybrid_prefix_cache_}); } } } diff --git a/tokenspeed-scheduler/csrc/scheduler/request.h b/tokenspeed-scheduler/csrc/scheduler/request.h index 89b770c68..c59b9bdf6 100644 --- a/tokenspeed-scheduler/csrc/scheduler/request.h +++ b/tokenspeed-scheduler/csrc/scheduler/request.h @@ -152,13 +152,13 @@ class Request { state_); } - const TreeNode* GetDeviceNode() const { + TreeNode* GetMutableDeviceNode() { return std::visit(Overloaded{ - [](const T& s) -> const TreeNode* + [](T& s) -> TreeNode* requires(std::derived_from) - { return s.GetDeviceNode(); }, - [this](const auto&) -> const TreeNode* { - throw std::logic_error("Request::GetDeviceNode: expected a base request state; got state=" + + { return s.GetMutableDeviceNode(); }, + [this](auto&) -> TreeNode* { + throw std::logic_error("Request::GetMutableDeviceNode: expected a base request state; got state=" + StateName()); }, }, diff --git a/tokenspeed-scheduler/csrc/scheduler/request_cache_context.h b/tokenspeed-scheduler/csrc/scheduler/request_cache_context.h new file mode 100644 index 000000000..c2cf29c51 --- /dev/null +++ b/tokenspeed-scheduler/csrc/scheduler/request_cache_context.h @@ -0,0 +1,80 @@ +// Copyright (c) 2026 LightSeek Foundation +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#pragma once + +#include +#include + +#include "resource/allocator/owned_pages.h" +#include "scheduler/request.h" + +namespace tokenspeed { + +class LocalMambaAllocator; +class TreeNode; + +// Non-owning read view over request-local cache state used for worker-facing +// metadata, statistics, and diagnostics. It does not own allocators, node refs, +// req-pool indices, cache-operation ids, or tree/table state. +class RequestCacheContext { +public: + explicit RequestCacheContext(Request& request) : request_(request) {} + + RequestCacheContext(const RequestCacheContext&) = delete; + RequestCacheContext& operator=(const RequestCacheContext&) = delete; + + std::vector OccupiedPagesSnapshot() const { return request_.GetOccupiedPages(); } + + std::int32_t OccupiedPageCountSnapshot() const { return static_cast(OccupiedPagesSnapshot().size()); } + + // Debug-memory diagnostics read only the request-local KV allocator pages + // through this request cache-state boundary. Shared radix-tree pages remain + // owned and reported by HybridPrefixCache diagnostics. + std::vector LocalKVPagesSnapshot() const { return request_.GetLocalAllocatorPages(); } + + std::int32_t RequestPoolIndex() const { return request_.GetReqPoolIndex(); } + + const LocalMambaAllocator* LocalMambaAllocatorView() const { return request_.GetLocalMambaAllocator(); } + +private: + Request& request_; +}; + +// Explicit mutable bridge for cache lifecycle operations that need a mutable +// tree observer or request-local page ownership transfer. Keeping this separate +// from RequestCacheContext prevents read-only flattening paths from growing +// hidden mutation authority. +class RequestCacheMutation { +public: + explicit RequestCacheMutation(Request& request) : request_(request) {} + + RequestCacheMutation(const RequestCacheMutation&) = delete; + RequestCacheMutation& operator=(const RequestCacheMutation&) = delete; + + TreeNode* MutableTerminalDeviceNode() { return request_.GetMutableDeviceNode(); } + + OwnedPages TakeFirstLocalKVPages(std::int32_t alloc_count) { return request_.TakeFirstPages(alloc_count); } + +private: + Request& request_; +}; + +} // namespace tokenspeed diff --git a/tokenspeed-scheduler/csrc/scheduler/scheduler.cpp b/tokenspeed-scheduler/csrc/scheduler/scheduler.cpp index 5df53231d..8c457a1f5 100644 --- a/tokenspeed-scheduler/csrc/scheduler/scheduler.cpp +++ b/tokenspeed-scheduler/csrc/scheduler/scheduler.cpp @@ -20,7 +20,6 @@ #include "scheduler/scheduler.h" -#include #include #include #include @@ -28,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -44,22 +44,124 @@ #include "fsm/forward_states.h" #include "resource/kv_prefix_cache/kv_prefix_cache.h" #include "resource/radix_tree/radix_tree.h" -#include "resource/radix_tree/tree_node.h" +#include "scheduler/device_memory_diagnostics.h" #include "scheduler/execution_event.h" #include "scheduler/operations/cache.h" #include "scheduler/page_hasher.h" #include "scheduler/request.h" +#include "scheduler/request_cache_context.h" #include "scheduler/request_spec.h" #include "scheduler/types.h" namespace tokenspeed { +namespace { + +std::optional MakeMambaAllocator(const SchedulerConfig& config) { + if (config.enable_mamba && config.mamba_pool_total_chunks > 0) { + return std::optional{std::in_place, config.mamba_pool_total_chunks}; + } + return std::nullopt; +} + +MambaChunkAllocator* MambaAdjunctAllocator(std::optional& allocator, + const SchedulerConfig& config) { + const bool has_mamba_pool = allocator.has_value(); + const bool has_mamba_adjunct = has_mamba_pool && config.role != Role::kD; + return has_mamba_adjunct ? &*allocator : nullptr; +} + +std::optional MakeMambaHostAllocator(const SchedulerConfig& config) { + if (config.enable_mamba && config.mamba_pool_total_chunks > 0 && config.enable_mamba_l2 && + config.mamba_l2_host_slots > 0) { + return std::optional{std::in_place, config.mamba_l2_host_slots}; + } + return std::nullopt; +} + +MambaHostAllocator* MambaHostAdjunctAllocator(std::optional& host_allocator, + std::optional& device_allocator, + const SchedulerConfig& config) { + const bool has_mamba_adjunct = device_allocator.has_value() && config.role != Role::kD; + return has_mamba_adjunct && host_allocator.has_value() ? &*host_allocator : nullptr; +} + +} // namespace + +bool ValidateDeviceMemoryDiagnostics(const std::vector& request_pages, + const HybridPrefixCache::DeviceMemoryDiagnosticsSnapshot& device_snapshot) { + bool ok = true; + const std::int32_t total_device = device_snapshot.total_device_pages; + // page_id → (owner_req_id, state_name) for duplicate tail-page reporting + std::unordered_map> page_owner; + + for (const auto& snapshot : request_pages) { + for (std::int32_t p : snapshot.pages) { + auto [it, inserted] = page_owner.emplace(p, std::make_pair(snapshot.request_id, snapshot.state_name)); + if (!inserted) { + spdlog::error("[check_mem] DEVICE TAIL PAGE OVERLAP: page={} req1={}({}) req2={}({})", p, + it->second.first, it->second.second, snapshot.request_id, snapshot.state_name); + ok = false; + } + } + } + + // ── 2a. Check for duplicate page_ids inside the tree itself ───────────── + for (auto& [page, cnt] : device_snapshot.tree_device_pages) { + if (cnt > 1) { + spdlog::error("[check_mem] DEVICE TREE DUPLICATE: page={} appears {} times in radix tree", page, cnt); + ok = false; + } + } + + const std::int32_t tree_device_total = static_cast(device_snapshot.tree_device_pages.size()); + + std::int32_t req_device_total = 0; + for (const auto& snapshot : request_pages) { + req_device_total += static_cast(snapshot.pages.size()); + } + + const std::int32_t free_device = device_snapshot.free_device_pages; + + if (tree_device_total + req_device_total + free_device != total_device) { + spdlog::error("[check_mem] DEVICE PAGE ACCOUNTING MISMATCH: tree={} req={} free={} sum={} total={}", + tree_device_total, req_device_total, free_device, + tree_device_total + req_device_total + free_device, total_device); + ok = false; + } + + // ── 4. Per-request: page ids must be in [1, total] ──────────────────── + // PageAllocator starts from page id 1 (0 is reserved as invalid/null). + for (const auto& snapshot : request_pages) { + for (std::int32_t p : snapshot.pages) { + if (p <= 0 || p > total_device) { + spdlog::error("[check_mem] INVALID DEVICE PAGE id={} for req={} (valid range [1,{}])", p, + snapshot.request_id, total_device); + ok = false; + } + } + } + for (const auto& entry : device_snapshot.tree_device_pages) { + const std::int32_t p = entry.first; + if (p <= 0 || p > total_device) { + spdlog::error("[check_mem] INVALID DEVICE PAGE id={} in radix tree (valid range [1,{}])", p, total_device); + ok = false; + } + } + + return ok; +} + Scheduler::Scheduler(SchedulerConfig config) : config_{std::move(config)}, device_allocator_{config_.page_size, config_.device_allocator.total_pages}, host_allocator_{config_.page_size, config_.host_allocator.total_pages}, - mamba_allocator_{}, + mamba_allocator_{MakeMambaAllocator(config_)}, + mamba_host_allocator_{MakeMambaHostAllocator(config_)}, kv_prefix_cache_{&device_allocator_, &host_allocator_, config_.enable_l3_storage, config_.disable_prefix_cache}, + hybrid_prefix_cache_{kv_prefix_cache_, device_allocator_, MambaAdjunctAllocator(mamba_allocator_, config_), + config_.mamba_cache_chunk_size, + MambaHostAdjunctAllocator(mamba_host_allocator_, mamba_allocator_, config_)}, req_pool_allocator_{config_.max_batch_size} { if (auto* env = std::getenv("SPDLOG_LEVEL")) { std::string level_str{env}; @@ -68,68 +170,18 @@ Scheduler::Scheduler(SchedulerConfig config) } if (config_.enable_kv_cache_events) { - kv_prefix_cache_.SetKvEventSink([this](KvCacheEvent event) { kv_events_.push_back(std::move(event)); }); - } - const bool has_mamba_pool = config_.enable_mamba && config_.mamba_pool_total_chunks > 0; - if (has_mamba_pool) { - mamba_allocator_.emplace(config_.mamba_pool_total_chunks); - } - const bool has_mamba_l2_pool = has_mamba_pool && config_.enable_mamba_l2 && config_.mamba_l2_host_slots > 0; - if (has_mamba_l2_pool) { - mamba_host_allocator_.emplace(config_.mamba_l2_host_slots); - } - - // Construct HybridPrefixCache when any adjunct/paged-cache feature is configured. - // Role::kD skips Mamba but still participates in paged-cache transport. - const bool has_mamba_adjunct = has_mamba_pool && config_.role != Role::kD; - const bool has_prefix_cache_adjunct = config_.prefix_cache_adjunct.has_value(); - const bool has_paged_cache_groups = !config_.paged_cache_groups.empty(); - if (has_mamba_adjunct || has_prefix_cache_adjunct || has_paged_cache_groups) { - MambaChunkAllocator* mamba_ptr = has_mamba_adjunct ? &*mamba_allocator_ : nullptr; - MambaHostAllocator* mamba_host_ptr = has_mamba_l2_pool ? &*mamba_host_allocator_ : nullptr; - hybrid_prefix_cache_.emplace(kv_prefix_cache_, mamba_ptr, config_.mamba_cache_chunk_size, mamba_host_ptr); - kv_prefix_cache_.GetDeviceManager().SetEvictionCallback( - [this](TreeNode* node) { hybrid_prefix_cache_->OnKVEvict(node); }); - kv_prefix_cache_.GetHostManager().SetEvictionCallback( - [this](TreeNode* node) { hybrid_prefix_cache_->OnKVHostEvict(node); }); - - for (const auto& cfg : config_.paged_cache_groups) { - PagedCacheGroupConfig copy = cfg; - copy.Validate(); - hybrid_prefix_cache_->RegisterPagedCacheGroup(std::make_unique(std::move(copy))); - } - - if (has_prefix_cache_adjunct) { - const auto& spec = *config_.prefix_cache_adjunct; - if (spec.required_groups.empty()) { - throw std::invalid_argument("Scheduler: prefix_cache_adjunct.required_groups must be non-empty"); - } - // HybridPrefixCache derives history alignment from the registered - // group configs; we still build the sliding-window map here. - std::unordered_map sliding_window_per_group; - for (const auto& gid : spec.required_groups) { - const PagedCacheGroupConfig* cfg = nullptr; - for (const auto& g : config_.paged_cache_groups) { - if (g.group_id == gid) { - cfg = &g; - break; - } - } - if (cfg == nullptr) { - throw std::invalid_argument("Scheduler: prefix_cache_adjunct required group_id '" + gid + - "' not found in paged_cache_groups"); - } - if (cfg->retention == PagedCacheGroupConfig::Retention::SlidingWindow) { - if (!cfg->sliding_window_tokens.has_value() || *cfg->sliding_window_tokens <= 0) { - throw std::invalid_argument("Scheduler: prefix_cache_adjunct sliding group '" + gid + - "' must declare positive sliding_window_tokens"); - } - sliding_window_per_group.emplace(gid, *cfg->sliding_window_tokens); - } - } - hybrid_prefix_cache_->EnablePagedCacheAdjunct(spec.required_groups, std::move(sliding_window_per_group)); - } + hybrid_prefix_cache_.SetKvEventSink([this](KvCacheEvent event) { kv_events_.push_back(std::move(event)); }); + } + std::optional> required_paged_cache_groups; + if (config_.prefix_cache_adjunct.has_value()) { + required_paged_cache_groups = std::span{config_.prefix_cache_adjunct->required_groups}; } + hybrid_prefix_cache_.ConfigurePagedCacheAdjunct(std::span{config_.paged_cache_groups}, + required_paged_cache_groups); +} + +Scheduler::~Scheduler() { + hybrid_prefix_cache_.SetKvEventSink({}); } std::vector Scheduler::DrainKvEvents() { @@ -149,16 +201,15 @@ std::vector Scheduler::CalcRollingHash(const std::vector= static_cast(num_pages)) { return {}; } - const auto& hashes = result.host.last_node->PageHashes(); - std::string prior = hashes.empty() ? std::string{} : hashes.back(); return ComputePagedHashes( - std::vector>(token_pages.begin() + host_matched, token_pages.end()), prior); + std::vector>(token_pages.begin() + host_matched, token_pages.end()), + raw_host_seed.prior_hash_seed); } void Scheduler::SubmitRequests(const std::vector& request_specs) { @@ -209,14 +260,15 @@ std::size_t Scheduler::RetractedSize() const { } std::size_t Scheduler::AvailableKvPages() const { - return device_allocator_.AvailablePages(); + return hybrid_prefix_cache_.Stats().available_device_pages; } std::size_t Scheduler::ActiveKvPages() const { std::unordered_set active_pages; for (const auto& [_, req] : requests_) { if (req->Is() || req->Is() || req->Is()) { - for (std::int32_t page : req->GetOccupiedPages()) { + RequestCacheContext cache_context(*req); + for (std::int32_t page : cache_context.OccupiedPagesSnapshot()) { active_pages.insert(page); } } @@ -225,45 +277,32 @@ std::size_t Scheduler::ActiveKvPages() const { } std::vector Scheduler::PagedCacheGroupIds() const { - if (!hybrid_prefix_cache_) return {}; - return hybrid_prefix_cache_->PagedCacheGroupIds(); + return hybrid_prefix_cache_.Stats().paged_cache_group_ids; } std::int32_t Scheduler::PagedCacheGroupTotalPages(const std::string& group_id) const { - if (!hybrid_prefix_cache_) { - throw std::out_of_range("Scheduler::PagedCacheGroupTotalPages: group_id not configured"); - } - return hybrid_prefix_cache_->PagedCacheGroupTotalPages(group_id); + return hybrid_prefix_cache_.Stats({.paged_cache_group_ids = {group_id}}).paged_cache_total_pages.at(group_id); } std::int32_t Scheduler::PagedCacheGroupAvailablePages(const std::string& group_id) const { - if (!hybrid_prefix_cache_) { - throw std::out_of_range("Scheduler::PagedCacheGroupAvailablePages: group_id not configured"); - } - return hybrid_prefix_cache_->PagedCacheGroupAvailablePages(group_id); + return hybrid_prefix_cache_.Stats({.paged_cache_group_ids = {group_id}}).paged_cache_available_pages.at(group_id); } std::int64_t Scheduler::PagedCacheGroupFailedAllocCount(const std::string& group_id) const { - if (!hybrid_prefix_cache_) { - throw std::out_of_range("Scheduler::PagedCacheGroupFailedAllocCount: group_id not configured"); - } - return hybrid_prefix_cache_->PagedCacheGroupFailedAllocCount(group_id); + return hybrid_prefix_cache_.Stats({.paged_cache_group_ids = {group_id}}) + .paged_cache_failed_alloc_count.at(group_id); } std::vector Scheduler::GetRequestPagedCachePageIds(const std::string& request_id, const std::string& group_id) const { - if (!hybrid_prefix_cache_) { - throw std::out_of_range("Scheduler::GetRequestPagedCachePageIds: group_id not configured"); - } - return hybrid_prefix_cache_->GetRequestPagedCachePageIds(request_id, group_id); + return hybrid_prefix_cache_.Stats({.request_id = request_id, .paged_cache_group_ids = {group_id}}) + .request_paged_cache_page_ids.at(group_id); } std::int32_t Scheduler::GetRequestPagedCacheBaseLogicalPage(const std::string& request_id, const std::string& group_id) const { - if (!hybrid_prefix_cache_) { - throw std::out_of_range("Scheduler::GetRequestPagedCacheBaseLogicalPage: group_id not configured"); - } - return hybrid_prefix_cache_->GetRequestPagedCacheBaseLogicalPage(request_id, group_id); + return hybrid_prefix_cache_.Stats({.request_id = request_id, .paged_cache_group_ids = {group_id}}) + .request_paged_cache_base_logical_page.at(group_id); } std::int32_t Scheduler::GetRequestTokenSize(const std::string& id) const { @@ -285,7 +324,7 @@ std::vector Scheduler::newWriteBackOperation( const auto& pages_to_transfer = req->GetPagesToTransfer(); if (!pages_to_transfer.empty()) { - cache_op_id op_id = kv_prefix_cache_.AllocateCacheOpId(); + cache_op_id op_id = hybrid_prefix_cache_.AllocateCacheOpId(); CacheOpSpec spec; spec.request_id = id; cache_op_tracker_[op_id] = std::move(spec); @@ -305,11 +344,9 @@ ExecutionPlan Scheduler::NextExecutionPlan() { std::vector write_back_ops; write_back_ops = std::move(newWriteBackOperation(requests_)); - if (hybrid_prefix_cache_) { - for (const auto& [id, req] : requests_) { - if (req->Is()) { - hybrid_prefix_cache_->ReleaseRequest(id); - } + for (const auto& [id, req] : requests_) { + if (req->Is()) { + hybrid_prefix_cache_.FinishRequest(id); } } std::erase_if(requests_, [](const auto& req) { return req.second->template Is(); }); @@ -345,73 +382,27 @@ ExecutionPlan Scheduler::NextExecutionPlan() { } void Scheduler::check_device_mem() { - bool ok = true; - const std::int32_t total_device = device_allocator_.TotalPages() - 1; - std::unordered_map> req_pages_map; - // page_id → (owner_req_id, state_name) for duplicate tail-page reporting - std::unordered_map> page_owner; + std::vector request_page_snapshots; for (auto& [id, req] : requests_) { - std::string state = req->StateName(); - std::vector pages = req->GetLocalAllocatorPages(); + RequestCacheContext cache_context(*req); + std::vector pages = cache_context.LocalKVPagesSnapshot(); if (pages.empty()) continue; - req_pages_map[id] = pages; - - for (std::int32_t p : pages) { - auto [it, inserted] = page_owner.emplace(p, std::make_pair(id, state)); - if (!inserted) { - spdlog::error("[check_mem] DEVICE TAIL PAGE OVERLAP: page={} req1={}({}) req2={}({})", p, - it->second.first, it->second.second, id, state); - ok = false; - } - } - } - - // ── 2. Collect pages in radix tree ─────────────────────────────────────── - auto tree_device_pages = kv_prefix_cache_.CollectAllPages(); - - // 2a. Check for duplicate page_ids inside the tree itself - for (auto& [page, cnt] : tree_device_pages) { - if (cnt > 1) { - spdlog::error("[check_mem] DEVICE TREE DUPLICATE: page={} appears {} times in radix tree", page, cnt); - ok = false; - } + request_page_snapshots.push_back(RequestLocalKVPagesSnapshot{ + .request_id = id, + .state_name = req->StateName(), + .pages = std::move(pages), + }); } - std::int32_t tree_device_total = static_cast(tree_device_pages.size()); - - std::int32_t req_device_total = 0; - for (auto& [id, pages] : req_pages_map) req_device_total += static_cast(pages.size()); - - std::int32_t free_device = device_allocator_.AvailablePages(); - - if (tree_device_total + req_device_total + free_device != total_device) { - spdlog::error("[check_mem] DEVICE PAGE ACCOUNTING MISMATCH: tree={} req={} free={} sum={} total={}", - tree_device_total, req_device_total, free_device, - tree_device_total + req_device_total + free_device, total_device); - ok = false; - } - - // ── 4. Per-request: page ids must be in [1, total] ──────────────────── - // PageAllocator starts from page id 1 (0 is reserved as invalid/null). - for (auto& [id, pages] : req_pages_map) { - for (std::int32_t p : pages) { - if (p <= 0 || p > total_device) { - spdlog::error("[check_mem] INVALID DEVICE PAGE id={} for req={} (valid range [1,{}])", p, id, - total_device); - ok = false; - } - } - } - for (auto& [p, cnt] : tree_device_pages) { - if (p <= 0 || p > total_device) { - spdlog::error("[check_mem] INVALID DEVICE PAGE id={} in radix tree (valid range [1,{}])", p, total_device); - ok = false; - } + auto stats_snapshot = hybrid_prefix_cache_.Stats({.include_device_memory_diagnostics = true}); + if (!stats_snapshot.device_memory_diagnostics.has_value()) { + throw std::runtime_error("Scheduler::check_device_mem: missing diagnostics snapshot"); } + auto device_snapshot = std::move(*stats_snapshot.device_memory_diagnostics); // ── 5. Summary ──────────────────────────────────────────────────────────── - if (!ok) { + if (!ValidateDeviceMemoryDiagnostics(request_page_snapshots, device_snapshot)) { throw std::runtime_error("Scheduler::CheckMem: device page accounting check failed"); } } diff --git a/tokenspeed-scheduler/csrc/scheduler/scheduler.h b/tokenspeed-scheduler/csrc/scheduler/scheduler.h index c36c3a413..24b0c425a 100644 --- a/tokenspeed-scheduler/csrc/scheduler/scheduler.h +++ b/tokenspeed-scheduler/csrc/scheduler/scheduler.h @@ -45,13 +45,12 @@ #include "resource/hybrid_prefix_cache/hybrid_prefix_cache.h" #include "fsm/forward_events.h" -#include "fsm/cache_events.h" -#include "fsm/pd_events.h" namespace tokenspeed { class Scheduler { public: explicit Scheduler(SchedulerConfig config); + ~Scheduler(); void SubmitRequests(const std::vector& request_specs); std::vector CalcRollingHash(const std::vector& input_tokens, bool apply_match = false); @@ -91,9 +90,6 @@ class Scheduler { DecodeOperation applyEventAndGenerateOp(Request* request, fsm::ScheduleDecodeEvent event); DecodeOperation applyEventAndGenerateOp(Request* request, fsm::ScheduleDecodeFromRetractedEvent event); std::optional applyEventAndGenerateOp(Request* request, fsm::ScheduleRetractEvent event); - PrefetchOperation applyEventAndGenerateOp(Request* request, fsm::SchedulePrefetchEvent event); - - std::optional schedulePrefetch(Request* request, const MatchResult& match); std::optional schedulePrefillFirstChunk( Request* request, std::int32_t remaining, std::int32_t reserve_num_tokens_in_next_schedule_event, @@ -136,8 +132,8 @@ class Scheduler { std::optional mamba_allocator_{}; std::optional mamba_host_allocator_{}; KVPrefixCache kv_prefix_cache_; + HybridPrefixCache hybrid_prefix_cache_; ReqPoolAllocator req_pool_allocator_; - std::optional hybrid_prefix_cache_{}; private: std::unordered_map> requests_; diff --git a/tokenspeed-scheduler/tests/cpp/hybrid_prefix_cache_test_peer.h b/tokenspeed-scheduler/tests/cpp/hybrid_prefix_cache_test_peer.h index 2fc5ba956..a50858eb8 100644 --- a/tokenspeed-scheduler/tests/cpp/hybrid_prefix_cache_test_peer.h +++ b/tokenspeed-scheduler/tests/cpp/hybrid_prefix_cache_test_peer.h @@ -16,14 +16,75 @@ #pragma once -// Test-only friend of HybridPrefixCache; exposes hooks needed to drive prune -// paths whose direct public surface is non-trivial to set up via AdmitChunk. +// Test-only friend of HybridPrefixCache; exposes narrow hooks needed to seed or +// inspect internals while production callers use the scheduler-facing facades. + +#include +#include +#include +#include +#include +#include +#include #include "resource/hybrid_prefix_cache/hybrid_prefix_cache.h" #include "resource/radix_tree/tree_node.h" namespace tokenspeed { -class HybridPrefixCacheTestPeer {}; +class HybridPrefixCacheTestPeer { +public: + static void InsertMamba(HybridPrefixCache& cache, TreeNode* terminal_node, std::unique_ptr slot) { + cache.InsertMamba(terminal_node, std::move(slot)); + } + + static TreeNode* FindLastMambaNode(const HybridPrefixCache& cache, TreeNode* from) { + return cache.FindLastMambaNode(from); + } + + static TreeNode* FindLastMambaHostNode(const HybridPrefixCache& cache, TreeNode* from) { + return cache.FindLastMambaHostNode(from); + } + + static std::vector PrepareMambaHostWriteBack(HybridPrefixCache& cache, + const std::vector& nodes) { + return cache.PrepareMambaHostWriteBack(nodes); + } + + static std::vector PrepareMambaDeviceLoadBack(HybridPrefixCache& cache, + const std::vector& nodes) { + return cache.PrepareMambaDeviceLoadBack(nodes); + } + + static void PublishFinishMambaState(HybridPrefixCache& cache, + const std::vector>& full_paged_tokens, + LocalMambaAllocator* local_mamba_allocator) { + cache.PublishFinishMambaState(full_paged_tokens, local_mamba_allocator); + } + + static void AcquireForRequest(HybridPrefixCache& cache, const std::string& request_id, + std::int32_t first_raw_position_of_op, std::int32_t target_raw_tokens_exclusive, + const MatchResult::PagedCache& paged_cache_hit = {}) { + cache.AcquireForRequest(request_id, first_raw_position_of_op, target_raw_tokens_exclusive, paged_cache_hit); + } + + static void ReleaseRequest(HybridPrefixCache& cache, const std::string& request_id) { + cache.ReleaseRequest(request_id); + } + + static void CommitChunk(HybridPrefixCache& cache, const std::string& request_id, TreeNode* terminal) { + cache.CommitChunk(request_id, terminal); + } + + static bool AttachPagedCacheSnapshotToNode(HybridPrefixCache& cache, TreeNode* node, + std::unique_ptr snapshot) { + return cache.AttachPagedCacheSnapshotToNode(node, std::move(snapshot)); + } + + static std::unique_ptr DetachPagedCacheSnapshotFromNode(HybridPrefixCache& cache, + TreeNode* node) { + return cache.DetachPagedCacheSnapshotFromNode(node); + } +}; } // namespace tokenspeed diff --git a/tokenspeed-scheduler/tests/cpp/paged_cache_test_fixture.h b/tokenspeed-scheduler/tests/cpp/paged_cache_test_fixture.h index 981cfcd6b..0f933d1c4 100644 --- a/tokenspeed-scheduler/tests/cpp/paged_cache_test_fixture.h +++ b/tokenspeed-scheduler/tests/cpp/paged_cache_test_fixture.h @@ -30,6 +30,7 @@ #include "resource/allocator/page_allocator.h" #include "resource/allocator/paged_cache_group.h" #include "resource/hybrid_prefix_cache/hybrid_prefix_cache.h" +#include "hybrid_prefix_cache_test_peer.h" #include "resource/kv_prefix_cache/kv_prefix_cache.h" #include "resource/radix_tree/paged_cache_snapshot.h" #include "resource/radix_tree/radix_tree.h" @@ -72,13 +73,12 @@ class PagedCacheTestFixtureT : public ::testing::Test { fh_alloc_ = fh_owner.get(); swa_alloc_ = swa_owner.get(); - hybrid_ = std::make_unique(*kv_cache_, /*mamba=*/nullptr, + hybrid_ = std::make_unique(*kv_cache_, *device_alloc_, /*mamba=*/nullptr, /*mamba_chunk_size=*/0); hybrid_->RegisterPagedCacheGroup(std::move(fh_owner)); hybrid_->RegisterPagedCacheGroup(std::move(swa_owner)); std::unordered_map sliding{{"swa", kSlidingWindow}}; hybrid_->EnablePagedCacheAdjunct(/*required=*/{"fh", "swa"}, std::move(sliding)); - kv_cache_->GetDeviceManager().SetEvictionCallback([this](TreeNode* node) { hybrid_->OnKVEvict(node); }); } // Insert pages from `start_node` (nullptr=root); returns terminal node. @@ -114,10 +114,10 @@ class PagedCacheTestFixtureT : public ::testing::Test { // Detach and reattach without the state group; re-attach recomputes // `complete_families` and leaves only History present. void DowngradeSnapshotToHistoryOnly(TreeNode* node) { - auto snap = hybrid_->DetachPagedCacheSnapshotFromNode(node); + auto snap = HybridPrefixCacheTestPeer::DetachPagedCacheSnapshotFromNode(*hybrid_, node); ASSERT_NE(snap, nullptr); snap->groups.erase("swa"); - hybrid_->AttachPagedCacheSnapshotToNode(node, std::move(snap)); + HybridPrefixCacheTestPeer::AttachPagedCacheSnapshotToNode(*hybrid_, node, std::move(snap)); } std::unique_ptr device_alloc_; diff --git a/tokenspeed-scheduler/tests/cpp/test_basic_lifecycle.cpp b/tokenspeed-scheduler/tests/cpp/test_basic_lifecycle.cpp index f4ba012c5..62a69a796 100644 --- a/tokenspeed-scheduler/tests/cpp/test_basic_lifecycle.cpp +++ b/tokenspeed-scheduler/tests/cpp/test_basic_lifecycle.cpp @@ -18,6 +18,8 @@ // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE // SOFTWARE. +#include + #include "integration_test_helper.h" namespace tokenspeed::test { @@ -115,10 +117,62 @@ TEST_F(BasicLifecycleTestSuite, GetRequestTokenSize_UnknownRequest) { TEST_F(BasicLifecycleTestSuite, AvailableKvPages_DecreasesAfterPrefill) { auto before = scheduler_->AvailableKvPages(); + EXPECT_EQ(before, static_cast(Config().device_allocator.total_pages - 1)); Submit(MakeRequestSpec("r1", 2)); PlanOnce(); auto after = scheduler_->AvailableKvPages(); EXPECT_LT(after, before); + EXPECT_EQ(before - after, scheduler_->ActiveKvPages()); +} + +TEST_F(BasicLifecycleTestSuite, ActiveKvPagesCountsOnlyForwardStates) { + EXPECT_EQ(scheduler_->ActiveKvPages(), 0u); + + Submit(MakeRequestSpec("r1", 1)); + EXPECT_EQ(scheduler_->ActiveKvPages(), 0u) << "submitted requests must not contribute"; + + PlanOnce(); + EXPECT_EQ(scheduler_->ActiveKvPages(), 2u) << "active forward state contributes occupied prefix/local pages"; + + SendForwardDone("r1", {42}); + EXPECT_EQ(scheduler_->ActiveKvPages(), 2u); + + PlanOnce(); + EXPECT_EQ(scheduler_->DecodingSize(), 1u); + EXPECT_EQ(scheduler_->ActiveKvPages(), 2u) << "Decoding remains active"; + + SendFinish("r1"); + EXPECT_EQ(scheduler_->ActiveKvPages(), 0u) << "Draining/finished requests must not contribute"; + + const auto writeback_plan = PlanOnce(); + EXPECT_FALSE(ExtractCacheOpsOfKind(writeback_plan).empty()); + EXPECT_EQ(scheduler_->ActiveKvPages(), 0u) << "WritingBack requests must not contribute"; +} + +TEST_F(BasicLifecycleTestSuite, ActiveKvPagesDeduplicatesSharedPrefixPages) { + Submit(MakeRequestSpec("r_seed", 2, 1)); + PlanOnce(); + SendForwardDone("r_seed", {101}); + PlanOnce(); + ASSERT_EQ(scheduler_->DecodingSize(), 1u); + ASSERT_EQ(scheduler_->ActiveKvPages(), 3u); + + Submit(MakeRequestSpec("r_reuse", 2, 1)); + EXPECT_EQ(scheduler_->ActiveKvPages(), 3u) << "submitted requests must not add pages"; + + auto reuse_plan = PlanOnce(); + auto* reuse_fwd = GetForwardOp(reuse_plan); + ASSERT_NE(reuse_fwd, nullptr); + ASSERT_EQ(reuse_fwd->request_ids.size(), 1u); + EXPECT_EQ(reuse_fwd->request_ids[0], "r_reuse"); + ASSERT_EQ(reuse_fwd->extend_prefix_lens.size(), 1u); + EXPECT_EQ(reuse_fwd->extend_prefix_lens[0], PageSize()); + + // r_seed and r_reuse both observe the same prefix-cache page for the same + // prompt. ActiveKvPages counts that shared page id once; summing per-request + // occupied-page snapshots would report 6 pages instead of the 5 unique page + // ids visible through the public statistic. + EXPECT_EQ(scheduler_->ActiveKvPages(), 5u); } } // namespace tokenspeed::test diff --git a/tokenspeed-scheduler/tests/cpp/test_hybrid_cache_registry.cpp b/tokenspeed-scheduler/tests/cpp/test_hybrid_cache_registry.cpp new file mode 100644 index 000000000..9a30bfcb8 --- /dev/null +++ b/tokenspeed-scheduler/tests/cpp/test_hybrid_cache_registry.cpp @@ -0,0 +1,204 @@ +// Copyright (c) 2026 LightSeek Foundation +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "resource/allocator/mamba_chunk_allocator.h" +#include "resource/allocator/page_allocator.h" +#include "resource/allocator/paged_cache_group.h" +#include "resource/hybrid_prefix_cache/hybrid_prefix_cache.h" +#include "resource/hybrid_prefix_cache/hybrid_prefix_cache_types.h" +#include "resource/kv_prefix_cache/kv_prefix_cache.h" + +namespace tokenspeed::test { + +namespace { + +constexpr std::int32_t kPageSize = 4; +constexpr std::int32_t kMambaChunkSize = 8; + +bool Contains(const std::vector& values, std::int32_t value) { + return std::find(values.begin(), values.end(), value) != values.end(); +} + +PagedCacheGroupConfig MakePagedGroup(std::string group_id, PagedCacheGroupFamily family, + PagedCacheGroupConfig::Retention retention, + std::optional sliding_window_tokens = std::nullopt) { + return PagedCacheGroupConfig{ + .group_id = std::move(group_id), + .rows_per_page = 4, + .entry_stride_tokens = 2, + .total_pages = 16, + .retention = retention, + .sliding_window_tokens = sliding_window_tokens, + .family = family, + }; +} + +} // namespace + +TEST(HybridPrefixCacheRegistryTest, KVOnlyRegistryHasSingleCarrierFamily) { + PageAllocator device_allocator{kPageSize, /*total_pages=*/8}; + PageAllocator host_allocator{kPageSize, /*total_pages=*/0}; + KVPrefixCache prefix_cache{&device_allocator, &host_allocator}; + HybridPrefixCache hybrid_prefix_cache{prefix_cache, device_allocator, /*allocator=*/nullptr, kMambaChunkSize}; + + const FamilyRegistry& registry = hybrid_prefix_cache.Registry(); + ASSERT_EQ(registry.specs.size(), 1u); + + const CacheResourceSpec* kv = registry.FindById("kv.token_page"); + ASSERT_NE(kv, nullptr); + EXPECT_EQ(kv->family_index, 0); + EXPECT_EQ(kv->family, CacheFamily::TokenPage); + EXPECT_EQ(kv->attachment_kind, TreeAttachmentKind::ReusableTree); + EXPECT_EQ(kv->recoverability, Recoverability::Exact); + EXPECT_EQ(kv->publication, PublicationKind::CanonicalPrefixIndex); + EXPECT_EQ(kv->split_policy, SplitPolicy::CarrierKV); + EXPECT_EQ(kv->rows_per_page, kPageSize); + EXPECT_TRUE(kv->required_for_recovery); + + EXPECT_EQ(registry.active_match_family_indices, std::vector{0}); + EXPECT_EQ(registry.active_admit_family_indices, std::vector{0}); + EXPECT_EQ(registry.active_commit_family_indices, std::vector{0}); + EXPECT_EQ(registry.active_evict_family_indices, std::vector{0}); + EXPECT_EQ(registry.active_finish_family_indices, std::vector{0}); + EXPECT_EQ(registry.active_stats_family_indices, std::vector{0}); + EXPECT_EQ(registry.active_compatibility_family_indices, std::vector{0}); + EXPECT_EQ(registry.FindById("mamba.checkpoint"), nullptr); +} + +TEST(HybridPrefixCacheRegistryTest, MambaAdjunctRegistersAlignedCheckpointFamily) { + PageAllocator device_allocator{kPageSize, /*total_pages=*/8}; + PageAllocator host_allocator{kPageSize, /*total_pages=*/0}; + KVPrefixCache prefix_cache{&device_allocator, &host_allocator}; + MambaChunkAllocator mamba_allocator{/*num_slots=*/4}; + HybridPrefixCache hybrid_prefix_cache{prefix_cache, device_allocator, &mamba_allocator, kMambaChunkSize}; + + const FamilyRegistry& registry = hybrid_prefix_cache.Registry(); + ASSERT_EQ(registry.specs.size(), 2u); + + const CacheResourceSpec* mamba = registry.FindById("mamba.checkpoint"); + ASSERT_NE(mamba, nullptr); + EXPECT_EQ(mamba->family_index, 1); + EXPECT_EQ(mamba->family, CacheFamily::RecurrentState); + EXPECT_EQ(mamba->attachment_kind, TreeAttachmentKind::ReusableTree); + EXPECT_EQ(mamba->recoverability, Recoverability::AlignedCheckpoint); + EXPECT_EQ(mamba->publication, PublicationKind::AuxiliaryLocalOnly); + EXPECT_EQ(mamba->split_policy, SplitPolicy::CheckpointBoundary); + EXPECT_EQ(mamba->checkpoint_chunk_tokens, kMambaChunkSize); + EXPECT_EQ(mamba->state_cohort_id, "mamba.checkpoint"); + EXPECT_TRUE(mamba->required_for_recovery); + + EXPECT_TRUE(Contains(registry.active_match_family_indices, mamba->family_index)); + EXPECT_TRUE(Contains(registry.active_admit_family_indices, mamba->family_index)); + EXPECT_TRUE(Contains(registry.active_commit_family_indices, mamba->family_index)); + EXPECT_TRUE(Contains(registry.active_evict_family_indices, mamba->family_index)); +} + +TEST(HybridPrefixCacheRegistryTest, PagedGroupsWithoutAdjunctAreRequestLocalCompatibilityFamilies) { + PageAllocator device_allocator{kPageSize, /*total_pages=*/8}; + PageAllocator host_allocator{kPageSize, /*total_pages=*/0}; + KVPrefixCache prefix_cache{&device_allocator, &host_allocator}; + HybridPrefixCache hybrid_prefix_cache{prefix_cache, device_allocator, /*allocator=*/nullptr, kMambaChunkSize}; + const std::vector groups = { + MakePagedGroup("v4.history", PagedCacheGroupFamily::History, PagedCacheGroupConfig::Retention::FullHistory), + MakePagedGroup("v4.swa", PagedCacheGroupFamily::State, PagedCacheGroupConfig::Retention::SlidingWindow, + /*sliding_window_tokens=*/16), + }; + + hybrid_prefix_cache.ConfigurePagedCacheAdjunct(std::span{groups}, std::nullopt); + + const FamilyRegistry& registry = hybrid_prefix_cache.Registry(); + const CacheResourceSpec* history = registry.FindById("v4.history"); + const CacheResourceSpec* state = registry.FindById("v4.swa"); + ASSERT_NE(history, nullptr); + ASSERT_NE(state, nullptr); + + EXPECT_EQ(history->family, CacheFamily::CompressedPage); + EXPECT_EQ(state->family, CacheFamily::SlidingWindowState); + EXPECT_EQ(history->attachment_kind, TreeAttachmentKind::NoneForRequestLocal); + EXPECT_EQ(state->attachment_kind, TreeAttachmentKind::NoneForRequestLocal); + EXPECT_EQ(history->recoverability, Recoverability::RequestLocalOnly); + EXPECT_EQ(state->recoverability, Recoverability::RequestLocalOnly); + EXPECT_FALSE(history->required_for_recovery); + EXPECT_FALSE(state->required_for_recovery); + + EXPECT_FALSE(Contains(registry.active_match_family_indices, history->family_index)); + EXPECT_FALSE(Contains(registry.active_match_family_indices, state->family_index)); + EXPECT_FALSE(Contains(registry.active_commit_family_indices, history->family_index)); + EXPECT_FALSE(Contains(registry.active_commit_family_indices, state->family_index)); + EXPECT_TRUE(Contains(registry.active_admit_family_indices, history->family_index)); + EXPECT_TRUE(Contains(registry.active_admit_family_indices, state->family_index)); + EXPECT_TRUE(Contains(registry.active_compatibility_family_indices, history->family_index)); + EXPECT_TRUE(Contains(registry.active_compatibility_family_indices, state->family_index)); +} + +TEST(HybridPrefixCacheRegistryTest, RequiredPagedAdjunctRegistersRecoverableHistoryAndStateCohort) { + PageAllocator device_allocator{kPageSize, /*total_pages=*/8}; + PageAllocator host_allocator{kPageSize, /*total_pages=*/0}; + KVPrefixCache prefix_cache{&device_allocator, &host_allocator}; + HybridPrefixCache hybrid_prefix_cache{prefix_cache, device_allocator, /*allocator=*/nullptr, kMambaChunkSize}; + const std::vector groups = { + MakePagedGroup("v4.history", PagedCacheGroupFamily::History, PagedCacheGroupConfig::Retention::FullHistory), + MakePagedGroup("v4.swa", PagedCacheGroupFamily::State, PagedCacheGroupConfig::Retention::SlidingWindow, + /*sliding_window_tokens=*/16), + }; + const std::vector required = {"v4.history", "v4.swa"}; + + hybrid_prefix_cache.ConfigurePagedCacheAdjunct(std::span{groups}, + std::span{required}); + + const FamilyRegistry& registry = hybrid_prefix_cache.Registry(); + const CacheResourceSpec* history = registry.FindById("v4.history"); + const CacheResourceSpec* state = registry.FindById("v4.swa"); + ASSERT_NE(history, nullptr); + ASSERT_NE(state, nullptr); + + EXPECT_EQ(history->attachment_kind, TreeAttachmentKind::ReusableTree); + EXPECT_EQ(state->attachment_kind, TreeAttachmentKind::ReusableTree); + EXPECT_EQ(history->recoverability, Recoverability::Exact); + EXPECT_EQ(state->recoverability, Recoverability::Exact); + EXPECT_EQ(history->publication, PublicationKind::CanonicalPrefixIndex); + EXPECT_EQ(state->publication, PublicationKind::AuxiliaryLocalOnly); + EXPECT_EQ(history->split_policy, SplitPolicy::SnapshotBoundary); + EXPECT_EQ(state->split_policy, SplitPolicy::SnapshotBoundary); + EXPECT_EQ(history->state_cohort_id, "paged.required"); + EXPECT_EQ(state->state_cohort_id, "paged.required"); + EXPECT_TRUE(history->required_for_recovery); + EXPECT_TRUE(state->required_for_recovery); + EXPECT_EQ(state->sliding_window_tokens, 16); + + EXPECT_TRUE(Contains(registry.active_match_family_indices, history->family_index)); + EXPECT_TRUE(Contains(registry.active_match_family_indices, state->family_index)); + EXPECT_TRUE(Contains(registry.active_commit_family_indices, history->family_index)); + EXPECT_TRUE(Contains(registry.active_commit_family_indices, state->family_index)); + EXPECT_TRUE(Contains(registry.active_evict_family_indices, history->family_index)); + EXPECT_TRUE(Contains(registry.active_evict_family_indices, state->family_index)); +} + +} // namespace tokenspeed::test diff --git a/tokenspeed-scheduler/tests/cpp/test_kv_cache_events.cpp b/tokenspeed-scheduler/tests/cpp/test_kv_cache_events.cpp index a91d8001b..72b7a1426 100644 --- a/tokenspeed-scheduler/tests/cpp/test_kv_cache_events.cpp +++ b/tokenspeed-scheduler/tests/cpp/test_kv_cache_events.cpp @@ -32,6 +32,7 @@ #include "integration_test_helper.h" #include "resource/allocator/page_allocator.h" +#include "resource/hybrid_prefix_cache/hybrid_prefix_cache.h" #include "resource/kv_prefix_cache/kv_prefix_cache.h" #include "scheduler/kv_cache_events.h" @@ -244,6 +245,71 @@ TEST_F(KVPrefixCacheEventTestSuite, HostRecoveryPublishesDeviceStoredEvents) { EXPECT_EQ(AsRemoved(events_[0]).block_hashes, (std::vector{first_hash, second_hash})); } +TEST(HybridPrefixCacheKvEventSinkTest, FacadeDelegatesEventsAndCanClearBinding) { + constexpr std::int32_t kPageSize = 2; + PageAllocator device_allocator{kPageSize, 4}; + PageAllocator host_allocator{kPageSize, 4}; + KVPrefixCache cache{&device_allocator, &host_allocator}; + HybridPrefixCache hybrid_cache{cache, device_allocator, nullptr, 0}; + std::vector events; + + hybrid_cache.SetKvEventSink([&](KvCacheEvent event) { events.push_back(std::move(event)); }); + + const token_vec_t first_tokens{1, 2}; + cache.Insert(first_tokens, {}, device_allocator.Allocate(1)); + + ASSERT_EQ(events.size(), 1u); + EXPECT_EQ(AsStored(events[0]).token_ids, first_tokens); + + hybrid_cache.SetKvEventSink({}); + + const token_vec_t second_tokens{3, 4}; + cache.Insert(second_tokens, {}, device_allocator.Allocate(1)); + + EXPECT_EQ(events.size(), 1u); +} + +TEST(HybridPrefixCacheKvEventSinkTest, DestructorClearsFacadeInstalledBinding) { + constexpr std::int32_t kPageSize = 2; + PageAllocator device_allocator{kPageSize, 4}; + PageAllocator host_allocator{kPageSize, 4}; + KVPrefixCache cache{&device_allocator, &host_allocator}; + std::vector events; + + { + HybridPrefixCache hybrid_cache{cache, device_allocator, nullptr, 0}; + hybrid_cache.SetKvEventSink([&](KvCacheEvent event) { events.push_back(std::move(event)); }); + + const token_vec_t first_tokens{1, 2}; + cache.Insert(first_tokens, {}, device_allocator.Allocate(1)); + } + + const token_vec_t second_tokens{3, 4}; + cache.Insert(second_tokens, {}, device_allocator.Allocate(1)); + + ASSERT_EQ(events.size(), 1u); + EXPECT_EQ(AsStored(events[0]).token_ids, (token_vec_t{1, 2})); +} + +TEST(HybridPrefixCacheKvEventSinkTest, DestructorDoesNotClearUnownedBinding) { + constexpr std::int32_t kPageSize = 2; + PageAllocator device_allocator{kPageSize, 4}; + PageAllocator host_allocator{kPageSize, 4}; + KVPrefixCache cache{&device_allocator, &host_allocator}; + std::vector events; + + cache.SetKvEventSink([&](KvCacheEvent event) { events.push_back(std::move(event)); }); + { + HybridPrefixCache hybrid_cache{cache, device_allocator, nullptr, 0}; + } + + const token_vec_t tokens{1, 2}; + cache.Insert(tokens, {}, device_allocator.Allocate(1)); + + ASSERT_EQ(events.size(), 1u); + EXPECT_EQ(AsStored(events[0]).token_ids, tokens); +} + TEST(KVPrefixCacheEventBenchTest, OptimizedInsertIsFasterThanLegacyAncestorRehashing) { constexpr std::int32_t kBenchPageSize = 16; constexpr std::int32_t kPageCount = 512; diff --git a/tokenspeed-scheduler/tests/cpp/test_mamba_cache.cpp b/tokenspeed-scheduler/tests/cpp/test_mamba_cache.cpp index 7f5958fd2..a3b19ed99 100644 --- a/tokenspeed-scheduler/tests/cpp/test_mamba_cache.cpp +++ b/tokenspeed-scheduler/tests/cpp/test_mamba_cache.cpp @@ -20,14 +20,24 @@ #include +#include +#include +#include +#include +#include #include +#include #include +#include #include "core/token_container.h" #include "fsm/forward_events.h" #include "fsm/forward_states.h" #include "resource/hybrid_prefix_cache/hybrid_prefix_cache.h" +#include "hybrid_prefix_cache_test_peer.h" +#include "resource/allocator/kv_allocator.h" +#include "resource/allocator/local_mamba_allocator.h" #include "resource/allocator/mamba_chunk_allocator.h" #include "resource/allocator/mamba_host_allocator.h" #include "scheduler/operations/cache.h" @@ -36,6 +46,7 @@ #include "resource/radix_tree/node_range.h" #include "resource/allocator/page_allocator.h" #include "resource/allocator/req_pool_allocator.h" +#include "scheduler/operations/forward.h" #include "unit_test_helper.h" #include "scheduler/types.h" @@ -54,8 +65,8 @@ class MambaCacheTest : public ::testing::Test { host_alloc_ = std::make_unique(kPageSize, kHostPages); prefix_cache_ = std::make_unique(device_alloc_.get(), host_alloc_.get()); mamba_alloc_ = std::make_unique(kMambaSlots); - hybrid_prefix_cache_ = - std::make_unique(*prefix_cache_, mamba_alloc_.get(), kMambaCacheChunkSize); + hybrid_prefix_cache_ = std::make_unique(*prefix_cache_, *device_alloc_, mamba_alloc_.get(), + kMambaCacheChunkSize); } std::vector CollectPrefixPages(TreeNode* matched_node) { @@ -63,7 +74,7 @@ class MambaCacheTest : public ::testing::Test { return DevicePagesFromRoot(matched_node); } - void InsertKVAndMamba(const token_vec_t& tokens) { + TreeNode* InsertKVAndMamba(const token_vec_t& tokens) { auto match = prefix_cache_->Match(tokens); std::int32_t matched_pages = match.device.DepthInPage(); std::int32_t total_pages = static_cast(tokens.size()) / kPageSize; @@ -74,20 +85,46 @@ class MambaCacheTest : public ::testing::Test { prefix_cache_->Insert(tokens, prefix_pages, device_alloc_->Allocate(new_pages)); auto slot = mamba_alloc_->Allocate(); if (slot.has_value()) { - hybrid_prefix_cache_->InsertMamba(result.last_node, std::make_unique(std::move(*slot))); + HybridPrefixCacheTestPeer::InsertMamba(*hybrid_prefix_cache_, result.last_node, + std::make_unique(std::move(*slot))); } + return result.last_node; } + return match.device.last_node; } - void InsertKVOnly(const token_vec_t& tokens) { + TreeNode* InsertKVOnly(const token_vec_t& tokens) { auto match = prefix_cache_->Match(tokens); std::int32_t matched_pages = match.device.DepthInPage(); std::int32_t total_pages = static_cast(tokens.size()) / kPageSize; std::int32_t new_pages = total_pages - matched_pages; if (new_pages > 0) { auto prefix_pages = CollectPrefixPages(match.device.last_node); - prefix_cache_->Insert(tokens, prefix_pages, device_alloc_->Allocate(new_pages)); + auto result = + prefix_cache_->Insert(tokens, prefix_pages, device_alloc_->Allocate(new_pages)); + return result.last_node; + } + return match.device.last_node; + } + + std::vector> PagedTokenSpans(const token_vec_t& tokens) const { + std::vector> pages; + const auto page_count = static_cast(tokens.size()) / kPageSize; + pages.reserve(static_cast(page_count)); + for (std::int32_t page = 0; page < page_count; ++page) { + pages.emplace_back(tokens.data() + page * kPageSize, static_cast(kPageSize)); } + return pages; + } + + std::vector FillMambaSlots() { + std::vector nodes; + nodes.reserve(kMambaSlots); + for (std::int32_t i = 0; i < kMambaSlots; ++i) { + TreeNode* node = InsertKVAndMamba(MakeAlignedTokens(1, kPageSize, /*start=*/1000 + i * 10)); + nodes.push_back(node); + } + return nodes; } std::unique_ptr device_alloc_; @@ -101,8 +138,33 @@ TEST_F(MambaCacheTest, MatchWithoutMambaTruncatesToRoot) { auto tokens = MakeAlignedTokens(3, kPageSize); InsertKVOnly(tokens); - auto match = hybrid_prefix_cache_->Match(tokens); + auto match = hybrid_prefix_cache_->MatchPrefix(tokens).compat_match; + EXPECT_EQ(match.device.DepthInPage(), 0); + EXPECT_EQ(match.mamba_cow_src_index, -1); + EXPECT_EQ(match.mamba_branching_seqlen, 4); +} + +TEST(HybridPrefixCacheMambaRecoverablePrefixTest, HostOnlyKVWithoutMambaDoesNotProduceLoadBackPrefix) { + static constexpr std::int32_t kPageSize = 2; + PageAllocator device_alloc{kPageSize, 4}; + PageAllocator host_alloc{kPageSize, 16}; + KVPrefixCache prefix_cache{&device_alloc, &host_alloc, false}; + MambaChunkAllocator mamba_alloc{2}; + HybridPrefixCache hybrid_prefix_cache{prefix_cache, device_alloc, &mamba_alloc, + /*mamba_cache_chunk_size=*/4}; + + auto tokens = MakeAlignedTokens(/*num_pages=*/2, kPageSize); + prefix_cache.Insert(tokens, /*prefix_pages=*/{}, host_alloc.Allocate(/*num_pages=*/2)); + auto raw_match = prefix_cache.Match(tokens); + ASSERT_EQ(raw_match.device.DepthInPage(), 0); + ASSERT_EQ(raw_match.host.DepthInPage(), 2); + + auto match = hybrid_prefix_cache.MatchPrefix(tokens).compat_match; + EXPECT_EQ(match.device.DepthInPage(), 0); + EXPECT_EQ(match.host.DepthInPage(), 0); + EXPECT_TRUE(match.NodesWithout().empty()) + << "host-only KV without tree-owned Mamba state must not plan LoadBack"; EXPECT_EQ(match.mamba_cow_src_index, -1); EXPECT_EQ(match.mamba_branching_seqlen, 4); } @@ -111,7 +173,7 @@ TEST_F(MambaCacheTest, MatchWithFullMambaKeepsDepth) { auto tokens = MakeAlignedTokens(3, kPageSize); InsertKVAndMamba(tokens); - auto match = hybrid_prefix_cache_->Match(tokens); + auto match = hybrid_prefix_cache_->MatchPrefix(tokens).compat_match; EXPECT_EQ(match.device.DepthInPage(), 3); EXPECT_NE(match.mamba_cow_src_index, -1); EXPECT_EQ(match.mamba_branching_seqlen, -1); @@ -124,7 +186,7 @@ TEST_F(MambaCacheTest, MatchWithPartialMambaTruncatesToMambaDepth) { auto tokens4 = MakeAlignedTokens(4, kPageSize); InsertKVOnly(tokens4); - auto match = hybrid_prefix_cache_->Match(tokens4); + auto match = hybrid_prefix_cache_->MatchPrefix(tokens4).compat_match; EXPECT_EQ(match.device.DepthInPage(), 2); EXPECT_NE(match.mamba_cow_src_index, -1); EXPECT_NE(match.mamba_branching_seqlen, -1); @@ -140,7 +202,7 @@ TEST_F(MambaCacheTest, SplitPrefixWithoutMambaStillRequestsBranchingSnapshot) { diverged[2 * kPageSize] = 1001; diverged[2 * kPageSize + 1] = 1002; - auto match = hybrid_prefix_cache_->Match(diverged); + auto match = hybrid_prefix_cache_->MatchPrefix(diverged).compat_match; EXPECT_EQ(match.device.DepthInPage(), 0); EXPECT_EQ(match.mamba_cow_src_index, -1); EXPECT_EQ(match.mamba_branching_seqlen, 4); @@ -153,7 +215,7 @@ TEST_F(MambaCacheTest, BranchingSeqlenIsSuppressedWhenAlignedInsideMambaPrefix) auto tokens3 = MakeAlignedTokens(3, kPageSize); InsertKVOnly(tokens3); - auto match = hybrid_prefix_cache_->Match(tokens3); + auto match = hybrid_prefix_cache_->MatchPrefix(tokens3).compat_match; EXPECT_EQ(match.device.DepthInPage(), 2); EXPECT_NE(match.mamba_cow_src_index, -1); EXPECT_EQ(match.mamba_branching_seqlen, -1); @@ -180,13 +242,365 @@ TEST_F(MambaCacheTest, FindLastMambaNodeWalksUp) { auto match = prefix_cache_->Match(tokens4); TreeNode* terminal = match.device.last_node; - TreeNode* mamba_node = hybrid_prefix_cache_->FindLastMambaNode(terminal); + TreeNode* mamba_node = HybridPrefixCacheTestPeer::FindLastMambaNode(*hybrid_prefix_cache_, terminal); ASSERT_NE(mamba_node, nullptr); EXPECT_TRUE(mamba_node->HasMamba()); EXPECT_EQ(mamba_node->DepthInPage(kPageSize), 2); } +TEST_F(MambaCacheTest, MatchPrefixStateRecoverySetsCowSourceAndProtection) { + auto tokens = MakeAlignedTokens(2, kPageSize); + TreeNode* recovery_source = InsertKVAndMamba(tokens); + ASSERT_NE(recovery_source, nullptr); + ASSERT_TRUE(recovery_source->HasMamba()); + + RecoveryPlan recovery = hybrid_prefix_cache_->MatchPrefix(tokens, MatchIntent::StateRecovery); + + EXPECT_TRUE(recovery.recovery_state_available); + EXPECT_EQ(recovery.protected_recovery_node, recovery_source); + EXPECT_EQ(recovery.compat_match.mamba_cow_src_index, recovery_source->MambaSlotIndex()); +} + +TEST_F(MambaCacheTest, MatchPrefixStateRecoveryReportsMissingMambaState) { + auto tokens = MakeAlignedTokens(2, kPageSize); + InsertKVOnly(tokens); + + RecoveryPlan recovery = hybrid_prefix_cache_->MatchPrefix(tokens, MatchIntent::StateRecovery); + + EXPECT_FALSE(recovery.recovery_state_available); + EXPECT_EQ(recovery.protected_recovery_node, nullptr); + EXPECT_EQ(recovery.compat_match.mamba_cow_src_index, -1); +} + +TEST(HybridPrefixCacheRecoveryPlanTest, StateRecoverySucceedsWithoutMambaAdjunct) { + PageAllocator device_alloc(/*page_size=*/2, /*total_pages=*/4); + PageAllocator host_alloc(/*page_size=*/2, /*total_pages=*/0); + KVPrefixCache prefix_cache(&device_alloc, &host_alloc); + HybridPrefixCache hybrid_prefix_cache(prefix_cache, device_alloc, /*allocator=*/nullptr, + /*mamba_cache_chunk_size=*/4); + RecoveryPlan recovery = hybrid_prefix_cache.MatchPrefix(token_vec_t{}, MatchIntent::StateRecovery); + + EXPECT_TRUE(recovery.recovery_state_available); + EXPECT_EQ(recovery.protected_recovery_node, nullptr); + EXPECT_EQ(recovery.compat_match.mamba_cow_src_index, -1); +} + +TEST_F(MambaCacheTest, WorkerMetadataReceivesPrefixRecoveryAndRequestLocalInfo) { + MatchResult prefix_match{}; + prefix_match.mamba_cow_src_index = 7; + prefix_match.mamba_branching_seqlen = 12; + ForwardOperationBase prefix_op{}; + + hybrid_prefix_cache_->StepCommit({ + .worker_metadata = + WorkerCompatibilityCommitRequest{ + .op_base = &prefix_op, + .compat_match = &prefix_match, + .populate_prefix_reuse_metadata = true, + }, + }); + + EXPECT_EQ(prefix_op.mamba_cow_src_idx, 7); + EXPECT_EQ(prefix_op.mamba_branching_seqlen, 12); + + MatchResult recovery_match{}; + recovery_match.mamba_cow_src_index = 5; + recovery_match.mamba_branching_seqlen = 12; + ForwardOperationBase recovery_op{}; + recovery_op.mamba_branching_seqlen = 99; + + hybrid_prefix_cache_->StepCommit({ + .worker_metadata = + WorkerCompatibilityCommitRequest{ + .op_base = &recovery_op, + .compat_match = &recovery_match, + .populate_recovery_metadata = true, + }, + }); + + EXPECT_EQ(recovery_op.mamba_cow_src_idx, 5); + EXPECT_EQ(recovery_op.mamba_branching_seqlen, 99); + + LocalMambaAllocator local_mamba(mamba_alloc_.get()); + ASSERT_TRUE(local_mamba.AllocateWorking()); + ASSERT_TRUE(local_mamba.AllocateCheckpoint()); + ForwardOperationBase request_local_op{}; + + hybrid_prefix_cache_->StepCommit({ + .worker_metadata = + WorkerCompatibilityCommitRequest{ + .op_base = &request_local_op, + .local_mamba_allocator_view = &local_mamba, + }, + }); + + EXPECT_EQ(request_local_op.mamba_working_idx, local_mamba.WorkingIndex()); + EXPECT_EQ(request_local_op.mamba_checkpoint_dst_idx, local_mamba.CheckpointIndex()); + EXPECT_EQ(request_local_op.mamba_cow_src_idx, -1); + EXPECT_EQ(request_local_op.mamba_branching_seqlen, -1); + + ForwardOperationBase defaults_op{}; + + hybrid_prefix_cache_->StepCommit({ + .worker_metadata = + WorkerCompatibilityCommitRequest{ + .op_base = &defaults_op, + }, + }); + + EXPECT_EQ(defaults_op.mamba_working_idx, -1); + EXPECT_EQ(defaults_op.mamba_checkpoint_dst_idx, -1); +} + +TEST(HybridPrefixCacheMambaCompatibilityFieldsTest, WorkerMetadataLeavesDefaultsWithoutMambaAdjunct) { + PageAllocator device_alloc(/*page_size=*/2, /*total_pages=*/4); + PageAllocator host_alloc(/*page_size=*/2, /*total_pages=*/0); + KVPrefixCache prefix_cache(&device_alloc, &host_alloc); + HybridPrefixCache hybrid_prefix_cache(prefix_cache, device_alloc, /*allocator=*/nullptr, + /*mamba_cache_chunk_size=*/4); + MatchResult match{}; + match.mamba_cow_src_index = 3; + match.mamba_branching_seqlen = 8; + MambaChunkAllocator mamba_alloc(/*num_slots=*/2); + LocalMambaAllocator local_mamba(&mamba_alloc); + ASSERT_TRUE(local_mamba.AllocateWorking()); + ASSERT_TRUE(local_mamba.AllocateCheckpoint()); + ForwardOperationBase op{}; + + hybrid_prefix_cache.StepCommit({ + .worker_metadata = + WorkerCompatibilityCommitRequest{ + .op_base = &op, + .compat_match = &match, + .local_mamba_allocator_view = &local_mamba, + .populate_prefix_reuse_metadata = true, + }, + }); + + EXPECT_EQ(op.mamba_working_idx, -1); + EXPECT_EQ(op.mamba_checkpoint_dst_idx, -1); + EXPECT_EQ(op.mamba_cow_src_idx, -1); + EXPECT_EQ(op.mamba_branching_seqlen, -1); +} + +TEST(HybridPrefixCacheKVAllocationTest, PrefillFirstChunkCreatesRequestLocalStateAndPreservesAccounting) { + PageAllocator device_alloc(/*page_size=*/2, /*total_pages=*/5); + PageAllocator host_alloc(/*page_size=*/2, /*total_pages=*/0); + MambaChunkAllocator mamba_alloc(/*num_slots=*/4); + KVPrefixCache prefix_cache(&device_alloc, &host_alloc); + HybridPrefixCache hybrid_prefix_cache(prefix_cache, device_alloc, &mamba_alloc, /*mamba_cache_chunk_size=*/4); + + auto result = hybrid_prefix_cache.StepCommit({ + .request_local_kv = + RequestLocalKVStateRequest{ + .create_allocator = true, + .initial_tokens = 1, + .acquire_tokens = 2, + }, + .request_local_mamba = + RequestLocalMambaStateRequest{ + .create_allocator = true, + }, + }); + auto local_kv = std::move(result.local_kv_allocator); + auto local_mamba = std::move(result.local_mamba_allocator); + + ASSERT_NE(local_kv, nullptr); + EXPECT_EQ(local_kv->Pages().size(), 2u); + EXPECT_EQ(local_kv->TailPageAvailableTokens(), 1); + EXPECT_EQ(device_alloc.AvailablePages(), 2); + ASSERT_NE(local_mamba, nullptr); + EXPECT_TRUE(local_mamba->HasWorking()); + EXPECT_TRUE(local_mamba->HasCheckpoint()); + EXPECT_EQ(mamba_alloc.AvailableSlots(), 2); + + local_kv.reset(); + local_mamba.reset(); + EXPECT_EQ(device_alloc.AvailablePages(), 4); + EXPECT_EQ(mamba_alloc.AvailableSlots(), 4); +} + +TEST(HybridPrefixCacheKVAllocationTest, PrefillFirstChunkAllocationFailurePreservesPageAccounting) { + PageAllocator device_alloc(/*page_size=*/2, /*total_pages=*/2); + PageAllocator host_alloc(/*page_size=*/2, /*total_pages=*/0); + KVPrefixCache prefix_cache(&device_alloc, &host_alloc); + HybridPrefixCache hybrid_prefix_cache(prefix_cache, device_alloc, /*allocator=*/nullptr, + /*mamba_cache_chunk_size=*/4); + + EXPECT_THROW((void)hybrid_prefix_cache.StepCommit({ + .request_local_kv = + RequestLocalKVStateRequest{ + .create_allocator = true, + .initial_tokens = 2, + .acquire_tokens = 1, + }, + }), + std::runtime_error); + EXPECT_EQ(device_alloc.AvailablePages(), 1); +} + +TEST(HybridPrefixCacheKVAllocationTest, PrefillContinuationRefreshesCheckpointWithoutLeakingPages) { + PageAllocator device_alloc(/*page_size=*/2, /*total_pages=*/5); + PageAllocator host_alloc(/*page_size=*/2, /*total_pages=*/0); + MambaChunkAllocator mamba_alloc(/*num_slots=*/4); + KVPrefixCache prefix_cache(&device_alloc, &host_alloc); + HybridPrefixCache hybrid_prefix_cache(prefix_cache, device_alloc, &mamba_alloc, /*mamba_cache_chunk_size=*/4); + + { + LocalKVAllocator local_kv(&device_alloc, /*num_tokens=*/1); + LocalKVAllocator* original_allocator = &local_kv; + const std::vector original_pages = local_kv.Pages(); + ASSERT_EQ(original_pages.size(), 1u); + ASSERT_EQ(local_kv.TailPageAvailableTokens(), 1); + LocalMambaAllocator local_mamba(&mamba_alloc); + ASSERT_TRUE(local_mamba.AllocateWorking()); + ASSERT_TRUE(local_mamba.AllocateCheckpoint()); + const std::int32_t original_checkpoint = local_mamba.CheckpointIndex(); + + (void)hybrid_prefix_cache.StepCommit({ + .request_local_kv = + RequestLocalKVStateRequest{ + .allocator = &local_kv, + .acquire_tokens = 2, + }, + .request_local_mamba = + RequestLocalMambaStateRequest{ + .refresh_checkpoint_allocator = &local_mamba, + }, + }); + + EXPECT_EQ(&local_kv, original_allocator); + ASSERT_EQ(local_kv.Pages().size(), 2u); + EXPECT_EQ(local_kv.Pages().front(), original_pages.front()); + EXPECT_EQ(local_kv.TailPageAvailableTokens(), 1); + EXPECT_EQ(device_alloc.AvailablePages(), 2); + EXPECT_TRUE(local_mamba.HasWorking()); + EXPECT_TRUE(local_mamba.HasCheckpoint()); + EXPECT_NE(local_mamba.CheckpointIndex(), original_checkpoint); + EXPECT_EQ(mamba_alloc.AvailableSlots(), 2); + } + + EXPECT_EQ(device_alloc.AvailablePages(), 4); + EXPECT_EQ(mamba_alloc.AvailableSlots(), 4); +} + +TEST(HybridPrefixCacheKVAllocationTest, DecodeRetractFailureReleasesPartialResources) { + PageAllocator device_alloc(/*page_size=*/2, /*total_pages=*/2); + PageAllocator host_alloc(/*page_size=*/2, /*total_pages=*/0); + KVPrefixCache prefix_cache(&device_alloc, &host_alloc); + HybridPrefixCache hybrid_prefix_cache(prefix_cache, device_alloc, /*allocator=*/nullptr, + /*mamba_cache_chunk_size=*/4); + + { + LocalKVAllocator local_kv(&device_alloc, /*num_tokens=*/1); + const std::vector original_pages = local_kv.Pages(); + ASSERT_EQ(original_pages.size(), 1u); + ASSERT_EQ(local_kv.TailPageAvailableTokens(), 1); + ASSERT_EQ(device_alloc.AvailablePages(), 0); + + EXPECT_THROW((void)hybrid_prefix_cache.StepCommit({ + .request_local_kv = + RequestLocalKVStateRequest{ + .allocator = &local_kv, + .acquire_tokens = 3, + }, + }), + std::runtime_error); + + EXPECT_EQ(local_kv.Pages(), original_pages); + EXPECT_EQ(local_kv.TailPageAvailableTokens(), 1); + EXPECT_EQ(device_alloc.AvailablePages(), 0); + } + + EXPECT_EQ(device_alloc.AvailablePages(), 1); + + PageAllocator mamba_device_alloc(/*page_size=*/2, /*total_pages=*/4); + PageAllocator mamba_host_alloc(/*page_size=*/2, /*total_pages=*/0); + KVPrefixCache mamba_prefix_cache(&mamba_device_alloc, &mamba_host_alloc); + MambaChunkAllocator mamba_alloc(/*num_slots=*/1); + HybridPrefixCache mamba_hybrid(mamba_prefix_cache, mamba_device_alloc, &mamba_alloc, + /*mamba_cache_chunk_size=*/4); + EXPECT_THROW((void)mamba_hybrid.StepCommit({ + .request_local_mamba = + RequestLocalMambaStateRequest{ + .create_allocator = true, + .require_allocator = true, + }, + }), + std::logic_error); + EXPECT_EQ(mamba_alloc.AvailableSlots(), 1); +} + +TEST(HybridPrefixCacheMambaAllocationTest, RequestLocalMambaAllocationFailureReturnsNullAndReleasesPartialSlot) { + PageAllocator device_alloc(/*page_size=*/2, /*total_pages=*/4); + PageAllocator host_alloc(/*page_size=*/2, /*total_pages=*/0); + KVPrefixCache prefix_cache(&device_alloc, &host_alloc); + MambaChunkAllocator mamba_alloc(/*num_slots=*/1); + HybridPrefixCache hybrid_prefix_cache(prefix_cache, device_alloc, &mamba_alloc, /*mamba_cache_chunk_size=*/4); + + auto mamba_result = hybrid_prefix_cache.StepCommit({ + .request_local_mamba = + RequestLocalMambaStateRequest{ + .create_allocator = true, + }, + }); + auto local_mamba = std::move(mamba_result.local_mamba_allocator); + + EXPECT_EQ(local_mamba, nullptr); + EXPECT_EQ(mamba_alloc.AvailableSlots(), 1); +} + +TEST(HybridPrefixCacheMambaAllocationTest, RequestLocalAllocationsReturnNullWithoutMambaAdjunct) { + PageAllocator device_alloc(/*page_size=*/2, /*total_pages=*/4); + PageAllocator host_alloc(/*page_size=*/2, /*total_pages=*/0); + KVPrefixCache prefix_cache(&device_alloc, &host_alloc); + HybridPrefixCache hybrid_prefix_cache(prefix_cache, device_alloc, /*allocator=*/nullptr, + /*mamba_cache_chunk_size=*/4); + + auto prefill_result = hybrid_prefix_cache.StepCommit({ + .request_local_mamba = + RequestLocalMambaStateRequest{ + .create_allocator = true, + }, + }); + auto retracted_result = hybrid_prefix_cache.StepCommit({ + .request_local_mamba = + RequestLocalMambaStateRequest{ + .create_allocator = true, + .require_allocator = true, + }, + }); + + EXPECT_EQ(prefill_result.local_mamba_allocator, nullptr); + EXPECT_EQ(retracted_result.local_mamba_allocator, nullptr); +} + +TEST(HybridPrefixCacheMambaAllocationTest, PrefillContinuationLeavesMambaStateUntouchedWithoutAdjunct) { + PageAllocator device_alloc(/*page_size=*/2, /*total_pages=*/4); + PageAllocator host_alloc(/*page_size=*/2, /*total_pages=*/0); + KVPrefixCache prefix_cache(&device_alloc, &host_alloc); + HybridPrefixCache hybrid_prefix_cache(prefix_cache, device_alloc, /*allocator=*/nullptr, + /*mamba_cache_chunk_size=*/4); + MambaChunkAllocator mamba_alloc(/*num_slots=*/2); + LocalMambaAllocator local_mamba(&mamba_alloc); + ASSERT_TRUE(local_mamba.AllocateWorking()); + + (void)hybrid_prefix_cache.StepCommit({ + .request_local_mamba = + RequestLocalMambaStateRequest{ + .refresh_checkpoint_allocator = &local_mamba, + }, + }); + (void)hybrid_prefix_cache.StepCommit({ + .request_local_mamba = RequestLocalMambaStateRequest{}, + }); + + EXPECT_TRUE(local_mamba.HasWorking()); + EXPECT_FALSE(local_mamba.HasCheckpoint()); + EXPECT_EQ(mamba_alloc.AvailableSlots(), 1); +} + TEST_F(MambaCacheTest, KVEvictionTriggersMambaEviction) { auto tokens = MakeAlignedTokens(2, kPageSize); InsertKVAndMamba(tokens); @@ -195,8 +609,6 @@ TEST_F(MambaCacheTest, KVEvictionTriggersMambaEviction) { TreeNode* node = match.device.last_node; EXPECT_TRUE(node->HasMamba()); - prefix_cache_->GetDeviceManager().SetEvictionCallback([this](TreeNode* n) { hybrid_prefix_cache_->OnKVEvict(n); }); - prefix_cache_->EnsureCapacityByEvict(kDevicePages); EXPECT_FALSE(node->HasMamba()); @@ -217,7 +629,7 @@ class MambaL2CacheTest : public ::testing::Test { prefix_cache_ = std::make_unique(device_alloc_.get(), host_alloc_.get()); mamba_alloc_ = std::make_unique(kMambaSlots); mamba_host_alloc_ = std::make_unique(kMambaHostSlots); - hybrid_prefix_cache_ = std::make_unique(*prefix_cache_, mamba_alloc_.get(), + hybrid_prefix_cache_ = std::make_unique(*prefix_cache_, *device_alloc_, mamba_alloc_.get(), kMambaCacheChunkSize, mamba_host_alloc_.get()); } @@ -243,7 +655,7 @@ TEST_F(MambaL2CacheTest, HostKVRequiresHostMambaForHybridMatch) { ASSERT_TRUE(device_slot.has_value()); node->AttachMamba(std::make_unique(std::move(*device_slot))); - auto mismatch = hybrid_prefix_cache_->Match(tokens); + auto mismatch = hybrid_prefix_cache_->MatchPrefix(tokens).compat_match; EXPECT_EQ(mismatch.host.DepthInPage(), 0); EXPECT_EQ(mismatch.device.DepthInPage(), 0); @@ -253,7 +665,7 @@ TEST_F(MambaL2CacheTest, HostKVRequiresHostMambaForHybridMatch) { const std::int32_t host_idx = host_slot->Index(); node->AttachMambaHost(std::make_unique(std::move(*host_slot))); - auto match = hybrid_prefix_cache_->Match(tokens); + auto match = hybrid_prefix_cache_->MatchPrefix(tokens).compat_match; EXPECT_EQ(match.host.DepthInPage(), 3); EXPECT_EQ(match.device.DepthInPage(), 0); EXPECT_EQ(match.mamba_host_src_index, host_idx); @@ -276,7 +688,7 @@ TEST_F(MambaL2CacheTest, DeeperHostMambaMatchTakesPriorityOverShallowDeviceMamba const std::int32_t host_idx = host_slot->Index(); host_node->AttachMambaHost(std::make_unique(std::move(*host_slot))); - auto match = hybrid_prefix_cache_->Match(tokens4); + auto match = hybrid_prefix_cache_->MatchPrefix(tokens4).compat_match; EXPECT_EQ(match.device.DepthInPage(), 2); EXPECT_EQ(match.host.DepthInPage(), 4); @@ -284,30 +696,6 @@ TEST_F(MambaL2CacheTest, DeeperHostMambaMatchTakesPriorityOverShallowDeviceMamba EXPECT_EQ(match.mamba_cow_src_index, -1) << "deeper host hit must trigger Mamba L2 loadback"; } -TEST_F(MambaL2CacheTest, PrefillFirstChunkRequiresCheckpointSlot) { - MambaChunkAllocator one_slot_mamba_alloc(1); - ReqPoolAllocator req_pool_alloc(1); - auto tokens = MakeAlignedTokens(1, kPageSize); - TokenContainer token_container(tokens); - auto match = prefix_cache_->Match(token_container.GetFullPagedTokens(kPageSize, true)); - - fsm::SchedulePrefillFirstChunkEvent event{ - static_cast(tokens.size()), - 0, - device_alloc_.get(), - &req_pool_alloc, - match, - Role::kP, - prefix_cache_.get(), - false, - {}, - hybrid_prefix_cache_.get(), - &one_slot_mamba_alloc, - }; - - EXPECT_THROW((void)event(fsm::Submitted{&token_container, kPageSize}), std::logic_error); -} - TEST_F(MambaL2CacheTest, PrepareMambaLoadBackAllocatesDeviceSlotAndTransferPair) { auto tokens = MakeAlignedTokens(2, kPageSize); TreeNode* node = InsertHostKV(tokens); @@ -316,7 +704,7 @@ TEST_F(MambaL2CacheTest, PrepareMambaLoadBackAllocatesDeviceSlotAndTransferPair) const std::int32_t host_idx = host_slot->Index(); node->AttachMambaHost(std::make_unique(std::move(*host_slot))); - auto transfers = hybrid_prefix_cache_->PrepareMambaDeviceLoadBack({node}); + auto transfers = HybridPrefixCacheTestPeer::PrepareMambaDeviceLoadBack(*hybrid_prefix_cache_, {node}); ASSERT_TRUE(node->HasMamba()); ASSERT_EQ(transfers.size(), 1u); @@ -341,8 +729,9 @@ TEST_F(MambaL2CacheTest, ExactWriteBackAckDoesNotPublishUnackedAncestor) { ASSERT_TRUE(descendant_slot.has_value()); descendant->AttachMamba(std::make_unique(std::move(*descendant_slot))); - auto ancestor_transfers = hybrid_prefix_cache_->PrepareMambaHostWriteBack({ancestor}); - auto descendant_transfers = hybrid_prefix_cache_->PrepareMambaHostWriteBack({descendant}); + auto ancestor_transfers = HybridPrefixCacheTestPeer::PrepareMambaHostWriteBack(*hybrid_prefix_cache_, {ancestor}); + auto descendant_transfers = + HybridPrefixCacheTestPeer::PrepareMambaHostWriteBack(*hybrid_prefix_cache_, {descendant}); ASSERT_EQ(ancestor_transfers.size(), 1u); ASSERT_EQ(descendant_transfers.size(), 1u); @@ -366,7 +755,7 @@ TEST_F(MambaL2CacheTest, PrepareMambaWriteBackPublishesHostSlotOnlyAfterAck) { const std::int32_t device_idx = device_slot->Index(); node->AttachMamba(std::make_unique(std::move(*device_slot))); - auto transfers = hybrid_prefix_cache_->PrepareMambaHostWriteBack({node}); + auto transfers = HybridPrefixCacheTestPeer::PrepareMambaHostWriteBack(*hybrid_prefix_cache_, {node}); ASSERT_EQ(transfers.size(), 1u); EXPECT_EQ(transfers[0].kind, CacheKind::kMamba); @@ -374,7 +763,7 @@ TEST_F(MambaL2CacheTest, PrepareMambaWriteBackPublishesHostSlotOnlyAfterAck) { const std::int32_t host_idx = transfers[0].dst; EXPECT_FALSE(node->HasMambaOnHost()) << "host mamba must remain invisible until writeback ack"; - auto pending_match = hybrid_prefix_cache_->Match(tokens); + auto pending_match = hybrid_prefix_cache_->MatchPrefix(tokens).compat_match; EXPECT_EQ(pending_match.host.DepthInPage(), 0); hybrid_prefix_cache_->OnMambaHostWriteBackDone(node); @@ -382,7 +771,7 @@ TEST_F(MambaL2CacheTest, PrepareMambaWriteBackPublishesHostSlotOnlyAfterAck) { ASSERT_TRUE(node->HasMambaOnHost()); EXPECT_EQ(node->MambaHostSlotIndex(), host_idx); EXPECT_FALSE(node->HasMamba()) << "idle device mamba copy should demote once host writeback is acknowledged"; - auto host_match = hybrid_prefix_cache_->Match(tokens); + auto host_match = hybrid_prefix_cache_->MatchPrefix(tokens).compat_match; EXPECT_EQ(host_match.host.DepthInPage(), 2); EXPECT_EQ(host_match.mamba_host_src_index, host_idx); EXPECT_EQ(host_match.mamba_cow_src_index, -1); @@ -398,7 +787,7 @@ TEST_F(MambaL2CacheTest, HostWriteBackDemotesAfterDeviceRefUnlock) { ASSERT_TRUE(device_slot.has_value()); node->AttachMamba(std::make_unique(std::move(*device_slot))); - auto transfers = hybrid_prefix_cache_->PrepareMambaHostWriteBack({node}); + auto transfers = HybridPrefixCacheTestPeer::PrepareMambaHostWriteBack(*hybrid_prefix_cache_, {node}); ASSERT_EQ(transfers.size(), 1u); { @@ -443,4 +832,370 @@ TEST_F(MambaL2CacheTest, WriteBackDoneDropsDeviceMambaWhenKVChildKeepsDeviceNode EXPECT_TRUE(node->HasMambaOnHost()); } +TEST(HybridPrefixCacheEvictionCallbackTest, KVOnlyEvictionIsNoOpCompatible) { + PageAllocator device_alloc(/*page_size=*/2, /*total_pages=*/4); + PageAllocator host_alloc(/*page_size=*/2, /*total_pages=*/0); + KVPrefixCache prefix_cache(&device_alloc, &host_alloc); + HybridPrefixCache hybrid_prefix_cache(prefix_cache, device_alloc, /*allocator=*/nullptr, + /*mamba_cache_chunk_size=*/4); + + auto tokens = MakeAlignedTokens(/*num_pages=*/1, /*page_size=*/2); + auto insert_result = + prefix_cache.Insert(tokens, /*prefix_pages=*/{}, device_alloc.Allocate(1)); + ASSERT_NE(insert_result.last_node, nullptr); + const std::int32_t pages_before_evict = device_alloc.AvailablePages(); + ASSERT_LT(pages_before_evict, device_alloc.TotalPages()); + + EXPECT_TRUE(prefix_cache.EnsureCapacityByEvict(pages_before_evict + 1)); + EXPECT_GT(device_alloc.AvailablePages(), pages_before_evict); +} + +TEST(HybridPrefixCacheEvictionCallbackTest, DestroyingWrapperClearsKVEvictionCallback) { + PageAllocator device_alloc(/*page_size=*/2, /*total_pages=*/4); + PageAllocator host_alloc(/*page_size=*/2, /*total_pages=*/0); + MambaChunkAllocator mamba_alloc(/*num_slots=*/1); + KVPrefixCache prefix_cache(&device_alloc, &host_alloc); + + TreeNode* node = nullptr; + { + HybridPrefixCache hybrid_prefix_cache(prefix_cache, device_alloc, &mamba_alloc, /*mamba_cache_chunk_size=*/4); + + auto tokens = MakeAlignedTokens(/*num_pages=*/1, /*page_size=*/2); + auto insert_result = + prefix_cache.Insert(tokens, /*prefix_pages=*/{}, device_alloc.Allocate(1)); + node = insert_result.last_node; + ASSERT_NE(node, nullptr); + + auto slot = mamba_alloc.Allocate(); + ASSERT_TRUE(slot.has_value()); + HybridPrefixCacheTestPeer::InsertMamba(hybrid_prefix_cache, node, + std::make_unique(std::move(*slot))); + ASSERT_TRUE(node->HasMamba()); + ASSERT_EQ(mamba_alloc.AvailableSlots(), 0); + } + + ASSERT_TRUE(node->HasMamba()); + auto evicted = prefix_cache.GetDeviceManager().Evict(/*num_pages=*/1); + + ASSERT_EQ(evicted.size(), 1); + EXPECT_EQ(evicted.front(), node); + EXPECT_TRUE(node->HasMamba()); + EXPECT_EQ(mamba_alloc.AvailableSlots(), 0); +} + +TEST_F(MambaCacheTest, StepCommitFinishStateInsertsKvPagesHashesAndMamba) { + auto tokens = MakeAlignedTokens(3, kPageSize); + auto page_hashes = MakePageHashes(/*num_pages=*/3); + TreeNode* root = prefix_cache_->Match(token_vec_t{}).device.last_node; + + LocalKVAllocator local_kv(device_alloc_.get(), static_cast(tokens.size())); + LocalMambaAllocator local_mamba(mamba_alloc_.get()); + ASSERT_TRUE(local_mamba.AllocateWorking()); + const std::int32_t working_index = local_mamba.WorkingIndex(); + ASSERT_TRUE(local_mamba.AllocateCheckpoint()); + const std::int32_t checkpoint_index = local_mamba.CheckpointIndex(); + + const auto full_pages = PagedTokenSpans(tokens); + StepCommitRequest request{ + .publish_finished_request = + FinishedRequestPublicationRequest{ + .full_paged_tokens = &full_pages, + .current_device_node = root, + .local_kv_allocator = &local_kv, + .local_mamba_allocator = &local_mamba, + .page_hashes = &page_hashes, + }, + }; + MatchResult match = hybrid_prefix_cache_->StepCommit(std::move(request)).match_result; + + EXPECT_EQ(match.device.DepthInPage(), 3); + TreeNode* terminal = match.device.last_node; + ASSERT_NE(terminal, nullptr); + ASSERT_TRUE(terminal->HasMamba()); + EXPECT_EQ(terminal->MambaSlotIndex(), checkpoint_index); + EXPECT_EQ(terminal->PageHashes(), page_hashes); + EXPECT_TRUE(local_kv.Pages().empty()); + EXPECT_FALSE(local_mamba.HasCheckpoint()); + EXPECT_TRUE(local_mamba.HasWorking()); + EXPECT_EQ(local_mamba.WorkingIndex(), working_index); +} + +TEST_F(MambaCacheTest, StepCommitFinishStateSkipsPublicationWhenNoNewPages) { + auto tokens = MakeAlignedTokens(2, kPageSize); + TreeNode* terminal = InsertKVOnly(tokens); + ASSERT_NE(terminal, nullptr); + ASSERT_FALSE(terminal->HasMamba()); + + LocalKVAllocator local_kv(device_alloc_.get(), /*num_tokens=*/0); + LocalMambaAllocator local_mamba(mamba_alloc_.get()); + ASSERT_TRUE(local_mamba.AllocateWorking()); + ASSERT_TRUE(local_mamba.AllocateCheckpoint()); + const std::int32_t checkpoint_index = local_mamba.CheckpointIndex(); + + const auto full_pages = PagedTokenSpans(tokens); + const auto page_hashes = MakePageHashes(/*num_pages=*/2); + StepCommitRequest request{ + .publish_finished_request = + FinishedRequestPublicationRequest{ + .full_paged_tokens = &full_pages, + .current_device_node = terminal, + .local_kv_allocator = &local_kv, + .local_mamba_allocator = &local_mamba, + .page_hashes = &page_hashes, + }, + }; + MatchResult match = hybrid_prefix_cache_->StepCommit(std::move(request)).match_result; + + EXPECT_EQ(match.device.DepthInPage(), 2); + EXPECT_EQ(match.device.last_node, terminal); + EXPECT_FALSE(terminal->HasMamba()); + EXPECT_TRUE(local_mamba.HasCheckpoint()); + EXPECT_EQ(local_mamba.CheckpointIndex(), checkpoint_index); + EXPECT_TRUE(local_kv.Pages().empty()); +} + +TEST_F(MambaCacheTest, StepCommitRetractPublicationInsertsKvAndReturnsRawStateRecoveryMatch) { + auto tokens = MakeAlignedTokens(3, kPageSize); + token_vec_t prefix_tokens(tokens.begin(), tokens.begin() + kPageSize); + TreeNode* prefix = InsertKVOnly(prefix_tokens); + ASSERT_NE(prefix, nullptr); + ASSERT_FALSE(prefix->HasMamba()); + + const auto full_pages = PagedTokenSpans(tokens); + StepCommitRequest count_request{ + .plan_device_prefix_insertion = + DevicePrefixInsertionPlanRequest{ + .full_paged_tokens = &full_pages, + .current_device_node = prefix, + }, + }; + EXPECT_EQ(hybrid_prefix_cache_->StepCommit(std::move(count_request)).device_insert_page_count, 2); + + const std::vector existing_pages = DevicePagesFromRoot(prefix); + auto pages_to_insert = device_alloc_->Allocate(/*num_pages=*/2); + const std::vector inserted_pages = pages_to_insert.Ids(); + + StepCommitRequest request{ + .publish_device_prefix_insertion = + DevicePrefixInsertionRequest{ + .full_paged_tokens = &full_pages, + .current_device_node = prefix, + .pages_to_insert = std::move(pages_to_insert), + }, + }; + MatchResult match = hybrid_prefix_cache_->StepCommit(std::move(request)).match_result; + + EXPECT_EQ(match.device.DepthInPage(), 3); + EXPECT_EQ(match.host.DepthInPage(), 0); + EXPECT_EQ(match.mamba_cow_src_index, -1); + EXPECT_EQ(match.mamba_branching_seqlen, -1); + + std::vector expected_pages = existing_pages; + expected_pages.insert(expected_pages.end(), inserted_pages.begin(), inserted_pages.end()); + EXPECT_EQ(DevicePagesFromRoot(match.device.last_node), expected_pages); + + // The retract facade intentionally returns the raw KV state-recovery match + // used for host writeback planning, not the Mamba-capped hybrid match. + MatchResult hybrid_match = hybrid_prefix_cache_->MatchPrefix(full_pages, MatchIntent::StateRecovery).compat_match; + EXPECT_EQ(hybrid_match.device.DepthInPage(), 0); +} + +TEST_F(MambaCacheTest, PublishFinishMambaStatePrefersCheckpointOverWorking) { + auto tokens = MakeAlignedTokens(2, kPageSize); + TreeNode* terminal = InsertKVOnly(tokens); + ASSERT_NE(terminal, nullptr); + ASSERT_FALSE(terminal->HasMamba()); + + LocalMambaAllocator local_mamba(mamba_alloc_.get()); + ASSERT_TRUE(local_mamba.AllocateWorking()); + const std::int32_t working_index = local_mamba.WorkingIndex(); + ASSERT_TRUE(local_mamba.AllocateCheckpoint()); + const std::int32_t checkpoint_index = local_mamba.CheckpointIndex(); + + HybridPrefixCacheTestPeer::PublishFinishMambaState(*hybrid_prefix_cache_, PagedTokenSpans(tokens), &local_mamba); + + ASSERT_TRUE(terminal->HasMamba()); + EXPECT_EQ(terminal->MambaSlotIndex(), checkpoint_index); + EXPECT_FALSE(local_mamba.HasCheckpoint()); + EXPECT_TRUE(local_mamba.HasWorking()); + EXPECT_EQ(local_mamba.WorkingIndex(), working_index); +} + +TEST_F(MambaCacheTest, PublishFinishMambaStateSkipsTerminalWithExistingMamba) { + auto tokens = MakeAlignedTokens(2, kPageSize); + TreeNode* terminal = InsertKVAndMamba(tokens); + ASSERT_NE(terminal, nullptr); + ASSERT_TRUE(terminal->HasMamba()); + const std::int32_t original_index = terminal->MambaSlotIndex(); + + LocalMambaAllocator local_mamba(mamba_alloc_.get()); + ASSERT_TRUE(local_mamba.AllocateWorking()); + ASSERT_TRUE(local_mamba.AllocateCheckpoint()); + + HybridPrefixCacheTestPeer::PublishFinishMambaState(*hybrid_prefix_cache_, PagedTokenSpans(tokens), &local_mamba); + + EXPECT_TRUE(terminal->HasMamba()); + EXPECT_EQ(terminal->MambaSlotIndex(), original_index); + EXPECT_TRUE(local_mamba.HasWorking()); + EXPECT_TRUE(local_mamba.HasCheckpoint()); +} + +TEST_F(MambaCacheTest, StepCommitRetractMambaStateAttachesCheckpointAndReleasesRequestLocalSlots) { + auto tokens = MakeAlignedTokens(2, kPageSize); + TreeNode* terminal = InsertKVOnly(tokens); + ASSERT_NE(terminal, nullptr); + ASSERT_FALSE(terminal->HasMamba()); + + auto local_mamba = std::make_unique(mamba_alloc_.get()); + ASSERT_TRUE(local_mamba->AllocateWorking()); + ASSERT_TRUE(local_mamba->AllocateCheckpoint()); + const std::int32_t checkpoint_index = local_mamba->CheckpointIndex(); + + hybrid_prefix_cache_->StepCommit({ + .publish_tree_owned_request_state = + TreeOwnedRequestStatePublicationRequest{ + .terminal = terminal, + .local_mamba_allocator_owner = &local_mamba, + }, + }); + + EXPECT_EQ(local_mamba, nullptr); + ASSERT_TRUE(terminal->HasMamba()); + EXPECT_EQ(terminal->MambaSlotIndex(), checkpoint_index); + EXPECT_EQ(mamba_alloc_->AvailableSlots(), kMambaSlots - 1); +} + +TEST_F(MambaCacheTest, StepCommitRetractMambaStateFallsBackToWorkingAndReleasesRequestLocalSlots) { + auto tokens = MakeAlignedTokens(2, kPageSize); + TreeNode* terminal = InsertKVOnly(tokens); + ASSERT_NE(terminal, nullptr); + ASSERT_FALSE(terminal->HasMamba()); + + auto local_mamba = std::make_unique(mamba_alloc_.get()); + ASSERT_TRUE(local_mamba->AllocateWorking()); + const std::int32_t working_index = local_mamba->WorkingIndex(); + + hybrid_prefix_cache_->StepCommit({ + .publish_tree_owned_request_state = + TreeOwnedRequestStatePublicationRequest{ + .terminal = terminal, + .local_mamba_allocator_owner = &local_mamba, + }, + }); + + EXPECT_EQ(local_mamba, nullptr); + ASSERT_TRUE(terminal->HasMamba()); + EXPECT_EQ(terminal->MambaSlotIndex(), working_index); + EXPECT_EQ(mamba_alloc_->AvailableSlots(), kMambaSlots - 1); +} + +TEST_F(MambaCacheTest, StepCommitRetractMambaStateReleasesRequestLocalSlotsWhenTerminalAlreadyHasMamba) { + auto tokens = MakeAlignedTokens(2, kPageSize); + TreeNode* terminal = InsertKVAndMamba(tokens); + ASSERT_NE(terminal, nullptr); + ASSERT_TRUE(terminal->HasMamba()); + const std::int32_t original_index = terminal->MambaSlotIndex(); + + auto local_mamba = std::make_unique(mamba_alloc_.get()); + ASSERT_TRUE(local_mamba->AllocateWorking()); + ASSERT_TRUE(local_mamba->AllocateCheckpoint()); + + hybrid_prefix_cache_->StepCommit({ + .publish_tree_owned_request_state = + TreeOwnedRequestStatePublicationRequest{ + .terminal = terminal, + .local_mamba_allocator_owner = &local_mamba, + }, + }); + + EXPECT_EQ(local_mamba, nullptr); + EXPECT_TRUE(terminal->HasMamba()); + EXPECT_EQ(terminal->MambaSlotIndex(), original_index); + EXPECT_EQ(mamba_alloc_->AvailableSlots(), kMambaSlots - 1); +} + +TEST_F(MambaCacheTest, AdmitComputesBranchingCheckpointWithoutMutatingMatch) { + FillMambaSlots(); + ASSERT_EQ(mamba_alloc_->AvailableSlots(), 0); + + MatchResult match{}; + std::map simulated_free; + AdmissionRequest request{ + .request_id = "r-prefill-first", + .device_pages_needed = 0, + .tokens_this_round = 5, + .first_raw_position_of_op = 0, + .target_raw_tokens_exclusive = 5, + .compat_match = &match, + .auxiliary_tree_slots_needed = 2, + .compute_branching_checkpoint = true, + }; + AdmissionVerdict result = hybrid_prefix_cache_->Admit(request, simulated_free); + + EXPECT_TRUE(result.admitted); + ASSERT_TRUE(result.mamba_branching_seqlen.has_value()); + EXPECT_EQ(*result.mamba_branching_seqlen, 4); + EXPECT_EQ(match.mamba_branching_seqlen, -1); + EXPECT_GE(mamba_alloc_->AvailableSlots(), 2); +} + +TEST_F(MambaCacheTest, AdmitContinuePrefillReservesOneMambaSlot) { + FillMambaSlots(); + ASSERT_EQ(mamba_alloc_->AvailableSlots(), 0); + + std::map simulated_free; + AdmissionRequest request{ + .request_id = "r-prefill", + .device_pages_needed = 0, + .first_raw_position_of_op = 4, + .target_raw_tokens_exclusive = 8, + .auxiliary_tree_slots_needed = 1, + }; + EXPECT_TRUE(hybrid_prefix_cache_->Admit(request, simulated_free).admitted); + + EXPECT_GE(mamba_alloc_->AvailableSlots(), 1); +} + +TEST_F(MambaCacheTest, AdmitDecodeDoesNotReserveMambaSlots) { + auto nodes = FillMambaSlots(); + ASSERT_EQ(mamba_alloc_->AvailableSlots(), 0); + + std::map simulated_free; + AdmissionRequest request{ + .request_id = "r-decode", + .device_pages_needed = 0, + .first_raw_position_of_op = 8, + .target_raw_tokens_exclusive = 9, + }; + EXPECT_TRUE(hybrid_prefix_cache_->Admit(request, simulated_free).admitted); + + EXPECT_EQ(mamba_alloc_->AvailableSlots(), 0); + EXPECT_EQ(std::count_if(nodes.begin(), nodes.end(), [](const TreeNode* node) { return node->HasMamba(); }), + kMambaSlots); +} + +TEST_F(MambaCacheTest, AdmitRecoverDecodeProtectsMambaRecoverySource) { + auto nodes = FillMambaSlots(); + ASSERT_EQ(mamba_alloc_->AvailableSlots(), 0); + TreeNode* recovery_source = nodes.front(); + ASSERT_NE(recovery_source, nullptr); + ASSERT_TRUE(recovery_source->HasMamba()); + + std::map simulated_free; + MatchResult match{}; + AdmissionRequest request{ + .request_id = "r-retracted", + .device_pages_needed = 0, + .target_raw_tokens_exclusive = 8, + .compat_match = &match, + .protected_recovery_node = recovery_source, + .auxiliary_tree_slots_needed = 2, + .fresh_request_table_view = true, + }; + EXPECT_TRUE(hybrid_prefix_cache_->Admit(request, simulated_free).admitted); + + EXPECT_TRUE(recovery_source->HasMamba()); + EXPECT_GE(mamba_alloc_->AvailableSlots(), 2); +} + } // namespace tokenspeed::test diff --git a/tokenspeed-scheduler/tests/cpp/test_paged_cache_eviction.cpp b/tokenspeed-scheduler/tests/cpp/test_paged_cache_eviction.cpp index 178a400e3..8c8459ff4 100644 --- a/tokenspeed-scheduler/tests/cpp/test_paged_cache_eviction.cpp +++ b/tokenspeed-scheduler/tests/cpp/test_paged_cache_eviction.cpp @@ -43,7 +43,7 @@ TEST_F(PagedCacheEvictionTest, PassiveEvictionReleasesPagedCachePages) { const std::int32_t fh_before = fh_alloc_->AvailablePages(); const std::int32_t swa_before = swa_alloc_->AvailablePages(); - hybrid_->AttachPagedCacheSnapshotToNode(attach_a, MakeCompleteSnapshot(kLcm)); + HybridPrefixCacheTestPeer::AttachPagedCacheSnapshotToNode(*hybrid_, attach_a, MakeCompleteSnapshot(kLcm)); EXPECT_TRUE(attach_a->HasPagedCacheSnapshot()); // The snapshot must hold *some* pages from each group, otherwise the test // below ("eviction returns them") is vacuous. We do NOT assert the exact @@ -74,4 +74,51 @@ TEST_F(PagedCacheEvictionTest, PassiveEvictionReleasesPagedCachePages) { EXPECT_EQ(swa_alloc_->AvailablePages(), swa_before); } +TEST_F(PagedCacheEvictionTest, StatePressurePrunesOnlyStateWhenHistoryHasCapacity) { + TreeNode* terminal = InsertDevicePages(/*num_pages=*/2, /*token_start=*/1); + ASSERT_NE(terminal, nullptr); + + TreeNode* attach = kv_cache_->GetRadixTree().SplitAt(terminal, kLcm); + ASSERT_NE(attach, nullptr); + HybridPrefixCacheTestPeer::AttachPagedCacheSnapshotToNode(*hybrid_, attach, MakeCompleteSnapshot(kLcm)); + ASSERT_TRUE(attach->HasPagedCacheSnapshot()); + ASSERT_TRUE(attach->GetPagedCacheSnapshot()->IsCompleteFor(PagedCacheGroupFamily::History)); + ASSERT_TRUE(attach->GetPagedCacheSnapshot()->IsCompleteFor(PagedCacheGroupFamily::State)); + + const std::int32_t fh_available_before_admit = fh_alloc_->AvailablePages(); + ASSERT_GE(fh_available_before_admit, 1); + + std::vector state_saturator = swa_alloc_->Allocate(swa_alloc_->AvailablePages()); + ASSERT_FALSE(state_saturator.empty()); + ASSERT_EQ(swa_alloc_->AvailablePages(), 0); + + auto simulated_free = hybrid_->InitialSimulatedFree(); + MatchResult match = kv_cache_->Match(token_vec_t{}); + AdmissionRequest request{ + .request_id = "state-pressure", + .device_pages_needed = 0, + .tokens_this_round = kLcm, + .first_raw_position_of_op = 0, + .target_raw_tokens_exclusive = kLcm, + .compat_match = &match, + .auxiliary_tree_slots_needed = 2, + .compute_branching_checkpoint = true, + }; + + auto result = hybrid_->Admit(request, simulated_free); + + ASSERT_TRUE(result.admitted); + ASSERT_TRUE(attach->HasPagedCacheSnapshot()) + << "History had enough free pages; State pressure must not full-prune the snapshot"; + const auto* snap = attach->GetPagedCacheSnapshot(); + ASSERT_NE(snap, nullptr); + EXPECT_TRUE(snap->IsCompleteFor(PagedCacheGroupFamily::History)); + EXPECT_FALSE(snap->IsCompleteFor(PagedCacheGroupFamily::State)); + EXPECT_NE(snap->groups.find("fh"), snap->groups.end()); + EXPECT_EQ(snap->groups.find("swa"), snap->groups.end()); + + HybridPrefixCacheTestPeer::ReleaseRequest(*hybrid_, "state-pressure"); + swa_alloc_->Deallocate(state_saturator); +} + } // namespace tokenspeed::test diff --git a/tokenspeed-scheduler/tests/cpp/test_paged_cache_family_split.cpp b/tokenspeed-scheduler/tests/cpp/test_paged_cache_family_split.cpp index aee02ae3b..9bf6623b1 100644 --- a/tokenspeed-scheduler/tests/cpp/test_paged_cache_family_split.cpp +++ b/tokenspeed-scheduler/tests/cpp/test_paged_cache_family_split.cpp @@ -41,9 +41,9 @@ TEST_F(PagedCacheFamilySplitTest, HistoryCompleteStateMissingFallback) { ASSERT_NE(n512, nullptr); ASSERT_NE(n768, nullptr); - hybrid_->AttachPagedCacheSnapshotToNode(n256, MakeCompleteSnapshot(256)); - hybrid_->AttachPagedCacheSnapshotToNode(n512, MakeCompleteSnapshot(512)); - hybrid_->AttachPagedCacheSnapshotToNode(n768, MakeCompleteSnapshot(768)); + HybridPrefixCacheTestPeer::AttachPagedCacheSnapshotToNode(*hybrid_, n256, MakeCompleteSnapshot(256)); + HybridPrefixCacheTestPeer::AttachPagedCacheSnapshotToNode(*hybrid_, n512, MakeCompleteSnapshot(512)); + HybridPrefixCacheTestPeer::AttachPagedCacheSnapshotToNode(*hybrid_, n768, MakeCompleteSnapshot(768)); // Downgrade only the deepest snapshot: history-only at 768. DowngradeSnapshotToHistoryOnly(n768); @@ -51,7 +51,7 @@ TEST_F(PagedCacheFamilySplitTest, HistoryCompleteStateMissingFallback) { EXPECT_TRUE(n768->GetPagedCacheSnapshot()->IsCompleteFor(PagedCacheGroupFamily::History)); EXPECT_FALSE(n768->GetPagedCacheSnapshot()->IsCompleteFor(PagedCacheGroupFamily::State)); - auto match = hybrid_->Match(MakeAlignedTokens(num_pages, kPageSize, /*start=*/1)); + auto match = hybrid_->MatchPrefix(MakeAlignedTokens(num_pages, kPageSize, /*start=*/1)).compat_match; ASSERT_NE(match.paged_cache.last_node, nullptr); // History chain reaches 768 but state at 768 is missing; segments_needed=1 // forces fallback to 512. @@ -74,13 +74,13 @@ TEST_F(PagedCacheFamilyWideWindowTest, StateWindowDiscontinuityFallback) { ASSERT_NE(n512, nullptr); ASSERT_NE(n768, nullptr); - hybrid_->AttachPagedCacheSnapshotToNode(n256, MakeCompleteSnapshot(256)); - hybrid_->AttachPagedCacheSnapshotToNode(n512, MakeCompleteSnapshot(512)); - hybrid_->AttachPagedCacheSnapshotToNode(n768, MakeCompleteSnapshot(768)); + HybridPrefixCacheTestPeer::AttachPagedCacheSnapshotToNode(*hybrid_, n256, MakeCompleteSnapshot(256)); + HybridPrefixCacheTestPeer::AttachPagedCacheSnapshotToNode(*hybrid_, n512, MakeCompleteSnapshot(512)); + HybridPrefixCacheTestPeer::AttachPagedCacheSnapshotToNode(*hybrid_, n768, MakeCompleteSnapshot(768)); DowngradeSnapshotToHistoryOnly(n512); - auto match = hybrid_->Match(MakeAlignedTokens(num_pages, kPageSize, /*start=*/1)); + auto match = hybrid_->MatchPrefix(MakeAlignedTokens(num_pages, kPageSize, /*start=*/1)).compat_match; ASSERT_NE(match.paged_cache.last_node, nullptr); EXPECT_EQ(match.paged_cache.last_node, n256); EXPECT_EQ(match.paged_cache.prefix_len_tokens, 256); @@ -100,9 +100,9 @@ TEST_F(PagedCacheFamilySplitTest, StateDetachDoesNotBreakHistoryChain) { ASSERT_NE(n512, nullptr); ASSERT_NE(n768, nullptr); - hybrid_->AttachPagedCacheSnapshotToNode(n256, MakeCompleteSnapshot(256)); - hybrid_->AttachPagedCacheSnapshotToNode(n512, MakeCompleteSnapshot(512)); - hybrid_->AttachPagedCacheSnapshotToNode(n768, MakeCompleteSnapshot(768)); + HybridPrefixCacheTestPeer::AttachPagedCacheSnapshotToNode(*hybrid_, n256, MakeCompleteSnapshot(256)); + HybridPrefixCacheTestPeer::AttachPagedCacheSnapshotToNode(*hybrid_, n512, MakeCompleteSnapshot(512)); + HybridPrefixCacheTestPeer::AttachPagedCacheSnapshotToNode(*hybrid_, n768, MakeCompleteSnapshot(768)); DowngradeSnapshotToHistoryOnly(n512); ASSERT_TRUE(n512->HasPagedCacheSnapshot()); @@ -110,7 +110,7 @@ TEST_F(PagedCacheFamilySplitTest, StateDetachDoesNotBreakHistoryChain) { EXPECT_FALSE(n512->GetPagedCacheSnapshot()->IsCompleteFor(PagedCacheGroupFamily::State)); EXPECT_TRUE(n768->GetPagedCacheSnapshot()->IsCompleteFor(PagedCacheGroupFamily::State)); - auto match = hybrid_->Match(MakeAlignedTokens(num_pages, kPageSize, /*start=*/1)); + auto match = hybrid_->MatchPrefix(MakeAlignedTokens(num_pages, kPageSize, /*start=*/1)).compat_match; ASSERT_NE(match.paged_cache.last_node, nullptr); // History chain unbroken; state at 768 (only the trailing segment) is fine. EXPECT_EQ(match.paged_cache.last_node, n768); diff --git a/tokenspeed-scheduler/tests/cpp/test_paged_cache_prefix_hit_commit.cpp b/tokenspeed-scheduler/tests/cpp/test_paged_cache_prefix_hit_commit.cpp index 049113942..5d64ca917 100644 --- a/tokenspeed-scheduler/tests/cpp/test_paged_cache_prefix_hit_commit.cpp +++ b/tokenspeed-scheduler/tests/cpp/test_paged_cache_prefix_hit_commit.cpp @@ -42,27 +42,27 @@ TEST_F(PagedCachePrefixHitCommitTest, PrefixHitFollowedByCheckpointDoesNotOverfl TreeNode* n256 = kv_cache_->GetRadixTree().SplitAt(terminal, 256); ASSERT_NE(n256, nullptr); - hybrid_->AttachPagedCacheSnapshotToNode(n256, MakeCompleteSnapshot(256)); + HybridPrefixCacheTestPeer::AttachPagedCacheSnapshotToNode(*hybrid_, n256, MakeCompleteSnapshot(256)); ASSERT_TRUE(n256->HasPagedCacheSnapshot()); const auto tokens = MakeAlignedTokens(num_pages, kPageSize, /*start=*/1); // The second request: prefix-cache match returns the depth-256 hit. - auto pre_match = hybrid_->Match(tokens); + auto pre_match = hybrid_->MatchPrefix(tokens).compat_match; ASSERT_NE(pre_match.paged_cache.last_node, nullptr); EXPECT_EQ(pre_match.paged_cache.last_node, n256); EXPECT_EQ(pre_match.paged_cache.prefix_len_tokens, 256); // Import borrowed prefix + acquire fresh pages for the remaining LCM segment. const std::string request_id = "r-prefix-hit"; - hybrid_->AcquireForRequest(request_id, - /*first_raw_position_of_op=*/256, - /*target_raw_tokens_exclusive=*/512, pre_match.paged_cache); + HybridPrefixCacheTestPeer::AcquireForRequest(*hybrid_, request_id, + /*first_raw_position_of_op=*/256, + /*target_raw_tokens_exclusive=*/512, pre_match.paged_cache); // Trigger CheckpointStateToSnapshot at the next LCM boundary. Pre-fix this // throws std::logic_error("not enough owned pages for window"); post-fix it // commits only the new LCM segment's delta to the snapshot. - ASSERT_NO_THROW(hybrid_->CommitChunk(request_id, terminal)); + ASSERT_NO_THROW(HybridPrefixCacheTestPeer::CommitChunk(*hybrid_, request_id, terminal)); // After commit, n512 (=terminal) must hold a complete snapshot covering // both required families. @@ -75,7 +75,7 @@ TEST_F(PagedCachePrefixHitCommitTest, PrefixHitFollowedByCheckpointDoesNotOverfl // Observable: a fresh Match now reconstructs the full trailing window // (state_span = [n256, n512]) and exposes window/raw_per_page page ids // for the sliding "swa" group. - auto post_match = hybrid_->Match(tokens); + auto post_match = hybrid_->MatchPrefix(tokens).compat_match; ASSERT_NE(post_match.paged_cache.last_node, nullptr); EXPECT_EQ(post_match.paged_cache.prefix_len_tokens, 512); @@ -92,7 +92,7 @@ TEST_F(PagedCachePrefixHitCommitTest, PrefixHitFollowedByCheckpointDoesNotOverfl EXPECT_EQ(static_cast(swa_ids.size()), expected_state_pages); // Clean up the request tables; owned pages return via RAII / ReleaseAll. - hybrid_->ReleaseRequest(request_id); + HybridPrefixCacheTestPeer::ReleaseRequest(*hybrid_, request_id); } } // namespace tokenspeed::test diff --git a/tokenspeed-scheduler/tests/cpp/test_paged_cache_prefix_match.cpp b/tokenspeed-scheduler/tests/cpp/test_paged_cache_prefix_match.cpp index 0c2f7de84..aa938439b 100644 --- a/tokenspeed-scheduler/tests/cpp/test_paged_cache_prefix_match.cpp +++ b/tokenspeed-scheduler/tests/cpp/test_paged_cache_prefix_match.cpp @@ -32,7 +32,7 @@ TEST_F(PagedCachePrefixMatchTest, CapVsNoCap320) { const auto tokens = MakeAlignedTokens(num_pages, kPageSize, /*start=*/1); // No snapshot: paged_cache empty; device/host capped to root. - auto match = hybrid_->Match(tokens); + auto match = hybrid_->MatchPrefix(tokens).compat_match; EXPECT_EQ(match.paged_cache.last_node, nullptr); EXPECT_EQ(match.paged_cache.prefix_len_tokens, 0); ASSERT_NE(match.device.last_node, nullptr); @@ -46,12 +46,12 @@ TEST_F(PagedCachePrefixMatchTest, CapVsNoCap320) { TreeNode* boundary_256 = kv_cache_->GetRadixTree().SplitAt(terminal, 256); ASSERT_NE(boundary_256, nullptr); EXPECT_EQ(boundary_256->DepthInTokens(), 256u); - hybrid_->AttachPagedCacheSnapshotToNode(boundary_256, MakeCompleteSnapshot(256)); + HybridPrefixCacheTestPeer::AttachPagedCacheSnapshotToNode(*hybrid_, boundary_256, MakeCompleteSnapshot(256)); ASSERT_TRUE(boundary_256->HasPagedCacheSnapshot()); EXPECT_TRUE(boundary_256->GetPagedCacheSnapshot()->IsCompleteFor(PagedCacheGroupFamily::History)); EXPECT_TRUE(boundary_256->GetPagedCacheSnapshot()->IsCompleteFor(PagedCacheGroupFamily::State)); - match = hybrid_->Match(tokens); + match = hybrid_->MatchPrefix(tokens).compat_match; ASSERT_NE(match.paged_cache.last_node, nullptr); EXPECT_EQ(match.paged_cache.last_node, boundary_256); EXPECT_EQ(match.paged_cache.prefix_len_tokens, 256); @@ -73,18 +73,18 @@ TEST_F(PagedCachePrefixMatchTest, ContiguousChainBreakMid) { ASSERT_NE(n512, nullptr); ASSERT_NE(n768, nullptr); - hybrid_->AttachPagedCacheSnapshotToNode(n256, MakeCompleteSnapshot(256)); - hybrid_->AttachPagedCacheSnapshotToNode(n512, MakeCompleteSnapshot(512)); - hybrid_->AttachPagedCacheSnapshotToNode(n768, MakeCompleteSnapshot(768)); + HybridPrefixCacheTestPeer::AttachPagedCacheSnapshotToNode(*hybrid_, n256, MakeCompleteSnapshot(256)); + HybridPrefixCacheTestPeer::AttachPagedCacheSnapshotToNode(*hybrid_, n512, MakeCompleteSnapshot(512)); + HybridPrefixCacheTestPeer::AttachPagedCacheSnapshotToNode(*hybrid_, n768, MakeCompleteSnapshot(768)); // Drop the middle snapshot; chain scan must stop at the gap. - auto dropped = hybrid_->DetachPagedCacheSnapshotFromNode(n512); + auto dropped = HybridPrefixCacheTestPeer::DetachPagedCacheSnapshotFromNode(*hybrid_, n512); EXPECT_TRUE(dropped != nullptr); EXPECT_FALSE(n512->HasPagedCacheSnapshot()); ASSERT_TRUE(n768->HasPagedCacheSnapshot()); EXPECT_TRUE(n768->GetPagedCacheSnapshot()->IsCompleteFor(PagedCacheGroupFamily::History)); - auto match = hybrid_->Match(MakeAlignedTokens(num_pages, kPageSize, /*start=*/1)); + auto match = hybrid_->MatchPrefix(MakeAlignedTokens(num_pages, kPageSize, /*start=*/1)).compat_match; ASSERT_NE(match.paged_cache.last_node, nullptr); EXPECT_EQ(match.paged_cache.last_node, n256); EXPECT_EQ(match.paged_cache.prefix_len_tokens, 256); diff --git a/tokenspeed-scheduler/tests/cpp/test_prefix_cache_host_match.cpp b/tokenspeed-scheduler/tests/cpp/test_prefix_cache_host_match.cpp index 3883385c7..8957c6f07 100644 --- a/tokenspeed-scheduler/tests/cpp/test_prefix_cache_host_match.cpp +++ b/tokenspeed-scheduler/tests/cpp/test_prefix_cache_host_match.cpp @@ -24,14 +24,31 @@ #include #include "unit_test_helper.h" +#include "resource/allocator/mamba_chunk_allocator.h" #include "resource/allocator/owned_pages.h" #include "resource/allocator/page_allocator.h" +#include "resource/hybrid_prefix_cache/hybrid_prefix_cache.h" #include "resource/kv_prefix_cache/kv_prefix_cache.h" #include "resource/radix_tree/tree_node.h" #include "resource/types.h" +#include "scheduler/page_hasher.h" namespace tokenspeed::test { +namespace { + +std::vector> TokenPages(const token_vec_t& tokens, std::int32_t page_size) { + const std::size_t num_pages = tokens.size() / page_size; + std::vector> pages; + pages.reserve(num_pages); + for (std::size_t i = 0; i < num_pages; ++i) { + pages.emplace_back(tokens.data() + i * page_size, static_cast(page_size)); + } + return pages; +} + +} // namespace + // page_size=2, device_total=2, host_total=16. // Step 1: Insert([1,2], pages=[0]) → node([1,2]) has device=[0]. // Step 2: Match([1,2]) → device.matched=1; Insert → host=[1]. @@ -116,4 +133,50 @@ TEST(PrefixCacheHostMatchDiag, HostCacheMatchAfterDirectAttachResource) { EXPECT_EQ(match12c.host.DepthInPage(), 1) << "host still matches"; } +TEST(PrefixCacheHostMatchDiag, RawHostStorageHashSeedIgnoresAugmentedMambaMatchCapping) { + static constexpr std::int32_t kPageSize = 2; + PageAllocator device_alloc{kPageSize, 16}; + PageAllocator host_alloc{kPageSize, 16}; + KVPrefixCache cache{&device_alloc, &host_alloc, false}; + MambaChunkAllocator mamba_alloc{2}; + HybridPrefixCache hybrid_cache{cache, device_alloc, &mamba_alloc, /*mamba_cache_chunk_size=*/4}; + + token_vec_t stored_tokens = MakeAlignedTokens(/*num_pages=*/2, kPageSize, /*start=*/1); + auto stored_pages = TokenPages(stored_tokens, kPageSize); + const auto stored_hashes = ComputePagedHashes(stored_pages, ""); + + cache.Insert(stored_pages, {}, device_alloc.Allocate(2), stored_hashes); + cache.Insert(stored_pages, {}, host_alloc.Allocate(2), stored_hashes); + + const MatchResult augmented_match = hybrid_cache.MatchPrefix(stored_pages).compat_match; + ASSERT_EQ(augmented_match.host.DepthInPage(), 0) + << "regular HybridPrefixCache::Match is Mamba-recovery capped when no Mamba slot exists"; + + token_vec_t lookup_tokens = MakeAlignedTokens(/*num_pages=*/3, kPageSize, /*start=*/1); + auto lookup_pages = TokenPages(lookup_tokens, kPageSize); + const auto seed = hybrid_cache.LookupRawHostStorageHashSeed(lookup_pages); + + EXPECT_EQ(seed.host_matched_pages, 2); + EXPECT_EQ(seed.prior_hash_seed, stored_hashes.back()); +} + +TEST(PrefixCacheHostMatchDiag, RawHostStorageHashSeedReturnsEmptyPriorWhenHostNodeHasNoPageHashes) { + static constexpr std::int32_t kPageSize = 2; + PageAllocator device_alloc{kPageSize, 16}; + PageAllocator host_alloc{kPageSize, 16}; + KVPrefixCache cache{&device_alloc, &host_alloc, false}; + HybridPrefixCache hybrid_cache{cache, device_alloc, nullptr, /*mamba_cache_chunk_size=*/0}; + + token_vec_t stored_tokens = MakeAlignedTokens(/*num_pages=*/2, kPageSize, /*start=*/1); + auto stored_pages = TokenPages(stored_tokens, kPageSize); + cache.Insert(stored_pages, {}, host_alloc.Allocate(2)); + + token_vec_t lookup_tokens = MakeAlignedTokens(/*num_pages=*/3, kPageSize, /*start=*/1); + auto lookup_pages = TokenPages(lookup_tokens, kPageSize); + const auto seed = hybrid_cache.LookupRawHostStorageHashSeed(lookup_pages); + + EXPECT_EQ(seed.host_matched_pages, 2); + EXPECT_TRUE(seed.prior_hash_seed.empty()); +} + } // namespace tokenspeed::test diff --git a/tokenspeed-scheduler/tests/cpp/test_request_cache_context.cpp b/tokenspeed-scheduler/tests/cpp/test_request_cache_context.cpp new file mode 100644 index 000000000..6a0fa402e --- /dev/null +++ b/tokenspeed-scheduler/tests/cpp/test_request_cache_context.cpp @@ -0,0 +1,221 @@ +// Copyright (c) 2026 LightSeek Foundation +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include + +#include +#include +#include + +#include "fsm/forward_events.h" +#include "resource/allocator/mamba_chunk_allocator.h" +#include "resource/allocator/page_allocator.h" +#include "resource/allocator/req_pool_allocator.h" +#include "resource/hybrid_prefix_cache/hybrid_prefix_cache.h" +#include "resource/kv_prefix_cache/kv_prefix_cache.h" +#include "scheduler/operations/forward.h" +#include "scheduler/request.h" +#include "scheduler/request_cache_context.h" +#include "unit_test_helper.h" + +namespace tokenspeed::test { + +namespace { + +constexpr std::int32_t kPageSize = 2; +constexpr std::int32_t kMambaCacheChunkSize = 4; + +Request MakeRequest(const std::string& request_id, std::int32_t num_pages) { + return Request{RequestSpec{.request_id = request_id, .tokens = MakeAlignedTokens(num_pages, kPageSize)}, kPageSize, + Role::kFused}; +} + +void ApplyFirstChunkToPrefillDone(Request& request, ReqPoolAllocator& req_pool_allocator, + HybridPrefixCache& hybrid_prefix_cache) { + auto match = hybrid_prefix_cache.MatchPrefix(request.GetFullPagedTokens(/*except_last=*/true)).compat_match; + request.Apply(fsm::SchedulePrefillFirstChunkEvent{ + request.PrefillSize(), /*decode_input_tokens=*/0, &req_pool_allocator, match, Role::kFused, + /*disable_l2_cache=*/false, std::vector{}, std::vector{}, hybrid_prefix_cache}); + ASSERT_TRUE(request.Is()); +} + +} // namespace + +TEST(RequestCacheContextTest, SubmittedViewExposesEmptyOccupiedPagesOnly) { + Request request = MakeRequest("r_submitted", /*num_pages=*/1); + + RequestCacheContext context(request); + + EXPECT_TRUE(context.OccupiedPagesSnapshot().empty()); + EXPECT_EQ(context.OccupiedPageCountSnapshot(), 0); +} + +TEST(RequestCacheContextTest, ForwardViewExposesPostPrefillFlatteningInputs) { + PageAllocator device_allocator(kPageSize, /*total_pages=*/16); + PageAllocator host_allocator(kPageSize, /*total_pages=*/0); + KVPrefixCache prefix_cache(&device_allocator, &host_allocator); + MambaChunkAllocator mamba_allocator(/*num_slots=*/4); + HybridPrefixCache hybrid_prefix_cache(prefix_cache, device_allocator, &mamba_allocator, kMambaCacheChunkSize); + ReqPoolAllocator req_pool_allocator(/*size=*/2); + Request request = MakeRequest("r_prefill", /*num_pages=*/2); + + RequestCacheContext pre_apply_context(request); + EXPECT_EQ(pre_apply_context.OccupiedPageCountSnapshot(), 0); + + auto match = hybrid_prefix_cache.MatchPrefix(request.GetFullPagedTokens(/*except_last=*/true)).compat_match; + request.Apply(fsm::SchedulePrefillFirstChunkEvent{ + request.PrefillSize(), /*decode_input_tokens=*/1, &req_pool_allocator, match, Role::kFused, + /*disable_l2_cache=*/false, std::vector{}, std::vector{}, hybrid_prefix_cache}); + + RequestCacheContext context(request); + std::vector occupied_pages = context.OccupiedPagesSnapshot(); + + EXPECT_EQ(context.OccupiedPageCountSnapshot(), static_cast(occupied_pages.size())); + EXPECT_EQ(occupied_pages.size(), 3u); + EXPECT_GT(context.RequestPoolIndex(), 0); + ASSERT_NE(context.LocalMambaAllocatorView(), nullptr); + + RequestCacheMutation mutation(request); + EXPECT_NE(mutation.MutableTerminalDeviceNode(), nullptr); + + ForwardOperationBase op{}; + hybrid_prefix_cache.StepCommit({ + .worker_metadata = + WorkerCompatibilityCommitRequest{ + .op_base = &op, + .local_mamba_allocator_view = context.LocalMambaAllocatorView(), + }, + }); + EXPECT_NE(op.mamba_working_idx, -1); + EXPECT_NE(op.mamba_checkpoint_dst_idx, -1); +} + +TEST(RequestCacheContextTest, ForwardViewExposesPostDecodeFromRetractedRecoveryInputs) { + PageAllocator device_allocator(kPageSize, /*total_pages=*/16); + PageAllocator host_allocator(kPageSize, /*total_pages=*/0); + KVPrefixCache prefix_cache(&device_allocator, &host_allocator); + HybridPrefixCache hybrid_prefix_cache(prefix_cache, device_allocator, /*allocator=*/nullptr, kMambaCacheChunkSize); + ReqPoolAllocator req_pool_allocator(/*size=*/2); + Request request = MakeRequest("r_recovered", /*num_pages=*/1); + + auto prefill_match = hybrid_prefix_cache.MatchPrefix(request.GetFullPagedTokens(/*except_last=*/true)).compat_match; + request.Apply(fsm::SchedulePrefillFirstChunkEvent{ + request.PrefillSize(), /*decode_input_tokens=*/0, &req_pool_allocator, prefill_match, Role::kFused, + /*disable_l2_cache=*/false, std::vector{}, std::vector{}, hybrid_prefix_cache}); + + auto retract_match = hybrid_prefix_cache.MatchPrefix(request.GetFullPagedTokens(/*except_last=*/true)).compat_match; + request.Apply(fsm::ScheduleRetractEvent{retract_match, hybrid_prefix_cache}); + request.Apply(fsm::WriteBackDoneEvent{}); + ASSERT_TRUE(request.Is()); + + auto recovery_match = + hybrid_prefix_cache.MatchPrefix(request.GetFullPagedTokens(/*except_last=*/true), MatchIntent::StateRecovery) + .compat_match; + request.Apply(fsm::ScheduleDecodeFromRetractedEvent{/*decode_input_tokens=*/1, &req_pool_allocator, recovery_match, + std::vector{}, std::vector{}, + hybrid_prefix_cache}); + ASSERT_TRUE(request.Is()); + + RequestCacheContext context(request); + std::vector occupied_pages = context.OccupiedPagesSnapshot(); + DecodeOperation recovered_op{{ + .request_id = request.Id(), + .request_pool_index = context.RequestPoolIndex(), + .input_length = 1, + .occupied_pages = occupied_pages, + .begin = 0, + .size = context.OccupiedPageCountSnapshot(), + .prefill_length = request.PrefillSize(), + }}; + + EXPECT_EQ(context.OccupiedPageCountSnapshot(), static_cast(occupied_pages.size())); + EXPECT_GT(context.RequestPoolIndex(), 0); + EXPECT_EQ(context.LocalMambaAllocatorView(), nullptr); + RequestCacheMutation mutation(request); + EXPECT_NE(mutation.MutableTerminalDeviceNode(), nullptr); + EXPECT_EQ(recovered_op.begin, 0); + EXPECT_EQ(recovered_op.size, static_cast(recovered_op.occupied_pages.size())); +} + +TEST(RequestCacheContextTest, MutationBridgeDelegatesFrontToBackOwnershipTransfer) { + PageAllocator device_allocator(kPageSize, /*total_pages=*/16); + PageAllocator host_allocator(kPageSize, /*total_pages=*/0); + KVPrefixCache prefix_cache(&device_allocator, &host_allocator); + HybridPrefixCache hybrid_prefix_cache(prefix_cache, device_allocator, /*allocator=*/nullptr, kMambaCacheChunkSize); + ReqPoolAllocator req_pool_allocator(/*size=*/2); + Request request = MakeRequest("r_transfer", /*num_pages=*/3); + ApplyFirstChunkToPrefillDone(request, req_pool_allocator, hybrid_prefix_cache); + + std::vector original_local_pages = request.GetLocalAllocatorPages(); + ASSERT_EQ(original_local_pages.size(), 3u); + + RequestCacheMutation mutation(request); + OwnedPages zero_pages = mutation.TakeFirstLocalKVPages(/*alloc_count=*/0); + EXPECT_TRUE(zero_pages.Empty()); + EXPECT_EQ(request.GetLocalAllocatorPages(), original_local_pages); + + OwnedPages taken_pages = mutation.TakeFirstLocalKVPages(/*alloc_count=*/2); + const std::vector expected_taken(original_local_pages.begin(), original_local_pages.begin() + 2); + const std::vector expected_remaining(original_local_pages.begin() + 2, original_local_pages.end()); + + EXPECT_EQ(taken_pages.Ids(), expected_taken); + EXPECT_EQ(request.GetLocalAllocatorPages(), expected_remaining); +} + +TEST(RequestCacheContextTest, LocalKVPagesSnapshotExposesDiagnosticRequestLocalPages) { + PageAllocator device_allocator(kPageSize, /*total_pages=*/16); + PageAllocator host_allocator(kPageSize, /*total_pages=*/0); + KVPrefixCache prefix_cache(&device_allocator, &host_allocator); + HybridPrefixCache hybrid_prefix_cache(prefix_cache, device_allocator, /*allocator=*/nullptr, kMambaCacheChunkSize); + ReqPoolAllocator req_pool_allocator(/*size=*/2); + Request request = MakeRequest("r_debug_pages", /*num_pages=*/2); + ApplyFirstChunkToPrefillDone(request, req_pool_allocator, hybrid_prefix_cache); + + RequestCacheContext context(request); + + EXPECT_EQ(context.LocalKVPagesSnapshot(), request.GetLocalAllocatorPages()); + EXPECT_FALSE(context.LocalKVPagesSnapshot().empty()); +} + +TEST(RequestCacheContextTest, MutationBridgePreservesRequestErrorSemantics) { + Request submitted_request = MakeRequest("r_submitted_transfer", /*num_pages=*/1); + RequestCacheMutation submitted_mutation(submitted_request); + EXPECT_THROW(submitted_mutation.TakeFirstLocalKVPages(/*alloc_count=*/1), std::logic_error); + + PageAllocator device_allocator(kPageSize, /*total_pages=*/16); + PageAllocator host_allocator(kPageSize, /*total_pages=*/0); + KVPrefixCache prefix_cache(&device_allocator, &host_allocator); + HybridPrefixCache hybrid_prefix_cache(prefix_cache, device_allocator, /*allocator=*/nullptr, kMambaCacheChunkSize); + ReqPoolAllocator req_pool_allocator(/*size=*/2); + Request request = MakeRequest("r_bad_count", /*num_pages=*/1); + ApplyFirstChunkToPrefillDone(request, req_pool_allocator, hybrid_prefix_cache); + + std::vector original_local_pages = request.GetLocalAllocatorPages(); + RequestCacheMutation mutation(request); + + EXPECT_THROW(mutation.TakeFirstLocalKVPages(/*alloc_count=*/-1), std::out_of_range); + EXPECT_EQ(request.GetLocalAllocatorPages(), original_local_pages); + + EXPECT_THROW(mutation.TakeFirstLocalKVPages(static_cast(original_local_pages.size() + 1)), + std::out_of_range); + EXPECT_EQ(request.GetLocalAllocatorPages(), original_local_pages); +} + +} // namespace tokenspeed::test diff --git a/tokenspeed-scheduler/tests/cpp/test_scheduler_memory_diagnostics.cpp b/tokenspeed-scheduler/tests/cpp/test_scheduler_memory_diagnostics.cpp new file mode 100644 index 000000000..fecc90981 --- /dev/null +++ b/tokenspeed-scheduler/tests/cpp/test_scheduler_memory_diagnostics.cpp @@ -0,0 +1,160 @@ +// Copyright (c) 2026 LightSeek Foundation +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "resource/allocator/owned_pages.h" +#include "resource/allocator/page_allocator.h" +#include "resource/hybrid_prefix_cache/hybrid_prefix_cache.h" +#include "resource/kv_prefix_cache/kv_prefix_cache.h" +#include "resource/types.h" +#include "scheduler/device_memory_diagnostics.h" + +namespace tokenspeed::test { + +namespace { + +constexpr std::int32_t kPageSize = 2; +constexpr std::int32_t kMambaCacheChunkSize = 4; + +std::vector> TokenPages(const token_vec_t& tokens) { + std::vector> pages; + pages.reserve(tokens.size() / kPageSize); + for (std::size_t i = 0; i < tokens.size(); i += kPageSize) { + pages.emplace_back(tokens.data() + i, kPageSize); + } + return pages; +} + +HybridPrefixCache::DeviceMemoryDiagnosticsSnapshot DeviceSnapshot( + std::unordered_map tree_device_pages, std::int32_t free_device_pages, + std::int32_t total_device_pages) { + return HybridPrefixCache::DeviceMemoryDiagnosticsSnapshot{ + .tree_device_pages = std::move(tree_device_pages), + .free_device_pages = free_device_pages, + .total_device_pages = total_device_pages, + }; +} + +} // namespace + +TEST(HybridPrefixCacheDeviceStatsTest, AvailableDevicePagesMirrorsAllocatorAcrossAllocateAndRelease) { + PageAllocator device_allocator{kPageSize, /*total_pages=*/8}; + PageAllocator host_allocator{kPageSize, /*total_pages=*/0}; + KVPrefixCache prefix_cache{&device_allocator, &host_allocator}; + HybridPrefixCache hybrid_prefix_cache{prefix_cache, device_allocator, /*allocator=*/nullptr, kMambaCacheChunkSize}; + + EXPECT_EQ(hybrid_prefix_cache.AvailableDevicePages(), static_cast(device_allocator.AvailablePages())); + EXPECT_EQ(hybrid_prefix_cache.AvailableDevicePages(), 7u); + + { + OwnedPages pages = device_allocator.Allocate(/*num_pages=*/2); + ASSERT_EQ(pages.Size(), 2); + EXPECT_EQ(hybrid_prefix_cache.AvailableDevicePages(), + static_cast(device_allocator.AvailablePages())); + EXPECT_EQ(hybrid_prefix_cache.AvailableDevicePages(), 5u); + } + + EXPECT_EQ(hybrid_prefix_cache.AvailableDevicePages(), static_cast(device_allocator.AvailablePages())); + EXPECT_EQ(hybrid_prefix_cache.AvailableDevicePages(), 7u); +} + +TEST(HybridPrefixCacheDeviceStatsTest, DiagnosticsSnapshotReportsTreeAndAllocatorPages) { + PageAllocator device_allocator{kPageSize, /*total_pages=*/8}; + PageAllocator host_allocator{kPageSize, /*total_pages=*/0}; + KVPrefixCache prefix_cache{&device_allocator, &host_allocator}; + HybridPrefixCache hybrid_prefix_cache{prefix_cache, device_allocator, /*allocator=*/nullptr, kMambaCacheChunkSize}; + + auto initial_snapshot = hybrid_prefix_cache.CollectDeviceMemoryDiagnostics(); + EXPECT_TRUE(initial_snapshot.tree_device_pages.empty()); + EXPECT_EQ(initial_snapshot.free_device_pages, 7); + EXPECT_EQ(initial_snapshot.total_device_pages, 7); + + OwnedPages pages = device_allocator.Allocate(/*num_pages=*/2); + std::vector inserted_page_ids = pages.Ids(); + const token_vec_t tokens = {1, 2, 3, 4}; + prefix_cache.Insert(TokenPages(tokens), {}, std::move(pages)); + + auto snapshot = hybrid_prefix_cache.CollectDeviceMemoryDiagnostics(); + ASSERT_EQ(snapshot.tree_device_pages.size(), 2u); + for (std::int32_t page_id : inserted_page_ids) { + EXPECT_EQ(snapshot.tree_device_pages.at(page_id), 1); + } + EXPECT_EQ(snapshot.free_device_pages, 5); + EXPECT_EQ(snapshot.total_device_pages, 7); + EXPECT_TRUE(ValidateDeviceMemoryDiagnostics(/*request_pages=*/{}, snapshot)); +} + +TEST(DeviceMemoryDiagnosticsValidationTest, AcceptsBalancedSnapshot) { + const std::vector request_pages = { + {.request_id = "r1", .state_name = "Prefilling", .pages = {1, 2}}, + {.request_id = "r2", .state_name = "Decoding", .pages = {3}}, + }; + auto device_snapshot = DeviceSnapshot({{4, 1}, {5, 1}}, /*free_device_pages=*/4, /*total_device_pages=*/9); + + EXPECT_TRUE(ValidateDeviceMemoryDiagnostics(request_pages, device_snapshot)); +} + +TEST(DeviceMemoryDiagnosticsValidationTest, RejectsDuplicateRequestLocalPages) { + const std::vector request_pages = { + {.request_id = "r1", .state_name = "Prefilling", .pages = {1}}, + {.request_id = "r2", .state_name = "Decoding", .pages = {1}}, + }; + auto device_snapshot = DeviceSnapshot({{2, 1}}, /*free_device_pages=*/2, /*total_device_pages=*/5); + + EXPECT_FALSE(ValidateDeviceMemoryDiagnostics(request_pages, device_snapshot)); +} + +TEST(DeviceMemoryDiagnosticsValidationTest, RejectsDuplicateTreePages) { + const std::vector request_pages = { + {.request_id = "r1", .state_name = "Prefilling", .pages = {1}}, + }; + auto device_snapshot = DeviceSnapshot({{2, 2}}, /*free_device_pages=*/2, /*total_device_pages=*/4); + + EXPECT_FALSE(ValidateDeviceMemoryDiagnostics(request_pages, device_snapshot)); +} + +TEST(DeviceMemoryDiagnosticsValidationTest, RejectsOutOfRangePageIds) { + const std::vector request_pages = { + {.request_id = "r1", .state_name = "Prefilling", .pages = {0}}, + }; + auto device_snapshot = DeviceSnapshot({{5, 1}}, /*free_device_pages=*/2, /*total_device_pages=*/4); + + EXPECT_FALSE(ValidateDeviceMemoryDiagnostics(request_pages, device_snapshot)); +} + +TEST(DeviceMemoryDiagnosticsValidationTest, RejectsAccountingMismatch) { + const std::vector request_pages = { + {.request_id = "r1", .state_name = "Prefilling", .pages = {1}}, + }; + auto device_snapshot = DeviceSnapshot({{2, 1}}, /*free_device_pages=*/0, /*total_device_pages=*/4); + + EXPECT_FALSE(ValidateDeviceMemoryDiagnostics(request_pages, device_snapshot)); +} + +} // namespace tokenspeed::test diff --git a/tokenspeed-scheduler/tests/cpp/test_scheduler_plan.cpp b/tokenspeed-scheduler/tests/cpp/test_scheduler_plan.cpp index c44e07ae4..41a4732b9 100644 --- a/tokenspeed-scheduler/tests/cpp/test_scheduler_plan.cpp +++ b/tokenspeed-scheduler/tests/cpp/test_scheduler_plan.cpp @@ -19,9 +19,97 @@ // SOFTWARE. #include "integration_test_helper.h" +#include "scheduler/page_hasher.h" + +#include +#include namespace tokenspeed::test { +namespace { + +const FlatForwardOperation* GetForwardOp(const ExecutionPlan& plan) { + for (const auto& op : plan.Operations()) { + if (const auto* fwd = std::get_if(&op)) { + return fwd; + } + } + return nullptr; +} + +const FlatWriteBackOperation* GetWriteBackOp(const ExecutionPlan& plan) { + for (const auto& op : plan.Operations()) { + if (const auto* cache_op = std::get_if(&op)) { + if (const auto* writeback = std::get_if(cache_op)) { + return writeback; + } + } + } + return nullptr; +} + +std::vector> TokenPages(const token_vec_t& tokens, std::int32_t page_size) { + const std::size_t num_pages = tokens.size() / page_size; + std::vector> pages; + pages.reserve(num_pages); + for (std::size_t i = 0; i < num_pages; ++i) { + pages.emplace_back(tokens.data() + i * page_size, static_cast(page_size)); + } + return pages; +} + +void ExpectNoAdjunctMetadata(const FlatForwardOperation& fwd) { + ASSERT_EQ(fwd.request_ids.size(), fwd.mamba_working_indices.size()); + ASSERT_EQ(fwd.request_ids.size(), fwd.mamba_checkpoint_dst_indices.size()); + ASSERT_EQ(fwd.request_ids.size(), fwd.mamba_cow_src_indices.size()); + ASSERT_EQ(fwd.request_ids.size(), fwd.mamba_branching_seqlens.size()); + for (std::size_t i = 0; i < fwd.request_ids.size(); ++i) { + EXPECT_EQ(fwd.mamba_working_indices[i], -1); + EXPECT_EQ(fwd.mamba_checkpoint_dst_indices[i], -1); + EXPECT_EQ(fwd.mamba_cow_src_indices[i], -1); + EXPECT_EQ(fwd.mamba_branching_seqlens[i], -1); + } + EXPECT_TRUE(fwd.paged_cache_block_tables.empty()); + EXPECT_TRUE(fwd.paged_cache_block_table_base_offsets.empty()); +} + +SchedulerConfig MakePagedCacheConfigTestBase() { + SchedulerConfig cfg{}; + cfg.page_size = 2; + cfg.device_allocator.total_pages = 16; + cfg.host_allocator.total_pages = 16; + cfg.max_scheduled_tokens = 16; + cfg.max_batch_size = 4; + cfg.enable_l3_storage = false; + return cfg; +} + +PagedCacheGroupConfig MakePagedCacheHistoryGroup(std::string group_id = "fh") { + PagedCacheGroupConfig group{}; + group.group_id = std::move(group_id); + group.rows_per_page = 4; + group.entry_stride_tokens = 1; + group.total_pages = 8; + group.retention = PagedCacheGroupConfig::Retention::FullHistory; + group.family = PagedCacheGroupFamily::History; + return group; +} + +PagedCacheGroupConfig MakePagedCacheSlidingStateGroup(std::string group_id = "swa", + std::int32_t sliding_window_tokens = 8) { + PagedCacheGroupConfig group{}; + group.group_id = std::move(group_id); + group.rows_per_page = 2; + group.entry_stride_tokens = 1; + group.total_pages = 8; + group.retention = PagedCacheGroupConfig::Retention::SlidingWindow; + group.sliding_window_tokens = sliding_window_tokens; + group.family = PagedCacheGroupFamily::State; + return group; +} + +} // namespace + class LoadBackViaCacheTestSuite : public SchedulerTestSuite { protected: SchedulerConfig MakeConfig() override { @@ -111,6 +199,188 @@ TEST_F(SchedulerTestSuite, NoCacheOps_PlainRequestNoCacheHit) { EXPECT_TRUE(cache_ops.empty()); } +TEST(SchedulerPagedCacheConfigurationTest, ConfiguredGroupsWithoutAdjunctRemainPubliclyIntrospectable) { + auto cfg = MakePagedCacheConfigTestBase(); + cfg.paged_cache_groups.push_back(MakePagedCacheHistoryGroup()); + + Scheduler scheduler{cfg}; + + const auto group_ids = scheduler.PagedCacheGroupIds(); + ASSERT_EQ(group_ids.size(), 1u); + EXPECT_EQ(group_ids[0], "fh"); + EXPECT_EQ(scheduler.PagedCacheGroupTotalPages("fh"), 8); + EXPECT_EQ(scheduler.PagedCacheGroupAvailablePages("fh"), 7); + EXPECT_EQ(scheduler.PagedCacheGroupFailedAllocCount("fh"), 0); + EXPECT_TRUE(scheduler.GetRequestPagedCachePageIds("missing", "fh").empty()); + EXPECT_EQ(scheduler.GetRequestPagedCacheBaseLogicalPage("missing", "fh"), 0); +} + +TEST(SchedulerPagedCacheConfigurationTest, InvalidGroupConfigFailsAtConstruction) { + auto cfg = MakePagedCacheConfigTestBase(); + auto invalid_group = MakePagedCacheHistoryGroup(); + invalid_group.rows_per_page = 0; + cfg.paged_cache_groups.push_back(invalid_group); + + EXPECT_THROW({ Scheduler scheduler{cfg}; }, std::invalid_argument); +} + +TEST(SchedulerPagedCacheConfigurationTest, DuplicateGroupIdsFailAtConstruction) { + auto cfg = MakePagedCacheConfigTestBase(); + cfg.paged_cache_groups.push_back(MakePagedCacheHistoryGroup("dup")); + cfg.paged_cache_groups.push_back(MakePagedCacheHistoryGroup("dup")); + + EXPECT_THROW({ Scheduler scheduler{cfg}; }, std::invalid_argument); +} + +TEST(SchedulerPagedCacheConfigurationTest, EmptyAdjunctRequiredGroupsFailAtConstruction) { + auto cfg = MakePagedCacheConfigTestBase(); + cfg.paged_cache_groups.push_back(MakePagedCacheHistoryGroup()); + cfg.prefix_cache_adjunct = PrefixCacheAdjunctSpec{}; + + EXPECT_THROW({ Scheduler scheduler{cfg}; }, std::invalid_argument); +} + +TEST(SchedulerPagedCacheConfigurationTest, MissingAdjunctRequiredGroupFailsAtConstruction) { + auto cfg = MakePagedCacheConfigTestBase(); + cfg.paged_cache_groups.push_back(MakePagedCacheHistoryGroup()); + PrefixCacheAdjunctSpec spec{}; + spec.required_groups = {"fh", "missing"}; + cfg.prefix_cache_adjunct = spec; + + EXPECT_THROW({ Scheduler scheduler{cfg}; }, std::invalid_argument); +} + +TEST(SchedulerPagedCacheConfigurationTest, SlidingRequiredGroupWithoutPositiveWindowFailsAtConstruction) { + auto cfg = MakePagedCacheConfigTestBase(); + auto sliding_group = MakePagedCacheSlidingStateGroup(); + sliding_group.sliding_window_tokens.reset(); + cfg.paged_cache_groups.push_back(MakePagedCacheHistoryGroup()); + cfg.paged_cache_groups.push_back(sliding_group); + PrefixCacheAdjunctSpec spec{}; + spec.required_groups = {"fh", "swa"}; + cfg.prefix_cache_adjunct = spec; + + EXPECT_THROW({ Scheduler scheduler{cfg}; }, std::invalid_argument); +} + +class KVOnlyAlwaysPresentFacadeTestSuite : public SchedulerTestSuite { +protected: + SchedulerConfig MakeConfig() override { + auto cfg = SchedulerTestSuite::MakeConfig(); + cfg.enable_l3_storage = false; + cfg.disable_l2_cache = true; + cfg.enable_mamba = false; + cfg.mamba_pool_total_chunks = 0; + cfg.paged_cache_groups.clear(); + cfg.prefix_cache_adjunct.reset(); + return cfg; + } +}; + +TEST_F(KVOnlyAlwaysPresentFacadeTestSuite, NextExecutionPlanKeepsKVOnlyPrefillDecodeAndPrefixReuseStable) { + EXPECT_TRUE(scheduler_->PagedCacheGroupIds().empty()); + EXPECT_THROW((void)scheduler_->PagedCacheGroupTotalPages("missing"), std::out_of_range); + EXPECT_THROW((void)scheduler_->PagedCacheGroupAvailablePages("missing"), std::out_of_range); + EXPECT_THROW((void)scheduler_->PagedCacheGroupFailedAllocCount("missing"), std::out_of_range); + EXPECT_THROW((void)scheduler_->GetRequestPagedCachePageIds("r_seed", "missing"), std::out_of_range); + EXPECT_THROW((void)scheduler_->GetRequestPagedCacheBaseLogicalPage("r_seed", "missing"), std::out_of_range); + + Submit(MakeRequestSpec("r_seed", /*num_pages=*/2, /*start=*/1)); + auto seed_prefill = PlanOnce(); + const auto* seed_prefill_fwd = GetForwardOp(seed_prefill); + ASSERT_NE(seed_prefill_fwd, nullptr); + ASSERT_EQ(seed_prefill_fwd->request_ids.size(), 1u); + ASSERT_EQ(seed_prefill_fwd->extend_prefix_lens.size(), 1u); + EXPECT_EQ(seed_prefill_fwd->request_ids[0], "r_seed"); + EXPECT_EQ(seed_prefill_fwd->extend_prefix_lens[0], 0); + EXPECT_EQ(seed_prefill_fwd->input_lengths[0], 4); + ExpectNoAdjunctMetadata(*seed_prefill_fwd); + + SendForwardDone("r_seed", {101}); + auto seed_decode = PlanOnce(); + const auto* seed_decode_fwd = GetForwardOp(seed_decode); + ASSERT_NE(seed_decode_fwd, nullptr); + ASSERT_EQ(seed_decode_fwd->request_ids.size(), 1u); + EXPECT_EQ(seed_decode_fwd->request_ids[0], "r_seed"); + EXPECT_EQ(seed_decode_fwd->input_lengths[0], 1); + ExpectNoAdjunctMetadata(*seed_decode_fwd); + EXPECT_EQ(scheduler_->DecodingSize(), 1u); + + SendFinish("r_seed"); + PlanOnce(); + + Submit(MakeRequestSpec("r_reuse", /*num_pages=*/2, /*start=*/1)); + auto reuse_prefill = PlanOnce(); + const auto* reuse_prefill_fwd = GetForwardOp(reuse_prefill); + ASSERT_NE(reuse_prefill_fwd, nullptr); + ASSERT_EQ(reuse_prefill_fwd->request_ids.size(), 1u); + ASSERT_EQ(reuse_prefill_fwd->extend_prefix_lens.size(), 1u); + EXPECT_EQ(reuse_prefill_fwd->request_ids[0], "r_reuse"); + EXPECT_EQ(reuse_prefill_fwd->extend_prefix_lens[0], PageSize()); + EXPECT_EQ(reuse_prefill_fwd->input_lengths[0], PageSize()); + ExpectNoAdjunctMetadata(*reuse_prefill_fwd); + EXPECT_TRUE(ExtractCacheOpsOfKind(reuse_prefill).empty()); +} + +class RollingHashSeedFacadeTestSuite : public SchedulerTestSuite { +protected: + SchedulerConfig MakeConfig() override { + auto cfg = SchedulerTestSuite::MakeConfig(); + cfg.device_allocator.total_pages = 32; + cfg.host_allocator.total_pages = 32; + cfg.enable_l3_storage = true; + cfg.disable_l2_cache = false; + return cfg; + } + + void StoreHostPrefix(std::int32_t num_pages, token_t start = 1) { + Submit(MakeRequestSpec("r_seed", num_pages, start)); + PlanOnce(); + SendForwardDone("r_seed", {900}); + PlanOnce(); + SendFinish("r_seed"); + const auto plan = PlanOnce(); + const auto* writeback = GetWriteBackOp(plan); + ASSERT_NE(writeback, nullptr); + ASSERT_FALSE(writeback->op_ids.empty()); + SendWriteBackDone(writeback->op_ids[0]); + PlanOnce(); + } +}; + +TEST_F(RollingHashSeedFacadeTestSuite, CalcRollingHashWithoutMatchHashesFullInputFromEmptySeed) { + StoreHostPrefix(/*num_pages=*/2); + + const token_vec_t input_tokens = MakeAlignedTokens(/*num_pages=*/3, PageSize(), /*start=*/1); + const auto pages = TokenPages(input_tokens, PageSize()); + const auto full_hashes = ComputePagedHashes(pages, ""); + + EXPECT_EQ(scheduler_->CalcRollingHash(input_tokens, /*apply_match=*/false), full_hashes); + + const std::vector> suffix{pages.begin() + 2, pages.end()}; + EXPECT_EQ(scheduler_->CalcRollingHash(input_tokens, /*apply_match=*/true), ComputePagedHashes(suffix, "")); +} + +TEST_F(RollingHashSeedFacadeTestSuite, CalcRollingHashWithFullHostMatchReturnsEmpty) { + StoreHostPrefix(/*num_pages=*/2); + + const token_vec_t input_tokens = MakeAlignedTokens(/*num_pages=*/2, PageSize(), /*start=*/1); + EXPECT_TRUE(scheduler_->CalcRollingHash(input_tokens, /*apply_match=*/true).empty()); +} + +TEST_F(RollingHashSeedFacadeTestSuite, CalcRollingHashUsesEmptyPriorSeedWhenHostNodeHasNoPageHashes) { + auto cfg = Config(); + cfg.enable_l3_storage = false; + scheduler_ = std::make_unique(cfg); + StoreHostPrefix(/*num_pages=*/2); + + const token_vec_t input_tokens = MakeAlignedTokens(/*num_pages=*/3, PageSize(), /*start=*/1); + const auto pages = TokenPages(input_tokens, PageSize()); + const std::vector> suffix{pages.begin() + 2, pages.end()}; + + EXPECT_EQ(scheduler_->CalcRollingHash(input_tokens, /*apply_match=*/true), ComputePagedHashes(suffix, "")); +} + class DisablePrefixCacheTestSuite : public SchedulerTestSuite { protected: SchedulerConfig MakeConfig() override {