From bccd75025da2366af2e8a7eda0fc1e0dfdd354c5 Mon Sep 17 00:00:00 2001 From: Jue Wang Date: Thu, 7 May 2026 02:57:28 +0000 Subject: [PATCH] fix prefill cache update. --- tokenspeed-scheduler/csrc/fsm/forward_events.cpp | 14 +++++++------- tokenspeed-scheduler/csrc/fsm/forward_events.h | 14 ++++++++++---- .../csrc/scheduler/operations/forward.cpp | 4 ++-- .../tests/cpp/test_chunked_prefill.cpp | 15 +++++++++++++++ 4 files changed, 34 insertions(+), 13 deletions(-) diff --git a/tokenspeed-scheduler/csrc/fsm/forward_events.cpp b/tokenspeed-scheduler/csrc/fsm/forward_events.cpp index 65a749003..63d505202 100644 --- a/tokenspeed-scheduler/csrc/fsm/forward_events.cpp +++ b/tokenspeed-scheduler/csrc/fsm/forward_events.cpp @@ -60,11 +60,11 @@ std::vector> BuildWriteBackPairs( namespace tokenspeed::fsm { -void InsertHybridCache(HybridPrefixCache* hybrid_cache, +void InsertPrefixCache(KVPrefixCache* kv_prefix_cache, HybridPrefixCache* hybrid_cache, const std::vector>& full_paged_tokens, std::unique_ptr& device_node_ref, LocalKVAllocator* local_kv_allocator, LocalMambaAllocator* local_mamba_allocator) { - if (hybrid_cache == nullptr) return; + if (kv_prefix_cache == nullptr) return; std::vector prefix_pages = DevicePagesFromRoot(device_node_ref->Node()); std::int32_t new_page_count = @@ -72,10 +72,10 @@ void InsertHybridCache(HybridPrefixCache* hybrid_cache, if (new_page_count <= 0) return; OwnedPages pages_to_insert = local_kv_allocator->TakeFirst(new_page_count); - auto insert_result = hybrid_cache->GetKVPrefixCache().Insert(full_paged_tokens, prefix_pages, - std::move(pages_to_insert)); + auto insert_result = + kv_prefix_cache->Insert(full_paged_tokens, prefix_pages, std::move(pages_to_insert)); - if (local_mamba_allocator != nullptr && local_mamba_allocator->HasCheckpoint()) { + if (hybrid_cache != nullptr && local_mamba_allocator != nullptr && local_mamba_allocator->HasCheckpoint()) { hybrid_cache->InsertMamba(insert_result.last_node, local_mamba_allocator->DetachCheckpoint()); } device_node_ref = std::make_unique(insert_result.last_node); @@ -155,7 +155,7 @@ std::variant SchedulePrefillEvent::operator()(Prefillin if (end_of_window_pages < static_cast(paged_tokens.size())) { paged_tokens.resize(end_of_window_pages); } - InsertHybridCache(hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(), + InsertPrefixCache(kv_prefix_cache_, hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(), local_mamba_allocator.get()); // Allocate KV pages for the new chunk local_kv_allocator->Acquire(tokens_this_round_); @@ -202,7 +202,7 @@ Decoding ScheduleDecodeEvent::operator()(PrefillDone&& state) { if (end_of_window_pages < static_cast(paged_tokens.size())) { paged_tokens.resize(end_of_window_pages); } - InsertHybridCache(hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(), + InsertPrefixCache(kv_prefix_cache_, hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(), local_mamba_allocator.get()); // Allocate fresh checkpoint for decode-phase mamba state tracking diff --git a/tokenspeed-scheduler/csrc/fsm/forward_events.h b/tokenspeed-scheduler/csrc/fsm/forward_events.h index 12d3afdb0..96fd89b9e 100644 --- a/tokenspeed-scheduler/csrc/fsm/forward_events.h +++ b/tokenspeed-scheduler/csrc/fsm/forward_events.h @@ -52,7 +52,7 @@ namespace tokenspeed::fsm { struct PrefetchDone; struct Prefetching; -void InsertHybridCache(HybridPrefixCache* hybrid_prefix_cache, +void InsertPrefixCache(KVPrefixCache* kv_prefix_cache, HybridPrefixCache* hybrid_prefix_cache, const std::vector>& full_paged_tokens, std::unique_ptr& device_node_ref, LocalKVAllocator* local_kv_allocator, LocalMambaAllocator* local_mamba_allocator); @@ -101,9 +101,10 @@ struct SchedulePrefillFirstChunkEvent : InvalidTransitionHandler { using InvalidTransitionHandler::operator(); SchedulePrefillEvent(std::int32_t tokens_this_round, std::int32_t reserve_num_tokens_in_next_schedule_event, - HybridPrefixCache* hybrid_prefix_cache = nullptr) + KVPrefixCache* kv_prefix_cache, HybridPrefixCache* hybrid_prefix_cache = nullptr) : tokens_this_round_(tokens_this_round), reserve_num_tokens_in_next_schedule_event_(reserve_num_tokens_in_next_schedule_event), + kv_prefix_cache_(kv_prefix_cache), hybrid_prefix_cache_(hybrid_prefix_cache) {} // Returns PrefillDone (last chunk) or Prefilling (more chunks remain). @@ -112,20 +113,25 @@ struct SchedulePrefillEvent : InvalidTransitionHandler { private: std::int32_t tokens_this_round_{}; std::int32_t reserve_num_tokens_in_next_schedule_event_{}; + KVPrefixCache* kv_prefix_cache_{}; HybridPrefixCache* hybrid_prefix_cache_{}; }; struct ScheduleDecodeEvent : InvalidTransitionHandler { using InvalidTransitionHandler::operator(); - ScheduleDecodeEvent(std::int32_t decode_input_tokens, HybridPrefixCache* hybrid_prefix_cache = nullptr) - : decode_input_tokens_(decode_input_tokens), hybrid_prefix_cache_(hybrid_prefix_cache) {} + ScheduleDecodeEvent(std::int32_t decode_input_tokens, KVPrefixCache* kv_prefix_cache, + HybridPrefixCache* hybrid_prefix_cache = nullptr) + : decode_input_tokens_(decode_input_tokens), + kv_prefix_cache_(kv_prefix_cache), + hybrid_prefix_cache_(hybrid_prefix_cache) {} Decoding operator()(PrefillDone&& state); Decoding operator()(Decoding&& state); private: std::int32_t decode_input_tokens_; + KVPrefixCache* kv_prefix_cache_{}; HybridPrefixCache* hybrid_prefix_cache_{}; }; diff --git a/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp b/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp index 89ebdef00..d4ab1dc10 100644 --- a/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp +++ b/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp @@ -121,7 +121,7 @@ std::optional Scheduler::schedulePrefill( return {}; } - return fsm::SchedulePrefillEvent{tokens_this_round, reserve_num_tokens_in_next_schedule_event, + return fsm::SchedulePrefillEvent{tokens_this_round, reserve_num_tokens_in_next_schedule_event, &kv_prefix_cache_, hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr}; } @@ -134,7 +134,7 @@ std::optional Scheduler::scheduleDecode(Request* reque return {}; } - return fsm::ScheduleDecodeEvent{config_.decode_input_tokens, + return fsm::ScheduleDecodeEvent{config_.decode_input_tokens, &kv_prefix_cache_, hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr}; } diff --git a/tokenspeed-scheduler/tests/cpp/test_chunked_prefill.cpp b/tokenspeed-scheduler/tests/cpp/test_chunked_prefill.cpp index 80faf5fca..439f6a673 100644 --- a/tokenspeed-scheduler/tests/cpp/test_chunked_prefill.cpp +++ b/tokenspeed-scheduler/tests/cpp/test_chunked_prefill.cpp @@ -87,6 +87,21 @@ TEST_F(ChunkedPrefillTestSuite, PrefillFirst_ContinuesPrefillBeforeNewSubmitted) EXPECT_EQ(fwd->request_ids[0], "r1"); } +TEST_F(ChunkedPrefillTestSuite, CompletedChunk_IsVisibleToPrefixCacheWithoutHybridCache) { + Submit(MakeRequestSpec("r1", 4)); // 8 tokens, needs 2 chunks + PlanOnce(); // r1 chunk 1 + + Submit(MakeRequestSpec("r2", 4)); // same prefix as r1 + PlanOnce(); // r1 chunk 2; inserts chunk 1 into KV prefix cache + + auto plan = PlanOnce(); + auto* fwd = GetForwardOp(plan); + ASSERT_NE(fwd, nullptr); + ASSERT_EQ(fwd->request_ids.size(), 1u); + EXPECT_EQ(fwd->request_ids[0], "r2"); + EXPECT_EQ(fwd->extend_prefix_lens[0], 4); +} + TEST_F(ChunkedPrefillTestSuite, InputIds_CorrectPerChunk) { Submit(MakeRequestSpec("r1", 3)); // 6 tokens: [1,2,3,4,5,6] auto plan1 = PlanOnce();