Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 49 additions & 4 deletions python/tokenspeed/runtime/cache/executor/host_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,22 @@ def __init__(
self.layer_done_counter = LayerDoneCounter(layer_num)
device_pool.register_layer_transfer_counter(self.layer_done_counter)

# Drafter pool needs its own counter so its get_key_buffer waits
# per-layer; otherwise the drafter forward races load_stream's draft
# copies and reads uninitialized KV.
self.draft_layer_done_counter: LayerDoneCounter | None = None
if (
draft_device_pool is not None
and draft_layer_num > 0
and hasattr(draft_device_pool, "register_layer_transfer_counter")
):
self.draft_layer_done_counter = LayerDoneCounter(draft_layer_num)
draft_device_pool.register_layer_transfer_counter(
self.draft_layer_done_counter
)

self._producer_map: OrderedDict[int, int] = OrderedDict()
self._draft_producer_map: OrderedDict[int, int] = OrderedDict()
self._producer_map_limit = 1024

def enqueue_writeback(
Expand Down Expand Up @@ -319,19 +334,33 @@ def _start_loading(self) -> None:
host_indices, device_indices = self._move_indices(op, self.host_pool)
self.load_queue.clear()

producer_event = self.layer_done_counter.events[producer_id]
producer_event.start_event.record()

# Prepare draft indices once if draft pool is present.
# Issue draft-index H2D before recording start events so both events
# cover all index copies; otherwise load_stream may consume them
# before the H2D completes.
if self.draft_host_pool is not None:
draft_host_indices, draft_device_indices = self._move_indices(
op, self.draft_host_pool
)
else:
draft_host_indices = draft_device_indices = None

draft_producer_id: int | None = None
draft_producer_event = None
if self.draft_layer_done_counter is not None:
draft_producer_id = self.draft_layer_done_counter.update_producer()
draft_producer_event = self.draft_layer_done_counter.events[
draft_producer_id
]

producer_event = self.layer_done_counter.events[producer_id]
producer_event.start_event.record()
if draft_producer_event is not None:
draft_producer_event.start_event.record()

with device_module.stream(self.load_stream):
producer_event.start_event.wait(self.load_stream)
if draft_producer_event is not None:
draft_producer_event.start_event.wait(self.load_stream)
for layer_index in range(self.layer_num):
self.host_pool.load_to_device_per_layer(
self.device_pool,
Expand All @@ -351,6 +380,8 @@ def _start_loading(self) -> None:
layer_index,
self.io_backend,
)
if draft_producer_event is not None:
draft_producer_event.complete(layer_index)
if draft_host_indices.is_cuda:
draft_host_indices.record_stream(self.load_stream)
if draft_device_indices.is_cuda:
Expand All @@ -363,8 +394,12 @@ def _start_loading(self) -> None:
self.ack_load_queue.append(_Ack(producer_event.finish_event, op.node_ids))
for op_id in op.node_ids:
self._producer_map[op_id] = producer_id
if draft_producer_id is not None:
self._draft_producer_map[op_id] = draft_producer_id
while len(self._producer_map) > self._producer_map_limit:
self._producer_map.popitem(last=False)
while len(self._draft_producer_map) > self._producer_map_limit:
self._draft_producer_map.popitem(last=False)

def _move_indices(self, op: _TransferOp, host_pool):
host_indices = op.host_indices
Expand Down Expand Up @@ -423,6 +458,13 @@ def get_producer_index(self, op_id: int) -> int | None:
def set_consumer(self, producer_index: int | Iterable[int]) -> None:
self.layer_done_counter.set_consumer(producer_index)

def get_draft_producer_index(self, op_id: int) -> int | None:
return self._draft_producer_map.pop(op_id, None)

def set_draft_consumer(self, producer_index: int | Iterable[int]) -> None:
if self.draft_layer_done_counter is not None:
self.draft_layer_done_counter.set_consumer(producer_index)

def shutdown(self) -> None:
self.write_stream.synchronize()
self.load_stream.synchronize()
Expand All @@ -435,4 +477,7 @@ def reset(self) -> None:
self.ack_write_queue.clear()
self.ack_load_queue.clear()
self._producer_map.clear()
self._draft_producer_map.clear()
self.layer_done_counter.reset()
if self.draft_layer_done_counter is not None:
self.draft_layer_done_counter.reset()
6 changes: 6 additions & 0 deletions python/tokenspeed/runtime/cache/executor/memory_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,12 @@ def get_producer_index(self, op_id: int) -> Optional[int]:
def set_consumer(self, producer_index: int | Iterable[int]) -> None:
self.host_exec.set_consumer(producer_index)

def get_draft_producer_index(self, op_id: int) -> Optional[int]:
return self.host_exec.get_draft_producer_index(op_id)

def set_draft_consumer(self, producer_index: int | Iterable[int]) -> None:
self.host_exec.set_draft_consumer(producer_index)

def query_l3_pages(self, hashes: list[str]) -> int:
return self.storage_exec.query_exists(hashes)

Expand Down
13 changes: 10 additions & 3 deletions python/tokenspeed/runtime/engine/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,7 @@ def _submit_cache_ops(self, execution_plan) -> None:

def _setup_layerwise_loadback(self, execution_plan) -> None:
consumer_indices = []
draft_consumer_indices = []
for cache_op in execution_plan.cache:
if isinstance(cache_op, Cache.LoadBackOp):
for op_id in cache_op.op_ids:
Expand All @@ -622,16 +623,22 @@ def _setup_layerwise_loadback(self, execution_plan) -> None:
and producer_idx not in consumer_indices
):
consumer_indices.append(producer_idx)
draft_idx = self.memory_executor.get_draft_producer_index(op_id)
if (
draft_idx is not None
and draft_idx not in draft_consumer_indices
):
draft_consumer_indices.append(draft_idx)
self.memory_executor.set_consumer(consumer_indices if consumer_indices else -1)
self.memory_executor.set_draft_consumer(
draft_consumer_indices if draft_consumer_indices else -1
)
# Fence WriteBack against this iter's ``set_kv_buffer``: the
# scheduler can re-allocate a freed-but-not-yet-written-back slot
# to a new prefill / decode within the same iter. ``set_kv_buffer``
# runs before any ``wait_until`` in attention, so nothing else
# orders writeback's reads against the new writes. Cheap when
# write_stream is idle.
# LoadBack does not need fencing here: it only fires on admission
# iters (eager prefill), whose per-layer ``wait_until`` drains
# ``load_stream`` before the iter ends.
host_exec = getattr(self.memory_executor, "host_exec", None)
if host_exec is not None:
self.model_executor.execution_stream.wait_stream(host_exec.write_stream)
Expand Down
Loading