Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
626 changes: 626 additions & 0 deletions docs/mamba-dsv4-refactor.md

Large diffs are not rendered by default.

8 changes: 6 additions & 2 deletions tokenspeed-scheduler/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ add_library(tokenspeed_scheduler_core STATIC
csrc/resource/radix_tree/radix_tree.cpp
csrc/resource/radix_tree/tree_node.cpp
csrc/resource/kv_prefix_cache/kv_prefix_cache.cpp
csrc/resource/kv_prefix_cache/cache_coordinator.cpp
csrc/resource/hybrid_prefix_cache/family_registry.cpp
csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache.cpp
csrc/resource/hybrid_prefix_cache/mamba_family_ops.cpp
csrc/resource/hybrid_prefix_cache/mamba_eviction_manager.cpp
csrc/resource/hybrid_prefix_cache/paged_cache_family_ops.cpp
csrc/resource/allocator/mamba_chunk_allocator.cpp
csrc/resource/allocator/mamba_host_allocator.cpp
csrc/resource/allocator/local_mamba_allocator.cpp
Expand All @@ -50,7 +52,6 @@ add_library(tokenspeed_scheduler_core STATIC
csrc/fsm/cache_events.cpp
csrc/fsm/pd_events.cpp

csrc/scheduler/operations/cache.cpp
csrc/scheduler/operations/forward.cpp

csrc/scheduler/kv_cache_events.cpp
Expand Down Expand Up @@ -112,12 +113,15 @@ if(TOKENSPEED_SCHEDULER_BUILD_TESTS)
tests/cpp/test_batch_scheduling.cpp
tests/cpp/test_prefetch.cpp
tests/cpp/test_owned_pages.cpp
tests/cpp/test_hybrid_cache_registry.cpp
tests/cpp/test_paged_cache_prefix_match.cpp
tests/cpp/test_paged_cache_attach_loop.cpp
tests/cpp/test_paged_cache_eviction.cpp
tests/cpp/test_paged_cache_family_split.cpp
tests/cpp/test_paged_cache_prefix_hit_commit.cpp
tests/cpp/test_retract_abort_pages.cpp
tests/cpp/test_request_cache_context.cpp
tests/cpp/test_scheduler_memory_diagnostics.cpp
tests/cpp/test_req_pool_allocator.cpp
tests/cpp/test_mamba_slot.cpp
tests/cpp/test_mamba_eviction.cpp
Expand Down
326 changes: 141 additions & 185 deletions tokenspeed-scheduler/csrc/fsm/forward_events.cpp

Large diffs are not rendered by default.

88 changes: 27 additions & 61 deletions tokenspeed-scheduler/csrc/fsm/forward_events.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,9 @@
#include "fsm/forward_states.h"
#include "resource/types.h"
#include "resource/hybrid_prefix_cache/hybrid_prefix_cache.h"
#include "resource/allocator/mamba_chunk_allocator.h"
#include "resource/allocator/local_mamba_allocator.h"
#include "utils.h"

namespace tokenspeed {
class PageAllocator;
class KVPrefixCache;
class ReqPoolAllocator;
class TreeNode;
} // namespace tokenspeed
Expand All @@ -52,61 +48,47 @@ namespace tokenspeed::fsm {
struct PrefetchDone;
struct Prefetching;

void InsertHybridCache(HybridPrefixCache* hybrid_prefix_cache,
const std::vector<std::span<const std::int32_t>>& full_paged_tokens,
std::unique_ptr<DeviceNodeRef>& device_node_ref, LocalKVAllocator* local_kv_allocator,
LocalMambaAllocator* local_mamba_allocator, std::int32_t chunk_begin, std::int32_t chunk_size,
std::int32_t page_size);

struct SchedulePrefillFirstChunkEvent : InvalidTransitionHandler<SchedulePrefillFirstChunkEvent> {
using InvalidTransitionHandler<SchedulePrefillFirstChunkEvent>::operator();
SchedulePrefillFirstChunkEvent(std::int32_t tokens_this_round, std::int32_t decode_input_tokens,
PageAllocator* device_allocator, ReqPoolAllocator* req_pool_allocator,
MatchResult match_result, Role role, KVPrefixCache* kv_prefix_cache,
ReqPoolAllocator* req_pool_allocator, MatchResult match_result, Role role,
bool disable_l2_cache, std::vector<TreeNode*> loadback_diff,
HybridPrefixCache* hybrid_prefix_cache = nullptr,
MambaChunkAllocator* mamba_allocator = nullptr,
std::vector<TreeNode*> mamba_loadback_nodes = {})
std::vector<TransferPair> cache_transfer_pairs,
HybridPrefixCache& hybrid_prefix_cache)
: tokens_this_round_(tokens_this_round),
decode_input_tokens_(decode_input_tokens),
device_allocator_(device_allocator),
req_pool_allocator_(req_pool_allocator),
match_result_(match_result),
role_{role},
disable_l2_cache_{disable_l2_cache},
loadback_diff_(std::move(loadback_diff)),
mamba_loadback_nodes_(std::move(mamba_loadback_nodes)),
kv_prefix_cache_(kv_prefix_cache),
hybrid_prefix_cache_(hybrid_prefix_cache),
mamba_allocator_(mamba_allocator) {}
cache_transfer_pairs_(std::move(cache_transfer_pairs)),
hybrid_prefix_cache_(hybrid_prefix_cache) {}

// Returns PrefillDone (single-chunk or last chunk) or Prefilling (more chunks remain).
std::variant<PrefillDone, Prefilling> operator()(Submitted&& state);

const MatchResult GetMatchResult() const { return match_result_; }

const std::vector<TreeNode*>& GetLoadbackDiff() const { return loadback_diff_; }
const std::vector<TreeNode*>& GetMambaLoadbackNodes() const { return mamba_loadback_nodes_; }
const std::vector<TransferPair>& GetCacheTransferPairs() const { return cache_transfer_pairs_; }

private:
std::int32_t tokens_this_round_{};
std::int32_t decode_input_tokens_{};
PageAllocator* device_allocator_{};
ReqPoolAllocator* req_pool_allocator_{};
const MatchResult match_result_{};
const Role role_;
bool disable_l2_cache_{};
std::vector<TreeNode*> loadback_diff_;
std::vector<TreeNode*> mamba_loadback_nodes_;
KVPrefixCache* kv_prefix_cache_;
HybridPrefixCache* hybrid_prefix_cache_{};
MambaChunkAllocator* mamba_allocator_{};
std::vector<TransferPair> cache_transfer_pairs_;
HybridPrefixCache& hybrid_prefix_cache_;
};

struct SchedulePrefillEvent : InvalidTransitionHandler<SchedulePrefillEvent> {
using InvalidTransitionHandler<SchedulePrefillEvent>::operator();
SchedulePrefillEvent(std::int32_t tokens_this_round, std::int32_t reserve_num_tokens_in_next_schedule_event,
HybridPrefixCache* hybrid_prefix_cache = nullptr)
HybridPrefixCache& hybrid_prefix_cache)
: tokens_this_round_(tokens_this_round),
reserve_num_tokens_in_next_schedule_event_(reserve_num_tokens_in_next_schedule_event),
hybrid_prefix_cache_(hybrid_prefix_cache) {}
Expand All @@ -117,67 +99,59 @@ struct SchedulePrefillEvent : InvalidTransitionHandler<SchedulePrefillEvent> {
private:
std::int32_t tokens_this_round_{};
std::int32_t reserve_num_tokens_in_next_schedule_event_{};
HybridPrefixCache* hybrid_prefix_cache_{};
HybridPrefixCache& hybrid_prefix_cache_;
};

struct ScheduleDecodeEvent : InvalidTransitionHandler<ScheduleDecodeEvent> {
using InvalidTransitionHandler<ScheduleDecodeEvent>::operator();

ScheduleDecodeEvent(std::int32_t decode_input_tokens, HybridPrefixCache* hybrid_prefix_cache = nullptr)
ScheduleDecodeEvent(std::int32_t decode_input_tokens, HybridPrefixCache& hybrid_prefix_cache)
: decode_input_tokens_(decode_input_tokens), hybrid_prefix_cache_(hybrid_prefix_cache) {}

Decoding operator()(PrefillDone&& state);
Decoding operator()(Decoding&& state);

private:
std::int32_t decode_input_tokens_;
HybridPrefixCache* hybrid_prefix_cache_{};
HybridPrefixCache& hybrid_prefix_cache_;
};

struct ScheduleDecodeFromRetractedEvent : InvalidTransitionHandler<ScheduleDecodeFromRetractedEvent> {
using InvalidTransitionHandler<ScheduleDecodeFromRetractedEvent>::operator();

// Constructor for Retracted → Decoding recovery (LoadBack from host).
ScheduleDecodeFromRetractedEvent(std::int32_t decode_input_tokens, PageAllocator* device_allocator,
ReqPoolAllocator* req_pool_allocator, KVPrefixCache* kv_prefix_cache,
ScheduleDecodeFromRetractedEvent(std::int32_t decode_input_tokens, ReqPoolAllocator* req_pool_allocator,
MatchResult match_result, std::vector<TreeNode*> loadback_diff,
MambaChunkAllocator* mamba_allocator = nullptr,
std::vector<TreeNode*> mamba_loadback_nodes = {})
std::vector<TransferPair> cache_transfer_pairs,
HybridPrefixCache& hybrid_prefix_cache)
: decode_input_tokens_(decode_input_tokens),
device_allocator_(device_allocator),
req_pool_allocator_(req_pool_allocator),
kv_prefix_cache_(kv_prefix_cache),
match_result_(std::move(match_result)),
loadback_diff_(std::move(loadback_diff)),
mamba_loadback_nodes_(std::move(mamba_loadback_nodes)),
mamba_allocator_(mamba_allocator) {}
cache_transfer_pairs_(std::move(cache_transfer_pairs)),
hybrid_prefix_cache_(hybrid_prefix_cache) {}

Decoding operator()(Retracted&& state);

const MatchResult& GetMatchResult() const { return match_result_; }

const std::vector<TreeNode*>& GetLoadbackDiff() const { return loadback_diff_; }
const std::vector<TreeNode*>& GetMambaLoadbackNodes() const { return mamba_loadback_nodes_; }
const std::vector<TransferPair>& GetCacheTransferPairs() const { return cache_transfer_pairs_; }

private:
std::int32_t decode_input_tokens_{};
PageAllocator* device_allocator_{};
ReqPoolAllocator* req_pool_allocator_{};
KVPrefixCache* kv_prefix_cache_{};
MatchResult match_result_{};
std::vector<TreeNode*> loadback_diff_;
std::vector<TreeNode*> mamba_loadback_nodes_;
MambaChunkAllocator* mamba_allocator_{};
std::vector<TransferPair> cache_transfer_pairs_;
HybridPrefixCache& hybrid_prefix_cache_;
};

struct FinishEvent : InvalidTransitionHandler<FinishEvent> {
using InvalidTransitionHandler<FinishEvent>::operator();
explicit FinishEvent(KVPrefixCache* kv_prefix_cache, PageAllocator* host_allocator,
std::vector<std::string> page_hashes = {}, bool disable_l2_cache = false,
HybridPrefixCache* hybrid_prefix_cache = nullptr)
: kv_prefix_cache_(kv_prefix_cache),
host_allocator_(host_allocator),
page_hashes_(std::move(page_hashes)),
explicit FinishEvent(std::vector<std::string> page_hashes, bool disable_l2_cache,
HybridPrefixCache& hybrid_prefix_cache)
: page_hashes_(std::move(page_hashes)),
disable_l2_cache_(disable_l2_cache),
hybrid_prefix_cache_(hybrid_prefix_cache) {}

Expand All @@ -192,11 +166,9 @@ struct FinishEvent : InvalidTransitionHandler<FinishEvent> {
Finished operator()(Finished&& state) { return std::move(state); }

private:
KVPrefixCache* kv_prefix_cache_{};
std::vector<std::string> page_hashes_;
PageAllocator* host_allocator_;
bool disable_l2_cache_;
HybridPrefixCache* hybrid_prefix_cache_{};
HybridPrefixCache& hybrid_prefix_cache_;

template <typename ForwardStateT>
std::variant<Draining, Finished> apply(ForwardStateT&& state);
Expand All @@ -221,12 +193,8 @@ struct AbortEvent : InvalidTransitionHandler<AbortEvent> {

struct ScheduleRetractEvent : InvalidTransitionHandler<ScheduleRetractEvent> {
using InvalidTransitionHandler<ScheduleRetractEvent>::operator();
ScheduleRetractEvent(KVPrefixCache* kv_prefix_cache, PageAllocator* host_allocator, MatchResult match_result,
HybridPrefixCache* hybrid_prefix_cache = nullptr)
: kv_prefix_cache_(kv_prefix_cache),
host_allocator_(host_allocator),
match_result_(match_result),
hybrid_prefix_cache_(hybrid_prefix_cache) {}
ScheduleRetractEvent(MatchResult match_result, HybridPrefixCache& hybrid_prefix_cache)
: match_result_(match_result), hybrid_prefix_cache_(hybrid_prefix_cache) {}

Retracting operator()(Decoding&& state);
Retracting operator()(PrefillDone&& state);
Expand All @@ -237,10 +205,8 @@ struct ScheduleRetractEvent : InvalidTransitionHandler<ScheduleRetractEvent> {
template <typename ForwardStateT>
Retracting applyRetract(ForwardStateT&& state);

KVPrefixCache* kv_prefix_cache_{};
PageAllocator* host_allocator_{};
const MatchResult match_result_{};
HybridPrefixCache* hybrid_prefix_cache_{};
HybridPrefixCache& hybrid_prefix_cache_;
};

// Draining → WritingBack: WriteBack op has been generated this round; transfer
Expand Down
1 change: 1 addition & 0 deletions tokenspeed-scheduler/csrc/fsm/forward_states.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ struct BaseState {
}

const TreeNode* GetDeviceNode() const { return device_node_ref_->Node(); }
TreeNode* GetMutableDeviceNode() const { return device_node_ref_->Node(); }

std::int32_t TailPageAvailableTokens() const { return local_kv_allocator_->TailPageAvailableTokens(); }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,11 @@ std::int32_t LocalMambaAllocator::WorkingIndex() const {
return working_ ? working_->Index() : -1;
}

bool LocalMambaAllocator::AllocateCheckpoint() {
bool LocalMambaAllocator::AllocateCheckpoint(std::int32_t raw_position) {
auto slot = allocator_->Allocate();
if (!slot.has_value()) return false;
checkpoint_ = std::make_unique<MambaSlot>(std::move(*slot));
checkpoint_position_ = raw_position;
return true;
}

Expand All @@ -52,6 +53,7 @@ std::int32_t LocalMambaAllocator::CheckpointIndex() const {
}

std::unique_ptr<MambaSlot> LocalMambaAllocator::DetachCheckpoint() {
checkpoint_position_ = -1;
return std::move(checkpoint_);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,9 @@ class LocalMambaAllocator {
std::int32_t WorkingIndex() const;
bool HasWorking() const { return working_ != nullptr; }

bool AllocateCheckpoint();
bool AllocateCheckpoint(std::int32_t raw_position = -1);
std::int32_t CheckpointIndex() const;
std::int32_t CheckpointPosition() const { return checkpoint_position_; }
bool HasCheckpoint() const { return checkpoint_ != nullptr; }
std::unique_ptr<MambaSlot> DetachCheckpoint();
std::unique_ptr<MambaSlot> DetachWorking();
Expand All @@ -54,6 +55,7 @@ class LocalMambaAllocator {
MambaChunkAllocator* allocator_;
std::unique_ptr<MambaSlot> working_{};
std::unique_ptr<MambaSlot> checkpoint_{};
std::int32_t checkpoint_position_{-1};
};

} // namespace tokenspeed
Loading
Loading