Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 42 additions & 2 deletions crates/grpc_client/proto/tokenspeed_scheduler.proto
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ import "google/protobuf/timestamp.proto";
import "google/protobuf/struct.proto";

// TokenSpeed scheduler gRPC service. Fully self-contained wire definition.
// Trimmed to text-generation only (no embed, no multimodal, no
// PD-disaggregated, no LoRA, no hidden-state forwarding).
// Trimmed to text+image generation (no embed, no PD-disaggregated, no
// LoRA, no hidden-state forwarding). Multimodal carries preprocessed
// tensors only — image fetch + per-model preprocess happen in the
// gateway (see crates/multimodal).
service TokenSpeedScheduler {
rpc Generate(GenerateRequest) returns (stream GenerateResponse);
rpc HealthCheck(HealthCheckRequest) returns (HealthCheckResponse);
Expand Down Expand Up @@ -85,13 +87,46 @@ message GenerateRequest {

// Whether the client wants stream chunks (otherwise: complete-only).
bool stream = 8;

// Preprocessed multimodal payload. Absent for text-only requests.
MultimodalInputs mm_inputs = 9;
}

message TokenizedInput {
repeated uint32 input_ids = 1;
string original_text = 2; // cosmetic, for worker logs
}

// A typed tensor: raw little-endian bytes + shape + dtype.
message TensorData {
bytes data = 1; // Raw little-endian bytes (f32/i64/u32)
repeated uint32 shape = 2; // Dimension sizes
string dtype = 3; // "float32", "int64", "uint32"
}

// Where a multimodal item's tokens sit inside input_ids.
message PlaceholderRange {
uint32 offset = 1;
uint32 length = 2;
}

// Multimodal inputs for vision/audio models. Tensors are produced by the
// gateway's per-model preprocessor (crates/multimodal); the servicer
// only reconstructs them and hands them to the engine — no preprocess.
message MultimodalInputs {
// Preprocessed pixel values tensor.
TensorData pixel_values = 1;

// Model-specific tensors (image_grid_thw, aspect_ratios, etc.).
map<string, TensorData> model_specific_tensors = 2;

// Image token id used for placeholder expansion.
optional uint32 im_token_id = 3;

// Placeholder offsets: where each image's tokens are in input_ids.
repeated PlaceholderRange mm_placeholders = 4;
}

message GenerateResponse {
string request_id = 1;

Expand Down Expand Up @@ -191,6 +226,11 @@ message GetModelInfoResponse {
// ``{"temperature": 0.6, "top_p": 0.9}``). Empty = no overrides. Surfaced
// to the router via the ``default_sampling_params_json`` worker label.
string default_sampling_params_json = 13;

// True when the model can consume image/video inputs. Drives the
// router's mm_inputs handling (router rejects mm requests for
// non-vision workers before tokenization).
bool supports_vision = 14;
}

message GetServerInfoRequest {}
Expand Down
5 changes: 5 additions & 0 deletions crates/grpc_client/src/tokenspeed_scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ impl TokenSpeedSchedulerClient {
body: &ChatCompletionRequest,
processed_text: String,
token_ids: Vec<u32>,
multimodal_inputs: Option<tokenspeed_proto::MultimodalInputs>,
tool_call_constraint: Option<(String, String)>,
) -> Result<tokenspeed_proto::GenerateRequest, String> {
let sampling_params = Self::build_sampling_params_from_chat(body, tool_call_constraint)?;
Expand All @@ -193,6 +194,7 @@ impl TokenSpeedSchedulerClient {
logprob_start_len: Some(-1),
top_logprobs_num: body.top_logprobs.unwrap_or(0) as i32,
stream: body.stream,
mm_inputs: multimodal_inputs,
..Default::default()
})
}
Expand Down Expand Up @@ -222,6 +224,7 @@ impl TokenSpeedSchedulerClient {
top_logprobs_num: body.top_logprobs_num.unwrap_or(0),
token_ids_logprob: body.token_ids_logprob.clone().unwrap_or_default(),
stream: body.stream,
mm_inputs: None,
})
}

