diff --git a/packages/toolbox-adk/src/toolbox_adk/tool.py b/packages/toolbox-adk/src/toolbox_adk/tool.py index db41f81b9..0d94821b6 100644 --- a/packages/toolbox-adk/src/toolbox_adk/tool.py +++ b/packages/toolbox-adk/src/toolbox_adk/tool.py @@ -17,11 +17,7 @@ from typing import Any, Awaitable, Callable, Dict, Mapping, Optional import toolbox_core -from fastapi.openapi.models import ( - OAuth2, - OAuthFlowAuthorizationCode, - OAuthFlows, -) +from fastapi.openapi.models import OAuth2, OAuthFlowAuthorizationCode, OAuthFlows from google.adk.auth.auth_credential import ( AuthCredential, AuthCredentialTypes, diff --git a/packages/toolbox-adk/tests/unit/test_credentials.py b/packages/toolbox-adk/tests/unit/test_credentials.py index 749d28667..9e94d5638 100644 --- a/packages/toolbox-adk/tests/unit/test_credentials.py +++ b/packages/toolbox-adk/tests/unit/test_credentials.py @@ -109,10 +109,7 @@ def test_from_adk_credentials_http_bearer(self): def test_from_adk_credentials_api_key(self): from fastapi.openapi.models import APIKey, APIKeyIn - from google.adk.auth.auth_credential import ( - AuthCredential, - AuthCredentialTypes, - ) + from google.adk.auth.auth_credential import AuthCredential, AuthCredentialTypes auth_credential = AuthCredential( auth_type=AuthCredentialTypes.API_KEY, api_key="abc" @@ -129,10 +126,7 @@ def test_from_adk_credentials_api_key(self): def test_from_adk_credentials_api_key_default_location(self): from fastapi.openapi.models import APIKey - from google.adk.auth.auth_credential import ( - AuthCredential, - AuthCredentialTypes, - ) + from google.adk.auth.auth_credential import AuthCredential, AuthCredentialTypes auth_credential = AuthCredential( auth_type=AuthCredentialTypes.API_KEY, api_key="abc" @@ -157,10 +151,7 @@ class MockScheme: def test_from_adk_credentials_api_key_query_fail(self): import pytest from fastapi.openapi.models import APIKey, APIKeyIn - from google.adk.auth.auth_credential import ( - AuthCredential, - AuthCredentialTypes, - ) + from google.adk.auth.auth_credential import AuthCredential, AuthCredentialTypes cred = AuthCredential(auth_type=AuthCredentialTypes.API_KEY, api_key="abc") scheme = APIKey(type="apiKey", name="key", **{"in": APIKeyIn.query}) @@ -172,10 +163,7 @@ def test_from_adk_credentials_api_key_query_fail(self): def test_from_adk_credentials_api_key_no_scheme_raises(self): import pytest - from google.adk.auth.auth_credential import ( - AuthCredential, - AuthCredentialTypes, - ) + from google.adk.auth.auth_credential import AuthCredential, AuthCredentialTypes auth_credential = AuthCredential( auth_type=AuthCredentialTypes.API_KEY, api_key="my-key" @@ -187,10 +175,7 @@ def test_from_adk_credentials_api_key_no_scheme_raises(self): def test_from_adk_credentials_unsupported(self): import pytest - from google.adk.auth.auth_credential import ( - AuthCredential, - AuthCredentialTypes, - ) + from google.adk.auth.auth_credential import AuthCredential, AuthCredentialTypes auth_credential = AuthCredential( auth_type=AuthCredentialTypes.OAUTH2 diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 7bb32cdf9..b2224eedc 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -63,7 +63,7 @@ def __init__( def _create_transport(self, protocol: Protocol) -> ITransport: match protocol: - case Protocol.MCP_v20260618: + case Protocol.MCP_DRAFT: return McpHttpTransportV20260618( self._url, self._session, @@ -188,12 +188,6 @@ def __init__( telemetry_enabled: Whether to enable OpenTelemetry tracing and metrics. (Default: False) """ - if protocol != Protocol.MCP_LATEST: - logging.warning( - f"A newer version of MCP ({Protocol.MCP_LATEST.value}) is available. " - "Please use Protocol.MCP_LATEST to use the latest features." - ) - self.__transport = _McpTransportProxy( url, session, diff --git a/packages/toolbox-core/src/toolbox_core/mcp_transport/v20241105/mcp.py b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20241105/mcp.py index ec6a2475d..7f03dcee0 100644 --- a/packages/toolbox-core/src/toolbox_core/mcp_transport/v20241105/mcp.py +++ b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20241105/mcp.py @@ -18,6 +18,7 @@ from pydantic import BaseModel from ... import version +from ...exceptions import ProtocolNegotiationError from ...protocol import ManifestSchema, TelemetryAttributes from .. import telemetry from ..transport_base import _McpHttpTransportBase @@ -131,9 +132,7 @@ async def _initialize_session( self._server_version = result.serverInfo.version if result.protocolVersion != self._protocol_version: - raise RuntimeError( - f"MCP version mismatch: client does not support server version {result.protocolVersion}" - ) + raise ProtocolNegotiationError(result.protocolVersion) if not result.capabilities.tools: if self._manage_session: diff --git a/packages/toolbox-core/src/toolbox_core/mcp_transport/v20250326/mcp.py b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20250326/mcp.py index 0194bd282..7bf749e84 100644 --- a/packages/toolbox-core/src/toolbox_core/mcp_transport/v20250326/mcp.py +++ b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20250326/mcp.py @@ -13,11 +13,12 @@ # limitations under the License. import time -from typing import Mapping, Optional, TypeVar +from typing import Any, Mapping, Optional, TypeVar from pydantic import BaseModel from ... import version +from ...exceptions import ProtocolNegotiationError from ...protocol import ManifestSchema, TelemetryAttributes from .. import telemetry from ..transport_base import _McpHttpTransportBase @@ -29,7 +30,7 @@ class McpHttpTransportV20250326(_McpHttpTransportBase): """Transport for the MCP v2025-03-26 protocol.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._session_id: Optional[str] = None @@ -147,10 +148,7 @@ async def _initialize_session( self._server_version = result.serverInfo.version if result.protocolVersion != self._protocol_version: - raise RuntimeError( - "MCP version mismatch: client does not support server version" - f" {result.protocolVersion}" - ) + raise ProtocolNegotiationError(result.protocolVersion) if not result.capabilities.tools: if self._manage_session: diff --git a/packages/toolbox-core/src/toolbox_core/mcp_transport/v20250618/mcp.py b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20250618/mcp.py index d7a626ed1..2ea3ee636 100644 --- a/packages/toolbox-core/src/toolbox_core/mcp_transport/v20250618/mcp.py +++ b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20250618/mcp.py @@ -18,7 +18,8 @@ from pydantic import BaseModel from ... import version -from ...protocol import ManifestSchema, TelemetryAttributes +from ...exceptions import ProtocolNegotiationError +from ...protocol import ManifestSchema, Protocol, TelemetryAttributes from .. import telemetry from ..transport_base import _McpHttpTransportBase from . import types @@ -71,6 +72,21 @@ async def _send_request( # Check for JSON-RPC Error if "error" in json_resp: + err_val = json_resp["error"] + if isinstance(err_val, dict) and err_val.get("code") == -32004: + server_supported = err_val.get("data", {}).get("supported", []) + client_supported = Protocol.get_supported_mcp_versions() + mutually_supported = [ + v for v in client_supported if v in server_supported + ] + if mutually_supported: + raise ProtocolNegotiationError(mutually_supported[0]) + else: + raise RuntimeError( + "No mutually supported protocol version. " + f"Client supports: {client_supported}, " + f"Server supports: {server_supported}" + ) try: err = types.JSONRPCError.model_validate(json_resp).error raise RuntimeError( @@ -138,10 +154,7 @@ async def _initialize_session( self._server_version = result.serverInfo.version if result.protocolVersion != self._protocol_version: - raise RuntimeError( - "MCP version mismatch: client does not support server version" - f" {result.protocolVersion}" - ) + raise ProtocolNegotiationError(result.protocolVersion) if not result.capabilities.tools: if self._manage_session: diff --git a/packages/toolbox-core/src/toolbox_core/mcp_transport/v20251125/mcp.py b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20251125/mcp.py index 999301552..4a0e8d876 100644 --- a/packages/toolbox-core/src/toolbox_core/mcp_transport/v20251125/mcp.py +++ b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20251125/mcp.py @@ -18,7 +18,8 @@ from pydantic import BaseModel from ... import version -from ...protocol import ManifestSchema, TelemetryAttributes +from ...exceptions import ProtocolNegotiationError +from ...protocol import ManifestSchema, Protocol, TelemetryAttributes from .. import telemetry from ..transport_base import _McpHttpTransportBase from . import types @@ -71,6 +72,21 @@ async def _send_request( # Check for JSON-RPC Error if "error" in json_resp: + err_val = json_resp["error"] + if isinstance(err_val, dict) and err_val.get("code") == -32004: + server_supported = err_val.get("data", {}).get("supported", []) + client_supported = Protocol.get_supported_mcp_versions() + mutually_supported = [ + v for v in client_supported if v in server_supported + ] + if mutually_supported: + raise ProtocolNegotiationError(mutually_supported[0]) + else: + raise RuntimeError( + "No mutually supported protocol version. " + f"Client supports: {client_supported}, " + f"Server supports: {server_supported}" + ) try: err = types.JSONRPCError.model_validate(json_resp).error raise RuntimeError( @@ -138,10 +154,7 @@ async def _initialize_session( self._server_version = result.serverInfo.version if result.protocolVersion != self._protocol_version: - raise RuntimeError( - "MCP version mismatch: client does not support server version" - f" {result.protocolVersion}" - ) + raise ProtocolNegotiationError(result.protocolVersion) if not result.capabilities.tools: if self._manage_session: diff --git a/packages/toolbox-core/src/toolbox_core/mcp_transport/v20260618/mcp.py b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20260618/mcp.py index fe9f45fb3..866249733 100644 --- a/packages/toolbox-core/src/toolbox_core/mcp_transport/v20260618/mcp.py +++ b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20260618/mcp.py @@ -42,13 +42,16 @@ async def _send_request( # Inject SEP-2243 routing headers req_headers["Mcp-Method"] = request.method - if ( - request.method == "tools/call" - and hasattr(request, "params") - and request.params is not None - ): - if hasattr(request.params, "name"): - req_headers["Mcp-Name"] = request.params.name + params = getattr(request, "params", None) + if params is not None: + if request.method in ("tools/call", "prompts/get"): + name = getattr(params, "name", None) + if name is not None: + req_headers["Mcp-Name"] = str(name) + elif request.method == "resources/read": + uri = getattr(params, "uri", None) + if uri is not None: + req_headers["Mcp-Name"] = str(uri) # Dynamically update the _meta protocol version in the parameters model if hasattr(request, "params") and request.params is not None: @@ -125,6 +128,21 @@ async def _send_request( # Check for JSON-RPC Error if "error" in json_resp: + err_val = json_resp["error"] + if isinstance(err_val, dict) and err_val.get("code") == -32004: + server_supported = err_val.get("data", {}).get("supported", []) + client_supported = Protocol.get_supported_mcp_versions() + mutually_supported = [ + v for v in client_supported if v in server_supported + ] + if mutually_supported: + raise ProtocolNegotiationError(mutually_supported[0]) + else: + raise RuntimeError( + "No mutually supported protocol version. " + f"Client supports: {client_supported}, " + f"Server supports: {server_supported}" + ) try: err = types.JSONRPCError.model_validate(json_resp).error raise RuntimeError( diff --git a/packages/toolbox-core/src/toolbox_core/protocol.py b/packages/toolbox-core/src/toolbox_core/protocol.py index 472580a94..350e0251f 100644 --- a/packages/toolbox-core/src/toolbox_core/protocol.py +++ b/packages/toolbox-core/src/toolbox_core/protocol.py @@ -47,19 +47,21 @@ def _empty_string_to_none(cls, value: Any) -> Any: class Protocol(str, Enum): """Defines how the client should choose between communication protocols.""" - MCP_v20260618 = "DRAFT-2026-v1" MCP_v20250618 = "2025-06-18" MCP_v20250326 = "2025-03-26" MCP_v20241105 = "2024-11-05" MCP_v20251125 = "2025-11-25" - MCP = MCP_v20250618 - MCP_LATEST = MCP_v20260618 + MCP_v2026_DRAFT = "DRAFT-2026-v1" + + MCP = MCP_v20251125 + MCP_LATEST = MCP_v20251125 + MCP_DRAFT = MCP_v2026_DRAFT @staticmethod def get_supported_mcp_versions() -> list[str]: """Returns a list of supported MCP protocol versions.""" return [ - Protocol.MCP_v20260618.value, + Protocol.MCP_DRAFT.value, Protocol.MCP_v20251125.value, Protocol.MCP_v20250618.value, Protocol.MCP_v20250326.value, diff --git a/packages/toolbox-core/tests/conformance/client.py b/packages/toolbox-core/tests/conformance/client.py index 5d59d593a..33902563e 100644 --- a/packages/toolbox-core/tests/conformance/client.py +++ b/packages/toolbox-core/tests/conformance/client.py @@ -50,7 +50,7 @@ async def main(): protocol = Protocol.MCP if scenario == "request-metadata": - protocol = Protocol.MCP_v20260618 + protocol = Protocol.MCP_LATEST async with ToolboxClient( server_url, client_headers=client_headers, protocol=protocol diff --git a/packages/toolbox-core/tests/mcp_transport/test_v20241105.py b/packages/toolbox-core/tests/mcp_transport/test_v20241105.py index 1ce2f39bb..eadb0e264 100644 --- a/packages/toolbox-core/tests/mcp_transport/test_v20241105.py +++ b/packages/toolbox-core/tests/mcp_transport/test_v20241105.py @@ -18,6 +18,7 @@ import pytest_asyncio from aiohttp import ClientSession +from toolbox_core.exceptions import ProtocolNegotiationError from toolbox_core.mcp_transport.v20241105 import types from toolbox_core.mcp_transport.v20241105.mcp import McpHttpTransportV20241105 from toolbox_core.protocol import ManifestSchema, Protocol @@ -248,7 +249,7 @@ async def test_initialize_session_protocol_mismatch(self, transport, mocker): ), ) - with pytest.raises(RuntimeError, match="MCP version mismatch"): + with pytest.raises(ProtocolNegotiationError): await transport._initialize_session() async def test_initialize_session_missing_tools_capability(self, transport, mocker): diff --git a/packages/toolbox-core/tests/mcp_transport/test_v20250618.py b/packages/toolbox-core/tests/mcp_transport/test_v20250618.py index fd2e50f7c..b5a1a0a94 100644 --- a/packages/toolbox-core/tests/mcp_transport/test_v20250618.py +++ b/packages/toolbox-core/tests/mcp_transport/test_v20250618.py @@ -18,6 +18,7 @@ import pytest_asyncio from aiohttp import ClientSession +from toolbox_core.exceptions import ProtocolNegotiationError from toolbox_core.mcp_transport.v20250618 import types from toolbox_core.mcp_transport.v20250618.mcp import McpHttpTransportV20250618 from toolbox_core.protocol import ManifestSchema, Protocol, TelemetryAttributes @@ -256,7 +257,7 @@ async def test_initialize_session_protocol_mismatch(self, transport, mocker): ), ) - with pytest.raises(RuntimeError, match="MCP version mismatch"): + with pytest.raises(ProtocolNegotiationError): await transport._initialize_session() async def test_initialize_session_missing_tools_capability(self, transport, mocker): diff --git a/packages/toolbox-core/tests/mcp_transport/test_v20251125.py b/packages/toolbox-core/tests/mcp_transport/test_v20251125.py index 9942041aa..f575b81bc 100644 --- a/packages/toolbox-core/tests/mcp_transport/test_v20251125.py +++ b/packages/toolbox-core/tests/mcp_transport/test_v20251125.py @@ -18,6 +18,7 @@ import pytest_asyncio from aiohttp import ClientSession +from toolbox_core.exceptions import ProtocolNegotiationError from toolbox_core.mcp_transport.v20251125 import types from toolbox_core.mcp_transport.v20251125.mcp import McpHttpTransportV20251125 from toolbox_core.protocol import ManifestSchema, Protocol @@ -256,7 +257,7 @@ async def test_initialize_session_protocol_mismatch(self, transport, mocker): ), ) - with pytest.raises(RuntimeError, match="MCP version mismatch"): + with pytest.raises(ProtocolNegotiationError): await transport._initialize_session() async def test_initialize_session_missing_tools_capability(self, transport, mocker): diff --git a/packages/toolbox-core/tests/mcp_transport/test_v20260618.py b/packages/toolbox-core/tests/mcp_transport/test_v20260618.py index 852112262..e400fdf99 100644 --- a/packages/toolbox-core/tests/mcp_transport/test_v20260618.py +++ b/packages/toolbox-core/tests/mcp_transport/test_v20260618.py @@ -70,7 +70,7 @@ async def transport(request, mocker): transport = McpHttpTransportV20260618( "http://fake-server.com", session=mock_session, - protocol=Protocol.MCP_v20260618, + protocol=Protocol.MCP_DRAFT, telemetry_enabled=request.param, ) yield transport @@ -131,7 +131,7 @@ def get_result_model(self): assert headers["Mcp-Method"] == "method" assert "Mcp-Name" not in headers - async def test_send_request_adds_mcp_name_header(self, transport): + async def test_send_request_adds_mcp_name_header_for_tools_call(self, transport): """Test that the Mcp-Name header is added for tools/call.""" mock_response = AsyncMock() mock_response.ok = True @@ -162,6 +162,70 @@ def get_result_model(self): assert headers["Mcp-Method"] == "tools/call" assert headers["Mcp-Name"] == "test_tool" + async def test_send_request_adds_mcp_name_header_for_prompts_get(self, transport): + """Test that the Mcp-Name header is added for prompts/get.""" + mock_response = AsyncMock() + mock_response.ok = True + mock_response.content = Mock() + mock_response.content.at_eof.return_value = False + mock_response.json.return_value = {"jsonrpc": "2.0", "id": "1", "result": {}} + transport._session.post.return_value.__aenter__.return_value = mock_response + + class TestResult(types.BaseModel): + pass + + class TestParams(types.BaseModel): + name: str + + class TestRequest(types.MCPRequest[TestResult]): + method: str = "prompts/get" + params: TestParams + + def get_result_model(self): + return TestResult + + await transport._send_request( + "url", TestRequest(params=TestParams(name="test_prompt")) + ) + + call_args = transport._session.post.call_args + headers = call_args.kwargs["headers"] + assert headers["Mcp-Method"] == "prompts/get" + assert headers["Mcp-Name"] == "test_prompt" + + async def test_send_request_adds_mcp_name_header_for_resources_read( + self, transport + ): + """Test that the Mcp-Name header is added for resources/read.""" + mock_response = AsyncMock() + mock_response.ok = True + mock_response.content = Mock() + mock_response.content.at_eof.return_value = False + mock_response.json.return_value = {"jsonrpc": "2.0", "id": "1", "result": {}} + transport._session.post.return_value.__aenter__.return_value = mock_response + + class TestResult(types.BaseModel): + pass + + class TestParams(types.BaseModel): + uri: str + + class TestRequest(types.MCPRequest[TestResult]): + method: str = "resources/read" + params: TestParams + + def get_result_model(self): + return TestResult + + await transport._send_request( + "url", TestRequest(params=TestParams(uri="file:///test.txt")) + ) + + call_args = transport._session.post.call_args + headers = call_args.kwargs["headers"] + assert headers["Mcp-Method"] == "resources/read" + assert headers["Mcp-Name"] == "file:///test.txt" + # --- Version Negotiation Tests --- async def test_version_negotiation_raises_fallback(self, transport): @@ -201,6 +265,44 @@ def get_result_model(self): assert exc_info.value.negotiated_version == "DRAFT-2026-v1" assert transport._session.post.call_count == 1 + async def test_version_negotiation_raises_fallback_200_ok(self, transport): + """Tests that the client raises ProtocolNegotiationError when the server returns 200 OK with -32004.""" + from toolbox_core.exceptions import ProtocolNegotiationError + + mock_response_reject = AsyncMock() + mock_response_reject.ok = True + mock_response_reject.status = 200 + mock_response_reject.content.at_eof = MagicMock(return_value=False) + mock_response_reject.json.return_value = { + "jsonrpc": "2.0", + "id": "1", + "error": { + "code": -32004, + "message": "Unsupported protocol version", + "data": {"supported": ["DRAFT-2026-v1"]}, + }, + } + + transport._session.post.return_value.__aenter__.return_value = ( + mock_response_reject + ) + + class TestResult(types.BaseModel): + pass + + class TestRequest(types.MCPRequest[TestResult]): + method: str = "method" + params: dict = {} + + def get_result_model(self): + return TestResult + + with pytest.raises(ProtocolNegotiationError) as exc_info: + await transport._send_request("url", TestRequest()) + + assert exc_info.value.negotiated_version == "DRAFT-2026-v1" + assert transport._session.post.call_count == 1 + async def test_version_negotiation_empty_intersection(self, transport): """Tests that the client errors immediately without retrying when there is no mutual version.""" mock_response_reject = AsyncMock() diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index d41b3f581..e28bb03a1 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -276,9 +276,7 @@ async def test_load_tool_protocol_fallback_success(test_tool_str): mock_2025.tool_invoke.return_value = "ok_from_fallback" mock_2025_cls.return_value = mock_2025 - async with ToolboxClient( - TEST_BASE_URL, protocol=Protocol.MCP_v20260618 - ) as client: + async with ToolboxClient(TEST_BASE_URL, protocol=Protocol.MCP_DRAFT) as client: # This should trigger the fallback loaded_tool = await client.load_tool(TOOL_NAME) @@ -322,9 +320,7 @@ async def test_load_tool_protocol_fallback_infinite_loop_prevention(test_tool_st mock_2025.tool_get.side_effect = ProtocolNegotiationError("2024-11-05") mock_2025_cls.return_value = mock_2025 - async with ToolboxClient( - TEST_BASE_URL, protocol=Protocol.MCP_v20260618 - ) as client: + async with ToolboxClient(TEST_BASE_URL, protocol=Protocol.MCP_DRAFT) as client: with pytest.raises( ProtocolNegotiationError, match="Server requires protocol fallback to 2024-11-05", @@ -814,7 +810,7 @@ async def test_client_init_with_client_info(): def test_toolbox_client_no_warning_on_mcp(): """Test that initializing ToolboxClient with Protocol.MCP issues NO DeprecationWarning.""" # Mock the transport to avoid actual connection attempts or MCP version warnings - with patch("toolbox_core.client.McpHttpTransportV20250618") as mock_transport: + with patch("toolbox_core.client.McpHttpTransportV20251125") as mock_transport: with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") diff --git a/packages/toolbox-core/tests/test_e2e_mcp.py b/packages/toolbox-core/tests/test_e2e_mcp.py index c35695ffa..7ffb3ddfa 100644 --- a/packages/toolbox-core/tests/test_e2e_mcp.py +++ b/packages/toolbox-core/tests/test_e2e_mcp.py @@ -24,11 +24,9 @@ from toolbox_core.tool import ToolboxTool -# TODO: Include draft versions in E2E integration tests once the server -# supports SEP-2575 (stateless MCP / Request-Metadata). @pytest_asyncio.fixture( scope="function", - params=[v for v in Protocol.get_supported_mcp_versions() if "DRAFT" not in v], + params=[v for v in Protocol.get_supported_mcp_versions()], ) async def toolbox(request): """Creates a ToolboxClient instance shared by all tests in this module.""" @@ -100,20 +98,27 @@ async def test_run_tool_missing_params(self, get_n_rows_tool: ToolboxTool): with pytest.raises(TypeError, match="missing a required argument: 'num_rows'"): await get_n_rows_tool() - async def test_protocol_fallback_e2e(self): - """Tests that a client using MCP_LATEST can fallback to an older protocol against a server that doesn't support the latest version.""" - # The E2E server currently does not support DRAFT 2026, so this will trigger a fallback. + async def test_protocol_fallback_e2e(self, toolbox_server_url: str): + """Tests that a client using MCP_DRAFT can fallback to an older protocol against a server that doesn't support the draft version.""" + # The E2E server currently does not support DRAFT 2026 on port 5000, so this will trigger a fallback. + # However, port 5001 does support DRAFT 2026. async with ToolboxClient( - "http://localhost:5000", protocol=Protocol.MCP_LATEST + toolbox_server_url, protocol=Protocol.MCP_DRAFT ) as client: tool = await client.load_tool("get-n-rows") response = await tool(num_rows="1") assert "row1" in response # Verify that fallback occurred by checking the transport's final protocol version - assert ( - client._ToolboxClient__transport._protocol_version - != Protocol.MCP_LATEST.value - ) + if "5001" in toolbox_server_url: + assert ( + client._ToolboxClient__transport._protocol_version + == Protocol.MCP_DRAFT.value + ) + else: + assert ( + client._ToolboxClient__transport._protocol_version + != Protocol.MCP_DRAFT.value + ) async def test_run_tool_wrong_param_type(self, get_n_rows_tool: ToolboxTool): """Invoke a tool with wrong param type.""" @@ -463,3 +468,25 @@ async def test_run_tool_with_wrong_map_value_type(self, toolbox: ToolboxClient): execution_context={"env": "staging"}, user_scores={"user4": "not-an-integer"}, ) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("toolbox_server") +async def test_mcp_default_protocol(): + """Verify that omitting the protocol argument defaults correctly and works.""" + async with ToolboxClient("http://localhost:5000") as client: + tool = await client.load_tool("get-n-rows") + response = await tool(num_rows="1") + assert "row1" in response + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("toolbox_server") +async def test_mcp_draft_fallback(): + """Verify that explicitly using MCP_DRAFT against a server that doesn't support it falls back successfully.""" + async with ToolboxClient( + "http://localhost:5000", protocol=Protocol.MCP_DRAFT + ) as client: + tool = await client.load_tool("get-n-rows") + response = await tool(num_rows="1") + assert "row1" in response diff --git a/packages/toolbox-core/tests/test_tool.py b/packages/toolbox-core/tests/test_tool.py index 1b23db3ae..97576916e 100644 --- a/packages/toolbox-core/tests/test_tool.py +++ b/packages/toolbox-core/tests/test_tool.py @@ -409,9 +409,10 @@ def test_tool_init_basic(http_session, sample_tool_params, sample_tool_descripti bound_params={}, client_headers={}, ) + warnings_list = [w for w in record if not issubclass(w.category, ResourceWarning)] assert ( - len(record) == 0 - ), f"ToolboxTool instantiation unexpectedly warned: {[f'{w.category.__name__}: {w.message}' for w in record]}" + len(warnings_list) == 0 + ), f"ToolboxTool instantiation unexpectedly warned: {[f'{w.category.__name__}: {w.message}' for w in warnings_list]}" assert tool_instance.__name__ == TEST_TOOL_NAME assert inspect.iscoroutinefunction(tool_instance.__call__)