-
Notifications
You must be signed in to change notification settings - Fork 55
Make _serializing_socket_cls used by internal WebsocketRPCEndpoint configurable from PubSubEndpoint #96
Description
Hello everyone!
First I'd like to say that this project has been great help. I was looking to implement a websocket endpoint that works across multiple instances of an application when I found this library. Took a while to get a hang of things but I later got to understand it.
For my use case, a client connects to the websocket and sends a request to fetch specific data. This request (with a unique channel id attached) is published to a queue which is consumed by another service that processes the request and then responds with the data, which is then published/broadcasted to all clients subscribed to that channel on the pub/sub endpoint. The client which initially made the request is also notified if still connected.
The Problem
All my endpoints respond with a specific response schema and I noticed after looking at the source code that the underlying WebsocketRPCEndpoint used by the PubSubEndpoint instance uses a _serializing_socket_cls to serialize and deserialize RPC messages, so that's most likely the best place to intercept the message and format the response. I then wrote a custom _serializing_socket_cls but there was no way to pass it to the PubSubEndpoint, so I had to override the pubsub_endpoint.endpoint._serializing_socket_cls manually to achieve this.
Here's a sample code of the implementation I had:
The custom socket serializer
class JobFetchWebSocketProxy(SimpleWebSocket):
data_schema = FetchExternalJobSchema
def __init__(self, websocket: SimpleWebSocket):
self.socket = websocket
self.rabbitmq_connection = getattr(
self.root.app.state, "rabbitmq_connection", None
)
if self.rabbitmq_connection is None:
raise ValueError("RabbitMQ connection not found in app state")
self.pubsub_topics = getattr(self.root.state, "pubsub_topics", None)
self.db_session = getattr(self.root.state, "db_session", None)
@property
def root(self) -> WebSocket:
root = self.socket
while socket := (
getattr(root, "socket", None) or getattr(root, "websocket", None)
):
root = socket
if not isinstance(root, WebSocket):
raise ValueError("Could not find root `starlette.WebSocket` instance")
return root
async def connect(self, uri: str, **connect_kwargs: typing.Any): # type: ignore
await self.socket.connect(uri, **connect_kwargs) # type: ignore
def serialize(self, msg: pydantic.BaseModel) -> str:
return pydantic_serialize(msg)
def deserialize(self, buffer: str) -> typing.Dict[str, typing.Any]:
data = orjson.loads(buffer)
if not isinstance(data, dict):
raise ValueError("Invalid data format, expected a JSON object")
return data
async def send(self, msg: pydantic.BaseModel): # type: ignore
# Convert the `RpcMessage` to a `response.Schema`
if isinstance(msg, RpcMessage) and msg.response is not None:
# Custom response schema applied here
response_msg = response.Schema.model_validate(msg.response.result)
else:
response_msg = msg
await self.socket.send(self.serialize(response_msg)) # type: ignore
async def recv(self) -> typing.Dict[str, typing.Any]: # type: ignore
msg = await self.socket.recv() # type: ignore
if msg is None:
return {"request": None}
try:
msg = self.deserialize(msg)
# Data schema validated here
data = self.data_schema.model_validate(msg)
except pydantic.ValidationError as exc:
logger.error(f"Failed to validate job fetch data: {exc}")
await self.send(
response.Schema(
status=response.Status.ERROR,
message="Invalid data",
detail="Failed to validate job fetch data",
errors=[e["msg"] for e in exc.errors()],
)
)
return {"request": None}
# Convert the data schema to a `RpcMessage`
rpc_msg = RpcMessage(
request=RpcRequest(
method="fetch_job",
arguments={
"db": self.db_session,
"job_url": str(data.job_url),
"metadata": data.metadata,
"rabbitmq_connection": self.rabbitmq_connection,
"pubsub_topics": self.pubsub_topics,
},
)
)
return rpc_msg.model_dump()The endpoint
job_fetch_pubsub = PubSubEndpoint(
methods_class=JobFetchRPCMethods,
broadcaster=settings.REDIS_URL,
on_connect=[on_job_fetch_socket_connect], # type: ignore
on_disconnect=[on_job_fetch_socket_disconnect], # type: ignore
ignore_broadcaster_disconnected=False,
)
# Manually overriding the `_serializing_socket_cls `
job_fetch_pubsub.endpoint._serializing_socket_cls = JobFetchWebSocketProxy
@router.websocket("/ws/job-fetch", dependencies=[authe.authentication_required])
async def job_fetch_socket(ws: WebSocket, user: authe.VerifiedUser):
user_id = str(user.id)
ws_channel_id = f"milo:core:job_fetch:ws:{user_id}"
logger.info(
f"User {user_id!r} is connecting to job fetch WebSocket with client ID: {ws_channel_id!r}",
extra={"user_id": user_id, "ws_channel_id": ws_channel_id},
)
# Let the topics be the same as the channel ID for simplicity
pubsub_topics = [ws_channel_id]
ws.state.pubsub_topics = pubsub_topics
await job_fetch_pubsub.main_loop(
ws, channel_id=ws_channel_id, pubsub_topics=pubsub_topics
)Proposed Improvement/Solution
It'd be a good improvement to have the serializing_socket_cls be passed as and argument on instantiating the PubSubEndpoint and then passed to the WebsocketRPCEndpoint created internally. So we have something like this instead
job_fetch_pubsub = PubSubEndpoint(
methods_class=JobFetchRPCMethods,
broadcaster=settings.REDIS_URL,
on_connect=[on_job_fetch_socket_connect], # type: ignore
on_disconnect=[on_job_fetch_socket_disconnect], # type: ignore
ignore_broadcaster_disconnected=False,
serializing_socket_cls=JobFetchWebSocketProxy,
)I'd be happy to make a PR for this change. Thanks!