Expand Down Expand Up @@ -260,6 +263,7 @@ impl TokenSpeedSchedulerClient {
body: &CreateMessageRequest,
processed_text: String,
token_ids: Vec<u32>,
multimodal_inputs: Option<tokenspeed_proto::MultimodalInputs>,
tool_call_constraint: Option<(String, String)>,
) -> Result<tokenspeed_proto::GenerateRequest, String> {
let sampling_params =
Expand All @@ -272,6 +276,7 @@ impl TokenSpeedSchedulerClient {
}),
sampling_params: Some(sampling_params),
stream: body.stream.unwrap_or(false),
mm_inputs: multimodal_inputs,
..Default::default()
})
}
Expand Down
49 changes: 45 additions & 4 deletions grpc_servicer/smg_grpc_servicer/tokenspeed/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,58 @@

logger = logging.getLogger(__name__)

# Match the other SMG servicers' 256 MiB default — a single oversized or
# malformed gRPC frame can otherwise trigger a multi-GiB transient
# allocation. VLM deployments that genuinely need bigger ``pixel_values``
# raise it via ``TOKENSPEED_GRPC_MAX_MESSAGE_BYTES``; the env value is
# clamped to gRPC's hard ceiling (``INT32_MAX`` = 2 GiB - 1).
_GRPC_DEFAULT_MAX_BYTES = 256 * 1024 * 1024
_GRPC_HARD_CEILING_BYTES = (1 << 31) - 1


def _grpc_max_message_bytes() -> int:
"""Return the configured gRPC message ceiling (send + receive use the same)."""
raw = os.getenv("TOKENSPEED_GRPC_MAX_MESSAGE_BYTES")
if not raw:
return _GRPC_DEFAULT_MAX_BYTES
try:
value = int(raw)
except ValueError:
logger.warning(
"TOKENSPEED_GRPC_MAX_MESSAGE_BYTES=%r is not an int; falling back to %d",
raw,
_GRPC_DEFAULT_MAX_BYTES,
)
return _GRPC_DEFAULT_MAX_BYTES
if value <= 0:
logger.warning(
"TOKENSPEED_GRPC_MAX_MESSAGE_BYTES=%d must be positive; falling back to %d",
value,
_GRPC_DEFAULT_MAX_BYTES,
)
return _GRPC_DEFAULT_MAX_BYTES
if value > _GRPC_HARD_CEILING_BYTES:
logger.warning(
"TOKENSPEED_GRPC_MAX_MESSAGE_BYTES=%d exceeds gRPC ceiling %d; clamping",
value,
_GRPC_HARD_CEILING_BYTES,
)
return _GRPC_HARD_CEILING_BYTES
return value


async def serve_grpc(server_args: ServerArgs) -> None:
"""Run the TokenSpeed gRPC server until a shutdown signal is received."""

logger.info("Launching TokenSpeed scheduler + AsyncLLM...")
async_llm, scheduler_info = launch_engine(server_args)

