Skip to content
6 changes: 1 addition & 5 deletions packages/toolbox-adk/src/toolbox_adk/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@
from typing import Any, Awaitable, Callable, Dict, Mapping, Optional

import toolbox_core
from fastapi.openapi.models import (
OAuth2,
OAuthFlowAuthorizationCode,
OAuthFlows,
)
from fastapi.openapi.models import OAuth2, OAuthFlowAuthorizationCode, OAuthFlows
from google.adk.auth.auth_credential import (
AuthCredential,
AuthCredentialTypes,
Expand Down
25 changes: 5 additions & 20 deletions packages/toolbox-adk/tests/unit/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,7 @@ def test_from_adk_credentials_http_bearer(self):

def test_from_adk_credentials_api_key(self):
from fastapi.openapi.models import APIKey, APIKeyIn
from google.adk.auth.auth_credential import (
AuthCredential,
AuthCredentialTypes,
)
from google.adk.auth.auth_credential import AuthCredential, AuthCredentialTypes

auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="abc"
Expand All @@ -129,10 +126,7 @@ def test_from_adk_credentials_api_key(self):

def test_from_adk_credentials_api_key_default_location(self):
from fastapi.openapi.models import APIKey
from google.adk.auth.auth_credential import (
AuthCredential,
AuthCredentialTypes,
)
from google.adk.auth.auth_credential import AuthCredential, AuthCredentialTypes

auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="abc"
Expand All @@ -157,10 +151,7 @@ class MockScheme:
def test_from_adk_credentials_api_key_query_fail(self):
import pytest
from fastapi.openapi.models import APIKey, APIKeyIn
from google.adk.auth.auth_credential import (
AuthCredential,
AuthCredentialTypes,
)
from google.adk.auth.auth_credential import AuthCredential, AuthCredentialTypes

cred = AuthCredential(auth_type=AuthCredentialTypes.API_KEY, api_key="abc")
scheme = APIKey(type="apiKey", name="key", **{"in": APIKeyIn.query})
Expand All @@ -172,10 +163,7 @@ def test_from_adk_credentials_api_key_query_fail(self):

def test_from_adk_credentials_api_key_no_scheme_raises(self):
import pytest
from google.adk.auth.auth_credential import (
AuthCredential,
AuthCredentialTypes,
)
from google.adk.auth.auth_credential import AuthCredential, AuthCredentialTypes

auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.API_KEY, api_key="my-key"
Expand All @@ -187,10 +175,7 @@ def test_from_adk_credentials_api_key_no_scheme_raises(self):

def test_from_adk_credentials_unsupported(self):
import pytest
from google.adk.auth.auth_credential import (
AuthCredential,
AuthCredentialTypes,
)
from google.adk.auth.auth_credential import AuthCredential, AuthCredentialTypes

auth_credential = AuthCredential(
auth_type=AuthCredentialTypes.OAUTH2
Expand Down
8 changes: 1 addition & 7 deletions packages/toolbox-core/src/toolbox_core/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(

def _create_transport(self, protocol: Protocol) -> ITransport:
match protocol:
case Protocol.MCP_v20260618:
case Protocol.MCP_DRAFT:
return McpHttpTransportV20260618(
self._url,
self._session,
Expand Down Expand Up @@ -188,12 +188,6 @@ def __init__(
telemetry_enabled: Whether to enable OpenTelemetry tracing and metrics. (Default: False)
"""

if protocol != Protocol.MCP_LATEST:
logging.warning(
f"A newer version of MCP ({Protocol.MCP_LATEST.value}) is available. "
"Please use Protocol.MCP_LATEST to use the latest features."
)

self.__transport = _McpTransportProxy(
url,
session,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from pydantic import BaseModel

from ... import version
from ...exceptions import ProtocolNegotiationError
from ...protocol import ManifestSchema, TelemetryAttributes
from .. import telemetry
from ..transport_base import _McpHttpTransportBase
Expand Down Expand Up @@ -131,9 +132,7 @@ async def _initialize_session(
self._server_version = result.serverInfo.version

if result.protocolVersion != self._protocol_version:
raise RuntimeError(
f"MCP version mismatch: client does not support server version {result.protocolVersion}"
)
raise ProtocolNegotiationError(result.protocolVersion)

if not result.capabilities.tools:
if self._manage_session:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
# limitations under the License.

import time
from typing import Mapping, Optional, TypeVar
from typing import Any, Mapping, Optional, TypeVar

from pydantic import BaseModel

from ... import version
from ...exceptions import ProtocolNegotiationError
from ...protocol import ManifestSchema, TelemetryAttributes
from .. import telemetry
from ..transport_base import _McpHttpTransportBase
Expand All @@ -29,7 +30,7 @@
class McpHttpTransportV20250326(_McpHttpTransportBase):
"""Transport for the MCP v2025-03-26 protocol."""

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._session_id: Optional[str] = None

Expand Down Expand Up @@ -147,10 +148,7 @@ async def _initialize_session(
self._server_version = result.serverInfo.version

if result.protocolVersion != self._protocol_version:
raise RuntimeError(
"MCP version mismatch: client does not support server version"
f" {result.protocolVersion}"
)
raise ProtocolNegotiationError(result.protocolVersion)

if not result.capabilities.tools:
if self._manage_session:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from pydantic import BaseModel

from ... import version
from ...protocol import ManifestSchema, TelemetryAttributes
from ...exceptions import ProtocolNegotiationError
from ...protocol import ManifestSchema, Protocol, TelemetryAttributes
from .. import telemetry
from ..transport_base import _McpHttpTransportBase
from . import types
Expand Down Expand Up @@ -71,6 +72,21 @@ async def _send_request(

# Check for JSON-RPC Error
if "error" in json_resp:
err_val = json_resp["error"]
if isinstance(err_val, dict) and err_val.get("code") == -32004:
server_supported = err_val.get("data", {}).get("supported", [])
client_supported = Protocol.get_supported_mcp_versions()
mutually_supported = [
v for v in client_supported if v in server_supported
]
if mutually_supported:
raise ProtocolNegotiationError(mutually_supported[0])
Comment thread
anubhav756 marked this conversation as resolved.
else:
raise RuntimeError(
"No mutually supported protocol version. "
f"Client supports: {client_supported}, "
f"Server supports: {server_supported}"
)
try:
err = types.JSONRPCError.model_validate(json_resp).error
raise RuntimeError(
Expand Down Expand Up @@ -138,10 +154,7 @@ async def _initialize_session(
self._server_version = result.serverInfo.version

if result.protocolVersion != self._protocol_version:
raise RuntimeError(
"MCP version mismatch: client does not support server version"
f" {result.protocolVersion}"
)
raise ProtocolNegotiationError(result.protocolVersion)

if not result.capabilities.tools:
if self._manage_session:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from pydantic import BaseModel

from ... import version
from ...protocol import ManifestSchema, TelemetryAttributes
from ...exceptions import ProtocolNegotiationError
from ...protocol import ManifestSchema, Protocol, TelemetryAttributes
from .. import telemetry
from ..transport_base import _McpHttpTransportBase
from . import types
Expand Down Expand Up @@ -71,6 +72,21 @@ async def _send_request(

# Check for JSON-RPC Error
if "error" in json_resp:
err_val = json_resp["error"]
if isinstance(err_val, dict) and err_val.get("code") == -32004:
server_supported = err_val.get("data", {}).get("supported", [])
client_supported = Protocol.get_supported_mcp_versions()
mutually_supported = [
v for v in client_supported if v in server_supported
]
if mutually_supported:
raise ProtocolNegotiationError(mutually_supported[0])
else:
raise RuntimeError(
"No mutually supported protocol version. "
f"Client supports: {client_supported}, "
f"Server supports: {server_supported}"
)
try:
err = types.JSONRPCError.model_validate(json_resp).error
raise RuntimeError(
Expand Down Expand Up @@ -138,10 +154,7 @@ async def _initialize_session(
self._server_version = result.serverInfo.version

if result.protocolVersion != self._protocol_version:
raise RuntimeError(
"MCP version mismatch: client does not support server version"
f" {result.protocolVersion}"
)
raise ProtocolNegotiationError(result.protocolVersion)

if not result.capabilities.tools:
if self._manage_session:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,16 @@ async def _send_request(

# Inject SEP-2243 routing headers
req_headers["Mcp-Method"] = request.method
if (
request.method == "tools/call"
and hasattr(request, "params")
and request.params is not None
):
if hasattr(request.params, "name"):
req_headers["Mcp-Name"] = request.params.name
params = getattr(request, "params", None)
if params is not None:
if request.method in ("tools/call", "prompts/get"):
name = getattr(params, "name", None)
if name is not None:
req_headers["Mcp-Name"] = str(name)
elif request.method == "resources/read":
uri = getattr(params, "uri", None)
if uri is not None:
req_headers["Mcp-Name"] = str(uri)

# Dynamically update the _meta protocol version in the parameters model
if hasattr(request, "params") and request.params is not None:
Expand Down Expand Up @@ -125,6 +128,21 @@ async def _send_request(

# Check for JSON-RPC Error
if "error" in json_resp:
err_val = json_resp["error"]
if isinstance(err_val, dict) and err_val.get("code") == -32004:
server_supported = err_val.get("data", {}).get("supported", [])
client_supported = Protocol.get_supported_mcp_versions()
mutually_supported = [
v for v in client_supported if v in server_supported
]
if mutually_supported:
raise ProtocolNegotiationError(mutually_supported[0])
else:
raise RuntimeError(
"No mutually supported protocol version. "
f"Client supports: {client_supported}, "
f"Server supports: {server_supported}"
)
try:
err = types.JSONRPCError.model_validate(json_resp).error
raise RuntimeError(
Expand Down
10 changes: 6 additions & 4 deletions packages/toolbox-core/src/toolbox_core/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,21 @@ def _empty_string_to_none(cls, value: Any) -> Any:
class Protocol(str, Enum):
"""Defines how the client should choose between communication protocols."""

MCP_v20260618 = "DRAFT-2026-v1"
MCP_v20250618 = "2025-06-18"
MCP_v20250326 = "2025-03-26"
MCP_v20241105 = "2024-11-05"
MCP_v20251125 = "2025-11-25"
MCP = MCP_v20250618
MCP_LATEST = MCP_v20260618
MCP_v2026_DRAFT = "DRAFT-2026-v1"

MCP = MCP_v20251125
MCP_LATEST = MCP_v20251125
MCP_DRAFT = MCP_v2026_DRAFT

@staticmethod
def get_supported_mcp_versions() -> list[str]:
"""Returns a list of supported MCP protocol versions."""
return [
Protocol.MCP_v20260618.value,
Protocol.MCP_DRAFT.value,
Protocol.MCP_v20251125.value,
Protocol.MCP_v20250618.value,
Protocol.MCP_v20250326.value,
Expand Down
2 changes: 1 addition & 1 deletion packages/toolbox-core/tests/conformance/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ async def main():

protocol = Protocol.MCP
if scenario == "request-metadata":
protocol = Protocol.MCP_v20260618
protocol = Protocol.MCP_LATEST

async with ToolboxClient(
server_url, client_headers=client_headers, protocol=protocol
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pytest_asyncio
from aiohttp import ClientSession

from toolbox_core.exceptions import ProtocolNegotiationError
from toolbox_core.mcp_transport.v20241105 import types
from toolbox_core.mcp_transport.v20241105.mcp import McpHttpTransportV20241105
from toolbox_core.protocol import ManifestSchema, Protocol
Expand Down Expand Up @@ -248,7 +249,7 @@ async def test_initialize_session_protocol_mismatch(self, transport, mocker):
),
)

with pytest.raises(RuntimeError, match="MCP version mismatch"):
with pytest.raises(ProtocolNegotiationError):
await transport._initialize_session()

async def test_initialize_session_missing_tools_capability(self, transport, mocker):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pytest_asyncio
from aiohttp import ClientSession

from toolbox_core.exceptions import ProtocolNegotiationError
from toolbox_core.mcp_transport.v20250618 import types
from toolbox_core.mcp_transport.v20250618.mcp import McpHttpTransportV20250618
from toolbox_core.protocol import ManifestSchema, Protocol, TelemetryAttributes
Expand Down Expand Up @@ -256,7 +257,7 @@ async def test_initialize_session_protocol_mismatch(self, transport, mocker):
),
)

with pytest.raises(RuntimeError, match="MCP version mismatch"):
with pytest.raises(ProtocolNegotiationError):
await transport._initialize_session()

async def test_initialize_session_missing_tools_capability(self, transport, mocker):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pytest_asyncio
from aiohttp import ClientSession

from toolbox_core.exceptions import ProtocolNegotiationError
from toolbox_core.mcp_transport.v20251125 import types
from toolbox_core.mcp_transport.v20251125.mcp import McpHttpTransportV20251125
from toolbox_core.protocol import ManifestSchema, Protocol
Expand Down Expand Up @@ -256,7 +257,7 @@ async def test_initialize_session_protocol_mismatch(self, transport, mocker):
),
)

with pytest.raises(RuntimeError, match="MCP version mismatch"):
with pytest.raises(ProtocolNegotiationError):
await transport._initialize_session()

async def test_initialize_session_missing_tools_capability(self, transport, mocker):
Expand Down
Loading