Skip to content

Make _serializing_socket_cls used by internal WebsocketRPCEndpoint configurable from PubSubEndpoint #96

@Taiwo-Sh

Description

@Taiwo-Sh

Hello everyone!

First I'd like to say that this project has been great help. I was looking to implement a websocket endpoint that works across multiple instances of an application when I found this library. Took a while to get a hang of things but I later got to understand it.

For my use case, a client connects to the websocket and sends a request to fetch specific data. This request (with a unique channel id attached) is published to a queue which is consumed by another service that processes the request and then responds with the data, which is then published/broadcasted to all clients subscribed to that channel on the pub/sub endpoint. The client which initially made the request is also notified if still connected.

The Problem
All my endpoints respond with a specific response schema and I noticed after looking at the source code that the underlying WebsocketRPCEndpoint used by the PubSubEndpoint instance uses a _serializing_socket_cls to serialize and deserialize RPC messages, so that's most likely the best place to intercept the message and format the response. I then wrote a custom _serializing_socket_cls but there was no way to pass it to the PubSubEndpoint, so I had to override the pubsub_endpoint.endpoint._serializing_socket_cls manually to achieve this.

Here's a sample code of the implementation I had:
The custom socket serializer

class JobFetchWebSocketProxy(SimpleWebSocket):
    data_schema = FetchExternalJobSchema

    def __init__(self, websocket: SimpleWebSocket):
        self.socket = websocket
        self.rabbitmq_connection = getattr(
            self.root.app.state, "rabbitmq_connection", None
        )
        if self.rabbitmq_connection is None:
            raise ValueError("RabbitMQ connection not found in app state")
        self.pubsub_topics = getattr(self.root.state, "pubsub_topics", None)
        self.db_session = getattr(self.root.state, "db_session", None)

    @property
    def root(self) -> WebSocket:
        root = self.socket
        while socket := (
            getattr(root, "socket", None) or getattr(root, "websocket", None)
        ):
            root = socket

        if not isinstance(root, WebSocket):
            raise ValueError("Could not find root `starlette.WebSocket` instance")
        return root

    async def connect(self, uri: str, **connect_kwargs: typing.Any):  # type: ignore
        await self.socket.connect(uri, **connect_kwargs)  # type: ignore

    def serialize(self, msg: pydantic.BaseModel) -> str:
        return pydantic_serialize(msg)

    def deserialize(self, buffer: str) -> typing.Dict[str, typing.Any]:
        data = orjson.loads(buffer)
        if not isinstance(data, dict):
            raise ValueError("Invalid data format, expected a JSON object")
        return data

    async def send(self, msg: pydantic.BaseModel):  # type: ignore
        # Convert the `RpcMessage` to a `response.Schema`
        if isinstance(msg, RpcMessage) and msg.response is not None:
            # Custom response schema applied here
            response_msg = response.Schema.model_validate(msg.response.result)
        else:
            response_msg = msg
        await self.socket.send(self.serialize(response_msg))  # type: ignore

    async def recv(self) -> typing.Dict[str, typing.Any]:  # type: ignore
        msg = await self.socket.recv()  # type: ignore
        if msg is None:
            return {"request": None}

        try:
            msg = self.deserialize(msg)
            # Data schema validated here
            data = self.data_schema.model_validate(msg)
        except pydantic.ValidationError as exc:
            logger.error(f"Failed to validate job fetch data: {exc}")
            await self.send(
                response.Schema(
                    status=response.Status.ERROR,
                    message="Invalid data",
                    detail="Failed to validate job fetch data",
                    errors=[e["msg"] for e in exc.errors()],
                )
            )
            return {"request": None}

        # Convert the data schema to a `RpcMessage`
        rpc_msg = RpcMessage(
            request=RpcRequest(
                method="fetch_job",
                arguments={
                    "db": self.db_session,
                    "job_url": str(data.job_url),
                    "metadata": data.metadata,
                    "rabbitmq_connection": self.rabbitmq_connection,
                    "pubsub_topics": self.pubsub_topics,
                },
            )
        )
        return rpc_msg.model_dump()

The endpoint

job_fetch_pubsub = PubSubEndpoint(
    methods_class=JobFetchRPCMethods,
    broadcaster=settings.REDIS_URL,
    on_connect=[on_job_fetch_socket_connect],  # type: ignore
    on_disconnect=[on_job_fetch_socket_disconnect],  # type: ignore
    ignore_broadcaster_disconnected=False,
)
# Manually overriding the `_serializing_socket_cls `
job_fetch_pubsub.endpoint._serializing_socket_cls = JobFetchWebSocketProxy

@router.websocket("/ws/job-fetch", dependencies=[authe.authentication_required])
async def job_fetch_socket(ws: WebSocket, user: authe.VerifiedUser):
    user_id = str(user.id)
    ws_channel_id = f"milo:core:job_fetch:ws:{user_id}"
    logger.info(
        f"User {user_id!r} is connecting to job fetch WebSocket with client ID: {ws_channel_id!r}",
        extra={"user_id": user_id, "ws_channel_id": ws_channel_id},
    )
    # Let the topics be the same as the channel ID for simplicity
    pubsub_topics = [ws_channel_id]
    ws.state.pubsub_topics = pubsub_topics
    await job_fetch_pubsub.main_loop(
        ws, channel_id=ws_channel_id, pubsub_topics=pubsub_topics
    )

Proposed Improvement/Solution
It'd be a good improvement to have the serializing_socket_cls be passed as and argument on instantiating the PubSubEndpoint and then passed to the WebsocketRPCEndpoint created internally. So we have something like this instead

job_fetch_pubsub = PubSubEndpoint(
    methods_class=JobFetchRPCMethods,
    broadcaster=settings.REDIS_URL,
    on_connect=[on_job_fetch_socket_connect],  # type: ignore
    on_disconnect=[on_job_fetch_socket_disconnect],  # type: ignore
    ignore_broadcaster_disconnected=False,
    serializing_socket_cls=JobFetchWebSocketProxy,
)

I'd be happy to make a PR for this change. Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions