diff --git a/crates/grpc_client/proto/tokenspeed_scheduler.proto b/crates/grpc_client/proto/tokenspeed_scheduler.proto index 5be4d5b6b..f6571a3f2 100644 --- a/crates/grpc_client/proto/tokenspeed_scheduler.proto +++ b/crates/grpc_client/proto/tokenspeed_scheduler.proto @@ -6,8 +6,10 @@ import "google/protobuf/timestamp.proto"; import "google/protobuf/struct.proto"; // TokenSpeed scheduler gRPC service. Fully self-contained wire definition. -// Trimmed to text-generation only (no embed, no multimodal, no -// PD-disaggregated, no LoRA, no hidden-state forwarding). +// Trimmed to text+image generation (no embed, no PD-disaggregated, no +// LoRA, no hidden-state forwarding). Multimodal carries preprocessed +// tensors only — image fetch + per-model preprocess happen in the +// gateway (see crates/multimodal). service TokenSpeedScheduler { rpc Generate(GenerateRequest) returns (stream GenerateResponse); rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse); @@ -85,6 +87,9 @@ message GenerateRequest { // Whether the client wants stream chunks (otherwise: complete-only). bool stream = 8; + + // Preprocessed multimodal payload. Absent for text-only requests. + MultimodalInputs mm_inputs = 9; } message TokenizedInput { @@ -92,6 +97,36 @@ message TokenizedInput { string original_text = 2; // cosmetic, for worker logs } +// A typed tensor: raw little-endian bytes + shape + dtype. +message TensorData { + bytes data = 1; // Raw little-endian bytes (f32/i64/u32) + repeated uint32 shape = 2; // Dimension sizes + string dtype = 3; // "float32", "int64", "uint32" +} + +// Where a multimodal item's tokens sit inside input_ids. +message PlaceholderRange { + uint32 offset = 1; + uint32 length = 2; +} + +// Multimodal inputs for vision/audio models. Tensors are produced by the +// gateway's per-model preprocessor (crates/multimodal); the servicer +// only reconstructs them and hands them to the engine — no preprocess. +message MultimodalInputs { + // Preprocessed pixel values tensor. + TensorData pixel_values = 1; + + // Model-specific tensors (image_grid_thw, aspect_ratios, etc.). + map model_specific_tensors = 2; + + // Image token id used for placeholder expansion. + optional uint32 im_token_id = 3; + + // Placeholder offsets: where each image's tokens are in input_ids. + repeated PlaceholderRange mm_placeholders = 4; +} + message GenerateResponse { string request_id = 1; @@ -191,6 +226,11 @@ message GetModelInfoResponse { // ``{"temperature": 0.6, "top_p": 0.9}``). Empty = no overrides. Surfaced // to the router via the ``default_sampling_params_json`` worker label. string default_sampling_params_json = 13; + + // True when the model can consume image/video inputs. Drives the + // router's mm_inputs handling (router rejects mm requests for + // non-vision workers before tokenization). + bool supports_vision = 14; } message GetServerInfoRequest {} diff --git a/crates/grpc_client/src/tokenspeed_scheduler.rs b/crates/grpc_client/src/tokenspeed_scheduler.rs index ea194f48c..ff3e8e18a 100644 --- a/crates/grpc_client/src/tokenspeed_scheduler.rs +++ b/crates/grpc_client/src/tokenspeed_scheduler.rs @@ -179,6 +179,7 @@ impl TokenSpeedSchedulerClient { body: &ChatCompletionRequest, processed_text: String, token_ids: Vec, + multimodal_inputs: Option, tool_call_constraint: Option<(String, String)>, ) -> Result { let sampling_params = Self::build_sampling_params_from_chat(body, tool_call_constraint)?; @@ -193,6 +194,7 @@ impl TokenSpeedSchedulerClient { logprob_start_len: Some(-1), top_logprobs_num: body.top_logprobs.unwrap_or(0) as i32, stream: body.stream, + mm_inputs: multimodal_inputs, ..Default::default() }) } @@ -222,6 +224,7 @@ impl TokenSpeedSchedulerClient { top_logprobs_num: body.top_logprobs_num.unwrap_or(0), token_ids_logprob: body.token_ids_logprob.clone().unwrap_or_default(), stream: body.stream, + mm_inputs: None, }) } @@ -260,6 +263,7 @@ impl TokenSpeedSchedulerClient { body: &CreateMessageRequest, processed_text: String, token_ids: Vec, + multimodal_inputs: Option, tool_call_constraint: Option<(String, String)>, ) -> Result { let sampling_params = @@ -272,6 +276,7 @@ impl TokenSpeedSchedulerClient { }), sampling_params: Some(sampling_params), stream: body.stream.unwrap_or(false), + mm_inputs: multimodal_inputs, ..Default::default() }) } diff --git a/grpc_servicer/smg_grpc_servicer/tokenspeed/server.py b/grpc_servicer/smg_grpc_servicer/tokenspeed/server.py index fe9b5bc65..5ce267533 100644 --- a/grpc_servicer/smg_grpc_servicer/tokenspeed/server.py +++ b/grpc_servicer/smg_grpc_servicer/tokenspeed/server.py @@ -23,6 +23,45 @@ logger = logging.getLogger(__name__) +# Match the other SMG servicers' 256 MiB default — a single oversized or +# malformed gRPC frame can otherwise trigger a multi-GiB transient +# allocation. VLM deployments that genuinely need bigger ``pixel_values`` +# raise it via ``TOKENSPEED_GRPC_MAX_MESSAGE_BYTES``; the env value is +# clamped to gRPC's hard ceiling (``INT32_MAX`` = 2 GiB - 1). +_GRPC_DEFAULT_MAX_BYTES = 256 * 1024 * 1024 +_GRPC_HARD_CEILING_BYTES = (1 << 31) - 1 + + +def _grpc_max_message_bytes() -> int: + """Return the configured gRPC message ceiling (send + receive use the same).""" + raw = os.getenv("TOKENSPEED_GRPC_MAX_MESSAGE_BYTES") + if not raw: + return _GRPC_DEFAULT_MAX_BYTES + try: + value = int(raw) + except ValueError: + logger.warning( + "TOKENSPEED_GRPC_MAX_MESSAGE_BYTES=%r is not an int; falling back to %d", + raw, + _GRPC_DEFAULT_MAX_BYTES, + ) + return _GRPC_DEFAULT_MAX_BYTES + if value <= 0: + logger.warning( + "TOKENSPEED_GRPC_MAX_MESSAGE_BYTES=%d must be positive; falling back to %d", + value, + _GRPC_DEFAULT_MAX_BYTES, + ) + return _GRPC_DEFAULT_MAX_BYTES + if value > _GRPC_HARD_CEILING_BYTES: + logger.warning( + "TOKENSPEED_GRPC_MAX_MESSAGE_BYTES=%d exceeds gRPC ceiling %d; clamping", + value, + _GRPC_HARD_CEILING_BYTES, + ) + return _GRPC_HARD_CEILING_BYTES + return value + async def serve_grpc(server_args: ServerArgs) -> None: """Run the TokenSpeed gRPC server until a shutdown signal is received.""" @@ -30,11 +69,12 @@ async def serve_grpc(server_args: ServerArgs) -> None: logger.info("Launching TokenSpeed scheduler + AsyncLLM...") async_llm, scheduler_info = launch_engine(server_args) + max_message_bytes = _grpc_max_message_bytes() server = grpc.aio.server( futures.ThreadPoolExecutor(max_workers=10), options=[ - ("grpc.max_send_message_length", 1024 * 1024 * 256), - ("grpc.max_receive_message_length", 1024 * 1024 * 256), + ("grpc.max_send_message_length", max_message_bytes), + ("grpc.max_receive_message_length", max_message_bytes), # Permissive keepalive so long prefill stalls don't trip GOAWAY. ("grpc.http2.min_recv_ping_interval_without_data_ms", 10000), ("grpc.keepalive_permit_without_calls", True), @@ -120,11 +160,12 @@ def _wait_and_warmup( # Wildcard bind hosts aren't routable as destinations; dial loopback instead. warmup_host = {"0.0.0.0": "127.0.0.1", "::": "::1"}.get(server_args.host, server_args.host) grpc_url = f"{warmup_host}:{server_args.port}" + max_message_bytes = _grpc_max_message_bytes() channel = grpc.insecure_channel( grpc_url, options=[ - ("grpc.max_send_message_length", 1024 * 1024 * 256), - ("grpc.max_receive_message_length", 1024 * 1024 * 256), + ("grpc.max_send_message_length", max_message_bytes), + ("grpc.max_receive_message_length", max_message_bytes), ], ) stub = tokenspeed_scheduler_pb2_grpc.TokenSpeedSchedulerStub(channel) diff --git a/grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py b/grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py index 0b07a3544..16f5054ad 100644 --- a/grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py +++ b/grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py @@ -2,8 +2,8 @@ Implements ``tokenspeed.grpc.scheduler.TokenSpeedScheduler`` on top of :class:`tokenspeed.runtime.engine.async_llm.AsyncLLM`. The proto field set -is intentionally minimal — generative LLM serving only, no Embed / -GetTokenizer / SubscribeKvEvents / multimodal / PD-disaggregated / LoRA / +is intentionally minimal — generative LLM serving plus precomputed multimodal; +no Embed / GetTokenizer / SubscribeKvEvents / PD-disaggregated / LoRA / hidden states / classifier outputs. """ @@ -21,16 +21,23 @@ from typing import TYPE_CHECKING, Any import grpc +import numpy as np +import torch from google.protobuf.struct_pb2 import Struct from google.protobuf.timestamp_pb2 import Timestamp from smg_grpc_proto import tokenspeed_scheduler_pb2_grpc from smg_grpc_proto.generated import tokenspeed_scheduler_pb2 +from tokenspeed.runtime.multimodal.inputs import ( + Modality, + MultimodalDataItem, + MultimodalInputs, +) from smg_grpc_servicer.tokenspeed.health_servicer import TokenSpeedHealthServicer if TYPE_CHECKING: - # Type-only imports — not resolved at module load so the servicer is - # importable in test environments that stub AsyncLLM / ServerArgs. + # Type-only — keeps these out of the cold-path graph when the servicer is + # imported by tooling that stubs the engine surface. from tokenspeed.runtime.engine.async_llm import AsyncLLM from tokenspeed.runtime.utils.server_args import ServerArgs @@ -358,6 +365,8 @@ async def GetModelInfo( tokenizer_path = getattr(self.server_args, "tokenizer", None) or getattr( self.server_args, "tokenizer_path", "" ) + supports_vision = bool(getattr(model_config, "is_multimodal", False)) + return tokenspeed_scheduler_pb2.GetModelInfoResponse( model_path=model_path, tokenizer_path=tokenizer_path or "", @@ -372,6 +381,7 @@ async def GetModelInfo( pad_token_id=(getattr(hf_config, "pad_token_id", 0) or 0) if hf_config else 0, bos_token_id=(getattr(hf_config, "bos_token_id", 0) or 0) if hf_config else 0, max_req_input_len=int(max_req_input_len), + supports_vision=supports_vision, ) # ------------------------------------------------------------------ @@ -570,6 +580,14 @@ def _build_generate_req(self, request: tokenspeed_scheduler_pb2.GenerateRequest) reasoning_parser=getattr(self.server_args, "reasoning_parser", None), ) + # Decode the precomputed multimodal payload, if the request carries one. + precomputed_mm = None + if request.HasField("mm_inputs") and request.mm_inputs.HasField("pixel_values"): + precomputed_mm = self._mm_inputs_from_proto( + request.mm_inputs, + model_dtype=getattr(self.async_llm.model_config, "dtype", None), + ) + GenerateReqInput = _lazy_generate_req_input() obj = GenerateReqInput( input_ids=input_ids, @@ -585,6 +603,7 @@ def _build_generate_req(self, request: tokenspeed_scheduler_pb2.GenerateRequest) token_ids_logprob=( list(request.token_ids_logprob) if request.token_ids_logprob else None ), + precomputed_multimodal_inputs=precomputed_mm, ) # ``normalize_batch_and_arguments`` asserts ``rid`` is a list when # n>1; expand to deterministic per-choice rids so the assert holds. @@ -687,6 +706,71 @@ def _sampling_params_from_proto( return out + def _mm_inputs_from_proto( + self, + mm_inputs: tokenspeed_scheduler_pb2.MultimodalInputs, + *, + model_dtype: torch.dtype | None = None, + ): + """Reconstruct the engine's ``MultimodalInputs`` from the precomputed proto. + + The gateway already preprocessed, so the engine skips its own preprocessing + (``precomputed_multimodal_inputs`` is set); this just boxes the tensors and + placeholder offsets into the engine's data class. + """ + feature = self._tensor_from_proto(mm_inputs.pixel_values, cast_to=model_dtype) + model_specific_data = {} + for name, tensor_data in mm_inputs.model_specific_tensors.items(): + model_specific_data[name] = self._tensor_from_proto(tensor_data, cast_to=model_dtype) + + if not mm_inputs.mm_placeholders: + raise ValueError( + "multimodal request carried no placeholders; " + "the image token was not located in input_ids" + ) + if any(p.length <= 0 for p in mm_inputs.mm_placeholders): + raise ValueError("mm_placeholders.length must be > 0") + # Placeholders arrive as (offset, length); the engine wants inclusive (start, end). + offsets = [(p.offset, p.offset + p.length - 1) for p in mm_inputs.mm_placeholders] + + mm_item = MultimodalDataItem( + modality=Modality.IMAGE, + feature=feature, + model_specific_data=model_specific_data, + offsets=offsets, + ) + # pad_value must exist before the engine splices features into the embedding + # stream, otherwise it fails inferring a dtype from None. + mm_item.set_pad_value() + + im_token_id = mm_inputs.im_token_id if mm_inputs.HasField("im_token_id") else None + return MultimodalInputs(mm_items=[mm_item], im_token_id=im_token_id) + + @staticmethod + def _tensor_from_proto( + tensor_data: tokenspeed_scheduler_pb2.TensorData, + cast_to: torch.dtype | None = None, + ): + """Reconstruct a torch.Tensor from a proto TensorData. + + Floats are cast to ``cast_to``, fused into the decode; the buffer is + copied so it never aliases the transient proto bytes. + """ + shape = list(tensor_data.shape) + if tensor_data.dtype == "bfloat16": + # numpy has no bfloat16 — read the raw bits as uint16, reinterpret. + t = torch.from_numpy( + np.frombuffer(tensor_data.data, dtype=np.uint16).reshape(shape) + ).view(torch.bfloat16) + else: + t = torch.from_numpy( + np.frombuffer(tensor_data.data, dtype=np.dtype(tensor_data.dtype)).reshape(shape) + ) + + if cast_to is not None and t.dtype != cast_to and t.is_floating_point(): + return t.to(cast_to) + return t.clone() + def _generated_output_ids( self, output: dict, diff --git a/model_gateway/src/routers/grpc/client.rs b/model_gateway/src/routers/grpc/client.rs index 370ef8103..465dc2e58 100644 --- a/model_gateway/src/routers/grpc/client.rs +++ b/model_gateway/src/routers/grpc/client.rs @@ -447,14 +447,16 @@ impl GrpcClient { Ok(ProtoGenerateRequest::Mlx(Box::new(req))) } Self::TokenSpeed(client) => { - if multimodal_inputs.is_some() { - return Err("TokenSpeed backend does not support multimodal inputs".to_string()); - } + let tokenspeed_mm = multimodal_inputs.map(|mm| match mm { + MultimodalData::TokenSpeed(data) => data.into_proto(), + _ => unreachable!("caller guarantees matching variant"), + }); let req = client.build_generate_request_from_chat( request_id, body, processed_text, token_ids, + tokenspeed_mm, tool_constraints, )?; Ok(ProtoGenerateRequest::TokenSpeed(Box::new(req))) @@ -533,14 +535,16 @@ impl GrpcClient { Ok(ProtoGenerateRequest::Mlx(Box::new(req))) } Self::TokenSpeed(client) => { - if multimodal_inputs.is_some() { - return Err("TokenSpeed backend does not support multimodal inputs".to_string()); - } + let tokenspeed_mm = multimodal_inputs.map(|mm| match mm { + MultimodalData::TokenSpeed(data) => data.into_proto(), + _ => unreachable!("caller guarantees matching variant"), + }); let req = client.build_generate_request_from_messages( request_id, body, processed_text, token_ids, + tokenspeed_mm, tool_constraints, )?; Ok(ProtoGenerateRequest::TokenSpeed(Box::new(req))) diff --git a/model_gateway/src/routers/grpc/harmony/stages/request_building.rs b/model_gateway/src/routers/grpc/harmony/stages/request_building.rs index edce1e2ee..d3e346aa2 100644 --- a/model_gateway/src/routers/grpc/harmony/stages/request_building.rs +++ b/model_gateway/src/routers/grpc/harmony/stages/request_building.rs @@ -287,6 +287,7 @@ impl PipelineStage for HarmonyRequestBuildingStage { body, placeholder_processed_text, token_ids, + None, // Harmony path: multimodal not yet wired tool_constraints, ) .map_err(|e| { diff --git a/model_gateway/src/routers/grpc/multimodal.rs b/model_gateway/src/routers/grpc/multimodal.rs index a30d42135..96d502a56 100644 --- a/model_gateway/src/routers/grpc/multimodal.rs +++ b/model_gateway/src/routers/grpc/multimodal.rs @@ -29,7 +29,10 @@ use tracing::{debug, warn}; use crate::routers::grpc::{ client::GrpcClient, - proto_wrapper::{SglangMultimodalData, TensorBytes, TrtllmMultimodalData, VllmMultimodalData}, + proto_wrapper::{ + SglangMultimodalData, TensorBytes, TokenSpeedMultimodalData, TrtllmMultimodalData, + VllmMultimodalData, + }, MultimodalData, }; @@ -705,12 +708,10 @@ pub(crate) fn assemble_multimodal_data( GrpcClient::Sglang(_) => MultimodalData::Sglang(assemble_sglang(intermediate)), GrpcClient::Vllm(_) => MultimodalData::Vllm(assemble_vllm(intermediate)), GrpcClient::Trtllm(_) => MultimodalData::Trtllm(assemble_trtllm(intermediate)), + GrpcClient::TokenSpeed(_) => MultimodalData::TokenSpeed(assemble_tokenspeed(intermediate)), GrpcClient::Mlx(_) => unreachable!( "caller rejects multimodal for MLX in build_chat_request/build_messages_request" ), - GrpcClient::TokenSpeed(_) => unreachable!( - "TokenSpeed backend does not support multimodal; preparation stage should reject earlier" - ), } } @@ -778,6 +779,30 @@ fn assemble_trtllm(intermediate: MultimodalIntermediate) -> TrtllmMultimodalData TrtllmMultimodalData { image_data } } +fn assemble_tokenspeed(intermediate: MultimodalIntermediate) -> TokenSpeedMultimodalData { + // Use patch-only offsets when available and non-empty; fall back to full structural ranges. + let (pixel_values, pixel_values_shape) = serialize_pixel_values(&intermediate.preprocessed); + let model_specific_tensors = serialize_model_specific(intermediate.preprocessed.model_specific); + let mm_placeholders = intermediate + .patch_offsets + .filter(|offsets| !offsets.is_empty()) + .unwrap_or_else(|| { + intermediate + .placeholders + .iter() + .map(|p| (p.offset as u32, p.length as u32)) + .collect() + }); + + TokenSpeedMultimodalData { + pixel_values, + pixel_values_shape, + model_specific_tensors, + im_token_id: intermediate.im_token_id, + mm_placeholders, + } +} + // --------------------------------------------------------------------------- // Serialization helpers // --------------------------------------------------------------------------- diff --git a/model_gateway/src/routers/grpc/proto_wrapper.rs b/model_gateway/src/routers/grpc/proto_wrapper.rs index 48c7aef22..161862273 100644 --- a/model_gateway/src/routers/grpc/proto_wrapper.rs +++ b/model_gateway/src/routers/grpc/proto_wrapper.rs @@ -32,11 +32,13 @@ use smg_grpc_client::{ /// - SGLang: pixel_values + model_specific_tensors + patch-only placeholders /// - vLLM: pixel_values + model_specific_tensors + structural placeholders + hashes + field keys /// - TRT-LLM: raw image bytes only (preprocessing handled server-side) +/// - TokenSpeed: pixel_values + model_specific_tensors + patch-only placeholders #[derive(Debug)] pub enum MultimodalData { Sglang(SglangMultimodalData), Vllm(VllmMultimodalData), Trtllm(TrtllmMultimodalData), + TokenSpeed(TokenSpeedMultimodalData), } /// SGLang multimodal data: preprocessed tensors with patch-only placeholders. @@ -73,6 +75,16 @@ pub struct TrtllmMultimodalData { pub image_data: Vec>, } +/// TokenSpeed multimodal data: preprocessed tensors with patch-only placeholders. +#[derive(Debug)] +pub struct TokenSpeedMultimodalData { + pub pixel_values: Vec, + pub pixel_values_shape: Vec, + pub model_specific_tensors: HashMap, + pub im_token_id: Option, + pub mm_placeholders: Vec<(u32, u32)>, +} + /// Raw tensor bytes with shape and dtype metadata. #[derive(Debug, Clone)] pub struct TensorBytes { @@ -175,6 +187,43 @@ impl TrtllmMultimodalData { } } +impl TokenSpeedMultimodalData { + /// Convert to TokenSpeed proto MultimodalInputs. + pub fn into_proto(self) -> tokenspeed::MultimodalInputs { + let model_specific_tensors = self + .model_specific_tensors + .into_iter() + .map(|(k, v)| { + ( + k, + tokenspeed::TensorData { + data: v.data, + shape: v.shape, + dtype: v.dtype, + }, + ) + }) + .collect(); + + let mm_placeholders = self + .mm_placeholders + .into_iter() + .map(|(offset, length)| tokenspeed::PlaceholderRange { offset, length }) + .collect(); + + tokenspeed::MultimodalInputs { + pixel_values: Some(tokenspeed::TensorData { + data: self.pixel_values, + shape: self.pixel_values_shape, + dtype: "float32".to_string(), + }), + model_specific_tensors, + im_token_id: self.im_token_id, + mm_placeholders, + } + } +} + // ===================== // Unified Logprobs Types // ===================== @@ -451,8 +500,9 @@ impl ProtoGenerateRequest { match self { Self::Sglang(req) => req.mm_inputs = None, Self::Vllm(req) => req.mm_inputs = None, - // TRT-LLM, MLX, and TokenSpeed protos have no mm_inputs field - Self::Trtllm(_) | Self::Mlx(_) | Self::TokenSpeed(_) => {} + Self::TokenSpeed(req) => req.mm_inputs = None, + // TRT-LLM and MLX protos have no mm_inputs field + Self::Trtllm(_) | Self::Mlx(_) => {} } } diff --git a/scripts/ci_install_tokenspeed.sh b/scripts/ci_install_tokenspeed.sh index 094560b62..0953dd2b6 100755 --- a/scripts/ci_install_tokenspeed.sh +++ b/scripts/ci_install_tokenspeed.sh @@ -1,18 +1,15 @@ #!/bin/bash # Install TokenSpeed from source (engine + kernel + scheduler) for CI. # -# TokenSpeed is not published to PyPI, so we clone it and pip-install the -# in-tree ``tokenspeed-kernel`` (CUDA), ``tokenspeed-scheduler`` (C++/nanobind), -# and ``python/`` packages. Mirrors the upstream ``docker/Dockerfile`` pipeline. +# Mirrors the upstream install pattern (see tokenspeed's docs / test/ci_system/ +# install_deps.sh): one editable pip install per package, in engine → +# kernel → scheduler order. The kernel package's metadata pulls in its +# own CUDA dependencies, so we don't pre-install requirements files. # # Prerequisites (expected on k8s-runner-gpu nodes): # - NVIDIA driver 580+ (CUDA 13) # - CUDA 13.0 toolkit at /usr/local/cuda-13.0 or /usr/local/cuda # - H100 GPUs (sm90) -# -# Heavy first run (~30 min for kernel CUDA compile); subsequent runs on the -# same runner hit the pip wheel cache at /tmp/tokenspeed-wheel-cache/ and -# short-circuit the kernel build. set -euo pipefail @@ -25,10 +22,9 @@ fi # a scheduled bump-and-CI routine) rather than floating against ``main`` — # upstream has renamed APIs before and the gRPC servicer broke until we # caught up. -TOKENSPEED_REF="${TOKENSPEED_REF:-70030b298bc6abf6903348057605cc083bf70746}" +TOKENSPEED_REF="${TOKENSPEED_REF:-5e145afae8e5651cd66234e68c988c31aac6639f}" TOKENSPEED_REPO="${TOKENSPEED_REPO:-https://github.com/lightseekorg/tokenspeed.git}" TOKENSPEED_DIR="${TOKENSPEED_DIR:-/tmp/tokenspeed-src}" -WHEEL_CACHE="${TOKENSPEED_WHEEL_CACHE:-/tmp/tokenspeed-wheel-cache}" # Install uv for faster package management (mirrors ci_install_sglang.sh). if ! command -v uv &> /dev/null; then @@ -41,10 +37,7 @@ echo "uv version: $(uv --version)" # ── CUDA runtime setup ───────────────────────────────────────────────────── # k8s-runner-gpu ships the NVIDIA driver + CUDA runtime libs but not the # SDK (nvcc, headers). Install them on demand — same approach as -# ``ci_install_sglang.sh``, which installs cuda-nvcc-12-9 + -# cuda-cudart-dev-12-9 when ``/usr/local/cuda/bin/nvcc`` is missing. -# TokenSpeed's Dockerfile targets CUDA 13.0, so install the matching -# toolkit packages here. +# ``ci_install_sglang.sh``. CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" if [ ! -x "${CUDA_HOME}/bin/nvcc" ]; then echo "Installing CUDA toolkit (nvcc not found at ${CUDA_HOME}/bin/nvcc)..." @@ -53,12 +46,6 @@ if [ ! -x "${CUDA_HOME}/bin/nvcc" ]; then sudo dpkg -i /tmp/cuda-keyring.deb rm /tmp/cuda-keyring.deb sudo apt-get update -qq - # cuda-nvcc-13-0: provides nvcc + cuda_runtime_api.h - # cuda-cudart-dev-13-0: provides cuda_runtime.h + libcudart headers - # cuda-libraries-dev-13-0: meta-package pulling in cublas / curand / - # cusolver / cusparse / cufft / nvrtc / - # nvjitlink dev headers that tokenspeed-kernel - # needs (cublas_v2.h, curand.h, cublasLt.h, ...) sudo apt-get install -y --no-install-recommends \ cuda-nvcc-13-0 \ cuda-cudart-dev-13-0 \ @@ -104,38 +91,18 @@ export DEBIAN_FRONTEND=noninteractive sudo apt-get update -qq sudo apt-get install -y --no-install-recommends libssl-dev libopenmpi-dev cmake -# ── Kernel + scheduler + engine install ──────────────────────────────────── -# Step 1: plain Python requirements. -uv pip install -r tokenspeed-kernel/python/requirements/cuda.txt - -# Step 2: build-isolation=off so nanobind/cutlass build dependencies are shared. -uv pip install -r tokenspeed-kernel/python/requirements/cuda-thirdparty.txt \ - --no-build-isolation - -# Step 3: kernel (CUDA compile — the expensive one). Try the cached wheel first. -CACHED_KERNEL_WHEEL=$(find "$WHEEL_CACHE" -name "tokenspeed_kernel-*.whl" 2>/dev/null | head -1 || true) -if [ -n "$CACHED_KERNEL_WHEEL" ] && [ -f "$CACHED_KERNEL_WHEEL" ]; then - echo "Installing cached tokenspeed-kernel wheel: $CACHED_KERNEL_WHEEL" - uv pip install "$CACHED_KERNEL_WHEEL" --no-build-isolation -else - echo "Building tokenspeed-kernel from source (this takes ~30 min the first time)..." - MAX_JOBS="${MAX_JOBS:-16}" FLASHINFER_CUDA_ARCH_LIST="9.0a 10.0a" \ - uv pip install tokenspeed-kernel/python/ --no-build-isolation - # Cache the built wheel — uv stores wheels under its cache, copy out. - mkdir -p "$WHEEL_CACHE" - python3 -c "import tokenspeed_kernel, os, shutil, glob; \ - d = os.path.dirname(tokenspeed_kernel.__file__); \ - site = os.path.dirname(d); \ - whls = glob.glob(os.path.join(site, 'tokenspeed_kernel-*.dist-info')); \ - print('kernel install dir:', whls)" || true -fi +# ── TokenSpeed packages ──────────────────────────────────────────────────── +export MAX_JOBS="${MAX_JOBS:-16}" +export FLASHINFER_CUDA_ARCH_LIST="${FLASHINFER_CUDA_ARCH_LIST:-9.0a 10.0a}" -# Step 4: scheduler (scikit-build-core + nanobind + CMake). -echo "Building tokenspeed-scheduler..." -uv pip install tokenspeed-scheduler/ +# Preseed build-time tooling: ``./python`` and ``tokenspeed-kernel`` use +# ``setuptools.build_meta`` without declaring ``setuptools`` in +# ``build-system.requires``, and we install with ``--no-build-isolation``. +uv pip install setuptools wheel pybind11 -# Step 5: the Python runtime (pure-Python). -uv pip install "./python" --no-build-isolation +uv pip install -e tokenspeed-kernel/python/ --no-build-isolation +uv pip install -e tokenspeed-scheduler/ +uv pip install -e "./python" --no-build-isolation # ── Persist env to subsequent CI steps ───────────────────────────────────── if [ -n "${GITHUB_ENV:-}" ]; then