max_message_bytes = _grpc_max_message_bytes()
server = grpc.aio.server(
futures.ThreadPoolExecutor(max_workers=10),
options=[
("grpc.max_send_message_length", 1024 * 1024 * 256),
("grpc.max_receive_message_length", 1024 * 1024 * 256),
("grpc.max_send_message_length", max_message_bytes),
("grpc.max_receive_message_length", max_message_bytes),
# Permissive keepalive so long prefill stalls don't trip GOAWAY.
("grpc.http2.min_recv_ping_interval_without_data_ms", 10000),
("grpc.keepalive_permit_without_calls", True),
Comment on lines 73 to 80
Expand Down Expand Up @@ -120,11 +160,12 @@ def _wait_and_warmup(
# Wildcard bind hosts aren't routable as destinations; dial loopback instead.
warmup_host = {"0.0.0.0": "127.0.0.1", "::": "::1"}.get(server_args.host, server_args.host)
grpc_url = f"{warmup_host}:{server_args.port}"
max_message_bytes = _grpc_max_message_bytes()
channel = grpc.insecure_channel(
grpc_url,
options=[
("grpc.max_send_message_length", 1024 * 1024 * 256),
("grpc.max_receive_message_length", 1024 * 1024 * 256),
("grpc.max_send_message_length", max_message_bytes),
("grpc.max_receive_message_length", max_message_bytes),
],
)
stub = tokenspeed_scheduler_pb2_grpc.TokenSpeedSchedulerStub(channel)
Expand Down
92 changes: 88 additions & 4 deletions grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

Implements ``tokenspeed.grpc.scheduler.TokenSpeedScheduler`` on top of
:class:`tokenspeed.runtime.engine.async_llm.AsyncLLM`. The proto field set
is intentionally minimal — generative LLM serving only, no Embed /
GetTokenizer / SubscribeKvEvents / multimodal / PD-disaggregated / LoRA /
is intentionally minimal — generative LLM serving plus precomputed multimodal;
no Embed / GetTokenizer / SubscribeKvEvents / PD-disaggregated / LoRA /
hidden states / classifier outputs.
"""

Expand All @@ -21,16 +21,23 @@
from typing import TYPE_CHECKING, Any

import grpc
import numpy as np
import torch
from google.protobuf.struct_pb2 import Struct
from google.protobuf.timestamp_pb2 import Timestamp
from smg_grpc_proto import tokenspeed_scheduler_pb2_grpc
from smg_grpc_proto.generated import tokenspeed_scheduler_pb2
from tokenspeed.runtime.multimodal.inputs import (
Modality,
MultimodalDataItem,
MultimodalInputs,
)
Comment on lines +30 to +34
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Defer TokenSpeed multimodal imports to runtime

Loading smg_grpc_servicer.tokenspeed.servicer now eagerly imports tokenspeed.runtime.multimodal.inputs (and its transitive deps) at module import time, which breaks the existing stub-based test/tooling scenario where TokenSpeed runtime is intentionally absent and only protobuf/request-conversion paths are exercised. This regresses the module’s previous lazy-import behavior and can fail fast with ModuleNotFoundError before any RPC handling; moving these imports into the multimodal conversion path (similar to _lazy_generate_req_input) keeps non-TokenSpeed environments importable.

Useful? React with 👍 / 👎.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@CatherineSue any thoughts?


from smg_grpc_servicer.tokenspeed.health_servicer import TokenSpeedHealthServicer

if TYPE_CHECKING:
# Type-only imports — not resolved at module load so the servicer is
# importable in test environments that stub AsyncLLM / ServerArgs.
# Type-only — keeps these out of the cold-path graph when the servicer is
# imported by tooling that stubs the engine surface.
from tokenspeed.runtime.engine.async_llm import AsyncLLM
from tokenspeed.runtime.utils.server_args import ServerArgs

Expand Down Expand Up @@ -358,6 +365,8 @@ async def GetModelInfo(
tokenizer_path = getattr(self.server_args, "tokenizer", None) or getattr(
self.server_args, "tokenizer_path", ""
)
supports_vision = bool(getattr(model_config, "is_multimodal", False))

return tokenspeed_scheduler_pb2.GetModelInfoResponse(
model_path=model_path,
tokenizer_path=tokenizer_path or "",
Expand All @@ -372,6 +381,7 @@ async def GetModelInfo(
pad_token_id=(getattr(hf_config, "pad_token_id", 0) or 0) if hf_config else 0,
bos_token_id=(getattr(hf_config, "bos_token_id", 0) or 0) if hf_config else 0,
max_req_input_len=int(max_req_input_len),
supports_vision=supports_vision,
)

# ------------------------------------------------------------------
Expand Down Expand Up @@ -570,6 +580,14 @@ def _build_generate_req(self, request: tokenspeed_scheduler_pb2.GenerateRequest)
reasoning_parser=getattr(self.server_args, "reasoning_parser", None),
)

# Decode the precomputed multimodal payload, if the request carries one.
precomputed_mm = None
if request.HasField("mm_inputs") and request.mm_inputs.HasField("pixel_values"):
precomputed_mm = self._mm_inputs_from_proto(
request.mm_inputs,
model_dtype=getattr(self.async_llm.model_config, "dtype", None),
)
Comment on lines +585 to +589
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Reject partial multimodal payloads instead of silently dropping them.

When mm_inputs is present but pixel_values is missing, the request currently falls back to text-only execution. That hides malformed payloads and can produce incorrect results; fail fast with INVALID_ARGUMENT by raising ValueError.

Proposed fix
-        precomputed_mm = None
-        if request.HasField("mm_inputs") and request.mm_inputs.HasField("pixel_values"):
-            precomputed_mm = self._mm_inputs_from_proto(
-                request.mm_inputs,
-                model_dtype=getattr(self.async_llm.model_config, "dtype", None),
-            )
+        precomputed_mm = None
+        if request.HasField("mm_inputs"):
+            if not request.mm_inputs.HasField("pixel_values"):
+                raise ValueError(
+                    "GenerateRequest.mm_inputs.pixel_values is required when mm_inputs is set"
+                )
+            precomputed_mm = self._mm_inputs_from_proto(
+                request.mm_inputs,
+                model_dtype=getattr(self.async_llm.model_config, "dtype", None),
+            )
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@grpc_servicer/smg_grpc_servicer/tokenspeed/servicer.py` around lines 579 -
583, The handler currently treats a present mm_inputs without pixel_values as
text-only; instead, in the block that checks request.HasField("mm_inputs") you
should detect the malformed partial multimodal payload (i.e., request.mm_inputs
is present but not request.mm_inputs.HasField("pixel_values")) and raise a
ValueError so the RPC returns INVALID_ARGUMENT; otherwise, continue to call
_mm_inputs_from_proto(request.mm_inputs,
model_dtype=getattr(self.async_llm.model_config, "dtype", None)) as before.


GenerateReqInput = _lazy_generate_req_input()
obj = GenerateReqInput(
input_ids=input_ids,
Expand All @@ -585,6 +603,7 @@ def _build_generate_req(self, request: tokenspeed_scheduler_pb2.GenerateRequest)
token_ids_logprob=(
list(request.token_ids_logprob) if request.token_ids_logprob else None
),
precomputed_multimodal_inputs=precomputed_mm,
)
# ``normalize_batch_and_arguments`` asserts ``rid`` is a list when
# n>1; expand to deterministic per-choice rids so the assert holds.
Expand Down Expand Up @@ -687,6 +706,71 @@ def _sampling_params_from_proto(

return out

def _mm_inputs_from_proto(
self,
mm_inputs: tokenspeed_scheduler_pb2.MultimodalInputs,
*,
model_dtype: torch.dtype | None = None,
):
"""Reconstruct the engine's ``MultimodalInputs`` from the precomputed proto.

The gateway already preprocessed, so the engine skips its own preprocessing
(``precomputed_multimodal_inputs`` is set); this just boxes the tensors and
placeholder offsets into the engine's data class.
"""
feature = self._tensor_from_proto(mm_inputs.pixel_values, cast_to=model_dtype)
model_specific_data = {}
for name, tensor_data in mm_inputs.model_specific_tensors.items():
model_specific_data[name] = self._tensor_from_proto(tensor_data, cast_to=model_dtype)

if not mm_inputs.mm_placeholders:
raise ValueError(
"multimodal request carried no placeholders; "
"the image token was not located in input_ids"
)
if any(p.length <= 0 for p in mm_inputs.mm_placeholders):
raise ValueError("mm_placeholders.length must be > 0")
# Placeholders arrive as (offset, length); the engine wants inclusive (start, end).
offsets = [(p.offset, p.offset + p.length - 1) for p in mm_inputs.mm_placeholders]

mm_item = MultimodalDataItem(
modality=Modality.IMAGE,
feature=feature,
model_specific_data=model_specific_data,
offsets=offsets,
)
# pad_value must exist before the engine splices features into the embedding
# stream, otherwise it fails inferring a dtype from None.
mm_item.set_pad_value()

im_token_id = mm_inputs.im_token_id if mm_inputs.HasField("im_token_id") else None
return MultimodalInputs(mm_items=[mm_item], im_token_id=im_token_id)

@staticmethod
def _tensor_from_proto(
tensor_data: tokenspeed_scheduler_pb2.TensorData,
cast_to: torch.dtype | None = None,
):
"""Reconstruct a torch.Tensor from a proto TensorData.

Floats are cast to ``cast_to``, fused into the decode; the buffer is
copied so it never aliases the transient proto bytes.
"""
shape = list(tensor_data.shape)
if tensor_data.dtype == "bfloat16":
# numpy has no bfloat16 — read the raw bits as uint16, reinterpret.
t = torch.from_numpy(
np.frombuffer(tensor_data.data, dtype=np.uint16).reshape(shape)
).view(torch.bfloat16)
else:
t = torch.from_numpy(
np.frombuffer(tensor_data.data, dtype=np.dtype(tensor_data.dtype)).reshape(shape)
)
Comment on lines +766 to +768
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Handle uint32 TensorData without torch.from_numpy

_tensor_from_proto decodes non-bfloat tensors via np.frombuffer(..., dtype=np.dtype(tensor_data.dtype)) and then calls torch.from_numpy(...). In this commit, TensorData explicitly allows "uint32", and the gateway can forward uint32 model-specific tensors, but torch.from_numpy does not accept np.uint32; that path raises a TypeError and the RPC is surfaced as INTERNAL instead of successfully building the request. Any multimodal model that includes unsigned tensor metadata will fail at request conversion unless uint32 is converted through a supported torch dtype path.

Useful? React with 👍 / 👎.


if cast_to is not None and t.dtype != cast_to and t.is_floating_point():
return t.to(cast_to)
return t.clone()
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def _generated_output_ids(
self,
output: dict,
Expand Down
16 changes: 10 additions & 6 deletions model_gateway/src/routers/grpc/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -447,14 +447,16 @@ impl GrpcClient {
Ok(ProtoGenerateRequest::Mlx(Box::new(req)))
}
Self::TokenSpeed(client) => {
if multimodal_inputs.is_some() {
return Err("TokenSpeed backend does not support multimodal inputs".to_string());
}
let tokenspeed_mm = multimodal_inputs.map(|mm| match mm {
MultimodalData::TokenSpeed(data) => data.into_proto(),
_ => unreachable!("caller guarantees matching variant"),
});
let req = client.build_generate_request_from_chat(
request_id,
body,
processed_text,
token_ids,
tokenspeed_mm,
tool_constraints,
)?;
Ok(ProtoGenerateRequest::TokenSpeed(Box::new(req)))
Expand Down Expand Up @@ -533,14 +535,16 @@ impl GrpcClient {
Ok(ProtoGenerateRequest::Mlx(Box::new(req)))
}
Self::TokenSpeed(client) => {
if multimodal_inputs.is_some() {
return Err("TokenSpeed backend does not support multimodal inputs".to_string());
}
let tokenspeed_mm = multimodal_inputs.map(|mm| match mm {
MultimodalData::TokenSpeed(data) => data.into_proto(),
_ => unreachable!("caller guarantees matching variant"),
});
let req = client.build_generate_request_from_messages(
request_id,
body,
processed_text,
token_ids,
tokenspeed_mm,
tool_constraints,
)?;
Ok(ProtoGenerateRequest::TokenSpeed(Box::new(req)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ impl PipelineStage for HarmonyRequestBuildingStage {
body,
placeholder_processed_text,
token_ids,
None, // Harmony path: multimodal not yet wired
tool_constraints,
)
.map_err(|e| {
Expand Down
Loading
Loading