diff --git a/python/tokenspeed/bench.py b/python/tokenspeed/bench.py index c9a61f3ec..08adba1be 100755 --- a/python/tokenspeed/bench.py +++ b/python/tokenspeed/bench.py @@ -776,7 +776,7 @@ def get_lora_request( self, index: int, max_loras: int | None = None, - lora_path: str | None = None, + lora_name: str | None = None, lora_assignment: str = "random", ) -> None: return None @@ -821,7 +821,7 @@ def sample( output_len: int = DEFAULT_OUTPUT_LEN, batchsize: int = 1, max_loras: int | None = None, - lora_path: str | None = None, + lora_name: str | None = None, lora_assignment: str = "random", **kwargs, ) -> list[SampleRequest]: @@ -879,7 +879,7 @@ def sample( lora_req = self.get_lora_request( index=i, max_loras=max_loras, - lora_path=lora_path, + lora_name=lora_name, lora_assignment=lora_assignment, ) requests.append( diff --git a/python/tokenspeed/runtime/engine/async_llm.py b/python/tokenspeed/runtime/engine/async_llm.py index eaa0173f2..def2892cf 100755 --- a/python/tokenspeed/runtime/engine/async_llm.py +++ b/python/tokenspeed/runtime/engine/async_llm.py @@ -143,6 +143,8 @@ def __init__( # Read model args self.model_path = server_args.model self.served_model_name = server_args.served_model_name + # LoRA adapter name → integer lora_id (populated by load_lora_adapter). + self._lora_name_to_id: dict[str, int] = {} self.model_config = ModelConfig( server_args.model, trust_remote_code=server_args.trust_remote_code, diff --git a/python/tokenspeed/runtime/engine/event_loop.py b/python/tokenspeed/runtime/engine/event_loop.py index ae1ceab44..092e89b03 100644 --- a/python/tokenspeed/runtime/engine/event_loop.py +++ b/python/tokenspeed/runtime/engine/event_loop.py @@ -19,6 +19,7 @@ # SOFTWARE. import faulthandler +import os import signal import time from collections import OrderedDict @@ -50,7 +51,6 @@ cache_sync_debug_enabled, make_config, pool_to_paged_cache_groups, - pool_to_prefix_cache_adjunct_spec, pop_common_cache_event_payloads, ) from tokenspeed.runtime.execution.distributed_initializer import ( @@ -299,8 +299,11 @@ def __init__( f"(ratio={server_args.mamba_full_memory_ratio})." ) - # Adjunct enabled only when pool opts in AND prefix-caching switch is on. + enable_mixed_prefill_decode = ( + server_args.enable_mixed_batch and server_args.speculative_algorithm is None + ) + # Adjunct enabled only when pool opts in AND prefix-caching switch is on. paged_cache_groups = pool_to_paged_cache_groups(token_to_kv_pool) prefix_cache_adjunct = None required_groups = token_to_kv_pool.prefix_cache_required_group_ids @@ -329,8 +332,9 @@ def __init__( enable_mamba_l2=server_args.enable_mamba_l2, mamba_l2_host_slots=mamba_l2_host_slots, paged_cache_groups=paged_cache_groups, - enable_mixed_prefill_decode=server_args.enable_mixed_batch, + enable_mixed_prefill_decode=enable_mixed_prefill_decode, prefix_cache_adjunct=prefix_cache_adjunct, + max_loras=server_args.max_loras if server_args.enable_lora else 0, ) logger.info( "Scheduler config: page_size=%s num_device_pages=%s " @@ -381,6 +385,8 @@ def __init__( send_func=self.send_to_tokenizer, get_load_fn=self._get_load, architectures=self.model_config.hf_config.architectures, + load_lora_fn=self.load_lora_adapter, + unload_lora_fn=self.unload_lora_adapter, ) self.output_processor = OutputProcesser( @@ -436,6 +442,60 @@ def __init__( else: self.pd_kv_transfer = None + # ── LoRA ───────────────────────────────────────────────────────────── + self._lora_manager = None # LoraManager (lazy init) + self._lora_name_to_id: dict[str, int] = {} # name → integer lora_id + self._request_lora_ids: dict[str, int] = {} # rid → lora_id + + if server_args.enable_lora: + self._init_lora_manager() + + def _init_lora_manager(self) -> None: + """Bind to the LoraManager owned by the model executor. + + The model executor creates the manager during its own ``__init__`` so + that the CUDA-graph capture sees a live manager (and bakes the LoRA + delta path into the captured graphs). The event loop only borrows + the reference and shares its request-id → lora-id map. + """ + self._lora_manager = self.model_executor.lora_manager + if self._lora_manager is None: + raise RuntimeError( + "Model executor was not configured with --enable-lora; " + "cannot initialize LoRA support." + ) + self.model_executor.request_lora_ids = self._request_lora_ids + logger.info("LoraManager bound (max_loras=%d)", self.server_args.max_loras) + + def load_lora_adapter(self, lora_name: str, adapter_path: str) -> int: + """Load a PEFT LoRA adapter and make it available for serving. + + Returns the integer lora_id assigned to this adapter. + """ + if not self.server_args.enable_lora: + raise ValueError( + "Server was not started with --enable-lora. " + "Restart with --enable-lora to use LoRA adapters." + ) + if self._lora_manager is None: + self._init_lora_manager() + lora_id = self._lora_manager.load_adapter(lora_name, adapter_path) + self._lora_name_to_id[lora_name] = lora_id + logger.info("Loaded LoRA adapter '%s' → lora_id=%d", lora_name, lora_id) + return lora_id + + def unload_lora_adapter(self, lora_name: str) -> None: + """Unload a LoRA adapter and free its GPU slot.""" + if self._lora_manager is None: + raise KeyError(f"No LoRA adapters loaded; '{lora_name}' not found.") + lora_id = self._lora_name_to_id.get(lora_name) + self._lora_manager.unload_adapter(lora_name) + self._lora_name_to_id.pop(lora_name, None) + # Proactively evict the KV cache namespace for this adapter so pages + # are freed immediately rather than waiting for LRU eviction pressure. + if lora_id is not None: + self.scheduler.evict_lora_namespace(lora_id) + def _setup_pd_layerwise_transfer(self, interval: int) -> None: if not isinstance(self.pd_kv_transfer, DisaggPrefillExecutor): return @@ -838,8 +898,42 @@ def _process_new_requests(self): spec.rolling_hashes = hashes spec.storage_hit_pages = hit_pages admitted_specs.append(spec) + # Track lora_id per request for forward-pass injection + if spec.lora_id != 0: + self._request_lora_ids[spec.request_id] = spec.lora_id + # Async-prefetch the adapter into the CPU pool so the + # disk read is overlapped with the previous forward step + # rather than blocking ``prepare_loras`` of the step that + # actually consumes it. No-op when already CPU-resident. + if ( + self._lora_manager is not None + and os.environ.get("TOKENSPEED_LORA_PREFETCH", "1") == "1" + ): + name = self._lora_manager._id_to_name.get(spec.lora_id) + if name is not None: + self._lora_manager.prefetch(name) if admitted_specs: + # Optional ``pack`` policy: cluster admissions by lora_id so + # adapter-shared requests batch together at the C++ scheduler. + # Reduces GPU/CPU eviction churn under heavy mixed-adapter + # traffic (multiple distinct adapters > max_loras). + # + # Sort is stable: requests for the same adapter keep their + # arrival order, base-model (lora_id == 0) requests stay + # together at the front (their slot is the no-op sentinel). + # + # The benchmark in benchmark/test_lora_eviction_latency.py + # shows that CPU↔GPU promotion is essentially free; the + # only meaningful eviction cost is CPU→disk re-read (~30 ms). + # ``pack`` therefore mainly helps when ``working_set > + # max_loras_cpu`` and incoming traffic is bursty enough that + # multiple cold requests arrive in one event-loop iteration. + if ( + self._lora_manager is not None + and self.server_args.lora_scheduling_policy == "pack" + ): + admitted_specs.sort(key=lambda s: s.lora_id) self.scheduler.submit_requests(admitted_specs) @nvtx_range("loop:commit", color="rapids") diff --git a/python/tokenspeed/runtime/engine/input_processor.py b/python/tokenspeed/runtime/engine/input_processor.py index 040ae6675..0e6b8d0d0 100644 --- a/python/tokenspeed/runtime/engine/input_processor.py +++ b/python/tokenspeed/runtime/engine/input_processor.py @@ -189,6 +189,7 @@ async def tokenize_one_request( created_time=time.time(), input_multi_ids=obj.input_multi_ids, input_extra_infos=obj.input_extra_infos, + lora_id=self._resolve_lora_id(obj), ) return TokenizedEmbeddingReqInput( @@ -198,3 +199,17 @@ async def tokenize_one_request( sampling_params, created_time=time.time(), ) + + def _resolve_lora_id(self, obj: "GenerateReqInput") -> int: + """Map request LoRA adapter name to an integer lora_id.""" + lora_name = getattr(obj, "lora_name", None) + if lora_name is None: + return 0 + lora_registry: dict = getattr(self.engine, "_lora_name_to_id", {}) + lora_id = lora_registry.get(lora_name, 0) + if lora_id == 0: + raise ValueError( + f"lora_name={lora_name!r} is not a registered adapter. " + "Call load_lora_adapter(name, adapter_path) before using it in a request." + ) + return lora_id diff --git a/python/tokenspeed/runtime/engine/io_struct.py b/python/tokenspeed/runtime/engine/io_struct.py index 5782d30c0..e592da5bf 100755 --- a/python/tokenspeed/runtime/engine/io_struct.py +++ b/python/tokenspeed/runtime/engine/io_struct.py @@ -136,6 +136,12 @@ class GenerateReqInput: bootstrap_port: list[int] | int | None = None bootstrap_room: list[int] | int | None = None + # LoRA adapter to use for this request. Supply the name under which the + # adapter was registered via Engine.load_lora_adapter(). None means use the + # base model. Requests do not load adapters from disk; adapter filesystem + # paths belong to load_lora_adapter(). + lora_name: list[str | None] | str | None = None + def normalize_batch_and_arguments(self): if ( self.text is None and self.input_ids is None and self.input_embeds is None @@ -228,6 +234,11 @@ def normalize_batch_and_arguments(self): self.token_ids_logprob = None if isinstance(self.input_extra_infos, dict): self.input_extra_infos = [self.input_extra_infos] + if isinstance(self.lora_name, list): + assert ( + len(self.lora_name) == 1 + ), "lora_name list should have length 1 for single request." + self.lora_name = self.lora_name[0] else: if self.parallel_sample_num == 1: num = self.batch_size @@ -320,6 +331,15 @@ def normalize_batch_and_arguments(self): else: assert self.parallel_sample_num == 1 + if self.lora_name is None: + self.lora_name = [None] * num + elif not isinstance(self.lora_name, list): + self.lora_name = [self.lora_name] * num + else: + assert ( + len(self.lora_name) == num + ), "lora_name should be a str or a list of matching length." + # Other checks if self.session_params is not None: assert isinstance(self.session_params, dict) or isinstance( @@ -372,6 +392,11 @@ def __getitem__(self, i): bootstrap_room=( self.bootstrap_room[i] if self.bootstrap_room is not None else None ), + lora_name=( + self.lora_name[i] + if isinstance(self.lora_name, list) + else self.lora_name + ), ) sub.rid = self.rid[i] return sub @@ -422,6 +447,8 @@ class TokenizedGenerateReqInput: input_multi_ids: list[list[int]] = None input_extra_infos: list[dict] | None = None + # Integer lora_id resolved from lora_name (0 = base model) + lora_id: int = 0 @dataclass @@ -852,6 +879,30 @@ class RpcReqOutput: message: str +@dataclass +class LoadLoraReqInput: + lora_name: str + adapter_path: str + + +@dataclass +class LoadLoraReqOutput: + success: bool + lora_id: int = 0 + message: str = "" + + +@dataclass +class UnloadLoraReqInput: + lora_name: str + + +@dataclass +class UnloadLoraReqOutput: + success: bool + message: str = "" + + @dataclass class GetLoadReqInput(BaseReq): pass diff --git a/python/tokenspeed/runtime/engine/request_handler.py b/python/tokenspeed/runtime/engine/request_handler.py index aa0b31fc5..3480aa4b4 100644 --- a/python/tokenspeed/runtime/engine/request_handler.py +++ b/python/tokenspeed/runtime/engine/request_handler.py @@ -41,12 +41,16 @@ GetInternalStateReqOutput, GetLoadReqInput, GetLoadReqOutput, + LoadLoraReqInput, + LoadLoraReqOutput, ProfileReq, ProfileReqOutput, ProfileReqType, SetInternalStateReq, SetInternalStateReqOutput, TokenizedGenerateReqInput, + UnloadLoraReqInput, + UnloadLoraReqOutput, ) from tokenspeed.runtime.engine.request_types import FINISH_ABORT from tokenspeed.runtime.engine.scheduler_utils import make_spec @@ -80,6 +84,8 @@ def __init__( send_func, get_load_fn=None, architectures: list[str] | None = None, + load_lora_fn=None, + unload_lora_fn=None, ) -> None: self.forward_ct = 0 @@ -97,6 +103,8 @@ def __init__( self.max_req_len = max_req_len self.vocab_size = vocab_size self.get_load_fn = get_load_fn + self.load_lora_fn = load_lora_fn + self.unload_lora_fn = unload_lora_fn self.tokenizer = get_tokenizer( server_args.tokenizer, @@ -176,6 +184,34 @@ def process_requests(self, recv_reqs: list): self.send_func.send_pyobj(self.get_load_fn()) else: self.send_func.send_pyobj(GetLoadReqOutput()) + elif isinstance(recv_req, LoadLoraReqInput): + try: + if self.load_lora_fn is not None: + lora_id = self.load_lora_fn( + recv_req.lora_name, recv_req.adapter_path + ) + self.send_func.send_pyobj( + LoadLoraReqOutput(success=True, lora_id=lora_id) + ) + else: + self.send_func.send_pyobj( + LoadLoraReqOutput( + success=False, message="LoRA not enabled on this server" + ) + ) + except Exception as e: + self.send_func.send_pyobj( + LoadLoraReqOutput(success=False, message=str(e)) + ) + elif isinstance(recv_req, UnloadLoraReqInput): + try: + if self.unload_lora_fn is not None: + self.unload_lora_fn(recv_req.lora_name) + self.send_func.send_pyobj(UnloadLoraReqOutput(success=True)) + except Exception as e: + self.send_func.send_pyobj( + UnloadLoraReqOutput(success=False, message=str(e)) + ) else: raise NotImplementedError(f"Unsupported request type: {type(recv_req)}") return new_req_specs, req_states, bootstrap_infos, abort_rids @@ -190,6 +226,7 @@ def handle_generate_request( req_spec = make_spec( rid=recv_req.rid, tokens=recv_req.input_ids, + lora_id=getattr(recv_req, "lora_id", 0), ) req_state = RequestState.from_recv_req( recv_req, diff --git a/python/tokenspeed/runtime/engine/scheduler_control_client.py b/python/tokenspeed/runtime/engine/scheduler_control_client.py index 52fb5d9c7..8325a2527 100755 --- a/python/tokenspeed/runtime/engine/scheduler_control_client.py +++ b/python/tokenspeed/runtime/engine/scheduler_control_client.py @@ -47,6 +47,8 @@ GetWeightsByNameReqOutput, InitWeightsUpdateGroupReqInput, InitWeightsUpdateGroupReqOutput, + LoadLoraReqInput, + LoadLoraReqOutput, ProfileReq, ProfileReqOutput, ProfileReqType, @@ -56,6 +58,8 @@ ResumeMemoryOccupationReqOutput, SetInternalStateReq, SetInternalStateReqOutput, + UnloadLoraReqInput, + UnloadLoraReqOutput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqOutput, UpdateWeightsFromTensorReqInput, @@ -95,7 +99,7 @@ async def queueing_call(self, obj: T): assert self._result_values is None if obj: - self._sender.send_pyobj(obj) + await self._sender.send_pyobj(obj) self._result_event = asyncio.Event() self._result_values = [] @@ -115,7 +119,7 @@ async def watching_call(self, obj): self._result_event = asyncio.Event() if obj: - self._sender.send_pyobj(obj) + await self._sender.send_pyobj(obj) await self._result_event.wait() result_values = copy.deepcopy(self._result_values) @@ -178,6 +182,12 @@ def init_communicators(self: AsyncLLM, server_args: ServerArgs): server_args.mapping.attn.dp_size, mode="watching", ) + self.load_lora_communicator = _Communicator( + self.engine_core_client.send_to_scheduler, server_args.mapping.attn.dp_size + ) + self.unload_lora_communicator = _Communicator( + self.engine_core_client.send_to_scheduler, server_args.mapping.attn.dp_size + ) self._result_dispatcher += self._get_communicator_dispatcher() @@ -232,9 +242,39 @@ def _get_communicator_dispatcher(self: AsyncLLM): GetLoadReqOutput, self.get_load_communicator.handle_recv, ), + ( + LoadLoraReqOutput, + self.load_lora_communicator.handle_recv, + ), + ( + UnloadLoraReqOutput, + self.unload_lora_communicator.handle_recv, + ), ] ) + async def load_lora_adapter( + self: "AsyncLLM", + lora_name: str, + adapter_path: str, + ) -> tuple[bool, int, str]: + """Send a LoadLoraReqInput to the scheduler subprocess.""" + self.auto_create_handle_loop() + result = ( + await self.load_lora_communicator( + LoadLoraReqInput(lora_name=lora_name, adapter_path=adapter_path) + ) + )[0] + return result.success, result.lora_id, result.message + + async def unload_lora_adapter(self: "AsyncLLM", lora_name: str) -> tuple[bool, str]: + """Send an UnloadLoraReqInput to the scheduler subprocess.""" + self.auto_create_handle_loop() + result = ( + await self.unload_lora_communicator(UnloadLoraReqInput(lora_name=lora_name)) + )[0] + return result.success, result.message + async def flush_cache(self: AsyncLLM) -> FlushCacheReqOutput: return (await self.flush_cache_communicator(FlushCacheReqInput()))[0] diff --git a/python/tokenspeed/runtime/engine/scheduler_utils.py b/python/tokenspeed/runtime/engine/scheduler_utils.py index 7ddee553e..0b6a3b8e5 100644 --- a/python/tokenspeed/runtime/engine/scheduler_utils.py +++ b/python/tokenspeed/runtime/engine/scheduler_utils.py @@ -30,9 +30,7 @@ ExecutionEvent, ForwardEvent, PagedCacheGroupConfig, - PagedCacheGroupFamily, PagedCacheRetention, - PrefixCacheAdjunctSpec, RequestSpec, SchedulerConfig, ) @@ -44,10 +42,11 @@ _TRUTHY_ENV_VALUES = {"1", "true", "yes", "on"} -def make_spec(rid: str, tokens: list[int]) -> RequestSpec: +def make_spec(rid: str, tokens: list[int], lora_id: int = 0) -> RequestSpec: spec = RequestSpec() spec.request_id = rid spec.tokens = tokens + spec.lora_id = lora_id return spec @@ -71,7 +70,7 @@ def make_config( mamba_l2_host_slots: int = 0, paged_cache_groups: Sequence["PagedCacheGroupConfig"] | None = None, enable_mixed_prefill_decode: bool = False, - prefix_cache_adjunct: "PrefixCacheAdjunctSpec | None" = None, + max_loras: int = 0, ) -> SchedulerConfig: cfg = SchedulerConfig() cfg.num_device_pages = num_device_pages @@ -101,17 +100,15 @@ def make_config( cfg.enable_mamba_l2 = enable_mamba_l2 cfg.mamba_l2_host_slots = mamba_l2_host_slots cfg.enable_mixed_prefill_decode = enable_mixed_prefill_decode + cfg.max_loras = max_loras if paged_cache_groups: cfg.paged_cache_groups = list(paged_cache_groups) - # Opt-in; unset means paged-cache groups are transport-only. - if prefix_cache_adjunct is not None: - cfg.prefix_cache_adjunct = prefix_cache_adjunct return cfg def pool_to_paged_cache_groups(pool: Any) -> list: """Convert a KV pool's paged_cache_group_specs to scheduler configs.""" - specs = pool.paged_cache_group_specs + specs = getattr(pool, "paged_cache_group_specs", ()) if not specs: return [] counts = pool.paged_cache_group_page_counts @@ -126,23 +123,12 @@ def pool_to_paged_cache_groups(pool: Any) -> list: f"pool_to_paged_cache_groups: unsupported retention " f"{spec.retention!r} for group {spec.group_id!r}" ) - family_str = getattr(spec, "family", "history") - if family_str == "history": - family = PagedCacheGroupFamily.History - elif family_str == "state": - family = PagedCacheGroupFamily.State - else: - raise ValueError( - f"pool_to_paged_cache_groups: unsupported family " - f"{family_str!r} for group {spec.group_id!r}" - ) kwargs = dict( group_id=spec.group_id, rows_per_page=int(spec.rows_per_page), entry_stride_tokens=int(spec.entry_stride_tokens), total_pages=int(counts[spec.group_id]), retention=retention, - family=family, ) if spec.retention == "sliding_window": kwargs["sliding_window_tokens"] = int(spec.sliding_window_tokens) @@ -150,19 +136,6 @@ def pool_to_paged_cache_groups(pool: Any) -> list: return out -def pool_to_prefix_cache_adjunct_spec( - required_group_ids: Sequence[str], -) -> "PrefixCacheAdjunctSpec": - """Build a PrefixCacheAdjunctSpec from a non-empty required-group-id list.""" - if not required_group_ids: - raise ValueError( - "pool_to_prefix_cache_adjunct_spec: required_group_ids must be non-empty" - ) - spec = PrefixCacheAdjunctSpec() - spec.required_groups = [str(gid) for gid in required_group_ids] - return spec - - def make_extend_result_event(request_id: str, tokens: list[int] = ()) -> None: fe = ForwardEvent.ExtendResult() fe.request_id = request_id diff --git a/python/tokenspeed/runtime/entrypoints/engine.py b/python/tokenspeed/runtime/entrypoints/engine.py index 048964e9a..156508022 100755 --- a/python/tokenspeed/runtime/entrypoints/engine.py +++ b/python/tokenspeed/runtime/entrypoints/engine.py @@ -170,6 +170,7 @@ def generate( bootstrap_port: list[int] | int | None = None, bootstrap_room: list[int] | int | None = None, data_parallel_rank: int | None = None, + lora_name: list[str | None] | str | None = None, ) -> dict | Iterator[dict]: """ The arguments of this function match @@ -209,6 +210,7 @@ def generate( bootstrap_host=bootstrap_host, bootstrap_port=bootstrap_port, bootstrap_room=bootstrap_room, + lora_name=lora_name, ) if stream: return self.llm.generate_stream(obj) @@ -245,6 +247,7 @@ async def async_generate( bootstrap_port: list[int] | int | None = None, bootstrap_room: list[int] | int | None = None, user_rid: list[str] | str | None = None, + lora_name: list[str | None] | str | None = None, ) -> dict | AsyncIterator[dict]: """ The arguments of this function match @@ -279,6 +282,7 @@ async def async_generate( bootstrap_port=bootstrap_port, bootstrap_room=bootstrap_room, user_rid=user_rid, + lora_name=lora_name, ) generator = self.tokenizer_manager.generate_request(obj) @@ -435,6 +439,32 @@ def collective_rpc(self, method: str, **kwargs): assert isinstance(recv_req, RpcReqOutput) assert recv_req.success, recv_req.message + def load_lora_adapter( + self, + lora_name: str, + adapter_path: str, + ) -> int: + """Load a PEFT LoRA adapter. Returns the integer lora_id.""" + success, lora_id, message = self.llm.run( + self.tokenizer_manager.load_lora_adapter(lora_name, adapter_path) + ) + if not success: + raise RuntimeError(f"Failed to load LoRA adapter '{lora_name}': {message}") + # Update the local name→id registry so future requests resolve correctly. + self.tokenizer_manager._lora_name_to_id[lora_name] = lora_id + return lora_id + + def unload_lora_adapter(self, lora_name: str) -> None: + """Unload a previously loaded LoRA adapter.""" + success, message = self.llm.run( + self.tokenizer_manager.unload_lora_adapter(lora_name) + ) + if not success: + raise RuntimeError( + f"Failed to unload LoRA adapter '{lora_name}': {message}" + ) + self.tokenizer_manager._lora_name_to_id.pop(lora_name, None) + def save_remote_model(self, **kwargs): self.collective_rpc("save_remote_model", **kwargs) diff --git a/python/tokenspeed/runtime/entrypoints/engine_base.py b/python/tokenspeed/runtime/entrypoints/engine_base.py index 4654f25d6..c4e141d76 100755 --- a/python/tokenspeed/runtime/entrypoints/engine_base.py +++ b/python/tokenspeed/runtime/entrypoints/engine_base.py @@ -56,6 +56,7 @@ def generate( bootstrap_port: list[int] | int | None = None, bootstrap_room: list[int] | int | None = None, data_parallel_rank: int | None = None, + lora_name: list[str | None] | str | None = None, ) -> dict | Iterator[dict]: """Generate outputs based on given inputs.""" @@ -83,3 +84,32 @@ def resume_memory_occupation(self) -> None: @abstractmethod def shutdown(self) -> None: """Shutdown the engine and clean up resources.""" + + # ------------------------------------------------------------------ + # LoRA adapter management + # ------------------------------------------------------------------ + + def load_lora_adapter( + self, + lora_name: str, + adapter_path: str, + ) -> int: + """Load a PEFT LoRA adapter and make it available for serving. + + Args: + lora_name: Short identifier used by request-time lora_name. + adapter_path: Filesystem path to the PEFT adapter directory. + + Returns: + Integer lora_id assigned to this adapter. + """ + raise NotImplementedError( + "load_lora_adapter() is not implemented on this engine type. " + "Use the tokenspeed serve engine." + ) + + def unload_lora_adapter(self, lora_name: str) -> None: + """Unload a previously loaded LoRA adapter and free its GPU slot.""" + raise NotImplementedError( + "unload_lora_adapter() is not implemented on this engine type." + ) diff --git a/python/tokenspeed/runtime/execution/context.py b/python/tokenspeed/runtime/execution/context.py index e5cb59f39..324a61971 100644 --- a/python/tokenspeed/runtime/execution/context.py +++ b/python/tokenspeed/runtime/execution/context.py @@ -20,8 +20,11 @@ from __future__ import annotations +from collections.abc import Iterator +from contextlib import contextmanager +from contextvars import ContextVar from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch @@ -33,6 +36,24 @@ if TYPE_CHECKING: from tokenspeed.runtime.layers.attention.backends.base import AttentionBackend from tokenspeed.runtime.layers.attention.kv_cache.base import BaseTokenToKVPool + from tokenspeed.runtime.lora.lora_manager import LoraManager + +_CURRENT_LORA_MANAGER: ContextVar[Optional["LoraManager"]] = ContextVar( + "tokenspeed_current_lora_manager", default=None +) + + +def get_current_lora_manager() -> Optional["LoraManager"]: + return _CURRENT_LORA_MANAGER.get() + + +@contextmanager +def bind_forward_context(ctx: "ForwardContext") -> Iterator[None]: + token = _CURRENT_LORA_MANAGER.set(ctx.lora_manager) + try: + yield + finally: + _CURRENT_LORA_MANAGER.reset(token) @dataclass @@ -58,3 +79,11 @@ class ForwardContext: # --- logits processor --- gather_ids: torch.Tensor | None = None + + # --- LoRA --- + # Reference to the LoraManager. When set, forward layers call + # ``lora_manager.apply_qkv_lora`` / ``apply_o_lora`` which read from + # the manager's persistent batch_info. Set at capture time when + # ``--enable-lora`` is on so the LoRA path is recorded into the graph + # (NO_LORA_SLOT = no adapter), otherwise None. + lora_manager: Optional["LoraManager"] = None diff --git a/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py b/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py index 9fa151db1..ee4403d01 100644 --- a/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py +++ b/python/tokenspeed/runtime/execution/cuda_graph_wrapper.py @@ -50,6 +50,7 @@ from tokenspeed.runtime.execution.runtime_states import RuntimeStates from tokenspeed.runtime.layers.attention.backends.base import AttentionBackend from tokenspeed.runtime.layers.attention.kv_cache.base import BaseTokenToKVPool + from tokenspeed.runtime.lora.lora_manager import LoraManager from tokenspeed.runtime.sampling.backends.base import SamplingBackend logger = get_colorful_logger(__name__) @@ -194,6 +195,7 @@ def __init__( eager_grammar_buffers=None, sampling_backend: SamplingBackend | None = None, runtime_states: RuntimeStates | None = None, + lora_manager: LoraManager | None = None, ): self.config = config self.attn_backend = attn_backend @@ -206,6 +208,7 @@ def __init__( self.capturable_grammar = capturable_grammar self.eager_grammar_buffers = eager_grammar_buffers self.runtime_states = runtime_states + self.lora_manager = lora_manager self.enable_torch_compile = getattr(config, "enable_torch_compile", False) self.disable_padding = config.disable_cuda_graph_padding self.enable_cudagraph_gc = getattr(config, "enable_cudagraph_gc", True) @@ -255,6 +258,12 @@ def __init__( self.graphs: dict[int, torch.cuda.CUDAGraph] = {} self.output_buffers: dict[int, tuple] = {} + # Per-bs no-LoRA variant. Populated only when ``lora_manager`` is + # configured: a second captured graph that omits the LoRA Triton + # kernels entirely, replayed when ``LoraManager.has_active_lora`` + # is False so base-model decode pays no LoRA overhead at all. + self.graphs_no_lora: dict[int, torch.cuda.CUDAGraph] = {} + self.output_buffers_no_lora: dict[int, tuple] = {} self._forward_func: Callable | None = forward_func self.disable = config.enforce_eager @@ -270,15 +279,26 @@ def capture(self): """ Capture CUDA graphs for all configured batch sizes. + When a ``lora_manager`` is attached, captures TWO graphs per batch + size: a with-LoRA graph (records the segmented-GEMM Triton kernels + and feeds them with the manager's persistent batch_info) and a + no-LoRA graph (omits those kernels entirely). Replay picks the + no-LoRA variant when ``has_active_lora`` is False. + Args: forward_func: ModelExecutor.forward_step(bs, ctx, sampling_info). """ rank = self.global_rank + capture_no_lora_too = self.lora_manager is not None with freeze_gc(self.enable_cudagraph_gc): self.stream = torch.cuda.Stream() capture_range = tqdm.tqdm(self.capture_bs) if rank == 0 else self.capture_bs if rank == 0: - logger.info("Capturing batches: %s", self.capture_bs) + logger.info( + "Capturing batches: %s%s", + self.capture_bs, + " (×2: with-LoRA + no-LoRA)" if capture_no_lora_too else "", + ) for bs in capture_range: if rank == 0: avail_mem = get_available_gpu_memory( @@ -287,11 +307,15 @@ def capture(self): capture_range.set_description( f"Capturing batches ({bs=} {avail_mem=:.2f} GB)" ) - graph, output_buffers = self._capture_one(bs) + graph, output_buffers = self._capture_one(bs, attach_lora=True) self.graphs[bs] = graph self.output_buffers[bs] = output_buffers + if capture_no_lora_too: + graph_nl, output_nl = self._capture_one(bs, attach_lora=False) + self.graphs_no_lora[bs] = graph_nl + self.output_buffers_no_lora[bs] = output_nl - def _capture_one(self, bs: int): + def _capture_one(self, bs: int, attach_lora: bool = True): graph = torch.cuda.CUDAGraph() ctx = ForwardContext( @@ -314,6 +338,44 @@ def _capture_one(self, bs: int): if self.dp_size > 1: ctx.global_num_tokens = [bs * self.max_tokens_per_req] * self.world_size + # Bind LoRA only for the with-LoRA variant. When ``attach_lora`` + # is False we capture a parallel graph that omits the LoRA Triton + # kernels entirely (qwen3's ``if ctx.lora_manager is not None`` + # branch falls through), used at replay when no request in the + # batch has an active adapter. + if attach_lora and self.lora_manager is not None: + ctx.lora_manager = self.lora_manager + # Pre-fill batch_info so the captured kernels see a stable + # set of pointers; runtime updates the same tensors before + # each ``graph.replay()`` and the kernels re-read seg_lens / + # weight_indices / lora_ranks. + # + # Use lora_id=0 (base model) which resolves to NO_LORA_SLOT, BUT + # force has_active_lora=True so LoRA kernels ARE captured in the + # graph. With dynamic GPU-tensor weight indexing (w13_A_buffers + # etc.) the captured kernels read weight_indices at replay time, + # so the correct adapter slot is used regardless of what was set + # during capture. Slot 0 weights are all-zero at capture time + # (no adapter loaded yet), so the model output is unaffected. + self.lora_manager.prepare_loras( + [0] * bs, per_request_token_counts=self.max_tokens_per_req + ) + # Force has_active_lora and single_lora_slot so ALL LoRA kernels + # (MoE, attention, MLP) are included in the captured graph. + # This applies to any enabled LoRA type — without this, kernels that + # check has_active_lora (e.g. apply_qkv_lora) return early during + # capture, recording a no-op that is then replayed at every decode step. + if ( + self.lora_manager.enable_moe_lora + or self.lora_manager.enable_attn_lora + or self.lora_manager.enable_mlp_lora + ): + self.lora_manager.has_active_lora = True + bi = self.lora_manager._batch_info + bi.single_lora_slot = 0 + bi.single_lora_rank = self.lora_manager.max_lora_rank + bi.weight_indices[:bs].fill_(0) + # Capture with is_all_greedy=False so the graph records the full # top_k_top_p_sampling path (greedy-only requests are served by the # same path with top_k=1 in the buffer, which effectively argmaxes). @@ -332,6 +394,7 @@ def _capture_one(self, bs: int): device=self.device, ) + from tokenspeed.runtime.execution.context import bind_forward_context from tokenspeed.runtime.grammar.capturable_grammar import ( bind_grammar_mask_buf, ) @@ -359,7 +422,8 @@ def run_once(): self.capturable_grammar.add_batch( grammars=[None] * bs, bs=bs, has_candidates=False ) - return self._forward_func(bs=bs, ctx=ctx, sampling_info=sampling_info) + with bind_forward_context(ctx): + return self._forward_func(bs=bs, ctx=ctx, sampling_info=sampling_info) # Warm up before capture. for _ in range(4): @@ -790,12 +854,25 @@ def __call__( # the per-request generators with the capture-stub generator. self.deepep_adapter.replay() + # Pick the no-LoRA variant when --enable-lora is on but no + # request in this batch uses an adapter — that graph omits the + # per-layer Triton LoRA kernels entirely. + use_no_lora_variant = ( + self.lora_manager is not None + and not self.lora_manager.has_active_lora + and padded_bs in self.graphs_no_lora + ) + if use_no_lora_variant: + graph = self.graphs_no_lora[padded_bs] + output_buffers = self.output_buffers_no_lora[padded_bs] + else: + graph = self.graphs[padded_bs] + output_buffers = self.output_buffers[padded_bs] + with nvtx_range("graph_replay", color="red"): - self.graphs[padded_bs].replay() + graph.replay() - output_tokens, output_lengths, output_logprobs = self.output_buffers[ - padded_bs - ] + output_tokens, output_lengths, output_logprobs = output_buffers result = ( output_tokens[: bs * self.max_tokens_per_req], @@ -839,7 +916,10 @@ def __call__( **mamba_kwargs, ) - result = self._forward_func(bs=bs, ctx=ctx, sampling_info=sampling_info) + from tokenspeed.runtime.execution.context import bind_forward_context + + with bind_forward_context(ctx): + result = self._forward_func(bs=bs, ctx=ctx, sampling_info=sampling_info) # Update mamba/GDN state after speculative verify if ( diff --git a/python/tokenspeed/runtime/execution/model_executor.py b/python/tokenspeed/runtime/execution/model_executor.py index 9c36205d9..30ba0fb9b 100644 --- a/python/tokenspeed/runtime/execution/model_executor.py +++ b/python/tokenspeed/runtime/execution/model_executor.py @@ -108,6 +108,18 @@ class ModelExecutorConfig: disable_capturable_grammar: bool = False mamba_cache_chunk_size: int = 64 + # ====== LORA ========= + enable_lora: bool = False + max_loras: int = 4 + max_lora_rank: int = 64 + # Tiered residence: at most ``max_loras`` adapters in GPU buffers, + # at most ``max_loras_cpu`` cached in pinned host memory; beyond + # that adapters fall back to their disk_path on next use. + max_loras_cpu: int = 16 + lora_buffer_groups: str = "attn,mlp,moe" + lora_moe_compressed_shared_outer: bool = False + lora_scheduling_policy: str = "lru" + @staticmethod def from_server_args( server_args: ServerArgs, @@ -147,6 +159,15 @@ def from_server_args( spec_num_tokens=server_args.speculative_num_draft_tokens, grammar_backend=server_args.grammar_backend, disable_capturable_grammar=server_args.disable_capturable_grammar, + enable_lora=server_args.enable_lora, + max_loras=server_args.max_loras, + max_lora_rank=server_args.max_lora_rank, + max_loras_cpu=server_args.max_loras_cpu or 4 * server_args.max_loras, + lora_buffer_groups=server_args.lora_buffer_groups, + lora_moe_compressed_shared_outer=( + server_args.lora_moe_compressed_shared_outer + ), + lora_scheduling_policy=server_args.lora_scheduling_policy, mamba_cache_chunk_size=server_args.mamba_cache_chunk_size, ) @@ -177,6 +198,11 @@ def __init__( self.draft_attn_backend = draft_attn_backend self.draft_token_to_kv_pool = draft_token_to_kv_pool + # LoRA — created below before CudaGraphWrapper so that the captured + # graphs include the LoRA delta path (NO_LORA_SLOT = no adapter). + self.lora_manager = None + self.request_lora_ids: dict[str, int] = {} + if config.spec_algo is not None: max_num_pages_per_req = ( config.context_len + config.spec_num_tokens + config.block_size - 1 @@ -274,6 +300,39 @@ def __init__( req_to_page=self.req_to_page, ) + if config.enable_lora: + from tokenspeed.runtime.lora.lora_manager import LoraManager + + model = self.model_runner.model + lora_dtype = next(model.parameters()).dtype + lora_device = next(model.parameters()).device + attn_mapping = model_runner.mapping.attn + tp_size = attn_mapping.tp_size + tp_rank = attn_mapping.tp_rank + # ``tp_group`` is the rank-tuple expected by comm_ops.all_reduce + # (it routes through the codebase's graph-capturable backend). + tp_group = attn_mapping.tp_group if tp_size > 1 else None + self.lora_manager = LoraManager( + model_config=model_runner.model_config.hf_config, + max_loras=config.max_loras, + max_lora_rank=config.max_lora_rank, + max_num_tokens=config.chunked_prefill_size, + max_loras_cpu=config.max_loras_cpu, + dtype=lora_dtype, + device=lora_device, + tp_rank=tp_rank, + tp_size=tp_size, + tp_group=tp_group, + lora_buffer_groups={ + group.strip() + for group in config.lora_buffer_groups.split(",") + if group.strip() + }, + lora_moe_compressed_shared_outer=( + config.lora_moe_compressed_shared_outer + ), + ) + self.forward_step = CudaGraphWrapper( forward_func=self._forward_step, attn_backend=attn_backend, @@ -287,6 +346,7 @@ def __init__( eager_grammar_buffers=self.eager_grammar_buffers, sampling_backend=self.sampling_backend, runtime_states=self.runtime_states, + lora_manager=self.lora_manager, ) self.execution_stream = torch.cuda.Stream() @@ -1069,6 +1129,21 @@ def execute_forward_op( ), gather_ids=gather_ids, ) + # Bind LoRA when adapters are active. ``prepare_loras`` + # writes per-segment metadata into the manager's persistent + # ``batch_info`` (the captured graph already references + # those tensors); we set ``ctx.lora_manager`` so the + # forward layers call into the LoRA delta path. + if self.lora_manager is not None and bs > 0: + lora_ids = [ + self.request_lora_ids.get(rid, 0) + for rid in forward_op.request_ids + ] + self.lora_manager.prepare_loras( + lora_ids, list(forward_op.input_lengths) + ) + if any(lid != 0 for lid in lora_ids): + ctx.lora_manager = self.lora_manager if self.config.data_parallel_size > 1: if dp_global_num_tokens is None: raise RuntimeError( diff --git a/python/tokenspeed/runtime/execution/model_runner.py b/python/tokenspeed/runtime/execution/model_runner.py index bb57b7ad5..62f0ad218 100644 --- a/python/tokenspeed/runtime/execution/model_runner.py +++ b/python/tokenspeed/runtime/execution/model_runner.py @@ -24,6 +24,7 @@ import torch +from tokenspeed.runtime.execution.context import bind_forward_context from tokenspeed.runtime.execution.weight_loader import WeightLoader from tokenspeed.runtime.utils import get_colorful_logger from tokenspeed.runtime.utils.env import global_server_args_dict_update @@ -136,11 +137,12 @@ def forward( if captured_hidden_states is not None: kwargs["captured_hidden_states"] = captured_hidden_states - return self.model.forward( - ctx, - input_ids, - positions, - out_cache_loc, - input_lengths, - **kwargs, - ) + with bind_forward_context(ctx): + return self.model.forward( + ctx, + input_ids, + positions, + out_cache_loc, + input_lengths, + **kwargs, + ) diff --git a/python/tokenspeed/runtime/layers/logits_processor.py b/python/tokenspeed/runtime/layers/logits_processor.py index 0b7a26865..ae264b9d4 100755 --- a/python/tokenspeed/runtime/layers/logits_processor.py +++ b/python/tokenspeed/runtime/layers/logits_processor.py @@ -28,7 +28,10 @@ from torch import nn from tokenspeed.runtime.distributed.comm_ops import all_gather_into_tensor -from tokenspeed.runtime.execution.context import ForwardContext +from tokenspeed.runtime.execution.context import ( + ForwardContext, + get_current_lora_manager, +) from tokenspeed.runtime.execution.forward_batch_info import ( CaptureHiddenMode, ForwardMode, @@ -396,6 +399,10 @@ def _get_logits( if self.logit_scale is not None: logits.mul_(self.logit_scale) + lora_manager = get_current_lora_manager() + if lora_manager is not None and lora_manager.enable_head_lora: + logits = lora_manager.apply_lm_head_lora(hidden_states, logits) + if self.tp_size > 1 and not self.skip_all_gather: gathered_logits = torch.empty( self.tp_size * logits.size(0), diff --git a/python/tokenspeed/runtime/layers/moe/backends/E=128,inter_size=384,hidden_size=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=bf16_down.json b/python/tokenspeed/runtime/layers/moe/backends/E=128,inter_size=384,hidden_size=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=bf16_down.json new file mode 100644 index 000000000..1e28041de --- /dev/null +++ b/python/tokenspeed/runtime/layers/moe/backends/E=128,inter_size=384,hidden_size=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=bf16_down.json @@ -0,0 +1,11 @@ +{ + "1": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "USE_TMA": false}, + "8": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "USE_TMA": false}, + "16": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "USE_TMA": false}, + "32": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "USE_TMA": false}, + "64": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "USE_TMA": false}, + "128": {"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 64, "GROUP_SIZE_M": 1, "USE_TMA": false}, + "129": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, "USE_TMA": false}, + "512": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, "USE_TMA": false}, + "2048": {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32, "GROUP_SIZE_M": 8, "USE_TMA": false} +} diff --git a/python/tokenspeed/runtime/layers/moe/backends/base.py b/python/tokenspeed/runtime/layers/moe/backends/base.py index 1dfe8e51d..b1f7b3fa2 100644 --- a/python/tokenspeed/runtime/layers/moe/backends/base.py +++ b/python/tokenspeed/runtime/layers/moe/backends/base.py @@ -95,6 +95,10 @@ def supports_deferred_finalize(self) -> bool: """ return False + @property + def supports_moe_lora(self) -> bool: + return False + @property def topk_output_format(self) -> TopKOutputFormat: return TopKOutputFormat.STANDARD diff --git a/python/tokenspeed/runtime/layers/moe/backends/fp8/triton.py b/python/tokenspeed/runtime/layers/moe/backends/fp8/triton.py index 4dc4ebccb..5cd0de555 100644 --- a/python/tokenspeed/runtime/layers/moe/backends/fp8/triton.py +++ b/python/tokenspeed/runtime/layers/moe/backends/fp8/triton.py @@ -78,6 +78,7 @@ def forward( topk_output, num_global_tokens, max_num_tokens_per_gpu, + moe_lora_context=None, ): del num_global_tokens, max_num_tokens_per_gpu return triton_forward( @@ -88,7 +89,12 @@ def forward( layer, hidden_states, topk_output, + moe_lora_context=moe_lora_context, ) + @property + def supports_moe_lora(self) -> bool: + return True + __all__ = ["Fp8TritonBackend"] diff --git a/python/tokenspeed/runtime/layers/moe/backends/triton_common.py b/python/tokenspeed/runtime/layers/moe/backends/triton_common.py index c67208400..4de5ac6c4 100644 --- a/python/tokenspeed/runtime/layers/moe/backends/triton_common.py +++ b/python/tokenspeed/runtime/layers/moe/backends/triton_common.py @@ -121,6 +121,7 @@ def triton_forward( layer: nn.Module, hidden_states: torch.Tensor, topk_output: object, + moe_lora_context=None, ) -> torch.Tensor: from tokenspeed.runtime.layers.activation import silu_and_mul @@ -193,6 +194,11 @@ def triton_forward( dtype=dtype, ) + # Prefetch gate_up LoRA A-shrink on secondary stream, concurrent with gate_up_gemm. + if moe_lora_context is not None: + moe_lora_context.launch_gate_up_shrink( + layer.layer_index, hidden_states, topk_ids + ) gate_up_gemm( A=hidden_states, B=layer.w13_weight, @@ -208,6 +214,14 @@ def triton_forward( b_use_tma=gate_up_moe_use_tma, c_sorted=down_moe_use_tma, ) + if moe_lora_context is not None: + moe_lora_context.apply_gate_up_lora( + layer.layer_index, + hidden_states, + topk_ids, + intermediate_cache1, + sorted_token_ids=sorted_token_ids if down_moe_use_tma else None, + ) if activation == "silu": silu_and_mul( @@ -217,6 +231,14 @@ def triton_forward( else: raise ValueError(f"Unsupported activation: {activation}") + # Prefetch down LoRA A-shrink on secondary stream, concurrent with down_gemm. + if moe_lora_context is not None and not down_moe_use_tma: + moe_lora_context.launch_down_shrink( + layer.layer_index, + intermediate_cache2, + topk_ids, + m_tokens * top_k, + ) down_gemm( A=intermediate_cache2, B=layer.w2_weight, @@ -231,6 +253,15 @@ def triton_forward( a_use_tma=down_moe_use_tma, b_use_tma=down_moe_use_tma, ) + if moe_lora_context is not None: + moe_lora_context.apply_down_lora( + layer.layer_index, + intermediate_cache2, + topk_ids, + topk_weights, + intermediate_cache3, + sorted_token_ids=sorted_token_ids if down_moe_use_tma else None, + ) out_hidden_states = torch.empty_like(hidden_states) # Current limitation: Should avoid using runtime shapes as traits diff --git a/python/tokenspeed/runtime/layers/moe/backends/unquantized/triton.py b/python/tokenspeed/runtime/layers/moe/backends/unquantized/triton.py index 77cc34b56..f44840e66 100644 --- a/python/tokenspeed/runtime/layers/moe/backends/unquantized/triton.py +++ b/python/tokenspeed/runtime/layers/moe/backends/unquantized/triton.py @@ -60,6 +60,7 @@ def forward( topk_output, num_global_tokens, max_num_tokens_per_gpu, + moe_lora_context=None, ): del num_global_tokens, max_num_tokens_per_gpu return triton_forward( @@ -70,7 +71,12 @@ def forward( layer, hidden_states, topk_output, + moe_lora_context=moe_lora_context, ) + @property + def supports_moe_lora(self) -> bool: + return True + __all__ = ["Bf16TritonBackend"] diff --git a/python/tokenspeed/runtime/layers/moe/backends/w8a8_fp8/triton.py b/python/tokenspeed/runtime/layers/moe/backends/w8a8_fp8/triton.py index 35061ef35..fec9f1e7a 100644 --- a/python/tokenspeed/runtime/layers/moe/backends/w8a8_fp8/triton.py +++ b/python/tokenspeed/runtime/layers/moe/backends/w8a8_fp8/triton.py @@ -78,6 +78,7 @@ def forward( topk_output, num_global_tokens, max_num_tokens_per_gpu, + moe_lora_context=None, ): del num_global_tokens, max_num_tokens_per_gpu return triton_forward( @@ -88,8 +89,13 @@ def forward( layer, hidden_states, topk_output, + moe_lora_context=moe_lora_context, ) + @property + def supports_moe_lora(self) -> bool: + return True + W8A8Fp8TritonBackend = W8A8PerTokenPerChannelFp8TritonBackend diff --git a/python/tokenspeed/runtime/layers/moe/layer.py b/python/tokenspeed/runtime/layers/moe/layer.py index ef5790969..2f3e2da8d 100755 --- a/python/tokenspeed/runtime/layers/moe/layer.py +++ b/python/tokenspeed/runtime/layers/moe/layer.py @@ -21,6 +21,7 @@ import torch +from tokenspeed.runtime.execution.context import get_current_lora_manager from tokenspeed.runtime.layers.activation import SwigluArg from tokenspeed.runtime.layers.moe.core import MoELayerSpec, select_backend from tokenspeed.runtime.layers.moe.utils import get_all2all_backend @@ -155,6 +156,7 @@ def forward( num_global_tokens: int, max_num_tokens_per_gpu: int, do_finalize: bool = True, + lora_manager=None, ): # Only pass ``do_finalize`` through when the caller actually wants # the deferred path. Other backends do not accept this kwarg; @@ -166,6 +168,21 @@ def forward( self.backend.supports_deferred_finalize ), f"{type(self.backend).__name__} does not support do_finalize=False" kwargs["do_finalize"] = False + if lora_manager is None: + lora_manager = get_current_lora_manager() + if lora_manager is not None: + if not self.backend.supports_moe_lora: + raise NotImplementedError( + f"{type(self.backend).__name__} does not support MoE LoRA; " + "use the Triton backend instead." + ) + if self.ep_size != 1: + raise NotImplementedError( + "MoE LoRA currently supports local/Tensor-Parallel MoE only; " + "expert-parallel dispatch needs the LoRA slot map to be " + "dispatched with tokens." + ) + kwargs["moe_lora_context"] = lora_manager.moe_lora_context return self.backend.forward( self, hidden_states, diff --git a/python/tokenspeed/runtime/lora/__init__.py b/python/tokenspeed/runtime/lora/__init__.py new file mode 100644 index 000000000..57692962f --- /dev/null +++ b/python/tokenspeed/runtime/lora/__init__.py @@ -0,0 +1,33 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""LoRA adapter serving runtime.""" + +from tokenspeed.runtime.lora.lora_config import LoraConfig + +__all__ = ["LoraConfig", "LoraRegistry"] + + +def __getattr__(name: str): + if name == "LoraRegistry": + from tokenspeed.runtime.lora.lora_registry import LoraRegistry + + return LoraRegistry + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/python/tokenspeed/runtime/lora/adapter_io.py b/python/tokenspeed/runtime/lora/adapter_io.py new file mode 100644 index 000000000..d92020c13 --- /dev/null +++ b/python/tokenspeed/runtime/lora/adapter_io.py @@ -0,0 +1,142 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""PEFT LoRA adapter loading and metadata helpers.""" + +from __future__ import annotations + +import json +import os +import re + +import torch + +PEFT_ATTN_MODULES = ("q_proj", "k_proj", "v_proj", "o_proj") +PEFT_MLP_MODULES = ("gate_proj", "up_proj", "down_proj") +PEFT_EXPERT_MODULES = PEFT_MLP_MODULES +PEFT_HEAD_MODULE = "lm_head" +PEFT_MODULES = (*PEFT_ATTN_MODULES, *PEFT_MLP_MODULES) + +# Sentinel layer_id used for model-global modules (e.g. lm_head) that have no +# per-layer index in AdapterWeights. +LORA_HEAD_LAYER_ID = -1 + +AdapterWeights = dict[int, dict[str, tuple[torch.Tensor, torch.Tensor]]] + + +def resolve_adapter_weight_path(adapter_path: str) -> str: + safetensors_path = os.path.join(adapter_path, "adapter_model.safetensors") + return safetensors_path if os.path.exists(safetensors_path) else adapter_path + + +def load_adapter_weights(adapter_path: str) -> AdapterWeights: + return parse_adapter_weights( + load_safetensors(resolve_adapter_weight_path(adapter_path)) + ) + + +def load_safetensors(path: str) -> dict[str, torch.Tensor]: + from safetensors import safe_open + + tensors: dict[str, torch.Tensor] = {} + with safe_open(path, framework="pt", device="cpu") as f: + for key in f.keys(): + tensors[key] = f.get_tensor(key) + return tensors + + +def parse_adapter_weights(tensors: dict[str, torch.Tensor]) -> AdapterWeights: + """Return ``{layer_id: {module_name: (lora_A, lora_B)}}``. + + Matches attention (``self_attn.{q,k,v,o}_proj``), MLP + (``mlp.{gate,up,down}_proj``), and lm_head PEFT module names. + lm_head weights are stored under ``LORA_HEAD_LAYER_ID`` (-1). + """ + dense_pattern = re.compile( + r"base_model\.model\.model\.layers\.(\d+)\." + r"(?:self_attn|mlp)\." + r"(q_proj|k_proj|v_proj|o_proj|gate_proj|up_proj|down_proj)\." + r"lora_(A|B)\.weight" + ) + expert_pattern = re.compile( + r"base_model\.model\.model\.layers\.(\d+)\." + r"mlp\.experts\.(\d+)\." + r"(gate_proj|up_proj|down_proj)\." + r"lora_(A|B)\.weight" + ) + expert_3d_pattern = re.compile( + r"base_model\.model\.model\.layers\.(\d+)\." + r"mlp\.experts\." + r"(w1|w2|w3)\." + r"lora_(A|B)\.weight" + ) + # PEFT uses ``lora_embedding_A/B`` (no ``.weight`` suffix) for modules + # treated as embedding tables (lm_head, embed_tokens). + head_pattern = re.compile( + r"base_model\.model\.lm_head\." r"(?:lora_(A|B)\.weight|lora_embedding_(A|B))" + ) + weights: dict[int, dict[str, dict[str, torch.Tensor]]] = {} + for key, tensor in tensors.items(): + m = dense_pattern.match(key) + if m: + layer_id, module, ab = int(m.group(1)), m.group(2), m.group(3) + else: + m = expert_pattern.match(key) + if m: + layer_id = int(m.group(1)) + module = f"experts.{int(m.group(2))}.{m.group(3)}" + ab = m.group(4) + else: + m = expert_3d_pattern.match(key) + if m: + layer_id = int(m.group(1)) + module = f"experts.{m.group(2)}" + ab = m.group(3) + else: + m = head_pattern.match(key) + if not m: + continue + layer_id = LORA_HEAD_LAYER_ID + module = PEFT_HEAD_MODULE + ab = m.group(1) or m.group(2) + weights.setdefault(layer_id, {}).setdefault(module, {})[ab] = tensor + + result: AdapterWeights = {} + for layer_id, modules in weights.items(): + result[layer_id] = {} + for module, ab_dict in modules.items(): + result[layer_id][module] = (ab_dict["A"], ab_dict["B"]) + return result + + +def read_adapter_scaling(adapter_path: str | None, rank: int) -> float: + if adapter_path is None: + return 1.0 + config_file = os.path.join(adapter_path, "adapter_config.json") + if not os.path.exists(config_file): + return 1.0 + try: + with open(config_file) as f: + cfg = json.load(f) + alpha = float(cfg.get("lora_alpha", rank)) + r = int(cfg.get("r", rank)) + return alpha / r if r > 0 else 1.0 + except Exception: + return 1.0 diff --git a/python/tokenspeed/runtime/lora/lora_batch.py b/python/tokenspeed/runtime/lora/lora_batch.py new file mode 100644 index 000000000..23064db4b --- /dev/null +++ b/python/tokenspeed/runtime/lora/lora_batch.py @@ -0,0 +1,98 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Batch metadata structures for segmented LoRA kernels.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import torch + +NO_LORA_SLOT = -1 + + +@dataclass +class LoraBatchInfo: + """Per-step segment metadata read by the LoRA kernels. + + All tensors live on the LoRA device. When the captured CUDA graph needs + persistent storage, :class:`LoraManager` pre-allocates these tensors with + maximum sizes; runtime fills the prefix and updates ``bs`` / ``max_len``. + """ + + bs: int + num_segments: int + max_len: int + seg_lens: torch.Tensor # (num_segments,) int32 + seg_indptr: torch.Tensor # (num_segments + 1,) int32 + weight_indices: torch.Tensor # (num_segments,) int32 + lora_ranks: torch.Tensor # (n_slots,) int32; NO_LORA_SLOT means base model + scalings: torch.Tensor # (n_slots,) float32 + permutation: torch.Tensor | None = None # unused (no sort by adapter yet) + # Adapter-group metadata for lora_expand_grouped_v2_fwd (decode path only). + # Populated by prepare_loras when max_len == 1. + sort_order: torch.Tensor | None = None # (bs,) int64 + group_slots: torch.Tensor | None = None # (num_groups,) int32 + group_starts: torch.Tensor | None = None # (num_groups,) int32 + group_sizes: torch.Tensor | None = None # (num_groups,) int32 + num_groups: int = 0 + # Largest group size; pre-computed on CPU so the kernel grid avoids a + # GPU-CPU sync. Equals max(group_sizes) when num_groups > 0, else 0. + max_group_size: int = 0 + # Host-only fast-path metadata. Non-negative iff every segment in this step + # uses the same real adapter slot; NO_LORA_SLOT means mixed/base-only. + single_lora_slot: int = NO_LORA_SLOT + # Host-only active rank for ``single_lora_slot``. Zero when no single + # nonzero adapter slot is active. + single_lora_rank: int = 0 + # Host-only metadata for the multi-adapter batched CuTeDSL fast path. + # Non-negative iff segments are equal-length, slots are consecutive, and + # all participating slots share rank/scaling. + multi_lora_start_slot: int = NO_LORA_SLOT + multi_lora_count: int = 0 + multi_lora_segment_len: int = 0 + multi_lora_rank: int = 0 + + +def build_decode_lora_groups( + per_request_slots: list[int], +) -> tuple[list[int], list[int], list[int], list[int]]: + """Group decode requests by adapter slot for the grouped expand kernel. + + Returns ``(sort_order, group_slots, group_starts, group_sizes)``. + ``group_starts`` are offsets into ``sort_order``. + """ + sort_order = sorted( + (i for i, slot in enumerate(per_request_slots) if slot != NO_LORA_SLOT), + key=lambda i: per_request_slots[i], + ) + group_slots: list[int] = [] + group_starts: list[int] = [] + group_sizes: list[int] = [] + for pos, orig in enumerate(sort_order): + slot = per_request_slots[orig] + if not group_slots or group_slots[-1] != slot: + group_slots.append(slot) + group_starts.append(pos) + group_sizes.append(1) + else: + group_sizes[-1] += 1 + return sort_order, group_slots, group_starts, group_sizes diff --git a/python/tokenspeed/runtime/lora/lora_buffers.py b/python/tokenspeed/runtime/lora/lora_buffers.py new file mode 100644 index 000000000..2b024f4d8 --- /dev/null +++ b/python/tokenspeed/runtime/lora/lora_buffers.py @@ -0,0 +1,332 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""GPU-resident LoRA weight buffer layout and slot loading.""" + +from __future__ import annotations + +import torch + +from tokenspeed.runtime.lora.adapter_io import ( + LORA_HEAD_LAYER_ID, + PEFT_HEAD_MODULE, + AdapterWeights, +) + +LORA_BUFFER_GROUPS = frozenset({"attn", "mlp", "moe", "lm_head"}) + + +class LoraWeightBuffers: + def __init__( + self, + *, + n_layers: int, + n_slots: int, + max_lora_rank: int, + hidden_size: int, + q_size_per_tp: int, + kv_size_per_tp: int, + o_in_per_tp: int, + intermediate_per_tp: int, + vocab_per_tp: int, + dtype: torch.dtype, + device: torch.device, + tp_rank: int, + tp_size: int, + buffer_groups: set[str] | frozenset[str] = LORA_BUFFER_GROUPS, + ) -> None: + self.n_layers = n_layers + self.n_slots = n_slots + self.max_lora_rank = max_lora_rank + self.hidden_size = hidden_size + self.q_size_per_tp = q_size_per_tp + self.kv_size_per_tp = kv_size_per_tp + self.o_in_per_tp = o_in_per_tp + self.intermediate_per_tp = intermediate_per_tp + self.vocab_per_tp = vocab_per_tp + self.dtype = dtype + self.device = device + self.tp_rank = tp_rank + self.tp_size = tp_size + unknown_groups = set(buffer_groups) - LORA_BUFFER_GROUPS + if unknown_groups: + raise ValueError(f"Unknown LoRA buffer groups: {sorted(unknown_groups)}") + self.buffer_groups = frozenset(buffer_groups) + self.enable_attn = "attn" in self.buffer_groups + self.enable_mlp = "mlp" in self.buffer_groups + self.enable_head = "lm_head" in self.buffer_groups + + self.qkv_A_buffers: list[torch.Tensor] = [] + self.qkv_B_buffers: list[torch.Tensor] = [] + self.o_A_buffers: list[torch.Tensor] = [] + self.o_B_buffers: list[torch.Tensor] = [] + self.gate_up_A_buffers: list[torch.Tensor] = [] + self.gate_up_B_buffers: list[torch.Tensor] = [] + self.down_A_buffers: list[torch.Tensor] = [] + self.down_B_buffers: list[torch.Tensor] = [] + # lm_head LoRA — single pair of buffers (not per-layer). + # A: (n_slots, r, hidden) — replicated across TP ranks. + # B: (n_slots, vocab_per_tp, r) — column-parallel shard. + self.lm_head_A_buffer: torch.Tensor + self.lm_head_B_buffer: torch.Tensor + + self.qkv_output_offset = torch.tensor( + [ + 0, + q_size_per_tp, + q_size_per_tp + kv_size_per_tp, + q_size_per_tp + 2 * kv_size_per_tp, + ], + dtype=torch.int32, + device=device, + ) + self.max_qkv_out_dim = max(q_size_per_tp, kv_size_per_tp) + + self.o_slice_offsets = torch.tensor( + [0, hidden_size], dtype=torch.int32, device=device + ) + self.gate_up_slice_offsets = torch.tensor( + [0, intermediate_per_tp, 2 * intermediate_per_tp], + dtype=torch.int32, + device=device, + ) + self.down_slice_offsets = torch.tensor( + [0, hidden_size], dtype=torch.int32, device=device + ) + + self._alloc() + + def _alloc(self) -> None: + r = self.max_lora_rank + h = self.hidden_size + q = self.q_size_per_tp + kv = self.kv_size_per_tp + o_in = self.o_in_per_tp + i = self.intermediate_per_tp + v = self.vocab_per_tp + n = self.n_slots + + for _ in range(self.n_layers): + if self.enable_attn: + self.qkv_A_buffers.append( + torch.zeros((n, 3 * r, h), dtype=self.dtype, device=self.device) + ) + self.qkv_B_buffers.append( + torch.zeros( + (n, q + 2 * kv, r), dtype=self.dtype, device=self.device + ) + ) + self.o_A_buffers.append( + torch.zeros((n, r, o_in), dtype=self.dtype, device=self.device) + ) + self.o_B_buffers.append( + torch.zeros((n, h, r), dtype=self.dtype, device=self.device) + ) + if self.enable_mlp: + self.gate_up_A_buffers.append( + torch.zeros((n, 2 * r, h), dtype=self.dtype, device=self.device) + ) + self.gate_up_B_buffers.append( + torch.zeros((n, 2 * i, r), dtype=self.dtype, device=self.device) + ) + self.down_A_buffers.append( + torch.zeros((n, r, i), dtype=self.dtype, device=self.device) + ) + self.down_B_buffers.append( + torch.zeros((n, h, r), dtype=self.dtype, device=self.device) + ) + if self.enable_head: + self.lm_head_A_buffer = torch.zeros( + (n, r, h), dtype=self.dtype, device=self.device + ) + self.lm_head_B_buffer = torch.zeros( + (n, v, r), dtype=self.dtype, device=self.device + ) + + def load_adapter_to_slot( + self, + cpu_weights: AdapterWeights, + slot: int, + rank: int, + ) -> None: + for layer_id, modules in cpu_weights.items(): + if layer_id == LORA_HEAD_LAYER_ID: + if PEFT_HEAD_MODULE in modules: + self._load_lm_head_to_slot(modules[PEFT_HEAD_MODULE], slot, rank) + continue + for mod, (lora_A_full, lora_B_full) in modules.items(): + if mod.startswith("experts."): + continue + self._check_module_enabled(mod) + lora_A_shard_cpu, lora_B_shard_cpu = self.shard_weights( + mod, lora_A_full, lora_B_full + ) + r = min(lora_A_shard_cpu.shape[0], rank) + lora_A_shard = lora_A_shard_cpu[:r].to( + device=self.device, + dtype=self.dtype, + non_blocking=True, + ) + lora_B_shard = lora_B_shard_cpu[:, :r].to( + device=self.device, + dtype=self.dtype, + non_blocking=True, + ) + + if mod in ("q_proj", "k_proj", "v_proj"): + qkv_idx = ("q_proj", "k_proj", "v_proj").index(mod) + rank_off = qkv_idx * r + out_off, out_size = self.qkv_b_slice(mod) + self.qkv_A_buffers[layer_id][ + slot, rank_off : rank_off + r, : + ].copy_(lora_A_shard, non_blocking=True) + self.qkv_B_buffers[layer_id][ + slot, out_off : out_off + out_size, :r + ].copy_(lora_B_shard, non_blocking=True) + elif mod == "o_proj": + self.o_A_buffers[layer_id][slot, :r, :].copy_( + lora_A_shard, non_blocking=True + ) + self.o_B_buffers[layer_id][slot, :, :r].copy_( + lora_B_shard, non_blocking=True + ) + elif mod in ("gate_proj", "up_proj"): + gate_up_idx = 0 if mod == "gate_proj" else 1 + rank_off = gate_up_idx * r + out_off = gate_up_idx * self.intermediate_per_tp + self.gate_up_A_buffers[layer_id][ + slot, rank_off : rank_off + r, : + ].copy_(lora_A_shard, non_blocking=True) + self.gate_up_B_buffers[layer_id][ + slot, out_off : out_off + self.intermediate_per_tp, :r + ].copy_(lora_B_shard, non_blocking=True) + else: + self.down_A_buffers[layer_id][slot, :r, :].copy_( + lora_A_shard, non_blocking=True + ) + self.down_B_buffers[layer_id][slot, :, :r].copy_( + lora_B_shard, non_blocking=True + ) + + def _load_lm_head_to_slot( + self, + ab: tuple[torch.Tensor, torch.Tensor], + slot: int, + rank: int, + ) -> None: + if not self.enable_head: + raise ValueError( + "Adapter targets lm_head, but LoRA buffer group 'head' is disabled." + ) + lora_A_full, lora_B_full = ab + lora_A_cpu, lora_B_cpu = self.shard_weights( + PEFT_HEAD_MODULE, lora_A_full, lora_B_full + ) + r = min(lora_A_cpu.shape[0], rank) + self.lm_head_A_buffer[slot, :r, :].copy_( + lora_A_cpu[:r].to(device=self.device, dtype=self.dtype, non_blocking=True), + non_blocking=True, + ) + self.lm_head_B_buffer[slot, :, :r].copy_( + lora_B_cpu[:, :r].to( + device=self.device, dtype=self.dtype, non_blocking=True + ), + non_blocking=True, + ) + + def zero_slot(self, slot: int) -> None: + if self.enable_attn: + for layer_id in range(self.n_layers): + self.qkv_A_buffers[layer_id][slot].zero_() + self.qkv_B_buffers[layer_id][slot].zero_() + self.o_A_buffers[layer_id][slot].zero_() + self.o_B_buffers[layer_id][slot].zero_() + if self.enable_mlp: + for layer_id in range(self.n_layers): + self.gate_up_A_buffers[layer_id][slot].zero_() + self.gate_up_B_buffers[layer_id][slot].zero_() + self.down_A_buffers[layer_id][slot].zero_() + self.down_B_buffers[layer_id][slot].zero_() + if self.enable_head: + self.lm_head_A_buffer[slot].zero_() + self.lm_head_B_buffer[slot].zero_() + + def _check_module_enabled(self, module: str) -> None: + if module in ("q_proj", "k_proj", "v_proj", "o_proj"): + if not self.enable_attn: + raise ValueError( + f"Adapter targets {module}, but LoRA buffer group 'attn' " + "is disabled." + ) + return + if module in ("gate_proj", "up_proj", "down_proj"): + if not self.enable_mlp: + raise ValueError( + f"Adapter targets {module}, but LoRA buffer group 'mlp' " + "is disabled." + ) + return + if module == PEFT_HEAD_MODULE: + if not self.enable_head: + raise ValueError( + "Adapter targets lm_head, but LoRA buffer group 'head' " + "is disabled." + ) + return + raise ValueError(f"Unsupported dense LoRA module: {module}") + + def qkv_b_slice(self, module: str) -> tuple[int, int]: + """Return ``(offset, size)`` of a projection inside fused QKV B.""" + if module == "q_proj": + return 0, self.q_size_per_tp + if module == "k_proj": + return self.q_size_per_tp, self.kv_size_per_tp + return self.q_size_per_tp + self.kv_size_per_tp, self.kv_size_per_tp + + def shard_weights( + self, + module: str, + lora_A: torch.Tensor, + lora_B: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + if self.tp_size == 1: + return lora_A, lora_B + # Column-parallel (attn q/k/v, MLP gate/up, lm_head): shard B along output dim. + if module in ( + "q_proj", + "k_proj", + "v_proj", + "gate_proj", + "up_proj", + PEFT_HEAD_MODULE, + ): + out_total = lora_B.shape[0] + out_per = out_total // self.tp_size + return ( + lora_A, + lora_B[self.tp_rank * out_per : (self.tp_rank + 1) * out_per], + ) + # Row-parallel (attn o_proj, MLP down_proj): shard A along input dim. + in_total = lora_A.shape[1] + in_per = in_total // self.tp_size + return ( + lora_A[:, self.tp_rank * in_per : (self.tp_rank + 1) * in_per], + lora_B, + ) diff --git a/python/tokenspeed/runtime/lora/lora_cache.py b/python/tokenspeed/runtime/lora/lora_cache.py new file mode 100644 index 000000000..185ca791b --- /dev/null +++ b/python/tokenspeed/runtime/lora/lora_cache.py @@ -0,0 +1,189 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Tier-2 CPU LoRA adapter cache with async disk prefetch.""" + +from __future__ import annotations + +import threading +from collections import OrderedDict +from collections.abc import Callable +from concurrent.futures import Future, ThreadPoolExecutor + +import torch + +from tokenspeed.runtime.lora.adapter_io import AdapterWeights, load_adapter_weights +from tokenspeed.runtime.utils import get_colorful_logger + +logger = get_colorful_logger(__name__) + + +class LoraCpuCache: + def __init__( + self, + *, + capacity: int, + is_gpu_resident: Callable[[str], bool], + ) -> None: + self.capacity = capacity + self.is_gpu_resident = is_gpu_resident + self.cache: dict[str, AdapterWeights] = {} + self.lru: OrderedDict[str, None] = OrderedDict() + self.adapter_paths: dict[str, str] = {} + self.loader_executor = ThreadPoolExecutor( + max_workers=2, thread_name_prefix="lora-loader" + ) + self.lock = threading.Lock() + self.pending_loads: dict[str, Future] = {} + + def set_path(self, name: str, adapter_path: str) -> None: + self.adapter_paths[name] = adapter_path + + def remove(self, name: str) -> None: + self.evict(name) + self.adapter_paths.pop(name, None) + with self.lock: + self.pending_loads.pop(name, None) + + def prefetch(self, name: str) -> None: + """Best-effort async warm of the CPU pool for *name*.""" + with self.lock: + if name in self.cache: + self.lru.move_to_end(name) + return + if name in self.pending_loads: + return + adapter_path = self.adapter_paths.get(name) + if adapter_path is None: + return + fut = self.loader_executor.submit( + self._async_load_weights, name, adapter_path + ) + self.pending_loads[name] = fut + + def ensure( + self, + name: str, + weights: AdapterWeights | None = None, + ) -> None: + """Synchronously ensure *name* is CPU-resident.""" + with self.lock: + if name in self.cache: + self.lru.move_to_end(name) + return + pending = self.pending_loads.get(name) + + if pending is not None: + pending.result() + with self.lock: + if name in self.cache: + self.lru.move_to_end(name) + return + + if weights is None: + adapter_path = self.adapter_paths.get(name) + if adapter_path is None: + raise KeyError(f"Adapter '{name}' has no recorded disk path.") + weights = load_adapter_weights(adapter_path) + + with self.lock: + if name in self.cache: + self.lru.move_to_end(name) + return + self._install_locked(name, weights) + + def evict(self, name: str) -> None: + with self.lock: + self._evict_locked(name) + + def _async_load_weights(self, name: str, adapter_path: str) -> None: + try: + weights = load_adapter_weights(adapter_path) + except Exception: + logger.exception("Async LoRA load failed for '%s'", name) + with self.lock: + self.pending_loads.pop(name, None) + return + with self.lock: + try: + if name not in self.cache: + self._install_locked(name, weights) + finally: + self.pending_loads.pop(name, None) + + def _install_locked(self, name: str, weights: AdapterWeights) -> None: + while len(self.cache) >= self.capacity: + evicted = False + # Prefer evicting non-GPU-resident entries first: they cost a disk + # read to bring back, while GPU-resident ones cost nothing until + # their GPU slot is also evicted. + for stage in ("non_gpu", "gpu_resident"): + for candidate in list(self.lru.keys()): + if candidate == name: + continue + is_gpu = self.is_gpu_resident(candidate) + if stage == "non_gpu" and is_gpu: + continue + self._evict_locked(candidate) + evicted = True + break + if evicted: + break + if not evicted: + raise RuntimeError( + f"CPU LoRA pool is full ({len(self.cache)}/{self.capacity}) " + "and no evictable entry was found. " + f"cpu_lru={list(self.lru.keys())}. " + "Increase max_loras_cpu." + ) + self.cache[name] = self._pin_weights(weights) + self.lru[name] = None + + def _evict_locked(self, name: str) -> None: + if name in self.cache: + del self.cache[name] + self.lru.pop(name, None) + logger.debug( + "Evicted '%s' from CPU pool (now %d/%d)", + name, + len(self.cache), + self.capacity, + ) + + def _pin_weights(self, weights: AdapterWeights) -> AdapterWeights: + return { + layer_id: { + module: ( + self._pin_tensor(lora_A), + self._pin_tensor(lora_B), + ) + for module, (lora_A, lora_B) in modules.items() + } + for layer_id, modules in weights.items() + } + + @staticmethod + def _pin_tensor(tensor: torch.Tensor) -> torch.Tensor: + if tensor.device.type != "cpu" or tensor.is_pinned(): + return tensor + try: + return tensor.pin_memory() + except RuntimeError: + return tensor diff --git a/python/tokenspeed/runtime/lora/lora_config.py b/python/tokenspeed/runtime/lora/lora_config.py new file mode 100644 index 000000000..7938b7d38 --- /dev/null +++ b/python/tokenspeed/runtime/lora/lora_config.py @@ -0,0 +1,79 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""LoRA adapter configuration and metadata.""" + +from __future__ import annotations + +import json +import os +from dataclasses import dataclass, field +from typing import Optional + + +@dataclass +class LoraConfig: + """Configuration for a single LoRA adapter. + + Loaded from the adapter's ``adapter_config.json`` (PEFT format). + """ + + # Identifier used at request time (e.g. "sql-expert") + name: str + + # Filesystem path to the adapter directory or file + path: str + + # LoRA rank (r) + r: int = 16 + + # LoRA alpha scaling factor + lora_alpha: int = 16 + + # Target modules (e.g. ["q_proj", "v_proj"]) + target_modules: list[str] = field(default_factory=list) + + # Base model name for compatibility checking + base_model_name_or_path: Optional[str] = None + + @classmethod + def from_path(cls, name: str, path: str) -> "LoraConfig": + """Load LoraConfig from a PEFT adapter directory.""" + config_file = os.path.join(path, "adapter_config.json") + if not os.path.exists(config_file): + raise FileNotFoundError( + f"adapter_config.json not found at {config_file}. " + "The path must point to a PEFT-format adapter directory." + ) + with open(config_file) as f: + raw = json.load(f) + + return cls( + name=name, + path=path, + r=raw.get("r", 16), + lora_alpha=raw.get("lora_alpha", 16), + target_modules=raw.get("target_modules") or [], + base_model_name_or_path=raw.get("base_model_name_or_path"), + ) + + @property + def scaling(self) -> float: + return self.lora_alpha / self.r if self.r > 0 else 1.0 diff --git a/python/tokenspeed/runtime/lora/lora_manager.py b/python/tokenspeed/runtime/lora/lora_manager.py new file mode 100644 index 000000000..01c529dfc --- /dev/null +++ b/python/tokenspeed/runtime/lora/lora_manager.py @@ -0,0 +1,1155 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""LoRA adapter weight manager (segment-grouped Triton path). + +Adapted from sglang/Punica's S-LoRA design. + +Memory layout +------------- +For each layer the manager owns: + +* ``qkv_A_buffers[layer]``: ``(n_slots, 3 * max_rank, hidden)`` — fused + q_proj/k_proj/v_proj A matrices, stack-major (q first, then k, then v). +* ``qkv_B_buffers[layer]``: ``(n_slots, q_per_tp + 2 * kv_per_tp, max_rank)`` + — fused output-side, ``[q_per_tp | kv_per_tp | kv_per_tp]`` along dim 1. +* ``o_A_buffers[layer]``: ``(n_slots, max_rank, in_per_tp)`` — row-parallel + A, sharded along input dim. +* ``o_B_buffers[layer]``: ``(n_slots, hidden, max_rank)`` — full B. + +No-LoRA requests use ``NO_LORA_SLOT`` (-1), matching vLLM's convention. +Real adapters occupy slots ``0 .. max_loras - 1``. + +Tensor parallelism +------------------ +* QKV is column-parallel: A is full, B is sharded along output dim + (``q_per_tp + 2 * kv_per_tp``). No collective inside the LoRA path. +* O is row-parallel: A is sharded along input dim, B is full. The host + module (qwen3 ``o_proj``) runs with ``reduce_results=False`` and has its + partial sum all-reduced downstream by ``post_attention_layernorm``; the + LoRA delta rides that same reduction (full ``B @ lora_a`` is added to the + partial output and the downstream reduce sums it ``tp_size`` times — see + ``apply_o_lora`` for the resulting numerical caveat). +""" + +from __future__ import annotations + +import os +from collections import OrderedDict + +import torch +from tokenspeed_kernel.ops.lora.triton import ( + lora_expand_fwd, + lora_expand_grouped_v2_fwd, + lora_expand_prefill_fwd, + lora_gate_up_expand_fwd, + lora_qkv_expand_fwd, + lora_shrink_fwd, + lora_shrink_prefill_fwd, +) + +from tokenspeed.runtime.lora.adapter_io import ( + LORA_HEAD_LAYER_ID, + PEFT_HEAD_MODULE, + PEFT_MODULES, + read_adapter_scaling, + resolve_adapter_weight_path, +) +from tokenspeed.runtime.lora.lora_batch import ( + NO_LORA_SLOT, + LoraBatchInfo, + build_decode_lora_groups, +) +from tokenspeed.runtime.lora.lora_buffers import LORA_BUFFER_GROUPS, LoraWeightBuffers +from tokenspeed.runtime.lora.lora_cache import LoraCpuCache +from tokenspeed.runtime.lora.moe_lora import MoeLoraBuffers, MoeLoraContext +from tokenspeed.runtime.utils import get_colorful_logger + +# Segments longer than this use the prefill (chunked-SGMV) expand kernel, +# which specialises strides and loop counts at compile time. Shorter +# segments (decode) use the decode-tuned kernels. Threshold chosen from +# benchmarks: chunked-SGMV wins above ~32 tokens/segment at rank ≥ 64. +_CHUNKED_THRESHOLD = 32 + +# With max_group_size-based grid, the kernel degenerates to the segmented +# layout when every group has 1 token (n_unique = n), so no threshold is +# needed for correctness. Keep a minimum of 1 (always use grpv2). +_TRITON_GROUPED_DECODE_MIN_GROUP_SIZE = 1 + +logger = get_colorful_logger(__name__) + + +# ── Manager ───────────────────────────────────────────────────────────────── + + +def _use_triton_grouped_decode(bi: LoraBatchInfo) -> bool: + """Return whether grouped Triton decode expand should beat basic decode.""" + return ( + bi.single_lora_slot == NO_LORA_SLOT + and bi.num_groups > 0 + and bi.bs // bi.num_groups >= _TRITON_GROUPED_DECODE_MIN_GROUP_SIZE + ) + + +class LoraManager: + """Owns GPU-resident LoRA weights and dispatches the segmented-GEMM path. + + Public surface (used by the model + executor): + + * :meth:`load_adapter` / :meth:`unload_adapter` — adapter lifecycle. + * :attr:`batch_info` — persistent :class:`LoraBatchInfo` whose tensor + pointers are stable across forward steps (so they can be baked into + the captured CUDA graph). + * :meth:`prepare_loras` — fill the persistent batch_info for one step. + * :meth:`apply_qkv_lora` / :meth:`apply_o_lora` — Triton-backed deltas. + """ + + def __init__( + self, + model_config, + max_loras: int, + max_lora_rank: int, + max_num_tokens: int, + dtype: torch.dtype, + device: torch.device, + tp_rank: int = 0, + tp_size: int = 1, + tp_group=None, + max_loras_cpu: int | None = None, + lora_buffer_groups: set[str] | frozenset[str] = LORA_BUFFER_GROUPS, + lora_moe_compressed_shared_outer: bool = False, + ) -> None: + self.max_loras = max_loras + self.max_lora_rank = max_lora_rank + self.max_num_tokens = max_num_tokens + self.dtype = dtype + self.device = device + self.tp_rank = tp_rank + self.tp_size = tp_size + self.tp_group = tp_group + unknown_groups = set(lora_buffer_groups) - LORA_BUFFER_GROUPS + if unknown_groups: + raise ValueError(f"Unknown LoRA buffer groups: {sorted(unknown_groups)}") + self.lora_buffer_groups = frozenset(lora_buffer_groups) + self.enable_attn_lora = "attn" in self.lora_buffer_groups + self.enable_mlp_lora = "mlp" in self.lora_buffer_groups + self.enable_moe_lora = "moe" in self.lora_buffer_groups + self.enable_head_lora = "lm_head" in self.lora_buffer_groups + self.lora_moe_compressed_shared_outer = lora_moe_compressed_shared_outer + # Tier-2 CPU cache cap. Defaults to 4× the GPU pool so adapter + # spill-out to disk is rare in steady state. + self.max_loras_cpu: int = ( + max_loras_cpu if max_loras_cpu is not None else 4 * max_loras + ) + if self.max_loras_cpu < max_loras: + raise ValueError( + f"max_loras_cpu ({self.max_loras_cpu}) must be ≥ " + f"max_loras ({max_loras}); GPU-resident adapters live in " + "the CPU pool too." + ) + + self.n_layers: int = model_config.num_hidden_layers + hidden: int = model_config.hidden_size + n_heads: int = model_config.num_attention_heads + n_kv: int = model_config.num_key_value_heads + # Use the model's explicit head_dim when available (some architectures like + # Qwen3.5 decouple head_dim from hidden/n_heads, e.g. hidden=2048, n_heads=16 + # but head_dim=256). + head_dim: int = getattr(model_config, "head_dim", None) or (hidden // n_heads) + # attn_output_gate doubles the Q projection size (2× heads in qkv_proj). + # The o_proj input is n_heads × head_dim (no doubling). + q_multiplier: int = 2 if getattr(model_config, "attn_output_gate", False) else 1 + q_size_base: int = (n_heads // tp_size) * head_dim + + self.q_size_per_tp: int = q_multiplier * q_size_base + self.kv_size_per_tp: int = max(1, n_kv // tp_size) * head_dim + self.o_in_per_tp: int = q_size_base # o_proj reads un-gated heads + self.hidden_size: int = hidden + + from tokenspeed.runtime.layers.vocab_parallel_embedding import pad_vocab_size + + vocab_size: int = model_config.vocab_size + self.vocab_per_tp: int = pad_vocab_size(vocab_size) // tp_size + + # Qwen3MLP is TP-aware: ``gate_up_proj`` is column-parallel (each rank + # holds ``intermediate_size // tp_size`` output cols) and ``down_proj`` + # is row-parallel (each rank holds ``intermediate_size // tp_size`` + # input cols). The LoRA deltas ride the partial outputs of those base + # linears, and the existing downstream all-reduce sums per-rank + # partials — see ``apply_down_lora``/``apply_gate_up_lora``. + self.intermediate_size: int = getattr( + model_config, "intermediate_size", 4 * hidden + ) + self.intermediate_per_tp: int = self.intermediate_size // self.tp_size + self.moe_intermediate_size: int = getattr( + model_config, "moe_intermediate_size", self.intermediate_size + ) + self.moe_intermediate_per_tp: int = self.moe_intermediate_size // self.tp_size + self.num_experts: int = int( + getattr( + model_config, + "num_experts", + getattr( + model_config, + "num_local_experts", + getattr(model_config, "n_routed_experts", 0), + ), + ) + ) + + # CPU-side flag: True when at least one segment in the current + # batch_info uses a real adapter. CudaGraphWrapper + # reads this to pick the with-LoRA vs no-LoRA captured graph. + self.has_active_lora: bool = False + + # ── Tier 1: GPU pool ───────────────────────────────────────────── + # Real adapters take slots 0 .. max_loras - 1. Base/no-LoRA requests + # use NO_LORA_SLOT in batch metadata and do not consume a GPU slot. + self._n_slots: int = max_loras + self._slot_to_name: list[str | None] = [None] * self._n_slots + self._name_to_slot: dict[str, int] = {} + self._gpu_lru: OrderedDict[str, None] = OrderedDict() # alias of _lru + + # Mid-decode unload safety: track which adapter names were active in + # the most recent prepare_loras call, and defer GPU weight eviction for + # any adapter that is unloaded while still in use. + self._active_names: set[str] = set() + self._pending_eviction: set[tuple[str, int]] = ( + set() + ) # (name, lora_id) to evict when safe + + # ── Tier 2: pinned CPU pool ───────────────────────────────────── + # ``_cpu_cache[name]`` holds parsed weights in pinned host memory. + # ``_cpu_lru`` tracks LRU order for CPU eviction back to disk. An + # adapter is "CPU-resident" iff its name is in ``_cpu_cache``. + # GPU-resident adapters are also kept in ``_cpu_cache`` (we pay + # the host RAM cost once; reload to GPU is cheap and re-evicting + # GPU then re-promoting only needs an H2D copy, not a disk read). + self._name_to_id: dict[str, int] = {} + self._id_to_name: dict[int, str] = {} + self._next_id: int = 1 + + # ── Tier 2/3: CPU pinned pool + disk source of truth ───────────── + self._cpu_store = LoraCpuCache( + capacity=self.max_loras_cpu, + is_gpu_resident=lambda name: name in self._name_to_slot, + ) + # Compatibility aliases for existing tests/debug tooling. + self._cpu_cache = self._cpu_store.cache + self._cpu_lru = self._cpu_store.lru + self._adapter_paths = self._cpu_store.adapter_paths + self._pending_loads = self._cpu_store.pending_loads + + # Per-slot rank + scaling for real adapter slots only. + self._lora_ranks: torch.Tensor = torch.zeros( + self._n_slots, dtype=torch.int32, device=device + ) + self._slot_ranks: list[int] = [0] * self._n_slots + self._slot_scalings: list[float] = [0.0] * self._n_slots + self._scalings: torch.Tensor = torch.zeros( + self._n_slots, dtype=torch.float32, device=device + ) + + # ── Persistent batch_info ────────────────────────────────────────── + # All tensors are sized for the worst case so their pointers are + # stable across forward steps; per-step updates are in-place. + # ``num_segments`` may equal ``bs`` (one segment per token in the + # current path — no sort-by-adapter yet). + self._batch_info = LoraBatchInfo( + bs=0, + num_segments=0, + max_len=0, + seg_lens=torch.zeros(max_num_tokens, dtype=torch.int32, device=device), + seg_indptr=torch.zeros( + max_num_tokens + 1, dtype=torch.int32, device=device + ), + weight_indices=torch.full( + (max_num_tokens,), NO_LORA_SLOT, dtype=torch.int32, device=device + ), + lora_ranks=self._lora_ranks, + scalings=self._scalings, + permutation=None, + ) + + # CPU staging buffers (pinned) for the per-step H2D copy. + self._seg_lens_cpu = torch.zeros( + max_num_tokens, dtype=torch.int32, pin_memory=True + ) + self._weight_indices_cpu = torch.full( + (max_num_tokens,), NO_LORA_SLOT, dtype=torch.int32, pin_memory=True + ) + # Adapter-group buffers for the decode grouped expand kernel. + # Computed on CPU in prepare_loras (no GPU sync) and transferred + # non-blocking. Using stable GPU addresses so decode CUDA graphs + # can capture the pointers; num_groups on axis=1 changes per step + # so the graph grid must be re-evaluated outside the captured region. + _mg = self._n_slots # upper bound: one group per loaded adapter + self._sort_order_cpu = torch.zeros( + max_num_tokens, dtype=torch.int64, pin_memory=True + ) + self._group_slots_cpu = torch.zeros(_mg, dtype=torch.int32, pin_memory=True) + self._group_starts_cpu = torch.zeros(_mg, dtype=torch.int32, pin_memory=True) + self._group_sizes_cpu = torch.zeros(_mg, dtype=torch.int32, pin_memory=True) + self._sort_order_buf = torch.zeros( + max_num_tokens, dtype=torch.int64, device=device + ) + self._group_slots_buf = torch.zeros(_mg, dtype=torch.int32, device=device) + self._group_starts_buf = torch.zeros(_mg, dtype=torch.int32, device=device) + self._group_sizes_buf = torch.zeros(_mg, dtype=torch.int32, device=device) + + # ── GPU weight buffers ───────────────────────────────────────────── + # Attention: + # qkv_A_buffers: (n_slots, 3 * max_rank, hidden) — stacked q/k/v A. + # qkv_B_buffers: (n_slots, q_per_tp + 2 * kv_per_tp, max_rank). + # o_A_buffers: (n_slots, max_rank, o_in_per_tp). + # o_B_buffers: (n_slots, hidden, max_rank). + # MLP (TP-aware, mirrors qwen3 ``Qwen3MLP``): + # gate_up_A_buffers: (n_slots, 2 * max_rank, hidden) — A replicated. + # gate_up_B_buffers: (n_slots, 2 * intermediate_per_tp, max_rank) — column-parallel. + # down_A_buffers: (n_slots, max_rank, intermediate_per_tp) — row-parallel. + # down_B_buffers: (n_slots, hidden, max_rank) — B replicated. + self._weight_buffers = LoraWeightBuffers( + n_layers=self.n_layers, + n_slots=self._n_slots, + max_lora_rank=self.max_lora_rank, + hidden_size=self.hidden_size, + q_size_per_tp=self.q_size_per_tp, + kv_size_per_tp=self.kv_size_per_tp, + o_in_per_tp=self.o_in_per_tp, + intermediate_per_tp=self.intermediate_per_tp, + vocab_per_tp=self.vocab_per_tp, + dtype=self.dtype, + device=self.device, + tp_rank=self.tp_rank, + tp_size=self.tp_size, + buffer_groups=self.lora_buffer_groups, + ) + self.qkv_A_buffers = self._weight_buffers.qkv_A_buffers + self.qkv_B_buffers = self._weight_buffers.qkv_B_buffers + self.o_A_buffers = self._weight_buffers.o_A_buffers + self.o_B_buffers = self._weight_buffers.o_B_buffers + self.gate_up_A_buffers = self._weight_buffers.gate_up_A_buffers + self.gate_up_B_buffers = self._weight_buffers.gate_up_B_buffers + self.down_A_buffers = self._weight_buffers.down_A_buffers + self.down_B_buffers = self._weight_buffers.down_B_buffers + self.lm_head_A_buffer = ( + self._weight_buffers.lm_head_A_buffer if self.enable_head_lora else None + ) + self.lm_head_B_buffer = ( + self._weight_buffers.lm_head_B_buffer if self.enable_head_lora else None + ) + self._qkv_output_offset = self._weight_buffers.qkv_output_offset + self._max_qkv_out_dim = self._weight_buffers.max_qkv_out_dim + self._o_slice_offsets = self._weight_buffers.o_slice_offsets + self._gate_up_slice_offsets = self._weight_buffers.gate_up_slice_offsets + self._down_slice_offsets = self._weight_buffers.down_slice_offsets + self._moe_lora_buffers = MoeLoraBuffers( + n_layers=self.n_layers, + n_slots=self._n_slots, + max_lora_rank=self.max_lora_rank, + num_experts=self.num_experts, + hidden_size=self.hidden_size, + intermediate_per_tp=self.moe_intermediate_per_tp, + dtype=self.dtype, + device=self.device, + shard_weights=self._weight_buffers.shard_weights, + enabled=self.enable_moe_lora, + compressed_shared_outer=self.lora_moe_compressed_shared_outer, + ) + # Compatibility alias for tests/debug tooling that inspected the old + # manager-owned storage directly. + self._moe_lora_weights = self._moe_lora_buffers.weights_by_layer + + logger.info( + "LoraManager initialized: max_loras=%d max_rank=%d " + "tp_rank=%d/%d device=%s dtype=%s buffer_groups=%s " + "moe_compressed_shared_outer=%s", + max_loras, + max_lora_rank, + tp_rank, + tp_size, + device, + dtype, + ",".join(sorted(self.lora_buffer_groups)), + self.lora_moe_compressed_shared_outer, + ) + + # ── Public API ────────────────────────────────────────────────────────── + + @property + def batch_info(self) -> LoraBatchInfo: + return self._batch_info + + @property + def moe_lora_context(self) -> MoeLoraContext: + return self._moe_lora_buffers.build_context( + batch_info=self._batch_info, + scalings=self._scalings, + has_active_lora=self.has_active_lora, + ) + + def load_adapter(self, name: str, path: str) -> int: + """Register a PEFT adapter from *path* and warm the CPU pool. + + ``path`` is recorded as the adapter's durable disk path; it must + remain accessible for the lifetime of the manager because the CPU + pool may evict the adapter back to disk under memory pressure. + + Returns the integer ``lora_id`` to use in subsequent + ``prepare_loras`` calls. + """ + if name in self._name_to_id: + logger.warning("Adapter '%s' is already loaded; re-loading.", name) + self._evict_by_name(name) + self._evict_from_cpu(name) + + # Resolve the durable disk path now (used by future re-reads when + # the CPU pool evicts these weights). + adapter_path = path + weight_path = resolve_adapter_weight_path(adapter_path) + if not os.path.exists(weight_path): + raise FileNotFoundError( + f"Adapter weights not found at {weight_path!r} or {path!r}" + ) + + lora_id = self._next_id + self._next_id += 1 + self._name_to_id[name] = lora_id + self._id_to_name[lora_id] = name + self._cpu_store.set_path(name, adapter_path) + + # Warm the CPU pool — bounded by ``max_loras_cpu``, may evict + # other CPU-resident adapters back to disk. + self._cpu_store.ensure(name) + + logger.info( + "Registered adapter '%s' (lora_id=%d) from %s; CPU pool: %d/%d", + name, + lora_id, + path, + len(self._cpu_cache), + self.max_loras_cpu, + ) + return lora_id + + def unload_adapter(self, name: str) -> None: + if name not in self._name_to_id: + raise KeyError(f"Adapter '{name}' is not loaded.") + + lora_id = self._name_to_id.pop(name) + # Keep _id_to_name[lora_id] alive until eviction fires so retracted + # requests that resume with this lora_id can still be recognised and + # get the correct weights rather than silently falling back to the base + # model. It is cleared when the GPU slot is finally freed below. + + if name in self._active_names: + # The adapter was used in the most recent forward step; its GPU + # weights may still be needed if the scheduler is mid-decode or has + # retracted a request that it will later reschedule. Defer the + # weight zeroing until a batch arrives that does not include this + # adapter — at that point the previous step is confirmed complete. + logger.warning( + "Adapter '%s' (lora_id=%d) unloaded while potentially in-flight; " + "GPU slot eviction deferred until next batch that does not use it.", + name, + lora_id, + ) + # Store (name, lora_id) so the flush can distinguish this entry from + # a same-name re-registration that might occur before flushing. + self._pending_eviction.add((name, lora_id)) + # CPU weights kept alive so retracted requests can still reload. + else: + # Safe to evict immediately — not active in the current batch. + del self._id_to_name[lora_id] + self._evict_by_name(name) + self._cpu_store.remove(name) + + logger.info("Unloaded adapter '%s' (lora_id=%d)", name, lora_id) + + def get_id(self, name: str) -> int | None: + return self._name_to_id.get(name) + + def _flush_one_pending(self, name: str, lora_id: int) -> None: + """Carry out the deferred GPU+CPU eviction for one (name, lora_id) entry. + + Guards against two edge cases: + - Re-registration: the same name was re-loaded after unload; the new + slot should NOT be zeroed. Detected by checking that _id_to_name + still maps lora_id → name (the old entry, kept alive for retracted + requests) and that the current name→id mapping no longer exists. + - Already-evicted slot: LRU pressure may have freed the GPU slot + before the deferred flush fires; _evict_by_name is idempotent. + """ + # If the same name was re-registered, _name_to_id[name] exists again + # with a NEW lora_id. Skip GPU eviction — the slot now belongs to the + # new adapter. The CPU copy for the OLD weights was already removed or + # never loaded under the new id. + if self._name_to_id.get(name) is not None: + # Name was re-registered; clear the stale _id_to_name entry for the + # old lora_id only if it still points to this name. + if self._id_to_name.get(lora_id) == name: + del self._id_to_name[lora_id] + return + + # Clear the reverse mapping kept alive for retracted-request safety. + if self._id_to_name.get(lora_id) == name: + del self._id_to_name[lora_id] + + self._evict_by_name(name) # idempotent if already LRU-evicted + self._cpu_store.remove(name) + + def flush_pending_evictions(self) -> None: + """Evict all deferred adapter weights immediately. + + Safe to call at any time, including while the GPU is running a forward + pass. Because _reset_slot no longer issues GPU zero operations, the + only side effects are CPU-side: removing the slot from _name_to_slot + and cleaning up CPU weight caches. The GPU memory retains stale values + until the slot is reused, but no kernel ever reads from an evicted slot + (prepare_loras only assigns weight_indices to slots present in + _name_to_slot). + + Call this when the server has no pending decode requests and you want + to reclaim GPU slots occupied by deferred-unloaded adapters. + """ + for name, lora_id in list(self._pending_eviction): + logger.info("Flushing deferred eviction for adapter '%s'.", name) + self._flush_one_pending(name, lora_id) + self._pending_eviction.discard((name, lora_id)) + + def prepare_loras( + self, + lora_ids: list[int], + per_request_token_counts: list[int] | int = 1, + ) -> int: + """Fill :attr:`batch_info` for the upcoming forward. + + Each request becomes one segment. Returns the total number of + tokens written. All updates are in place on the persistent + batch_info tensors so the captured CUDA graph keeps replaying + against the same pointers. + """ + bs = len(lora_ids) + + # Phase 1: resolve all unique adapters and promote them to GPU before + # assigning any per-request slots. A single-pass approach would silently + # produce wrong outputs: if the batch needs more adapters than max_loras, + # _find_free_slot evicts an already-assigned adapter (e.g. A), then the + # later request for A gets NO_LORA_SLOT and runs as the base model. + unique_names: dict[int, str] = {} + for lid in lora_ids: + if lid == 0 or lid in unique_names: + continue + name = self._id_to_name.get(lid) + if name is not None: + unique_names[lid] = name + + n_unique = len(unique_names) + if n_unique > self.max_loras: + raise RuntimeError( + f"Batch requires {n_unique} unique LoRA adapters but " + f"max_loras={self.max_loras}. Reduce adapter diversity per " + "batch (use pack scheduling) or increase max_loras." + ) + + # The previous forward step is now complete (prepare_loras is called + # synchronously before each forward). Flush any deferred evictions for + # adapters that are NOT needed by the current batch. + current_batch_names = set(unique_names.values()) + for pending_name, pending_lora_id in list(self._pending_eviction): + if pending_name not in current_batch_names: + logger.info( + "Deferred eviction: removing adapter '%s' GPU weights now.", + pending_name, + ) + self._flush_one_pending(pending_name, pending_lora_id) + self._pending_eviction.discard((pending_name, pending_lora_id)) + + # Track which adapters are active in this batch for mid-decode unload safety. + self._active_names = current_batch_names + + # Promote all needed adapters before touching per_request_slots so that + # LRU eviction only targets adapters NOT in this batch. After each + # promotion, move the adapter to MRU so subsequent promotions within + # this loop don't evict an already-promoted or already-resident batch + # adapter (which would be LRU if it was loaded in a previous step). + for name in unique_names.values(): + self._ensure_in_gpu(name) + self._gpu_lru.move_to_end(name) # protect from intra-phase eviction + + # Phase 2: assign per-request slots from the now-stable _name_to_slot + # map (no further evictions occur here). + per_request_slots: list[int] = [] + for lid in lora_ids: + if lid == 0: + per_request_slots.append(NO_LORA_SLOT) + continue + name = self._id_to_name.get(lid) + if name is None: + logger.warning("Unknown lora_id %d; treating as base model.", lid) + per_request_slots.append(NO_LORA_SLOT) + continue + slot = self._name_to_slot[name] # guaranteed present after phase 1 + per_request_slots.append(slot) + self._gpu_lru.move_to_end(name) + + # Per-request seg_lens. + if isinstance(per_request_token_counts, int): + seg_lens_list = [per_request_token_counts] * bs + else: + if len(per_request_token_counts) != bs: + raise ValueError( + "per_request_token_counts length must match lora_ids length" + ) + seg_lens_list = list(per_request_token_counts) + + total_tokens = sum(seg_lens_list) + if total_tokens > self.max_num_tokens: + raise ValueError( + f"LoRA batch_info overflow: {total_tokens} > {self.max_num_tokens}" + ) + max_len = max(seg_lens_list) if seg_lens_list else 0 + + bi = self._batch_info + + # For decode batches (max_len == 1): compute adapter groups on CPU + # so the grouped expand kernel can batch same-adapter tokens into a + # full BLOCK_S=16 GEMM tile, recovering tensor-core efficiency. + if max_len == 1 and bs > 1: + sort_order, group_slots, group_starts, group_sizes = ( + build_decode_lora_groups(per_request_slots) + ) + ng = len(group_slots) + active_count = len(sort_order) + self._sort_order_cpu[:active_count] = torch.as_tensor( + sort_order, dtype=torch.int64 + ) + self._group_slots_cpu[:ng] = torch.as_tensor(group_slots, dtype=torch.int32) + self._group_starts_cpu[:ng] = torch.as_tensor( + group_starts, dtype=torch.int32 + ) + self._group_sizes_cpu[:ng] = torch.as_tensor(group_sizes, dtype=torch.int32) + bi.sort_order = self._sort_order_buf + bi.group_slots = self._group_slots_buf + bi.group_starts = self._group_starts_buf + bi.group_sizes = self._group_sizes_buf + bi.sort_order[:active_count].copy_( + self._sort_order_cpu[:active_count], non_blocking=True + ) + bi.group_slots[:ng].copy_(self._group_slots_cpu[:ng], non_blocking=True) + bi.group_starts[:ng].copy_(self._group_starts_cpu[:ng], non_blocking=True) + bi.group_sizes[:ng].copy_(self._group_sizes_cpu[:ng], non_blocking=True) + bi.num_groups = ng + bi.max_group_size = max(group_sizes) if group_sizes else 0 + else: + bi.sort_order = bi.group_slots = bi.group_starts = bi.group_sizes = None + bi.num_groups = 0 + bi.max_group_size = 0 + + first_slot = per_request_slots[0] if per_request_slots else NO_LORA_SLOT + bi.single_lora_slot = ( + first_slot + if first_slot != NO_LORA_SLOT + and all(slot == first_slot for slot in per_request_slots) + else NO_LORA_SLOT + ) + bi.single_lora_rank = ( + self._slot_ranks[bi.single_lora_slot] + if bi.single_lora_slot != NO_LORA_SLOT + else 0 + ) + bi.multi_lora_start_slot = NO_LORA_SLOT + bi.multi_lora_count = 0 + bi.multi_lora_segment_len = 0 + bi.multi_lora_rank = 0 + if ( + bs > 1 + and bi.single_lora_slot == NO_LORA_SLOT + and max_len > _CHUNKED_THRESHOLD + and len(set(seg_lens_list)) == 1 + and all(slot != NO_LORA_SLOT for slot in per_request_slots) + ): + start_slot = per_request_slots[0] + consecutive_slots = all( + slot == start_slot + i for i, slot in enumerate(per_request_slots) + ) + rank = self._slot_ranks[start_slot] + scaling = self._slot_scalings[start_slot] + same_rank_and_scaling = all( + self._slot_ranks[slot] == rank and self._slot_scalings[slot] == scaling + for slot in per_request_slots + ) + if consecutive_slots and rank > 0 and same_rank_and_scaling: + bi.multi_lora_start_slot = start_slot + bi.multi_lora_count = bs + bi.multi_lora_segment_len = seg_lens_list[0] + bi.multi_lora_rank = rank + + # Stage on CPU then a single non-blocking H2D. + self._seg_lens_cpu[:bs] = torch.as_tensor(seg_lens_list, dtype=torch.int32) + self._weight_indices_cpu[:bs] = torch.as_tensor( + per_request_slots, dtype=torch.int32 + ) + + self.has_active_lora = any(s != NO_LORA_SLOT for s in per_request_slots) + + bi = self._batch_info + bi.bs = bs + bi.num_segments = bs + bi.max_len = max_len + + # Skip the H2D copies and on-device cumsum when no adapter is active: + # the no-LoRA CUDA graph omits all LoRA kernels and never reads + # weight_indices / seg_lens / seg_indptr, so updating them is wasted work. + if self.has_active_lora: + bi.seg_lens[:bs].copy_(self._seg_lens_cpu[:bs], non_blocking=True) + bi.weight_indices[:bs].copy_( + self._weight_indices_cpu[:bs], non_blocking=True + ) + bi.seg_indptr[0] = 0 + torch.cumsum(bi.seg_lens[:bs], dim=0, out=bi.seg_indptr[1 : bs + 1]) + + return total_tokens + + def apply_qkv_lora( + self, + hidden_states: torch.Tensor, + qkv: torch.Tensor, + layer_id: int, + ) -> torch.Tensor: + """Fused QKV LoRA delta: ``qkv += B @ A @ x * scaling``. + + ``hidden_states``: ``(s, hidden)`` (full input). + ``qkv``: ``(s, q_per_tp + 2 * kv_per_tp)`` (output of qkv_proj + on this rank). Updated in place via the kernel's fused-add. + """ + if hidden_states.shape[0] == 0: + return qkv + if not self.enable_attn_lora: + return qkv + bi = self._batch_info + if bi.bs == 0 or not self.has_active_lora: + return qkv + + A_buf = self.qkv_A_buffers[layer_id] + B_buf = self.qkv_B_buffers[layer_id] + # lora_a: (s, 3 * max_rank) + lora_a = ( + lora_shrink_prefill_fwd(hidden_states, A_buf, bi, stack_num=3) + if bi.max_len > _CHUNKED_THRESHOLD + else lora_shrink_fwd(hidden_states, A_buf, bi, stack_num=3) + ) + if bi.max_len > _CHUNKED_THRESHOLD: + lora_expand_prefill_fwd( + lora_a, + B_buf, + bi, + self._qkv_output_offset, + self._max_qkv_out_dim, + base_output=qkv, + ) + else: + lora_qkv_expand_fwd( + lora_a, + B_buf, + bi, + self._qkv_output_offset, + self._max_qkv_out_dim, + base_output=qkv, + ) + return qkv + + def apply_o_lora( + self, + attn_output: torch.Tensor, + o_output: torch.Tensor, + layer_id: int, + ) -> torch.Tensor: + """Row-parallel O-projection LoRA delta. + + ``attn_output``: ``(s, q_per_tp)`` per-rank attention output (input + to o_proj). + ``o_output``: ``(s, hidden)`` partial sum from the host o_proj + (``reduce_results=False`` on this codebase). Updated in place. + + Each rank computes ``B @ A_local @ x_local`` — a partial of shape + ``(s, hidden)``. A is sharded along its input dim and B is + replicated, so the sum of partials over ranks equals + ``B @ A_full @ x_full``. The host layer's downstream fused + all-reduce in ``post_attention_layernorm`` sums the base partial + and the LoRA partial together, producing the correct full output. + """ + if attn_output.shape[0] == 0: + return o_output + if not self.enable_attn_lora: + return o_output + bi = self._batch_info + if bi.bs == 0 or not self.has_active_lora: + return o_output + + A_buf = self.o_A_buffers[layer_id] + B_buf = self.o_B_buffers[layer_id] + # lora_a (partial per rank): (s, max_rank). No internal all-reduce — + # the partial flows into B and the result rides the downstream sum. + lora_a = ( + lora_shrink_prefill_fwd(attn_output, A_buf, bi, stack_num=1) + if bi.max_len > _CHUNKED_THRESHOLD + else lora_shrink_fwd(attn_output, A_buf, bi, stack_num=1) + ) + if bi.max_len > _CHUNKED_THRESHOLD: + lora_expand_prefill_fwd( + lora_a, + B_buf, + bi, + self._o_slice_offsets, + self.hidden_size, + base_output=o_output, + ) + elif _use_triton_grouped_decode(bi): + lora_expand_grouped_v2_fwd(lora_a, B_buf, bi, base_output=o_output) + else: + lora_expand_fwd(lora_a, B_buf, bi, base_output=o_output) + return o_output + + def apply_gate_up_lora( + self, + hidden_states: torch.Tensor, + gate_up: torch.Tensor, + layer_id: int, + ) -> torch.Tensor: + """Fused gate/up LoRA delta: ``gate_up += B @ A @ x * scaling``. + + ``hidden_states``: ``(s, hidden)``. + ``gate_up``: ``(s, 2 * intermediate_per_tp)`` — output of the + column-parallel ``gate_up_proj`` (each rank holds its own output + shard). Updated in place via the kernel's fused-add. + """ + if hidden_states.shape[0] == 0: + return gate_up + if not self.enable_mlp_lora: + return gate_up + bi = self._batch_info + if bi.bs == 0 or not self.has_active_lora: + return gate_up + + A_buf = self.gate_up_A_buffers[layer_id] + B_buf = self.gate_up_B_buffers[layer_id] + # lora_a: (s, 2 * max_rank) — gate's lora_a in [:, :r], up's in [:, r:]. + lora_a = ( + lora_shrink_prefill_fwd(hidden_states, A_buf, bi, stack_num=2) + if bi.max_len > _CHUNKED_THRESHOLD + else lora_shrink_fwd(hidden_states, A_buf, bi, stack_num=2) + ) + if bi.max_len > _CHUNKED_THRESHOLD: + lora_expand_prefill_fwd( + lora_a, + B_buf, + bi, + self._gate_up_slice_offsets, + self.intermediate_per_tp, + base_output=gate_up, + ) + else: + lora_gate_up_expand_fwd( + lora_a, + B_buf, + bi, + self.intermediate_per_tp, + base_output=gate_up, + ) + return gate_up + + def apply_down_lora( + self, + x: torch.Tensor, + down_output: torch.Tensor, + layer_id: int, + ) -> torch.Tensor: + """Down-projection LoRA delta (row-parallel under MLP TP). + + ``x``: ``(s, intermediate_per_tp)`` — input to the + row-parallel ``down_proj`` (this rank's input shard). + ``down_output``: ``(s, hidden)`` — partial output of ``down_proj`` + before its all-reduce. Updated in place. + + Each rank's delta is ``B @ A_local @ x_local``: A is sharded along + the input dim and B is replicated, so summing per-rank deltas yields + the full ``B @ A_full @ x_full``. The base linear runs with + ``reduce_results=False``; the downstream all-reduce that sums the + base partial also sums the LoRA partials. + """ + if x.shape[0] == 0: + return down_output + if not self.enable_mlp_lora: + return down_output + bi = self._batch_info + if bi.bs == 0 or not self.has_active_lora: + return down_output + + A_buf = self.down_A_buffers[layer_id] + B_buf = self.down_B_buffers[layer_id] + lora_a = ( + lora_shrink_prefill_fwd(x, A_buf, bi, stack_num=1) + if bi.max_len > _CHUNKED_THRESHOLD + else lora_shrink_fwd(x, A_buf, bi, stack_num=1) + ) + if bi.max_len > _CHUNKED_THRESHOLD: + lora_expand_prefill_fwd( + lora_a, + B_buf, + bi, + self._down_slice_offsets, + self.hidden_size, + base_output=down_output, + ) + elif _use_triton_grouped_decode(bi): + lora_expand_grouped_v2_fwd(lora_a, B_buf, bi, base_output=down_output) + else: + lora_expand_fwd(lora_a, B_buf, bi, base_output=down_output) + return down_output + + def apply_lm_head_lora( + self, + hidden_states: torch.Tensor, + logits: torch.Tensor, + ) -> torch.Tensor: + """lm_head LoRA delta: ``logits += B @ A @ x * scaling``. + + ``hidden_states``: ``(s, hidden)`` — one token per request (pruned). + ``logits``: ``(s, vocab_per_tp)`` — pre-all-gather logits shard. + Applied before the TP all-gather so each rank contributes its vocab + shard correctly. + + Note: when ``extend_return_logprob`` is True the caller may pass more + than ``bi.bs`` tokens. In that case this method is a no-op because + the per-token slot mapping is not available here; sampling logits are + still correct for the last token of each request. + """ + if hidden_states.shape[0] == 0: + return logits + if not self.enable_head_lora: + return logits + bi = self._batch_info + if bi.bs == 0 or not self.has_active_lora: + return logits + if hidden_states.shape[0] != bi.bs: + return logits + + slots = bi.weight_indices[: bi.bs] # (bs,) + valid = slots != NO_LORA_SLOT + if not valid.any(): + return logits + + # Fast path: all requests use the same adapter slot. + # Use plain matmul to avoid a gather of the B matrix (vocab_per_tp × rank + # bytes) for every request. Guarded from CUDA graph capture because the + # Python branch is frozen at capture time — replaying with a different + # single_lora_slot would silently use stale weights. + if ( + bi.single_lora_slot != NO_LORA_SLOT + and not torch.cuda.is_current_stream_capturing() + ): + slot = bi.single_lora_slot + scaling = self._scalings[slot].item() + A = self.lm_head_A_buffer[slot] # (r, hidden) + B = self.lm_head_B_buffer[slot] # (vocab_per_tp, r) + lora_a = hidden_states @ A.T # (bs, r) + delta = lora_a @ B.T # (bs, vocab_per_tp) + return logits + delta * scaling + + valid_slots = slots.clamp(min=0) + # A: (bs, r, hidden), B: (bs, vocab_per_tp, r) + A = self.lm_head_A_buffer[valid_slots] + B = self.lm_head_B_buffer[valid_slots] + # lora_a: (bs, r) = A @ hidden_states[..., None] + lora_a = torch.bmm(A, hidden_states.unsqueeze(-1)).squeeze(-1) + # delta: (bs, vocab_per_tp) + delta = torch.bmm(B, lora_a.unsqueeze(-1)).squeeze(-1) + # Zero out requests with no adapter; scale the rest. + scale = self._scalings[valid_slots] * valid.to(self._scalings.dtype) + return logits + delta * scale.unsqueeze(-1) + + def apply_moe_gate_up_lora( + self, + layer_id: int, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + gate_up_output: torch.Tensor, + *, + sorted_token_ids: torch.Tensor | None = None, + ) -> torch.Tensor: + """Compatibility wrapper; MoE-specific work lives in MoeLoraContext.""" + if not self.enable_moe_lora: + return gate_up_output + return self.moe_lora_context.apply_gate_up_lora( + layer_id, + hidden_states, + topk_ids, + gate_up_output, + sorted_token_ids=sorted_token_ids, + ) + + def apply_moe_down_lora( + self, + layer_id: int, + intermediate: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + down_output: torch.Tensor, + *, + sorted_token_ids: torch.Tensor | None = None, + ) -> torch.Tensor: + """Compatibility wrapper; MoE-specific work lives in MoeLoraContext.""" + if not self.enable_moe_lora: + return down_output + return self.moe_lora_context.apply_down_lora( + layer_id, + intermediate, + topk_ids, + topk_weights, + down_output, + sorted_token_ids=sorted_token_ids, + ) + + def set_adapter_scaling(self, name: str, scaling: float) -> None: + slot = self._name_to_slot.get(name) + if slot is not None: + self._slot_scalings[slot] = scaling + self._scalings[slot] = scaling + + # ── Slot allocation ───────────────────────────────────────────────────── + + def _ensure_in_gpu(self, name: str) -> int: + if name in self._name_to_slot: + return self._name_to_slot[name] + # Tier-2 → Tier-1 promotion; may need to read from disk if the + # CPU pool has evicted this adapter since registration. + self._cpu_store.ensure(name) + slot = self._find_free_slot() + self._load_to_slot(name, slot) + self._name_to_slot[name] = slot + self._slot_to_name[slot] = name + self._gpu_lru[name] = None + return slot + + def prefetch(self, name: str) -> None: + """Best-effort async warm of the CPU pool for *name*. + + Called from the request-admission path: when a request with a + non-zero ``lora_id`` arrives the manager kicks off a background + disk read so the safetensors I/O is overlapped with the previous + forward step rather than blocking ``prepare_loras`` of the step + that actually consumes the adapter. + + No-op when the adapter is already CPU-resident or a load is + already in flight. Silently ignores unknown adapters (the + request will fall back to base via NO_LORA_SLOT). + """ + self._cpu_store.prefetch(name) + + def _evict_from_cpu(self, name: str) -> None: + """Public helper, takes the lock. Caller must ensure *name* is + not currently GPU-resident.""" + self._cpu_store.evict(name) + + def _find_free_slot(self) -> int: + for slot in range(self._n_slots): + if self._slot_to_name[slot] is None: + return slot + for candidate_name in list(self._gpu_lru.keys()): + slot = self._name_to_slot[candidate_name] + logger.debug("Evicting adapter '%s' from GPU slot %d", candidate_name, slot) + self._evict_by_name(candidate_name) + return slot + raise RuntimeError( + "LoRA GPU pool is full and no evictable adapter was found. " + f"Increase max_loras (current: {self.max_loras})." + ) + + def _load_to_slot(self, name: str, slot: int) -> None: + cpu_weights = self._cpu_cache[name] + rank = self._get_rank_for(name) + scaling = self._get_scaling_for(name, rank) + self._reset_slot(slot) + self._lora_ranks[slot] = rank + self._slot_ranks[slot] = rank + self._slot_scalings[slot] = scaling + self._scalings[slot] = scaling + self._weight_buffers.load_adapter_to_slot(cpu_weights, slot, rank) + self._moe_lora_buffers.load_adapter_to_slot(cpu_weights, slot, rank) + + logger.debug("Loaded adapter '%s' into GPU slot %d (rank=%d)", name, slot, rank) + + def _get_rank_for(self, name: str) -> int: + cpu_weights = self._cpu_cache.get(name, {}) + if not cpu_weights: + return self.max_lora_rank + # Check layer 0 first (dense attn/MLP modules). + if 0 in cpu_weights: + for mod in PEFT_MODULES: + if mod in cpu_weights[0]: + return cpu_weights[0][mod][0].shape[0] + for mod, tensors in cpu_weights[0].items(): + if mod.startswith("experts."): + lora_A = tensors[0] + if lora_A.dim() == 3: + return lora_A.shape[1] + return lora_A.shape[0] + # Fall back to lm_head (head-only adapters). + if LORA_HEAD_LAYER_ID in cpu_weights: + head = cpu_weights[LORA_HEAD_LAYER_ID] + if PEFT_HEAD_MODULE in head: + return head[PEFT_HEAD_MODULE][0].shape[0] + return self.max_lora_rank + + def _get_scaling_for(self, name: str, rank: int) -> float: + return read_adapter_scaling(self._adapter_paths.get(name), rank) + + def _evict_by_name(self, name: str) -> None: + if name in self._name_to_slot: + slot = self._name_to_slot.pop(name) + self._slot_to_name[slot] = None + self._reset_slot(slot) + self._gpu_lru.pop(name, None) + + def _reset_slot(self, slot: int) -> None: + # GPU weight tensors are intentionally NOT zeroed here. + # + # Correctness argument: prepare_loras assigns weight_indices[i] only to + # slots present in _name_to_slot. _evict_by_name removes the slot from + # _name_to_slot before calling _reset_slot, so no kernel ever reads from + # an evicted slot's GPU memory regardless of what values are there. + # _load_to_slot overwrites the stale values when the slot is reused. + # + # Skipping the GPU zeros removes potentially hundreds of kernel launches + # per eviction (one zero_() per buffer per layer) and — more critically — + # eliminates a CUDA stream race: graph.replay() uses a dedicated stream + # while tensor.zero_() runs on the default stream; without explicit + # inter-stream synchronisation, an immediate GPU zero could race with an + # in-flight graph kernel still reading the old weights. + # + # MoE buffers: _moe_lora_buffers.clear_slot does both GPU zeroing AND + # CPU dict cleanup (weights_by_layer.pop). Keep the CPU cleanup, skip + # the GPU zeros. + self._moe_lora_buffers.clear_slot_cpu_only(slot) + self._lora_ranks[slot] = 0 + self._slot_ranks[slot] = 0 + self._slot_scalings[slot] = 0.0 + self._scalings[slot] = 0.0 diff --git a/python/tokenspeed/runtime/lora/lora_registry.py b/python/tokenspeed/runtime/lora/lora_registry.py new file mode 100644 index 000000000..9ee651f1a --- /dev/null +++ b/python/tokenspeed/runtime/lora/lora_registry.py @@ -0,0 +1,105 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""In-process registry that tracks loaded LoRA adapters and maps names to IDs.""" + +from __future__ import annotations + +from typing import Iterator, Optional + +from tokenspeed.runtime.lora.lora_config import LoraConfig +from tokenspeed.runtime.utils import get_colorful_logger + +logger = get_colorful_logger(__name__) + +# Sentinel value meaning "no adapter" — maps cleanly to int for scheduling. +NO_LORA_ID: int = 0 + + +class LoraRegistry: + """Thread-unsafe registry; call from the scheduler/engine main thread only. + + TODO: add locking when multi-threaded engine support is needed. + """ + + def __init__(self, max_loras: int) -> None: + self.max_loras = max_loras + self._configs: dict[str, LoraConfig] = {} # name → config + self._name_to_id: dict[str, int] = {} # name → integer ID + self._id_to_name: dict[int, str] = {} # integer ID → name + self._next_id: int = 1 # 0 is reserved for "no lora" + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def register(self, config: LoraConfig) -> int: + """Register a new adapter and return its integer ID. + + Raises ``ValueError`` if the adapter is already registered or the + registry is at capacity. + """ + if config.name in self._name_to_id: + raise ValueError(f"LoRA adapter '{config.name}' is already registered.") + if len(self._configs) >= self.max_loras: + raise ValueError( + f"LoRA registry is full ({self.max_loras} adapters). " + "Unload an adapter before loading a new one." + ) + lora_id = self._next_id + self._next_id += 1 + self._configs[config.name] = config + self._name_to_id[config.name] = lora_id + self._id_to_name[lora_id] = config.name + logger.info("Registered LoRA adapter '%s' → id=%d", config.name, lora_id) + return lora_id + + def unregister(self, name: str) -> None: + """Remove an adapter from the registry. + + Raises ``KeyError`` if the name is not registered. + """ + if name not in self._name_to_id: + raise KeyError(f"LoRA adapter '{name}' is not registered.") + lora_id = self._name_to_id.pop(name) + del self._id_to_name[lora_id] + del self._configs[name] + logger.info("Unregistered LoRA adapter '%s' (id=%d)", name, lora_id) + + def get_id(self, name: str) -> Optional[int]: + """Return the integer ID for an adapter name, or None if not found.""" + return self._name_to_id.get(name) + + def get_config(self, name: str) -> Optional[LoraConfig]: + """Return the LoraConfig for a registered adapter name.""" + return self._configs.get(name) + + def get_config_by_id(self, lora_id: int) -> Optional[LoraConfig]: + name = self._id_to_name.get(lora_id) + return self._configs.get(name) if name else None + + def __contains__(self, name: str) -> bool: + return name in self._name_to_id + + def __len__(self) -> int: + return len(self._name_to_id) + + def __iter__(self) -> Iterator[LoraConfig]: + return iter(self._configs.values()) diff --git a/python/tokenspeed/runtime/lora/moe_lora.py b/python/tokenspeed/runtime/lora/moe_lora.py new file mode 100644 index 000000000..9957bfd30 --- /dev/null +++ b/python/tokenspeed/runtime/lora/moe_lora.py @@ -0,0 +1,1338 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable + +import torch + +from tokenspeed.runtime.lora.lora_batch import NO_LORA_SLOT, LoraBatchInfo + +try: + from tokenspeed_kernel.ops.moe_lora import ( + gate_up_b_expand, + per_expert_a_shrink, + per_expert_b_down_expand, + per_expert_gate_up_b_expand, + shared_a_shrink, + shared_b_down_expand, + sorted_a_down_shrink, + sorted_gate_up_b_expand, + ) + + _FUSED_MOE_LORA_AVAILABLE = True +except Exception: + _FUSED_MOE_LORA_AVAILABLE = False + +MoeLayerSlotWeights = dict[int, dict[str, torch.Tensor]] +MoeWeightsByLayer = dict[int, MoeLayerSlotWeights] + + +@dataclass(frozen=True) +class MoeLoraContext: + """Narrow per-forward view of MoE LoRA state consumed by MoE backends.""" + + weights_by_layer: MoeWeightsByLayer + batch_info: LoraBatchInfo + scalings: torch.Tensor + has_active_lora: bool + # Per-layer buffer lists for CUDA-graph-compatible dynamic slot indexing. + # When set, _apply_*_slot uses GPU tensor indexing via batch_info.weight_indices + # instead of Python dict lookup, so the CUDA graph can replay with any adapter. + w13_A_buffers: list | None + w13_B_buffers: list | None + down_A_buffers: list | None + down_B_buffers: list | None + # Multi-stream prefetch: secondary stream + pre-allocated output buffers. + # Shrink ops run on _lora_stream concurrently with the base MoE GEMMs. + _lora_stream: object | None = None # torch.cuda.Stream + _lora_a_m_buf: torch.Tensor | None = None # (max_bs, 2*max_r) + _lora_a_flat_buf: torch.Tensor | None = None # (max_bs*max_topk, max_r) + # Mutable flags (list elements are mutable even in frozen dataclass): + # _prefetch_flags[0] = gate_up shrink pending; [1] = down shrink pending. + _prefetch_flags: list | None = None + + def apply_gate_up_lora( + self, + layer_id: int, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + gate_up_output: torch.Tensor, + *, + sorted_token_ids: torch.Tensor | None = None, + ) -> torch.Tensor: + """Apply expert-scoped LoRA to routed MoE gate/up output.""" + if hidden_states.shape[0] == 0 or topk_ids.numel() == 0: + return gate_up_output + slots, single_slot = self._token_slots(hidden_states.shape[0]) + if single_slot == NO_LORA_SLOT and slots is None: + return gate_up_output + if single_slot != NO_LORA_SLOT: + self._apply_gate_up_slot( + layer_id, + single_slot, + hidden_states, + topk_ids, + gate_up_output, + sorted_token_ids=sorted_token_ids, + ) + return gate_up_output + assert slots is not None + for slot_t in torch.unique(slots): + slot = int(slot_t.item()) + if slot == NO_LORA_SLOT: + continue + self._apply_gate_up_slot( + layer_id, + slot, + hidden_states, + topk_ids, + gate_up_output, + token_mask=slots == slot, + sorted_token_ids=sorted_token_ids, + ) + return gate_up_output + + def apply_down_lora( + self, + layer_id: int, + intermediate: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + down_output: torch.Tensor, + *, + sorted_token_ids: torch.Tensor | None = None, + ) -> torch.Tensor: + """Apply expert-scoped LoRA to routed MoE down output.""" + if intermediate.shape[0] == 0 or topk_ids.numel() == 0: + return down_output + num_tokens = topk_ids.shape[0] + slots, single_slot = self._token_slots(num_tokens) + if single_slot == NO_LORA_SLOT and slots is None: + return down_output + # Sorted-space fast path: work directly on sorted intermediate, skipping + # _route_rows_from_cache. Only applies when sorted dispatch is active (TMA + # config), since the fused shrink kernel has poor utilization for small + # flat-pair batches. + if ( + _FUSED_MOE_LORA_AVAILABLE + and sorted_token_ids is not None + and single_slot != NO_LORA_SLOT + and self.down_A_buffers is not None + and self.batch_info.single_lora_slot != -1 + ): + if self._apply_down_sorted( + layer_id, + single_slot, + intermediate, + topk_ids, + topk_weights, + down_output, + sorted_token_ids, + ): + return down_output + route_input = self._route_rows_from_cache( + intermediate, + topk_ids.numel(), + sorted_token_ids=sorted_token_ids, + ).view(topk_ids.shape[0], topk_ids.shape[1], -1) + if single_slot != NO_LORA_SLOT: + self._apply_down_slot( + layer_id, + single_slot, + route_input, + topk_ids, + topk_weights, + down_output, + ) + return down_output + assert slots is not None + for slot_t in torch.unique(slots): + slot = int(slot_t.item()) + if slot == NO_LORA_SLOT: + continue + self._apply_down_slot( + layer_id, + slot, + route_input, + topk_ids, + topk_weights, + down_output, + token_mask=slots == slot, + ) + return down_output + + def _token_slots(self, num_tokens: int) -> tuple[torch.Tensor | None, int]: + bi = self.batch_info + if bi.bs == 0 or not self.has_active_lora: + return None, NO_LORA_SLOT + if bi.single_lora_slot != NO_LORA_SLOT: + return None, bi.single_lora_slot + slots = torch.repeat_interleave( + bi.weight_indices[: bi.bs], bi.seg_lens[: bi.bs] + ) + if slots.numel() != num_tokens: + # Token ownership changed under TP/EP communication. Mixed LoRA + # cannot be applied safely without transforming the slot map too. + return None, NO_LORA_SLOT + return slots, NO_LORA_SLOT + + # ── Multi-stream prefetch API ────────────────────────────────────────────── + # Called from triton_common.py BEFORE each base GEMM to overlap the LoRA + # shrink kernel with the base model's gate_up / down computation. + + def launch_gate_up_shrink( + self, + layer_id: int, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + ) -> None: + """Fork: launch gate_up LoRA A-shrink on secondary stream. + + Must be called immediately BEFORE gate_up_gemm so that shared_a_shrink + (torch.mm) runs concurrently on _lora_stream while gate_up_gemm runs + on the main stream. apply_gate_up_lora will join the stream and use + the pre-filled _lora_a_m_buf instead of recomputing. + """ + if self._prefetch_flags is None: + return + self._prefetch_flags[0] = False # default: no prefetch + bi = self.batch_info + if ( + not self.has_active_lora + or bi.single_lora_slot == NO_LORA_SLOT + or self.w13_A_buffers is None + or self._lora_stream is None + or self._lora_a_m_buf is None + ): + return + m = hidden_states.shape[0] + w13_A_buf = self.w13_A_buffers[layer_id] + if w13_A_buf.shape[1] != 1: # only sglang_shared format (shared A) + return + if m > self._lora_a_m_buf.shape[0]: + return # prefill with too many tokens — skip prefetch to avoid OOB + # Fork to secondary stream: launch torch.mm concurrently with gate_up_gemm. + self._lora_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self._lora_stream): + torch.mm(hidden_states, w13_A_buf[0, 0].T, out=self._lora_a_m_buf[:m]) + self._prefetch_flags[0] = True + + def launch_down_shrink( + self, + layer_id: int, + intermediate: torch.Tensor, + topk_ids: torch.Tensor, + m_k: int, + ) -> None: + """Fork: launch down LoRA A-shrink on secondary stream. + + Must be called immediately BEFORE down_gemm so that per_expert_a_shrink + runs concurrently on _lora_stream while down_gemm runs on main stream. + intermediate is intermediate_cache2 (silu output), shape (m*topk, INTER). + m_k is m_tokens * top_k (non-padded). + """ + if self._prefetch_flags is None: + return + self._prefetch_flags[1] = False + bi = self.batch_info + down_A_buf = self.down_A_buffers[layer_id] if self.down_A_buffers else None + down_B_buf = self.down_B_buffers[layer_id] if self.down_B_buffers else None + if ( + not self.has_active_lora + or bi.single_lora_slot == NO_LORA_SLOT + or down_A_buf is None + or down_B_buf is None + or self._lora_stream is None + or self._lora_a_flat_buf is None + or not _FUSED_MOE_LORA_AVAILABLE + or down_A_buf.shape[1] <= 1 # per-expert A only + or down_B_buf.shape[1] != 1 # shared B only + or not down_A_buf.is_contiguous() + ): + return + if m_k > self._lora_a_flat_buf.shape[0]: + return # prefill with too many tokens — skip prefetch to avoid OOB + ri_flat = intermediate[:m_k].view(m_k, -1) + safe_ids = topk_ids.clamp(0, down_A_buf.shape[1] - 1).to(torch.long) + slot_idx = bi.weight_indices[:1].clamp(0) + self._lora_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(self._lora_stream): + per_expert_a_shrink( + ri_flat, + down_A_buf, + slot_idx, + safe_ids, + out=self._lora_a_flat_buf[:m_k], + ) + self._prefetch_flags[1] = True + + def _apply_gate_up_slot( + self, + layer_id: int, + slot: int, + hidden_states: torch.Tensor, + topk_ids: torch.Tensor, + gate_up_output: torch.Tensor, + *, + token_mask: torch.Tensor | None = None, + sorted_token_ids: torch.Tensor | None = None, + ) -> None: + # For the single-slot case (all tokens same adapter), use dynamic GPU tensor + # indexing so the CUDA graph can replay with any loaded adapter. + # For multi-slot batches, fall back to Python dict lookup (eager only). + bi = self.batch_info + # Determine if we're on the CUDA-graph buffer path (single slot, all tokens + # same adapter). In this path we keep slot_idx as a GPU tensor so the CUDA + # graph can replay with any loaded adapter without re-capture. + _use_buffer_path = self.w13_A_buffers is not None and bi.single_lora_slot != -1 + slot_idx = None + if _use_buffer_path: + slot_idx = bi.weight_indices[:1].clamp(0) + w13_B_buf = self.w13_B_buffers[layer_id] # (n_slots, E, I2, MAX_R) + w13_B = None + # For the sglang_shared fast path (shared A, per-expert B) with fused kernels + # available, skip the w13_A gather entirely — shared_a_shrink reads directly from + # the buffer. For all other paths, gather as before. + _w13_A_buf = self.w13_A_buffers[layer_id] + _skip_a_gather = ( + _FUSED_MOE_LORA_AVAILABLE + and _w13_A_buf.shape[1] == 1 # shared outer (sglang_shared) + and w13_B_buf.shape[1] > 1 # per-expert B + ) + # Also skip the buffer copy for per_expert format (both A and B per-expert): + # per_expert_a_shrink + per_expert_gate_up_b_expand read the full buffer + # directly, making the 32MB w13_A buffer copy unnecessary. + _skip_a_gather_per_expert = ( + _FUSED_MOE_LORA_AVAILABLE + and _w13_A_buf.shape[1] > 1 # per-expert A + and w13_B_buf.shape[1] > 1 # per-expert B + and token_mask is None + and _w13_A_buf.is_contiguous() + and w13_B_buf.is_contiguous() + ) + if _skip_a_gather or _skip_a_gather_per_expert: + # Use slot-0 view (Python int index = no copy) — correct shape for checks. + # Actual compute reads from the full buffers directly. + w13_A = _w13_A_buf[0] # view: (1_or_E, R, H) — no copy! + else: + w13_A = _w13_A_buf[slot_idx].squeeze(0) + else: + weights = self.weights_by_layer.get(layer_id, {}).get(slot) + if weights is None: + return + w13_A = weights["w13_A"] + w13_B = weights["w13_B"] + w13_B_buf = None + + # Determine shapes without materialising w13_B when on the buffer path. + if _use_buffer_path: + w13_A_experts = w13_A.shape[0] + w13_B_experts = w13_B_buf.shape[1] # E dimension of buffer + else: + w13_A_experts = w13_A.shape[0] + w13_B_experts = w13_B.shape[0] + num_experts = max(w13_A_experts, w13_B_experts) + safe_ids = topk_ids.clamp(0, num_experts - 1).to(torch.long) + m, k = safe_ids.shape + # Build the validity mask without torch.any() to avoid GPU→CPU synchronisation. + if token_mask is not None: + valid = (topk_ids >= 0) & (topk_ids < num_experts) & token_mask[:, None] + else: + valid = None + + # Check if per_expert fast path is available (avoids the 32MB+16MB gather copies). + # Must be determined before the A-shrink so we can skip the expensive gather+einsum. + _use_flat_per_expert = ( + w13_A.shape[0] > 1 # per-expert A + and w13_B_buf is not None + and w13_B_experts > 1 # per-expert B + and _FUSED_MOE_LORA_AVAILABLE + and _use_buffer_path + and token_mask is None + and self.w13_A_buffers[layer_id].is_contiguous() + and w13_B_buf.is_contiguous() + ) + + # Shared A (sglang_shared format): one matmul for all tokens. + # lora_a_m (m, r) is only computed here; lora_a (m, k, r) is deferred until + # actually needed (not needed for the all-experts or shared-B paths). + lora_a_m = None + if w13_A.shape[0] == 1: + # Skip cuBLAS GEMM when shared_a_shrink will compute it without the gather. + if _use_buffer_path and _skip_a_gather: + lora_a_m = None # computed by shared_a_shrink in the fused branch below + else: + lora_a_m = hidden_states @ w13_A[0].T + lora_a = None # computed lazily below only if per-expert B path is taken + elif _use_flat_per_expert: + lora_a = ( + None # computed inline by per_expert_a_shrink + per_expert_*_b_expand + ) + else: + selected_A = self._select_expert_weights(w13_A, safe_ids) + lora_a = torch.einsum("mh,mkrh->mkr", hidden_states, selected_A) + + # Compute lora_a only when needed (per-expert B path). + # For shared-A + all-experts or shared-A + shared-B, lora_a_m is used directly. + # Lazily materialise w13_B for non-fused fallback paths on the buffer path. + def _get_w13_B(): + nonlocal w13_B + if w13_B is None: + w13_B = w13_B_buf[slot_idx].squeeze(0) + return w13_B + + if w13_B_experts == 1: + # Shared B: expand lora_a_m to (m*k, r) via repeat_interleave (no contiguous copy). + w13_B_local = _get_w13_B() + r = lora_a_m.shape[-1] if lora_a_m is not None else lora_a.shape[-1] + la_flat = ( + lora_a_m.repeat_interleave(k, dim=0) + if lora_a_m is not None + else lora_a.reshape(-1, r) + ) + delta = la_flat @ w13_B_local[0].T # (m*k, n) + delta = delta.view(m, k, -1) + elif w13_A.shape[0] == 1: + # Shared-A + per-expert B. + if _FUSED_MOE_LORA_AVAILABLE and token_mask is None: + if sorted_token_ids is not None: + # TMA sorted path: write to sorted output positions (SCATTER=False). + _scaling = ( + self.scalings[slot_idx] + if _use_buffer_path + else self.scalings[slot] + ) + w13_B_local = _get_w13_B() + assert w13_B_local.is_contiguous(), "w13_B must be contiguous" + sorted_gate_up_b_expand( + lora_a_m, + w13_B_local, + safe_ids, + sorted_token_ids, + gate_up_output, + _scaling, + m * k, + k, + ) + elif _use_buffer_path: + # Decode path (buffer path): use pre-fetched lora_a_m if available + # (launched on secondary stream before gate_up_gemm), else compute inline. + _gu_prefetched = ( + self._prefetch_flags is not None + and self._prefetch_flags[0] + and self._lora_a_m_buf is not None + and self._lora_stream is not None + ) + if _gu_prefetched: + # Join secondary stream: wait for torch.mm to complete. + torch.cuda.current_stream().wait_stream(self._lora_stream) + lora_a_m = self._lora_a_m_buf[: hidden_states.shape[0]] + self._prefetch_flags[0] = False + else: + lora_a_m = shared_a_shrink( + hidden_states, self.w13_A_buffers[layer_id], slot_idx + ) + gate_up_b_expand( + lora_a_m, + w13_B_buf, + slot_idx, + safe_ids, + gate_up_output, + self.scalings, # full buffer; kernel loads scalings[slot] + ) + else: + # Non-buffer decode path (multi-slot eager). + w13_B_local = _get_w13_B() + assert w13_B_local.is_contiguous(), "w13_B must be contiguous" + gate_up_b_expand( + lora_a_m, + w13_B_local.unsqueeze(0), + torch.zeros(1, dtype=torch.int32, device=w13_B_local.device), + safe_ids, + gate_up_output, + self.scalings[slot].unsqueeze(0), # (1,) for slot 0 of fake buf + ) + return + # Fallback: all-experts GEMM + gather (no expand+copy needed). + w13_B_local = _get_w13_B() + E_fb, n_out, r = w13_B_local.shape + candidates = ( + lora_a_m @ w13_B_local.permute(2, 0, 1).reshape(r, E_fb * n_out) + ).view(m, E_fb, n_out) + delta = candidates.gather(1, safe_ids.unsqueeze(-1).expand(-1, -1, n_out)) + else: + # Per-expert A + per-expert B. + if _use_flat_per_expert: + # Fast flat path: avoid two buffer gather copies (w13_A_buf[slot] = 32MB, + # w13_B_buf[slot] = 16MB) by reading directly from the full buffers. + # per_expert_a_shrink reused: treats w13_A (n_slots, E, 2r, H) as + # down_A (n_slots, E, MAX_R, INTER) with MAX_R=2r, INTER=H. + hidden_flat = hidden_states.repeat_interleave(k, dim=0) # (m*k, H) + lora_a_flat = per_expert_a_shrink( + hidden_flat, + self.w13_A_buffers[layer_id], + slot_idx, + safe_ids, + ) # (m*k, 2r) + per_expert_gate_up_b_expand( + lora_a_flat, + w13_B_buf, + slot_idx, + safe_ids, + gate_up_output, + self.scalings, + ) + return + # Fallback: gather + einsum for non-buffer or masked paths. + w13_B_local = _get_w13_B() + if lora_a is None: + lora_a = lora_a_m.unsqueeze(1).expand(-1, k, -1).contiguous() + selected_B = self._select_expert_weights(w13_B_local, safe_ids) + delta = torch.einsum("mkr,mknr->mkn", lora_a, selected_B) + + # Reuse slot_idx already computed above (avoid extra clamp+gather for scalings). + scaling = self.scalings[slot_idx] if _use_buffer_path else self.scalings[slot] + delta = delta * scaling + if valid is not None: + delta = delta.masked_fill(~valid[:, :, None], 0.0) + self._add_route_delta( + gate_up_output, + delta.reshape(-1, delta.shape[-1]), + sorted_token_ids=sorted_token_ids, + ) + + def _apply_down_sorted( + self, + layer_id: int, + slot: int, + intermediate: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + down_output: torch.Tensor, + sorted_token_ids: torch.Tensor, + ) -> bool: + """Sorted-space down LoRA: skip route_from_cache, fuse per-expert shrink. + + Returns True if the fast path was taken (per-expert A + shared B format), + False if the format requires the generic path. + """ + bi = self.batch_info + slot_idx = bi.weight_indices[:1].clamp(0) + down_A = self.down_A_buffers[layer_id][slot_idx].squeeze(0) + down_B = self.down_B_buffers[layer_id][slot_idx].squeeze(0) + # Only handles per-expert A + shared B (sglang_shared format for down). + if down_A.shape[0] <= 1 or down_B.shape[0] != 1: + return False + if not down_A.is_contiguous(): + return False + + m, k = topk_ids.shape + num_experts = down_A.shape[0] + safe_ids = topk_ids.clamp(0, num_experts - 1).to(torch.long) + route_count = m * k + r = down_A.shape[1] + + # moe_dispatch pre-allocates sorted_token_ids for all potential experts, which + # can exceed the intermediate cache size. All valid entries (≥0) lie within + # the first intermediate.shape[0] rows (bound = m*k + max_active*(BM-1)). + inter_flat = intermediate.reshape(intermediate.shape[0], -1) + padded = inter_flat.shape[0] + sti = sorted_token_ids[:padded] # truncate to intermediate size + + # Fused per-expert shrink: lora_a[s] = intermediate[s] @ down_A[exp[s]].T + lora_a_sorted = sorted_a_down_shrink( + inter_flat, # (padded, INTER) + down_A, # (E, r, INTER) + safe_ids, + sti, + route_count=route_count, + K=k, + ) + + # Shared B GEMM: (padded, r) @ (r, h) → (padded, h) + delta = lora_a_sorted @ down_B[0].T + + # Scale each sorted position by its topk_weight * adapter scaling. + valid = (sti >= 0) & (sti < route_count) + # Clamp to [0, route_count-1]: sorted_token_ids may contain route_count as + # a sentinel value, which would be OOB without the upper bound. + flat_j_safe = sti.clamp(0, route_count - 1) + weights_sorted = topk_weights.reshape(-1)[flat_j_safe].to(delta.dtype) + scaling_t = self.scalings[slot_idx].to(delta.dtype) + delta = delta * (weights_sorted * scaling_t * valid.to(delta.dtype)).unsqueeze( + -1 + ) + + # Scatter-add to token-ordered down_output. + h = delta.shape[-1] + down_output.view(route_count, h).scatter_add_( + 0, flat_j_safe.unsqueeze(-1).expand(-1, h), delta + ) + return True + + def _apply_down_slot( + self, + layer_id: int, + slot: int, + route_input: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + down_output: torch.Tensor, + *, + token_mask: torch.Tensor | None = None, + ) -> None: + bi = self.batch_info + # Determine if we're on the CUDA-graph buffer path (single slot, all tokens + # same adapter). In this path we keep slot_idx as a GPU tensor so the CUDA + # graph can replay with any loaded adapter without re-capture. + _use_buffer_path = self.down_A_buffers is not None and bi.single_lora_slot != -1 + slot_idx = None + if _use_buffer_path: + # (1,) GPU tensor — changes at CUDA-graph replay without re-capture. + slot_idx = bi.weight_indices[:1].clamp(0) + # Keep references to the full buffers; slicing is done lazily or inside kernels. + down_A_buf = self.down_A_buffers[layer_id] # (n_slots, E, MAX_R, INTER) + down_B_buf = self.down_B_buffers[layer_id] # (n_slots, 1_or_E, H, MAX_R) + # Sliced views are populated lazily to avoid redundant gathers. + down_A = None + down_B = None + else: + weights = self.weights_by_layer.get(layer_id, {}).get(slot) + if weights is None: + return + down_A = weights["down_A"] + down_B = weights["down_B"] + down_A_buf = None + down_B_buf = None + + # Determine shapes without materialising tensors when on the buffer path. + if _use_buffer_path: + down_A_experts = down_A_buf.shape[1] # E dimension of buffer + down_B_experts = down_B_buf.shape[1] # 1 for shared-B + else: + down_A_experts = down_A.shape[0] + down_B_experts = down_B.shape[0] + num_experts = max(down_A_experts, down_B_experts) + safe_ids = topk_ids.clamp(0, num_experts - 1).to(torch.long) + m, k = safe_ids.shape + if token_mask is not None: + valid = (topk_ids >= 0) & (topk_ids < num_experts) & token_mask[:, None] + else: + valid = None + + # Helpers to lazily materialise sliced tensors for fallback paths. + def _get_down_A(): + nonlocal down_A + if down_A is None: + down_A = down_A_buf[slot_idx].squeeze(0) + return down_A + + def _get_down_B(): + nonlocal down_B + if down_B is None: + down_B = down_B_buf[slot_idx].squeeze(0) + return down_B + + # Fast fused path: per-expert A + shared B on the CUDA-graph buffer path. + # Eliminates both gather copies (down_A gather + down_B gather) and the + # separate GEMM + scale + add chain. + if ( + _FUSED_MOE_LORA_AVAILABLE + and _use_buffer_path + and token_mask is None + and down_A_experts > 1 + and down_B_experts == 1 + and down_A_buf.is_contiguous() + and down_B_buf.is_contiguous() + ): + _down_prefetched = ( + self._prefetch_flags is not None + and self._prefetch_flags[1] + and self._lora_a_flat_buf is not None + and self._lora_stream is not None + ) + if _down_prefetched: + # Join secondary stream: wait for per_expert_a_shrink to complete. + torch.cuda.current_stream().wait_stream(self._lora_stream) + lora_a_flat = self._lora_a_flat_buf[: m * k] + self._prefetch_flags[1] = False + else: + ri_flat = route_input.reshape(m * k, -1) # (m*k, INTER) + lora_a_flat = per_expert_a_shrink( + ri_flat, down_A_buf, slot_idx, safe_ids + ) + shared_b_down_expand( + lora_a_flat, + down_B_buf, + slot_idx, + down_output.view(m, k, -1), + topk_weights, + self.scalings, # full buffer; kernel loads scalings[slot] + k, + ) + return + + # Shared A (sglang_shared down_proj): one matmul per token-topk group. + if down_A_experts == 1: + down_A_local = _get_down_A() + ri = route_input.reshape(m * k, -1) # (m*k, i) + lora_a = (ri @ down_A_local[0].T).view(m, k, -1) # (m, k, r) + elif _FUSED_MOE_LORA_AVAILABLE and token_mask is None: + # Flat per-expert shrink: avoids the (m*k, r, INTER) gather intermediate + # and replaces the batched einsum with a single fused Triton kernel. + if _use_buffer_path: + # Buffer path: pass full buffer + slot_idx to avoid gather. + lora_a = per_expert_a_shrink( + route_input.reshape(m * k, -1), down_A_buf, slot_idx, safe_ids + ).view(m, k, -1) + else: + down_A_local = _get_down_A() + assert down_A_local.is_contiguous(), "down_A must be contiguous" + lora_a = per_expert_a_shrink( + route_input.reshape(m * k, -1), + down_A_local.unsqueeze(0), # fake (1, E, MAX_R, INTER) buffer + torch.zeros(1, dtype=torch.int32, device=down_A_local.device), + safe_ids, + ).view(m, k, -1) + else: + down_A_local = _get_down_A() + selected_A = self._select_expert_weights(down_A_local, safe_ids) + lora_a = torch.einsum("mki,mkri->mkr", route_input, selected_A) + + # Shared B (sglang_shared down_proj): one batched matmul. + if down_B_experts == 1: + down_B_local = _get_down_B() + r = lora_a.shape[-1] + delta = lora_a.reshape(-1, r) @ down_B_local[0].T # (m*k, h) + delta = delta.view(m, k, -1) + elif ( + _FUSED_MOE_LORA_AVAILABLE + and _use_buffer_path + and token_mask is None + and down_B_buf.is_contiguous() + ): + # Per-expert B fast path: avoid the 16MB buffer copy + gather. + # lora_a computed via per_expert_a_shrink is already (m*k, r); reshape to flat. + lora_a_flat = lora_a.reshape(m * k, -1) + per_expert_b_down_expand( + lora_a_flat, + down_B_buf, + slot_idx, + safe_ids, + down_output.view(m, k, -1), + topk_weights, + self.scalings, + k, + ) + return # accumulation already done inside the kernel + else: + down_B_local = _get_down_B() + selected_B = self._select_expert_weights(down_B_local, safe_ids) + delta = torch.einsum("mkr,mkhr->mkh", lora_a, selected_B) + + delta = delta * topk_weights[:, :, None].to(delta.dtype) + # Reuse slot_idx computed above for scalings (avoid extra clamp+gather). + scaling = self.scalings[slot_idx] if _use_buffer_path else self.scalings[slot] + delta = delta * scaling + if valid is not None: + delta = delta.masked_fill(~valid[:, :, None], 0.0) + down_output.view(topk_ids.shape[0], topk_ids.shape[1], -1).add_(delta) + + @staticmethod + def _select_expert_weights( + weights: torch.Tensor, + safe_ids: torch.Tensor, + ) -> torch.Tensor: + if weights.shape[0] == 1: + return weights[0].expand(*safe_ids.shape, *weights.shape[1:]) + return weights[safe_ids] + + @staticmethod + def _add_route_delta( + output: torch.Tensor, + route_delta: torch.Tensor, + *, + sorted_token_ids: torch.Tensor | None, + ) -> None: + if sorted_token_ids is None: + output.view(route_delta.shape[0], -1).add_(route_delta) + return + # moe_dispatch may pre-allocate sorted_token_ids larger than output. + # Truncate: all valid entries lie within the first output.shape[0] rows. + padded = output.shape[0] + sti = sorted_token_ids[:padded] + # Gather route_delta into output-layout, zero invalid (padding) entries, + # then add in one vectorised kernel — avoids boolean-index tensor creation. + route_count = route_delta.shape[0] + clipped = sti.clamp(0, route_count - 1).to(torch.long) + reordered = route_delta[clipped] # (padded, n) + invalid = (sti < 0) | (sti >= route_count) + reordered.masked_fill_(invalid.unsqueeze(-1), 0) + output.add_(reordered) + + @staticmethod + def _route_rows_from_cache( + cache: torch.Tensor, + route_count: int, + *, + sorted_token_ids: torch.Tensor | None, + ) -> torch.Tensor: + if sorted_token_ids is None: + return cache.view(route_count, -1) + # moe_dispatch may pre-allocate sorted_token_ids larger than cache. + # Truncate: all valid entries lie within the first cache.shape[0] rows. + sti = sorted_token_ids[: cache.shape[0]] + # Use scatter_ with an extra dummy row (index 0) for padding positions. + # Avoids boolean-index tensor creation; one scatter_ + one slice. + n = cache.shape[-1] + rows = torch.zeros((route_count + 1, n), dtype=cache.dtype, device=cache.device) + # Shift: -1 (padding) → 0 (dummy), valid 0..route_count-1 → 1..route_count. + clipped = (sti.clamp(-1, route_count - 1) + 1).to(torch.long) + rows.scatter_(0, clipped.unsqueeze(-1).expand(-1, n), cache) + return rows[1:] # drop dummy row → (route_count, n) + + +class MoeLoraBuffers: + """Own expert-scoped MoE LoRA weights independently from dense buffers.""" + + def __init__( + self, + *, + n_layers: int, + n_slots: int, + max_lora_rank: int, + num_experts: int, + hidden_size: int, + intermediate_per_tp: int, + dtype: torch.dtype, + device: torch.device, + shard_weights: Callable[ + [str, torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor] + ], + enabled: bool = True, + compressed_shared_outer: bool = False, + ) -> None: + self.n_layers = n_layers + self.n_slots = n_slots + self.max_lora_rank = max_lora_rank + self.num_experts = num_experts + self.hidden_size = hidden_size + self.intermediate_per_tp = intermediate_per_tp + self.dtype = dtype + self.device = device + self._shard_weights = shard_weights + self.enabled = enabled + self.compressed_shared_outer = compressed_shared_outer + self.weights_by_layer: MoeWeightsByLayer = {} + self.w13_A_buffers: list[torch.Tensor] = [] + self.w13_B_buffers: list[torch.Tensor] = [] + self.down_A_buffers: list[torch.Tensor] = [] + self.down_B_buffers: list[torch.Tensor] = [] + self._alloc() + # Multi-stream prefetch: overlap LoRA shrink ops with base MoE GEMMs. + # Shrink kernels run on a secondary stream in parallel with gate_up/down GEMMs. + # Pre-allocated output buffers avoid torch.empty inside CUDA graphs. + _max_bs = 128 + _max_topk = 8 + self._lora_stream: torch.cuda.Stream | None = ( + torch.cuda.Stream() if torch.cuda.is_available() else None + ) + self._lora_a_m_buf: torch.Tensor | None = None + self._lora_a_flat_buf: torch.Tensor | None = None + if self.enabled and torch.cuda.is_available(): + # gate/up shrink: (m, 2*r); down shrink: (m*topk, r) + self._lora_a_m_buf = torch.zeros( + _max_bs, 2 * max_lora_rank, dtype=dtype, device=device + ) + self._lora_a_flat_buf = torch.zeros( + _max_bs * _max_topk, max_lora_rank, dtype=dtype, device=device + ) + # Pre-warm cuBLAS and Triton kernels on _lora_stream before any CUDA graph + # capture. torch.mm (cuBLAS) requires its handle to be initialized on each + # stream; failing to do so causes CUBLAS_STATUS_NOT_INITIALIZED during capture. + if self._lora_stream is not None: + _d = torch.zeros(1, dtype=dtype, device=device) + with torch.cuda.stream(self._lora_stream): + torch.mm(_d.unsqueeze(0), _d.unsqueeze(1)) + del _d + torch.cuda.synchronize() + # Mutable flags shared between MoeLoraBuffers and MoeLoraContext instances: + # [0] = gate_up shrink launched; [1] = down shrink launched. + self._prefetch_flags: list[bool] = [False, False] + + def _alloc(self) -> None: + if not self.enabled: + return + n = self.n_slots + e = max(self.num_experts, 0) + r = self.max_lora_rank + h = self.hidden_size + i = self.intermediate_per_tp + w13_a_experts = 1 if self.compressed_shared_outer else e + w13_b_experts = e + down_a_experts = e + down_b_experts = 1 if self.compressed_shared_outer else e + for _ in range(self.n_layers): + self.w13_A_buffers.append( + torch.zeros( + (n, w13_a_experts, 2 * r, h), + dtype=self.dtype, + device=self.device, + ) + ) + self.w13_B_buffers.append( + torch.zeros( + (n, w13_b_experts, 2 * i, 2 * r), + dtype=self.dtype, + device=self.device, + ) + ) + self.down_A_buffers.append( + torch.zeros( + (n, down_a_experts, r, i), dtype=self.dtype, device=self.device + ) + ) + self.down_B_buffers.append( + torch.zeros( + (n, down_b_experts, h, r), dtype=self.dtype, device=self.device + ) + ) + + def load_adapter_to_slot(self, cpu_weights, slot: int, rank: int) -> None: + has_moe = any( + mod.startswith("experts.") + for modules in cpu_weights.values() + for mod in modules + ) + if has_moe and not self.enabled: + raise ValueError( + "Adapter contains MoE LoRA weights, but LoRA buffer group 'moe' " + "is disabled." + ) + if self.num_experts <= 0: + if has_moe: + raise ValueError( + "MoE LoRA adapter requires model_config.num_experts or " + "model_config.num_local_experts." + ) + return + rank = min(rank, self.max_lora_rank) + for layer_id, modules in cpu_weights.items(): + if not any(mod.startswith("experts.") for mod in modules): + continue + self._clear_layer_slot(layer_id, slot) + if any( + mod in modules for mod in ("experts.w1", "experts.w2", "experts.w3") + ): + self._load_3d_adapter_layer(layer_id, modules, slot, rank) + else: + self._load_2d_adapter_layer(layer_id, modules, slot, rank) + + def _load_2d_adapter_layer(self, layer_id: int, modules, slot: int, rank: int): + expert_ids = [ + int(mod.split(".")[1]) for mod in modules if mod.startswith("experts.") + ] + if not expert_ids: + return + if self.compressed_shared_outer: + raise ValueError( + "Compressed MoE shared-outer storage only supports 3D " + "experts.w1/w2/w3 adapters." + ) + num_experts = max(expert_ids) + 1 + self._check_num_experts(layer_id, num_experts) + w13_A, w13_B, down_A, down_B = self._slot_layer_tensors(layer_id, slot) + r = rank + for mod, (lora_A_full, lora_B_full) in modules.items(): + if not mod.startswith("experts."): + continue + _, expert_id_s, module = mod.split(".", 2) + expert_id = int(expert_id_s) + # Normalize A/B convention: standard PEFT stores A as (rank, in_features) + # and B as (out_features, rank). Some adapters use the transposed layout + # (in_features, rank) and (rank, out_features). Detect by comparing dims: + # if the first dim is larger than the second, A is in (in, rank) format. + if lora_A_full.dim() == 2 and lora_A_full.shape[0] > lora_A_full.shape[1]: + lora_A_full = lora_A_full.T # (in, rank) → (rank, in) + if lora_B_full.dim() == 2 and lora_B_full.shape[0] < lora_B_full.shape[1]: + lora_B_full = lora_B_full.T # (rank, out) → (out, rank) + lora_A_shard_cpu, lora_B_shard_cpu = self._shard_weights( + module, lora_A_full, lora_B_full + ) + actual_rank = min(lora_A_shard_cpu.shape[0], r) + lora_A_shard = lora_A_shard_cpu[:actual_rank].to( + device=self.device, + dtype=self.dtype, + non_blocking=True, + ) + lora_B_shard = lora_B_shard_cpu[:, :actual_rank].to( + device=self.device, + dtype=self.dtype, + non_blocking=True, + ) + self._copy_projection( + module, + expert_id, + actual_rank, + lora_A_shard, + lora_B_shard, + w13_A, + w13_B, + down_A, + down_B, + rank=r, + ) + self.weights_by_layer.setdefault(layer_id, {})[slot] = { + "w13_A": w13_A, + "w13_B": w13_B, + "down_A": down_A, + "down_B": down_B, + } + + def _load_3d_adapter_layer(self, layer_id: int, modules, slot: int, rank: int): + required = ("experts.w1", "experts.w2", "experts.w3") + missing = [name for name in required if name not in modules] + if missing: + raise ValueError( + f"3D MoE LoRA layer {layer_id} is missing modules: {missing}" + ) + w1_A, w1_B = modules["experts.w1"] + w2_A, w2_B = modules["experts.w2"] + w3_A, w3_B = modules["experts.w3"] + num_experts = self._infer_3d_num_experts((w1_A, w1_B, w2_A, w2_B, w3_A, w3_B)) + self._check_num_experts(layer_id, num_experts) + if self.compressed_shared_outer: + self._check_shared_outer_layer(layer_id, modules, num_experts) + w13_A, w13_B, down_A, down_B = self._slot_layer_tensors(layer_id, slot) + self._copy_3d_projection( + "gate_proj", w1_A, w1_B, w13_A, w13_B, down_A, down_B, rank + ) + self._copy_3d_projection( + "up_proj", w3_A, w3_B, w13_A, w13_B, down_A, down_B, rank + ) + self._copy_3d_projection( + "down_proj", w2_A, w2_B, w13_A, w13_B, down_A, down_B, rank + ) + E_b, I2, R = w13_B.shape + w13_B_T = w13_B.permute(2, 0, 1).reshape(R, E_b * I2).contiguous() + self.weights_by_layer.setdefault(layer_id, {})[slot] = { + "w13_A": w13_A, + "w13_B": w13_B, + "w13_B_T": w13_B_T, + "down_A": down_A, + "down_B": down_B, + } + + def _check_num_experts(self, layer_id: int, adapter_num_experts: int) -> None: + if adapter_num_experts > self.num_experts: + raise ValueError( + f"MoE LoRA layer {layer_id} has {adapter_num_experts} experts, " + f"but the model has {self.num_experts}." + ) + + def _slot_layer_tensors(self, layer_id: int, slot: int): + return ( + self.w13_A_buffers[layer_id][slot], + self.w13_B_buffers[layer_id][slot], + self.down_A_buffers[layer_id][slot], + self.down_B_buffers[layer_id][slot], + ) + + def _clear_layer_slot(self, layer_id: int, slot: int) -> None: + self.w13_A_buffers[layer_id][slot].zero_() + self.w13_B_buffers[layer_id][slot].zero_() + self.down_A_buffers[layer_id][slot].zero_() + self.down_B_buffers[layer_id][slot].zero_() + + @staticmethod + def _check_shared_outer_layer( + layer_id: int, + modules, + num_experts: int, + ) -> None: + expected = { + "experts.w1": (1, num_experts), + "experts.w2": (num_experts, 1), + "experts.w3": (1, num_experts), + } + for module, (expected_a, expected_b) in expected.items(): + lora_A, lora_B = modules[module] + if lora_A.shape[0] != expected_a or lora_B.shape[0] != expected_b: + raise ValueError( + "Compressed MoE shared-outer storage expects " + f"{module} A/B dim0=({expected_a}, {expected_b}) for " + f"layer {layer_id}; got {tuple(lora_A.shape)}, " + f"{tuple(lora_B.shape)}." + ) + + @staticmethod + def _infer_3d_num_experts(tensors: tuple[torch.Tensor, ...]) -> int: + num_experts = 0 + for tensor in tensors: + if tensor.dim() != 3: + raise ValueError( + f"3D MoE LoRA tensors must be rank-3, got shape {tuple(tensor.shape)}" + ) + if tensor.shape[0] != 1: + num_experts = max(num_experts, int(tensor.shape[0])) + if num_experts <= 0: + raise ValueError("3D MoE LoRA layer has no per-expert tensor dimension") + for tensor in tensors: + if tensor.shape[0] not in (1, num_experts): + raise ValueError( + "3D MoE LoRA dim0 must be either 1 (shared) or num_experts " + f"({num_experts}); got {tuple(tensor.shape)}" + ) + return num_experts + + def _copy_3d_projection( + self, + module: str, + lora_A_full: torch.Tensor, + lora_B_full: torch.Tensor, + w13_A: torch.Tensor, + w13_B: torch.Tensor, + down_A: torch.Tensor, + down_B: torch.Tensor, + rank: int, + ) -> None: + num_experts = max( + w13_A.shape[0], w13_B.shape[0], down_A.shape[0], down_B.shape[0] + ) + if self.compressed_shared_outer: + self._copy_3d_projection_compressed( + module, + lora_A_full, + lora_B_full, + w13_A, + w13_B, + down_A, + down_B, + rank, + num_experts, + ) + return + for expert_id in range(num_experts): + lora_A = self._select_3d_expert_tensor(lora_A_full, expert_id) + lora_B = self._select_3d_expert_tensor(lora_B_full, expert_id) + lora_A_shard_cpu, lora_B_shard_cpu = self._shard_weights( + module, lora_A, lora_B + ) + actual_rank = min(lora_A_shard_cpu.shape[0], rank) + lora_A_shard = lora_A_shard_cpu[:actual_rank].to( + device=self.device, + dtype=self.dtype, + non_blocking=True, + ) + lora_B_shard = lora_B_shard_cpu[:, :actual_rank].to( + device=self.device, + dtype=self.dtype, + non_blocking=True, + ) + self._copy_projection( + module, + expert_id, + actual_rank, + lora_A_shard, + lora_B_shard, + w13_A, + w13_B, + down_A, + down_B, + rank=rank, + a_expert_id=self._dst_expert_id(module, "A", expert_id), + b_expert_id=self._dst_expert_id(module, "B", expert_id), + ) + + def _copy_3d_projection_compressed( + self, + module: str, + lora_A_full: torch.Tensor, + lora_B_full: torch.Tensor, + w13_A: torch.Tensor, + w13_B: torch.Tensor, + down_A: torch.Tensor, + down_B: torch.Tensor, + rank: int, + num_experts: int, + ) -> None: + if module in ("gate_proj", "up_proj"): + shared_A = self._select_3d_expert_tensor(lora_A_full, 0) + first_B = self._select_3d_expert_tensor(lora_B_full, 0) + lora_A_shard_cpu, _ = self._shard_weights(module, shared_A, first_B) + actual_rank = min(lora_A_shard_cpu.shape[0], rank) + lora_A_shard = lora_A_shard_cpu[:actual_rank].to( + device=self.device, + dtype=self.dtype, + non_blocking=True, + ) + if module == "gate_proj": + w13_A[0, :actual_rank, :].copy_(lora_A_shard, non_blocking=True) + else: + w13_A[0, rank : rank + actual_rank, :].copy_( + lora_A_shard, non_blocking=True + ) + for expert_id in range(num_experts): + expert_B = self._select_3d_expert_tensor(lora_B_full, expert_id) + _, lora_B_shard_cpu = self._shard_weights(module, shared_A, expert_B) + b_rank = min(lora_B_shard_cpu.shape[1], rank) + lora_B_shard = lora_B_shard_cpu[:, :b_rank].to( + device=self.device, + dtype=self.dtype, + non_blocking=True, + ) + if module == "gate_proj": + w13_B[expert_id, : self.intermediate_per_tp, :b_rank].copy_( + lora_B_shard, non_blocking=True + ) + else: + w13_B[ + expert_id, + self.intermediate_per_tp : 2 * self.intermediate_per_tp, + rank : rank + b_rank, + ].copy_(lora_B_shard, non_blocking=True) + return + + if module == "down_proj": + first_A = self._select_3d_expert_tensor(lora_A_full, 0) + shared_B = self._select_3d_expert_tensor(lora_B_full, 0) + _, lora_B_shard_cpu = self._shard_weights(module, first_A, shared_B) + b_rank = min(lora_B_shard_cpu.shape[1], rank) + lora_B_shard = lora_B_shard_cpu[:, :b_rank].to( + device=self.device, + dtype=self.dtype, + non_blocking=True, + ) + down_B[0, :, :b_rank].copy_(lora_B_shard, non_blocking=True) + for expert_id in range(num_experts): + expert_A = self._select_3d_expert_tensor(lora_A_full, expert_id) + lora_A_shard_cpu, _ = self._shard_weights(module, expert_A, shared_B) + actual_rank = min(lora_A_shard_cpu.shape[0], rank) + lora_A_shard = lora_A_shard_cpu[:actual_rank].to( + device=self.device, + dtype=self.dtype, + non_blocking=True, + ) + down_A[expert_id, :actual_rank, :].copy_( + lora_A_shard, non_blocking=True + ) + return + + raise ValueError(f"Unsupported MoE LoRA projection: {module}") + + @staticmethod + def _select_3d_expert_tensor(tensor: torch.Tensor, expert_id: int) -> torch.Tensor: + return tensor[0 if tensor.shape[0] == 1 else expert_id] + + def _copy_projection( + self, + module: str, + expert_id: int, + actual_rank: int, + lora_A_shard: torch.Tensor, + lora_B_shard: torch.Tensor, + w13_A: torch.Tensor, + w13_B: torch.Tensor, + down_A: torch.Tensor, + down_B: torch.Tensor, + *, + rank: int, + a_expert_id: int | None = None, + b_expert_id: int | None = None, + ) -> None: + a_expert_id = expert_id if a_expert_id is None else a_expert_id + b_expert_id = expert_id if b_expert_id is None else b_expert_id + if module == "gate_proj": + w13_A[a_expert_id, :actual_rank, :].copy_(lora_A_shard, non_blocking=True) + w13_B[ + b_expert_id, + : self.intermediate_per_tp, + :actual_rank, + ].copy_(lora_B_shard, non_blocking=True) + elif module == "up_proj": + w13_A[a_expert_id, rank : rank + actual_rank, :].copy_( + lora_A_shard, non_blocking=True + ) + w13_B[ + b_expert_id, + self.intermediate_per_tp : 2 * self.intermediate_per_tp, + rank : rank + actual_rank, + ].copy_(lora_B_shard, non_blocking=True) + elif module == "down_proj": + down_A[a_expert_id, :actual_rank, :].copy_(lora_A_shard, non_blocking=True) + down_B[b_expert_id, :, :actual_rank].copy_(lora_B_shard, non_blocking=True) + else: + raise ValueError(f"Unsupported MoE LoRA projection: {module}") + + def _dst_expert_id(self, module: str, side: str, expert_id: int) -> int: + if not self.compressed_shared_outer: + return expert_id + if module in ("gate_proj", "up_proj") and side == "A": + return 0 + if module == "down_proj" and side == "B": + return 0 + return expert_id + + def clear_slot(self, slot: int) -> None: + if not self.enabled: + return + for layer_id in range(self.n_layers): + self._clear_layer_slot(layer_id, slot) + for layer_slots in self.weights_by_layer.values(): + layer_slots.pop(slot, None) + + def clear_slot_cpu_only(self, slot: int) -> None: + """Remove slot from CPU-side tracking without GPU zeroing. + + The GPU weight tensors for this slot are NOT zeroed. This is safe + because prepare_loras only assigns weight_indices[i] to slots present + in _name_to_slot, which is cleared before this method is called. + No kernel can read from an evicted slot. Stale GPU values are + overwritten when _load_to_slot reuses the slot for a new adapter. + """ + if not self.enabled: + return + for layer_slots in self.weights_by_layer.values(): + layer_slots.pop(slot, None) + + def build_context( + self, + *, + batch_info: LoraBatchInfo, + scalings: torch.Tensor, + has_active_lora: bool, + ) -> "MoeLoraContext": + return MoeLoraContext( + weights_by_layer=self.weights_by_layer, + batch_info=batch_info, + scalings=scalings, + has_active_lora=has_active_lora, + w13_A_buffers=self.w13_A_buffers if self.enabled else None, + w13_B_buffers=self.w13_B_buffers if self.enabled else None, + down_A_buffers=self.down_A_buffers if self.enabled else None, + down_B_buffers=self.down_B_buffers if self.enabled else None, + _lora_stream=self._lora_stream, + _lora_a_m_buf=self._lora_a_m_buf, + _lora_a_flat_buf=self._lora_a_flat_buf, + _prefetch_flags=self._prefetch_flags, + ) diff --git a/python/tokenspeed/runtime/models/qwen3.py b/python/tokenspeed/runtime/models/qwen3.py index 43465b476..9d3fe0081 100755 --- a/python/tokenspeed/runtime/models/qwen3.py +++ b/python/tokenspeed/runtime/models/qwen3.py @@ -62,11 +62,13 @@ def __init__( intermediate_size: int, hidden_act: str, quant_config: QuantizationConfig | None = None, + layer_id: int = 0, tp_rank: int | None = None, tp_size: int | None = None, tp_group: tuple[int, ...] | None = None, ) -> None: super().__init__() + self.layer_id = layer_id self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, @@ -93,11 +95,17 @@ def __init__( ) self.act_fn = SiluAndMul() - def forward(self, x): + def forward(self, x, ctx: ForwardContext | None = None): gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) - x, _ = self.down_proj(x) - return x + # LoRA delta on the fused gate/up output (added before SiluAndMul, + # matching PEFT semantics). + if ctx is not None and ctx.lora_manager is not None: + gate_up = ctx.lora_manager.apply_gate_up_lora(x, gate_up, self.layer_id) + intermediate = self.act_fn(gate_up) + out, _ = self.down_proj(intermediate) + if ctx is not None and ctx.lora_manager is not None: + out = ctx.lora_manager.apply_down_lora(intermediate, out, self.layer_id) + return out class Qwen3Attention(nn.Module): @@ -119,6 +127,7 @@ def __init__( prefix: str = "", ) -> None: super().__init__() + self.layer_id = layer_id self.mapping = mapping self.hidden_size = hidden_size self.tp_rank = self.mapping.attn.tp_rank @@ -213,6 +222,14 @@ def forward( cos_sin: tuple[torch.Tensor, torch.Tensor] | None = None, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) + + # LoRA delta for Q/K/V projections (segment-grouped Triton path). + # The manager's batch_info holds persistent buffers, so this call + # is safe to record into a CUDA graph: replay updates batch_info + # in place before graph.replay(). + if ctx.lora_manager is not None: + qkv = ctx.lora_manager.apply_qkv_lora(hidden_states, qkv, self.layer_id) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self._apply_qk_norm(q, k) q, k = self.rotary_emb(positions, q, k) @@ -220,6 +237,11 @@ def forward( if len(attn_output.size()) == 3: attn_output = attn_output.reshape(attn_output.shape[0], -1) output, _ = self.o_proj(attn_output) + + # LoRA delta for O projection + if ctx.lora_manager is not None: + output = ctx.lora_manager.apply_o_lora(attn_output, output, self.layer_id) + return output @@ -263,6 +285,7 @@ def __init__( intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, + layer_id=layer_id, tp_rank=self.mapping.dense.tp_rank, tp_size=self.mapping.dense.tp_size, tp_group=self.mapping.dense.tp_group, @@ -327,7 +350,7 @@ def forward( residual, ) ) - hidden_states = self.mlp(hidden_states) + hidden_states = self.mlp(hidden_states, ctx) return hidden_states, residual diff --git a/python/tokenspeed/runtime/models/qwen3_5.py b/python/tokenspeed/runtime/models/qwen3_5.py index a64ed69fd..1bb8266d3 100644 --- a/python/tokenspeed/runtime/models/qwen3_5.py +++ b/python/tokenspeed/runtime/models/qwen3_5.py @@ -713,6 +713,11 @@ def self_attention( """Full attention forward pass.""" qkv, _ = self.qkv_proj(hidden_states) + # Apply QKV LoRA delta (same as qwen3.py; qkv layout matches the buffer + # offsets because q_size_per_tp already accounts for attn_output_gate). + if ctx.lora_manager is not None: + qkv = ctx.lora_manager.apply_qkv_lora(hidden_states, qkv, self.layer_id) + if self.attn_output_gate: q_gate, k, v = qkv.split( [self.q_size * 2, self.kv_size, self.kv_size], dim=-1 @@ -741,6 +746,11 @@ def self_attention( sigmoid_mul(attn_output, gate) output, _ = self.o_proj(attn_output) + + # Apply O-projection LoRA delta. + if ctx.lora_manager is not None: + output = ctx.lora_manager.apply_o_lora(attn_output, output, self.layer_id) + return output def forward( diff --git a/python/tokenspeed/runtime/utils/server_args.py b/python/tokenspeed/runtime/utils/server_args.py index f1da1111a..5d93fdb16 100755 --- a/python/tokenspeed/runtime/utils/server_args.py +++ b/python/tokenspeed/runtime/utils/server_args.py @@ -226,6 +226,30 @@ class ServerArgs: # server started without the matching flag will receive empty logprobs. enable_output_logprobs: bool = False + # LoRA adapter serving + enable_lora: bool = False + # Maximum number of LoRA adapters resident in GPU memory at once. + # Adapters beyond this cap are LRU-evicted to the CPU pool. + max_loras: int = 4 + # Maximum LoRA rank supported (caps adapter loading; larger = more GPU memory). + max_lora_rank: int = 64 + # Maximum number of LoRA adapters cached in CPU pinned memory. When + # an adapter is evicted from this pool it falls back to its disk path + # (assumed durable) and is reloaded on next use. ``None`` ⇒ default + # to ``4 * max_loras``. + max_loras_cpu: int | None = None + # Comma-separated coarse GPU buffer families to allocate for LoRA. + # Valid groups: attn, mlp, moe, lm_head. + lora_buffer_groups: str = "attn,mlp,moe" + # Store 3D MoE shared-outer adapters in compressed shared/per-expert + # buffers instead of fully expanding all sides to num_experts. + lora_moe_compressed_shared_outer: bool = False + # Scheduler-side LoRA scheduling policy. ``"lru"`` (default) just + # relies on the manager's LRU; ``"admission"`` (future) gates batches + # that don't fit in GPU; ``"pack"`` (future) sorts the queue to reuse + # resident adapters. + lora_scheduling_policy: str = "lru" + # Runtime options disable_pdl: bool = False enable_prefix_caching: bool = True @@ -554,6 +578,43 @@ def resolve_communication(self): ) def resolve_disaggregation(self): + if self.enable_lora: + # LoRA delta path is baked into the captured graph: the per-token + # slot index buffer (LoraManager.weight_indices_buf) is bound at + # capture and updated in place at replay. Base/no-LoRA requests + # use NO_LORA_SLOT in metadata and do not consume a GPU slot. + # + # Default the CPU pool to 4× the GPU pool so adapter swap-out + # to disk is rare in steady state. + if self.max_loras_cpu is None: + self.max_loras_cpu = 4 * self.max_loras + if self.max_loras_cpu < self.max_loras: + raise ValueError( + f"max_loras_cpu ({self.max_loras_cpu}) must be ≥ " + f"max_loras ({self.max_loras}) — every GPU-resident " + "adapter must also fit in the CPU pool." + ) + groups = { + group.strip() + for group in self.lora_buffer_groups.split(",") + if group.strip() + } + valid_groups = {"attn", "mlp", "moe", "lm_head"} + unknown_groups = groups - valid_groups + if not groups: + raise ValueError("lora_buffer_groups must include at least one group.") + if unknown_groups: + raise ValueError( + "lora_buffer_groups contains unknown groups: " + f"{sorted(unknown_groups)}. Valid groups: {sorted(valid_groups)}." + ) + self.lora_buffer_groups = ",".join(sorted(groups)) + if self.lora_moe_compressed_shared_outer and "moe" not in groups: + raise ValueError( + "--lora-moe-compressed-shared-outer requires " + "--lora-buffer-groups to include 'moe'." + ) + # PD disaggregation if self.disaggregation_mode == "prefill": self.enforce_eager = True @@ -1465,6 +1526,70 @@ def add_cli_args(parser: argparse.ArgumentParser): action="store_true", help="Disable PDL launch.", ) + # LoRA adapter serving + parser.add_argument( + "--enable-lora", + action="store_true", + default=ServerArgs.enable_lora, + help="Enable LoRA adapter serving.", + ) + parser.add_argument( + "--max-loras", + type=int, + default=ServerArgs.max_loras, + help="Maximum number of LoRA adapters in GPU memory at once.", + ) + parser.add_argument( + "--max-lora-rank", + type=int, + default=ServerArgs.max_lora_rank, + help="Maximum LoRA rank supported across all loaded adapters.", + ) + parser.add_argument( + "--max-loras-cpu", + type=int, + default=ServerArgs.max_loras_cpu, + help=( + "Maximum number of LoRA adapters cached in CPU pinned " + "memory. Defaults to 4 × --max-loras. Adapters evicted " + "from this pool are reloaded from disk on next use." + ), + ) + parser.add_argument( + "--lora-buffer-groups", + type=str, + default=ServerArgs.lora_buffer_groups, + help=( + "Comma-separated LoRA GPU buffer groups to allocate. " + "Valid groups: attn, mlp, moe, lm_head. Loading an adapter that " + "targets a disabled group raises an error." + ), + ) + parser.add_argument( + "--lora-moe-compressed-shared-outer", + action="store_true", + default=ServerArgs.lora_moe_compressed_shared_outer, + help=( + "Use compressed MoE storage for 3D shared-outer adapters " + "(w1/w3 A shared, w1/w3 B per-expert, w2 A per-expert, " + "w2 B shared)." + ), + ) + parser.add_argument( + "--lora-scheduling-policy", + type=str, + default=ServerArgs.lora_scheduling_policy, + choices=["lru", "pack"], + help=( + "Scheduler-side LoRA scheduling policy. ``lru`` (default) " + "submits requests in arrival order and relies on the " + "manager's LRU pool. ``pack`` sorts the admission queue " + "by lora_id so adapter-shared requests cluster, reducing " + "eviction churn when working_set > max_loras_cpu and " + "traffic is bursty." + ), + ) + prefix_cache_group = parser.add_mutually_exclusive_group() prefix_cache_group.add_argument( "--enable-prefix-caching", diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/__init__.py new file mode 100644 index 000000000..6de1c6efd --- /dev/null +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/__init__.py @@ -0,0 +1,57 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Triton kernels for segment-grouped LoRA matmuls. + +Adapted from sglang ``python/sglang/srt/lora/triton_ops/`` (Apache-2.0): +https://github.com/sgl-project/sglang/tree/main/python/sglang/srt/lora/triton_ops. +sglang's kernels in turn descend from the Punica S-LoRA design +(https://github.com/punica-ai/punica). Each batch is a sequence of +segments where each segment uses a single adapter; the kernels fuse the +per-segment GEMMs into a single launch and keep per-segment state +(rank, scaling) on-device. See each kernel module for file-level +provenance. +""" + +from tokenspeed_kernel.ops.lora.triton.lora_expand import lora_expand_fwd +from tokenspeed_kernel.ops.lora.triton.lora_expand_grouped_v2 import ( + lora_expand_grouped_v2_fwd, +) +from tokenspeed_kernel.ops.lora.triton.lora_expand_prefill import ( + lora_expand_prefill_fwd, +) +from tokenspeed_kernel.ops.lora.triton.lora_gate_up_expand import ( + lora_gate_up_expand_fwd, +) +from tokenspeed_kernel.ops.lora.triton.lora_qkv_expand import lora_qkv_expand_fwd +from tokenspeed_kernel.ops.lora.triton.lora_shrink import lora_shrink_fwd +from tokenspeed_kernel.ops.lora.triton.lora_shrink_prefill import ( + lora_shrink_prefill_fwd, +) + +__all__ = [ + "lora_shrink_fwd", + "lora_shrink_prefill_fwd", + "lora_expand_fwd", + "lora_expand_grouped_v2_fwd", + "lora_qkv_expand_fwd", + "lora_gate_up_expand_fwd", + "lora_expand_prefill_fwd", +] diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/configs/H100_80GB_HBM3/_lora_expand_kernel.json b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/configs/H100_80GB_HBM3/_lora_expand_kernel.json new file mode 100644 index 000000000..cc2325080 --- /dev/null +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/configs/H100_80GB_HBM3/_lora_expand_kernel.json @@ -0,0 +1,178 @@ +{ + "(24576, 128, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 64, + "BLOCK_N": 64, + "BLOCK_S": 8 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 3, + "num_warps": 4 + }, + "(24576, 16, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 16, + "BLOCK_N": 128, + "BLOCK_S": 8 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(24576, 32, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 16, + "BLOCK_N": 32, + "BLOCK_S": 8 + }, + "maxnreg": 160, + "num_ctas": 1, + "num_stages": 2, + "num_warps": 4 + }, + "(24576, 64, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 64, + "BLOCK_N": 64, + "BLOCK_S": 8 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 2, + "num_warps": 4 + }, + "(4096, 128, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 64, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(4096, 16, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 16, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 2, + "num_warps": 4 + }, + "(4096, 32, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 32, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(4096, 64, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 64, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(6144, 128, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 64, + "BLOCK_N": 64, + "BLOCK_S": 8 + }, + "maxnreg": 160, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(6144, 16, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 16, + "BLOCK_N": 128, + "BLOCK_S": 8 + }, + "maxnreg": 160, + "num_ctas": 1, + "num_stages": 3, + "num_warps": 4 + }, + "(6144, 32, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 32, + "BLOCK_N": 128, + "BLOCK_S": 8 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(6144, 64, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 16, + "BLOCK_N": 32, + "BLOCK_S": 8 + }, + "maxnreg": 160, + "num_ctas": 1, + "num_stages": 3, + "num_warps": 4 + }, + "(8192, 128, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 64, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(8192, 16, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 16, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 2, + "num_warps": 4 + }, + "(8192, 32, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 32, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(8192, 64, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 64, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + } +} diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/configs/H100_80GB_HBM3/_lora_gate_up_expand_kernel.json b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/configs/H100_80GB_HBM3/_lora_gate_up_expand_kernel.json new file mode 100644 index 000000000..906ea17e7 --- /dev/null +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/configs/H100_80GB_HBM3/_lora_gate_up_expand_kernel.json @@ -0,0 +1,266 @@ +{ + "(12288, 128, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 64, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(12288, 16, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 16, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 2, + "num_warps": 4 + }, + "(12288, 32, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 32, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(12288, 64, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 64, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(14336, 128, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 64, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(14336, 16, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 16, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 2, + "num_warps": 4 + }, + "(14336, 32, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 32, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(14336, 64, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 64, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(3072, 128, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 64, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(3072, 16, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 16, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 2, + "num_warps": 4 + }, + "(3072, 32, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 32, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(3072, 64, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 64, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(3584, 128, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 64, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(3584, 16, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 16, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 2, + "num_warps": 4 + }, + "(3584, 32, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 32, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(3584, 64, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 64, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(6144, 128, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 64, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(6144, 16, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 16, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 2, + "num_warps": 4 + }, + "(6144, 32, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 32, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(6144, 64, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 64, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(7168, 128, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 64, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(7168, 16, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 16, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 2, + "num_warps": 4 + }, + "(7168, 32, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 32, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(7168, 64, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 64, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + } +} diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/configs/H100_80GB_HBM3/_lora_qkv_expand_kernel.json b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/configs/H100_80GB_HBM3/_lora_qkv_expand_kernel.json new file mode 100644 index 000000000..dd2b1a72a --- /dev/null +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/configs/H100_80GB_HBM3/_lora_qkv_expand_kernel.json @@ -0,0 +1,134 @@ +{ + "(1024, 128, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 64, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(1024, 16, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 16, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 2, + "num_warps": 4 + }, + "(1024, 32, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 32, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(1024, 64, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 64, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(2048, 128, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 64, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(2048, 16, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 16, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 2, + "num_warps": 4 + }, + "(2048, 32, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 32, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(2048, 64, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 64, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(4096, 128, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 64, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(4096, 16, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 16, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 2, + "num_warps": 4 + }, + "(4096, 32, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 32, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + }, + "(4096, 64, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.float32')": { + "kwargs": { + "BLOCK_K": 64, + "BLOCK_N": 128, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 1, + "num_warps": 4 + } +} diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/configs/H100_80GB_HBM3/_lora_shrink_kernel.json b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/configs/H100_80GB_HBM3/_lora_shrink_kernel.json new file mode 100644 index 000000000..669dfb53a --- /dev/null +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/configs/H100_80GB_HBM3/_lora_shrink_kernel.json @@ -0,0 +1,541 @@ +{ + "(128, 1024, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 128, + "BLOCK_N": 32, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(128, 12288, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 128, + "BLOCK_N": 32, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(128, 14336, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 128, + "BLOCK_N": 32, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(128, 2048, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 128, + "BLOCK_N": 32, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(128, 3072, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 128, + "BLOCK_N": 32, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(128, 3584, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 128, + "BLOCK_N": 32, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(128, 4096, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 128, + "BLOCK_N": 32, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(128, 6144, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 128, + "BLOCK_N": 32, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(128, 7168, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 128, + "BLOCK_N": 32, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(128, 8192, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 128, + "BLOCK_N": 32, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(16, 1024, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 128, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(16, 12288, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(16, 14336, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(16, 2048, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 128, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(16, 3072, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(16, 3584, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(16, 4096, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(16, 6144, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(16, 7168, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(192, 4096, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 128, + "BLOCK_N": 32, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(192, 8192, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 32, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(256, 4096, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 128, + "BLOCK_N": 32, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(256, 8192, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 32, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(32, 1024, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 3, + "num_warps": 4 + }, + "(32, 12288, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(32, 14336, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(32, 2048, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 128, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(32, 3072, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(32, 3584, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(32, 4096, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(32, 6144, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(32, 7168, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(32, 8192, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(384, 4096, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 32, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 3, + "num_warps": 4 + }, + "(384, 8192, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 32, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(48, 4096, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(48, 8192, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(64, 1024, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 3, + "num_warps": 4 + }, + "(64, 12288, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(64, 14336, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(64, 2048, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(64, 3072, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(64, 3584, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(64, 4096, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(64, 6144, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(64, 7168, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(64, 8192, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 16, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(96, 4096, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 256, + "BLOCK_N": 32, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + }, + "(96, 8192, 'torch.bfloat16', 'torch.bfloat16', 'torch.bfloat16', 'torch.int32', 'torch.int32', 'torch.int32', 'torch.int32')": { + "kwargs": { + "BLOCK_K": 128, + "BLOCK_N": 32, + "BLOCK_S": 16 + }, + "maxnreg": null, + "num_ctas": 1, + "num_stages": 4, + "num_warps": 4 + } +} diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/kernel_utils.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/kernel_utils.py new file mode 100644 index 000000000..8cee6453b --- /dev/null +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/kernel_utils.py @@ -0,0 +1,45 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Shared Triton helpers for the LoRA segmented matmul kernels. + +Adapted from sglang ``python/sglang/srt/lora/triton_ops/kernel_utils.py`` +(Apache-2.0): https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/lora/triton_ops/kernel_utils.py. +""" + +from tokenspeed_kernel._triton import tl, triton + + +@triton.jit +def _resolve_token_positions( + sorted_token_ids, seg_start, s_offset, seg_len, SORTED_BY_ADAPTER: tl.constexpr +): + """Map logical segment offsets to physical token positions. + + When ``SORTED_BY_ADAPTER`` is True the segment is a sorted slice of the + real token grid and ``sorted_token_ids[seg_start + s_offset]`` gives the + physical row index. Otherwise tokens in this segment occupy a + contiguous range starting at ``seg_start``. + """ + if SORTED_BY_ADAPTER: + return tl.load( + sorted_token_ids + seg_start + s_offset, mask=s_offset < seg_len + ).to(tl.int64) + return (seg_start + s_offset).to(tl.int64) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/lora_expand.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/lora_expand.py new file mode 100644 index 000000000..36bf0053a --- /dev/null +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/lora_expand.py @@ -0,0 +1,223 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Segmented LoRA-B matmul (expand: r → out_dim) with fused scale + add. + +Adapted from sglang ``python/sglang/srt/lora/triton_ops/sgemm_lora_b.py`` +(Apache-2.0): https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/lora/triton_ops/sgemm_lora_b.py. +sglang's kernel is descended from the Punica S-LoRA design +(https://github.com/punica-ai/punica). Local changes mirror those in +``lora_shrink.py`` (autotune + on-disk cache, constexpr ordering). +""" + +from __future__ import annotations + +import torch +from tokenspeed_kernel._triton import tl, triton +from tokenspeed_kernel.ops.lora.triton.kernel_utils import _resolve_token_positions +from tokenspeed_kernel.ops.lora.triton.tuning import load_kernel_cache + +# Expand kernel: N = out_dim (large, 4096+), K = max_rank (tiny, 16–128). +# Tile space targets "large N, small K, small S". Mirrors sglang's +# csgmv-expand grid (PR #20391); maxnreg helped with occupancy there. +# +# Profiling (2026-05-19) showed the kernel is instruction/overhead-bound +# (0% memory bandwidth utilisation). Two improvements over the original +# k ∈ {16, 32} space: +# • k=64, 128: when BLOCK_K == rank the inner K-loop runs exactly once, +# eliminating loop overhead and the k-mask predicate entirely. +# • BLOCK_N=128 with num_warps=4: halves CTA count vs BLOCK_N=64, which +# amortises per-CTA fixed cost without increasing register pressure. +_EXPAND_CONFIGS = [ + triton.Config( + {"BLOCK_S": s, "BLOCK_N": n, "BLOCK_K": k}, + num_warps=w, + num_stages=stages, + maxnreg=mr, + ) + for s in (8, 16, 32) + for n in (32, 64, 128) + for k in (16, 32, 64, 128) + for w in (4, 8) + for stages in (1, 2, 3) + for mr in (None, 128, 160) +] + + +@triton.autotune(configs=_EXPAND_CONFIGS, key=["N", "K"], restore_value=["output"]) +@triton.jit +def _lora_expand_kernel( + x, + weights, + output, + N, # out_dim + K, # max_rank + x_stride_0, + x_stride_1, + w_stride_0, + w_stride_1, + w_stride_2, + output_stride_0, + output_stride_1, + seg_lens, + seg_indptr, + weight_indices, + lora_ranks, + sorted_token_ids, + scalings, + SORTED_BY_ADAPTER: tl.constexpr, + BLOCK_S: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + batch_id = tl.program_id(axis=1) + w_index = tl.load(weight_indices + batch_id) + if w_index < 0: + return + rank = tl.load(lora_ranks + w_index) + + # rank == 0 is defensive: leave the base output unchanged. + if rank == 0: + return + + pid = tl.program_id(axis=0) + seg_len = tl.load(seg_lens + batch_id) + if seg_len == 0: + return + seg_start = tl.load(seg_indptr + batch_id) + scaling = tl.load(scalings + w_index) + K = tl.minimum(K, rank) + + num_pid_n = tl.cdiv(N, BLOCK_N) + pid_s = pid // num_pid_n + pid_n = pid % num_pid_n + if pid_s * BLOCK_S >= seg_len: + return + + s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S + n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + k_offset = tl.max_contiguous(tl.arange(0, BLOCK_K), BLOCK_K) + s_physical = _resolve_token_positions( + sorted_token_ids, seg_start, s_offset, seg_len, SORTED_BY_ADAPTER + ) + x_ptrs = x + (s_physical[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1) + w_ptrs = (weights + w_index * w_stride_0) + ( + k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 + ) + + s_mask = s_offset[:, None] < seg_len # hoisted: loop-invariant + n_mask = n_offset[None, :] < N # hoisted: loop-invariant (already was) + partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + k_remaining = K - k * BLOCK_K + x_tile = tl.load( + x_ptrs, + mask=s_mask & (k_offset[None, :] < k_remaining), + other=0.0, + eviction_policy="evict_first", + ) + w_tile = tl.load( + w_ptrs, + mask=(k_offset[:, None] < k_remaining) & n_mask, + other=0.0, + eviction_policy="evict_last", + ) + partial_sum += tl.dot(x_tile, w_tile) + + x_ptrs += BLOCK_K * x_stride_1 + w_ptrs += BLOCK_K * w_stride_2 + + partial_sum *= scaling + partial_sum = partial_sum.to(x.dtype.element_ty) + output_ptr = output + ( + s_physical[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 + ) + output_mask = s_mask & n_mask + partial_sum += tl.load(output_ptr, mask=output_mask, other=0.0) + tl.store(output_ptr, partial_sum, mask=output_mask) + + +def lora_expand_fwd( + x: torch.Tensor, + weights: torch.Tensor, + batch_info, + base_output: torch.Tensor | None = None, +) -> torch.Tensor: + """Run the LoRA-B expand and fuse-add into ``base_output``. + + Args: + x: ``(s, max_rank)`` activations from lora_shrink. + weights: ``(num_lora, out_dim, max_rank)``, contiguous. + batch_info: :class:`LoraBatchInfo` describing the segment layout. + base_output: optional ``(s, out_dim)`` to add into. When ``None``, + allocates a fresh zero-filled output. + + Returns: + ``(s, out_dim)`` (same buffer as ``base_output`` when supplied). + """ + assert x.is_contiguous() + assert weights.is_contiguous() + assert x.dim() == 2 + assert weights.dim() == 3 + + S = x.shape[0] + N = weights.shape[-2] + R = weights.shape[-1] + assert x.shape[-1] == R + + max_len = batch_info.max_len + + def grid(meta): + return ( + triton.cdiv(max_len, meta["BLOCK_S"]) * triton.cdiv(N, meta["BLOCK_N"]), + batch_info.bs, + ) + + if base_output is None: + output = torch.zeros((S, N), device=x.device, dtype=x.dtype) + else: + output = base_output + + sorted_by_adapter = batch_info.permutation is not None + _lora_expand_kernel[grid]( + x, + weights, + output, + N, + R, + x.stride(0), + x.stride(1), + weights.stride(0), + weights.stride(1), + weights.stride(2), + output.stride(0), + output.stride(1), + batch_info.seg_lens, + batch_info.seg_indptr, + batch_info.weight_indices, + batch_info.lora_ranks, + batch_info.permutation, + batch_info.scalings, + sorted_by_adapter, + ) + return output + + +load_kernel_cache(_lora_expand_kernel) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/lora_expand_grouped_v2.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/lora_expand_grouped_v2.py new file mode 100644 index 000000000..2d49b7896 --- /dev/null +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/lora_expand_grouped_v2.py @@ -0,0 +1,236 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Adapter-grouped LoRA-B expand without gather/scatter overhead. + +Adapts vLLM's token-sorted dispatch pattern (PR vllm-project/vllm#..., +Apache-2.0) to our kernel infrastructure. + +This kernel reads ``x`` and writes ``output`` directly at the original + (unsorted) token positions using ``token_indices`` loaded inside the kernel. + No gather/scatter needed — only a cheap pointer indirection per tile. + +Grid: ``(cdiv(N, BLOCK_N), num_groups)`` — axis 1 = unique adapter count. +Within each CTA, groups of ``BLOCK_S`` tokens are processed; each group loads +``BLOCK_S`` scattered rows from ``x`` via ``token_indices``. + +Adapted from vLLM ``vllm/lora/ops/triton_ops/lora_expand_op.py`` (Apache-2.0): +https://github.com/vllm-project/vllm/blob/main/vllm/lora/ops/triton_ops/lora_expand_op.py +Local changes: removed SPLIT_K / PDL / CAST_TYPE / multi-slice indirection; +added BLOCK_K ∈ {16,32,64,128} + tl.multiple_of EVEN_K; adopted our +eviction-policy hints and autotune + on-disk cache infrastructure. +""" + +from __future__ import annotations + +import torch +from tokenspeed_kernel._triton import tl, triton +from tokenspeed_kernel.ops.lora.triton.tuning import load_kernel_cache + +_GROUPED_V2_CONFIGS = [ + triton.Config( + {"BLOCK_S": s, "BLOCK_N": n, "BLOCK_K": k}, + num_warps=w, + num_stages=stages, + maxnreg=mr, + ) + for s in (16, 32) + for n in (32, 64, 128) + for k in (16, 32, 64, 128) + for w in (4, 8) + for stages in (1, 2, 3) + for mr in (None, 128, 160) +] + + +@triton.autotune( + configs=_GROUPED_V2_CONFIGS, + key=["N", "MAX_RANK"], + restore_value=["output"], +) +@triton.jit(do_not_specialize=["output_stride_0", "output_stride_1"]) +def _lora_expand_grouped_v2_kernel( + x, # (M, MAX_RANK) original unsorted token order + weights, # (n_slots, N, MAX_RANK) + output, # (M, N) written at original token positions + group_slots, # (num_groups,) int32 — weight-slot index per group + group_starts, # (num_groups,) int32 — start in token_indices + group_sizes, # (num_groups,) int32 — tokens per group + token_indices, # (M,) int32 — token positions sorted by adapter + scalings, # (n_slots,) float32 + lora_ranks, # (n_slots,) int32 + output_stride_0, + output_stride_1, + N: tl.constexpr, + MAX_RANK: tl.constexpr, + BLOCK_S: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + # Constexpr strides — x and weights are always contiguous. + x_stride_0: tl.constexpr = MAX_RANK + x_stride_1: tl.constexpr = 1 + w_stride_0: tl.constexpr = N * MAX_RANK + w_stride_1: tl.constexpr = MAX_RANK # row stride inside (N, MAX_RANK) slice + w_stride_2: tl.constexpr = 1 + + group_id = tl.program_id(axis=1) + # axis=0 encodes both the within-group M-tile and the N-tile. + # Grid: (cdiv(M, BLOCK_S) * cdiv(N, BLOCK_N), num_groups) — mirrors vLLM's + # (M_tiles × N_tiles, num_active_loras) layout. CTAs whose M-tile exceeds + # the group size exit immediately (same early-exit pattern as vLLM). + pid_flat = tl.program_id(axis=0) + cta_n_num = tl.cdiv(N, BLOCK_N) + pid_m = pid_flat // cta_n_num + pid_n = pid_flat % cta_n_num + + w_index = tl.load(group_slots + group_id) + if w_index < 0: + return + g_size = tl.load(group_sizes + group_id) + if g_size == 0: + return + rank = tl.load(lora_ranks + w_index) + if rank == 0: + return + + m_off = pid_m * BLOCK_S + if m_off >= g_size: + return # early exit for M-tiles beyond this group's token count + + g_start = tl.load(group_starts + group_id) + scaling = tl.load(scalings + w_index) + K = tl.minimum(MAX_RANK, rank) + + n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + k_offset = tl.max_contiguous(tl.arange(0, BLOCK_K), BLOCK_K) + n_mask = n_offset[None, :] < N + + # Load physical token positions for this M-tile. + s_offset = tl.arange(0, BLOCK_S) + m_valid = s_offset < g_size - m_off + tok_ptrs = token_indices + g_start + m_off + s_offset + ram = tl.load(tok_ptrs, mask=m_valid, other=0) + s_valid = m_valid[:, None] + + # Scattered read of x — no pre-gather needed. + x_ptrs = x + ram[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1 + w_ptrs = (weights + w_index * w_stride_0) + ( + k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 + ) + + partial = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + k_rem = K - k * BLOCK_K + x_tile = tl.load( + x_ptrs, + mask=s_valid & (k_offset[None, :] < k_rem), + other=0.0, + eviction_policy="evict_first", + ) + w_tile = tl.load( + w_ptrs, + mask=(k_offset[:, None] < k_rem) & n_mask, + other=0.0, + eviction_policy="evict_last", + ) + partial += tl.dot(x_tile, w_tile) + x_ptrs += BLOCK_K * x_stride_1 + w_ptrs += BLOCK_K * w_stride_2 + + partial *= scaling + partial = partial.to(x.dtype.element_ty) + + # Scattered write — no post-scatter needed. + out_ptrs = ( + output + ram[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 + ) + out_mask = s_valid & n_mask + partial += tl.load(out_ptrs, mask=out_mask, other=0.0) + tl.store(out_ptrs, partial, mask=out_mask) + + +def lora_expand_grouped_v2_fwd( + x: torch.Tensor, + weights: torch.Tensor, + batch_info, + base_output: torch.Tensor | None = None, +) -> torch.Tensor: + """Adapter-grouped expand without gather/scatter. + + Reads ``x`` and writes ``output`` at original token positions using + ``batch_info.token_indices`` (sorted by adapter). Requires batch_info to + have the adapter-group metadata populated by ``prepare_loras``: + ``token_indices``, ``group_slots``, ``group_starts``, ``group_sizes``, + ``num_groups``. + + Drops in for :func:`lora_expand_fwd` when ``batch_info.num_groups > 0`` + and ``batch_info.bs // batch_info.num_groups >= 8``. + """ + assert x.is_contiguous() + assert weights.is_contiguous() + + S, R = x.shape + N = weights.shape[-2] + dev, dt = x.device, x.dtype + + num_groups = batch_info.num_groups + + # Use the largest group size for the M dimension, not the total batch size. + # This makes the grid tight for both extremes: + # • n_unique = n (all different): max_group_size = 1 + # → grid = (1 × cdiv(N,BLOCK_N), n) ≡ segmented layout, zero wasted CTAs + # • n_unique = 1 (all same): max_group_size = n + # → grid = (n/BLOCK_S × cdiv(N,BLOCK_N), 1) ≡ grpv2 layout + # max_group_size is pre-computed on CPU in prepare_loras — no GPU sync here. + max_group_size = batch_info.max_group_size + + def grid(meta): + return ( + triton.cdiv(max_group_size, meta["BLOCK_S"]) + * triton.cdiv(N, meta["BLOCK_N"]), + num_groups, + ) + + output = ( + torch.zeros((S, N), device=dev, dtype=dt) + if base_output is None + else base_output + ) + + _lora_expand_grouped_v2_kernel[grid]( + x, + weights, + output, + batch_info.group_slots[:num_groups], + batch_info.group_starts[:num_groups], + batch_info.group_sizes[:num_groups], + batch_info.sort_order[: batch_info.bs], # token_indices sorted by adapter + batch_info.scalings, + batch_info.lora_ranks, + output.stride(0), + output.stride(1), + N=N, + MAX_RANK=R, + ) + return output + + +load_kernel_cache(_lora_expand_grouped_v2_kernel) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/lora_expand_prefill.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/lora_expand_prefill.py new file mode 100644 index 000000000..ceed827c9 --- /dev/null +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/lora_expand_prefill.py @@ -0,0 +1,253 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Unified LoRA-B expand for prefill batches (chunked-SGMV style). + +Replaces the three separate ``lora_expand`` / ``lora_qkv_expand`` / +``lora_gate_up_expand`` kernels for the prefill path. A single kernel +handles any number of output slices via the ``NUM_SLICES`` constexpr and a +``slice_offsets`` boundary tensor — the same trick as sglang's +``chunked_sgmv_expand`` (PR sgl-project/sglang#20391). + +Key structural difference from the decode-path expand kernels: +* ``OUTPUT_DIM``, ``MAX_RANK``, ``NUM_SLICES`` are **constexpr** — the + compiler specialises the K-loop trip count and all strides at compile + time, which gives 2–3× speedup over runtime-stride kernels at prefill + with rank ≥ 64. +* x strides are derived as compile-time constants: + ``x_stride_0 = NUM_SLICES * MAX_RANK``, ``x_stride_1 = 1``. + +Use :func:`lora_expand_fwd` / :func:`lora_qkv_expand_fwd` / +:func:`lora_gate_up_expand_fwd` for decode (``max_len ≤ 32``); switch to +:func:`lora_expand_prefill_fwd` for prefill. + +Adapted from sglang ``python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py`` +(previously ``chunked_sgmv_expand.py`` in this repo) +(Apache-2.0): https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/lora/triton_ops/chunked_sgmv_expand.py. +Local changes: merged SORTED_BY_ADAPTER from our decode kernels (avoids +permutation overhead for unsorted batches), replaced fixed configs with +``@triton.autotune`` + on-disk cache, constexpr ordering. +""" + +from __future__ import annotations + +import torch +from tokenspeed_kernel._triton import tl, triton +from tokenspeed_kernel.ops.lora.triton.kernel_utils import _resolve_token_positions +from tokenspeed_kernel.ops.lora.triton.tuning import load_kernel_cache + +_PREFILL_EXPAND_CONFIGS = [ + triton.Config( + {"BLOCK_S": s, "BLOCK_N": n, "BLOCK_K": k}, + num_warps=w, + num_stages=stages, + maxnreg=mr, + ) + for s in (16, 32) + for n in (32, 64, 128) + for k in (16, 32) + for w in (4, 8) + for stages in (1, 2, 3) + for mr in (None, 128, 160) +] + + +@triton.autotune( + configs=_PREFILL_EXPAND_CONFIGS, + key=["OUTPUT_DIM", "MAX_RANK", "NUM_SLICES"], + restore_value=["output"], +) +@triton.jit(do_not_specialize=["output_stride_0", "output_stride_1"]) +def _lora_expand_prefill_kernel( + x, + weights, + output, + output_stride_0, + output_stride_1, + seg_lens, + seg_indptr, + weight_indices, + lora_ranks, + sorted_token_ids, + scalings, + slice_offsets, + NUM_SLICES: tl.constexpr, + OUTPUT_DIM: tl.constexpr, + MAX_RANK: tl.constexpr, + SORTED_BY_ADAPTER: tl.constexpr, + BLOCK_S: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + # Constexpr strides — compiler eliminates all stride multiplications. + x_stride_0: tl.constexpr = NUM_SLICES * MAX_RANK + x_stride_1: tl.constexpr = 1 + w_stride_0: tl.constexpr = OUTPUT_DIM * MAX_RANK + w_stride_1: tl.constexpr = MAX_RANK + w_stride_2: tl.constexpr = 1 + + batch_id = tl.program_id(axis=2) + w_index = tl.load(weight_indices + batch_id) + if w_index < 0: + return + rank = tl.load(lora_ranks + w_index) + if rank == 0: + return + + slice_id = tl.program_id(axis=1) + pid = tl.program_id(axis=0) + seg_len = tl.load(seg_lens + batch_id) + if seg_len == 0: + return + seg_start = tl.load(seg_indptr + batch_id) + slice_start = tl.load(slice_offsets + slice_id) + slice_end = tl.load(slice_offsets + slice_id + 1) + n_size = slice_end - slice_start + scaling = tl.load(scalings + w_index) + K = tl.minimum(MAX_RANK, rank) + + num_pid_n = tl.cdiv(n_size, BLOCK_N) + pid_s = pid // num_pid_n + pid_n = pid % num_pid_n + if pid_s * BLOCK_S >= seg_len: + return + + s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S + n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + k_offset = tl.arange(0, BLOCK_K) + + s_physical = _resolve_token_positions( + sorted_token_ids, seg_start, s_offset, seg_len, SORTED_BY_ADAPTER + ) + + # x: slice i starts at column i * K (actual rank, not MAX_RANK). + x_ptrs = ( + x + + slice_id * K * x_stride_1 + + (s_physical[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1) + ) + w_ptrs = (weights + w_index * w_stride_0 + slice_start * w_stride_1) + ( + k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 + ) + + n_mask = n_offset[None, :] < n_size + partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + x_tile = tl.load( + x_ptrs, + mask=(s_offset[:, None] < seg_len) & (k_offset[None, :] < K - k * BLOCK_K), + other=0.0, + ) + w_tile = tl.load( + w_ptrs, + mask=(k_offset[:, None] < K - k * BLOCK_K) & n_mask, + other=0.0, + ) + partial_sum += tl.dot(x_tile, w_tile) + x_ptrs += BLOCK_K * x_stride_1 + w_ptrs += BLOCK_K * w_stride_2 + + partial_sum *= scaling + partial_sum = partial_sum.to(x.dtype.element_ty) + + output_ptr = output + ( + s_physical[:, None] * output_stride_0 + + (slice_start + n_offset)[None, :] * output_stride_1 + ) + output_mask = (s_offset[:, None] < seg_len) & n_mask + partial_sum += tl.load(output_ptr, mask=output_mask, other=0.0) + tl.store(output_ptr, partial_sum, mask=output_mask) + + +def lora_expand_prefill_fwd( + x: torch.Tensor, + weights: torch.Tensor, + batch_info, + slice_offsets: torch.Tensor, + max_slice_size: int, + base_output: torch.Tensor | None = None, +) -> torch.Tensor: + """Prefill-optimised LoRA-B expand for one or more output slices. + + Covers all projection types via ``slice_offsets``: + * plain expand (o/down): ``slice_offsets = [0, out_dim]`` + * gate/up: ``slice_offsets = [0, inter, 2*inter]`` + * QKV: ``slice_offsets = [0, q, q+kv, q+2*kv]`` + + Args: + x: ``(s, num_slices * max_rank)`` from lora_shrink. + weights: ``(num_lora, out_dim, max_rank)``, contiguous. + batch_info: :class:`LoraBatchInfo`. + slice_offsets: ``(num_slices + 1,)`` int32 boundary tensor. + max_slice_size: largest ``slice_offsets[i+1] - slice_offsets[i]``. + base_output: ``(s, out_dim)`` to fuse-add into; allocated if None. + + Returns: + ``(s, out_dim)`` (same buffer as ``base_output`` when supplied). + """ + assert x.is_contiguous() + assert weights.is_contiguous() + assert x.dim() == 2 + assert weights.dim() == 3 + + S = x.shape[0] + OUT_DIM = weights.shape[-2] + MAX_RANK = weights.shape[-1] + num_slices = len(slice_offsets) - 1 + assert x.shape[1] == num_slices * MAX_RANK + + max_len = batch_info.max_len + sorted_by_adapter = batch_info.permutation is not None + + def grid(meta): + return ( + triton.cdiv(max_len, meta["BLOCK_S"]) + * triton.cdiv(max_slice_size, meta["BLOCK_N"]), + num_slices, + batch_info.bs, + ) + + output = ( + torch.zeros((S, OUT_DIM), device=x.device, dtype=x.dtype) + if base_output is None + else base_output + ) + _lora_expand_prefill_kernel[grid]( + x, + weights, + output, + output.stride(0), + output.stride(1), + batch_info.seg_lens, + batch_info.seg_indptr, + batch_info.weight_indices, + batch_info.lora_ranks, + batch_info.permutation, + batch_info.scalings, + slice_offsets, + NUM_SLICES=num_slices, + OUTPUT_DIM=OUT_DIM, + MAX_RANK=MAX_RANK, + SORTED_BY_ADAPTER=sorted_by_adapter, + ) + return output + + +load_kernel_cache(_lora_expand_prefill_kernel) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/lora_gate_up_expand.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/lora_gate_up_expand.py new file mode 100644 index 000000000..caecf635e --- /dev/null +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/lora_gate_up_expand.py @@ -0,0 +1,225 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Fused LoRA-B expand for stacked gate/up projections (MLP). + +The MLP gate_up linear is fused into a single matmul with output layout +``[gate_per_tp, up_per_tp]`` (each of size ``intermediate_per_tp``). +This kernel packs the two B projections into one launch: each program +instance picks ``gate`` (axis=1, id=0) or ``up`` (id=1) and writes its +tile into the matching half of the fused output. + +Adapted from sglang ``python/sglang/srt/lora/triton_ops/gate_up_lora_b.py`` +(Apache-2.0): https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/lora/triton_ops/gate_up_lora_b.py. +Local changes: autotune + on-disk cache, constexpr ordering. +""" + +from __future__ import annotations + +import torch +from tokenspeed_kernel._triton import tl, triton +from tokenspeed_kernel.ops.lora.triton.kernel_utils import _resolve_token_positions +from tokenspeed_kernel.ops.lora.triton.tuning import load_kernel_cache + +_GATE_UP_EXPAND_CONFIGS = [ + triton.Config( + {"BLOCK_S": s, "BLOCK_N": n, "BLOCK_K": k}, + num_warps=w, + num_stages=stages, + maxnreg=mr, + ) + for s in (16, 32) + for n in (32, 64, 128) + for k in (16, 32, 64, 128) + for w in (4, 8) + for stages in (1, 2, 3) + for mr in (None, 128, 160) +] + + +@triton.autotune( + configs=_GATE_UP_EXPAND_CONFIGS, + key=["output_dim", "K"], + restore_value=["output"], +) +@triton.jit +def _lora_gate_up_expand_kernel( + x, + weights, + output, + K, # max_rank + output_dim, # intermediate_per_tp + x_stride_0, + x_stride_1, + w_stride_0, + w_stride_1, + w_stride_2, + output_stride_0, + output_stride_1, + seg_lens, + seg_indptr, + weight_indices, + lora_ranks, + sorted_token_ids, + scalings, + SORTED_BY_ADAPTER: tl.constexpr, + BLOCK_S: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + batch_id = tl.program_id(axis=2) + w_index = tl.load(weight_indices + batch_id) + if w_index < 0: + return + rank = tl.load(lora_ranks + w_index) + if rank == 0: + return + + gate_up_id = tl.program_id(axis=1) + pid = tl.program_id(axis=0) + seg_len = tl.load(seg_lens + batch_id) + if seg_len == 0: + return + seg_start = tl.load(seg_indptr + batch_id) + n_start = gate_up_id * output_dim + scaling = tl.load(scalings + w_index) + K = tl.minimum(K, rank) + + num_pid_n = tl.cdiv(output_dim, BLOCK_N) + pid_s = pid // num_pid_n + pid_n = pid % num_pid_n + if pid_s * BLOCK_S >= seg_len: + return + + s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S + n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + k_offset = tl.arange(0, BLOCK_K) + + s_physical = _resolve_token_positions( + sorted_token_ids, seg_start, s_offset, seg_len, SORTED_BY_ADAPTER + ) + x_ptrs = ( + x + + (gate_up_id * K) * x_stride_1 + + (s_physical[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1) + ) + w_ptrs = (weights + w_index * w_stride_0 + n_start * w_stride_1) + ( + k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 + ) + + s_mask = s_offset[:, None] < seg_len + n_mask = n_offset[None, :] < output_dim + partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + k_remaining = K - k * BLOCK_K + x_tile = tl.load( + x_ptrs, + mask=s_mask & (k_offset[None, :] < k_remaining), + other=0.0, + eviction_policy="evict_first", + ) + w_tile = tl.load( + w_ptrs, + mask=(k_offset[:, None] < k_remaining) & n_mask, + other=0.0, + eviction_policy="evict_last", + ) + partial_sum += tl.dot(x_tile, w_tile) + + x_ptrs += BLOCK_K * x_stride_1 + w_ptrs += BLOCK_K * w_stride_2 + + partial_sum *= scaling + partial_sum = partial_sum.to(x.dtype.element_ty) + output_ptr = ( + output + + n_start * output_stride_1 + + (s_physical[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1) + ) + output_mask = s_mask & n_mask + partial_sum += tl.load(output_ptr, mask=output_mask, other=0.0) + tl.store(output_ptr, partial_sum, mask=output_mask) + + +def lora_gate_up_expand_fwd( + x: torch.Tensor, + gate_up_lora_b: torch.Tensor, + batch_info, + output_dim: int, + base_output: torch.Tensor | None = None, +) -> torch.Tensor: + """Apply LoRA-B for the fused gate_up MLP linear, fuse-add into ``base_output``. + + Args: + x: ``(s, 2 * max_rank)`` from ``lora_shrink_fwd(stack_num=2)`` — + gate's lora_a in cols ``[:, :r]``, up's in ``[:, r:]``. + gate_up_lora_b: ``(num_lora, 2 * intermediate_per_tp, max_rank)`` + — gate's B in rows ``[:, :out, :]``, up's in ``[:, out:, :]``. + batch_info: :class:`LoraBatchInfo`. + output_dim: ``intermediate_per_tp``. + base_output: ``(s, 2 * intermediate_per_tp)`` to fuse-add into. + """ + s = x.shape[0] + input_dim = x.shape[1] + r = gate_up_lora_b.shape[-1] + assert input_dim == 2 * r + + max_len = batch_info.max_len + + def grid(meta): + return ( + triton.cdiv(max_len, meta["BLOCK_S"]) + * triton.cdiv(output_dim, meta["BLOCK_N"]), + 2, + batch_info.bs, + ) + + if base_output is None: + output = torch.zeros((s, 2 * output_dim), device=x.device, dtype=x.dtype) + else: + output = base_output + + sorted_by_adapter = batch_info.permutation is not None + _lora_gate_up_expand_kernel[grid]( + x, + gate_up_lora_b, + output, + r, + output_dim, + x.stride(0), + x.stride(1), + gate_up_lora_b.stride(0), + gate_up_lora_b.stride(1), + gate_up_lora_b.stride(2), + output.stride(0), + output.stride(1), + batch_info.seg_lens, + batch_info.seg_indptr, + batch_info.weight_indices, + batch_info.lora_ranks, + batch_info.permutation, + batch_info.scalings, + sorted_by_adapter, + ) + + return output + + +load_kernel_cache(_lora_gate_up_expand_kernel) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/lora_qkv_expand.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/lora_qkv_expand.py new file mode 100644 index 000000000..4bed480cf --- /dev/null +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/lora_qkv_expand.py @@ -0,0 +1,229 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Fused LoRA-B expand for stacked Q/K/V projections. + +The QKV linear is fused into a single matmul with output layout +``[q_per_tp, k_per_tp, v_per_tp]``. This kernel packs the three B +projections into one launch: each program instance picks ``q``, ``k``, or +``v`` via ``program_id(1)`` and writes its tile into the matching slice of +the fused output. + +Adapted from sglang ``python/sglang/srt/lora/triton_ops/qkv_lora_b.py`` +(Apache-2.0): https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/lora/triton_ops/qkv_lora_b.py. +Local changes: autotune + on-disk cache, constexpr ordering. +""" + +from __future__ import annotations + +import torch +from tokenspeed_kernel._triton import tl, triton +from tokenspeed_kernel.ops.lora.triton.kernel_utils import _resolve_token_positions +from tokenspeed_kernel.ops.lora.triton.tuning import load_kernel_cache + +_QKV_EXPAND_CONFIGS = [ + triton.Config( + {"BLOCK_S": s, "BLOCK_N": n, "BLOCK_K": k}, + num_warps=w, + num_stages=stages, + maxnreg=mr, + ) + for s in (16, 32) + for n in (32, 64, 128) + for k in (16, 32, 64, 128) + for w in (4, 8) + for stages in (1, 2, 3) + for mr in (None, 128, 160) +] + + +@triton.autotune( + configs=_QKV_EXPAND_CONFIGS, + key=["max_qkv_out_dim", "K"], + restore_value=["output"], +) +@triton.jit +def _lora_qkv_expand_kernel( + x, + weights, + output, + K, # max_rank + max_qkv_out_dim, # max(q_per_tp, kv_per_tp) + x_stride_0, + x_stride_1, + w_stride_0, + w_stride_1, + w_stride_2, + output_stride_0, + output_stride_1, + seg_lens, + seg_indptr, + weight_indices, + lora_ranks, + n_offs, # (4,) cumulative offsets into the fused QKV output + sorted_token_ids, + scalings, + SORTED_BY_ADAPTER: tl.constexpr, + BLOCK_S: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + batch_id = tl.program_id(axis=2) + w_index = tl.load(weight_indices + batch_id) + if w_index < 0: + return + rank = tl.load(lora_ranks + w_index) + if rank == 0: + return + + qkv_id = tl.program_id(axis=1) + pid = tl.program_id(axis=0) + seg_len = tl.load(seg_lens + batch_id) + if seg_len == 0: + return + seg_start = tl.load(seg_indptr + batch_id) + n_start = tl.load(n_offs + qkv_id) + n_size = tl.load(n_offs + qkv_id + 1) - n_start + scaling = tl.load(scalings + w_index) + K = tl.minimum(K, rank) + + num_pid_n = tl.cdiv(max_qkv_out_dim, BLOCK_N) + pid_s = pid // num_pid_n + pid_n = pid % num_pid_n + if pid_s * BLOCK_S >= seg_len: + return + + s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S + n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + k_offset = tl.arange(0, BLOCK_K) + + s_physical = _resolve_token_positions( + sorted_token_ids, seg_start, s_offset, seg_len, SORTED_BY_ADAPTER + ) + x_ptrs = ( + x + + (qkv_id * K) * x_stride_1 + + (s_physical[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1) + ) + w_ptrs = (weights + w_index * w_stride_0 + n_start * w_stride_1) + ( + k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 + ) + + s_mask = s_offset[:, None] < seg_len + n_mask = n_offset[None, :] < n_size + partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_K)): + k_remaining = K - k * BLOCK_K + x_tile = tl.load( + x_ptrs, + mask=s_mask & (k_offset[None, :] < k_remaining), + other=0.0, + eviction_policy="evict_first", + ) + w_tile = tl.load( + w_ptrs, + mask=(k_offset[:, None] < k_remaining) & n_mask, + other=0.0, + eviction_policy="evict_last", + ) + partial_sum += tl.dot(x_tile, w_tile) + + x_ptrs += BLOCK_K * x_stride_1 + w_ptrs += BLOCK_K * w_stride_2 + + partial_sum *= scaling + partial_sum = partial_sum.to(x.dtype.element_ty) + output_ptr = ( + output + + n_start * output_stride_1 + + (s_physical[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1) + ) + output_mask = s_mask & n_mask + partial_sum += tl.load(output_ptr, mask=output_mask, other=0.0) + tl.store(output_ptr, partial_sum, mask=output_mask) + + +def lora_qkv_expand_fwd( + x: torch.Tensor, + qkv_lora_b: torch.Tensor, + batch_info, + output_offset: torch.Tensor, + max_qkv_out_dim: int, + base_output: torch.Tensor | None = None, +) -> torch.Tensor: + """Apply LoRA-B for the fused QKV linear, fused-add into ``base_output``. + + Args: + x: ``(s, 3 * max_rank)`` from ``lora_shrink_fwd(stack_num=3)``. + qkv_lora_b: ``(num_lora, q_per_tp + 2 * kv_per_tp, max_rank)``. + batch_info: :class:`LoraBatchInfo`. + output_offset: ``(4,)`` cumulative offsets ``[0, q, q+kv, q+2*kv]``. + max_qkv_out_dim: ``max(q_per_tp, kv_per_tp)`` — used to size the grid. + base_output: ``(s, q_per_tp + 2 * kv_per_tp)`` to fuse-add into. + """ + s = x.shape[0] + input_dim = x.shape[1] + r = qkv_lora_b.shape[-1] + output_dim = qkv_lora_b.shape[-2] + assert input_dim == 3 * r + assert output_offset.shape[0] == 4 + + max_len = batch_info.max_len + + def grid(meta): + return ( + triton.cdiv(max_len, meta["BLOCK_S"]) + * triton.cdiv(max_qkv_out_dim, meta["BLOCK_N"]), + 3, + batch_info.bs, + ) + + if base_output is None: + output = torch.zeros((s, output_dim), device=x.device, dtype=x.dtype) + else: + output = base_output + + sorted_by_adapter = batch_info.permutation is not None + _lora_qkv_expand_kernel[grid]( + x, + qkv_lora_b, + output, + r, + max_qkv_out_dim, + x.stride(0), + x.stride(1), + qkv_lora_b.stride(0), + qkv_lora_b.stride(1), + qkv_lora_b.stride(2), + output.stride(0), + output.stride(1), + batch_info.seg_lens, + batch_info.seg_indptr, + batch_info.weight_indices, + batch_info.lora_ranks, + output_offset, + batch_info.permutation, + batch_info.scalings, + sorted_by_adapter, + ) + return output + + +load_kernel_cache(_lora_qkv_expand_kernel) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/lora_shrink.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/lora_shrink.py new file mode 100644 index 000000000..0c571f8df --- /dev/null +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/lora_shrink.py @@ -0,0 +1,229 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Segmented LoRA-A matmul (shrink: in_dim → r). + +For each segment ``b`` in the batch the kernel computes +``output[seg_b] = x[seg_b] @ A[wi_b].T`` where ``A[wi_b]`` has shape +``(stack_num * r, in_dim)``. No-adapter segments use a negative slot +sentinel; the kernel returns immediately for that slot, leaving the output +rows untouched. Real slots may have varying real ranks up to +``max_rank``; ``output[..., :rank * stack_num]`` stores the real product +and ``output[..., rank * stack_num:]`` is irrelevant — the consumer +(``lora_expand`` / ``lora_qkv_expand``) reads only the first ``rank * stack_num`` +columns. + +Adapted from sglang ``python/sglang/srt/lora/triton_ops/sgemm_lora_a.py`` +(Apache-2.0): https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/lora/triton_ops/sgemm_lora_a.py. +sglang's kernel is in turn descended from the Punica S-LoRA design +(https://github.com/punica-ai/punica). Local changes: ported to +``tokenspeed_kernel._triton``, added ``@triton.autotune`` over the +``(N, K)`` shape with an on-disk config cache, and reshuffled the +constexpr params so block sizes come last. +""" + +from __future__ import annotations + +import torch +from tokenspeed_kernel._triton import tl, triton +from tokenspeed_kernel.ops.lora.triton.kernel_utils import _resolve_token_positions +from tokenspeed_kernel.ops.lora.triton.tuning import load_kernel_cache + +# Shrink kernel: N = stack_num * rank (tiny, 16–192), K = in_dim (large, +# 4096+). Decode-step segments are short (S = 1–32 per segment), so the +# right tile shape is "small N, large K, small S". Sweep matches the +# sglang csgmv-shrink space (PR sgl-project/sglang#20391) plus a BLOCK_S +# axis since our kernel exposes it. 72 configs. +_SHRINK_CONFIGS = [ + triton.Config( + {"BLOCK_S": s, "BLOCK_N": n, "BLOCK_K": k}, num_warps=w, num_stages=stages + ) + for s in (8, 16, 32) + for n in (16, 32, 64) + for k in (64, 128, 256) + for w in (4, 8) + for stages in (2, 3, 4) +] + + +@triton.autotune(configs=_SHRINK_CONFIGS, key=["N", "K"]) +@triton.jit +def _lora_shrink_kernel( + x, + weights, + output, + N, # stack_num * max_rank + K, # in_dim + stack_num, + x_stride_0, + x_stride_1, + w_stride_0, + w_stride_1, + w_stride_2, + output_stride_0, + output_stride_1, + seg_lens, + seg_indptr, + weight_indices, + lora_ranks, + sorted_token_ids, + SORTED_BY_ADAPTER: tl.constexpr, + BLOCK_S: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + batch_id = tl.program_id(axis=1) + w_index = tl.load(weight_indices + batch_id) + if w_index < 0: + return + rank = tl.load(lora_ranks + w_index) + + # rank == 0 is defensive: skip and leave the output untouched + # (downstream lora_expand / lora_qkv_expand is also a no-op for rank == 0 + # so the leftover values never feed into the base-output add). + if rank == 0: + return + + pid = tl.program_id(axis=0) + seg_start = tl.load(seg_indptr + batch_id) + seg_len = tl.load(seg_lens + batch_id) + if seg_len == 0: + return + + # Cap N to the real ``stack_num * rank`` for this adapter. + N = tl.minimum(N, rank * stack_num) + + num_pid_n = tl.cdiv(N, BLOCK_N) + pid_s = pid // num_pid_n + pid_n = pid % num_pid_n + if pid_s * BLOCK_S >= seg_len: + return + + s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S + n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + k_offset = tl.max_contiguous(tl.arange(0, BLOCK_K), BLOCK_K) + s_physical = _resolve_token_positions( + sorted_token_ids, seg_start, s_offset, seg_len, SORTED_BY_ADAPTER + ) + x_ptrs = x + (s_physical[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1) + w_ptrs = (weights + w_index * w_stride_0) + ( + k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 + ) + + # Hoist loop-invariant masks — s_mask and n_mask don't change across K + # iterations so computing them once saves instructions in the hot loop. + s_mask = s_offset[:, None] < seg_len # (BLOCK_S, 1) + n_mask = n_offset[None, :] < N # (1, BLOCK_N) + + K = tl.multiple_of(K, BLOCK_K) + partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + for k in range(0, K // BLOCK_K): + x_tile = tl.load( + x_ptrs, + mask=s_mask, + other=0.0, + eviction_policy="evict_first", + ) + w_tile = tl.load( + w_ptrs, + mask=n_mask, + other=0.0, + eviction_policy="evict_last", + ) + partial_sum += tl.dot(x_tile, w_tile) + + x_ptrs += BLOCK_K * x_stride_1 + w_ptrs += BLOCK_K * w_stride_2 + + partial_sum = partial_sum.to(x.dtype.element_ty) + output_mask = s_mask & n_mask + output_ptr = output + ( + s_physical[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 + ) + tl.store(output_ptr, partial_sum, mask=output_mask) + + +def lora_shrink_fwd( + x: torch.Tensor, + weights: torch.Tensor, + batch_info, + stack_num: int = 1, +) -> torch.Tensor: + """Run the LoRA-A shrink for an arbitrary batch. + + Args: + x: ``(s, in_dim)`` activations, contiguous. + weights: ``(num_lora, stack_num * max_rank, in_dim)``, contiguous. + batch_info: :class:`LoraBatchInfo` describing the segment layout. + stack_num: 1 for single projection, 3 for fused QKV, 2 for gate-up. + + Returns: + ``(s, stack_num * max_rank)`` tensor. Rows of segments whose adapter + is the no-op slot are unwritten — callers must not consume them + (the matching lora_expand kernel is also a no-op for those segments). + """ + assert x.is_contiguous() + assert weights.is_contiguous() + assert x.dim() == 2 + assert weights.dim() == 3 + + S = x.shape[0] + N = weights.shape[-2] + K = weights.shape[-1] + assert x.shape[-1] == K + + max_len = batch_info.max_len + + def grid(meta): + return ( + triton.cdiv(max_len, meta["BLOCK_S"]) * triton.cdiv(N, meta["BLOCK_N"]), + batch_info.bs, + ) + + sorted_by_adapter = batch_info.permutation is not None + + output = torch.empty((S, N), device=x.device, dtype=x.dtype) + _lora_shrink_kernel[grid]( + x, + weights, + output, + N, + K, + stack_num, + x.stride(0), + x.stride(1), + weights.stride(0), + weights.stride(1), + weights.stride(2), + output.stride(0), + output.stride(1), + batch_info.seg_lens, + batch_info.seg_indptr, + batch_info.weight_indices, + batch_info.lora_ranks, + batch_info.permutation, + sorted_by_adapter, + ) + return output + + +# Eager pre-population from disk happens lazily inside the autotuner cache +# (see `tokenspeed_kernel.ops.lora.triton.__init__`). +load_kernel_cache(_lora_shrink_kernel) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/lora_shrink_prefill.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/lora_shrink_prefill.py new file mode 100644 index 000000000..8b8c28856 --- /dev/null +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/lora_shrink_prefill.py @@ -0,0 +1,206 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Prefill-optimised LoRA-A matmul (shrink: in_dim → r). + +Drop-in replacement for :func:`lora_shrink_fwd` on prefill batches +(``max_len > 32``). Identical algorithm; the structural difference is that +``K`` (= in_dim, 4096+), ``N`` (= stack_num * max_rank), and all strides are +**constexpr** — the compiler specialises the K-loop trip count at compile +time and eliminates all stride multiplications. + +Benchmarked gain on H100 vs the decode shrink kernel at s=512, rank=64: + QKV stack=3 (K=4096, N=192): 23 µs → 17 µs (1.3×) + g/up stack=2 (K=4096, N=128): 19 µs → 16 µs (1.2×) + single (K=4096, N=64): 18 µs → 17 µs (~1.0×) + +Adapted from sglang ``python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py`` +(Apache-2.0): https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py. +Local changes: kept SORTED_BY_ADAPTER + S-tiling from our decode kernel +(``lora_shrink.py``), replaced fixed configs with ``@triton.autotune`` + +on-disk cache. +""" + +from __future__ import annotations + +import torch +from tokenspeed_kernel._triton import tl, triton +from tokenspeed_kernel.ops.lora.triton.kernel_utils import _resolve_token_positions +from tokenspeed_kernel.ops.lora.triton.tuning import load_kernel_cache + +# Same config space as the decode shrink kernel. +_PREFILL_SHRINK_CONFIGS = [ + triton.Config( + {"BLOCK_S": s, "BLOCK_N": n, "BLOCK_K": k}, num_warps=w, num_stages=stages + ) + for s in (16, 32) + for n in (16, 32, 64) + for k in (64, 128, 256) + for w in (4, 8) + for stages in (2, 3, 4) +] + + +@triton.autotune(configs=_PREFILL_SHRINK_CONFIGS, key=["N", "K", "NUM_SLICES"]) +@triton.jit +def _lora_shrink_prefill_kernel( + x, + weights, + output, + seg_lens, + seg_indptr, + weight_indices, + lora_ranks, + sorted_token_ids, + N: tl.constexpr, # stack_num * max_rank + K: tl.constexpr, # in_dim + NUM_SLICES: tl.constexpr, # stack_num + SORTED_BY_ADAPTER: tl.constexpr, + BLOCK_S: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + # Constexpr strides — compiler eliminates all stride multiplications. + x_stride_0: tl.constexpr = K + x_stride_1: tl.constexpr = 1 + w_stride_0: tl.constexpr = N * K + w_stride_1: tl.constexpr = K # row stride of the (N, K) weight matrix + w_stride_2: tl.constexpr = 1 + output_stride_0: tl.constexpr = N + output_stride_1: tl.constexpr = 1 + + batch_id = tl.program_id(axis=1) + w_index = tl.load(weight_indices + batch_id) + if w_index < 0: + return + rank = tl.load(lora_ranks + w_index) + if rank == 0: + return + + pid = tl.program_id(axis=0) + seg_start = tl.load(seg_indptr + batch_id) + seg_len = tl.load(seg_lens + batch_id) + if seg_len == 0: + return + + cur_n = tl.minimum(N, rank * NUM_SLICES) + + num_pid_n = tl.cdiv(cur_n, BLOCK_N) + pid_s = pid // num_pid_n + pid_n = pid % num_pid_n + if pid_s * BLOCK_S >= seg_len: + return + + s_offset = tl.arange(0, BLOCK_S) + pid_s * BLOCK_S + n_offset = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + k_offset = tl.arange(0, BLOCK_K) + + s_physical = _resolve_token_positions( + sorted_token_ids, seg_start, s_offset, seg_len, SORTED_BY_ADAPTER + ) + + x_ptrs = x + (s_physical[:, None] * x_stride_0 + k_offset[None, :] * x_stride_1) + w_ptrs = (weights + w_index * w_stride_0) + ( + k_offset[:, None] * w_stride_2 + n_offset[None, :] * w_stride_1 + ) + + s_mask = s_offset[:, None] < seg_len + n_mask = n_offset[None, :] < cur_n + partial_sum = tl.zeros((BLOCK_S, BLOCK_N), dtype=tl.float32) + for k in range(0, K // BLOCK_K): + x_tile = tl.load( + x_ptrs, + mask=s_mask, + other=0.0, + eviction_policy="evict_first", + ) + w_tile = tl.load( + w_ptrs, + mask=n_mask, + other=0.0, + eviction_policy="evict_last", + ) + partial_sum += tl.dot(x_tile, w_tile) + x_ptrs += BLOCK_K * x_stride_1 + w_ptrs += BLOCK_K * w_stride_2 + + partial_sum = partial_sum.to(x.dtype.element_ty) + output_mask = s_mask & n_mask + output_ptr = output + ( + s_physical[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1 + ) + tl.store(output_ptr, partial_sum, mask=output_mask) + + +def lora_shrink_prefill_fwd( + x: torch.Tensor, + weights: torch.Tensor, + batch_info, + stack_num: int = 1, +) -> torch.Tensor: + """Prefill-optimised LoRA-A shrink. Same signature as :func:`lora_shrink_fwd`. + + Args: + x: ``(s, in_dim)`` activations, contiguous. + weights: ``(num_lora, stack_num * max_rank, in_dim)``, contiguous. + batch_info: :class:`LoraBatchInfo`. + stack_num: 1 for single projection, 3 for fused QKV, 2 for gate-up. + + Returns: + ``(s, stack_num * max_rank)`` tensor. + """ + assert x.is_contiguous() + assert weights.is_contiguous() + assert x.dim() == 2 + assert weights.dim() == 3 + + S = x.shape[0] + N = weights.shape[-2] # stack_num * max_rank + K = weights.shape[-1] # in_dim + assert x.shape[-1] == K + + max_len = batch_info.max_len + sorted_by_adapter = batch_info.permutation is not None + + def grid(meta): + return ( + triton.cdiv(max_len, meta["BLOCK_S"]) * triton.cdiv(N, meta["BLOCK_N"]), + batch_info.bs, + ) + + output = torch.empty((S, N), device=x.device, dtype=x.dtype) + _lora_shrink_prefill_kernel[grid]( + x, + weights, + output, + batch_info.seg_lens, + batch_info.seg_indptr, + batch_info.weight_indices, + batch_info.lora_ranks, + batch_info.permutation, + N=N, + K=K, + NUM_SLICES=stack_num, + SORTED_BY_ADAPTER=sorted_by_adapter, + ) + return output + + +load_kernel_cache(_lora_shrink_prefill_kernel) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/tune.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/tune.py new file mode 100644 index 000000000..570772e82 --- /dev/null +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/tune.py @@ -0,0 +1,254 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Offline autotune driver for the LoRA Triton kernels. + +Builds synthetic ``LoraBatchInfo`` batches for a few representative +segment shapes, calls each kernel once (triggering ``triton.autotune`` +to benchmark all candidate configs and pick the fastest per ``(N, K)`` +key), and then writes the picked configs to JSON via +:func:`tokenspeed_kernel.ops.lora.triton.tuning.save_kernel_cache`. + +Usage:: + + python -m tokenspeed_kernel.ops.lora.triton.tune \\ + --hidden 4096 --intermediate 12288 \\ + --q-per-tp 2048 --kv-per-tp 1024 \\ + --rank 16 --max-rank 64 --tp-size 2 + +The defaults match Qwen3-8B at attn_tp_size=2. Shapes only affect which +``(N, K)`` keys get tuned; the actual launch parameters are independent +of which model the cache is shipped against. +""" + +from __future__ import annotations + +import argparse +import logging +from dataclasses import dataclass + +import torch +from tokenspeed_kernel.ops.lora.triton.lora_expand import ( + _lora_expand_kernel, + lora_expand_fwd, +) +from tokenspeed_kernel.ops.lora.triton.lora_gate_up_expand import ( + _lora_gate_up_expand_kernel, + lora_gate_up_expand_fwd, +) +from tokenspeed_kernel.ops.lora.triton.lora_qkv_expand import ( + _lora_qkv_expand_kernel, + lora_qkv_expand_fwd, +) +from tokenspeed_kernel.ops.lora.triton.lora_shrink import ( + _lora_shrink_kernel, + lora_shrink_fwd, +) +from tokenspeed_kernel.ops.lora.triton.tuning import save_kernel_cache + +logger = logging.getLogger(__name__) + + +@dataclass +class _BatchInfo: + """Minimal stand-in for ``runtime.lora.lora_manager.LoraBatchInfo``.""" + + bs: int + max_len: int + seg_lens: torch.Tensor + seg_indptr: torch.Tensor + weight_indices: torch.Tensor + lora_ranks: torch.Tensor + scalings: torch.Tensor + permutation: torch.Tensor | None = None + + +def _make_batch( + s_per_seg: int, n_segs: int, rank: int, device: str = "cuda" +) -> _BatchInfo: + seg_lens = torch.full((n_segs,), s_per_seg, dtype=torch.int32, device=device) + seg_indptr = torch.tensor( + [i * s_per_seg for i in range(n_segs + 1)], dtype=torch.int32, device=device + ) + # weight_indices: route every segment to real adapter slot 0. + weight_indices = torch.zeros(n_segs, dtype=torch.int32, device=device) + lora_ranks = torch.tensor([rank], dtype=torch.int32, device=device) + scalings = torch.tensor([1.0], dtype=torch.float32, device=device) + return _BatchInfo( + bs=n_segs, + max_len=s_per_seg, + seg_lens=seg_lens, + seg_indptr=seg_indptr, + weight_indices=weight_indices, + lora_ranks=lora_ranks, + scalings=scalings, + ) + + +def tune_shrink(*, in_dim: int, stack_num: int, rank: int, max_rank: int) -> None: + """Drive ``_lora_shrink_kernel`` for one ``(stack_num, in_dim)`` shape. + + Uses a decode-shaped batch (``bs=32, max_len=1``) because that is where + LoRA latency dominates the e2e (every decode step pays the kernel cost; + prefill is amortized). Tuning at prefill shapes picks block tiles that + waste threads at decode-time. + """ + device = "cuda" + dtype = torch.bfloat16 + n_segs = 32 + s_per_seg = 1 + s = n_segs * s_per_seg + x = torch.randn((s, in_dim), device=device, dtype=dtype) + weights = torch.randn((2, stack_num * max_rank, in_dim), device=device, dtype=dtype) + bi = _make_batch(s_per_seg, n_segs, rank=rank, device=device) + lora_shrink_fwd(x, weights, bi, stack_num=stack_num) + torch.cuda.synchronize() + print( + f" shrink in_dim={in_dim} stack={stack_num} → best={_lora_shrink_kernel.best_config}" + ) + + +def tune_expand(*, out_dim: int, max_rank: int, rank: int) -> None: + device = "cuda" + dtype = torch.bfloat16 + n_segs = 32 + s_per_seg = 1 + s = n_segs * s_per_seg + x = torch.randn((s, max_rank), device=device, dtype=dtype) + weights = torch.randn((2, out_dim, max_rank), device=device, dtype=dtype) + bi = _make_batch(s_per_seg, n_segs, rank=rank, device=device) + out = torch.zeros((s, out_dim), device=device, dtype=dtype) + lora_expand_fwd(x, weights, bi, base_output=out) + torch.cuda.synchronize() + print( + f" expand out_dim={out_dim} R={max_rank} → best={_lora_expand_kernel.best_config}" + ) + + +def tune_qkv(*, q_per_tp: int, kv_per_tp: int, max_rank: int, rank: int) -> None: + device = "cuda" + dtype = torch.bfloat16 + n_segs = 32 + s_per_seg = 1 + s = n_segs * s_per_seg + x = torch.randn((s, 3 * max_rank), device=device, dtype=dtype) + out_dim = q_per_tp + 2 * kv_per_tp + weights = torch.randn((2, out_dim, max_rank), device=device, dtype=dtype) + max_qkv = max(q_per_tp, kv_per_tp) + output_offset = torch.tensor( + [0, q_per_tp, q_per_tp + kv_per_tp, q_per_tp + 2 * kv_per_tp], + dtype=torch.int32, + device=device, + ) + bi = _make_batch(s_per_seg, n_segs, rank=rank, device=device) + out = torch.zeros((s, out_dim), device=device, dtype=dtype) + lora_qkv_expand_fwd(x, weights, bi, output_offset, max_qkv, base_output=out) + torch.cuda.synchronize() + print( + f" qkv_expand max_qkv={max_qkv} R={max_rank} → best={_lora_qkv_expand_kernel.best_config}" + ) + + +def tune_gate_up(*, intermediate_per_tp: int, max_rank: int, rank: int) -> None: + device = "cuda" + dtype = torch.bfloat16 + n_segs = 32 + s_per_seg = 1 + s = n_segs * s_per_seg + x = torch.randn((s, 2 * max_rank), device=device, dtype=dtype) + weights = torch.randn( + (2, 2 * intermediate_per_tp, max_rank), device=device, dtype=dtype + ) + bi = _make_batch(s_per_seg, n_segs, rank=rank, device=device) + out = torch.zeros((s, 2 * intermediate_per_tp), device=device, dtype=dtype) + lora_gate_up_expand_fwd(x, weights, bi, intermediate_per_tp, base_output=out) + torch.cuda.synchronize() + print( + f" gate_up_expand out={intermediate_per_tp} R={max_rank} → best={_lora_gate_up_expand_kernel.best_config}" + ) + + +def main() -> int: + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--hidden", type=int, default=4096) + p.add_argument( + "--intermediate", + type=int, + default=12288, + help="Full (un-sharded) intermediate_size", + ) + p.add_argument("--q-per-tp", type=int, default=2048) + p.add_argument("--kv-per-tp", type=int, default=512) + p.add_argument("--rank", type=int, default=16) + p.add_argument("--max-rank", type=int, default=64) + p.add_argument("--tp-size", type=int, default=2) + args = p.parse_args() + + logging.basicConfig(level=logging.INFO, format="%(message)s") + + intermediate_per_tp = args.intermediate // args.tp_size + + print("=== Tuning shrink (lora_shrink) ===") + # Attention shrink: stack=3 (QKV) on hidden, stack=1 (o) on q_per_tp. + tune_shrink(in_dim=args.hidden, stack_num=3, rank=args.rank, max_rank=args.max_rank) + tune_shrink( + in_dim=args.q_per_tp, stack_num=1, rank=args.rank, max_rank=args.max_rank + ) + # MLP shrink: stack=2 (gate/up) on hidden, stack=1 (down) on intermediate_per_tp. + tune_shrink(in_dim=args.hidden, stack_num=2, rank=args.rank, max_rank=args.max_rank) + tune_shrink( + in_dim=intermediate_per_tp, stack_num=1, rank=args.rank, max_rank=args.max_rank + ) + + print("\n=== Tuning expand (lora_expand) ===") + # o_proj uses lora_expand directly (out_dim = hidden). + tune_expand(out_dim=args.hidden, max_rank=args.max_rank, rank=args.rank) + # down_proj also uses lora_expand (out_dim = hidden). + # Same shape — autotune cache hit on the second call. + + print("\n=== Tuning qkv_expand (lora_qkv_expand) ===") + tune_qkv( + q_per_tp=args.q_per_tp, + kv_per_tp=args.kv_per_tp, + max_rank=args.max_rank, + rank=args.rank, + ) + + print("\n=== Tuning gate_up_expand (lora_gate_up_expand) ===") + tune_gate_up( + intermediate_per_tp=intermediate_per_tp, + max_rank=args.max_rank, + rank=args.rank, + ) + + print("\n=== Saving caches ===") + for kern in ( + _lora_shrink_kernel, + _lora_expand_kernel, + _lora_qkv_expand_kernel, + _lora_gate_up_expand_kernel, + ): + path = save_kernel_cache(kern) + print(f" wrote {path} ({len(kern.cache)} entries)") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/tune_sweep.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/tune_sweep.py new file mode 100644 index 000000000..5a1507839 --- /dev/null +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/tune_sweep.py @@ -0,0 +1,140 @@ +"""Comprehensive autotune sweep for LoRA decode kernels across common shapes. + +Covers the (N, K) pairs seen in production for the major model families and +TP configurations, across max_rank values of 16 / 32 / 64 / 128. Saves all +picked configs to the on-disk JSON caches so fresh processes skip the sweep. + +Usage:: + + python -m tokenspeed_kernel.ops.lora.triton.tune_sweep + +Estimated runtime: ~5 min on H100 (all shapes × all kernels). +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass + +import torch +from tokenspeed_kernel.ops.lora.triton.lora_expand import _lora_expand_kernel +from tokenspeed_kernel.ops.lora.triton.lora_gate_up_expand import ( + _lora_gate_up_expand_kernel, +) +from tokenspeed_kernel.ops.lora.triton.lora_qkv_expand import _lora_qkv_expand_kernel +from tokenspeed_kernel.ops.lora.triton.lora_shrink import _lora_shrink_kernel +from tokenspeed_kernel.ops.lora.triton.tune import ( + _BatchInfo, + _make_batch, + tune_expand, + tune_gate_up, + tune_qkv, + tune_shrink, +) +from tokenspeed_kernel.ops.lora.triton.tuning import save_kernel_cache + +logging.basicConfig(level=logging.INFO, format="%(message)s") + + +@dataclass +class _ModelTP: + name: str + hidden: int + intermediate_per_tp: int + q_per_tp: int + kv_per_tp: int + + +# ── Representative (model, TP) configs ────────────────────────────────────── +# Each entry represents one serving configuration: hidden size, per-rank +# intermediate, and per-rank Q / KV sizes after tensor parallelism sharding. +# Source model sizes: +# Llama-3-8B: hidden=4096, intermediate=14336, heads=32/8, head_dim=128 +# Llama-3-70B: hidden=8192, intermediate=28672, heads=64/8, head_dim=128 +# Qwen3-8B: hidden=4096, intermediate=12288, heads=32/8, head_dim=128 +_CONFIGS: list[_ModelTP] = [ + # ── Llama-3-8B ────────────────────────────────────────────────────────── + _ModelTP("llama3-8b TP=1", 4096, 14336, 4096, 1024), + _ModelTP("llama3-8b TP=2", 4096, 7168, 2048, 512), + _ModelTP("llama3-8b TP=4", 4096, 3584, 1024, 256), + # ── Qwen3-8B ──────────────────────────────────────────────────────────── + _ModelTP("qwen3-8b TP=1", 4096, 12288, 4096, 1024), + _ModelTP("qwen3-8b TP=2", 4096, 6144, 2048, 512), + _ModelTP("qwen3-8b TP=4", 4096, 3072, 1024, 256), + # ── Llama-3-70B ───────────────────────────────────────────────────────── + _ModelTP("llama3-70b TP=4", 8192, 7168, 2048, 256), + _ModelTP("llama3-70b TP=8", 8192, 3584, 1024, 128), +] + +# Max-rank values to cover — N in the shrink key is stack_num * max_rank. +_MAX_RANKS = [16, 32, 64, 128] + + +def _sweep_shrink(cfg: _ModelTP, max_rank: int) -> None: + rank = max_rank # tune at full rank so the K-loop is fully exercised + # Attention shrink + tune_shrink(in_dim=cfg.hidden, stack_num=3, rank=rank, max_rank=max_rank) + tune_shrink(in_dim=cfg.q_per_tp, stack_num=1, rank=rank, max_rank=max_rank) + # MLP shrink + tune_shrink(in_dim=cfg.hidden, stack_num=2, rank=rank, max_rank=max_rank) + tune_shrink( + in_dim=cfg.intermediate_per_tp, stack_num=1, rank=rank, max_rank=max_rank + ) + + +def _sweep_expand(cfg: _ModelTP, max_rank: int) -> None: + # Clear in-process cache so the autotuner sweeps all configs fresh + # rather than reusing entries loaded from the on-disk JSON. + for k in _lora_expand_kernel, _lora_qkv_expand_kernel, _lora_gate_up_expand_kernel: + k.cache.clear() + rank = max_rank + # o_proj / down_proj + tune_expand(out_dim=cfg.hidden, max_rank=max_rank, rank=rank) + # QKV + tune_qkv( + q_per_tp=cfg.q_per_tp, + kv_per_tp=cfg.kv_per_tp, + max_rank=max_rank, + rank=rank, + ) + # gate/up + tune_gate_up( + intermediate_per_tp=cfg.intermediate_per_tp, + max_rank=max_rank, + rank=rank, + ) + + +def main() -> int: + total_shrink = len(_CONFIGS) * len(_MAX_RANKS) + total_expand = total_shrink + done = 0 + + for max_rank in _MAX_RANKS: + for cfg in _CONFIGS: + done += 1 + print(f"\n[{done}/{total_shrink}] shrink {cfg.name} max_rank={max_rank}") + _sweep_shrink(cfg, max_rank) + + done = 0 + for max_rank in _MAX_RANKS: + for cfg in _CONFIGS: + done += 1 + print(f"\n[{done}/{total_expand}] expand {cfg.name} max_rank={max_rank}") + _sweep_expand(cfg, max_rank) + + print("\n=== Saving caches ===") + for kern in ( + _lora_shrink_kernel, + _lora_expand_kernel, + _lora_qkv_expand_kernel, + _lora_gate_up_expand_kernel, + ): + path = save_kernel_cache(kern) + print(f" wrote {path} ({len(kern.cache)} entries)") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/tuning.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/tuning.py new file mode 100644 index 000000000..db82764b6 --- /dev/null +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/lora/triton/tuning.py @@ -0,0 +1,143 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""On-disk cache for LoRA Triton autotune picks. + +Triton's ``@triton.autotune`` caches the best config per ``key`` tuple in +``Autotuner.cache``, but only for the current process — every fresh Python +process re-runs the sweep on the first call to each unique shape. This +module persists that cache as JSON next to the kernels so the picks +survive process restarts and ship in the repo. + +Layout: ``configs//.json``. When a kernel runs +for the first time on a shape that has no saved entry, Triton falls back +to the candidate-config sweep (slow) and the result can be saved by a +follow-up call to :func:`save_kernel_cache`. + +Config JSON format:: + + { + "(N, K, 'torch.bfloat16')": { + "kwargs": {"BLOCK_S": 16, "BLOCK_N": 64, "BLOCK_K": 64}, + "num_warps": 4, + "num_stages": 3, + "num_ctas": 1, + "maxnreg": null + }, + ... + } +""" + +from __future__ import annotations + +import ast +import json +import logging +import os +from pathlib import Path +from typing import Any + +import torch +from tokenspeed_kernel._triton import triton + +logger = logging.getLogger(__name__) + +CONFIG_DIR = Path(__file__).parent / "configs" + + +def _gpu_label() -> str: + """Compact identifier for the active GPU — partitions config files.""" + if not torch.cuda.is_available(): + return "cpu" + name = torch.cuda.get_device_name(0) + # Strip vendor prefix and whitespace: "NVIDIA H100 80GB HBM3" → "H100_80GB_HBM3". + name = name.replace("NVIDIA ", "").strip() + return name.replace(" ", "_") + + +def _config_path(kernel_name: str) -> Path: + return CONFIG_DIR / _gpu_label() / f"{kernel_name}.json" + + +def _key_to_str(key: tuple) -> str: + # ``repr(tuple)`` round-trips through ``ast.literal_eval`` provided the + # tuple only holds primitives and str dtypes — which it does here. + return repr(tuple(key)) + + +def _str_to_key(s: str) -> tuple: + return tuple(ast.literal_eval(s)) + + +def _config_to_dict(cfg: triton.Config) -> dict: + return { + "kwargs": dict(cfg.kwargs), + "num_warps": cfg.num_warps, + "num_stages": cfg.num_stages, + "num_ctas": cfg.num_ctas, + "maxnreg": cfg.maxnreg, + } + + +def _dict_to_config(d: dict) -> triton.Config: + return triton.Config( + d["kwargs"], + num_warps=d["num_warps"], + num_stages=d["num_stages"], + num_ctas=d.get("num_ctas", 1), + maxnreg=d.get("maxnreg"), + ) + + +def load_kernel_cache(kernel) -> int: + """Populate ``kernel.cache`` from the on-disk JSON for the active GPU. + + ``kernel`` is the ``Autotuner`` wrapper produced by + ``@triton.autotune``. Returns the number of entries loaded (0 when + no config file exists for this GPU, which is the normal first-run + case). + """ + name = kernel.base_fn.__name__ + path = _config_path(name) + if not path.exists(): + logger.debug("no autotune cache for %s at %s", name, path) + return 0 + with open(path) as f: + raw = json.load(f) + loaded = 0 + for k, v in raw.items(): + kernel.cache[_str_to_key(k)] = _dict_to_config(v) + loaded += 1 + logger.info("loaded %d autotune picks for %s from %s", loaded, name, path) + return loaded + + +def save_kernel_cache(kernel) -> Path: + """Dump ``kernel.cache`` to JSON next to the kernel module.""" + name = kernel.base_fn.__name__ + path = _config_path(name) + path.parent.mkdir(parents=True, exist_ok=True) + blob: dict[str, Any] = {} + for key, cfg in kernel.cache.items(): + blob[_key_to_str(key)] = _config_to_dict(cfg) + with open(path, "w") as f: + json.dump(blob, f, indent=2, sort_keys=True) + logger.info("saved %d autotune picks for %s to %s", len(blob), name, path) + return path diff --git a/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe_lora/__init__.py b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe_lora/__init__.py new file mode 100644 index 000000000..63088af12 --- /dev/null +++ b/tokenspeed-kernel/python/tokenspeed_kernel/ops/moe_lora/__init__.py @@ -0,0 +1,1085 @@ +# Copyright (c) 2026 LightSeek Foundation +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Fused Triton kernels for MoE LoRA applied to sorted expert outputs. + +Targets the sglang_shared adapter format (shared outer A, per-expert inner B +for gate/up; per-expert A, shared outer B for down), operating directly on the +sorted token-expert buffers produced by the MoE dispatcher. + +Gate/up expand replaces: all-experts B GEMM (m×R × R×E·I) + candidates.gather + +_add_route_delta with a single per-sorted-position GEMV kernel. + +Down shrink replaces: _route_rows_from_cache + _select_expert_weights + einsum +with a per-sorted-position GEMV kernel; the caller then runs one shared-B GEMM +and scatter_add_ to accumulate into the token-ordered down output. + +Both kernels tile over the rank dimension in BLOCK_R chunks so that register +pressure stays bounded regardless of adapter rank (r=16 to r=256). +""" + +from __future__ import annotations + +import torch +import triton +import triton.language as tl + +# ── Gate/Up Expand ─────────────────────────────────────────────────────────── +# +# For each sorted position s: +# exp = safe_ids[flat_j // K, flat_j % K] where flat_j = sorted_token_ids[s] +# delta = lora_a_m[flat_j // K, :] @ w13_B[exp, offs_i, :].T * scaling +# gate_up_output[s, offs_i] += delta +# +# Rank dimension is reduced in BLOCK_R tiles to bound register usage. +# Grid: (cdiv(I2, BLOCK_I), padded) + + +@triton.jit +def _sorted_gate_up_b_expand_kernel( + lora_a_m, # (m, MAX_R) + w13_B, # (E, I2, MAX_R) — contiguous + safe_ids, # (m, K) int64 + sorted_token_ids, # (padded,) int64 — sorted pos → flat pair + gate_up_output, # output — in-place add + scaling_ptr, # float32 scalar on device + route_count, # int32 — m*K + K, # int32 + I2: tl.constexpr, + MAX_R: tl.constexpr, + BLOCK_I: tl.constexpr, + BLOCK_R: tl.constexpr, + SCATTER: tl.constexpr, # True: write to flat_j (flat-pair output); False: write to pid_s (sorted output) +): + pid_i = tl.program_id(0) + pid_s = tl.program_id(1) + + flat_j = tl.load(sorted_token_ids + pid_s) + if flat_j < 0: + return + if flat_j >= route_count: + return + + tok = flat_j // K + topk_v = flat_j % K + exp = tl.load(safe_ids + tok * K + topk_v).to(tl.int32) + + offs_i = pid_i * BLOCK_I + tl.arange(0, BLOCK_I) + i_mask = offs_i < I2 + + scaling = tl.load(scaling_ptr).to(tl.float32) + acc = tl.zeros((BLOCK_I,), dtype=tl.float32) + + for r_start in range(0, MAX_R, BLOCK_R): + kr = r_start + tl.arange(0, BLOCK_R) + la = tl.load(lora_a_m + tok * MAX_R + kr).to(tl.float32) # (BLOCK_R,) + B_ptr = w13_B + (exp * I2 + offs_i[:, None]) * MAX_R + kr[None, :] + B = tl.load(B_ptr, mask=i_mask[:, None], other=0.0).to( + tl.float32 + ) # (BLOCK_I, BLOCK_R) + acc += tl.sum(B * la[None, :], axis=1) + + # SCATTER=True: write to flat-pair position flat_j (non-TMA, flat-pair output). + # SCATTER=False: write to sorted position pid_s (TMA sorted output). + out_row = flat_j if SCATTER else pid_s + out_ptr = gate_up_output + out_row * I2 + offs_i + old = tl.load(out_ptr, mask=i_mask, other=0.0).to(tl.float32) + tl.store(out_ptr, old + acc * scaling, mask=i_mask) + + +def _choose_block_r(max_r: int) -> int: + """Largest power-of-2 ≤ 32 that divides max_r.""" + block_r = min(32, max_r) + while max_r % block_r != 0: + block_r //= 2 + return max(block_r, 1) + + +def sorted_gate_up_b_expand( + lora_a_m: torch.Tensor, # (m, R) — already computed + w13_B: torch.Tensor, # (E, I2, R) — per-expert B, contiguous + safe_ids: torch.Tensor, # (m, K) int64 + sorted_token_ids: torch.Tensor, # (padded,) int64 + gate_up_output: torch.Tensor, # (padded, I2) — in-place add + scaling: torch.Tensor, # () or (1,) float32 device tensor + route_count: int, # = m*K + K: int, + BLOCK_I: int = 64, +) -> None: + """Fused gate/up expand: lora_a_m @ B[expert].T, add directly to sorted output. + + For TMA-sorted dispatch: output is in sorted expert order (SCATTER=False). + """ + padded, I2 = gate_up_output.shape + MAX_R = w13_B.shape[2] + BLOCK_R = _choose_block_r(MAX_R) + assert w13_B.is_contiguous(), "w13_B must be contiguous for fused kernel" + + grid = (triton.cdiv(I2, BLOCK_I), padded) + _sorted_gate_up_b_expand_kernel[grid]( + lora_a_m, + w13_B, + safe_ids.to(torch.int64), + sorted_token_ids.to(torch.int64), + gate_up_output, + scaling, + route_count, + K, + I2=I2, + MAX_R=MAX_R, + BLOCK_I=BLOCK_I, + BLOCK_R=BLOCK_R, + SCATTER=False, + num_warps=4, + num_stages=3, + ) + + +# ── Flat Gate/Up Expand (decode path) ──────────────────────────────────────── +# +# No sorted_token_ids needed — computes tok = pid_s // K inside the kernel. +# One block per flat-pair position, processes all m*K positions directly. +# Replaces: all-experts B GEMM + candidates.gather + route_delta (3 → 1 kernel). +# Active-expert reads: only the ~51 unique experts' B rows, not all 128. + + +@triton.jit +def _gate_up_b_expand_kernel( + lora_a_m, # (m, MAX_R) + w13_B_buffer, # full buffer: n_slots × E × I2 × MAX_R (contiguous) + slot_ptr, # (1,) int32 — GPU scalar, dynamic at CUDA-graph replay + n_slot_stride, # int — E × I2 × MAX_R (stride between slots) + safe_ids, # (m, K) int64 + gate_up_output, # (m*K, I2) — flat-pair order, in-place add + scaling_ptr, # float32 scalar on device + K, # int32 — topk count + I2: tl.constexpr, + MAX_R: tl.constexpr, + BLOCK_I: tl.constexpr, + BLOCK_R: tl.constexpr, +): + pid_i = tl.program_id(0) + pid_s = tl.program_id(1) # flat-pair index [0 .. m*K-1] + + tok = pid_s // K + topk_v = pid_s % K + exp = tl.load(safe_ids + tok * K + topk_v).to(tl.int32) + + offs_i = pid_i * BLOCK_I + tl.arange(0, BLOCK_I) + i_mask = offs_i < I2 + + # Load slot index dynamically (changes at CUDA-graph replay without re-capture). + slot = tl.load(slot_ptr).to(tl.int32) + # Load scaling from buffer at [slot] — avoids a separate scalings[slot_idx] gather. + scaling = tl.load(scaling_ptr + slot).to(tl.float32) + acc = tl.zeros((BLOCK_I,), dtype=tl.float32) + + for r_start in range(0, MAX_R, BLOCK_R): + kr = r_start + tl.arange(0, BLOCK_R) + la = tl.load(lora_a_m + tok * MAX_R + kr).to(tl.float32) + # Compute B pointer directly into the full buffer using the slot offset, + # avoiding a separate gather copy: buffer[slot, exp, offs_i, kr]. + B_ptr = ( + w13_B_buffer + + slot * n_slot_stride + + (exp * I2 + offs_i[:, None]) * MAX_R + + kr[None, :] + ) + B = tl.load(B_ptr, mask=i_mask[:, None], other=0.0).to(tl.float32) + acc += tl.sum(B * la[None, :], axis=1) + + out_ptr = gate_up_output + pid_s * I2 + offs_i + old = tl.load(out_ptr, mask=i_mask, other=0.0).to(tl.float32) + tl.store(out_ptr, old + acc * scaling, mask=i_mask) + + +def gate_up_b_expand( + lora_a_m: torch.Tensor, # (m, R) — already computed + w13_B_buffer: torch.Tensor, # (n_slots, E, I2, R) — full buffer, contiguous + slot_idx: torch.Tensor, # (1,) int32 — GPU tensor; dynamic at CUDA-graph replay + safe_ids: torch.Tensor, # (m, K) int64 — expert assignments + gate_up_output: torch.Tensor, # (m*K, I2) — flat-pair order, in-place add + scalings: torch.Tensor, # (n_slots,) float32 — full scalings buffer; kernel loads [slot] + BLOCK_I: int = 64, +) -> None: + """Flat per-expert GEMV for decode (no TMA, no sorted_token_ids needed). + + Accepts the FULL (n_slots, E, I2, R) buffer, slot_idx, and the full scalings + buffer — the kernel loads both w13_B and scalings via the slot offset, eliminating + the separate w13_B gather (~38 µs) and scalings gather (~19 µs) per layer. + + One block per flat-pair position; computes tok = pid_s // K directly. + Replaces: all-experts B GEMM + candidates.gather + route_delta (3 → 1 kernel). + """ + m_k, I2 = gate_up_output.shape + K = safe_ids.shape[1] + # Buffer layout: (n_slots, E, I2, MAX_R). + _n_slots, E, _I2, MAX_R = w13_B_buffer.shape + n_slot_stride = E * I2 * MAX_R # elements between consecutive slots + BLOCK_R = _choose_block_r(MAX_R) + assert ( + w13_B_buffer.is_contiguous() + ), "w13_B_buffer must be contiguous for fused kernel" + + grid = (triton.cdiv(I2, BLOCK_I), m_k) + _gate_up_b_expand_kernel[grid]( + lora_a_m, + w13_B_buffer, + slot_idx.to(torch.int32), + n_slot_stride, + safe_ids.to(torch.int64), + gate_up_output, + scalings, + K, + I2=I2, + MAX_R=MAX_R, + BLOCK_I=BLOCK_I, + BLOCK_R=BLOCK_R, + num_warps=4, + num_stages=3, + ) + + +# ── Down Shrink ─────────────────────────────────────────────────────────────── +# +# For each sorted position s, for each rank tile pid_r: +# exp = safe_ids[flat_j // K, flat_j % K] +# lora_a_out[s, pid_r*BLOCK_R : (pid_r+1)*BLOCK_R] +# = intermediate[s, :] @ down_A[exp, pid_r*BLOCK_R : ..., :].T +# +# Grid: (padded, cdiv(MAX_R, BLOCK_R)) +# Splitting over rank tiles keeps (BLOCK_R × BLOCK_H) loads bounded in size. + + +@triton.jit +def _sorted_a_down_shrink_kernel( + intermediate, # (padded, INTER) + down_A, # (E, MAX_R, INTER) — per-expert A, contiguous + safe_ids, # (m, K) int64 + sorted_token_ids, # (padded,) int64 + lora_a_out, # (padded, MAX_R) + route_count, # int32 + K, # int32 + INTER: tl.constexpr, + MAX_R: tl.constexpr, + BLOCK_R: tl.constexpr, # rank output tile; MAX_R divisible by BLOCK_R + BLOCK_H: tl.constexpr, # INTER tile; INTER divisible by BLOCK_H +): + pid_s = tl.program_id(0) + pid_r = tl.program_id(1) + + flat_j = tl.load(sorted_token_ids + pid_s) + valid = (flat_j >= 0) & (flat_j < route_count) + + kr = pid_r * BLOCK_R + tl.arange(0, BLOCK_R) + + if not valid: + tl.store( + lora_a_out + pid_s * MAX_R + kr, + tl.zeros((BLOCK_R,), dtype=intermediate.dtype.element_ty), + ) + return + + tok = flat_j // K + topk_v = flat_j % K + exp = tl.load(safe_ids + tok * K + topk_v).to(tl.int32) + + acc = tl.zeros((BLOCK_R,), dtype=tl.float32) + + for h_start in range(0, INTER, BLOCK_H): + kh = h_start + tl.arange(0, BLOCK_H) + x = tl.load(intermediate + pid_s * INTER + kh).to(tl.float32) # (BLOCK_H,) + A_ptr = down_A + (exp * MAX_R + kr[:, None]) * INTER + kh[None, :] + A = tl.load(A_ptr).to(tl.float32) # (BLOCK_R, BLOCK_H) + acc += tl.sum(A * x[None, :], axis=1) + + tl.store( + lora_a_out + pid_s * MAX_R + kr, + acc.to(intermediate.dtype.element_ty), + ) + + +def _choose_block_h(inter: int) -> int: + """Largest power-of-2 ≤ 128 that divides inter.""" + block_h = min(128, inter) + while inter % block_h != 0: + block_h //= 2 + return max(block_h, 1) + + +def sorted_a_down_shrink( + intermediate: torch.Tensor, # (padded, INTER) + down_A: torch.Tensor, # (E, MAX_R, INTER) + safe_ids: torch.Tensor, # (m, K) int64 + sorted_token_ids: torch.Tensor, # (padded,) int64 + route_count: int, + K: int, +) -> torch.Tensor: + """Fused down shrink: intermediate[s] @ down_A[expert].T for each sorted pos.""" + padded, INTER = intermediate.shape + MAX_R = down_A.shape[1] + BLOCK_R = _choose_block_r(MAX_R) + BLOCK_H = _choose_block_h(INTER) + assert down_A.is_contiguous(), "down_A must be contiguous for fused kernel" + + lora_a = torch.empty( + (padded, MAX_R), dtype=intermediate.dtype, device=intermediate.device + ) + grid = (padded, MAX_R // BLOCK_R) + _sorted_a_down_shrink_kernel[grid]( + intermediate, + down_A, + safe_ids.to(torch.int64), + sorted_token_ids.to(torch.int64), + lora_a, + route_count, + K, + INTER=INTER, + MAX_R=MAX_R, + BLOCK_R=BLOCK_R, + BLOCK_H=BLOCK_H, + num_warps=4, + num_stages=2, + ) + return lora_a + + +# ── Flat Down Shrink (decode path) ──────────────────────────────────────────── +# +# No sorted_token_ids needed — computes tok = pid_s // K inside the kernel. +# One block per (flat-pair, rank-tile), replaces: select_A gather + einsum. +# Avoids the (m*K, r, INTER) intermediate created by _select_expert_weights. +# Grid: (m*K, MAX_R // BLOCK_R) + + +@triton.jit +def _per_expert_a_shrink_kernel( + route_input, # (m*K, INTER) + down_A_buffer, # full buffer: n_slots × E × MAX_R × INTER (contiguous) + slot_ptr, # (1,) int32 — GPU scalar, dynamic at CUDA-graph replay + n_slot_stride, # int — E × MAX_R × INTER (stride between slots) + safe_ids, # (m, K) int64 + lora_a_out, # (m*K, MAX_R) + K, # int32 + INTER: tl.constexpr, + MAX_R: tl.constexpr, + BLOCK_R: tl.constexpr, + BLOCK_H: tl.constexpr, +): + pid_s = tl.program_id(0) # flat-pair index + pid_r = tl.program_id(1) # rank tile + + tok = pid_s // K + topk_v = pid_s % K + exp = tl.load(safe_ids + tok * K + topk_v).to(tl.int32) + + # Load slot index dynamically (changes at CUDA-graph replay without re-capture). + slot = tl.load(slot_ptr).to(tl.int32) + + kr = pid_r * BLOCK_R + tl.arange(0, BLOCK_R) + acc = tl.zeros((BLOCK_R,), dtype=tl.float32) + + for h_start in range(0, INTER, BLOCK_H): + kh = h_start + tl.arange(0, BLOCK_H) + x = tl.load(route_input + pid_s * INTER + kh).to(tl.float32) + # Compute A pointer directly into the full buffer using the slot offset, + # avoiding a separate gather copy: buffer[slot, exp, kr, kh]. + A_ptr = ( + down_A_buffer + + slot * n_slot_stride + + (exp * MAX_R + kr[:, None]) * INTER + + kh[None, :] + ) + A = tl.load(A_ptr).to(tl.float32) + acc += tl.sum(A * x[None, :], axis=1) + + tl.store(lora_a_out + pid_s * MAX_R + kr, acc.to(route_input.dtype.element_ty)) + + +def per_expert_a_shrink( + route_input: torch.Tensor, # (m*K, INTER) — flat-pair intermediate + down_A_buffer: torch.Tensor, # (n_slots, E, MAX_R, INTER) — full buffer, contiguous + slot_idx: torch.Tensor, # (1,) int32 — GPU tensor; dynamic at CUDA-graph replay + safe_ids: torch.Tensor, # (m, K) int64 + out: torch.Tensor | None = None, # optional pre-allocated (m*K, MAX_R) output +) -> torch.Tensor: + """Flat per-expert shrink for decode: route_input[j] @ down_A_buffer[slot, exp[j]].T. + + Accepts the FULL (n_slots, E, MAX_R, INTER) buffer and a GPU scalar slot_idx, + computing the slot offset inside the kernel. This eliminates the separate + gather copy ``down_A = buffer[slot_idx].squeeze(0)`` (saves ~64 µs/layer). + + Replaces _select_expert_weights gather + einsum without any sorted_token_ids. + Returns lora_a (m*K, MAX_R) for the subsequent shared-B GEMM or shared_b_down_expand. + """ + m_k, INTER = route_input.shape + # Buffer layout: (n_slots, E, MAX_R, INTER). + _n_slots, E, MAX_R, _INTER = down_A_buffer.shape + n_slot_stride = E * MAX_R * INTER # elements between consecutive slots + BLOCK_R = _choose_block_r(MAX_R) + BLOCK_H = _choose_block_h(INTER) + assert down_A_buffer.is_contiguous(), "down_A_buffer must be contiguous" + + if out is None: + lora_a = torch.empty( + (m_k, MAX_R), dtype=route_input.dtype, device=route_input.device + ) + else: + lora_a = out + grid = (m_k, MAX_R // BLOCK_R) + _per_expert_a_shrink_kernel[grid]( + route_input, + down_A_buffer, + slot_idx.to(torch.int32), + n_slot_stride, + safe_ids.to(torch.int64), + lora_a, + safe_ids.shape[1], + INTER=INTER, + MAX_R=MAX_R, + BLOCK_R=BLOCK_R, + BLOCK_H=BLOCK_H, + num_warps=4, + num_stages=2, + ) + return lora_a + + +# ── Flat Down Expand (decode path) ──────────────────────────────────────────── +# +# Fused kernel that takes the lora_a output from per_expert_a_shrink and performs +# the shared-B GEMM + topk scaling + accumulation in a single pass. +# Avoids: separate down_B gather copy + standalone GEMM + scale + add. +# +# For each (token, topk_v) pair and each hidden chunk: +# lora_a_row = lora_a[tok*K + topk_v, :] — (MAX_R,) +# B_row = down_B_buffer[slot, 0, offs_h, :] — (BLOCK_H, MAX_R) +# delta_h = lora_a_row @ B_row.T — (BLOCK_H,) +# out[tok, topk_v, offs_h] += delta_h * topk_weights[tok, topk_v] * scaling +# +# Grid: (m*K, cdiv(H, BLOCK_H)) + + +@triton.jit +def _shared_b_down_expand_kernel( + lora_a, # (m*K, MAX_R) + down_B_buffer, # full buffer: n_slots × 1 × H × MAX_R (contiguous) + slot_ptr, # (1,) int32 — GPU scalar, dynamic at CUDA-graph replay + n_slot_stride_B, # int — H × MAX_R (stride between slots; shared-B has dim0=1) + topk_weights, # (m, K) — topk routing weights + scaling_ptr, # float32 scalar on device + down_output, # (m, K, H) — in-place add + K, # int32 — topk count + H: tl.constexpr, # hidden dimension (constexpr for tl.arange) + MAX_R: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_R: tl.constexpr, +): + pid_s = tl.program_id(0) # flat-pair index [0 .. m*K-1] + pid_h = tl.program_id(1) # hidden chunk index + + tok = pid_s // K + topk_v = pid_s % K + + offs_h = pid_h * BLOCK_H + tl.arange(0, BLOCK_H) + h_mask = offs_h < H + + # Load slot index dynamically (changes at CUDA-graph replay without re-capture). + slot = tl.load(slot_ptr).to(tl.int32) + # Load scaling from buffer at [slot] — avoids a separate scalings[slot_idx] gather. + scaling = tl.load(scaling_ptr + slot).to(tl.float32) + weight = tl.load(topk_weights + tok * K + topk_v).to(tl.float32) + + acc = tl.zeros((BLOCK_H,), dtype=tl.float32) + + for r_start in range(0, MAX_R, BLOCK_R): + kr = r_start + tl.arange(0, BLOCK_R) + # Load lora_a row tile: lora_a[pid_s, kr]. + la = tl.load(lora_a + pid_s * MAX_R + kr).to(tl.float32) # (BLOCK_R,) + # Load B tile directly from buffer: buffer[slot, 0, offs_h, kr]. + # n_slot_stride_B = H × MAX_R (shared-B has expert-dim=1 so no expert offset). + B_ptr = ( + down_B_buffer + + slot * n_slot_stride_B + + offs_h[:, None] * MAX_R + + kr[None, :] + ) + B = tl.load(B_ptr, mask=h_mask[:, None], other=0.0).to( + tl.float32 + ) # (BLOCK_H, BLOCK_R) + # delta_h += B @ la (contract over rank dimension) + acc += tl.sum(B * la[None, :], axis=1) + + # Scale by topk weight and adapter scaling, then accumulate. + out_ptr = down_output + (tok * K + topk_v) * H + offs_h + old = tl.load(out_ptr, mask=h_mask, other=0.0).to(tl.float32) + tl.store(out_ptr, old + acc * weight * scaling, mask=h_mask) + + +def _choose_block_h_expand(h: int) -> int: + """Largest power-of-2 ≤ 64 that divides h (or is the largest divisor ≤ 64).""" + block_h = min(64, h) + while h % block_h != 0: + block_h //= 2 + return max(block_h, 1) + + +def shared_b_down_expand( + lora_a: torch.Tensor, # (m*K, MAX_R) — output of per_expert_a_shrink + down_B_buffer: torch.Tensor, # (n_slots, 1, H, MAX_R) — full buffer, contiguous + slot_idx: torch.Tensor, # (1,) int32 — GPU tensor; dynamic at CUDA-graph replay + down_output: torch.Tensor, # (m, K, H) or (m*K, H) — in-place add + topk_weights: torch.Tensor, # (m, K) routing weights + scalings: torch.Tensor, # (n_slots,) float32 — full scalings buffer; kernel loads [slot] + K: int, +) -> None: + """Fused down expand for decode: lora_a @ down_B[slot, 0].T × weight × scaling. + + Accepts the FULL (n_slots, 1, H, MAX_R) buffer, slot_idx, and the full scalings + buffer — eliminates the separate down_B gather and scalings gather per layer. + + Performs the shared-B GEMM, topk-weight scaling, and accumulation into + down_output in a single fused kernel. + """ + m_k, MAX_R = lora_a.shape + # Buffer layout: (n_slots, 1, H, MAX_R). + _n_slots, _one, H, _MAX_R = down_B_buffer.shape + # Stride between slots: only 1 expert-slot for shared B, so stride = 1 × H × MAX_R. + n_slot_stride_B = H * MAX_R + BLOCK_H = _choose_block_h_expand(H) + BLOCK_R = _choose_block_r(MAX_R) + assert ( + down_B_buffer.is_contiguous() + ), "down_B_buffer must be contiguous for fused kernel" + + # Reshape output to (m*K, H) so the kernel can use a flat pid_s index. + out_flat = down_output.view(m_k, H) + + grid = (m_k, triton.cdiv(H, BLOCK_H)) + _shared_b_down_expand_kernel[grid]( + lora_a, + down_B_buffer, + slot_idx.to(torch.int32), + n_slot_stride_B, + topk_weights, + scalings, + out_flat, + K, + H=H, + MAX_R=MAX_R, + BLOCK_H=BLOCK_H, + BLOCK_R=BLOCK_R, + num_warps=4, + num_stages=3, + ) + + +# ── Flat A GEMM (decode path) ───────────────────────────────────────────────── +# +# Computes lora_a_m = hidden @ w13_A[slot, 0, :, :].T for each token, +# reading directly from the buffer without a prior gather copy. +# Replaces: w13_A gather (22 µs) + cuBLAS GEMM (25 µs) → ~5-8 µs per layer. +# +# Grid: (m, MAX_R // BLOCK_R) — one block per (token, rank-tile) + + +@triton.jit +def _shared_a_shrink_kernel( + hidden, # (m, H) + w13_A_buffer, # full buffer: n_slots × 1 × MAX_R × H (contiguous) + slot_ptr, # (1,) int32 — GPU scalar, dynamic at CUDA-graph replay + n_slot_stride_A, # int — MAX_R × H (stride between slots; shared outer has 1 row) + lora_a_out, # (m, MAX_R) + H: tl.constexpr, + MAX_R: tl.constexpr, + BLOCK_R: tl.constexpr, + BLOCK_H: tl.constexpr, +): + pid_m = tl.program_id(0) # token index + pid_r = tl.program_id(1) # rank tile + + slot = tl.load(slot_ptr).to(tl.int32) + kr = pid_r * BLOCK_R + tl.arange(0, BLOCK_R) + acc = tl.zeros((BLOCK_R,), dtype=tl.float32) + + for h_start in range(0, H, BLOCK_H): + kh = h_start + tl.arange(0, BLOCK_H) + x = tl.load(hidden + pid_m * H + kh).to(tl.float32) # (BLOCK_H,) + # buffer[slot, 0, kr, kh]: stride = slot * n_slot_stride_A + kr * H + kh + A_ptr = w13_A_buffer + slot * n_slot_stride_A + kr[:, None] * H + kh[None, :] + A = tl.load(A_ptr).to(tl.float32) # (BLOCK_R, BLOCK_H) + acc += tl.sum(A * x[None, :], axis=1) + + tl.store(lora_a_out + pid_m * MAX_R + kr, acc.to(hidden.dtype.element_ty)) + + +def shared_a_shrink( + hidden: torch.Tensor, # (m, H) + w13_A_buffer: torch.Tensor, # (n_slots, 1, MAX_R, H) — full buffer + slot_idx: torch.Tensor, # (1,) int32 GPU tensor + BLOCK_H: int = 128, +) -> torch.Tensor: + """Compute lora_a_m = hidden @ w13_A_buffer[slot, 0, :, :].T without gather. + + Replaces: w13_A gather (22 µs) + cuBLAS GEMM (25 µs) = 47 µs per layer + With: single Triton kernel (~5-8 µs), saving ~40 µs × 48 = 1.9 ms. + """ + m, H = hidden.shape + _n_slots, _one, MAX_R, _H = w13_A_buffer.shape + n_slot_stride_A = MAX_R * H # stride between slots (1 × MAX_R × H) + BLOCK_R = _choose_block_r(MAX_R) + + lora_a = torch.empty((m, MAX_R), dtype=hidden.dtype, device=hidden.device) + grid = (m, MAX_R // BLOCK_R) + _shared_a_shrink_kernel[grid]( + hidden, + w13_A_buffer, + slot_idx.to(torch.int32), + n_slot_stride_A, + lora_a, + H=H, + MAX_R=MAX_R, + BLOCK_R=BLOCK_R, + BLOCK_H=BLOCK_H, + num_warps=4, + num_stages=2, + ) + return lora_a + + +# ── Per-Expert Gate/Up Expand ───────────────────────────────────────────────── +# +# Like gate_up_b_expand but reads lora_a_flat[pid_s] (per flat-pair position) +# instead of lora_a_m[tok] (shared per token). Required for per_expert adapters +# where each expert has its own A matrix → lora_a differs per (token, topk_v) pair. +# +# Grid: (cdiv(I2, BLOCK_I), m*K) + + +@triton.jit +def _per_expert_gate_up_b_expand_kernel( + lora_a_flat, # (m*K, MAX_R) — per flat-pair lora_a (from per_expert_a_shrink w/ hidden) + w13_B_buffer, # full buffer: n_slots × E × I2 × MAX_R (contiguous) + slot_ptr, # (1,) int32 + n_slot_stride, # E × I2 × MAX_R + safe_ids, # (m, K) int64 + gate_up_output, # (m*K, I2) — in-place add + scaling_ptr, # (n_slots,) float32 + K, + I2: tl.constexpr, + MAX_R: tl.constexpr, + BLOCK_I: tl.constexpr, + BLOCK_R: tl.constexpr, +): + pid_i = tl.program_id(0) + pid_s = tl.program_id(1) # flat-pair index [0 .. m*K-1] + + tok = pid_s // K + topk_v = pid_s % K + exp = tl.load(safe_ids + tok * K + topk_v).to(tl.int32) + + offs_i = pid_i * BLOCK_I + tl.arange(0, BLOCK_I) + i_mask = offs_i < I2 + + slot = tl.load(slot_ptr).to(tl.int32) + scaling = tl.load(scaling_ptr + slot).to(tl.float32) + acc = tl.zeros((BLOCK_I,), dtype=tl.float32) + + for r_start in range(0, MAX_R, BLOCK_R): + kr = r_start + tl.arange(0, BLOCK_R) + # Per-position lora_a: lora_a_flat[pid_s] instead of lora_a_m[tok] + la = tl.load(lora_a_flat + pid_s * MAX_R + kr).to(tl.float32) + B_ptr = ( + w13_B_buffer + + slot * n_slot_stride + + (exp * I2 + offs_i[:, None]) * MAX_R + + kr[None, :] + ) + B = tl.load(B_ptr, mask=i_mask[:, None], other=0.0).to(tl.float32) + acc += tl.sum(B * la[None, :], axis=1) + + out_ptr = gate_up_output + pid_s * I2 + offs_i + old = tl.load(out_ptr, mask=i_mask, other=0.0).to(tl.float32) + tl.store(out_ptr, old + acc * scaling, mask=i_mask) + + +def per_expert_gate_up_b_expand( + lora_a_flat: torch.Tensor, # (m*K, MAX_R) — from per_expert_a_shrink(hidden_flat, w13_A_buf, ...) + w13_B_buffer: torch.Tensor, # (n_slots, E, I2, MAX_R) — full buffer + slot_idx: torch.Tensor, # (1,) int32 GPU tensor + safe_ids: torch.Tensor, # (m, K) int64 + gate_up_output: torch.Tensor, # (m*K, I2) — in-place add + scalings: torch.Tensor, # (n_slots,) float32 + BLOCK_I: int = 64, +) -> None: + """Per-expert gate/up expand for decode: lora_a_flat[j] @ w13_B[slot, e_j].T. + + Replaces the gather-then-einsum path for per_expert adapters. Accepts the FULL + (n_slots, E, I2, MAX_R) buffer and reads the expert offset directly using safe_ids, + eliminating the two gather copies (w13_B gather + expert-select gather). + """ + m_k, MAX_R = lora_a_flat.shape + _n_slots, E, I2, _MAX_R = w13_B_buffer.shape + n_slot_stride = E * I2 * MAX_R + BLOCK_R = _choose_block_r(MAX_R) + K = safe_ids.shape[1] + assert w13_B_buffer.is_contiguous(), "w13_B_buffer must be contiguous" + + grid = (triton.cdiv(I2, BLOCK_I), m_k) + _per_expert_gate_up_b_expand_kernel[grid]( + lora_a_flat, + w13_B_buffer, + slot_idx.to(torch.int32), + n_slot_stride, + safe_ids.to(torch.int64), + gate_up_output, + scalings, + K, + I2=I2, + MAX_R=MAX_R, + BLOCK_I=BLOCK_I, + BLOCK_R=BLOCK_R, + num_warps=4, + num_stages=2, + ) + + +# ── Per-Expert Down Expand ──────────────────────────────────────────────────── +# +# Like shared_b_down_expand but reads per-expert B: down_B_buffer[slot, e_j, offs_h, :]. +# Required for per_expert adapters where down_B is per-expert (not shared). +# Eliminates the two gather copies (down_B buffer copy + expert select gather). +# +# Grid: (m*K, cdiv(H, BLOCK_H)) + + +@triton.jit +def _per_expert_b_down_expand_kernel( + lora_a, # (m*K, MAX_R) + down_B_buffer, # full buffer: n_slots × E × H × MAX_R (contiguous) + slot_ptr, # (1,) int32 + n_slot_stride_B, # E × H × MAX_R + safe_ids, # (m, K) int64 + topk_weights, # (m, K) + scaling_ptr, # (n_slots,) float32 + down_output, # (m, K, H) — in-place add + K, + H: tl.constexpr, + MAX_R: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_R: tl.constexpr, +): + pid_s = tl.program_id(0) + pid_h = tl.program_id(1) + + tok = pid_s // K + topk_v = pid_s % K + exp = tl.load(safe_ids + tok * K + topk_v).to(tl.int32) + + offs_h = pid_h * BLOCK_H + tl.arange(0, BLOCK_H) + h_mask = offs_h < H + + slot = tl.load(slot_ptr).to(tl.int32) + scaling = tl.load(scaling_ptr + slot).to(tl.float32) + weight = tl.load(topk_weights + tok * K + topk_v).to(tl.float32) + + acc = tl.zeros((BLOCK_H,), dtype=tl.float32) + + for r_start in range(0, MAX_R, BLOCK_R): + kr = r_start + tl.arange(0, BLOCK_R) + la = tl.load(lora_a + pid_s * MAX_R + kr).to(tl.float32) + # Per-expert B: buffer[slot, exp, offs_h, kr] + B_ptr = ( + down_B_buffer + + slot * n_slot_stride_B + + (exp * H + offs_h[:, None]) * MAX_R + + kr[None, :] + ) + B = tl.load(B_ptr, mask=h_mask[:, None], other=0.0).to(tl.float32) + acc += tl.sum(B * la[None, :], axis=1) + + out_ptr = down_output + (tok * K + topk_v) * H + offs_h + old = tl.load(out_ptr, mask=h_mask, other=0.0).to(tl.float32) + tl.store(out_ptr, old + acc * weight * scaling, mask=h_mask) + + +def per_expert_b_down_expand( + lora_a: torch.Tensor, # (m*K, MAX_R) — from per_expert_a_shrink + down_B_buffer: torch.Tensor, # (n_slots, E, H, MAX_R) — full buffer + slot_idx: torch.Tensor, # (1,) int32 GPU tensor + safe_ids: torch.Tensor, # (m, K) int64 + down_output: torch.Tensor, # (m, K, H) or (m*K, H) — in-place add + topk_weights: torch.Tensor, # (m, K) + scalings: torch.Tensor, # (n_slots,) float32 + K: int, +) -> None: + """Per-expert down expand for decode: lora_a[j] @ down_B[slot, e_j].T × weight. + + Eliminates the two gather copies (down_B buffer copy + expert select gather) + for per_expert adapters where down_B is per-expert (not shared). + """ + m_k, MAX_R = lora_a.shape + _n_slots, E, H, _MAX_R = down_B_buffer.shape + n_slot_stride_B = E * H * MAX_R + BLOCK_H = _choose_block_h_expand(H) + BLOCK_R = _choose_block_r(MAX_R) + assert down_B_buffer.is_contiguous(), "down_B_buffer must be contiguous" + + out_flat = down_output.view(m_k, H) + grid = (m_k, triton.cdiv(H, BLOCK_H)) + _per_expert_b_down_expand_kernel[grid]( + lora_a, + down_B_buffer, + slot_idx.to(torch.int32), + n_slot_stride_B, + safe_ids.to(torch.int64), + topk_weights, + scalings, + out_flat, + K, + H=H, + MAX_R=MAX_R, + BLOCK_H=BLOCK_H, + BLOCK_R=BLOCK_R, + num_warps=4, + num_stages=2, + ) + + +# ── Fused A+B Gate/Up (eliminates shared_a_shrink + gate_up_b_expand) ────── +# +# Combines hidden @ w13_A + lora_a @ w13_B in one kernel, removing a separate +# shared_a_shrink launch. Lora_a is computed per flat-pair block (redundant for +# k>1 per token) but w13_A fits in L1 so cache hits make this negligible. +# Grid: (cdiv(I2, BLOCK_I), m*K) + + +@triton.jit +def _fused_shared_a_b_gate_up_kernel( + hidden, # (m, H) + w13_A_buffer, # (n_slots, 1, MAX_R, H) — contiguous + w13_B_buffer, # (n_slots, E, I2, MAX_R) — contiguous + safe_ids, # (m, K) int64 + gate_up_output, # (m*K, I2) — in-place add + scalings, # (n_slots,) float32 + slot_ptr, # (1,) int32 + K, + n_A_stride, # = MAX_R * H + n_B_stride, # = E * I2 * MAX_R + H: tl.constexpr, + I2: tl.constexpr, + MAX_R: tl.constexpr, + BLOCK_H: tl.constexpr, + BLOCK_I: tl.constexpr, + BLOCK_R: tl.constexpr, +): + pid_i = tl.program_id(0) # I2 chunk + pid_s = tl.program_id(1) # flat-pair index + + slot = tl.load(slot_ptr).to(tl.int32) + tok = pid_s // K + topk_v = pid_s % K + exp = tl.load(safe_ids + tok * K + topk_v).to(tl.int32) + scaling = tl.load(scalings + slot).to(tl.float32) + + offs_i = pid_i * BLOCK_I + tl.arange(0, BLOCK_I) + i_mask = offs_i < I2 + acc = tl.zeros((BLOCK_I,), dtype=tl.float32) + + # Outer loop over BLOCK_R chunks of rank: compute lora_a[r:r+BLOCK_R] then expand. + # This avoids storing the full lora_a vector when BLOCK_R < MAX_R. + for r_start in range(0, MAX_R, BLOCK_R): + kr = r_start + tl.arange(0, BLOCK_R) + + # Phase 1 (for this rank chunk): la = hidden[tok] @ w13_A[slot, 0, kr, :].T + la = tl.zeros((BLOCK_R,), dtype=tl.float32) + for h_start in range(0, H, BLOCK_H): + kh = h_start + tl.arange(0, BLOCK_H) + x = tl.load(hidden + tok * H + kh).to(tl.float32) + A_ptr = w13_A_buffer + slot * n_A_stride + kr[:, None] * H + kh[None, :] + A = tl.load(A_ptr).to(tl.float32) + la += tl.sum(A * x[None, :], axis=1) + + # Phase 2 (for this rank chunk): acc += la @ w13_B[slot, exp, offs_i, kr].T + B_ptr = ( + w13_B_buffer + + slot * n_B_stride + + (exp * I2 + offs_i[:, None]) * MAX_R + + kr[None, :] + ) + B = tl.load(B_ptr, mask=i_mask[:, None], other=0.0).to(tl.float32) + acc += tl.sum(B * la[None, :], axis=1) + + out_ptr = gate_up_output + pid_s * I2 + offs_i + old = tl.load(out_ptr, mask=i_mask, other=0.0).to(tl.float32) + tl.store(out_ptr, old + acc * scaling, mask=i_mask) + + +def fused_shared_a_b_gate_up_expand( + hidden: torch.Tensor, # (m, H) + w13_A_buffer: torch.Tensor, # (n_slots, 1, MAX_R, H) + w13_B_buffer: torch.Tensor, # (n_slots, E, I2, MAX_R) + safe_ids: torch.Tensor, # (m, K) int64 + gate_up_output: torch.Tensor, # (m*K, I2) — in-place add + scalings: torch.Tensor, # (n_slots,) float32 + slot_idx: torch.Tensor, # (1,) int32 + BLOCK_I: int = 64, + BLOCK_H: int = 128, +) -> None: + """Fused A+B gate/up: eliminates the separate shared_a_shrink kernel launch.""" + m_k, I2 = gate_up_output.shape + m, H = hidden.shape + K = safe_ids.shape[1] + _ns, _one, MAX_R, _H = w13_A_buffer.shape + _ns2, E, _I2, _MAX_R = w13_B_buffer.shape + n_A_stride = MAX_R * H + n_B_stride = E * I2 * MAX_R + BLOCK_R = _choose_block_r(MAX_R) + assert w13_A_buffer.is_contiguous() and w13_B_buffer.is_contiguous() + + grid = (triton.cdiv(I2, BLOCK_I), m_k) + _fused_shared_a_b_gate_up_kernel[grid]( + hidden, + w13_A_buffer, + w13_B_buffer, + safe_ids.to(torch.int64), + gate_up_output, + scalings, + slot_idx.to(torch.int32), + K, + n_A_stride, + n_B_stride, + H=H, + I2=I2, + MAX_R=MAX_R, + BLOCK_H=BLOCK_H, + BLOCK_I=BLOCK_I, + BLOCK_R=BLOCK_R, + num_warps=4, + num_stages=2, + ) + + +# ── Fused Shrink+Expand Down (eliminates per_expert_a_shrink + shared_b_down_expand) ─ +# +# Combines ri @ down_A + lora_a @ down_B in one kernel per (flat-pair, H-chunk). +# Grid: (m*K, cdiv(H, BLOCK_H)) + + +@triton.jit +def _fused_a_b_down_expand_kernel( + route_input, # (m*K, INTER) + down_A_buffer, # (n_slots, E, MAX_R, INTER) — contiguous + down_B_buffer, # (n_slots, 1, H, MAX_R) — contiguous + safe_ids, # (m, K) int64 + topk_weights, # (m, K) + scalings, # (n_slots,) float32 + slot_ptr, # (1,) int32 + down_output, # (m*K, H) — in-place add + K, + n_A_stride, # = E * MAX_R * INTER + n_B_stride, # = H * MAX_R + INTER: tl.constexpr, + H: tl.constexpr, + MAX_R: tl.constexpr, + BLOCK_H_S: tl.constexpr, # shrink tile over INTER + BLOCK_H_E: tl.constexpr, # expand tile over H +): + pid_s = tl.program_id(0) # flat-pair index + pid_h = tl.program_id(1) # H chunk + + slot = tl.load(slot_ptr).to(tl.int32) + tok = pid_s // K + topk_v = pid_s % K + exp = tl.load(safe_ids + tok * K + topk_v).to(tl.int32) + weight = tl.load(topk_weights + tok * K + topk_v).to(tl.float32) + scaling = tl.load(scalings + slot).to(tl.float32) + + offs_h = pid_h * BLOCK_H_E + tl.arange(0, BLOCK_H_E) + h_mask = offs_h < H + kr = tl.arange(0, MAX_R) + + # Phase 1: lora_a = ri[pid_s] @ down_A[slot, exp, :, :].T + lora_a = tl.zeros((MAX_R,), dtype=tl.float32) + for h_start in range(0, INTER, BLOCK_H_S): + kh = h_start + tl.arange(0, BLOCK_H_S) + x = tl.load(route_input + pid_s * INTER + kh).to(tl.float32) + A_ptr = ( + down_A_buffer + + slot * n_A_stride + + (exp * MAX_R + kr[:, None]) * INTER + + kh[None, :] + ) + A = tl.load(A_ptr).to(tl.float32) + lora_a += tl.sum(A * x[None, :], axis=1) + + # Phase 2: delta = lora_a @ down_B[slot, 0, offs_h, :].T * weight * scaling + B_ptr = down_B_buffer + slot * n_B_stride + offs_h[:, None] * MAX_R + kr[None, :] + B = tl.load(B_ptr, mask=h_mask[:, None], other=0.0).to(tl.float32) + delta = tl.sum(B * lora_a[None, :], axis=1) * weight * scaling + + out_ptr = down_output + pid_s * H + offs_h + old = tl.load(out_ptr, mask=h_mask, other=0.0).to(tl.float32) + tl.store(out_ptr, old + delta, mask=h_mask) + + +def fused_a_b_down_expand( + route_input: torch.Tensor, # (m*K, INTER) + down_A_buffer: torch.Tensor, # (n_slots, E, MAX_R, INTER) + down_B_buffer: torch.Tensor, # (n_slots, 1, H, MAX_R) + safe_ids: torch.Tensor, # (m, K) int64 + topk_weights: torch.Tensor, # (m, K) + scalings: torch.Tensor, # (n_slots,) float32 + slot_idx: torch.Tensor, # (1,) int32 + down_output: torch.Tensor, # (m*K, H) or (m, K, H) — in-place add + BLOCK_H_E: int = 64, +) -> None: + """Fused shrink+expand down: eliminates per_expert_a_shrink + shared_b_down_expand launches.""" + m_k, INTER = route_input.shape + _ns, E, MAX_R, _INTER = down_A_buffer.shape + _ns2, _one, H, _MAX_R = down_B_buffer.shape + K = safe_ids.shape[1] + n_A_stride = E * MAX_R * INTER + n_B_stride = H * MAX_R + BLOCK_H_S = _choose_block_h(INTER) + assert down_A_buffer.is_contiguous() and down_B_buffer.is_contiguous() + + out_flat = down_output.view(m_k, H) + grid = (m_k, triton.cdiv(H, BLOCK_H_E)) + _fused_a_b_down_expand_kernel[grid]( + route_input, + down_A_buffer, + down_B_buffer, + safe_ids.to(torch.int64), + topk_weights, + scalings, + slot_idx.to(torch.int32), + out_flat, + K, + n_A_stride, + n_B_stride, + INTER=INTER, + H=H, + MAX_R=MAX_R, + BLOCK_H_S=BLOCK_H_S, + BLOCK_H_E=BLOCK_H_E, + num_warps=4, + num_stages=2, + ) diff --git a/tokenspeed-scheduler/bindings/python_module.cpp b/tokenspeed-scheduler/bindings/python_module.cpp index eaa825b29..6c9358dd4 100644 --- a/tokenspeed-scheduler/bindings/python_module.cpp +++ b/tokenspeed-scheduler/bindings/python_module.cpp @@ -23,7 +23,6 @@ #include #include #include -#include #include #include @@ -151,10 +150,6 @@ NB_MODULE(tokenspeed_scheduler_ext, m) { .value("FullHistory", tokenspeed::PagedCacheGroupConfig::Retention::FullHistory) .value("SlidingWindow", tokenspeed::PagedCacheGroupConfig::Retention::SlidingWindow); - nb::enum_(m, "PagedCacheGroupFamily") - .value("History", tokenspeed::PagedCacheGroupFamily::History) - .value("State", tokenspeed::PagedCacheGroupFamily::State); - nb::class_(m, "PagedCacheGroupConfig") .def(nb::init<>()) .def( @@ -162,22 +157,19 @@ NB_MODULE(tokenspeed_scheduler_ext, m) { [](tokenspeed::PagedCacheGroupConfig* self, std::string group_id, std::int32_t rows_per_page, std::int32_t entry_stride_tokens, std::int32_t total_pages, tokenspeed::PagedCacheGroupConfig::Retention retention, - std::optional sliding_window_tokens, tokenspeed::PagedCacheGroupFamily family) { - new (self) tokenspeed::PagedCacheGroupConfig{ - std::move(group_id), rows_per_page, entry_stride_tokens, total_pages, retention, - sliding_window_tokens, family}; + std::optional sliding_window_tokens) { + new (self) tokenspeed::PagedCacheGroupConfig{std::move(group_id), rows_per_page, entry_stride_tokens, + total_pages, retention, sliding_window_tokens}; }, nb::arg("group_id"), nb::arg("rows_per_page"), nb::arg("entry_stride_tokens"), nb::arg("total_pages"), nb::arg("retention") = tokenspeed::PagedCacheGroupConfig::Retention::FullHistory, - nb::arg("sliding_window_tokens") = std::nullopt, - nb::arg("family") = tokenspeed::PagedCacheGroupFamily::History) + nb::arg("sliding_window_tokens") = std::nullopt) .def_rw("group_id", &tokenspeed::PagedCacheGroupConfig::group_id) .def_rw("rows_per_page", &tokenspeed::PagedCacheGroupConfig::rows_per_page) .def_rw("entry_stride_tokens", &tokenspeed::PagedCacheGroupConfig::entry_stride_tokens) .def_rw("total_pages", &tokenspeed::PagedCacheGroupConfig::total_pages) .def_rw("retention", &tokenspeed::PagedCacheGroupConfig::retention) .def_rw("sliding_window_tokens", &tokenspeed::PagedCacheGroupConfig::sliding_window_tokens) - .def_rw("family", &tokenspeed::PagedCacheGroupConfig::family) .def("raw_tokens_per_page", &tokenspeed::PagedCacheGroupConfig::RawTokensPerPage) .def("validate", &tokenspeed::PagedCacheGroupConfig::Validate); @@ -200,8 +192,6 @@ NB_MODULE(tokenspeed_scheduler_ext, m) { .def("page_ids", &tokenspeed::PagedCacheGroupTable::PageIds, nb::rv_policy::reference_internal) .def("size", &tokenspeed::PagedCacheGroupTable::Size) .def("active_pages_count", &tokenspeed::PagedCacheGroupTable::ActivePagesCount) - .def("owned_pages_count", &tokenspeed::PagedCacheGroupTable::OwnedPagesCount) - .def("borrowed_pages_count", &tokenspeed::PagedCacheGroupTable::BorrowedPagesCount) .def("released_pages_count", &tokenspeed::PagedCacheGroupTable::ReleasedPagesCount) .def("base_logical_page", &tokenspeed::PagedCacheGroupTable::BaseLogicalPage) .def("raw_token_cursor", &tokenspeed::PagedCacheGroupTable::RawTokenCursor) @@ -211,12 +201,6 @@ NB_MODULE(tokenspeed_scheduler_ext, m) { .def("is_sliding", &tokenspeed::PagedCacheGroupTable::IsSliding) .def("sliding_window_tokens", &tokenspeed::PagedCacheGroupTable::SlidingWindowTokens); - // Python declares the required group ids only. Scheduler derives LCM and - // sliding-window metadata from the matching PagedCacheGroupConfig entries. - nb::class_(m, "PrefixCacheAdjunctSpec") - .def(nb::init<>()) - .def_rw("required_groups", &tokenspeed::PrefixCacheAdjunctSpec::required_groups); - scheduler_config.def(nb::init<>()) .def_rw("page_size", &tokenspeed::SchedulerConfig::page_size) .def_rw("max_scheduled_tokens", &tokenspeed::SchedulerConfig::max_scheduled_tokens) @@ -230,7 +214,6 @@ NB_MODULE(tokenspeed_scheduler_ext, m) { "num_host_pages", [](const tokenspeed::SchedulerConfig& c) { return c.host_allocator.total_pages; }, [](tokenspeed::SchedulerConfig& c, std::int32_t v) { c.host_allocator.total_pages = v; }) .def_rw("paged_cache_groups", &tokenspeed::SchedulerConfig::paged_cache_groups) - .def_rw("prefix_cache_adjunct", &tokenspeed::SchedulerConfig::prefix_cache_adjunct) .def_rw("disable_l2_cache", &tokenspeed::SchedulerConfig::disable_l2_cache) .def_rw("enable_l3_storage", &tokenspeed::SchedulerConfig::enable_l3_storage) .def_rw("prefetch_threshold", &tokenspeed::SchedulerConfig::prefetch_threshold) @@ -241,14 +224,16 @@ NB_MODULE(tokenspeed_scheduler_ext, m) { .def_rw("mamba_cache_chunk_size", &tokenspeed::SchedulerConfig::mamba_cache_chunk_size) .def_rw("mamba_pool_total_chunks", &tokenspeed::SchedulerConfig::mamba_pool_total_chunks) .def_rw("enable_mamba_l2", &tokenspeed::SchedulerConfig::enable_mamba_l2) - .def_rw("mamba_l2_host_slots", &tokenspeed::SchedulerConfig::mamba_l2_host_slots); + .def_rw("mamba_l2_host_slots", &tokenspeed::SchedulerConfig::mamba_l2_host_slots) + .def_rw("max_loras", &tokenspeed::SchedulerConfig::max_loras); nb::class_(m, "RequestSpec") .def(nb::init<>()) .def_rw("request_id", &tokenspeed::RequestSpec::request_id) .def_rw("tokens", &tokenspeed::RequestSpec::tokens) .def_rw("rolling_hashes", &tokenspeed::RequestSpec::rolling_hashes) - .def_rw("storage_hit_pages", &tokenspeed::RequestSpec::storage_hit_pages); + .def_rw("storage_hit_pages", &tokenspeed::RequestSpec::storage_hit_pages) + .def_rw("lora_id", &tokenspeed::RequestSpec::lora_id); nb::module_ forward_event = m.def_submodule("ForwardEvent"); nb::class_(forward_event, "ExtendResult") @@ -429,6 +414,7 @@ NB_MODULE(tokenspeed_scheduler_ext, m) { .def("get_request_token_size", &tokenspeed::Scheduler::GetRequestTokenSize, nb::arg("id")) .def("calc_rolling_hash", &tokenspeed::Scheduler::CalcRollingHash, nb::arg("input_tokens"), nb::arg("apply_match") = false) + .def("evict_lora_namespace", &tokenspeed::Scheduler::EvictLoraNamespace, nb::arg("lora_id")) .def("paged_cache_group_ids", &tokenspeed::Scheduler::PagedCacheGroupIds) .def("paged_cache_group_total_pages", &tokenspeed::Scheduler::PagedCacheGroupTotalPages, nb::arg("group_id")) .def("paged_cache_group_available_pages", &tokenspeed::Scheduler::PagedCacheGroupAvailablePages, diff --git a/tokenspeed-scheduler/csrc/fsm/forward_events.cpp b/tokenspeed-scheduler/csrc/fsm/forward_events.cpp index 83dd0354d..20953b69d 100644 --- a/tokenspeed-scheduler/csrc/fsm/forward_events.cpp +++ b/tokenspeed-scheduler/csrc/fsm/forward_events.cpp @@ -106,7 +106,7 @@ void InsertHybridCache(HybridPrefixCache* hybrid_cache, const std::vector>& full_paged_tokens, std::unique_ptr& device_node_ref, LocalKVAllocator* local_kv_allocator, LocalMambaAllocator* local_mamba_allocator, std::int32_t chunk_begin, std::int32_t chunk_size, - std::int32_t page_size) { + std::int32_t page_size, std::int32_t lora_id = kLoraNone) { if (hybrid_cache == nullptr) return; std::vector prefix_pages = DevicePagesFromRoot(device_node_ref->Node()); @@ -120,8 +120,9 @@ void InsertHybridCache(HybridPrefixCache* hybrid_cache, } OwnedPages pages_to_insert = local_kv_allocator->TakeFirst(new_page_count); - auto insert_result = hybrid_cache->GetKVPrefixCache().Insert(full_paged_tokens, prefix_pages, - std::move(pages_to_insert)); + auto insert_result = hybrid_cache->GetKVPrefixCache().Insert( + full_paged_tokens, prefix_pages, std::move(pages_to_insert), + /*page_hashs=*/{}, /*start_node=*/nullptr, lora_id); if (local_mamba_allocator != nullptr && local_mamba_allocator->HasCheckpoint()) { if (ShouldPublishMambaCheckpoint(hybrid_cache, chunk_begin, chunk_size, page_size)) { @@ -214,7 +215,8 @@ std::variant SchedulePrefillEvent::operator()(Prefillin paged_tokens.resize(end_of_window_pages); } InsertHybridCache(hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(), - local_mamba_allocator.get(), state.window.begin, state.window.size, state.GetPageSize()); + local_mamba_allocator.get(), state.window.begin, state.window.size, state.GetPageSize(), + lora_id_); // Allocate KV pages for the new chunk local_kv_allocator->Acquire(tokens_this_round_); @@ -264,7 +266,8 @@ Decoding ScheduleDecodeEvent::operator()(PrefillDone&& state) { paged_tokens.resize(end_of_window_pages); } InsertHybridCache(hybrid_prefix_cache_, paged_tokens, device_node_ref, local_kv_allocator.get(), - local_mamba_allocator.get(), state.window.begin, state.window.size, state.GetPageSize()); + local_mamba_allocator.get(), state.window.begin, state.window.size, state.GetPageSize(), + lora_id_); // Allocate fresh checkpoint for decode-phase mamba state tracking if (hybrid_prefix_cache_ != nullptr && local_mamba_allocator != nullptr) { if (!local_mamba_allocator->AllocateCheckpoint()) { @@ -358,12 +361,12 @@ std::variant FinishEvent::apply(ForwardStateT&& state) { OwnedPages alloc_pages = local_allocator->TakeFirst(alloc_count); kv_prefix_cache_->Insert(full_paged_tokens, prefix_pages, std::move(alloc_pages), - page_hashes_); + page_hashes_, /*start_node=*/nullptr, lora_id_); // Mamba: insert the latest checkpoint snapshot at the terminal node. if (hybrid_prefix_cache_ != nullptr && local_mamba_allocator != nullptr && (local_mamba_allocator->HasCheckpoint() || local_mamba_allocator->HasWorking())) { - MatchResult post_match = kv_prefix_cache_->Match(full_paged_tokens); + MatchResult post_match = kv_prefix_cache_->Match(full_paged_tokens, lora_id_); TreeNode* terminal = post_match.device.last_node; if (terminal != nullptr && !terminal->HasMamba()) { if (local_mamba_allocator->HasCheckpoint()) { @@ -376,7 +379,7 @@ std::variant FinishEvent::apply(ForwardStateT&& state) { } // local_mamba_allocator dropped here — destructor frees remaining slots - MatchResult match = kv_prefix_cache_->Match(full_paged_tokens); + MatchResult match = kv_prefix_cache_->Match(full_paged_tokens, lora_id_); if (!disable_l2_cache_ && (match.device.DepthInPage() > match.host.DepthInPage())) { std::vector write_diff = match.NodesWithout(); std::int32_t host_pages_num = 0; diff --git a/tokenspeed-scheduler/csrc/fsm/forward_events.h b/tokenspeed-scheduler/csrc/fsm/forward_events.h index 0f42b86b6..1e70b98bb 100644 --- a/tokenspeed-scheduler/csrc/fsm/forward_events.h +++ b/tokenspeed-scheduler/csrc/fsm/forward_events.h @@ -35,6 +35,7 @@ #include "fsm/base_event.h" #include "fsm/forward_states.h" #include "resource/types.h" +#include "resource/kv_prefix_cache/kv_prefix_cache.h" #include "resource/hybrid_prefix_cache/hybrid_prefix_cache.h" #include "resource/allocator/mamba_chunk_allocator.h" #include "resource/allocator/local_mamba_allocator.h" @@ -106,10 +107,11 @@ struct SchedulePrefillFirstChunkEvent : InvalidTransitionHandler { using InvalidTransitionHandler::operator(); SchedulePrefillEvent(std::int32_t tokens_this_round, std::int32_t reserve_num_tokens_in_next_schedule_event, - HybridPrefixCache* hybrid_prefix_cache = nullptr) + HybridPrefixCache* hybrid_prefix_cache = nullptr, std::int32_t lora_id = kLoraNone) : tokens_this_round_(tokens_this_round), reserve_num_tokens_in_next_schedule_event_(reserve_num_tokens_in_next_schedule_event), - hybrid_prefix_cache_(hybrid_prefix_cache) {} + hybrid_prefix_cache_(hybrid_prefix_cache), + lora_id_(lora_id) {} // Returns PrefillDone (last chunk) or Prefilling (more chunks remain). std::variant operator()(Prefilling&& state); @@ -118,13 +120,15 @@ struct SchedulePrefillEvent : InvalidTransitionHandler { std::int32_t tokens_this_round_{}; std::int32_t reserve_num_tokens_in_next_schedule_event_{}; HybridPrefixCache* hybrid_prefix_cache_{}; + std::int32_t lora_id_{kLoraNone}; }; struct ScheduleDecodeEvent : InvalidTransitionHandler { using InvalidTransitionHandler::operator(); - ScheduleDecodeEvent(std::int32_t decode_input_tokens, HybridPrefixCache* hybrid_prefix_cache = nullptr) - : decode_input_tokens_(decode_input_tokens), hybrid_prefix_cache_(hybrid_prefix_cache) {} + ScheduleDecodeEvent(std::int32_t decode_input_tokens, HybridPrefixCache* hybrid_prefix_cache = nullptr, + std::int32_t lora_id = kLoraNone) + : decode_input_tokens_(decode_input_tokens), hybrid_prefix_cache_(hybrid_prefix_cache), lora_id_(lora_id) {} Decoding operator()(PrefillDone&& state); Decoding operator()(Decoding&& state); @@ -132,6 +136,7 @@ struct ScheduleDecodeEvent : InvalidTransitionHandler { private: std::int32_t decode_input_tokens_; HybridPrefixCache* hybrid_prefix_cache_{}; + std::int32_t lora_id_{kLoraNone}; }; struct ScheduleDecodeFromRetractedEvent : InvalidTransitionHandler { @@ -174,12 +179,13 @@ struct FinishEvent : InvalidTransitionHandler { using InvalidTransitionHandler::operator(); explicit FinishEvent(KVPrefixCache* kv_prefix_cache, PageAllocator* host_allocator, std::vector page_hashes = {}, bool disable_l2_cache = false, - HybridPrefixCache* hybrid_prefix_cache = nullptr) + HybridPrefixCache* hybrid_prefix_cache = nullptr, std::int32_t lora_id = kLoraNone) : kv_prefix_cache_(kv_prefix_cache), host_allocator_(host_allocator), page_hashes_(std::move(page_hashes)), disable_l2_cache_(disable_l2_cache), - hybrid_prefix_cache_(hybrid_prefix_cache) {} + hybrid_prefix_cache_(hybrid_prefix_cache), + lora_id_(lora_id) {} // Returns Draining (needs device→host writeback) or Finished. std::variant operator()(Decoding&& state); @@ -197,6 +203,7 @@ struct FinishEvent : InvalidTransitionHandler { PageAllocator* host_allocator_; bool disable_l2_cache_; HybridPrefixCache* hybrid_prefix_cache_{}; + std::int32_t lora_id_{kLoraNone}; template std::variant apply(ForwardStateT&& state); diff --git a/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache.cpp b/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache.cpp index 361067454..5db5bec11 100644 --- a/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache.cpp +++ b/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache.cpp @@ -48,16 +48,16 @@ HybridPrefixCache::HybridPrefixCache(KVPrefixCache& kv_prefix_cache, MambaChunkA mamba_eviction_manager_{mamba_allocator}, mamba_cache_chunk_size_{mamba_cache_chunk_size} {} -MatchResult HybridPrefixCache::Match(const token_vec_t& token_ids, MatchIntent intent) { - auto match = kv_prefix_cache_.Match(token_ids, intent); +MatchResult HybridPrefixCache::Match(const token_vec_t& token_ids, std::int32_t lora_id, MatchIntent intent) { + auto match = kv_prefix_cache_.Match(token_ids, lora_id, intent); augmentMatch(match); augmentMatchPagedCache(match); return match; } MatchResult HybridPrefixCache::Match(const std::vector>& token_pages, - MatchIntent intent) { - auto match = kv_prefix_cache_.Match(token_pages, intent); + std::int32_t lora_id, MatchIntent intent) { + auto match = kv_prefix_cache_.Match(token_pages, lora_id, intent); augmentMatch(match); augmentMatchPagedCache(match); return match; @@ -231,15 +231,11 @@ std::vector HybridPrefixCache::PrepareMambaDeviceLoadBack(const st } bool HybridPrefixCache::EnsureMambaCapacityByEvict(std::int32_t num_slots, TreeNode* protected_node) { - if (mamba_allocator_ == nullptr) return num_slots <= 0; return mamba_eviction_manager_.EnsureCapacity(num_slots, protected_node); } void HybridPrefixCache::InsertMamba(TreeNode* terminal_node, std::unique_ptr slot) { if (terminal_node == nullptr || slot == nullptr) return; - if (mamba_allocator_ == nullptr) { - throw std::logic_error("HybridPrefixCache::InsertMamba: mamba adjunct not enabled"); - } const std::int32_t page_size = kv_prefix_cache_.PageSize(); if (page_size <= 0 || terminal_node->DepthInTokens() % static_cast(page_size) != 0) { throw std::logic_error("HybridPrefixCache::InsertMamba: terminal node is not block-aligned"); @@ -250,10 +246,6 @@ void HybridPrefixCache::InsertMamba(TreeNode* terminal_node, std::unique_ptr snapshot) { if (node == nullptr || snapshot == nullptr) return false; - // Compute completeness from what is present. The policy-driven "snapshot - // must be full" invariant is enforced upstream by CommitChunk, which only - // attaches full snapshots; direct callers (tests, future restore paths) - // may attach history-only or state-only snapshots without policy gating. snapshot->complete_families.clear(); bool history_complete = !paged_cache_history_groups_.empty(); for (const auto& gid : paged_cache_history_groups_) { @@ -295,9 +287,6 @@ void HybridPrefixCache::OnKVEvict(TreeNode* node) { mamba_eviction_manager_.UpdateLeaf(node->Parent()); } } - // Passive paged-cache detach on KV LRU drop: returns OwnedPages via RAII; - // the chain scan sees the gap because `HasPagedCacheSnapshot()` is false. - // Route through DetachPagedCacheSnapshotFromNode to keep membership set in sync. if (node->HasPagedCacheSnapshot()) { DetachPagedCacheSnapshotFromNode(node); } @@ -387,7 +376,6 @@ void HybridPrefixCache::OnKVDeviceDemote(TreeNode* node) { } std::int32_t HybridPrefixCache::AvailableSlots() const { - if (mamba_allocator_ == nullptr) return 0; return mamba_allocator_->AvailableSlots(); } diff --git a/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache.h b/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache.h index a519427ce..96ee4960c 100644 --- a/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache.h +++ b/tokenspeed-scheduler/csrc/resource/hybrid_prefix_cache/hybrid_prefix_cache.h @@ -50,8 +50,9 @@ class HybridPrefixCache { HybridPrefixCache(KVPrefixCache& prefix_cache, MambaChunkAllocator* allocator, std::int32_t mamba_cache_chunk_size, MambaHostAllocator* mamba_host_allocator = nullptr); - MatchResult Match(const token_vec_t& token_ids, MatchIntent intent = MatchIntent::PrefixReuse); - MatchResult Match(const std::vector>& token_pages, + MatchResult Match(const token_vec_t& token_ids, std::int32_t lora_id = kLoraNone, + MatchIntent intent = MatchIntent::PrefixReuse); + MatchResult Match(const std::vector>& token_pages, std::int32_t lora_id = kLoraNone, MatchIntent intent = MatchIntent::PrefixReuse); bool EnsureMambaCapacityByEvict(std::int32_t num_slots, TreeNode* protected_node = nullptr); diff --git a/tokenspeed-scheduler/csrc/resource/kv_prefix_cache/eviction.h b/tokenspeed-scheduler/csrc/resource/kv_prefix_cache/eviction.h index 7726fb1e7..b7c135e20 100644 --- a/tokenspeed-scheduler/csrc/resource/kv_prefix_cache/eviction.h +++ b/tokenspeed-scheduler/csrc/resource/kv_prefix_cache/eviction.h @@ -173,6 +173,34 @@ std::vector ResourceManager::Evict(std::int32_t num_pages) { return evicted_nodes; } +template +void ResourceManager::EvictSubtree(const std::vector& nodes) { + for (TreeNode* node : nodes) { + bool has_resource; + if constexpr (RType == ResourceType::Device) { + has_resource = node->OnDevice(); + } else { + has_resource = node->OnHost(); + } + if (!has_resource) continue; + + const auto& res = GetResource(node); + if (!res.IsEvictable()) continue; // skip locked nodes; freed when request finishes + + auto it = node_time_.find(node); + if (it != node_time_.end()) { + lru_leaves_.erase({it->second, node}); + node_time_.erase(it); + GetResource(node).ClearEvictableNotifier(); + } + auto resource_ptr = node->DetachResource(); + if (eviction_callback_) { + eviction_callback_(node); + } + // OwnedPages RAII: pages returned to allocator on scope exit. + } +} + template std::vector ResourceManager::EnsureCapacity(std::int32_t required_num_pages) { if (required_num_pages <= 0) { diff --git a/tokenspeed-scheduler/csrc/resource/kv_prefix_cache/kv_prefix_cache.cpp b/tokenspeed-scheduler/csrc/resource/kv_prefix_cache/kv_prefix_cache.cpp index 6272e0fd8..0667c070b 100644 --- a/tokenspeed-scheduler/csrc/resource/kv_prefix_cache/kv_prefix_cache.cpp +++ b/tokenspeed-scheduler/csrc/resource/kv_prefix_cache/kv_prefix_cache.cpp @@ -125,6 +125,30 @@ KVPrefixCache::KVPrefixCache(PageAllocator* device_allocator, PageAllocator* hos enable_l3_storage_(enable_l3_storage), disable_prefix_cache_(disable_prefix_cache) {} +TreeNode* KVPrefixCache::getOrCreateLoraRoot(std::int32_t lora_id) { + auto& slot = lora_virtual_roots_[lora_id]; + // Re-create if null or if the node was pruned from the tree (parent == nullptr + // while not the real root means it was removed by PruneEmptyByNode). + if (slot != nullptr && slot->Parent() != nullptr) { + return slot; + } + // Sentinel page: [-lora_id, 0, ..., 0]. Negative token IDs never appear in + // real vocabularies (which are always non-negative), so there is no collision. + const std::int32_t page_size = tree_.PageSize(); + token_vec_t sentinel(page_size, 0); + sentinel[0] = -lora_id; + auto node = std::make_unique(sentinel, std::chrono::steady_clock::now()); + TreeNode* raw = node.get(); + // Attach an empty DeviceResource so OnDevice() returns true. + // This prevents PruneEmptyByNode from removing the virtual root even when + // all adapter sequences have been evicted. + raw->AttachResource(std::make_unique>(OwnedPages{})); + token_vec_t key(sentinel.begin(), sentinel.begin() + page_size); + tree_.Root()->AddChild(key, std::move(node)); + slot = raw; + return raw; +} + void KVPrefixCache::SetKvEventSink(KvEventSink sink) { kv_event_sink_ = std::move(sink); if (!kv_event_sink_) { @@ -160,7 +184,7 @@ void KVPrefixCache::recordDeviceBlockRemoved(TreeNode* node) { } } -MatchResult KVPrefixCache::Match(const token_vec_t& token_ids, MatchIntent intent) { +MatchResult KVPrefixCache::Match(const token_vec_t& token_ids, std::int32_t lora_id, MatchIntent intent) { if (disable_prefix_cache_ && intent == MatchIntent::PrefixReuse) { const std::int32_t page_size = tree_.PageSize(); if (token_ids.size() % page_size != 0) { @@ -176,15 +200,23 @@ MatchResult KVPrefixCache::Match(const token_vec_t& token_ids, MatchIntent inten std::to_string(token_ids.size()) + "; page_size=" + std::to_string(page_size)); } - WalkResult walk_result = tree_.WalkDownUtilMismatch(token_ids, access_time); + TreeNode* start_node = resolveStartNode(lora_id); + WalkResult walk_result = tree_.WalkDownUtilMismatch(token_ids, access_time, start_node); MatchResult& match = walk_result.match; match.device.page_size = page_size; match.host.page_size = page_size; + if (lora_id != kLoraNone) { + // The virtual namespace root contributes 1 sentinel page to absolute tree + // depth. Subtract it so callers see the number of real matched token pages. + match.device.namespace_depth_offset = 1; + match.host.namespace_depth_offset = 1; + } return match; } -MatchResult KVPrefixCache::Match(const std::vector>& token_pages, MatchIntent intent) { - return Match(FlattenPages(token_pages, 0, token_pages.size()), intent); +MatchResult KVPrefixCache::Match(const std::vector>& token_pages, std::int32_t lora_id, + MatchIntent intent) { + return Match(FlattenPages(token_pages, 0, token_pages.size()), lora_id, intent); } MatchResult KVPrefixCache::RootMatch() const { @@ -199,7 +231,7 @@ MatchResult KVPrefixCache::RootMatch() const { template InsertResult KVPrefixCache::Insert(const token_vec_t& token_ids, const std::vector& prefix_pages, OwnedPages allocator_pages, const std::vector& page_hashs, - TreeNode* start_node) { + TreeNode* start_node, std::int32_t lora_id) { const std::int32_t page_size = tree_.PageSize(); auto insert_result = InsertResult{ .last_node = tree_.Root(), @@ -219,8 +251,12 @@ InsertResult KVPrefixCache::Insert(const token_vec_t& token_ids, const std::vect const auto& alloc_ids = allocator_pages.Ids(); page_ids.insert(page_ids.end(), alloc_ids.begin(), alloc_ids.end()); - WalkResult walk_result = - tree_.WalkDownUtilMismatch(token_slice{token_ids.data(), total_pages * page_size}, access_time, start_node); + // When start_node is nullptr (no prior match), resolve the LoRA namespace root. + // When start_node is provided (continuation from a prior match), the caller + // already points into the correct namespace subtree. + TreeNode* effective_start = (start_node != nullptr) ? start_node : resolveStartNode(lora_id); + WalkResult walk_result = tree_.WalkDownUtilMismatch(token_slice{token_ids.data(), total_pages * page_size}, + access_time, effective_start); token_slice mistmatched_tokens = walk_result.remaining_tokens; TreeNode* current = walk_result.terminal; @@ -317,9 +353,10 @@ InsertResult KVPrefixCache::Insert(const token_vec_t& token_ids, const std::vect template InsertResult KVPrefixCache::Insert(const std::vector>& token_pages, const std::vector& prefix_pages, OwnedPages allocator_pages, - const std::vector& page_hashs, TreeNode* start_node) { + const std::vector& page_hashs, TreeNode* start_node, + std::int32_t lora_id) { return Insert(FlattenPages(token_pages, 0, token_pages.size()), prefix_pages, std::move(allocator_pages), - page_hashs, start_node); + page_hashs, start_node, lora_id); } template @@ -389,24 +426,52 @@ cache_op_id KVPrefixCache::AllocateCacheOpId() { return next_op_id_++; } -template InsertResult KVPrefixCache::Insert(const token_vec_t& token_ids, - const std::vector& prefix_pages, - OwnedPages allocator_pages, - const std::vector& page_hashs, - TreeNode* start_node); +void KVPrefixCache::EvictLoraNamespace(std::int32_t lora_id) { + auto it = lora_virtual_roots_.find(lora_id); + if (it == lora_virtual_roots_.end() || it->second == nullptr) { + return; + } + TreeNode* vroot = it->second; + + // Collect all descendant nodes via DFS (excluding the virtual root itself, + // which holds no real KV pages). + std::vector descendants; + std::function collect = [&](TreeNode* node) { + for (auto& [key, child] : node->Children()) { + if (!child) continue; + descendants.push_back(child.get()); + collect(child.get()); + } + }; + collect(vroot); + + // Evict device and host pages. OwnedPages RAII returns them to the allocator. + device_.EvictSubtree(descendants); + host_.EvictSubtree(descendants); + + // Remove the virtual root from the tree. The unique_ptr cascade destroys the + // entire subtree (including any mamba slots attached to those nodes). + token_vec_t sentinel(tree_.PageSize(), 0); + sentinel[0] = -lora_id; + tree_.Root()->RemoveChild(sentinel); -template InsertResult KVPrefixCache::Insert(const token_vec_t& token_ids, - const std::vector& prefix_pages, - OwnedPages allocator_pages, - const std::vector& page_hashs, - TreeNode* start_node); + lora_virtual_roots_.erase(it); +} +template InsertResult KVPrefixCache::Insert(const token_vec_t&, const std::vector&, + OwnedPages, const std::vector&, + TreeNode*, std::int32_t); +template InsertResult KVPrefixCache::Insert(const token_vec_t&, const std::vector&, + OwnedPages, const std::vector&, TreeNode*, + std::int32_t); template InsertResult KVPrefixCache::Insert(const std::vector>&, const std::vector&, OwnedPages, - const std::vector&, TreeNode*); + const std::vector&, TreeNode*, + std::int32_t); template InsertResult KVPrefixCache::Insert(const std::vector>&, const std::vector&, OwnedPages, - const std::vector&, TreeNode*); + const std::vector&, TreeNode*, + std::int32_t); template bool KVPrefixCache::EnsureCapacityByEvict(std::int32_t required_num_pages); template bool KVPrefixCache::EnsureCapacityByEvict(std::int32_t required_num_pages); diff --git a/tokenspeed-scheduler/csrc/resource/kv_prefix_cache/kv_prefix_cache.h b/tokenspeed-scheduler/csrc/resource/kv_prefix_cache/kv_prefix_cache.h index 1027d5e42..5f24138c4 100644 --- a/tokenspeed-scheduler/csrc/resource/kv_prefix_cache/kv_prefix_cache.h +++ b/tokenspeed-scheduler/csrc/resource/kv_prefix_cache/kv_prefix_cache.h @@ -28,6 +28,10 @@ #include #include +// kLoraNone is the lora_id value meaning "base model, no adapter". +// Adapter IDs are positive integers assigned by LoraRegistry. +static constexpr std::int32_t kLoraNone = 0; + #include "resource/radix_tree/radix_tree.h" #include "resource/radix_tree/tree_resource.h" #include "resource/types.h" @@ -47,19 +51,29 @@ class KVPrefixCache { bool disable_prefix_cache = false); void SetKvEventSink(KvEventSink sink); - MatchResult Match(const token_vec_t& token_ids, MatchIntent intent = MatchIntent::PrefixReuse); - MatchResult Match(const std::vector>& token_pages, + + // lora_id = kLoraNone (0) → base model, uses the shared radix tree root. + // lora_id > 0 → adapter namespace; a per-adapter virtual root is + // created on demand so same-adapter requests share the + // prefix cache while cross-adapter requests never collide. + // intent: PrefixReuse honours disable_prefix_cache_ (returns empty match); + // StateRecovery always walks the tree (used to recover state for + // retracted requests even when prefix caching is disabled). + MatchResult Match(const token_vec_t& token_ids, std::int32_t lora_id = kLoraNone, + MatchIntent intent = MatchIntent::PrefixReuse); + MatchResult Match(const std::vector>& token_pages, std::int32_t lora_id = kLoraNone, MatchIntent intent = MatchIntent::PrefixReuse); template InsertResult Insert(const token_vec_t& token_ids, const std::vector& prefix_pages, OwnedPages allocator_pages = {}, const std::vector& page_hashs = {}, - TreeNode* start_node = nullptr); + TreeNode* start_node = nullptr, std::int32_t lora_id = kLoraNone); template InsertResult Insert(const std::vector>& token_pages, const std::vector& prefix_pages, OwnedPages allocator_pages = {}, - const std::vector& page_hashs = {}, TreeNode* start_node = nullptr); + const std::vector& page_hashs = {}, TreeNode* start_node = nullptr, + std::int32_t lora_id = kLoraNone); cache_op_id AllocateCacheOpId(); @@ -88,6 +102,13 @@ class KVPrefixCache { RadixTree& GetRadixTree() { return tree_; } const RadixTree& GetRadixTree() const { return tree_; } + // Evict all KV pages cached under the given adapter's namespace and remove + // the virtual root from the tree. Call this when an adapter is unloaded so + // its pages are freed immediately rather than waiting for LRU pressure. + // Locked pages (in-flight requests) are skipped and freed when those + // requests finish. + void EvictLoraNamespace(std::int32_t lora_id); + private: MatchResult RootMatch() const; @@ -106,11 +127,25 @@ class KVPrefixCache { } } + // Returns (or creates) the virtual root node for the given LoRA adapter. + // The virtual root is a child of the real root keyed by a sentinel page + // [-lora_id, 0, ..., 0] that is outside any real vocabulary range. + // An empty DeviceResource is attached so PruneEmptyByNode never removes it. + TreeNode* getOrCreateLoraRoot(std::int32_t lora_id); + + // Resolve the start_node for Match/Insert: nullptr for base model, + // per-adapter virtual root for LoRA. + TreeNode* resolveStartNode(std::int32_t lora_id) { + return (lora_id == kLoraNone) ? nullptr : getOrCreateLoraRoot(lora_id); + } + RadixTree tree_; DeviceManager device_; HostManager host_; cache_op_id next_op_id_{1}; bool enable_l3_storage_{false}; + // Per-adapter virtual root nodes; keyed by lora_id (> 0). + std::unordered_map lora_virtual_roots_; KvEventSink kv_event_sink_{}; std::unordered_set published_device_blocks_; bool disable_prefix_cache_{false}; diff --git a/tokenspeed-scheduler/csrc/resource/radix_tree/tree_resource.h b/tokenspeed-scheduler/csrc/resource/radix_tree/tree_resource.h index f6658f47c..9e4ba1981 100644 --- a/tokenspeed-scheduler/csrc/resource/radix_tree/tree_resource.h +++ b/tokenspeed-scheduler/csrc/resource/radix_tree/tree_resource.h @@ -109,6 +109,9 @@ class ResourceManager { void UpdateLeaves(TreeNode* node); std::vector Evict(std::int32_t num_pages); std::vector EnsureCapacity(std::int32_t required_num_pages); + // Evict all pages held by the given nodes (e.g. a LoRA namespace subtree). + // Locked nodes are skipped — their pages are freed when the request finishes. + void EvictSubtree(const std::vector& nodes); // Called by NodeResource::Unlock() when ref_count transitions 1→0. void OnNodeEvictable(TreeNode* node) { updateLeaf(node); } diff --git a/tokenspeed-scheduler/csrc/resource/types.cpp b/tokenspeed-scheduler/csrc/resource/types.cpp index 17f046386..45fa350bd 100644 --- a/tokenspeed-scheduler/csrc/resource/types.cpp +++ b/tokenspeed-scheduler/csrc/resource/types.cpp @@ -25,11 +25,11 @@ namespace tokenspeed { std::int32_t MatchResult::Device::DepthInPage() const { - return last_node->DepthInPage(page_size); + return last_node->DepthInPage(page_size) - namespace_depth_offset; } std::int32_t MatchResult::Host::DepthInPage() const { - return last_node->DepthInPage(page_size); + return last_node->DepthInPage(page_size) - namespace_depth_offset; } template diff --git a/tokenspeed-scheduler/csrc/resource/types.h b/tokenspeed-scheduler/csrc/resource/types.h index 4d53e5c0b..7ed404087 100644 --- a/tokenspeed-scheduler/csrc/resource/types.h +++ b/tokenspeed-scheduler/csrc/resource/types.h @@ -55,12 +55,17 @@ struct MatchResult { struct Device { TreeNode* last_node; std::int32_t page_size{0}; + // Number of virtual namespace-root pages to subtract from the absolute + // tree depth to get the number of real matched token pages. + // 0 for base-model requests; 1 for LoRA adapter requests. + std::int32_t namespace_depth_offset{0}; std::int32_t DepthInPage() const; } device; struct Host { TreeNode* last_node; std::int32_t page_size{0}; + std::int32_t namespace_depth_offset{0}; std::int32_t DepthInPage() const; } host; diff --git a/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp b/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp index a8ce8f900..076a07ae2 100644 --- a/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp +++ b/tokenspeed-scheduler/csrc/scheduler/operations/forward.cpp @@ -81,8 +81,9 @@ std::optional Scheduler::schedulePrefillFir Request* request, std::int32_t remaining, std::int32_t decode_input_tokens, bool disable_l2_cache, std::map& simulated_free) { if (req_pool_allocator_.AvailableSlots() == 0) return {}; - MatchResult match_result = hybrid_prefix_cache_ ? hybrid_prefix_cache_->Match(request->GetFullPagedTokens(true)) - : kv_prefix_cache_.Match(request->GetFullPagedTokens(true)); + MatchResult match_result = hybrid_prefix_cache_ + ? hybrid_prefix_cache_->Match(request->GetFullPagedTokens(true), request->LoraId()) + : kv_prefix_cache_.Match(request->GetFullPagedTokens(true), request->LoraId()); std::int32_t loadback_tokens = 0; std::int32_t unscheduled = 0; std::vector loadback_diff; @@ -227,8 +228,9 @@ std::optional Scheduler::scheduleDecodeFr MatchResult match_result = hybrid_prefix_cache_ - ? hybrid_prefix_cache_->Match(request->GetFullPagedTokens(true), MatchIntent::StateRecovery) - : kv_prefix_cache_.Match(request->GetFullPagedTokens(true), MatchIntent::StateRecovery); + ? hybrid_prefix_cache_->Match(request->GetFullPagedTokens(true), request->LoraId(), + MatchIntent::StateRecovery) + : kv_prefix_cache_.Match(request->GetFullPagedTokens(true), request->LoraId(), MatchIntent::StateRecovery); std::vector loadback_diff = match_result.NodesWithout(); std::vector mamba_loadback_nodes; TreeNode* mamba_recovery_node = nullptr; @@ -321,7 +323,7 @@ std::optional Scheduler::scheduleRetract(Request* req kv_prefix_cache_.Insert(full_paged_tokens, prefix_pages, std::move(alloc_pages)); - MatchResult match_result = kv_prefix_cache_.Match(full_paged_tokens, MatchIntent::StateRecovery); + MatchResult match_result = kv_prefix_cache_.Match(full_paged_tokens, request->LoraId(), MatchIntent::StateRecovery); std::unique_ptr temp_lock = std::make_unique(match_result.host.last_node); const std::int32_t device_matched3 = match_result.device.DepthInPage(); @@ -576,9 +578,25 @@ Scheduler::newForwardOperation(std::vector candidates) { std::vector loadback_ops; auto simulated_free = hybrid_prefix_cache_ ? hybrid_prefix_cache_->InitialSimulatedFree() : std::map{}; + + // Track unique LoRA adapter ids in this batch. When max_loras > 0 we skip + // any request whose lora_id would push the count over the cap, deferring it + // to the next scheduling round. This guarantees prepare_loras() never + // receives a batch that requires more GPU adapter slots than are available. + std::unordered_set batch_lora_ids; + for (Request* request : candidates) { if (token_budget <= 0 || config_.max_batch_size == ops.size()) break; + // LoRA adapter cap: skip requests that would exceed max_loras unique ids. + if (config_.max_loras > 0 && request->lora_id() != kLoraNone) { + bool is_new = batch_lora_ids.find(request->lora_id()) == batch_lora_ids.end(); + if (is_new && static_cast(batch_lora_ids.size()) >= config_.max_loras) { + continue; // defer to next step + } + batch_lora_ids.insert(request->lora_id()); + } + if (request->Is() && config_.role != Role::kD) { std::int32_t reserver_num_tokens = config_.role == Role::kP ? 0 : config_.decode_input_tokens; if (auto ev = schedulePrefill(request, token_budget, reserver_num_tokens, simulated_free)) { diff --git a/tokenspeed-scheduler/csrc/scheduler/outside_event_handler.cpp b/tokenspeed-scheduler/csrc/scheduler/outside_event_handler.cpp index 2279df74a..9c5f31928 100644 --- a/tokenspeed-scheduler/csrc/scheduler/outside_event_handler.cpp +++ b/tokenspeed-scheduler/csrc/scheduler/outside_event_handler.cpp @@ -91,9 +91,9 @@ void Scheduler::handleEvent(const pd::FailedEvent& event) {} void Scheduler::handleEvent(const pd::SucceededEvent& event) { std::vector page_hashes; - requests_.at(event.request_id) - ->Apply(fsm::FinishEvent{&kv_prefix_cache_, &host_allocator_, std::move(page_hashes), config_.disable_l2_cache, - hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr}); + auto& req = requests_.at(event.request_id); + req->Apply(fsm::FinishEvent{&kv_prefix_cache_, &host_allocator_, std::move(page_hashes), config_.disable_l2_cache, + hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr, req->LoraId()}); } void Scheduler::handleEvent(const pd::RemotePrefillDoneEvent& event) { @@ -115,7 +115,8 @@ void Scheduler::handleEvent(const forward::Finish& event) { } } req->Apply(fsm::FinishEvent{&kv_prefix_cache_, &host_allocator_, std::move(page_hashes), - config_.disable_l2_cache, hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr}); + config_.disable_l2_cache, hybrid_prefix_cache_ ? &*hybrid_prefix_cache_ : nullptr, + req->LoraId()}); } } diff --git a/tokenspeed-scheduler/csrc/scheduler/request.cpp b/tokenspeed-scheduler/csrc/scheduler/request.cpp index 6aaa3c55a..46d5ab1b1 100644 --- a/tokenspeed-scheduler/csrc/scheduler/request.cpp +++ b/tokenspeed-scheduler/csrc/scheduler/request.cpp @@ -29,6 +29,7 @@ namespace tokenspeed { Request::Request(const RequestSpec& spec, std::int32_t page_size, Role role) : id_{spec.request_id}, + lora_id_{spec.lora_id}, token_container_{spec.tokens}, page_size_{page_size}, state_{role == Role::kFused ? fsm::State{fsm::Submitted{&token_container_, page_size}} diff --git a/tokenspeed-scheduler/csrc/scheduler/request.h b/tokenspeed-scheduler/csrc/scheduler/request.h index 89b770c68..56bdf2efd 100644 --- a/tokenspeed-scheduler/csrc/scheduler/request.h +++ b/tokenspeed-scheduler/csrc/scheduler/request.h @@ -53,6 +53,7 @@ class Request { Request(const RequestSpec& spec, std::int32_t page_size, Role role); std::string Id() const { return id_; } + std::int32_t LoraId() const { return lora_id_; } // Keep Apply the only non-const function in Request // The wrapper lambda converts any concrete state type returned by event's operator() @@ -273,6 +274,7 @@ class Request { private: std::string id_; + std::int32_t lora_id_{0}; TokenContainer token_container_; std::int32_t page_size_; fsm::State state_; diff --git a/tokenspeed-scheduler/csrc/scheduler/request_spec.h b/tokenspeed-scheduler/csrc/scheduler/request_spec.h index eaf85ebda..07a9e28ee 100644 --- a/tokenspeed-scheduler/csrc/scheduler/request_spec.h +++ b/tokenspeed-scheduler/csrc/scheduler/request_spec.h @@ -32,6 +32,10 @@ struct RequestSpec { std::vector tokens; std::vector rolling_hashes; std::int32_t storage_hit_pages{0}; + // 0 = base model (no adapter). >0 = LoRA adapter integer ID from + // LoraRegistry. The prefix cache is namespaced per lora_id so adapters + // never share KV pages with different LoRA weights. + std::int32_t lora_id{0}; }; struct PrefillInfo { diff --git a/tokenspeed-scheduler/csrc/scheduler/scheduler.cpp b/tokenspeed-scheduler/csrc/scheduler/scheduler.cpp index 5df53231d..ef79684ab 100644 --- a/tokenspeed-scheduler/csrc/scheduler/scheduler.cpp +++ b/tokenspeed-scheduler/csrc/scheduler/scheduler.cpp @@ -132,6 +132,10 @@ Scheduler::Scheduler(SchedulerConfig config) } } +void Scheduler::EvictLoraNamespace(std::int32_t lora_id) { + kv_prefix_cache_.EvictLoraNamespace(lora_id); +} + std::vector Scheduler::DrainKvEvents() { std::vector events; events.swap(kv_events_); diff --git a/tokenspeed-scheduler/csrc/scheduler/scheduler.h b/tokenspeed-scheduler/csrc/scheduler/scheduler.h index c36c3a413..84fb36d6c 100644 --- a/tokenspeed-scheduler/csrc/scheduler/scheduler.h +++ b/tokenspeed-scheduler/csrc/scheduler/scheduler.h @@ -60,6 +60,9 @@ class Scheduler { void Advance(const ExecutionEvent& event); std::vector DrainKvEvents(); + // Evict all KV pages cached under the given LoRA adapter's namespace and + // remove its virtual root from the prefix tree. Call on adapter unload. + void EvictLoraNamespace(std::int32_t lora_id); std::size_t WaitingSize() const; std::size_t DecodingSize() const; diff --git a/tokenspeed-scheduler/csrc/scheduler/types.h b/tokenspeed-scheduler/csrc/scheduler/types.h index 2a8310891..a34d7b669 100644 --- a/tokenspeed-scheduler/csrc/scheduler/types.h +++ b/tokenspeed-scheduler/csrc/scheduler/types.h @@ -42,8 +42,6 @@ enum class DisaggregationMode { kPrefill, kDecode, }; -// `PagedCacheGroupFamily` and `StateRestorePolicy` are defined in -// resource/allocator/paged_cache_group.h (transitively included above). template class NodeRef; @@ -84,7 +82,6 @@ struct SchedulerConfig { } device_allocator; std::vector paged_cache_groups{}; - // Unset means paged-cache groups are transport-only. std::optional prefix_cache_adjunct{}; @@ -106,6 +103,12 @@ struct SchedulerConfig { std::int32_t mamba_pool_total_chunks{0}; bool enable_mamba_l2{false}; std::int32_t mamba_l2_host_slots{0}; + + // Maximum number of unique LoRA adapter ids allowed in a single batch. + // 0 means LoRA is disabled (no cap enforced). When set, newForwardOperation + // defers requests that would push the batch over this limit to the next step, + // guaranteeing that prepare_loras() never sees n_unique > max_loras. + std::int32_t max_loras{0}; }; } // namespace tokenspeed diff --git a/tokenspeed-scheduler/python/tokenspeed_scheduler/__init__.py b/tokenspeed-scheduler/python/tokenspeed_scheduler/__init__.py index f2070be85..dc87ada88 100644 --- a/tokenspeed-scheduler/python/tokenspeed_scheduler/__init__.py +++ b/tokenspeed-scheduler/python/tokenspeed_scheduler/__init__.py @@ -27,10 +27,8 @@ ExecutionPlan, PagedCacheGroupAllocator, PagedCacheGroupConfig, - PagedCacheGroupFamily, PagedCacheGroupTable, PagedCacheRetention, - PrefixCacheAdjunctSpec, RequestSpec, Scheduler, SchedulerConfig, @@ -73,9 +71,7 @@ def _flat_forward_op_repr(self): "PagedCacheRetention", "PagedCacheGroupConfig", "PagedCacheGroupAllocator", - "PagedCacheGroupFamily", "PagedCacheGroupTable", - "PrefixCacheAdjunctSpec", # Execution plan & operations "ExecutionPlan", "Forward",