Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/mcp/client/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@ async def sse_client(
event_source.response.raise_for_status()
logger.debug("SSE connection established")

async def sse_reader(
task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED,
):
async def sse_reader(task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED):
try:
async for sse in event_source.aiter_sse(): # pragma: no branch
logger.debug(f"Received SSE event: {sse.event}")
Expand Down Expand Up @@ -108,7 +106,7 @@ async def sse_reader(
if not sse.data:
continue
try:
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
message = types.jsonrpc_message_adapter.validate_json(
sse.data, by_name=False
)
logger.debug(f"Received server message: {message}")
Expand Down
2 changes: 1 addition & 1 deletion src/mcp/client/stdio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ async def stdout_reader():

for line in lines:
try:
message = types.JSONRPCMessage.model_validate_json(line, by_name=False)
message = types.jsonrpc_message_adapter.validate_json(line, by_name=False)
except Exception as exc: # pragma: no cover
logger.exception("Failed to parse JSONRPC message from server")
await read_stream_writer.send(exc)
Expand Down
44 changes: 21 additions & 23 deletions src/mcp/client/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
JSONRPCRequest,
JSONRPCResponse,
RequestId,
jsonrpc_message_adapter,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -95,11 +96,11 @@ def _prepare_headers(self) -> dict[str, str]:

def _is_initialization_request(self, message: JSONRPCMessage) -> bool:
"""Check if the message is an initialization request."""
return isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
return isinstance(message, JSONRPCRequest) and message.method == "initialize"

def _is_initialized_notification(self, message: JSONRPCMessage) -> bool:
"""Check if the message is an initialized notification."""
return isinstance(message.root, JSONRPCNotification) and message.root.method == "notifications/initialized"
return isinstance(message, JSONRPCNotification) and message.method == "notifications/initialized"

def _maybe_extract_session_id_from_response(self, response: httpx.Response) -> None:
"""Extract and store session ID from response headers."""
Expand All @@ -110,15 +111,15 @@ def _maybe_extract_session_id_from_response(self, response: httpx.Response) -> N

def _maybe_extract_protocol_version_from_message(self, message: JSONRPCMessage) -> None:
"""Extract protocol version from initialization response message."""
if isinstance(message.root, JSONRPCResponse) and message.root.result: # pragma: no branch
if isinstance(message, JSONRPCResponse) and message.result: # pragma: no branch
try:
# Parse the result as InitializeResult for type safety
init_result = InitializeResult.model_validate(message.root.result, by_name=False)
init_result = InitializeResult.model_validate(message.result, by_name=False)
self.protocol_version = str(init_result.protocol_version)
logger.info(f"Negotiated protocol version: {self.protocol_version}")
except Exception: # pragma: no cover
logger.warning("Failed to parse initialization response as InitializeResult", exc_info=True)
logger.warning(f"Raw result: {message.root.result}")
logger.warning(f"Raw result: {message.result}")

async def _handle_sse_event(
self,
Expand All @@ -137,16 +138,16 @@ async def _handle_sse_event(
await resumption_callback(sse.id)
return False
try:
message = JSONRPCMessage.model_validate_json(sse.data, by_name=False)
message = jsonrpc_message_adapter.validate_json(sse.data, by_name=False)
logger.debug(f"SSE message: {message}")

# Extract protocol version from initialization response
if is_initialization:
self._maybe_extract_protocol_version_from_message(message)

# If this is a response and we have original_request_id, replace it
if original_request_id is not None and isinstance(message.root, JSONRPCResponse | JSONRPCError):
message.root.id = original_request_id
if original_request_id is not None and isinstance(message, JSONRPCResponse | JSONRPCError):
message.id = original_request_id

session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
Expand All @@ -157,7 +158,7 @@ async def _handle_sse_event(

# If this is a response or error return True indicating completion
# Otherwise, return False to continue listening
return isinstance(message.root, JSONRPCResponse | JSONRPCError)
return isinstance(message, JSONRPCResponse | JSONRPCError)

except Exception as exc: # pragma: no cover
logger.exception("Error parsing SSE message")
Expand Down Expand Up @@ -222,8 +223,8 @@ async def _handle_resumption_request(self, ctx: RequestContext) -> None:

# Extract original request ID to map responses
original_request_id = None
if isinstance(ctx.session_message.message.root, JSONRPCRequest): # pragma: no branch
original_request_id = ctx.session_message.message.root.id
if isinstance(ctx.session_message.message, JSONRPCRequest): # pragma: no branch
original_request_id = ctx.session_message.message.id

async with aconnect_sse(ctx.client, "GET", self.url, headers=headers) as event_source:
event_source.response.raise_for_status()
Expand Down Expand Up @@ -257,20 +258,17 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
return

if response.status_code == 404: # pragma: no branch
if isinstance(message.root, JSONRPCRequest):
await self._send_session_terminated_error( # pragma: no cover
ctx.read_stream_writer, # pragma: no cover
message.root.id, # pragma: no cover
) # pragma: no cover
return # pragma: no cover
if isinstance(message, JSONRPCRequest): # pragma: no branch
await self._send_session_terminated_error(ctx.read_stream_writer, message.id)
return

response.raise_for_status()
if is_initialization:
self._maybe_extract_session_id_from_response(response)

# Per https://modelcontextprotocol.io/specification/2025-06-18/basic#notifications:
# The server MUST NOT send a response to notifications.
if isinstance(message.root, JSONRPCRequest):
if isinstance(message, JSONRPCRequest):
content_type = response.headers.get("content-type", "").lower()
if content_type.startswith("application/json"):
await self._handle_json_response(response, ctx.read_stream_writer, is_initialization)
Expand All @@ -291,7 +289,7 @@ async def _handle_json_response(
"""Handle JSON response from the server."""
try:
content = await response.aread()
message = JSONRPCMessage.model_validate_json(content, by_name=False)
message = jsonrpc_message_adapter.validate_json(content, by_name=False)

# Extract protocol version from initialization response
if is_initialization:
Expand Down Expand Up @@ -365,8 +363,8 @@ async def _handle_reconnection(

# Extract original request ID to map responses
original_request_id = None
if isinstance(ctx.session_message.message.root, JSONRPCRequest): # pragma: no branch
original_request_id = ctx.session_message.message.root.id
if isinstance(ctx.session_message.message, JSONRPCRequest): # pragma: no branch
original_request_id = ctx.session_message.message.id

try:
async with aconnect_sse(ctx.client, "GET", self.url, headers=headers) as event_source:
Expand Down Expand Up @@ -416,7 +414,7 @@ async def _send_session_terminated_error(self, read_stream_writer: StreamWriter,
id=request_id,
error=ErrorData(code=32600, message="Session terminated"),
)
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
session_message = SessionMessage(jsonrpc_error)
await read_stream_writer.send(session_message)

async def post_writer(
Expand Down Expand Up @@ -463,7 +461,7 @@ async def handle_request_async():
await self._handle_post_request(ctx)

# If this is a request, start a new task to handle it
if isinstance(message.root, JSONRPCRequest):
if isinstance(message, JSONRPCRequest):
tg.start_soon(handle_request_async)
else:
await handle_request_async()
Expand Down
2 changes: 1 addition & 1 deletion src/mcp/client/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ async def ws_reader():
async with read_stream_writer:
async for raw_text in ws:
try:
message = types.JSONRPCMessage.model_validate_json(raw_text, by_name=False)
message = types.jsonrpc_message_adapter.validate_json(raw_text, by_name=False)
session_message = SessionMessage(message)
await read_stream_writer.send(session_message)
except ValidationError as exc: # pragma: no cover
Expand Down
10 changes: 2 additions & 8 deletions src/mcp/server/experimental/task_result_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
ErrorData,
GetTaskPayloadRequest,
GetTaskPayloadResult,
JSONRPCMessage,
RelatedTaskMetadata,
RequestId,
)
Expand Down Expand Up @@ -107,12 +106,7 @@ async def handle(
while True:
task = await self._store.get_task(task_id)
if task is None:
raise McpError(
ErrorData(
code=INVALID_PARAMS,
message=f"Task not found: {task_id}",
)
)
raise McpError(ErrorData(code=INVALID_PARAMS, message=f"Task not found: {task_id}"))

await self._deliver_queued_messages(task_id, session, request_id)

Expand Down Expand Up @@ -161,7 +155,7 @@ async def _deliver_queued_messages(

# Send the message with relatedRequestId for routing
session_message = SessionMessage(
message=JSONRPCMessage(message.message),
message=message.message,
metadata=ServerMessageMetadata(related_request_id=request_id),
)
await self.send_message(session, session_message)
Expand Down
2 changes: 1 addition & 1 deletion src/mcp/server/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ async def handle_post_message(self, scope: Scope, receive: Receive, send: Send)
logger.debug(f"Received JSON: {body}")

try:
message = types.JSONRPCMessage.model_validate_json(body, by_name=False)
message = types.jsonrpc_message_adapter.validate_json(body, by_name=False)
logger.debug(f"Validated client message: {message}")
except ValidationError as err:
logger.exception("Failed to parse message")
Expand Down
2 changes: 1 addition & 1 deletion src/mcp/server/stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def stdin_reader():
async with read_stream_writer:
async for line in stdin:
try:
message = types.JSONRPCMessage.model_validate_json(line, by_name=False)
message = types.jsonrpc_message_adapter.validate_json(line, by_name=False)
except Exception as exc: # pragma: no cover
await read_stream_writer.send(exc)
continue
Expand Down
36 changes: 15 additions & 21 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
JSONRPCRequest,
JSONRPCResponse,
RequestId,
jsonrpc_message_adapter,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -301,10 +302,7 @@ def _create_error_response(
error_response = JSONRPCError(
jsonrpc="2.0",
id="server-error", # We don't have a request ID for general errors
error=ErrorData(
code=error_code,
message=error_message,
),
error=ErrorData(code=error_code, message=error_message),
)

return Response(
Expand Down Expand Up @@ -455,14 +453,15 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
body = await request.body()

try:
# TODO(Marcelo): Replace `json.loads` with `pydantic_core.from_json`.
raw_message = json.loads(body)
except json.JSONDecodeError as e:
response = self._create_error_response(f"Parse error: {str(e)}", HTTPStatus.BAD_REQUEST, PARSE_ERROR)
await response(scope, receive, send)
return

try: # pragma: no cover
message = JSONRPCMessage.model_validate(raw_message, by_name=False)
message = jsonrpc_message_adapter.validate_python(raw_message, by_name=False)
except ValidationError as e: # pragma: no cover
response = self._create_error_response(
f"Validation error: {str(e)}",
Expand All @@ -473,9 +472,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
return

# Check if this is an initialization request
is_initialization_request = (
isinstance(message.root, JSONRPCRequest) and message.root.method == "initialize"
) # pragma: no cover
is_initialization_request = isinstance(message, JSONRPCRequest) and message.method == "initialize"

if is_initialization_request: # pragma: no cover
# Check if the server already has an established session
Expand All @@ -495,7 +492,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
return

# For notifications and responses only, return 202 Accepted
if not isinstance(message.root, JSONRPCRequest): # pragma: no cover
if not isinstance(message, JSONRPCRequest): # pragma: no cover
# Create response object and send it
response = self._create_json_response(
None,
Expand All @@ -514,13 +511,13 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
# For initialize requests, get from request params.
# For other requests, get from header (already validated).
protocol_version = (
str(message.root.params.get("protocolVersion", DEFAULT_NEGOTIATED_VERSION))
if is_initialization_request and message.root.params
str(message.params.get("protocolVersion", DEFAULT_NEGOTIATED_VERSION))
if is_initialization_request and message.params
else request.headers.get(MCP_PROTOCOL_VERSION_HEADER, DEFAULT_NEGOTIATED_VERSION)
)

# Extract the request ID outside the try block for proper scope
request_id = str(message.root.id) # pragma: no cover
request_id = str(message.id) # pragma: no cover
# Register this stream for the request ID
self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](0) # pragma: no cover
request_stream_reader = self._request_streams[request_id][1] # pragma: no cover
Expand All @@ -538,12 +535,12 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re
# Use similar approach to SSE writer for consistency
async for event_message in request_stream_reader:
# If it's a response, this is what we're waiting for
if isinstance(event_message.message.root, JSONRPCResponse | JSONRPCError):
if isinstance(event_message.message, JSONRPCResponse | JSONRPCError):
response_message = event_message.message
break
# For notifications and request, keep waiting
else:
logger.debug(f"received: {event_message.message.root.method}")
logger.debug(f"received: {event_message.message.method}")

# At this point we should have a response
if response_message:
Expand Down Expand Up @@ -589,10 +586,7 @@ async def sse_writer():
await sse_stream_writer.send(event_data)

# If response, remove from pending streams and close
if isinstance(
event_message.message.root,
JSONRPCResponse | JSONRPCError,
):
if isinstance(event_message.message, JSONRPCResponse | JSONRPCError):
break
except anyio.ClosedResourceError:
# Expected when close_sse_stream() is called
Expand Down Expand Up @@ -984,8 +978,8 @@ async def message_router(): # pragma: no cover
message = session_message.message
target_request_id = None
# Check if this is a response
if isinstance(message.root, JSONRPCResponse | JSONRPCError):
response_id = str(message.root.id)
if isinstance(message, JSONRPCResponse | JSONRPCError):
response_id = str(message.id)
# If this response is for an existing request stream,
# send it there
target_request_id = response_id
Expand Down Expand Up @@ -1022,7 +1016,7 @@ async def message_router(): # pragma: no cover
self._request_streams.pop(request_stream_id, None)
else:
logger.debug(
f"""Request stream {request_stream_id} not found
f"""Request stream {request_stream_id} not found
for message. Still processing message as the client
might reconnect and replay."""
)
Expand Down
2 changes: 1 addition & 1 deletion src/mcp/server/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ async def ws_reader():
async with read_stream_writer:
async for msg in websocket.iter_text():
try:
client_message = types.JSONRPCMessage.model_validate_json(msg, by_name=False)
client_message = types.jsonrpc_message_adapter.validate_json(msg, by_name=False)
except ValidationError as exc:
await read_stream_writer.send(exc)
continue
Expand Down
Loading