Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
1c84488
feat(lora): end-to-end LoRA adapter serving
qywu May 25, 2026
7f0e675
chore: remove dco.yml — no longer needed after squash
qywu May 25, 2026
98adfca
chore: move benchmark files to qywu/lora-dev branch
qywu May 25, 2026
de4cf46
chore: move LoRA doc files to qywu/lora-dev branch
qywu May 25, 2026
dc7a35b
chore: move LoRA test files to qywu/lora-dev branch
qywu May 25, 2026
0646be2
chore: move scheduler LoRA test to qywu/lora-dev branch
qywu May 25, 2026
3477a66
chore: revert CMakeLists.txt LoRA test entry (moved to qywu/lora-dev)
qywu May 25, 2026
ddbd79a
chore: restore docs/index.md and test/runners.py from upstream
qywu May 25, 2026
89655b6
chore: revert _triton.py; remove unused fused kernel imports
qywu May 25, 2026
d6e442b
chore: revert tokenspeed_kernel/__init__.py to upstream
qywu May 25, 2026
3ad51aa
chore: revert attention/__init__.py to upstream
qywu May 25, 2026
d4cbc00
chore: revert tokenspeed_scheduler exports; move kernel LoRA test
qywu May 25, 2026
1127517
fix(scheduler): remove PagedCacheGroupFamily/PrefixCacheAdjunctSpec f…
qywu May 25, 2026
46f20d4
fix(lora): two-phase prepare_loras to prevent silent wrong-output bug
qywu May 25, 2026
bc7b4ff
fix(lora): defer GPU weight eviction on mid-decode unload
qywu May 25, 2026
9bd96b8
fix(lora): add flush_pending_evictions for explicit slot reclaim
qywu May 25, 2026
c645b02
fix(lora): harden deferred eviction against re-registration and retra…
qywu May 25, 2026
6a03eb3
fix(lora): remove GPU zeroing from _reset_slot to eliminate CUDA stre…
qywu May 25, 2026
0e2d70d
feat(scheduler): enforce max_loras adapter cap per batch in C++ sched…
qywu May 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions python/tokenspeed/bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,7 +776,7 @@ def get_lora_request(
self,
index: int,
max_loras: int | None = None,
lora_path: str | None = None,
lora_name: str | None = None,
lora_assignment: str = "random",
) -> None:
return None
Expand Down Expand Up @@ -821,7 +821,7 @@ def sample(
output_len: int = DEFAULT_OUTPUT_LEN,
batchsize: int = 1,
max_loras: int | None = None,
lora_path: str | None = None,
lora_name: str | None = None,
lora_assignment: str = "random",
**kwargs,
) -> list[SampleRequest]:
Expand Down Expand Up @@ -879,7 +879,7 @@ def sample(
lora_req = self.get_lora_request(
index=i,
max_loras=max_loras,
lora_path=lora_path,
lora_name=lora_name,
lora_assignment=lora_assignment,
)
requests.append(
Expand Down
2 changes: 2 additions & 0 deletions python/tokenspeed/runtime/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ def __init__(
# Read model args
self.model_path = server_args.model
self.served_model_name = server_args.served_model_name
# LoRA adapter name → integer lora_id (populated by load_lora_adapter).
self._lora_name_to_id: dict[str, int] = {}
self.model_config = ModelConfig(
server_args.model,
trust_remote_code=server_args.trust_remote_code,
Expand Down
100 changes: 97 additions & 3 deletions python/tokenspeed/runtime/engine/event_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
# SOFTWARE.

import faulthandler
import os
import signal
import time
from collections import OrderedDict
Expand Down Expand Up @@ -50,7 +51,6 @@
cache_sync_debug_enabled,
make_config,
pool_to_paged_cache_groups,
pool_to_prefix_cache_adjunct_spec,
pop_common_cache_event_payloads,
)
from tokenspeed.runtime.execution.distributed_initializer import (
Expand Down Expand Up @@ -299,8 +299,11 @@ def __init__(
f"(ratio={server_args.mamba_full_memory_ratio})."
)

# Adjunct enabled only when pool opts in AND prefix-caching switch is on.
enable_mixed_prefill_decode = (
server_args.enable_mixed_batch and server_args.speculative_algorithm is None
)

# Adjunct enabled only when pool opts in AND prefix-caching switch is on.
paged_cache_groups = pool_to_paged_cache_groups(token_to_kv_pool)
prefix_cache_adjunct = None
required_groups = token_to_kv_pool.prefix_cache_required_group_ids
Expand Down Expand Up @@ -329,8 +332,9 @@ def __init__(
enable_mamba_l2=server_args.enable_mamba_l2,
mamba_l2_host_slots=mamba_l2_host_slots,
paged_cache_groups=paged_cache_groups,
enable_mixed_prefill_decode=server_args.enable_mixed_batch,
enable_mixed_prefill_decode=enable_mixed_prefill_decode,
prefix_cache_adjunct=prefix_cache_adjunct,
max_loras=server_args.max_loras if server_args.enable_lora else 0,
)
logger.info(
"Scheduler config: page_size=%s num_device_pages=%s "
Expand Down Expand Up @@ -381,6 +385,8 @@ def __init__(
send_func=self.send_to_tokenizer,
get_load_fn=self._get_load,
architectures=self.model_config.hf_config.architectures,
load_lora_fn=self.load_lora_adapter,
unload_lora_fn=self.unload_lora_adapter,
)

self.output_processor = OutputProcesser(
Expand Down Expand Up @@ -436,6 +442,60 @@ def __init__(
else:
self.pd_kv_transfer = None

# ── LoRA ─────────────────────────────────────────────────────────────
self._lora_manager = None # LoraManager (lazy init)
self._lora_name_to_id: dict[str, int] = {} # name → integer lora_id
self._request_lora_ids: dict[str, int] = {} # rid → lora_id

if server_args.enable_lora:
self._init_lora_manager()

def _init_lora_manager(self) -> None:
"""Bind to the LoraManager owned by the model executor.

The model executor creates the manager during its own ``__init__`` so
that the CUDA-graph capture sees a live manager (and bakes the LoRA
delta path into the captured graphs). The event loop only borrows
the reference and shares its request-id → lora-id map.
"""
self._lora_manager = self.model_executor.lora_manager
if self._lora_manager is None:
raise RuntimeError(
"Model executor was not configured with --enable-lora; "
"cannot initialize LoRA support."
)
self.model_executor.request_lora_ids = self._request_lora_ids
logger.info("LoraManager bound (max_loras=%d)", self.server_args.max_loras)

def load_lora_adapter(self, lora_name: str, adapter_path: str) -> int:
"""Load a PEFT LoRA adapter and make it available for serving.

Returns the integer lora_id assigned to this adapter.
"""
if not self.server_args.enable_lora:
raise ValueError(
"Server was not started with --enable-lora. "
"Restart with --enable-lora to use LoRA adapters."
)
if self._lora_manager is None:
self._init_lora_manager()
lora_id = self._lora_manager.load_adapter(lora_name, adapter_path)
self._lora_name_to_id[lora_name] = lora_id
logger.info("Loaded LoRA adapter '%s' → lora_id=%d", lora_name, lora_id)
return lora_id

def unload_lora_adapter(self, lora_name: str) -> None:
"""Unload a LoRA adapter and free its GPU slot."""
if self._lora_manager is None:
raise KeyError(f"No LoRA adapters loaded; '{lora_name}' not found.")
lora_id = self._lora_name_to_id.get(lora_name)
self._lora_manager.unload_adapter(lora_name)
self._lora_name_to_id.pop(lora_name, None)
# Proactively evict the KV cache namespace for this adapter so pages
# are freed immediately rather than waiting for LRU eviction pressure.
if lora_id is not None:
self.scheduler.evict_lora_namespace(lora_id)

def _setup_pd_layerwise_transfer(self, interval: int) -> None:
if not isinstance(self.pd_kv_transfer, DisaggPrefillExecutor):
return
Expand Down Expand Up @@ -838,8 +898,42 @@ def _process_new_requests(self):
spec.rolling_hashes = hashes
spec.storage_hit_pages = hit_pages
admitted_specs.append(spec)
# Track lora_id per request for forward-pass injection
if spec.lora_id != 0:
self._request_lora_ids[spec.request_id] = spec.lora_id
# Async-prefetch the adapter into the CPU pool so the
# disk read is overlapped with the previous forward step
# rather than blocking ``prepare_loras`` of the step that
# actually consumes it. No-op when already CPU-resident.
if (
self._lora_manager is not None
and os.environ.get("TOKENSPEED_LORA_PREFETCH", "1") == "1"
):
name = self._lora_manager._id_to_name.get(spec.lora_id)
if name is not None:
self._lora_manager.prefetch(name)

if admitted_specs:
# Optional ``pack`` policy: cluster admissions by lora_id so
# adapter-shared requests batch together at the C++ scheduler.
# Reduces GPU/CPU eviction churn under heavy mixed-adapter
# traffic (multiple distinct adapters > max_loras).
#
# Sort is stable: requests for the same adapter keep their
# arrival order, base-model (lora_id == 0) requests stay
# together at the front (their slot is the no-op sentinel).
#
# The benchmark in benchmark/test_lora_eviction_latency.py
# shows that CPU↔GPU promotion is essentially free; the
# only meaningful eviction cost is CPU→disk re-read (~30 ms).
# ``pack`` therefore mainly helps when ``working_set >
# max_loras_cpu`` and incoming traffic is bursty enough that
# multiple cold requests arrive in one event-loop iteration.
if (
self._lora_manager is not None
and self.server_args.lora_scheduling_policy == "pack"
):
admitted_specs.sort(key=lambda s: s.lora_id)
self.scheduler.submit_requests(admitted_specs)

@nvtx_range("loop:commit", color="rapids")
Expand Down
15 changes: 15 additions & 0 deletions python/tokenspeed/runtime/engine/input_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ async def tokenize_one_request(
created_time=time.time(),
input_multi_ids=obj.input_multi_ids,
input_extra_infos=obj.input_extra_infos,
lora_id=self._resolve_lora_id(obj),
)

return TokenizedEmbeddingReqInput(
Expand All @@ -198,3 +199,17 @@ async def tokenize_one_request(
sampling_params,
created_time=time.time(),
)

def _resolve_lora_id(self, obj: "GenerateReqInput") -> int:
"""Map request LoRA adapter name to an integer lora_id."""
lora_name = getattr(obj, "lora_name", None)
if lora_name is None:
return 0
lora_registry: dict = getattr(self.engine, "_lora_name_to_id", {})
lora_id = lora_registry.get(lora_name, 0)
if lora_id == 0:
raise ValueError(
f"lora_name={lora_name!r} is not a registered adapter. "
"Call load_lora_adapter(name, adapter_path) before using it in a request."
)
return lora_id
51 changes: 51 additions & 0 deletions python/tokenspeed/runtime/engine/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ class GenerateReqInput:
bootstrap_port: list[int] | int | None = None
bootstrap_room: list[int] | int | None = None

# LoRA adapter to use for this request. Supply the name under which the
# adapter was registered via Engine.load_lora_adapter(). None means use the
# base model. Requests do not load adapters from disk; adapter filesystem
# paths belong to load_lora_adapter().
lora_name: list[str | None] | str | None = None

def normalize_batch_and_arguments(self):
if (
self.text is None and self.input_ids is None and self.input_embeds is None
Expand Down Expand Up @@ -228,6 +234,11 @@ def normalize_batch_and_arguments(self):
self.token_ids_logprob = None
if isinstance(self.input_extra_infos, dict):
self.input_extra_infos = [self.input_extra_infos]
if isinstance(self.lora_name, list):
assert (
len(self.lora_name) == 1
), "lora_name list should have length 1 for single request."
self.lora_name = self.lora_name[0]
else:
if self.parallel_sample_num == 1:
num = self.batch_size
Expand Down Expand Up @@ -320,6 +331,15 @@ def normalize_batch_and_arguments(self):
else:
assert self.parallel_sample_num == 1

if self.lora_name is None:
self.lora_name = [None] * num
elif not isinstance(self.lora_name, list):
self.lora_name = [self.lora_name] * num
else:
assert (
len(self.lora_name) == num
), "lora_name should be a str or a list of matching length."

# Other checks
if self.session_params is not None:
assert isinstance(self.session_params, dict) or isinstance(
Expand Down Expand Up @@ -372,6 +392,11 @@ def __getitem__(self, i):
bootstrap_room=(
self.bootstrap_room[i] if self.bootstrap_room is not None else None
),
lora_name=(
self.lora_name[i]
if isinstance(self.lora_name, list)
else self.lora_name
),
)
sub.rid = self.rid[i]
return sub
Expand Down Expand Up @@ -422,6 +447,8 @@ class TokenizedGenerateReqInput:

input_multi_ids: list[list[int]] = None
input_extra_infos: list[dict] | None = None
# Integer lora_id resolved from lora_name (0 = base model)
lora_id: int = 0


@dataclass
Expand Down Expand Up @@ -852,6 +879,30 @@ class RpcReqOutput:
message: str


@dataclass
class LoadLoraReqInput:
lora_name: str
adapter_path: str


@dataclass
class LoadLoraReqOutput:
success: bool
lora_id: int = 0
message: str = ""


@dataclass
class UnloadLoraReqInput:
lora_name: str


@dataclass
class UnloadLoraReqOutput:
success: bool
message: str = ""


@dataclass
class GetLoadReqInput(BaseReq):
pass
Expand Down
37 changes: 37 additions & 0 deletions python/tokenspeed/runtime/engine/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,16 @@
GetInternalStateReqOutput,
GetLoadReqInput,
GetLoadReqOutput,
LoadLoraReqInput,
LoadLoraReqOutput,
ProfileReq,
ProfileReqOutput,
ProfileReqType,
SetInternalStateReq,
SetInternalStateReqOutput,
TokenizedGenerateReqInput,
UnloadLoraReqInput,
UnloadLoraReqOutput,
)
from tokenspeed.runtime.engine.request_types import FINISH_ABORT
from tokenspeed.runtime.engine.scheduler_utils import make_spec
Expand Down Expand Up @@ -80,6 +84,8 @@ def __init__(
send_func,
get_load_fn=None,
architectures: list[str] | None = None,
load_lora_fn=None,
unload_lora_fn=None,
) -> None:

self.forward_ct = 0
Expand All @@ -97,6 +103,8 @@ def __init__(
self.max_req_len = max_req_len
self.vocab_size = vocab_size
self.get_load_fn = get_load_fn
self.load_lora_fn = load_lora_fn
self.unload_lora_fn = unload_lora_fn

self.tokenizer = get_tokenizer(
server_args.tokenizer,
Expand Down Expand Up @@ -176,6 +184,34 @@ def process_requests(self, recv_reqs: list):
self.send_func.send_pyobj(self.get_load_fn())
else:
self.send_func.send_pyobj(GetLoadReqOutput())
elif isinstance(recv_req, LoadLoraReqInput):
try:
if self.load_lora_fn is not None:
lora_id = self.load_lora_fn(
recv_req.lora_name, recv_req.adapter_path
)
self.send_func.send_pyobj(
LoadLoraReqOutput(success=True, lora_id=lora_id)
)
else:
self.send_func.send_pyobj(
LoadLoraReqOutput(
success=False, message="LoRA not enabled on this server"
)
)
except Exception as e:
self.send_func.send_pyobj(
LoadLoraReqOutput(success=False, message=str(e))
)
elif isinstance(recv_req, UnloadLoraReqInput):
try:
if self.unload_lora_fn is not None:
self.unload_lora_fn(recv_req.lora_name)
self.send_func.send_pyobj(UnloadLoraReqOutput(success=True))
except Exception as e:
self.send_func.send_pyobj(
UnloadLoraReqOutput(success=False, message=str(e))
)
else:
raise NotImplementedError(f"Unsupported request type: {type(recv_req)}")
return new_req_specs, req_states, bootstrap_infos, abort_rids
Expand All @@ -190,6 +226,7 @@ def handle_generate_request(
req_spec = make_spec(
rid=recv_req.rid,
tokens=recv_req.input_ids,
lora_id=getattr(recv_req, "lora_id", 0),
)
req_state = RequestState.from_recv_req(
recv_req,
Expand Down
Loading
Loading