diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 13d5ecb1e..47e5b845a 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -73,9 +73,7 @@ async def sse_client( event_source.response.raise_for_status() logger.debug("SSE connection established") - async def sse_reader( - task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED, - ): + async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED): try: async for sse in event_source.aiter_sse(): # pragma: no branch logger.debug(f"Received SSE event: {sse.event}") @@ -108,7 +106,7 @@ async def sse_reader( if not sse.data: continue try: - message = types.JSONRPCMessage.model_validate_json( # noqa: E501 + message = types.jsonrpc_message_adapter.validate_json( sse.data, by_name=False ) logger.debug(f"Received server message: {message}") diff --git a/src/mcp/client/stdio/__init__.py b/src/mcp/client/stdio/__init__.py index 5ab541da8..19fdec5a3 100644 --- a/src/mcp/client/stdio/__init__.py +++ b/src/mcp/client/stdio/__init__.py @@ -150,7 +150,7 @@ async def stdout_reader(): for line in lines: try: - message = types.JSONRPCMessage.model_validate_json(line, by_name=False) + message = types.jsonrpc_message_adapter.validate_json(line, by_name=False) except Exception as exc: # pragma: no cover logger.exception("Failed to parse JSONRPC message from server") await read_stream_writer.send(exc) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 75dcd5e89..555dd1290 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -25,6 +25,7 @@ JSONRPCRequest, JSONRPCResponse, RequestId, + jsonrpc_message_adapter, ) logger = logging.getLogger(__name__) @@ -95,11 +96,11 @@ def _prepare_headers(self) -> dict[str, str]: def _is_initialization_request(self, message: JSONRPCMessage) -> bool: """Check if the message is an initialization request.""" - return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize" + return isinstance(message, JSONRPCRequest) and message.method == "initialize" def _is_initialized_notification(self, message: JSONRPCMessage) -> bool: """Check if the message is an initialized notification.""" - return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized" + return isinstance(message, JSONRPCNotification) and message.method == "notifications/initialized" def _maybe_extract_session_id_from_response(self, response: httpx.Response) -> None: """Extract and store session ID from response headers.""" @@ -110,15 +111,15 @@ def _maybe_extract_session_id_from_response(self, response: httpx.Response) -> N def _maybe_extract_protocol_version_from_message(self, message: JSONRPCMessage) -> None: """Extract protocol version from initialization response message.""" - if isinstance(message.root, JSONRPCResponse) and message.root.result: # pragma: no branch + if isinstance(message, JSONRPCResponse) and message.result: # pragma: no branch try: # Parse the result as InitializeResult for type safety - init_result = InitializeResult.model_validate(message.root.result, by_name=False) + init_result = InitializeResult.model_validate(message.result, by_name=False) self.protocol_version = str(init_result.protocol_version) logger.info(f"Negotiated protocol version: {self.protocol_version}") except Exception: # pragma: no cover logger.warning("Failed to parse initialization response as InitializeResult", exc_info=True) - logger.warning(f"Raw result: {message.root.result}") + logger.warning(f"Raw result: {message.result}") async def _handle_sse_event( self, @@ -137,7 +138,7 @@ async def _handle_sse_event( await resumption_callback(sse.id) return False try: - message = JSONRPCMessage.model_validate_json(sse.data, by_name=False) + message = jsonrpc_message_adapter.validate_json(sse.data, by_name=False) logger.debug(f"SSE message: {message}") # Extract protocol version from initialization response @@ -145,8 +146,8 @@ async def _handle_sse_event( self._maybe_extract_protocol_version_from_message(message) # If this is a response and we have original_request_id, replace it - if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError): - message.root.id = original_request_id + if original_request_id is not None and isinstance(message, JSONRPCResponse | JSONRPCError): + message.id = original_request_id session_message = SessionMessage(message) await read_stream_writer.send(session_message) @@ -157,7 +158,7 @@ async def _handle_sse_event( # If this is a response or error return True indicating completion # Otherwise, return False to continue listening - return isinstance(message.root, JSONRPCResponse | JSONRPCError) + return isinstance(message, JSONRPCResponse | JSONRPCError) except Exception as exc: # pragma: no cover logger.exception("Error parsing SSE message") @@ -222,8 +223,8 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None: # Extract original request ID to map responses original_request_id = None - if isinstance(ctx.session_message.message.root, JSONRPCRequest): # pragma: no branch - original_request_id = ctx.session_message.message.root.id + if isinstance(ctx.session_message.message, JSONRPCRequest): # pragma: no branch + original_request_id = ctx.session_message.message.id async with aconnect_sse(ctx.client, "GET", self.url, headers=headers) as event_source: event_source.response.raise_for_status() @@ -257,12 +258,9 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: return if response.status_code == 404: # pragma: no branch - if isinstance(message.root, JSONRPCRequest): - await self._send_session_terminated_error( # pragma: no cover - ctx.read_stream_writer, # pragma: no cover - message.root.id, # pragma: no cover - ) # pragma: no cover - return # pragma: no cover + if isinstance(message, JSONRPCRequest): # pragma: no branch + await self._send_session_terminated_error(ctx.read_stream_writer, message.id) + return response.raise_for_status() if is_initialization: @@ -270,7 +268,7 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: # Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications: # The server MUST NOT send a response to notifications. - if isinstance(message.root, JSONRPCRequest): + if isinstance(message, JSONRPCRequest): content_type = response.headers.get("content-type", "").lower() if content_type.startswith("application/json"): await self._handle_json_response(response, ctx.read_stream_writer, is_initialization) @@ -291,7 +289,7 @@ async def _handle_json_response( """Handle JSON response from the server.""" try: content = await response.aread() - message = JSONRPCMessage.model_validate_json(content, by_name=False) + message = jsonrpc_message_adapter.validate_json(content, by_name=False) # Extract protocol version from initialization response if is_initialization: @@ -365,8 +363,8 @@ async def _handle_reconnection( # Extract original request ID to map responses original_request_id = None - if isinstance(ctx.session_message.message.root, JSONRPCRequest): # pragma: no branch - original_request_id = ctx.session_message.message.root.id + if isinstance(ctx.session_message.message, JSONRPCRequest): # pragma: no branch + original_request_id = ctx.session_message.message.id try: async with aconnect_sse(ctx.client, "GET", self.url, headers=headers) as event_source: @@ -416,7 +414,7 @@ async def _send_session_terminated_error(self, read_stream_writer: StreamWriter, id=request_id, error=ErrorData(code=32600, message="Session terminated"), ) - session_message = SessionMessage(JSONRPCMessage(jsonrpc_error)) + session_message = SessionMessage(jsonrpc_error) await read_stream_writer.send(session_message) async def post_writer( @@ -463,7 +461,7 @@ async def handle_request_async(): await self._handle_post_request(ctx) # If this is a request, start a new task to handle it - if isinstance(message.root, JSONRPCRequest): + if isinstance(message, JSONRPCRequest): tg.start_soon(handle_request_async) else: await handle_request_async() diff --git a/src/mcp/client/websocket.py b/src/mcp/client/websocket.py index 71860be00..d9d0aa497 100644 --- a/src/mcp/client/websocket.py +++ b/src/mcp/client/websocket.py @@ -51,7 +51,7 @@ async def ws_reader(): async with read_stream_writer: async for raw_text in ws: try: - message = types.JSONRPCMessage.model_validate_json(raw_text, by_name=False) + message = types.jsonrpc_message_adapter.validate_json(raw_text, by_name=False) session_message = SessionMessage(message) await read_stream_writer.send(session_message) except ValidationError as exc: # pragma: no cover diff --git a/src/mcp/server/experimental/task_result_handler.py b/src/mcp/server/experimental/task_result_handler.py index 078de6628..4d763ef0e 100644 --- a/src/mcp/server/experimental/task_result_handler.py +++ b/src/mcp/server/experimental/task_result_handler.py @@ -26,7 +26,6 @@ ErrorData, GetTaskPayloadRequest, GetTaskPayloadResult, - JSONRPCMessage, RelatedTaskMetadata, RequestId, ) @@ -107,12 +106,7 @@ async def handle( while True: task = await self._store.get_task(task_id) if task is None: - raise McpError( - ErrorData( - code=INVALID_PARAMS, - message=f"Task not found: {task_id}", - ) - ) + raise McpError(ErrorData(code=INVALID_PARAMS, message=f"Task not found: {task_id}")) await self._deliver_queued_messages(task_id, session, request_id) @@ -161,7 +155,7 @@ async def _deliver_queued_messages( # Send the message with relatedRequestId for routing session_message = SessionMessage( - message=JSONRPCMessage(message.message), + message=message.message, metadata=ServerMessageMetadata(related_request_id=request_id), ) await self.send_message(session, session_message) diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index 46849eb82..ea0c8db4a 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -227,7 +227,7 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send) logger.debug(f"Received JSON: {body}") try: - message = types.JSONRPCMessage.model_validate_json(body, by_name=False) + message = types.jsonrpc_message_adapter.validate_json(body, by_name=False) logger.debug(f"Validated client message: {message}") except ValidationError as err: logger.exception("Failed to parse message") diff --git a/src/mcp/server/stdio.py b/src/mcp/server/stdio.py index d494d075f..531404f21 100644 --- a/src/mcp/server/stdio.py +++ b/src/mcp/server/stdio.py @@ -60,7 +60,7 @@ async def stdin_reader(): async with read_stream_writer: async for line in stdin: try: - message = types.JSONRPCMessage.model_validate_json(line, by_name=False) + message = types.jsonrpc_message_adapter.validate_json(line, by_name=False) except Exception as exc: # pragma: no cover await read_stream_writer.send(exc) continue diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 137a7da39..6b16b1554 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -42,6 +42,7 @@ JSONRPCRequest, JSONRPCResponse, RequestId, + jsonrpc_message_adapter, ) logger = logging.getLogger(__name__) @@ -301,10 +302,7 @@ def _create_error_response( error_response = JSONRPCError( jsonrpc="2.0", id="server-error", # We don't have a request ID for general errors - error=ErrorData( - code=error_code, - message=error_message, - ), + error=ErrorData(code=error_code, message=error_message), ) return Response( @@ -455,6 +453,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re body = await request.body() try: + # TODO(Marcelo): Replace `json.loads` with `pydantic_core.from_json`. raw_message = json.loads(body) except json.JSONDecodeError as e: response = self._create_error_response(f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR) @@ -462,7 +461,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re return try: # pragma: no cover - message = JSONRPCMessage.model_validate(raw_message, by_name=False) + message = jsonrpc_message_adapter.validate_python(raw_message, by_name=False) except ValidationError as e: # pragma: no cover response = self._create_error_response( f"Validation error: {str(e)}", @@ -473,9 +472,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re return # Check if this is an initialization request - is_initialization_request = ( - isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize" - ) # pragma: no cover + is_initialization_request = isinstance(message, JSONRPCRequest) and message.method == "initialize" if is_initialization_request: # pragma: no cover # Check if the server already has an established session @@ -495,7 +492,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re return # For notifications and responses only, return 202 Accepted - if not isinstance(message.root, JSONRPCRequest): # pragma: no cover + if not isinstance(message, JSONRPCRequest): # pragma: no cover # Create response object and send it response = self._create_json_response( None, @@ -514,13 +511,13 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re # For initialize requests, get from request params. # For other requests, get from header (already validated). protocol_version = ( - str(message.root.params.get("protocolVersion", DEFAULT_NEGOTIATED_VERSION)) - if is_initialization_request and message.root.params + str(message.params.get("protocolVersion", DEFAULT_NEGOTIATED_VERSION)) + if is_initialization_request and message.params else request.headers.get(MCP_PROTOCOL_VERSION_HEADER, DEFAULT_NEGOTIATED_VERSION) ) # Extract the request ID outside the try block for proper scope - request_id = str(message.root.id) # pragma: no cover + request_id = str(message.id) # pragma: no cover # Register this stream for the request ID self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](0) # pragma: no cover request_stream_reader = self._request_streams[request_id][1] # pragma: no cover @@ -538,12 +535,12 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re # Use similar approach to SSE writer for consistency async for event_message in request_stream_reader: # If it's a response, this is what we're waiting for - if isinstance(event_message.message.root, JSONRPCResponse | JSONRPCError): + if isinstance(event_message.message, JSONRPCResponse | JSONRPCError): response_message = event_message.message break # For notifications and request, keep waiting else: - logger.debug(f"received: {event_message.message.root.method}") + logger.debug(f"received: {event_message.message.method}") # At this point we should have a response if response_message: @@ -589,10 +586,7 @@ async def sse_writer(): await sse_stream_writer.send(event_data) # If response, remove from pending streams and close - if isinstance( - event_message.message.root, - JSONRPCResponse | JSONRPCError, - ): + if isinstance(event_message.message, JSONRPCResponse | JSONRPCError): break except anyio.ClosedResourceError: # Expected when close_sse_stream() is called @@ -984,8 +978,8 @@ async def message_router(): # pragma: no cover message = session_message.message target_request_id = None # Check if this is a response - if isinstance(message.root, JSONRPCResponse | JSONRPCError): - response_id = str(message.root.id) + if isinstance(message, JSONRPCResponse | JSONRPCError): + response_id = str(message.id) # If this response is for an existing request stream, # send it there target_request_id = response_id @@ -1022,7 +1016,7 @@ async def message_router(): # pragma: no cover self._request_streams.pop(request_stream_id, None) else: logger.debug( - f"""Request stream {request_stream_id} not found + f"""Request stream {request_stream_id} not found for message. Still processing message as the client might reconnect and replay.""" ) diff --git a/src/mcp/server/websocket.py b/src/mcp/server/websocket.py index 9dde5e016..9df3e25c8 100644 --- a/src/mcp/server/websocket.py +++ b/src/mcp/server/websocket.py @@ -36,7 +36,7 @@ async def ws_reader(): async with read_stream_writer: async for msg in websocket.iter_text(): try: - client_message = types.JSONRPCMessage.model_validate_json(msg, by_name=False) + client_message = types.jsonrpc_message_adapter.validate_json(msg, by_name=False) except ValidationError as exc: await read_stream_writer.send(exc) continue diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 800693354..be1990d61 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -24,7 +24,6 @@ ClientResult, ErrorData, JSONRPCError, - JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, @@ -271,7 +270,7 @@ async def send_request( **request_data, ) - await self._write_stream.send(SessionMessage(message=JSONRPCMessage(jsonrpc_request), metadata=metadata)) + await self._write_stream.send(SessionMessage(message=jsonrpc_request, metadata=metadata)) # request read timeout takes precedence over session read timeout timeout = None @@ -321,7 +320,7 @@ async def send_notification( **notification.model_dump(by_alias=True, mode="json", exclude_none=True), ) session_message = SessionMessage( # pragma: no cover - message=JSONRPCMessage(jsonrpc_notification), + message=jsonrpc_notification, metadata=ServerMessageMetadata(related_request_id=related_request_id) if related_request_id else None, ) await self._write_stream.send(session_message) @@ -329,7 +328,7 @@ async def send_notification( async def _send_response(self, request_id: RequestId, response: SendResultT | ErrorData) -> None: if isinstance(response, ErrorData): jsonrpc_error = JSONRPCError(jsonrpc="2.0", id=request_id, error=response) - session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error)) + session_message = SessionMessage(message=jsonrpc_error) await self._write_stream.send(session_message) else: jsonrpc_response = JSONRPCResponse( @@ -337,7 +336,7 @@ async def _send_response(self, request_id: RequestId, response: SendResultT | Er id=request_id, result=response.model_dump(by_alias=True, mode="json", exclude_none=True), ) - session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_response)) + session_message = SessionMessage(message=jsonrpc_response) await self._write_stream.send(session_message) async def _receive_loop(self) -> None: @@ -349,14 +348,14 @@ async def _receive_loop(self) -> None: async for message in self._read_stream: if isinstance(message, Exception): # pragma: no cover await self._handle_incoming(message) - elif isinstance(message.message.root, JSONRPCRequest): + elif isinstance(message.message, JSONRPCRequest): try: validated_request = self._receive_request_type.model_validate( - message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True), + message.message.model_dump(by_alias=True, mode="json", exclude_none=True), by_name=False, ) responder = RequestResponder( - request_id=message.message.root.id, + request_id=message.message.id, request_meta=validated_request.root.params.meta if validated_request.root.params else None, @@ -374,23 +373,23 @@ async def _receive_loop(self) -> None: # For request validation errors, send a proper JSON-RPC error # response instead of crashing the server logging.warning(f"Failed to validate request: {e}") - logging.debug(f"Message that failed validation: {message.message.root}") + logging.debug(f"Message that failed validation: {message.message}") error_response = JSONRPCError( jsonrpc="2.0", - id=message.message.root.id, + id=message.message.id, error=ErrorData( code=INVALID_PARAMS, message="Invalid request parameters", data="", ), ) - session_message = SessionMessage(message=JSONRPCMessage(error_response)) + session_message = SessionMessage(message=error_response) await self._write_stream.send(session_message) - elif isinstance(message.message.root, JSONRPCNotification): + elif isinstance(message.message, JSONRPCNotification): try: notification = self._receive_notification_type.model_validate( - message.message.root.model_dump(by_alias=True, mode="json", exclude_none=True), + message.message.model_dump(by_alias=True, mode="json", exclude_none=True), by_name=False, ) # Handle cancellation notifications @@ -419,10 +418,11 @@ async def _receive_loop(self) -> None: ) await self._received_notification(notification) await self._handle_incoming(notification) - except Exception as e: # pragma: no cover + except Exception: # pragma: no cover # For other validation errors, log and continue logging.warning( - f"Failed to validate notification: {e}. Message was: {message.message.root}" + f"Failed to validate notification:. Message was: {message.message}", + exc_info=True, ) else: # Response or error await self._handle_response(message) @@ -475,27 +475,25 @@ async def _handle_response(self, message: SessionMessage) -> None: Checks response routers first (e.g., for task-related responses), then falls back to the normal response stream mechanism. """ - root = message.message.root - # This check is always true at runtime: the caller (_receive_loop) only invokes # this method in the else branch after checking for JSONRPCRequest and # JSONRPCNotification. However, the type checker can't infer this from the # method signature, so we need this guard for type narrowing. - if not isinstance(root, JSONRPCResponse | JSONRPCError): + if not isinstance(message.message, JSONRPCResponse | JSONRPCError): return # pragma: no cover # Normalize response ID to handle type mismatches (e.g., "0" vs 0) - response_id = self._normalize_request_id(root.id) + response_id = self._normalize_request_id(message.message.id) # First, check response routers (e.g., TaskResultHandler) - if isinstance(root, JSONRPCError): + if isinstance(message.message, JSONRPCError): # Route error to routers for router in self._response_routers: - if router.route_error(response_id, root.error): + if router.route_error(response_id, message.message.error): return # Handled else: # Route success response to routers - response_data: dict[str, Any] = root.result or {} + response_data: dict[str, Any] = message.message.result or {} for router in self._response_routers: if router.route_response(response_id, response_data): return # Handled @@ -503,7 +501,7 @@ async def _handle_response(self, message: SessionMessage) -> None: # Fall back to normal response streams stream = self._response_streams.pop(response_id, None) if stream: # pragma: no cover - await stream.send(root) + await stream.send(message.message) else: # pragma: no cover await self._handle_incoming(RuntimeError(f"Received response with an unknown request ID: {message}")) diff --git a/src/mcp/types.py b/src/mcp/types.py index b2afd977d..4c886680a 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -4,7 +4,7 @@ from datetime import datetime from typing import Annotated, Any, Final, Generic, Literal, TypeAlias, TypeVar -from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel +from pydantic import BaseModel, ConfigDict, Field, FileUrl, RootModel, TypeAdapter from pydantic.alias_generators import to_camel LATEST_PROTOCOL_VERSION = "2025-11-25" @@ -197,8 +197,8 @@ class JSONRPCError(MCPModel): error: ErrorData -class JSONRPCMessage(RootModel[JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError]): - pass +JSONRPCMessage = JSONRPCRequest | JSONRPCNotification | JSONRPCResponse | JSONRPCError +jsonrpc_message_adapter = TypeAdapter[JSONRPCMessage](JSONRPCMessage) class EmptyResult(Result): diff --git a/tests/client/conftest.py b/tests/client/conftest.py index dfcad8215..7314a3735 100644 --- a/tests/client/conftest.py +++ b/tests/client/conftest.py @@ -43,35 +43,33 @@ def clear(self) -> None: def get_client_requests(self, method: str | None = None) -> list[JSONRPCRequest]: # pragma: no cover """Get client-sent requests, optionally filtered by method.""" return [ - req.message.root + req.message for req in self.client.sent_messages - if isinstance(req.message.root, JSONRPCRequest) and (method is None or req.message.root.method == method) + if isinstance(req.message, JSONRPCRequest) and (method is None or req.message.method == method) ] def get_server_requests(self, method: str | None = None) -> list[JSONRPCRequest]: # pragma: no cover """Get server-sent requests, optionally filtered by method.""" return [ # pragma: no cover - req.message.root + req.message for req in self.server.sent_messages - if isinstance(req.message.root, JSONRPCRequest) and (method is None or req.message.root.method == method) + if isinstance(req.message, JSONRPCRequest) and (method is None or req.message.method == method) ] def get_client_notifications(self, method: str | None = None) -> list[JSONRPCNotification]: # pragma: no cover """Get client-sent notifications, optionally filtered by method.""" return [ - notif.message.root + notif.message for notif in self.client.sent_messages - if isinstance(notif.message.root, JSONRPCNotification) - and (method is None or notif.message.root.method == method) + if isinstance(notif.message, JSONRPCNotification) and (method is None or notif.message.method == method) ] def get_server_notifications(self, method: str | None = None) -> list[JSONRPCNotification]: # pragma: no cover """Get server-sent notifications, optionally filtered by method.""" return [ - notif.message.root + notif.message for notif in self.server.sent_messages - if isinstance(notif.message.root, JSONRPCNotification) - and (method is None or notif.message.root.method == method) + if isinstance(notif.message, JSONRPCNotification) and (method is None or notif.message.method == method) ] diff --git a/tests/client/test_session.py b/tests/client/test_session.py index 78df8ed19..9512a0a7c 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -18,7 +18,6 @@ InitializedNotification, InitializeRequest, InitializeResult, - JSONRPCMessage, JSONRPCNotification, JSONRPCRequest, JSONRPCResponse, @@ -41,7 +40,7 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -65,18 +64,16 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) session_notification = await client_to_server_receive.receive() jsonrpc_notification = session_notification.message - assert isinstance(jsonrpc_notification.root, JSONRPCNotification) + assert isinstance(jsonrpc_notification, JSONRPCNotification) initialized_notification = ClientNotification.model_validate( jsonrpc_notification.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -128,7 +125,7 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -146,12 +143,10 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -189,7 +184,7 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -207,12 +202,10 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -220,10 +213,7 @@ async def mock_server(): await client_to_server_receive.receive() async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - ) as session, + ClientSession(server_to_client_receive, client_to_server_send) as session, anyio.create_task_group() as tg, client_to_server_send, client_to_server_receive, @@ -247,7 +237,7 @@ async def test_client_session_version_negotiation_success(): async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -268,12 +258,10 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -281,10 +269,7 @@ async def mock_server(): await client_to_server_receive.receive() async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - ) as session, + ClientSession(server_to_client_receive, client_to_server_send) as session, anyio.create_task_group() as tg, client_to_server_send, client_to_server_receive, @@ -309,7 +294,7 @@ async def test_client_session_version_negotiation_failure(): async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -327,21 +312,16 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - ) as session, + ClientSession(server_to_client_receive, client_to_server_send) as session, anyio.create_task_group() as tg, client_to_server_send, client_to_server_receive, @@ -368,7 +348,7 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -386,12 +366,10 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -399,10 +377,7 @@ async def mock_server(): await client_to_server_receive.receive() async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - ) as session, + ClientSession(server_to_client_receive, client_to_server_send) as session, anyio.create_task_group() as tg, client_to_server_send, client_to_server_receive, @@ -446,7 +421,7 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -464,12 +439,10 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -529,7 +502,7 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -547,12 +520,10 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -600,7 +571,7 @@ async def test_get_server_capabilities(): async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -617,12 +588,10 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -669,7 +638,7 @@ async def mock_server(): # Receive initialization request from client session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -686,12 +655,10 @@ async def mock_server(): # Answer initialization request await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -702,14 +669,14 @@ async def mock_server(): # Wait for the client to send a 'tools/call' request session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) - assert jsonrpc_request.root.method == "tools/call" + assert jsonrpc_request.method == "tools/call" if meta is not None: - assert jsonrpc_request.root.params - assert "_meta" in jsonrpc_request.root.params - assert jsonrpc_request.root.params["_meta"] == meta + assert jsonrpc_request.params + assert "_meta" in jsonrpc_request.params + assert jsonrpc_request.params["_meta"] == meta result = ServerResult( CallToolResult(content=[TextContent(type="text", text="Called successfully")], is_error=False) @@ -718,12 +685,10 @@ async def mock_server(): # Send the tools/call result await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -732,20 +697,18 @@ async def mock_server(): # The client requires this step to validate the tool output schema session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) - assert jsonrpc_request.root.method == "tools/list" + assert jsonrpc_request.method == "tools/list" result = types.ListToolsResult(tools=[mocked_tool]) await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -753,10 +716,7 @@ async def mock_server(): server_to_client_send.close() async with ( - ClientSession( - server_to_client_receive, - client_to_server_send, - ) as session, + ClientSession(server_to_client_receive, client_to_server_send) as session, anyio.create_task_group() as tg, client_to_server_send, client_to_server_receive, diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 61b7ce4fa..4059a9268 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -47,8 +47,8 @@ async def test_stdio_client(): async with stdio_client(server_parameters) as (read_stream, write_stream): # Test sending and receiving messages messages = [ - JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")), - JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})), + JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"), + JSONRPCResponse(jsonrpc="2.0", id=2, result={}), ] async with write_stream: @@ -67,8 +67,8 @@ async def test_stdio_client(): break assert len(read_messages) == 2 - assert read_messages[0] == JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")) - assert read_messages[1] == JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})) + assert read_messages[0] == JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + assert read_messages[1] == JSONRPCResponse(jsonrpc="2.0", id=2, result={}) @pytest.mark.anyio diff --git a/tests/experimental/tasks/client/test_capabilities.py b/tests/experimental/tasks/client/test_capabilities.py index de73b8c06..be3547801 100644 --- a/tests/experimental/tasks/client/test_capabilities.py +++ b/tests/experimental/tasks/client/test_capabilities.py @@ -15,7 +15,6 @@ Implementation, InitializeRequest, InitializeResult, - JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, ServerCapabilities, @@ -36,7 +35,7 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -54,12 +53,10 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -110,7 +107,7 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -128,12 +125,10 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -194,7 +189,7 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -212,12 +207,10 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) @@ -274,7 +267,7 @@ async def mock_server(): session_message = await client_to_server_receive.receive() jsonrpc_request = session_message.message - assert isinstance(jsonrpc_request.root, JSONRPCRequest) + assert isinstance(jsonrpc_request, JSONRPCRequest) request = ClientRequest.model_validate( jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) ) @@ -292,12 +285,10 @@ async def mock_server(): async with server_to_client_send: await server_to_client_send.send( SessionMessage( - JSONRPCMessage( - JSONRPCResponse( - jsonrpc="2.0", - id=jsonrpc_request.root.id, - result=result.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + JSONRPCResponse( + jsonrpc="2.0", + id=jsonrpc_request.id, + result=result.model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) diff --git a/tests/experimental/tasks/client/test_handlers.py b/tests/experimental/tasks/client/test_handlers.py index 0e4e8f45a..0cac3c736 100644 --- a/tests/experimental/tasks/client/test_handlers.py +++ b/tests/experimental/tasks/client/test_handlers.py @@ -151,15 +151,11 @@ async def run_client() -> None: await client_ready.wait() typed_request = GetTaskRequest(params=GetTaskRequestParams(task_id="test-task-123")) - request = types.JSONRPCRequest( - jsonrpc="2.0", - id="req-1", - **typed_request.model_dump(by_alias=True), - ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + request = types.JSONRPCRequest(jsonrpc="2.0", id="req-1", **typed_request.model_dump(by_alias=True)) + await client_streams.server_send.send(SessionMessage(request)) response_msg = await client_streams.server_receive.receive() - response = response_msg.message.root + response = response_msg.message assert isinstance(response, types.JSONRPCResponse) assert response.id == "req-1" @@ -219,10 +215,10 @@ async def run_client() -> None: id="req-2", **typed_request.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + await client_streams.server_send.send(SessionMessage(request)) response_msg = await client_streams.server_receive.receive() - response = response_msg.message.root + response = response_msg.message assert isinstance(response, types.JSONRPCResponse) assert isinstance(response.result, dict) @@ -277,10 +273,10 @@ async def run_client() -> None: id="req-3", **typed_request.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + await client_streams.server_send.send(SessionMessage(request)) response_msg = await client_streams.server_receive.receive() - response = response_msg.message.root + response = response_msg.message assert isinstance(response, types.JSONRPCResponse) result = ListTasksResult.model_validate(response.result) @@ -340,10 +336,10 @@ async def run_client() -> None: id="req-4", **typed_request.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + await client_streams.server_send.send(SessionMessage(request)) response_msg = await client_streams.server_receive.receive() - response = response_msg.message.root + response = response_msg.message assert isinstance(response, types.JSONRPCResponse) result = CancelTaskResult.model_validate(response.result) @@ -448,11 +444,11 @@ async def run_client() -> None: id="req-sampling", **typed_request.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + await client_streams.server_send.send(SessionMessage(request)) # Step 2: Client responds with CreateTaskResult response_msg = await client_streams.server_receive.receive() - response = response_msg.message.root + response = response_msg.message assert isinstance(response, types.JSONRPCResponse) task_result = CreateTaskResult.model_validate(response.result) @@ -469,10 +465,10 @@ async def run_client() -> None: id="req-poll", **typed_poll.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(poll_request))) + await client_streams.server_send.send(SessionMessage(poll_request)) poll_response_msg = await client_streams.server_receive.receive() - poll_response = poll_response_msg.message.root + poll_response = poll_response_msg.message assert isinstance(poll_response, types.JSONRPCResponse) status = GetTaskResult.model_validate(poll_response.result) @@ -485,10 +481,10 @@ async def run_client() -> None: id="req-result", **typed_result_req.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(result_request))) + await client_streams.server_send.send(SessionMessage(result_request)) result_response_msg = await client_streams.server_receive.receive() - result_response = result_response_msg.message.root + result_response = result_response_msg.message assert isinstance(result_response, types.JSONRPCResponse) assert isinstance(result_response.result, dict) @@ -588,11 +584,11 @@ async def run_client() -> None: id="req-elicit", **typed_request.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + await client_streams.server_send.send(SessionMessage(request)) # Step 2: Client responds with CreateTaskResult response_msg = await client_streams.server_receive.receive() - response = response_msg.message.root + response = response_msg.message assert isinstance(response, types.JSONRPCResponse) task_result = CreateTaskResult.model_validate(response.result) @@ -609,10 +605,10 @@ async def run_client() -> None: id="req-poll", **typed_poll.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(poll_request))) + await client_streams.server_send.send(SessionMessage(poll_request)) poll_response_msg = await client_streams.server_receive.receive() - poll_response = poll_response_msg.message.root + poll_response = poll_response_msg.message assert isinstance(poll_response, types.JSONRPCResponse) status = GetTaskResult.model_validate(poll_response.result) @@ -625,10 +621,10 @@ async def run_client() -> None: id="req-result", **typed_result_req.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(result_request))) + await client_streams.server_send.send(SessionMessage(result_request)) result_response_msg = await client_streams.server_receive.receive() - result_response = result_response_msg.message.root + result_response = result_response_msg.message assert isinstance(result_response, types.JSONRPCResponse) # Verify the elicitation result @@ -667,10 +663,10 @@ async def run_client() -> None: id="req-unhandled", **typed_request.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + await client_streams.server_send.send(SessionMessage(request)) response_msg = await client_streams.server_receive.receive() - response = response_msg.message.root + response = response_msg.message assert isinstance(response, types.JSONRPCError) assert ( "not supported" in response.error.message.lower() @@ -706,10 +702,10 @@ async def run_client() -> None: id="req-result", **typed_request.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + await client_streams.server_send.send(SessionMessage(request)) response_msg = await client_streams.server_receive.receive() - response = response_msg.message.root + response = response_msg.message assert isinstance(response, types.JSONRPCError) assert "not supported" in response.error.message.lower() @@ -742,10 +738,10 @@ async def run_client() -> None: id="req-list", **typed_request.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + await client_streams.server_send.send(SessionMessage(request)) response_msg = await client_streams.server_receive.receive() - response = response_msg.message.root + response = response_msg.message assert isinstance(response, types.JSONRPCError) assert "not supported" in response.error.message.lower() @@ -778,10 +774,10 @@ async def run_client() -> None: id="req-cancel", **typed_request.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + await client_streams.server_send.send(SessionMessage(request)) response_msg = await client_streams.server_receive.receive() - response = response_msg.message.root + response = response_msg.message assert isinstance(response, types.JSONRPCError) assert "not supported" in response.error.message.lower() @@ -822,10 +818,10 @@ async def run_client() -> None: id="req-sampling", **typed_request.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + await client_streams.server_send.send(SessionMessage(request)) response_msg = await client_streams.server_receive.receive() - response = response_msg.message.root + response = response_msg.message assert isinstance(response, types.JSONRPCError) assert "not supported" in response.error.message.lower() @@ -868,10 +864,10 @@ async def run_client() -> None: id="req-elicit", **typed_request.model_dump(by_alias=True), ) - await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) + await client_streams.server_send.send(SessionMessage(request)) response_msg = await client_streams.server_receive.receive() - response = response_msg.message.root + response = response_msg.message assert isinstance(response, types.JSONRPCError) assert "not supported" in response.error.message.lower() diff --git a/tests/experimental/tasks/server/test_server.py b/tests/experimental/tasks/server/test_server.py index 94b37e6d0..6b0bbfef3 100644 --- a/tests/experimental/tasks/server/test_server.py +++ b/tests/experimental/tasks/server/test_server.py @@ -36,7 +36,6 @@ GetTaskRequestParams, GetTaskResult, JSONRPCError, - JSONRPCMessage, JSONRPCNotification, JSONRPCResponse, ListTasksRequest, @@ -724,7 +723,7 @@ async def test_send_message() -> None: # Create a test message notification = JSONRPCNotification(jsonrpc="2.0", method="test/notification") message = SessionMessage( - message=JSONRPCMessage(notification), + message=notification, metadata=ServerMessageMetadata(related_request_id="test-req-1"), ) @@ -733,8 +732,8 @@ async def test_send_message() -> None: # Verify it was sent to the stream received = await server_to_client_receive.receive() - assert isinstance(received.message.root, JSONRPCNotification) - assert received.message.root.method == "test/notification" + assert isinstance(received.message, JSONRPCNotification) + assert received.message.method == "test/notification" finally: # pragma: no cover await server_to_client_send.aclose() await server_to_client_receive.aclose() @@ -776,7 +775,7 @@ def route_error(self, request_id: str | int, error: ErrorData) -> bool: # Simulate receiving a response from client response = JSONRPCResponse(jsonrpc="2.0", id="test-req-1", result={"status": "ok"}) - message = SessionMessage(message=JSONRPCMessage(response)) + message = SessionMessage(message=response) # Send from "client" side await client_to_server_send.send(message) @@ -831,7 +830,7 @@ def route_error(self, request_id: str | int, error: ErrorData) -> bool: # Simulate receiving an error response from client error_data = ErrorData(code=INVALID_REQUEST, message="Test error") error_response = JSONRPCError(jsonrpc="2.0", id="test-req-2", error=error_data) - message = SessionMessage(message=JSONRPCMessage(error_response)) + message = SessionMessage(message=error_response) # Send from "client" side await client_to_server_send.send(message) @@ -894,7 +893,7 @@ def route_error(self, request_id: str | int, error: ErrorData) -> bool: # Send a response - should skip first router and be handled by second response = JSONRPCResponse(jsonrpc="2.0", id="test-req-1", result={"status": "ok"}) - message = SessionMessage(message=JSONRPCMessage(response)) + message = SessionMessage(message=response) await client_to_server_send.send(message) with anyio.fail_after(5): @@ -953,7 +952,7 @@ def route_error(self, request_id: str | int, error: ErrorData) -> bool: # Send an error - should skip first router and be handled by second error_data = ErrorData(code=INVALID_REQUEST, message="Test error") error_response = JSONRPCError(jsonrpc="2.0", id="test-req-2", error=error_data) - message = SessionMessage(message=JSONRPCMessage(error_response)) + message = SessionMessage(message=error_response) await client_to_server_send.send(message) with anyio.fail_after(5): diff --git a/tests/issues/test_192_request_id.py b/tests/issues/test_192_request_id.py index ca4a95e5d..de96dbe23 100644 --- a/tests/issues/test_192_request_id.py +++ b/tests/issues/test_192_request_id.py @@ -66,7 +66,7 @@ async def run_server(): jsonrpc="2.0", ) - await client_writer.send(SessionMessage(JSONRPCMessage(root=init_req))) + await client_writer.send(SessionMessage(init_req)) response = await server_reader.receive() # Get init response but don't need to check it # Send initialized notification @@ -75,12 +75,12 @@ async def run_server(): params=NotificationParams().model_dump(by_alias=True, exclude_none=True), jsonrpc="2.0", ) - await client_writer.send(SessionMessage(JSONRPCMessage(root=initialized_notification))) + await client_writer.send(SessionMessage(initialized_notification)) # Send ping request with custom ID ping_request = JSONRPCRequest(id=custom_request_id, method="ping", params={}, jsonrpc="2.0") - await client_writer.send(SessionMessage(JSONRPCMessage(root=ping_request))) + await client_writer.send(SessionMessage(ping_request)) # Read response response = await server_reader.receive() @@ -88,8 +88,8 @@ async def run_server(): # Verify response ID matches request ID assert isinstance(response, SessionMessage) assert isinstance(response.message, JSONRPCMessage) - assert isinstance(response.message.root, JSONRPCResponse) - assert response.message.root.id == custom_request_id, "Response ID should match request ID" + assert isinstance(response.message, JSONRPCResponse) + assert response.message.id == custom_request_id, "Response ID should match request ID" # Cancel server task tg.cancel_scope.cancel() diff --git a/tests/issues/test_malformed_input.py b/tests/issues/test_malformed_input.py index 34498ba74..cb60ca42a 100644 --- a/tests/issues/test_malformed_input.py +++ b/tests/issues/test_malformed_input.py @@ -1,21 +1,13 @@ # Claude Debug """Test for HackerOne vulnerability report #3156202 - malformed input DOS.""" -from typing import Any - import anyio import pytest from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.shared.message import SessionMessage -from mcp.types import ( - INVALID_PARAMS, - JSONRPCError, - JSONRPCMessage, - JSONRPCRequest, - ServerCapabilities, -) +from mcp.types import INVALID_PARAMS, JSONRPCError, JSONRPCMessage, JSONRPCRequest, ServerCapabilities @pytest.mark.anyio @@ -37,7 +29,7 @@ async def test_malformed_initialize_request_does_not_crash_server(): ) # Wrap in session message - request_message = SessionMessage(message=JSONRPCMessage(malformed_request)) + request_message = SessionMessage(message=malformed_request) # Start a server session async with ServerSession( @@ -58,7 +50,7 @@ async def test_malformed_initialize_request_does_not_crash_server(): # Check that we received an error response instead of a crash try: response_message = write_receive_stream.receive_nowait() - response = response_message.message.root + response = response_message.message # Verify it's a proper JSON-RPC error response assert isinstance(response, JSONRPCError) @@ -75,14 +67,14 @@ async def test_malformed_initialize_request_does_not_crash_server(): method="tools/call", # params=None # Missing required params ) - another_request_message = SessionMessage(message=JSONRPCMessage(another_malformed_request)) + another_request_message = SessionMessage(message=another_malformed_request) await read_send_stream.send(another_request_message) await anyio.sleep(0.1) # Should get another error response, not a crash second_response_message = write_receive_stream.receive_nowait() - second_response = second_response_message.message.root + second_response = second_response_message.message assert isinstance(second_response, JSONRPCError) assert second_response.id == "test_id_2" @@ -125,7 +117,7 @@ async def test_multiple_concurrent_malformed_requests(): method="initialize", # params=None # Missing required params ) - request_message = SessionMessage(message=JSONRPCMessage(malformed_request)) + request_message = SessionMessage(message=malformed_request) malformed_requests.append(request_message) # Send all requests @@ -136,11 +128,11 @@ async def test_multiple_concurrent_malformed_requests(): await anyio.sleep(0.2) # Verify we get error responses for all requests - error_responses: list[Any] = [] + error_responses: list[JSONRPCMessage] = [] try: while True: response_message = write_receive_stream.receive_nowait() - error_responses.append(response_message.message.root) + error_responses.append(response_message.message) except anyio.WouldBlock: pass # No more messages diff --git a/tests/server/test_lifespan.py b/tests/server/test_lifespan.py index 138278594..caeb0530d 100644 --- a/tests/server/test_lifespan.py +++ b/tests/server/test_lifespan.py @@ -82,13 +82,11 @@ async def run_server(): ) await send_stream1.send( SessionMessage( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=1, - method="initialize", - params=TypeAdapter(InitializeRequestParams).dump_python(params), - ) + JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=TypeAdapter(InitializeRequestParams).dump_python(params), ) ) ) @@ -96,27 +94,16 @@ async def run_server(): response = response.message # Send initialized notification - await send_stream1.send( - SessionMessage( - JSONRPCMessage( - root=JSONRPCNotification( - jsonrpc="2.0", - method="notifications/initialized", - ) - ) - ) - ) + await send_stream1.send(SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized"))) # Call the tool to verify lifespan context await send_stream1.send( SessionMessage( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=2, - method="tools/call", - params={"name": "check_lifespan", "arguments": {}}, - ) + JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params={"name": "check_lifespan", "arguments": {}}, ) ) ) @@ -125,8 +112,8 @@ async def run_server(): response = await receive_stream2.receive() response = response.message assert isinstance(response, JSONRPCMessage) - assert isinstance(response.root, JSONRPCResponse) - assert response.root.result["content"][0]["text"] == "true" + assert isinstance(response, JSONRPCResponse) + assert response.result["content"][0]["text"] == "true" # Cancel server task tg.cancel_scope.cancel() @@ -162,13 +149,7 @@ def check_lifespan(ctx: Context[ServerSession, None]) -> bool: return True # Run server in background task - async with ( - anyio.create_task_group() as tg, - send_stream1, - receive_stream1, - send_stream2, - receive_stream2, - ): + async with anyio.create_task_group() as tg, send_stream1, receive_stream1, send_stream2, receive_stream2: async def run_server(): await server._mcp_server.run( @@ -188,13 +169,11 @@ async def run_server(): ) await send_stream1.send( SessionMessage( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=1, - method="initialize", - params=TypeAdapter(InitializeRequestParams).dump_python(params), - ) + JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=TypeAdapter(InitializeRequestParams).dump_python(params), ) ) ) @@ -202,27 +181,16 @@ async def run_server(): response = response.message # Send initialized notification - await send_stream1.send( - SessionMessage( - JSONRPCMessage( - root=JSONRPCNotification( - jsonrpc="2.0", - method="notifications/initialized", - ) - ) - ) - ) + await send_stream1.send(SessionMessage(JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized"))) # Call the tool to verify lifespan context await send_stream1.send( SessionMessage( - JSONRPCMessage( - root=JSONRPCRequest( - jsonrpc="2.0", - id=2, - method="tools/call", - params={"name": "check_lifespan", "arguments": {}}, - ) + JSONRPCRequest( + jsonrpc="2.0", + id=2, + method="tools/call", + params={"name": "check_lifespan", "arguments": {}}, ) ) ) @@ -231,8 +199,8 @@ async def run_server(): response = await receive_stream2.receive() response = response.message assert isinstance(response, JSONRPCMessage) - assert isinstance(response.root, JSONRPCResponse) - assert response.root.result["content"][0]["text"] == "true" + assert isinstance(response, JSONRPCResponse) + assert response.result["content"][0]["text"] == "true" # Cancel server task tg.cancel_scope.cancel() diff --git a/tests/server/test_session.py b/tests/server/test_session.py index ced1d92ff..3c1e96c12 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -169,25 +169,23 @@ async def mock_client(): # Send initialization request with older protocol version (2024-11-05) await client_to_server_send.send( SessionMessage( - types.JSONRPCMessage( - types.JSONRPCRequest( - jsonrpc="2.0", - id=1, - method="initialize", - params=types.InitializeRequestParams( - protocol_version="2024-11-05", - capabilities=types.ClientCapabilities(), - client_info=types.Implementation(name="test-client", version="1.0.0"), - ).model_dump(by_alias=True, mode="json", exclude_none=True), - ) + types.JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=types.InitializeRequestParams( + protocol_version="2024-11-05", + capabilities=types.ClientCapabilities(), + client_info=types.Implementation(name="test-client", version="1.0.0"), + ).model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) # Wait for the initialize response init_response_message = await server_to_client_receive.receive() - assert isinstance(init_response_message.message.root, types.JSONRPCResponse) - result_data = init_response_message.message.root.result + assert isinstance(init_response_message.message, types.JSONRPCResponse) + result_data = init_response_message.message.result init_result = types.InitializeResult.model_validate(result_data) # Check that the server responded with the requested protocol version @@ -196,14 +194,7 @@ async def mock_client(): # Send initialized notification await client_to_server_send.send( - SessionMessage( - types.JSONRPCMessage( - types.JSONRPCNotification( - jsonrpc="2.0", - method="notifications/initialized", - ) - ) - ) + SessionMessage(types.JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")) ) async with ( @@ -256,24 +247,14 @@ async def mock_client(): nonlocal ping_response_received, ping_response_id # Send ping request before any initialization - await client_to_server_send.send( - SessionMessage( - types.JSONRPCMessage( - types.JSONRPCRequest( - jsonrpc="2.0", - id=42, - method="ping", - ) - ) - ) - ) + await client_to_server_send.send(SessionMessage(types.JSONRPCRequest(jsonrpc="2.0", id=42, method="ping"))) # Wait for the ping response ping_response_message = await server_to_client_receive.receive() - assert isinstance(ping_response_message.message.root, types.JSONRPCResponse) + assert isinstance(ping_response_message.message, types.JSONRPCResponse) ping_response_received = True - ping_response_id = ping_response_message.message.root.id + ping_response_id = ping_response_message.message.id async with ( client_to_server_send, @@ -493,22 +474,14 @@ async def mock_client(): # Try to send a non-ping request before initialization await client_to_server_send.send( - SessionMessage( - types.JSONRPCMessage( - types.JSONRPCRequest( - jsonrpc="2.0", - id=1, - method="prompts/list", - ) - ) - ) + SessionMessage(types.JSONRPCRequest(jsonrpc="2.0", id=1, method="prompts/list")) ) # Wait for the error response error_message = await server_to_client_receive.receive() - if isinstance(error_message.message.root, types.JSONRPCError): # pragma: no branch + if isinstance(error_message.message, types.JSONRPCError): # pragma: no branch error_response_received = True - error_code = error_message.message.root.error.code + error_code = error_message.message.error.code async with ( client_to_server_send, diff --git a/tests/server/test_session_race_condition.py b/tests/server/test_session_race_condition.py index aa256f5b0..bc6145aca 100644 --- a/tests/server/test_session_race_condition.py +++ b/tests/server/test_session_race_condition.py @@ -87,54 +87,35 @@ async def mock_client(): # Step 1: Send InitializeRequest await client_to_server_send.send( SessionMessage( - types.JSONRPCMessage( - types.JSONRPCRequest( - jsonrpc="2.0", - id=1, - method="initialize", - params=types.InitializeRequestParams( - protocol_version=types.LATEST_PROTOCOL_VERSION, - capabilities=types.ClientCapabilities(), - client_info=types.Implementation(name="test-client", version="1.0.0"), - ).model_dump(by_alias=True, mode="json", exclude_none=True), - ) + types.JSONRPCRequest( + jsonrpc="2.0", + id=1, + method="initialize", + params=types.InitializeRequestParams( + protocol_version=types.LATEST_PROTOCOL_VERSION, + capabilities=types.ClientCapabilities(), + client_info=types.Implementation(name="test-client", version="1.0.0"), + ).model_dump(by_alias=True, mode="json", exclude_none=True), ) ) ) # Step 2: Wait for InitializeResult init_msg = await server_to_client_receive.receive() - assert isinstance(init_msg.message.root, types.JSONRPCResponse) + assert isinstance(init_msg.message, types.JSONRPCResponse) # Step 3: Immediately send tools/list BEFORE InitializedNotification # This is the race condition scenario - await client_to_server_send.send( - SessionMessage( - types.JSONRPCMessage( - types.JSONRPCRequest( - jsonrpc="2.0", - id=2, - method="tools/list", - ) - ) - ) - ) + await client_to_server_send.send(SessionMessage(types.JSONRPCRequest(jsonrpc="2.0", id=2, method="tools/list"))) # Step 4: Check the response tools_msg = await server_to_client_receive.receive() - if isinstance(tools_msg.message.root, types.JSONRPCError): # pragma: no cover - error_received = tools_msg.message.root.error.message + if isinstance(tools_msg.message, types.JSONRPCError): # pragma: no cover + error_received = tools_msg.message.error.message # Step 5: Send InitializedNotification await client_to_server_send.send( - SessionMessage( - types.JSONRPCMessage( - types.JSONRPCNotification( - jsonrpc="2.0", - method="notifications/initialized", - ) - ) - ) + SessionMessage(types.JSONRPCNotification(jsonrpc="2.0", method="notifications/initialized")) ) async with ( diff --git a/tests/server/test_stdio.py b/tests/server/test_stdio.py index 13cdde3d6..9a7ddaab4 100644 --- a/tests/server/test_stdio.py +++ b/tests/server/test_stdio.py @@ -5,7 +5,7 @@ from mcp.server.stdio import stdio_server from mcp.shared.message import SessionMessage -from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse +from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, jsonrpc_message_adapter @pytest.mark.anyio @@ -14,8 +14,8 @@ async def test_stdio_server(): stdout = io.StringIO() messages = [ - JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")), - JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})), + JSONRPCRequest(jsonrpc="2.0", id=1, method="ping"), + JSONRPCResponse(jsonrpc="2.0", id=2, result={}), ] for message in messages: @@ -37,13 +37,13 @@ async def test_stdio_server(): # Verify received messages assert len(received_messages) == 2 - assert received_messages[0] == JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")) - assert received_messages[1] == JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=2, result={})) + assert received_messages[0] == JSONRPCRequest(jsonrpc="2.0", id=1, method="ping") + assert received_messages[1] == JSONRPCResponse(jsonrpc="2.0", id=2, result={}) # Test sending responses from the server responses = [ - JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")), - JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})), + JSONRPCRequest(jsonrpc="2.0", id=3, method="ping"), + JSONRPCResponse(jsonrpc="2.0", id=4, result={}), ] async with write_stream: @@ -55,7 +55,7 @@ async def test_stdio_server(): output_lines = stdout.readlines() assert len(output_lines) == 2 - received_responses = [JSONRPCMessage.model_validate_json(line.strip()) for line in output_lines] + received_responses = [jsonrpc_message_adapter.validate_json(line.strip()) for line in output_lines] assert len(received_responses) == 2 - assert received_responses[0] == JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id=3, method="ping")) - assert received_responses[1] == JSONRPCMessage(root=JSONRPCResponse(jsonrpc="2.0", id=4, result={})) + assert received_responses[0] == JSONRPCRequest(jsonrpc="2.0", id=3, method="ping") + assert received_responses[1] == JSONRPCResponse(jsonrpc="2.0", id=4, result={}) diff --git a/tests/shared/test_session.py b/tests/shared/test_session.py index 8b4ebd81f..77bec4aa3 100644 --- a/tests/shared/test_session.py +++ b/tests/shared/test_session.py @@ -18,7 +18,6 @@ EmptyResult, ErrorData, JSONRPCError, - JSONRPCMessage, JSONRPCRequest, JSONRPCResponse, TextContent, @@ -130,7 +129,7 @@ async def mock_server(): """Receive a request and respond with a string ID instead of integer.""" message = await server_read.receive() assert isinstance(message, SessionMessage) - root = message.message.root + root = message.message assert isinstance(root, JSONRPCRequest) # Get the original request ID (which is an integer) request_id = root.id @@ -142,7 +141,7 @@ async def mock_server(): id=str(request_id), # Convert to string to simulate mismatch result={}, ) - await server_write.send(SessionMessage(message=JSONRPCMessage(response))) + await server_write.send(SessionMessage(message=response)) async def make_request(client_session: ClientSession): nonlocal result_holder @@ -185,7 +184,7 @@ async def mock_server(): """Receive a request and respond with an error using a string ID.""" message = await server_read.receive() assert isinstance(message, SessionMessage) - root = message.message.root + root = message.message assert isinstance(root, JSONRPCRequest) request_id = root.id assert isinstance(request_id, int) @@ -196,7 +195,7 @@ async def mock_server(): id=str(request_id), # Convert to string to simulate mismatch error=ErrorData(code=-32600, message="Test error"), ) - await server_write.send(SessionMessage(message=JSONRPCMessage(error_response))) + await server_write.send(SessionMessage(message=error_response)) async def make_request(client_session: ClientSession): nonlocal error_holder @@ -247,7 +246,7 @@ async def mock_server(): id="not_a_number", # Non-numeric string result={}, ) - await server_write.send(SessionMessage(message=JSONRPCMessage(response))) + await server_write.send(SessionMessage(message=response)) async def make_request(client_session: ClientSession): try: diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index ad198e627..fb006424c 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -503,12 +503,12 @@ def test_sse_message_id_coercion(): See for more details. """ json_message = '{"jsonrpc": "2.0", "id": "123", "method": "ping", "params": null}' - msg = types.JSONRPCMessage.model_validate_json(json_message) - assert msg == snapshot(types.JSONRPCMessage(root=types.JSONRPCRequest(method="ping", jsonrpc="2.0", id="123"))) + msg = types.JSONRPCRequest.model_validate_json(json_message) + assert msg == snapshot(types.JSONRPCRequest(method="ping", jsonrpc="2.0", id="123")) json_message = '{"jsonrpc": "2.0", "id": 123, "method": "ping", "params": null}' - msg = types.JSONRPCMessage.model_validate_json(json_message) - assert msg == snapshot(types.JSONRPCMessage(root=types.JSONRPCRequest(method="ping", jsonrpc="2.0", id=123))) + msg = types.JSONRPCRequest.model_validate_json(json_message) + assert msg == snapshot(types.JSONRPCRequest(method="ping", jsonrpc="2.0", id=123)) @pytest.mark.parametrize( @@ -601,5 +601,5 @@ async def mock_aiter_sse() -> AsyncGenerator[ServerSentEvent, None]: msg = await read_stream.receive() # If we get here without error, the empty message was skipped successfully assert not isinstance(msg, Exception) - assert isinstance(msg.message.root, types.JSONRPCResponse) - assert msg.message.root.id == 1 + assert isinstance(msg.message, types.JSONRPCResponse) + assert msg.message.id == 1 diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 8838eb62b..0c702dce2 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -49,7 +49,6 @@ from mcp.shared.session import RequestResponder from mcp.types import ( InitializeResult, - JSONRPCMessage, JSONRPCRequest, TextContent, TextResourceContents, @@ -1859,7 +1858,7 @@ async def test_close_sse_stream_callback_not_provided_for_old_protocol_version() ) # Create a mock message and request - mock_message = JSONRPCMessage(root=JSONRPCRequest(jsonrpc="2.0", id="test-1", method="tools/list")) + mock_message = JSONRPCRequest(jsonrpc="2.0", id="test-1", method="tools/list") mock_request = MagicMock() # Call _create_session_message with OLD protocol version diff --git a/tests/test_types.py b/tests/test_types.py index 7a9576c0b..454bac34b 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -12,7 +12,6 @@ Implementation, InitializeRequest, InitializeRequestParams, - JSONRPCMessage, JSONRPCRequest, ListToolsResult, SamplingCapability, @@ -22,6 +21,7 @@ ToolChoice, ToolResultContent, ToolUseContent, + jsonrpc_message_adapter, ) @@ -38,15 +38,15 @@ async def test_jsonrpc_request(): }, } - request = JSONRPCMessage.model_validate(json_data) - assert isinstance(request.root, JSONRPCRequest) + request = jsonrpc_message_adapter.validate_python(json_data) + assert isinstance(request, JSONRPCRequest) ClientRequest.model_validate(request.model_dump(by_alias=True, exclude_none=True)) - assert request.root.jsonrpc == "2.0" - assert request.root.id == 1 - assert request.root.method == "initialize" - assert request.root.params is not None - assert request.root.params["protocolVersion"] == LATEST_PROTOCOL_VERSION + assert request.jsonrpc == "2.0" + assert request.id == 1 + assert request.method == "initialize" + assert request.params is not None + assert request.params["protocolVersion"] == LATEST_PROTOCOL_VERSION @pytest.mark.anyio