Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
156 changes: 117 additions & 39 deletions packages/toolbox-core/src/toolbox_core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,16 @@
from aiohttp import ClientSession
from deprecated import deprecated

from toolbox_core.exceptions import ProtocolNegotiationError

from . import version
from .itransport import ITransport
from .mcp_transport import (
McpHttpTransportV20241105,
McpHttpTransportV20250326,
McpHttpTransportV20250618,
McpHttpTransportV20251125,
McpHttpTransportV20260618,
)
from .protocol import Protocol, ToolSchema
from .tool import ToolboxTool
Expand All @@ -39,6 +42,112 @@
)


class _McpTransportProxy(ITransport):
"""A proxy transport that transparently handles protocol fallback negotiation."""

def __init__(
self,
url: str,
session: Optional[ClientSession],
protocol: Protocol,
client_name: Optional[str],
client_version: Optional[str],
telemetry_enabled: bool,
):
self._url = url
self._session = session
self._client_name = client_name
self._client_version = client_version
self._telemetry_enabled = telemetry_enabled
self._active_transport = self._create_transport(protocol)

def _create_transport(self, protocol: Protocol) -> ITransport:
match protocol:
case Protocol.MCP_v20260618:
return McpHttpTransportV20260618(
self._url,
self._session,
protocol,
self._client_name,
self._client_version,
telemetry_enabled=self._telemetry_enabled,
)
case Protocol.MCP_v20251125:
return McpHttpTransportV20251125(
self._url,
self._session,
protocol,
self._client_name,
self._client_version,
telemetry_enabled=self._telemetry_enabled,
)
case Protocol.MCP_v20250618:
return McpHttpTransportV20250618(
self._url,
self._session,
protocol,
self._client_name,
self._client_version,
telemetry_enabled=self._telemetry_enabled,
)
case Protocol.MCP_v20250326:
return McpHttpTransportV20250326(
self._url,
self._session,
protocol,
self._client_name,
self._client_version,
telemetry_enabled=self._telemetry_enabled,
)
case Protocol.MCP_v20241105:
return McpHttpTransportV20241105(
self._url,
self._session,
protocol,
self._client_name,
self._client_version,
telemetry_enabled=self._telemetry_enabled,
)
case _:
raise ValueError(f"Unsupported MCP protocol version: {protocol}")

@property
def base_url(self) -> str:
return self._active_transport.base_url

@property
def _protocol_version(self) -> str:
# We must expose this for tests asserting the current protocol version.
return getattr(self._active_transport, "_protocol_version", "")

async def _execute_with_fallback(
self, method_name: str, *args: Any, **kwargs: Any
) -> Any:
try:
return await getattr(self._active_transport, method_name)(*args, **kwargs)
except ProtocolNegotiationError as e:
fallback_protocol = Protocol(e.negotiated_version)
logging.warning(
f"Protocol fallback required. Switching from "
f"{self._protocol_version} to {fallback_protocol.value}"
)
await self._active_transport.close()
self._active_transport = self._create_transport(fallback_protocol)
return await getattr(self._active_transport, method_name)(*args, **kwargs)

async def tool_get(self, *args: Any, **kwargs: Any) -> Any:
return await self._execute_with_fallback("tool_get", *args, **kwargs)

async def tools_list(self, *args: Any, **kwargs: Any) -> Any:
return await self._execute_with_fallback("tools_list", *args, **kwargs)

async def tool_invoke(self, *args: Any, **kwargs: Any) -> Any:
return await self._execute_with_fallback("tool_invoke", *args, **kwargs)

async def close(self) -> None:
await self._active_transport.close()


class ToolboxClient:
"""
An asynchronous client for interacting with a Toolbox service.
Expand Down Expand Up @@ -85,45 +194,14 @@ def __init__(
"Please use Protocol.MCP_LATEST to use the latest features."
)

match protocol:
case Protocol.MCP_v20251125:
self.__transport = McpHttpTransportV20251125(
url,
session,
protocol,
client_name,
client_version,
telemetry_enabled=telemetry_enabled,
)
case Protocol.MCP_v20250618:
self.__transport = McpHttpTransportV20250618(
url,
session,
protocol,
client_name,
client_version,
telemetry_enabled=telemetry_enabled,
)
case Protocol.MCP_v20250326:
self.__transport = McpHttpTransportV20250326(
url,
session,
protocol,
client_name,
client_version,
telemetry_enabled=telemetry_enabled,
)
case Protocol.MCP_v20241105:
self.__transport = McpHttpTransportV20241105(
url,
session,
protocol,
client_name,
client_version,
telemetry_enabled=telemetry_enabled,
)
case _:
raise ValueError(f"Unsupported MCP protocol version: {protocol}")
self.__transport = _McpTransportProxy(
url,
session,
protocol,
client_name,
client_version,
telemetry_enabled,
)

self.__client_headers = client_headers if client_headers is not None else {}
warn_if_http_and_headers(url, self.__client_headers)
Expand Down
27 changes: 27 additions & 0 deletions packages/toolbox-core/src/toolbox_core/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


class ToolboxError(Exception):
"""Base exception for all MCP Toolbox errors."""

pass


class ProtocolNegotiationError(ToolboxError):
"""Raised when the server requires a different protocol version during a stateless request."""

def __init__(self, negotiated_version: str):
self.negotiated_version = negotiated_version
super().__init__(f"Server requires protocol fallback to {negotiated_version}")
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
from .v20250326.mcp import McpHttpTransportV20250326
from .v20250618.mcp import McpHttpTransportV20250618
from .v20251125.mcp import McpHttpTransportV20251125
from .v20260618.mcp import McpHttpTransportV20260618

__all__ = [
"McpHttpTransportV20241105",
"McpHttpTransportV20250326",
"McpHttpTransportV20250618",
"McpHttpTransportV20251125",
"McpHttpTransportV20260618",
]
Loading
Loading