From 6df8f2cd6beb6dee2cd03a3a9059ab33fb59aadb Mon Sep 17 00:00:00 2001 From: Stenal P Jolly Date: Fri, 26 Jun 2026 23:15:39 +0530 Subject: [PATCH] refactor: Restructure SDK into subpackages & implement versioned HTTP transport layers Restructures the Java SDK package layout into domain-specific subpackages: - client - transport - auth - tool - exception Refactors HTTP transport to delegate version-specific initialization and requests to dedicated subclass routers: - HttpMcpTransportV20241105 - HttpMcpTransportV20250326 - HttpMcpTransportV20250618 - HttpMcpTransportV20251125 This establishes the directory layout and class structure for managing versioned transports. TAG=agy CONV=2e1a2106-f882-4b08-92df-a27ba37c6fae --- .../com/google/cloud/mcp/TelemetryHelper.java | 336 +++++ .../google/cloud/mcp/auth/AuthMethods.java | 51 + .../google/cloud/mcp/auth/AuthResolver.java | 55 + .../cloud/mcp/auth/AuthTokenGetter.java | 30 + .../cloud/mcp/auth/CredentialsProvider.java | 30 + .../mcp/auth/GoogleCredentialsProvider.java | 84 ++ .../google/cloud/mcp/auth/ResolvedAuth.java | 84 ++ .../mcp/client/McpToolboxClientBuilder.java | 144 +++ .../mcp/client/McpToolboxClientImpl.java | 414 +++++++ .../cloud/mcp/exception/McpException.java | 40 + .../java/com/google/cloud/mcp/tool/Tool.java | 261 ++++ .../google/cloud/mcp/tool/ToolDefinition.java | 84 ++ .../cloud/mcp/tool/ToolPostProcessor.java | 33 + .../cloud/mcp/tool/ToolPreProcessor.java | 34 + .../com/google/cloud/mcp/tool/ToolResult.java | 40 + .../cloud/mcp/transport/BaseMcpTransport.java | 507 ++++++++ .../cloud/mcp/transport/HttpMcpTransport.java | 175 +++ .../google/cloud/mcp/transport/Transport.java | 56 + .../mcp/transport/TransportManifest.java | 43 + .../mcp/transport/TransportResponse.java | 52 + .../v20241105/HttpMcpTransportV20241105.java | 137 +++ .../v20250326/HttpMcpTransportV20250326.java | 153 +++ .../v20250618/HttpMcpTransportV20250618.java | 139 +++ .../v20251125/HttpMcpTransportV20251125.java | 139 +++ .../com/google/cloud/mcp/McpCoverageTest.java | 114 ++ .../com/google/cloud/mcp/TelemetryTest.java | 245 ++++ .../cloud/mcp/auth/AuthMethodsTest.java | 215 ++++ .../mcp/client/HttpMcpToolboxClientTest.java | 342 ++++++ .../client/McpToolboxClientBuilderTest.java | 202 +++ .../McpToolboxClientImplErrorsTest.java | 276 +++++ .../McpToolboxClientImplHeadersTest.java | 319 +++++ .../McpToolboxClientImplJsonRpcTest.java | 622 ++++++++++ .../mcp/client/McpToolboxClientImplTest.java | 1088 +++++++++++++++++ .../com/google/cloud/mcp/tool/ToolTest.java | 561 +++++++++ .../cloud/mcp/tool/ToolValidationTest.java | 387 ++++++ .../mcp/transport/HttpMcpTransportTest.java | 503 ++++++++ 36 files changed, 7995 insertions(+) create mode 100644 src/main/java/com/google/cloud/mcp/TelemetryHelper.java create mode 100644 src/main/java/com/google/cloud/mcp/auth/AuthMethods.java create mode 100644 src/main/java/com/google/cloud/mcp/auth/AuthResolver.java create mode 100644 src/main/java/com/google/cloud/mcp/auth/AuthTokenGetter.java create mode 100644 src/main/java/com/google/cloud/mcp/auth/CredentialsProvider.java create mode 100644 src/main/java/com/google/cloud/mcp/auth/GoogleCredentialsProvider.java create mode 100644 src/main/java/com/google/cloud/mcp/auth/ResolvedAuth.java create mode 100644 src/main/java/com/google/cloud/mcp/client/McpToolboxClientBuilder.java create mode 100644 src/main/java/com/google/cloud/mcp/client/McpToolboxClientImpl.java create mode 100644 src/main/java/com/google/cloud/mcp/exception/McpException.java create mode 100644 src/main/java/com/google/cloud/mcp/tool/Tool.java create mode 100644 src/main/java/com/google/cloud/mcp/tool/ToolDefinition.java create mode 100644 src/main/java/com/google/cloud/mcp/tool/ToolPostProcessor.java create mode 100644 src/main/java/com/google/cloud/mcp/tool/ToolPreProcessor.java create mode 100644 src/main/java/com/google/cloud/mcp/tool/ToolResult.java create mode 100644 src/main/java/com/google/cloud/mcp/transport/BaseMcpTransport.java create mode 100644 src/main/java/com/google/cloud/mcp/transport/HttpMcpTransport.java create mode 100644 src/main/java/com/google/cloud/mcp/transport/Transport.java create mode 100644 src/main/java/com/google/cloud/mcp/transport/TransportManifest.java create mode 100644 src/main/java/com/google/cloud/mcp/transport/TransportResponse.java create mode 100644 src/main/java/com/google/cloud/mcp/transport/v20241105/HttpMcpTransportV20241105.java create mode 100644 src/main/java/com/google/cloud/mcp/transport/v20250326/HttpMcpTransportV20250326.java create mode 100644 src/main/java/com/google/cloud/mcp/transport/v20250618/HttpMcpTransportV20250618.java create mode 100644 src/main/java/com/google/cloud/mcp/transport/v20251125/HttpMcpTransportV20251125.java create mode 100644 src/test/java/com/google/cloud/mcp/McpCoverageTest.java create mode 100644 src/test/java/com/google/cloud/mcp/TelemetryTest.java create mode 100644 src/test/java/com/google/cloud/mcp/auth/AuthMethodsTest.java create mode 100644 src/test/java/com/google/cloud/mcp/client/HttpMcpToolboxClientTest.java create mode 100644 src/test/java/com/google/cloud/mcp/client/McpToolboxClientBuilderTest.java create mode 100644 src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplErrorsTest.java create mode 100644 src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplHeadersTest.java create mode 100644 src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplJsonRpcTest.java create mode 100644 src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplTest.java create mode 100644 src/test/java/com/google/cloud/mcp/tool/ToolTest.java create mode 100644 src/test/java/com/google/cloud/mcp/tool/ToolValidationTest.java create mode 100644 src/test/java/com/google/cloud/mcp/transport/HttpMcpTransportTest.java diff --git a/src/main/java/com/google/cloud/mcp/TelemetryHelper.java b/src/main/java/com/google/cloud/mcp/TelemetryHelper.java new file mode 100644 index 0000000..1f52ae7 --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/TelemetryHelper.java @@ -0,0 +1,336 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp; + +import io.opentelemetry.api.GlobalOpenTelemetry; +import io.opentelemetry.api.common.Attributes; +import io.opentelemetry.api.common.AttributesBuilder; +import io.opentelemetry.api.metrics.DoubleHistogram; +import io.opentelemetry.api.metrics.Meter; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.SpanKind; +import io.opentelemetry.api.trace.StatusCode; +import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.Scope; +import io.opentelemetry.context.propagation.TextMapPropagator; +import java.net.URI; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +/** Helper class for OpenTelemetry metrics and tracing instrumentation. */ +public final class TelemetryHelper { + /** Bucket boundary 0.01. */ + private static final double B_0_01 = 0.01; + + /** Bucket boundary 0.02. */ + private static final double B_0_02 = 0.02; + + /** Bucket boundary 0.05. */ + private static final double B_0_05 = 0.05; + + /** Bucket boundary 0.1. */ + private static final double B_0_1 = 0.1; + + /** Bucket boundary 0.2. */ + private static final double B_0_2 = 0.2; + + /** Bucket boundary 0.5. */ + private static final double B_0_5 = 0.5; + + /** Bucket boundary 1.0. */ + private static final double B_1 = 1.0; + + /** Bucket boundary 2.0. */ + private static final double B_2 = 2.0; + + /** Bucket boundary 5.0. */ + private static final double B_5 = 5.0; + + /** Bucket boundary 10.0. */ + private static final double B_10 = 10.0; + + /** Bucket boundary 30.0. */ + private static final double B_30 = 30.0; + + /** Bucket boundary 60.0. */ + private static final double B_60 = 60.0; + + /** Bucket boundary 120.0. */ + private static final double B_120 = 120.0; + + /** Bucket boundary 300.0. */ + private static final double B_300 = 300.0; + + /** Conversion factor from nanoseconds to seconds. */ + static final double NANOS_IN_SECOND = 1e9; + + /** Name of the instrumentation library. */ + private static final String INSTRUMENTATION_NAME = "toolbox.mcp.sdk"; + + // Dynamic / lazy OpenTelemetry binding cache + private static io.opentelemetry.api.OpenTelemetry lastOtel = null; + private static DoubleHistogram cachedOperationDuration = null; + private static DoubleHistogram cachedSessionDuration = null; + + private static synchronized void checkRebind() { + io.opentelemetry.api.OpenTelemetry currentOtel = GlobalOpenTelemetry.get(); + if (currentOtel != lastOtel) { + lastOtel = currentOtel; + Meter meter = currentOtel.getMeter(INSTRUMENTATION_NAME); + cachedOperationDuration = + meter + .histogramBuilder("mcp.client.operation.duration") + .setUnit("s") + .setDescription( + "Duration of MCP client operations (requests/notifications) from the time it was" + + " sent until the response or ack is received.") + .setExplicitBucketBoundariesAdvice( + Arrays.asList( + B_0_01, B_0_02, B_0_05, B_0_1, B_0_2, B_0_5, B_1, B_2, B_5, B_10, B_30, B_60, + B_120, B_300)) + .build(); + cachedSessionDuration = + meter + .histogramBuilder("mcp.client.session.duration") + .setUnit("s") + .setDescription("Total duration of MCP client sessions") + .setExplicitBucketBoundariesAdvice( + Arrays.asList( + B_0_01, B_0_02, B_0_05, B_0_1, B_0_2, B_0_5, B_1, B_2, B_5, B_10, B_30, B_60, + B_120, B_300)) + .build(); + } + } + + private static DoubleHistogram operationDuration() { + checkRebind(); + return cachedOperationDuration; + } + + private static DoubleHistogram sessionDuration() { + checkRebind(); + return cachedSessionDuration; + } + + private static Tracer tracer() { + return GlobalOpenTelemetry.getTracer(INSTRUMENTATION_NAME); + } + + private static TextMapPropagator propagator() { + return GlobalOpenTelemetry.getPropagators().getTextMapPropagator(); + } + + private TelemetryHelper() {} + + /** + * Helper record to extract ServerInfo. + * + * @param address The server host address. + * @param port The server port. + * @param protocol The network protocol (e.g. http). + */ + record ServerInfo(String address, Integer port, String protocol) {} + + static ServerInfo extractServerInfo(final String urlStr) { + try { + URI uri = new URI(urlStr); + String host = uri.getHost(); + if (host == null) { + host = uri.getAuthority(); + if (host != null && host.contains(":")) { + host = host.substring(0, host.indexOf(':')); + } + } + int port = uri.getPort(); + if (port == -1 && uri.getAuthority() != null && uri.getAuthority().contains(":")) { + try { + String auth = uri.getAuthority(); + port = Integer.parseInt(auth.substring(auth.indexOf(':') + 1)); + } catch (NumberFormatException e) { + // ignore + } + } + String protocol = uri.getScheme(); + if (protocol == null) { + protocol = "http"; + } + return new ServerInfo(host != null ? host : "", port != -1 ? port : null, protocol); + } catch (Exception e) { + return new ServerInfo("", null, "http"); + } + } + + /** Wrapper for recording client operation metrics and tracing spans. */ + public static class OperationSpan implements AutoCloseable { + /** The OpenTelemetry span. */ + private final Span span; + + /** The scope for the current span context. */ + private final Scope scope; + + /** Start time of the span in nanoseconds. */ + private final long startTimeNanos; + + /** Name of the MCP method. */ + private final String methodName; + + /** Protocol version of MCP. */ + private final String protocolVersion; + + /** Server base URL. */ + private final String serverUrl; + + /** Name of the tool. */ + private final String toolName; + + /** Class name of the error if an error occurred. */ + private String errorType = null; + + /** + * Constructs a new OperationSpan. + * + * @param method The MCP method name. + * @param version The protocol version. + * @param url The server base URL. + * @param tool The tool name, or null. + */ + public OperationSpan( + final String method, final String version, final String url, final String tool) { + this.methodName = method; + this.protocolVersion = version; + this.serverUrl = url; + this.toolName = tool; + this.startTimeNanos = System.nanoTime(); + + String spanName = tool != null ? method + " " + tool : method; + this.span = tracer().spanBuilder(spanName).setSpanKind(SpanKind.CLIENT).startSpan(); + this.scope = span.makeCurrent(); + + // Set standard span attributes + span.setAttribute("mcp.method.name", method); + span.setAttribute("mcp.protocol.version", version); + ServerInfo info = extractServerInfo(url); + span.setAttribute("server.address", info.address()); + span.setAttribute("network.protocol.name", info.protocol()); + span.setAttribute("network.transport", "tcp"); + if (info.port() != null) { + span.setAttribute("server.port", (long) info.port()); + } + if (tool != null) { + span.setAttribute("gen_ai.tool.name", tool); + } + if ("tools/call".equals(method)) { + span.setAttribute("gen_ai.operation.name", "execute_tool"); + } + } + + /** + * Gets W3C context headers to inject into the request. + * + * @return A map containing trace context headers. + */ + public Map getTraceContextHeaders() { + Map carrier = new HashMap<>(); + propagator().inject(Context.current(), carrier, Map::put); + return carrier; + } + + /** + * Records a throwable error on the span. + * + * @param t The error thrown. + */ + public void recordError(final Throwable t) { + span.recordException(t); + span.setStatus(StatusCode.ERROR, t.getMessage()); + this.errorType = t.getClass().getName(); + span.setAttribute("error.type", errorType); + } + + /** + * Records a JSON-RPC error on the span. + * + * @param code The JSON-RPC error code. + * @param message The error message. + */ + public void recordError(final int code, final String message) { + span.setStatus(StatusCode.ERROR, message); + this.errorType = "jsonrpc.error." + code; + span.setAttribute("error.type", errorType); + } + + @Override + public void close() { + scope.close(); + span.end(); + + // Record operation duration metric + double durationSeconds = (System.nanoTime() - startTimeNanos) / NANOS_IN_SECOND; + AttributesBuilder attrs = + Attributes.builder() + .put("mcp.method.name", methodName) + .put("mcp.protocol.version", protocolVersion); + ServerInfo info = extractServerInfo(serverUrl); + attrs.put("server.address", info.address()); + attrs.put("network.protocol.name", info.protocol()); + attrs.put("network.transport", "tcp"); + if (info.port() != null) { + attrs.put("server.port", (long) info.port()); + } + if (toolName != null) { + attrs.put("gen_ai.tool.name", toolName); + } + if ("tools/call".equals(methodName)) { + attrs.put("gen_ai.operation.name", "execute_tool"); + } + if (errorType != null) { + attrs.put("error.type", errorType); + } + + operationDuration().record(durationSeconds, attrs.build()); + } + } + + /** + * Records the duration of an MCP session. + * + * @param durationSeconds The duration of the session in seconds. + * @param protocolVersion The negotiated protocol version. + * @param serverUrl The server base URL. + * @param error The session error, or null if successful. + */ + public static void recordSessionDuration( + final double durationSeconds, + final String protocolVersion, + final String serverUrl, + final Throwable error) { + AttributesBuilder attrs = Attributes.builder().put("mcp.protocol.version", protocolVersion); + ServerInfo info = extractServerInfo(serverUrl); + attrs.put("server.address", info.address()); + attrs.put("network.protocol.name", info.protocol()); + attrs.put("network.transport", "tcp"); + if (info.port() != null) { + attrs.put("server.port", (long) info.port()); + } + if (error != null) { + attrs.put("error.type", error.getClass().getName()); + } + sessionDuration().record(durationSeconds, attrs.build()); + } +} diff --git a/src/main/java/com/google/cloud/mcp/auth/AuthMethods.java b/src/main/java/com/google/cloud/mcp/auth/AuthMethods.java new file mode 100644 index 0000000..a287009 --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/auth/AuthMethods.java @@ -0,0 +1,51 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.auth; + +import com.google.auth.oauth2.GoogleCredentials; +import com.google.auth.oauth2.IdTokenProvider; +import java.io.IOException; +import java.util.Collections; + +/** Utility methods for fetching OIDC credentials. */ +public final class AuthMethods { + private AuthMethods() {} + + /** + * Fetches a Google ID token for the given audience using the provided credentials. + * + * @param credentials The Google credentials. + * @param audience The audience for the ID token. + * @return The token prefixed with "Bearer ". + * @throws IOException If credentials refresh fails. + */ + public static String getGoogleIdToken(GoogleCredentials credentials, String audience) + throws IOException { + if (credentials == null) { + throw new IllegalArgumentException("Credentials must not be null"); + } + credentials.refreshIfExpired(); + if (credentials instanceof IdTokenProvider) { + String token = + ((IdTokenProvider) credentials) + .idTokenWithAudience(audience, Collections.emptyList()) + .getTokenValue(); + return token.startsWith("Bearer ") ? token : "Bearer " + token; + } + throw new IllegalArgumentException("Credentials are not an instance of IdTokenProvider"); + } +} diff --git a/src/main/java/com/google/cloud/mcp/auth/AuthResolver.java b/src/main/java/com/google/cloud/mcp/auth/AuthResolver.java new file mode 100644 index 0000000..0fccadd --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/auth/AuthResolver.java @@ -0,0 +1,55 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.auth; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +/** Handles concurrent resolution of token getters into a {@link ResolvedAuth} instance. */ +public final class AuthResolver { + private AuthResolver() {} + + /** + * Concurrently resolves all registered token getters. + * + * @param getters The map of service name to token getter. + * @return A CompletableFuture containing the resolved auth object. + */ + public static CompletableFuture resolve(Map getters) { + if (getters.isEmpty()) { + return CompletableFuture.completedFuture(new ResolvedAuth(Map.of())); + } + + var entries = List.copyOf(getters.entrySet()); + var futures = entries.stream().map(entry -> entry.getValue().getToken()).toList(); + + return CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])) + .thenApply( + v -> { + Map resolved = new HashMap<>(); + for (int i = 0; i < entries.size(); i++) { + String token = futures.get(i).join(); + if (token != null) { + resolved.put(entries.get(i).getKey(), token); + } + } + return new ResolvedAuth(resolved); + }); + } +} diff --git a/src/main/java/com/google/cloud/mcp/auth/AuthTokenGetter.java b/src/main/java/com/google/cloud/mcp/auth/AuthTokenGetter.java new file mode 100644 index 0000000..6067352 --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/auth/AuthTokenGetter.java @@ -0,0 +1,30 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.auth; + +import java.util.concurrent.CompletableFuture; + +/** Functional interface for retrieving authentication tokens dynamically. */ +@FunctionalInterface +public interface AuthTokenGetter { + /** + * Retrieves an authentication token. + * + * @return A CompletableFuture containing the token string. + */ + CompletableFuture getToken(); +} diff --git a/src/main/java/com/google/cloud/mcp/auth/CredentialsProvider.java b/src/main/java/com/google/cloud/mcp/auth/CredentialsProvider.java new file mode 100644 index 0000000..a9c7ca4 --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/auth/CredentialsProvider.java @@ -0,0 +1,30 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.auth; + +import java.util.concurrent.CompletableFuture; + +/** Functional interface for supplying the Authorization header. */ +@FunctionalInterface +public interface CredentialsProvider { + /** + * Retrieves the Authorization header value (e.g. "Bearer {@code }"). + * + * @return A CompletableFuture containing the full Authorization header value. + */ + CompletableFuture getAuthorizationHeader(); +} diff --git a/src/main/java/com/google/cloud/mcp/auth/GoogleCredentialsProvider.java b/src/main/java/com/google/cloud/mcp/auth/GoogleCredentialsProvider.java new file mode 100644 index 0000000..bb1490e --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/auth/GoogleCredentialsProvider.java @@ -0,0 +1,84 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.auth; + +import com.google.auth.oauth2.GoogleCredentials; +import java.io.IOException; +import java.util.concurrent.CompletableFuture; + +/** + * An implementation of CredentialsProvider that uses Google Application Default Credentials to + * fetch OIDC ID tokens. + */ +public class GoogleCredentialsProvider implements CredentialsProvider { + private final String audience; + private final CredentialsLoader credentialsLoader; + private volatile GoogleCredentials credentials; + + @FunctionalInterface + interface CredentialsLoader { + GoogleCredentials load() throws IOException; + } + + /** + * Constructs a new GoogleCredentialsProvider with a specified audience. + * + * @param audience The OIDC token audience (typically the service URL). + */ + public GoogleCredentialsProvider(String audience) { + this(audience, GoogleCredentials::getApplicationDefault); + } + + // Package-private constructor for unit testing + GoogleCredentialsProvider(String audience, CredentialsLoader credentialsLoader) { + if (audience == null || audience.isEmpty()) { + throw new IllegalArgumentException("Audience must not be null or empty"); + } + this.audience = audience; + this.credentialsLoader = credentialsLoader; + } + + private GoogleCredentials getCredentials() throws IOException { + GoogleCredentials localRef = credentials; + if (localRef == null) { + synchronized (this) { + localRef = credentials; + if (localRef == null) { + credentials = localRef = credentialsLoader.load(); + } + } + } + return localRef; + } + + @Override + public CompletableFuture getAuthorizationHeader() { + return CompletableFuture.supplyAsync( + () -> { + try { + GoogleCredentials creds = getCredentials(); + if (creds == null) { + return null; + } + return AuthMethods.getGoogleIdToken(creds, audience); + } catch (Exception e) { + // ADC not available or not OIDC-compatible. Proceed without global auth. + return null; + } + }); + } +} diff --git a/src/main/java/com/google/cloud/mcp/auth/ResolvedAuth.java b/src/main/java/com/google/cloud/mcp/auth/ResolvedAuth.java new file mode 100644 index 0000000..e79ecfb --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/auth/ResolvedAuth.java @@ -0,0 +1,84 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.auth; + +import com.google.cloud.mcp.tool.ToolDefinition; +import java.util.Map; + +/** Represents a resolved set of authentication credentials for a tool execution. */ +public final class ResolvedAuth { + private final Map tokens; + + /** + * Constructs a new ResolvedAuth. + * + * @param tokens The map of resolved auth tokens. + */ + public ResolvedAuth(Map tokens) { + Map copy = new java.util.HashMap<>(); + if (tokens != null) { + for (Map.Entry entry : tokens.entrySet()) { + if (entry.getKey() != null && entry.getValue() != null) { + copy.put(entry.getKey(), entry.getValue()); + } + } + } + this.tokens = Map.copyOf(copy); + } + + /** + * Applies the resolved credentials to the outgoing request parameters and headers. + * + * @param finalArgs The map of arguments for the tool execution. + * @param extraHeaders The map of extra headers for the tool execution. + * @param definition The tool definition to inspect for parameter auth mappings. + */ + public void applyTo( + Map finalArgs, Map extraHeaders, ToolDefinition definition) { + + for (Map.Entry entry : tokens.entrySet()) { + String serviceName = entry.getKey(); + String token = entry.getValue(); + if (token == null || token.isEmpty()) { + continue; + } + + // A. Parameter mapping + String paramName = findParameterForService(definition, serviceName); + if (paramName != null) { + finalArgs.put(paramName, token); + } + + // B. Header mapping + // Normalize to prevent double-prefixing if the provider already prefixed the token + String authorizationHeaderValue = + token.regionMatches(true, 0, "Bearer ", 0, 7) ? token : "Bearer " + token; + extraHeaders.put("Authorization", authorizationHeaderValue); + extraHeaders.put(serviceName + "_token", token); + } + } + + private static String findParameterForService(ToolDefinition definition, String serviceName) { + if (definition.parameters() == null) return null; + for (ToolDefinition.Parameter param : definition.parameters()) { + if (param.authSources() != null && param.authSources().contains(serviceName)) { + return param.name(); + } + } + return null; + } +} diff --git a/src/main/java/com/google/cloud/mcp/client/McpToolboxClientBuilder.java b/src/main/java/com/google/cloud/mcp/client/McpToolboxClientBuilder.java new file mode 100644 index 0000000..0b66c13 --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/client/McpToolboxClientBuilder.java @@ -0,0 +1,144 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.client; + +import com.google.cloud.mcp.McpToolboxClient; +import com.google.cloud.mcp.ProtocolVersion; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.tool.ToolPostProcessor; +import com.google.cloud.mcp.tool.ToolPreProcessor; +import com.google.cloud.mcp.transport.HttpMcpTransport; +import com.google.cloud.mcp.transport.Transport; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +/** Implementation of the {@link McpToolboxClient.Builder} interface. */ +public final class McpToolboxClientBuilder implements McpToolboxClient.Builder { + private String baseUrl; + private String apiKey; + private Map headers = new HashMap<>(); + private CredentialsProvider credentialsProvider; + private final List preProcessors = new ArrayList<>(); + private final List postProcessors = new ArrayList<>(); + private ProtocolVersion protocolVersion; + private java.net.http.HttpClient httpClient; + private java.util.concurrent.Executor executor; + + /** Constructs a new McpToolboxClientBuilder. */ + public McpToolboxClientBuilder() {} + + @Override + public McpToolboxClient.Builder baseUrl(String baseUrl) { + this.baseUrl = baseUrl; + return this; + } + + @Override + public McpToolboxClient.Builder apiKey(String apiKey) { + this.apiKey = apiKey; + return this; + } + + @Override + public McpToolboxClient.Builder headers(Map headers) { + if (headers != null) { + this.headers.putAll(headers); + } + return this; + } + + @Override + public McpToolboxClient.Builder credentialsProvider(CredentialsProvider credentialsProvider) { + this.credentialsProvider = credentialsProvider; + return this; + } + + @Override + public McpToolboxClient.Builder preProcessor(ToolPreProcessor preProcessor) { + if (preProcessor != null) { + this.preProcessors.add(preProcessor); + } + return this; + } + + @Override + public McpToolboxClient.Builder postProcessor(ToolPostProcessor postProcessor) { + if (postProcessor != null) { + this.postProcessors.add(postProcessor); + } + return this; + } + + @Override + public McpToolboxClient.Builder protocolVersion(ProtocolVersion protocolVersion) { + this.protocolVersion = protocolVersion; + return this; + } + + @Override + public McpToolboxClient.Builder httpClient(java.net.http.HttpClient httpClient) { + this.httpClient = httpClient; + return this; + } + + @Override + public McpToolboxClient.Builder executor(java.util.concurrent.Executor executor) { + this.executor = executor; + return this; + } + + @Override + public McpToolboxClient build() { + if (baseUrl == null || baseUrl.isEmpty()) { + throw new IllegalArgumentException("Base URL must be provided"); + } + // Normalize URL: remove trailing slash if present + if (baseUrl.endsWith("/")) { + baseUrl = baseUrl.substring(0, baseUrl.length() - 1); + } + + CredentialsProvider resolvedProvider = this.credentialsProvider; + boolean hasStaticAuth = false; + for (String key : this.headers.keySet()) { + if ("Authorization".equalsIgnoreCase(key)) { + hasStaticAuth = true; + break; + } + } + if (resolvedProvider == null + && !hasStaticAuth + && this.apiKey != null + && !this.apiKey.isEmpty()) { + String bearerKey = this.apiKey.startsWith("Bearer ") ? this.apiKey : "Bearer " + this.apiKey; + resolvedProvider = () -> CompletableFuture.completedFuture(bearerKey); + } + + Transport transport = + new HttpMcpTransport( + baseUrl, + this.headers, + resolvedProvider, + this.protocolVersion, + this.httpClient, + this.executor); + return new McpToolboxClientImpl( + transport, this.headers, resolvedProvider, preProcessors, postProcessors); + } +} diff --git a/src/main/java/com/google/cloud/mcp/client/McpToolboxClientImpl.java b/src/main/java/com/google/cloud/mcp/client/McpToolboxClientImpl.java new file mode 100644 index 0000000..5f586f5 --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/client/McpToolboxClientImpl.java @@ -0,0 +1,414 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.client; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.cloud.mcp.McpToolboxClient; +import com.google.cloud.mcp.auth.AuthTokenGetter; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.exception.McpException; +import com.google.cloud.mcp.tool.Tool; +import com.google.cloud.mcp.tool.ToolDefinition; +import com.google.cloud.mcp.tool.ToolPostProcessor; +import com.google.cloud.mcp.tool.ToolPreProcessor; +import com.google.cloud.mcp.tool.ToolResult; +import com.google.cloud.mcp.transport.HttpMcpTransport; +import com.google.cloud.mcp.transport.Transport; +import com.google.cloud.mcp.transport.TransportManifest; +import com.google.cloud.mcp.transport.TransportResponse; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.logging.Logger; + +/** Default implementation using Java 11 HttpClient. */ +public final class McpToolboxClientImpl implements McpToolboxClient { + + /** Logger for logging messages. */ + private static final Logger LOGGER = Logger.getLogger(McpToolboxClientImpl.class.getName()); + + /** Warning message for non-HTTPS URL usage. */ + private static final String HTTP_WARNING = + "This connection is using HTTP. To prevent credential exposure, " + + "please ensure all communication is sent over HTTPS."; + + /** The transport layer. */ + private final Transport transport; + + /** Client headers. */ + private final Map headers; + + /** Credentials provider. */ + private final CredentialsProvider credentialsProvider; + + /** Jackson ObjectMapper for JSON parsing. */ + private final ObjectMapper objectMapper; + + private final List preProcessors; + private final List postProcessors; + + /** + * Constructs a new McpToolboxClientImpl. + * + * @param clientTransport The underlying MCP transport layer. + */ + public McpToolboxClientImpl(final Transport clientTransport) { + this(clientTransport, java.util.Collections.emptyMap(), null); + } + + /** + * Constructs a new McpToolboxClientImpl. + * + * @param transport The underlying MCP transport layer. + * @param headers Fallback headers for deprecated constructor compatibility. + * @param credentialsProvider Fallback provider for deprecated constructor compatibility. + */ + @Deprecated + public McpToolboxClientImpl( + Transport transport, Map headers, CredentialsProvider credentialsProvider) { + this(transport, headers, credentialsProvider, null, null); + } + + /** + * Deprecated constructor. Use the constructor accepting {@link CredentialsProvider} instead. + * + * @param baseUrl The base URL. + * @param apiKey The static API key. + */ + @Deprecated + public McpToolboxClientImpl(final String baseUrl, final String apiKey) { + this( + new HttpMcpTransport(baseUrl, Collections.emptyMap(), apiKeyToProvider(apiKey)), + Collections.emptyMap(), + apiKeyToProvider(apiKey)); + } + + /** + * Constructs a new McpToolboxClientImpl with generic headers. + * + * @param baseUrl The base URL of the MCP Toolbox Server. + * @param clientHeaders The HTTP headers to include in requests. + */ + @Deprecated + public McpToolboxClientImpl(final String baseUrl, final Map clientHeaders) { + this(new HttpMcpTransport(baseUrl, clientHeaders), clientHeaders, null); + } + + /** + * Constructs a new McpToolboxClientImpl. + * + * @param baseUrl The base URL of the MCP Toolbox Server. + * @param clientHeaders The HTTP headers to include in requests. + * @param provider The provider for authentication headers (optional). + */ + @Deprecated + public McpToolboxClientImpl( + final String baseUrl, + final Map clientHeaders, + final CredentialsProvider provider) { + this(new HttpMcpTransport(baseUrl, clientHeaders, provider), clientHeaders, provider); + } + + /** + * Deprecated constructor. Use the constructor accepting {@link CredentialsProvider} instead. + * + * @param baseUrl The base URL. + * @param provider The provider for auth headers. + */ + @Deprecated + public McpToolboxClientImpl(final String baseUrl, final CredentialsProvider provider) { + this( + new HttpMcpTransport(baseUrl, Collections.emptyMap(), provider), + Collections.emptyMap(), + provider); + } + + /** + * Deprecated constructor. Use the constructor accepting {@link Transport} instead. + * + * @param clientTransport The underlying transport. + * @param provider The provider for auth headers. + */ + @Deprecated + public McpToolboxClientImpl(final Transport clientTransport, final CredentialsProvider provider) { + this(clientTransport, Collections.emptyMap(), provider); + } + + private static CredentialsProvider apiKeyToProvider(final String apiKey) { + if (apiKey == null || apiKey.isEmpty()) { + return null; + } + String bearerKey = apiKey.startsWith("Bearer ") ? apiKey : "Bearer " + apiKey; + return () -> CompletableFuture.completedFuture(bearerKey); + } + + /** + * Primary constructor for McpToolboxClientImpl. + * + * @param transport The underlying MCP transport layer. + * @param headers Default HTTP headers. + * @param credentialsProvider Provider for credentials. + * @param preProcessors List of pre-processors. + * @param postProcessors List of post-processors. + */ + public McpToolboxClientImpl( + Transport transport, + Map headers, + CredentialsProvider credentialsProvider, + List preProcessors, + List postProcessors) { + this.transport = transport; + this.headers = + headers != null + ? java.util.Collections.unmodifiableMap(new java.util.HashMap<>(headers)) + : java.util.Collections.emptyMap(); + this.credentialsProvider = credentialsProvider; + this.preProcessors = preProcessors != null ? List.copyOf(preProcessors) : List.of(); + this.postProcessors = postProcessors != null ? List.copyOf(postProcessors) : List.of(); + this.objectMapper = new ObjectMapper(); + } + + private CompletableFuture> getMergedMetadata( + final Map extraMetadata) { + if (this.transport instanceof HttpMcpTransport) { + return CompletableFuture.completedFuture( + extraMetadata != null ? extraMetadata : java.util.Collections.emptyMap()); + } + if (this.credentialsProvider == null && this.headers.isEmpty()) { + return CompletableFuture.completedFuture( + extraMetadata != null ? extraMetadata : java.util.Collections.emptyMap()); + } + return getAuthorizationHeader() + .thenApply( + authHeader -> { + Map merged = new HashMap<>(this.headers); + if (extraMetadata != null) { + extraMetadata.forEach( + (k, v) -> { + if (!"Authorization".equalsIgnoreCase(k)) { + merged.put(k, v); + } + }); + } + String finalAuthHeader = null; + if (extraMetadata != null) { + finalAuthHeader = + extraMetadata.keySet().stream() + .filter(k -> "Authorization".equalsIgnoreCase(k)) + .findFirst() + .map(extraMetadata::get) + .orElse(null); + } + if (finalAuthHeader == null) { + finalAuthHeader = authHeader; + } + if (finalAuthHeader != null) { + merged.put("Authorization", finalAuthHeader); + } + return merged; + }); + } + + @Override + public CompletableFuture> listTools() { + return loadToolset(""); + } + + @Override + public CompletableFuture> loadToolset(final String toolsetName) { + return getMergedMetadata(java.util.Collections.emptyMap()) + .thenCompose( + mergedMetadata -> + transport + .listTools(toolsetName, mergedMetadata) + .thenApply(TransportManifest::getTools)); + } + + @Override + public CompletableFuture> loadToolset( + final String toolsetName, + final Map> paramBinds, + final Map> authBinds, + final boolean strict) { + + if (this.transport.getBaseUrl().toLowerCase(java.util.Locale.ROOT).startsWith("http://") + && authBinds != null + && !authBinds.isEmpty()) { + LOGGER.warning(HTTP_WARNING); + } + + CompletableFuture> definitionsFuture = loadToolset(toolsetName); + + return definitionsFuture.thenApply( + defs -> { + if (strict) { + Set unknownTools = new HashSet<>(); + if (paramBinds != null) { + unknownTools.addAll(paramBinds.keySet()); + } + if (authBinds != null) { + unknownTools.addAll(authBinds.keySet()); + } + unknownTools.removeAll(defs.keySet()); + if (!unknownTools.isEmpty()) { + throw new IllegalArgumentException( + "Strict mode error: Bindings provided for unknown tools: " + unknownTools); + } + } + + Map tools = new HashMap<>(); + for (Map.Entry entry : defs.entrySet()) { + String toolName = entry.getKey(); + Tool tool = new Tool(toolName, entry.getValue(), this); + if (paramBinds != null && paramBinds.containsKey(toolName)) { + paramBinds.get(toolName).forEach(tool::bindParam); + } + if (authBinds != null && authBinds.containsKey(toolName)) { + authBinds.get(toolName).forEach(tool::addAuthTokenGetter); + } + for (ToolPreProcessor preProcessor : this.preProcessors) { + tool.addPreProcessor(preProcessor); + } + for (ToolPostProcessor postProcessor : this.postProcessors) { + tool.addPostProcessor(postProcessor); + } + tools.put(toolName, tool); + } + return tools; + }); + } + + @Override + public CompletableFuture loadTool(final String toolName) { + return loadTool(toolName, Collections.emptyMap()); + } + + @Override + public CompletableFuture loadTool( + final String toolName, final Map authTokenGetters) { + if (this.transport.getBaseUrl().toLowerCase(java.util.Locale.ROOT).startsWith("http://") + && authTokenGetters != null + && !authTokenGetters.isEmpty()) { + LOGGER.warning(HTTP_WARNING); + } + return listTools() + .thenApply( + tools -> { + if (!tools.containsKey(toolName)) { + throw new RuntimeException("Tool not found: " + toolName); + } + Tool tool = new Tool(toolName, tools.get(toolName), this); + if (authTokenGetters != null) { + authTokenGetters.forEach(tool::addAuthTokenGetter); + } + for (ToolPreProcessor preProcessor : this.preProcessors) { + tool.addPreProcessor(preProcessor); + } + for (ToolPostProcessor postProcessor : this.postProcessors) { + tool.addPostProcessor(postProcessor); + } + return tool; + }); + } + + @Override + public CompletableFuture invokeTool( + final String toolName, final Map arguments) { + return invokeTool(toolName, arguments, Collections.emptyMap()); + } + + @Override + public CompletableFuture invokeTool( + final String toolName, + final Map arguments, + final Map extraHeaders) { + if (this.transport.getBaseUrl().toLowerCase(java.util.Locale.ROOT).startsWith("http://") + && extraHeaders != null + && !extraHeaders.isEmpty()) { + LOGGER.warning(HTTP_WARNING); + } + return getMergedMetadata(extraHeaders) + .thenCompose( + mergedMetadata -> + transport + .invokeTool(toolName, arguments, mergedMetadata) + .thenApply(res -> handleInvokeResponse(res, toolName))); + } + + private CompletableFuture getAuthorizationHeader() { + if (this.credentialsProvider != null) { + return this.credentialsProvider.getAuthorizationHeader(); + } + for (Map.Entry entry : this.headers.entrySet()) { + if ("Authorization".equalsIgnoreCase(entry.getKey())) { + return CompletableFuture.completedFuture(entry.getValue()); + } + } + return CompletableFuture.completedFuture(null); + } + + private ToolResult handleInvokeResponse(final TransportResponse response, final String toolName) { + String body = response.getBody(); + if (response.getStatusCode() != java.net.HttpURLConnection.HTTP_OK) { + return new ToolResult( + java.util.List.of( + new ToolResult.Content("text", "Error " + response.getStatusCode() + ": " + body)), + true); + } + try { + JsonNode root = objectMapper.readTree(body); + if (root.has("error")) { + JsonNode errNode = root.get("error"); + int code = errNode.has("code") ? errNode.get("code").asInt() : -1; + String msg = errNode.has("message") ? errNode.get("message").asText() : errNode.toString(); + return new ToolResult( + java.util.List.of(new ToolResult.Content("text", "MCP Error: " + msg)), true); + } + + boolean isError = root.has("isError") && root.get("isError").asBoolean(); + + JsonNode result = root.get("result"); + if (result != null) { + ToolResult parsedResult = objectMapper.treeToValue(result, ToolResult.class); + if (parsedResult.content() == null) { + return new ToolResult( + java.util.List.of(new ToolResult.Content("text", result.asText())), + isError || parsedResult.isError()); + } + return parsedResult; + } + + return new ToolResult(java.util.List.of(new ToolResult.Content("text", body)), isError); + } catch (Exception e) { + return new ToolResult(java.util.List.of(new ToolResult.Content("text", body)), false); + } + } + + @Override + public void close() { + try { + transport.close(); + } catch (Exception e) { + throw new McpException("Failed to close transport", e); + } + } +} diff --git a/src/main/java/com/google/cloud/mcp/exception/McpException.java b/src/main/java/com/google/cloud/mcp/exception/McpException.java new file mode 100644 index 0000000..9016674 --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/exception/McpException.java @@ -0,0 +1,40 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.exception; + +/** Unchecked exception thrown for MCP Toolbox Client operations and protocol failures. */ +public class McpException extends RuntimeException { + + /** + * Constructs a new McpException with the specified detail message. + * + * @param message The detail message. + */ + public McpException(String message) { + super(message); + } + + /** + * Constructs a new McpException with the specified detail message and cause. + * + * @param message The detail message. + * @param cause The cause. + */ + public McpException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/src/main/java/com/google/cloud/mcp/tool/Tool.java b/src/main/java/com/google/cloud/mcp/tool/Tool.java new file mode 100644 index 0000000..4fb0229 --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/tool/Tool.java @@ -0,0 +1,261 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.tool; + +import com.google.cloud.mcp.McpToolboxClient; +import com.google.cloud.mcp.auth.AuthResolver; +import com.google.cloud.mcp.auth.AuthTokenGetter; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.function.Supplier; + +/** + * Represents a loaded tool ready to be invoked. Handles parameter binding, authentication token + * resolution, and input validation. + */ +public class Tool { + private final String name; + private final ToolDefinition definition; + private final McpToolboxClient client; + + private final Map boundParameters = new HashMap<>(); + private final Map authGetters = new HashMap<>(); + private final List preProcessors = new ArrayList<>(); + private final List postProcessors = new ArrayList<>(); + + /** + * Constructs a new Tool. + * + * @param name The name of the tool. + * @param definition The definition of the tool. + * @param client The client used to invoke the tool. + */ + public Tool(String name, ToolDefinition definition, McpToolboxClient client) { + this.name = name; + this.definition = definition; + this.client = client; + } + + /** + * Returns the name of the tool. + * + * @return The tool name. + */ + public String name() { + return name; + } + + /** + * Returns the definition of the tool. + * + * @return The tool definition. + */ + public ToolDefinition definition() { + return definition; + } + + /** + * Binds a static value to a parameter. + * + * @param key The parameter name. + * @param value The value to bind. + * @return The tool instance. + */ + public Tool bindParam(String key, Object value) { + this.boundParameters.put(key, value); + return this; + } + + /** + * Binds a dynamic value supplier to a parameter. + * + * @param key The parameter name. + * @param valueSupplier The supplier that provides the value at execution time. + * @return The tool instance. + */ + public Tool bindParam(String key, Supplier valueSupplier) { + this.boundParameters.put(key, valueSupplier); + return this; + } + + /** + * Registers an authentication token getter for a specific service. + * + * @param serviceName The name of the service. + * @param getter The token getter. + * @return The tool instance. + */ + public Tool addAuthTokenGetter(String serviceName, AuthTokenGetter getter) { + this.authGetters.put(serviceName, getter); + return this; + } + + /** + * Adds a pre-processor to the tool. + * + * @param processor The pre-processor to add. + * @return The tool instance. + */ + public Tool addPreProcessor(ToolPreProcessor processor) { + this.preProcessors.add(processor); + return this; + } + + /** + * Adds a post-processor to the tool. + * + * @param processor The post-processor to add. + * @return The tool instance. + */ + public Tool addPostProcessor(ToolPostProcessor processor) { + this.postProcessors.add(processor); + return this; + } + + /** + * Executes the tool with the provided arguments, applying any bound parameters and resolving + * authentication tokens. + * + * @param args The arguments for the tool invocation. + * @return A CompletableFuture containing the result of the tool execution. + */ + public CompletableFuture execute(Map args) { + CompletableFuture> argsFuture = + CompletableFuture.completedFuture(new HashMap<>(args)); + + for (ToolPreProcessor preProcessor : preProcessors) { + argsFuture = argsFuture.thenCompose(currentArgs -> preProcessor.process(name, currentArgs)); + } + + CompletableFuture resultFuture = + argsFuture.thenCompose( + processedArgs -> { + Map finalArgs = + java.util.Collections.synchronizedMap(new HashMap<>(processedArgs)); + Map extraHeaders = + java.util.Collections.synchronizedMap(new HashMap<>()); + + // 1. Apply Bound Parameters + for (Map.Entry entry : boundParameters.entrySet()) { + Object val = entry.getValue(); + if (val instanceof Supplier) { + finalArgs.put(entry.getKey(), ((Supplier) val).get()); + } else { + finalArgs.put(entry.getKey(), val); + } + } + + // 2. Resolve Auth & Execute + return AuthResolver.resolve(authGetters) + .thenCompose( + resolvedAuth -> { + try { + // Apply credential parameter bindings and extra headers + resolvedAuth.applyTo(finalArgs, extraHeaders, definition); + + // Validation & Cleanup + validateAndSanitizeArgs(finalArgs); + return client.invokeTool(name, finalArgs, extraHeaders); + } catch (Exception e) { + return CompletableFuture.failedFuture(e); + } + }); + }); + + for (ToolPostProcessor postProcessor : postProcessors) { + resultFuture = resultFuture.thenCompose(res -> postProcessor.process(name, res)); + } + + return resultFuture; + } + + /** Validates arguments against the tool definition and removes null values. */ + private void validateAndSanitizeArgs(Map args) { + // Remove nulls first (filtering none values) + args.values().removeIf(Objects::isNull); + + if (definition.parameters() == null) return; + + for (ToolDefinition.Parameter param : definition.parameters()) { + Object value = args.get(param.name()); + + if (value == null && param.defaultValue() != null) { + value = deepCopy(param.defaultValue()); + args.put(param.name(), value); + } + + // A. Check Required Parameters + if (param.required() && value == null) { + throw new IllegalArgumentException( + String.format( + "Missing required parameter '%s' for tool '%s'.", param.name(), this.name)); + } + + // B. Check Parameter Types (only if value is present) + if (value != null && param.type() != null) { + if (!isTypeMatch(value, param.type())) { + throw new IllegalArgumentException( + String.format( + "Parameter '%s' expected type '%s' but got '%s'.", + param.name(), param.type(), value.getClass().getSimpleName())); + } + } + } + } + + private Object deepCopy(Object value) { + if (value instanceof Map) { + Map map = (Map) value; + Map copy = new HashMap<>(); + for (Map.Entry entry : map.entrySet()) { + copy.put(deepCopy(entry.getKey()), deepCopy(entry.getValue())); + } + return copy; + } else if (value instanceof List) { + List list = (List) value; + List copy = new ArrayList<>(); + for (Object item : list) { + copy.add(deepCopy(item)); + } + return copy; + } + return value; + } + + private boolean isTypeMatch(Object value, String type) { + switch (type.toLowerCase()) { + case "string": + return value instanceof String; + case "integer": + return value instanceof Integer || value instanceof Long; + case "number": + return value instanceof Number; // Covers Integer, Long, Float, Double + case "boolean": + return value instanceof Boolean; + case "array": + return value instanceof java.util.List || value.getClass().isArray(); + case "object": + return value instanceof Map; + default: + return true; + } + } +} diff --git a/src/main/java/com/google/cloud/mcp/tool/ToolDefinition.java b/src/main/java/com/google/cloud/mcp/tool/ToolDefinition.java new file mode 100644 index 0000000..d73d098 --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/tool/ToolDefinition.java @@ -0,0 +1,84 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.tool; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; + +/** + * Represents the definition of a tool, including its description and parameters. + * + * @param description A description of what the tool does. + * @param parameters A list of parameters the tool accepts. + * @param authRequired A list of authentication sources required by the tool. + * @param readOnlyHint Hint indicating whether the tool is read-only. + * @param destructiveHint Hint indicating whether the tool is destructive. + */ +@JsonIgnoreProperties(ignoreUnknown = true) +public record ToolDefinition( + String description, + List parameters, + List authRequired, + Boolean readOnlyHint, + Boolean destructiveHint) { + + /** + * Backward-compatible constructor. + * + * @param description A description of what the tool does. + * @param parameters A list of parameters the tool accepts. + * @param authRequired List of auth services required. + */ + public ToolDefinition(String description, List parameters, List authRequired) { + this(description, parameters, authRequired, null, null); + } + + /** + * Represents a parameter of a tool. + * + * @param name The name of the parameter. + * @param type The type of the parameter (e.g., "string", "number"). + * @param required Whether the parameter is required. + * @param description A description of the parameter. + * @param authSources A list of authentication sources for this parameter. + * @param defaultValue The default value for the parameter. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public record Parameter( + String name, + String type, + boolean required, + String description, + List authSources, // Maps services to parameters + @JsonProperty("default") Object defaultValue) { + + /** + * Backward-compatible constructor. + * + * @param name The name of the parameter. + * @param type The type of the parameter. + * @param required Whether the parameter is required. + * @param description A description of the parameter. + * @param authSources Authentication sources list. + */ + public Parameter( + String name, String type, boolean required, String description, List authSources) { + this(name, type, required, description, authSources, null); + } + } +} diff --git a/src/main/java/com/google/cloud/mcp/tool/ToolPostProcessor.java b/src/main/java/com/google/cloud/mcp/tool/ToolPostProcessor.java new file mode 100644 index 0000000..61280ea --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/tool/ToolPostProcessor.java @@ -0,0 +1,33 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.tool; + +import java.util.concurrent.CompletableFuture; + +/** A functional interface for post-processing tool results after invocation. */ +@FunctionalInterface +public interface ToolPostProcessor { + + /** + * Processes the result of a tool after it has been invoked. + * + * @param toolName The name of the tool that was invoked. + * @param result The original tool result. + * @return A CompletableFuture containing the processed tool result. + */ + CompletableFuture process(String toolName, ToolResult result); +} diff --git a/src/main/java/com/google/cloud/mcp/tool/ToolPreProcessor.java b/src/main/java/com/google/cloud/mcp/tool/ToolPreProcessor.java new file mode 100644 index 0000000..190e38d --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/tool/ToolPreProcessor.java @@ -0,0 +1,34 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.tool; + +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +/** A functional interface for pre-processing tool inputs before invocation. */ +@FunctionalInterface +public interface ToolPreProcessor { + + /** + * Processes the input arguments for a tool before it is invoked. + * + * @param toolName The name of the tool being invoked. + * @param arguments The original arguments provided to the tool. + * @return A CompletableFuture containing the processed arguments. + */ + CompletableFuture> process(String toolName, Map arguments); +} diff --git a/src/main/java/com/google/cloud/mcp/tool/ToolResult.java b/src/main/java/com/google/cloud/mcp/tool/ToolResult.java new file mode 100644 index 0000000..29b6a61 --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/tool/ToolResult.java @@ -0,0 +1,40 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.tool; + +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.util.List; + +/** + * Represents the result of a tool invocation. + * + * @param content A list of content items returned by the tool. + * @param isError Whether the invocation resulted in an error. + */ +@JsonIgnoreProperties(ignoreUnknown = true) +public record ToolResult( + @JsonProperty("content") List content, @JsonProperty("isError") boolean isError) { + /** + * Represents a single content item in a tool result. + * + * @param type The type of content (e.g., "text"). + * @param text The text content. + */ + @JsonIgnoreProperties(ignoreUnknown = true) + public record Content(@JsonProperty("type") String type, @JsonProperty("text") String text) {} +} diff --git a/src/main/java/com/google/cloud/mcp/transport/BaseMcpTransport.java b/src/main/java/com/google/cloud/mcp/transport/BaseMcpTransport.java new file mode 100644 index 0000000..68342a8 --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/transport/BaseMcpTransport.java @@ -0,0 +1,507 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.transport; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.cloud.mcp.JsonRpc; +import com.google.cloud.mcp.ProtocolVersion; +import com.google.cloud.mcp.TelemetryHelper; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.tool.ToolDefinition; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.time.Duration; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.logging.Logger; + +public abstract class BaseMcpTransport implements Transport { + + protected static final Logger logger = Logger.getLogger(BaseMcpTransport.class.getName()); + protected static final String HTTP_WARNING = + "This connection is using HTTP. To prevent credential exposure, please ensure all" + + " communication is sent over HTTPS."; + + protected final String baseUrl; + protected final Map clientHeaders; + protected final CredentialsProvider credentialsProvider; + protected final HttpClient httpClient; + protected final ObjectMapper objectMapper; + protected final ProtocolVersion preferredProtocolVersion; + protected final Object initLock = new Object(); + protected CompletableFuture initFuture; + + /** The start time of the session in nanoseconds. */ + protected Long sessionStartTime; + + /** The error that occurred during the session, if any. */ + protected Throwable sessionError; + + /** The negotiated protocol version. */ + protected ProtocolVersion negotiatedProtocolVersion; + + /** + * Constructs a new BaseMcpTransport. + * + * @param baseUrl The base URL. + * @param clientHeaders The client headers. + * @param credentialsProvider The credentials provider. + * @param preferredProtocolVersion The preferred protocol version. + * @param httpClient The HTTP client. + * @param executor The executor. + */ + protected BaseMcpTransport( + final String baseUrl, + final Map clientHeaders, + final CredentialsProvider credentialsProvider, + final ProtocolVersion preferredProtocolVersion, + final HttpClient httpClient, + final java.util.concurrent.Executor executor) { + if (baseUrl == null || baseUrl.isEmpty()) { + throw new IllegalArgumentException("Base URL must be provided"); + } + this.baseUrl = baseUrl.endsWith("/") ? baseUrl.substring(0, baseUrl.length() - 1) : baseUrl; + this.clientHeaders = + clientHeaders != null + ? java.util.Collections.unmodifiableMap(new java.util.HashMap<>(clientHeaders)) + : java.util.Collections.emptyMap(); + this.credentialsProvider = credentialsProvider; + this.preferredProtocolVersion = + preferredProtocolVersion != null + ? preferredProtocolVersion + : ProtocolVersion.VERSION_2025_11_25; + if (httpClient != null) { + this.httpClient = httpClient; + } else { + HttpClient.Builder builder = + HttpClient.newBuilder() + .cookieHandler(new java.net.CookieManager()) + .connectTimeout(Duration.ofSeconds(10)); + if (executor != null) { + builder.executor(executor); + } + this.httpClient = builder.build(); + } + this.objectMapper = new ObjectMapper(); + } + + @Override + public final String getBaseUrl() { + return this.baseUrl; + } + + final CompletableFuture> mergeHeaders( + final Map extraMetadata) { + CompletableFuture authFuture = + this.credentialsProvider != null + ? this.credentialsProvider.getAuthorizationHeader() + : CompletableFuture.completedFuture(null); + + return authFuture.thenApply( + providerAuth -> { + Map merged = new HashMap<>(); + + // 1. Find dynamic or static Authorization header + String finalAuthHeader = null; + + // A. Check extraMetadata first + if (extraMetadata != null) { + String authKeyInExtra = + extraMetadata.keySet().stream() + .filter(k -> "Authorization".equalsIgnoreCase(k)) + .findFirst() + .orElse(null); + if (authKeyInExtra != null) { + finalAuthHeader = extraMetadata.get(authKeyInExtra); + } + } + + // B. If not in extraMetadata, check credentialsProvider + if (finalAuthHeader == null) { + finalAuthHeader = providerAuth; + } + + // C. If still null, check clientHeaders + if (finalAuthHeader == null) { + for (Map.Entry entry : this.clientHeaders.entrySet()) { + if ("Authorization".equalsIgnoreCase(entry.getKey())) { + finalAuthHeader = entry.getValue(); + break; + } + } + } + + // 2. Put all client-level headers except Authorization + this.clientHeaders.forEach( + (k, v) -> { + if (!"Authorization".equalsIgnoreCase(k)) { + merged.put(k, v); + } + }); + + // 3. Put all extra/call-level metadata except Authorization + if (extraMetadata != null) { + extraMetadata.forEach( + (k, v) -> { + if (!"Authorization".equalsIgnoreCase(k)) { + merged.put(k, v); + } + }); + } + + // 4. Put the final Authorization header if found + if (finalAuthHeader != null) { + merged.put("Authorization", finalAuthHeader); + } + + return merged; + }); + } + + final CompletableFuture ensureInitialized(final Map extraMetadata) { + synchronized (initLock) { + if (initFuture == null) { + if (sessionStartTime == null) { + sessionStartTime = System.nanoTime(); + } + TelemetryHelper.OperationSpan initSpan = + new TelemetryHelper.OperationSpan( + "initialize", preferredProtocolVersion.getValue(), baseUrl, null); + + Map handshakeMetadata = new HashMap<>(); + if (extraMetadata != null) { + String authKey = + extraMetadata.keySet().stream() + .filter(k -> "Authorization".equalsIgnoreCase(k)) + .findFirst() + .orElse(null); + if (authKey != null) { + handshakeMetadata.put("Authorization", extraMetadata.get(authKey)); + } + } + CompletableFuture future = + mergeHeaders(handshakeMetadata) + .thenCompose( + handshakeHeaders -> { + String authHeader = handshakeHeaders.get("Authorization"); + Map traceHeaders = initSpan.getTraceContextHeaders(); + return performInitialization(authHeader, handshakeHeaders, traceHeaders); + }); + + future.whenComplete( + (v, err) -> { + if (err != null) { + initSpan.recordError(err); + sessionError = err; + synchronized (initLock) { + initFuture = null; + } + } + initSpan.close(); + }); + initFuture = future; + return future; + } + return initFuture; + } + } + + /** + * Performs the version-specific initialization handshake. + * + * @param authHeader The authorization header value, if present. + * @param handshakeHeaders The resolved headers for the handshake. + * @param traceHeaders The trace context headers to propagate. + * @return A CompletableFuture that completes when initialization is done. + */ + protected abstract CompletableFuture performInitialization( + final String authHeader, + final Map handshakeHeaders, + final Map traceHeaders); + + protected abstract void applyProtocolHeaders(final HttpRequest.Builder builder); + + @Override + public final CompletableFuture listTools( + final String toolsetName, final Map metadata) { + if (this.baseUrl.toLowerCase(java.util.Locale.ROOT).startsWith("http://") + && !metadata.isEmpty()) { + logger.warning(HTTP_WARNING); + } + return ensureInitialized(metadata) + .thenCompose(v -> mergeHeaders(metadata)) + .thenCompose( + mergedHeaders -> { + String path = toolsetName != null && !toolsetName.isEmpty() ? "/" + toolsetName : ""; + String url = baseUrl + path; + + TelemetryHelper.OperationSpan listSpan = + new TelemetryHelper.OperationSpan( + "tools/list", + negotiatedProtocolVersion != null + ? negotiatedProtocolVersion.getValue() + : preferredProtocolVersion.getValue(), + url, + null); + + try { + Map traceHeaders = listSpan.getTraceContextHeaders(); + JsonRpc.RequestMetadata reqMetadata = + new JsonRpc.RequestMetadata( + traceHeaders.get("traceparent"), traceHeaders.get("tracestate")); + + JsonRpc.Request listReq = + new JsonRpc.Request( + "tools/list", new JsonRpc.ListToolsParams(null, reqMetadata)); + String body = objectMapper.writeValueAsString(listReq); + HttpRequest.Builder req = + HttpRequest.newBuilder() + .uri(URI.create(url)) + .POST(HttpRequest.BodyPublishers.ofString(body)); + mergedHeaders.forEach(req::setHeader); + applyProtocolHeaders(req); + + return httpClient + .sendAsync(req.build(), HttpResponse.BodyHandlers.ofString()) + .thenApply(res -> handleListToolsResponse(res, listSpan)) + .whenComplete( + (res, err) -> { + if (err != null) { + listSpan.recordError(err); + } + listSpan.close(); + }); + } catch (Exception e) { + listSpan.recordError(e); + listSpan.close(); + return CompletableFuture.failedFuture(e); + } + }); + } + + @Override + public final CompletableFuture invokeTool( + final String toolName, + final Map arguments, + final Map metadata) { + if (this.baseUrl.toLowerCase(java.util.Locale.ROOT).startsWith("http://") + && !metadata.isEmpty()) { + logger.warning(HTTP_WARNING); + } + return ensureInitialized(metadata) + .thenCompose(v -> mergeHeaders(metadata)) + .thenCompose( + mergedHeaders -> { + TelemetryHelper.OperationSpan callSpan = + new TelemetryHelper.OperationSpan( + "tools/call", + negotiatedProtocolVersion != null + ? negotiatedProtocolVersion.getValue() + : preferredProtocolVersion.getValue(), + baseUrl, + toolName); + + try { + Map traceHeaders = callSpan.getTraceContextHeaders(); + JsonRpc.RequestMetadata reqMetadata = + new JsonRpc.RequestMetadata( + traceHeaders.get("traceparent"), traceHeaders.get("tracestate")); + + JsonRpc.Request invokeReq = + new JsonRpc.Request( + "tools/call", new JsonRpc.CallToolParams(toolName, arguments, reqMetadata)); + String requestBody = objectMapper.writeValueAsString(invokeReq); + + HttpRequest.Builder requestBuilder = + HttpRequest.newBuilder() + .uri(URI.create(baseUrl)) + .POST(HttpRequest.BodyPublishers.ofString(requestBody)); + + mergedHeaders.forEach(requestBuilder::setHeader); + applyProtocolHeaders(requestBuilder); + + return httpClient + .sendAsync(requestBuilder.build(), HttpResponse.BodyHandlers.ofString()) + .thenApply( + res -> { + if (res.statusCode() < 200 || res.statusCode() >= 300) { + callSpan.recordError( + res.statusCode(), "Error " + res.statusCode() + ": " + res.body()); + } else { + try { + JsonNode root = objectMapper.readTree(res.body()); + if (root.has("error")) { + JsonNode errNode = root.get("error"); + int code = errNode.has("code") ? errNode.get("code").asInt() : -1; + String msg = + errNode.has("message") + ? errNode.get("message").asText() + : errNode.toString(); + callSpan.recordError(code, msg); + } + } catch (Exception ignored) { + // Ignore parsing exceptions here + } + } + return new TransportResponse(res.statusCode(), res.body()); + }) + .whenComplete( + (res, err) -> { + if (err != null) { + callSpan.recordError(err); + } + callSpan.close(); + }); + } catch (Exception e) { + callSpan.recordError(e); + callSpan.close(); + return CompletableFuture.failedFuture(e); + } + }); + } + + @Override + public void close() { + if (sessionStartTime != null) { + double durationSeconds = (System.nanoTime() - sessionStartTime) / 1e9; + TelemetryHelper.recordSessionDuration( + durationSeconds, + negotiatedProtocolVersion != null + ? negotiatedProtocolVersion.getValue() + : preferredProtocolVersion.getValue(), + baseUrl, + sessionError); + } + } + + private TransportManifest handleListToolsResponse( + final HttpResponse response, TelemetryHelper.OperationSpan span) { + if (response.statusCode() != 200) { + if (span != null) { + span.recordError( + response.statusCode(), + "Failed to list tools. Status: " + response.statusCode() + " " + response.body()); + } + throw new RuntimeException( + "Failed to list tools. Status: " + response.statusCode() + " " + response.body()); + } + try { + JsonNode root = objectMapper.readTree(response.body()); + if (root.has("error")) { + JsonNode errNode = root.get("error"); + int code = errNode.has("code") ? errNode.get("code").asInt() : -1; + String msg = errNode.has("message") ? errNode.get("message").asText() : errNode.toString(); + if (span != null) { + span.recordError(code, msg); + } + throw new RuntimeException("MCP Error: " + msg); + } + JsonNode result = root.get("result"); + JsonNode toolsNode = result.get("tools"); + + Map toolsMap = new HashMap<>(); + if (toolsNode != null && toolsNode.isArray()) { + for (JsonNode toolNode : toolsNode) { + String name = toolNode.get("name").asText(); + String description = + toolNode.has("description") ? toolNode.get("description").asText() : ""; + + List authRequired = new ArrayList<>(); + JsonNode metaNode = toolNode.get("_meta"); + if (metaNode != null && metaNode.has("toolbox/authInvoke")) { + JsonNode invokeAuthNode = metaNode.get("toolbox/authInvoke"); + if (invokeAuthNode != null && invokeAuthNode.isArray()) { + for (JsonNode src : invokeAuthNode) { + authRequired.add(src.asText()); + } + } + } + + List params = new ArrayList<>(); + JsonNode inputSchema = toolNode.get("inputSchema"); + JsonNode requiredNode = inputSchema != null ? inputSchema.get("required") : null; + Set requiredSet = new HashSet<>(); + if (requiredNode != null && requiredNode.isArray()) { + for (JsonNode req : requiredNode) { + requiredSet.add(req.asText()); + } + } + + JsonNode propertiesNode = inputSchema != null ? inputSchema.get("properties") : null; + if (propertiesNode != null && propertiesNode.isObject()) { + Iterator> fields = propertiesNode.fields(); + while (fields.hasNext()) { + Map.Entry entry = fields.next(); + String paramName = entry.getKey(); + JsonNode propNode = entry.getValue(); + + String paramType = propNode.has("type") ? propNode.get("type").asText() : "string"; + String paramDesc = + propNode.has("description") ? propNode.get("description").asText() : ""; + + List authSources = new ArrayList<>(); + if (metaNode != null && metaNode.has("toolbox/authParam")) { + JsonNode paramAuthNode = metaNode.get("toolbox/authParam").get(paramName); + if (paramAuthNode != null && paramAuthNode.isArray()) { + for (JsonNode src : paramAuthNode) { + authSources.add(src.asText()); + } + } + } + + Object defaultValue = null; + if (propNode.has("default")) { + JsonNode defNode = propNode.get("default"); + defaultValue = objectMapper.treeToValue(defNode, Object.class); + } + + params.add( + new ToolDefinition.Parameter( + paramName, + paramType, + requiredSet.contains(paramName), + paramDesc, + authSources, + defaultValue)); + } + } + + Boolean readOnlyHint = + toolNode.has("readOnlyHint") ? toolNode.get("readOnlyHint").asBoolean() : null; + Boolean destructiveHint = + toolNode.has("destructiveHint") ? toolNode.get("destructiveHint").asBoolean() : null; + + toolsMap.put( + name, + new ToolDefinition(description, params, authRequired, readOnlyHint, destructiveHint)); + } + } + return new TransportManifest(toolsMap); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/src/main/java/com/google/cloud/mcp/transport/HttpMcpTransport.java b/src/main/java/com/google/cloud/mcp/transport/HttpMcpTransport.java new file mode 100644 index 0000000..d5833fc --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/transport/HttpMcpTransport.java @@ -0,0 +1,175 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.transport; + +import com.google.cloud.mcp.ProtocolVersion; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.transport.v20241105.HttpMcpTransportV20241105; +import com.google.cloud.mcp.transport.v20250326.HttpMcpTransportV20250326; +import com.google.cloud.mcp.transport.v20250618.HttpMcpTransportV20250618; +import com.google.cloud.mcp.transport.v20251125.HttpMcpTransportV20251125; +import java.net.http.HttpClient; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +/** Default HTTP transport implementation routing requests to version-specific handlers. */ +public final class HttpMcpTransport implements Transport { + + private final Transport delegate; + + /** + * Constructs a new HttpMcpTransport with a base URL. + * + * @param baseUrl The base URL of the remote service. + */ + public HttpMcpTransport(final String baseUrl) { + this(baseUrl, Map.of(), (CredentialsProvider) null); + } + + /** + * Constructs a new HttpMcpTransport with base URL and default headers. + * + * @param baseUrl The base URL of the remote service. + * @param clientHeaders Default HTTP headers to include in every request. + */ + public HttpMcpTransport(final String baseUrl, final Map clientHeaders) { + this(baseUrl, clientHeaders, (CredentialsProvider) null); + } + + /** + * Constructs a new HttpMcpTransport with base URL, default headers and credentials provider. + * + * @param baseUrl The base URL of the remote service. + * @param clientHeaders Default HTTP headers to include in every request. + * @param credentialsProvider Provider for retrieving authorization credentials. + */ + public HttpMcpTransport( + final String baseUrl, + final Map clientHeaders, + final CredentialsProvider credentialsProvider) { + this(baseUrl, clientHeaders, credentialsProvider, null, null, null); + } + + /** + * Constructs a HttpMcpTransport. + * + * @param baseUrl The base URL of the remote service. + * @param clientHeaders Default HTTP headers to include in every request. + * @param preferredProtocolVersion Preferred MCP protocol version. + * @param httpClient Custom HTTP Client. + * @param executor Optional Executor for handling async requests. + */ + public HttpMcpTransport( + final String baseUrl, + final Map clientHeaders, + final ProtocolVersion preferredProtocolVersion, + final HttpClient httpClient, + final java.util.concurrent.Executor executor) { + this(baseUrl, clientHeaders, null, preferredProtocolVersion, httpClient, executor); + } + + /** + * Primary constructor for HttpMcpTransport. + * + * @param baseUrl The base URL of the remote service. + * @param clientHeaders Default HTTP headers to include in every request. + * @param credentialsProvider Provider for retrieving authorization credentials. + * @param preferredProtocolVersion Preferred MCP protocol version. + * @param httpClient Custom HTTP Client. + * @param executor Optional Executor for handling async requests. + */ + public HttpMcpTransport( + final String baseUrl, + final Map clientHeaders, + final CredentialsProvider credentialsProvider, + final ProtocolVersion preferredProtocolVersion, + final HttpClient httpClient, + final java.util.concurrent.Executor executor) { + final ProtocolVersion version = + preferredProtocolVersion != null + ? preferredProtocolVersion + : ProtocolVersion.VERSION_2025_11_25; + + switch (version) { + case VERSION_2025_11_25: + this.delegate = + new HttpMcpTransportV20251125( + baseUrl, clientHeaders, credentialsProvider, httpClient, executor); + break; + case VERSION_2025_06_18: + this.delegate = + new HttpMcpTransportV20250618( + baseUrl, clientHeaders, credentialsProvider, httpClient, executor); + break; + case VERSION_2025_03_26: + this.delegate = + new HttpMcpTransportV20250326( + baseUrl, clientHeaders, credentialsProvider, httpClient, executor); + break; + case VERSION_2024_11_05: + this.delegate = + new HttpMcpTransportV20241105( + baseUrl, clientHeaders, credentialsProvider, httpClient, executor); + break; + default: + throw new IllegalArgumentException("Unsupported protocol version: " + version); + } + } + + /** Internal constructor for testing purposes. */ + public HttpMcpTransport(final String baseUrl, final HttpClient httpClient) { + this(baseUrl, Map.of(), null, null, httpClient, null); + } + + /** Internal constructor for testing purposes. */ + public HttpMcpTransport( + final String baseUrl, final Map clientHeaders, final HttpClient httpClient) { + this(baseUrl, clientHeaders, null, null, httpClient, null); + } + + HttpMcpTransport( + final String baseUrl, + final Map clientHeaders, + final CredentialsProvider credentialsProvider, + final HttpClient httpClient) { + this(baseUrl, clientHeaders, credentialsProvider, null, httpClient, null); + } + + @Override + public String getBaseUrl() { + return delegate.getBaseUrl(); + } + + @Override + public CompletableFuture listTools( + final String toolsetName, final Map metadata) { + return delegate.listTools(toolsetName, metadata); + } + + @Override + public CompletableFuture invokeTool( + final String toolName, + final Map arguments, + final Map metadata) { + return delegate.invokeTool(toolName, arguments, metadata); + } + + @Override + public void close() { + delegate.close(); + } +} diff --git a/src/main/java/com/google/cloud/mcp/transport/Transport.java b/src/main/java/com/google/cloud/mcp/transport/Transport.java new file mode 100644 index 0000000..37ac88b --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/transport/Transport.java @@ -0,0 +1,56 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.transport; + +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +/** + * Defines the contract for an MCP transport layer that manages protocol-level formatting and + * network communication. + */ +public interface Transport { + /** + * Returns the base URL of the remote service. + * + * @return The base URL string. + */ + String getBaseUrl(); + + /** + * Asynchronously fetches available tools from the server. + * + * @param toolsetName The name of the toolset to load (optional). + * @param metadata Request metadata or extra options to include. + * @return A CompletableFuture containing the raw DTO manifest. + */ + CompletableFuture listTools(String toolsetName, Map metadata); + + /** + * Asynchronously invokes a tool on the server. + * + * @param toolName The name of the tool to invoke. + * @param arguments The arguments to pass to the tool. + * @param metadata Request metadata or extra options to include. + * @return A CompletableFuture containing the raw TransportResponse result of the tool execution. + */ + CompletableFuture invokeTool( + String toolName, Map arguments, Map metadata); + + /** Closes any underlying network connections/resources. */ + void close(); +} diff --git a/src/main/java/com/google/cloud/mcp/transport/TransportManifest.java b/src/main/java/com/google/cloud/mcp/transport/TransportManifest.java new file mode 100644 index 0000000..f8a8dac --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/transport/TransportManifest.java @@ -0,0 +1,43 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.transport; + +import com.google.cloud.mcp.tool.ToolDefinition; +import java.util.Map; + +/** Represents the raw tools manifest returned by the transport. */ +public final class TransportManifest { + private final Map tools; + + /** + * Constructs a new TransportManifest with a map of tool definitions. + * + * @param tools Map of tool name to definition. + */ + public TransportManifest(Map tools) { + this.tools = tools; + } + + /** + * Returns the map of tools in the manifest. + * + * @return The tools map. + */ + public Map getTools() { + return tools; + } +} diff --git a/src/main/java/com/google/cloud/mcp/transport/TransportResponse.java b/src/main/java/com/google/cloud/mcp/transport/TransportResponse.java new file mode 100644 index 0000000..5044af4 --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/transport/TransportResponse.java @@ -0,0 +1,52 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.transport; + +/** Represents a raw transport response containing status code and response body. */ +public final class TransportResponse { + private final int statusCode; + private final String body; + + /** + * Constructs a new TransportResponse. + * + * @param statusCode The HTTP status code. + * @param body The response body. + */ + public TransportResponse(int statusCode, String body) { + this.statusCode = statusCode; + this.body = body; + } + + /** + * Returns the status code. + * + * @return The status code. + */ + public int getStatusCode() { + return statusCode; + } + + /** + * Returns the response body. + * + * @return The response body. + */ + public String getBody() { + return body; + } +} diff --git a/src/main/java/com/google/cloud/mcp/transport/v20241105/HttpMcpTransportV20241105.java b/src/main/java/com/google/cloud/mcp/transport/v20241105/HttpMcpTransportV20241105.java new file mode 100644 index 0000000..4fdb288 --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/transport/v20241105/HttpMcpTransportV20241105.java @@ -0,0 +1,137 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.transport.v20241105; + +import com.fasterxml.jackson.databind.JsonNode; +import com.google.cloud.mcp.JsonRpc; +import com.google.cloud.mcp.ProtocolVersion; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.exception.McpException; +import com.google.cloud.mcp.transport.BaseMcpTransport; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +public final class HttpMcpTransportV20241105 extends BaseMcpTransport { + + public HttpMcpTransportV20241105( + final String baseUrl, + final Map clientHeaders, + final CredentialsProvider credentialsProvider, + final HttpClient httpClient, + final java.util.concurrent.Executor executor) { + super( + baseUrl, + clientHeaders, + credentialsProvider, + ProtocolVersion.VERSION_2024_11_05, + httpClient, + executor); + } + + @Override + protected CompletableFuture performInitialization( + final String authHeader, + final Map handshakeHeaders, + final Map traceHeaders) { + try { + if (this.baseUrl.toLowerCase(java.util.Locale.ROOT).startsWith("http://") + && authHeader != null) { + logger.warning(HTTP_WARNING); + } + JsonRpc.RequestMetadata metadata = + new JsonRpc.RequestMetadata( + traceHeaders.get("traceparent"), traceHeaders.get("tracestate")); + JsonRpc.Request initReq = + new JsonRpc.Request( + "initialize", + new JsonRpc.InitializeParams( + ProtocolVersion.VERSION_2024_11_05.getValue(), "mcp-toolbox-sdk-java", metadata)); + String body = objectMapper.writeValueAsString(initReq); + HttpRequest.Builder req = + HttpRequest.newBuilder() + .uri(URI.create(baseUrl)) + .POST(HttpRequest.BodyPublishers.ofString(body)); + + handshakeHeaders.forEach(req::setHeader); + applyProtocolHeaders(req); + + return httpClient + .sendAsync(req.build(), HttpResponse.BodyHandlers.ofString()) + .thenCompose( + res -> { + if (res.statusCode() != 200) { + return CompletableFuture.failedFuture( + new McpException("Init failed: " + res.statusCode() + " " + res.body())); + } + try { + JsonNode responseJson = objectMapper.readTree(res.body()); + if (responseJson.has("error")) { + return CompletableFuture.failedFuture( + new McpException("MCP Error: " + responseJson.get("error").toString())); + } + JsonNode result = responseJson.get("result"); + String serverVersion; + if (result != null && result.has("protocolVersion")) { + serverVersion = result.get("protocolVersion").asText(); + } else { + serverVersion = ProtocolVersion.VERSION_2024_11_05.getValue(); + } + + if (!ProtocolVersion.VERSION_2024_11_05.getValue().equals(serverVersion)) { + return CompletableFuture.failedFuture( + new McpException( + "MCP version mismatch: client (" + + ProtocolVersion.VERSION_2024_11_05.getValue() + + ") != server (" + + serverVersion + + ")")); + } + + this.negotiatedProtocolVersion = ProtocolVersion.VERSION_2024_11_05; + + JsonRpc.Notification notif = + new JsonRpc.Notification("notifications/initialized", Map.of()); + String notifBody = objectMapper.writeValueAsString(notif); + HttpRequest.Builder nReq = + HttpRequest.newBuilder() + .uri(URI.create(baseUrl)) + .POST(HttpRequest.BodyPublishers.ofString(notifBody)); + + handshakeHeaders.forEach(nReq::setHeader); + applyProtocolHeaders(nReq); + + return httpClient + .sendAsync(nReq.build(), HttpResponse.BodyHandlers.ofString()) + .thenAccept(nRes -> {}); + } catch (Exception e) { + return CompletableFuture.failedFuture(e); + } + }); + } catch (Exception e) { + return CompletableFuture.failedFuture(e); + } + } + + @Override + protected void applyProtocolHeaders(final HttpRequest.Builder builder) { + builder.header("Content-Type", "application/json"); + } +} diff --git a/src/main/java/com/google/cloud/mcp/transport/v20250326/HttpMcpTransportV20250326.java b/src/main/java/com/google/cloud/mcp/transport/v20250326/HttpMcpTransportV20250326.java new file mode 100644 index 0000000..b3e46ca --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/transport/v20250326/HttpMcpTransportV20250326.java @@ -0,0 +1,153 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.transport.v20250326; + +import com.fasterxml.jackson.databind.JsonNode; +import com.google.cloud.mcp.JsonRpc; +import com.google.cloud.mcp.ProtocolVersion; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.exception.McpException; +import com.google.cloud.mcp.transport.BaseMcpTransport; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +public final class HttpMcpTransportV20250326 extends BaseMcpTransport { + + private volatile String sessionId; + + public HttpMcpTransportV20250326( + final String baseUrl, + final Map clientHeaders, + final CredentialsProvider credentialsProvider, + final HttpClient httpClient, + final java.util.concurrent.Executor executor) { + super( + baseUrl, + clientHeaders, + credentialsProvider, + ProtocolVersion.VERSION_2025_03_26, + httpClient, + executor); + } + + @Override + protected CompletableFuture performInitialization( + final String authHeader, + final Map handshakeHeaders, + final Map traceHeaders) { + try { + if (this.baseUrl.toLowerCase(java.util.Locale.ROOT).startsWith("http://") + && authHeader != null) { + logger.warning(HTTP_WARNING); + } + JsonRpc.RequestMetadata metadata = + new JsonRpc.RequestMetadata( + traceHeaders.get("traceparent"), traceHeaders.get("tracestate")); + JsonRpc.Request initReq = + new JsonRpc.Request( + "initialize", + new JsonRpc.InitializeParams( + ProtocolVersion.VERSION_2025_03_26.getValue(), "mcp-toolbox-sdk-java", metadata)); + String body = objectMapper.writeValueAsString(initReq); + HttpRequest.Builder req = + HttpRequest.newBuilder() + .uri(URI.create(baseUrl)) + .POST(HttpRequest.BodyPublishers.ofString(body)); + + handshakeHeaders.forEach(req::setHeader); + applyProtocolHeaders(req); + + return httpClient + .sendAsync(req.build(), HttpResponse.BodyHandlers.ofString()) + .thenCompose( + res -> { + if (res.statusCode() != 200) { + return CompletableFuture.failedFuture( + new McpException("Init failed: " + res.statusCode() + " " + res.body())); + } + try { + JsonNode responseJson = objectMapper.readTree(res.body()); + if (responseJson.has("error")) { + return CompletableFuture.failedFuture( + new McpException("MCP Error: " + responseJson.get("error").toString())); + } + JsonNode result = responseJson.get("result"); + String serverVersion; + if (result != null && result.has("protocolVersion")) { + serverVersion = result.get("protocolVersion").asText(); + } else { + serverVersion = ProtocolVersion.VERSION_2025_03_26.getValue(); + } + + if (!ProtocolVersion.VERSION_2025_03_26.getValue().equals(serverVersion)) { + return CompletableFuture.failedFuture( + new McpException( + "MCP version mismatch: client (" + + ProtocolVersion.VERSION_2025_03_26.getValue() + + ") != server (" + + serverVersion + + ")")); + } + + Optional sessionIdOpt = res.headers().firstValue("Mcp-Session-Id"); + if (sessionIdOpt.isEmpty()) { + return CompletableFuture.failedFuture( + new McpException( + "Server did not return a Mcp-Session-Id header during" + + " initialization.")); + } + this.sessionId = sessionIdOpt.get(); + + this.negotiatedProtocolVersion = ProtocolVersion.VERSION_2025_03_26; + + JsonRpc.Notification notif = + new JsonRpc.Notification("notifications/initialized", Map.of()); + String notifBody = objectMapper.writeValueAsString(notif); + HttpRequest.Builder nReq = + HttpRequest.newBuilder() + .uri(URI.create(baseUrl)) + .POST(HttpRequest.BodyPublishers.ofString(notifBody)); + + handshakeHeaders.forEach(nReq::setHeader); + applyProtocolHeaders(nReq); + + return httpClient + .sendAsync(nReq.build(), HttpResponse.BodyHandlers.ofString()) + .thenAccept(nRes -> {}); + } catch (Exception e) { + return CompletableFuture.failedFuture(e); + } + }); + } catch (Exception e) { + return CompletableFuture.failedFuture(e); + } + } + + @Override + protected void applyProtocolHeaders(final HttpRequest.Builder builder) { + builder.header("Content-Type", "application/json"); + builder.header("Accept", "application/json"); + if (sessionId != null) { + builder.header("Mcp-Session-Id", sessionId); + } + } +} diff --git a/src/main/java/com/google/cloud/mcp/transport/v20250618/HttpMcpTransportV20250618.java b/src/main/java/com/google/cloud/mcp/transport/v20250618/HttpMcpTransportV20250618.java new file mode 100644 index 0000000..4973068 --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/transport/v20250618/HttpMcpTransportV20250618.java @@ -0,0 +1,139 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.transport.v20250618; + +import com.fasterxml.jackson.databind.JsonNode; +import com.google.cloud.mcp.JsonRpc; +import com.google.cloud.mcp.ProtocolVersion; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.exception.McpException; +import com.google.cloud.mcp.transport.BaseMcpTransport; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +public final class HttpMcpTransportV20250618 extends BaseMcpTransport { + + public HttpMcpTransportV20250618( + final String baseUrl, + final Map clientHeaders, + final CredentialsProvider credentialsProvider, + final HttpClient httpClient, + final java.util.concurrent.Executor executor) { + super( + baseUrl, + clientHeaders, + credentialsProvider, + ProtocolVersion.VERSION_2025_06_18, + httpClient, + executor); + } + + @Override + protected CompletableFuture performInitialization( + final String authHeader, + final Map handshakeHeaders, + final Map traceHeaders) { + try { + if (this.baseUrl.toLowerCase(java.util.Locale.ROOT).startsWith("http://") + && authHeader != null) { + logger.warning(HTTP_WARNING); + } + JsonRpc.RequestMetadata metadata = + new JsonRpc.RequestMetadata( + traceHeaders.get("traceparent"), traceHeaders.get("tracestate")); + JsonRpc.Request initReq = + new JsonRpc.Request( + "initialize", + new JsonRpc.InitializeParams( + ProtocolVersion.VERSION_2025_06_18.getValue(), "mcp-toolbox-sdk-java", metadata)); + String body = objectMapper.writeValueAsString(initReq); + HttpRequest.Builder req = + HttpRequest.newBuilder() + .uri(URI.create(baseUrl)) + .POST(HttpRequest.BodyPublishers.ofString(body)); + + handshakeHeaders.forEach(req::setHeader); + applyProtocolHeaders(req); + + return httpClient + .sendAsync(req.build(), HttpResponse.BodyHandlers.ofString()) + .thenCompose( + res -> { + if (res.statusCode() != 200) { + return CompletableFuture.failedFuture( + new McpException("Init failed: " + res.statusCode() + " " + res.body())); + } + try { + JsonNode responseJson = objectMapper.readTree(res.body()); + if (responseJson.has("error")) { + return CompletableFuture.failedFuture( + new McpException("MCP Error: " + responseJson.get("error").toString())); + } + JsonNode result = responseJson.get("result"); + String serverVersion; + if (result != null && result.has("protocolVersion")) { + serverVersion = result.get("protocolVersion").asText(); + } else { + serverVersion = ProtocolVersion.VERSION_2025_06_18.getValue(); + } + + if (!ProtocolVersion.VERSION_2025_06_18.getValue().equals(serverVersion)) { + return CompletableFuture.failedFuture( + new McpException( + "MCP version mismatch: client (" + + ProtocolVersion.VERSION_2025_06_18.getValue() + + ") != server (" + + serverVersion + + ")")); + } + + this.negotiatedProtocolVersion = ProtocolVersion.VERSION_2025_06_18; + + JsonRpc.Notification notif = + new JsonRpc.Notification("notifications/initialized", Map.of()); + String notifBody = objectMapper.writeValueAsString(notif); + HttpRequest.Builder nReq = + HttpRequest.newBuilder() + .uri(URI.create(baseUrl)) + .POST(HttpRequest.BodyPublishers.ofString(notifBody)); + + handshakeHeaders.forEach(nReq::setHeader); + applyProtocolHeaders(nReq); + + return httpClient + .sendAsync(nReq.build(), HttpResponse.BodyHandlers.ofString()) + .thenAccept(nRes -> {}); + } catch (Exception e) { + return CompletableFuture.failedFuture(e); + } + }); + } catch (Exception e) { + return CompletableFuture.failedFuture(e); + } + } + + @Override + protected void applyProtocolHeaders(final HttpRequest.Builder builder) { + builder.header("Content-Type", "application/json"); + builder.header("Accept", "application/json"); + builder.header("MCP-Protocol-Version", ProtocolVersion.VERSION_2025_06_18.getValue()); + } +} diff --git a/src/main/java/com/google/cloud/mcp/transport/v20251125/HttpMcpTransportV20251125.java b/src/main/java/com/google/cloud/mcp/transport/v20251125/HttpMcpTransportV20251125.java new file mode 100644 index 0000000..4a3a5ea --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/transport/v20251125/HttpMcpTransportV20251125.java @@ -0,0 +1,139 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.transport.v20251125; + +import com.fasterxml.jackson.databind.JsonNode; +import com.google.cloud.mcp.JsonRpc; +import com.google.cloud.mcp.ProtocolVersion; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.exception.McpException; +import com.google.cloud.mcp.transport.BaseMcpTransport; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +public final class HttpMcpTransportV20251125 extends BaseMcpTransport { + + public HttpMcpTransportV20251125( + final String baseUrl, + final Map clientHeaders, + final CredentialsProvider credentialsProvider, + final HttpClient httpClient, + final java.util.concurrent.Executor executor) { + super( + baseUrl, + clientHeaders, + credentialsProvider, + ProtocolVersion.VERSION_2025_11_25, + httpClient, + executor); + } + + @Override + protected CompletableFuture performInitialization( + final String authHeader, + final Map handshakeHeaders, + final Map traceHeaders) { + try { + if (this.baseUrl.toLowerCase(java.util.Locale.ROOT).startsWith("http://") + && authHeader != null) { + logger.warning(HTTP_WARNING); + } + JsonRpc.RequestMetadata metadata = + new JsonRpc.RequestMetadata( + traceHeaders.get("traceparent"), traceHeaders.get("tracestate")); + JsonRpc.Request initReq = + new JsonRpc.Request( + "initialize", + new JsonRpc.InitializeParams( + ProtocolVersion.VERSION_2025_11_25.getValue(), "mcp-toolbox-sdk-java", metadata)); + String body = objectMapper.writeValueAsString(initReq); + HttpRequest.Builder req = + HttpRequest.newBuilder() + .uri(URI.create(baseUrl)) + .POST(HttpRequest.BodyPublishers.ofString(body)); + + handshakeHeaders.forEach(req::setHeader); + applyProtocolHeaders(req); + + return httpClient + .sendAsync(req.build(), HttpResponse.BodyHandlers.ofString()) + .thenCompose( + res -> { + if (res.statusCode() != 200) { + return CompletableFuture.failedFuture( + new McpException("Init failed: " + res.statusCode() + " " + res.body())); + } + try { + JsonNode responseJson = objectMapper.readTree(res.body()); + if (responseJson.has("error")) { + return CompletableFuture.failedFuture( + new McpException("MCP Error: " + responseJson.get("error").toString())); + } + JsonNode result = responseJson.get("result"); + String serverVersion; + if (result != null && result.has("protocolVersion")) { + serverVersion = result.get("protocolVersion").asText(); + } else { + serverVersion = ProtocolVersion.VERSION_2025_11_25.getValue(); + } + + if (!ProtocolVersion.VERSION_2025_11_25.getValue().equals(serverVersion)) { + return CompletableFuture.failedFuture( + new McpException( + "MCP version mismatch: client (" + + ProtocolVersion.VERSION_2025_11_25.getValue() + + ") != server (" + + serverVersion + + ")")); + } + + this.negotiatedProtocolVersion = ProtocolVersion.VERSION_2025_11_25; + + JsonRpc.Notification notif = + new JsonRpc.Notification("notifications/initialized", Map.of()); + String notifBody = objectMapper.writeValueAsString(notif); + HttpRequest.Builder nReq = + HttpRequest.newBuilder() + .uri(URI.create(baseUrl)) + .POST(HttpRequest.BodyPublishers.ofString(notifBody)); + + handshakeHeaders.forEach(nReq::setHeader); + applyProtocolHeaders(nReq); + + return httpClient + .sendAsync(nReq.build(), HttpResponse.BodyHandlers.ofString()) + .thenAccept(nRes -> {}); + } catch (Exception e) { + return CompletableFuture.failedFuture(e); + } + }); + } catch (Exception e) { + return CompletableFuture.failedFuture(e); + } + } + + @Override + protected void applyProtocolHeaders(final HttpRequest.Builder builder) { + builder.header("Content-Type", "application/json"); + builder.header("Accept", "application/json"); + builder.header("MCP-Protocol-Version", ProtocolVersion.VERSION_2025_11_25.getValue()); + } +} diff --git a/src/test/java/com/google/cloud/mcp/McpCoverageTest.java b/src/test/java/com/google/cloud/mcp/McpCoverageTest.java new file mode 100644 index 0000000..34e6f73 --- /dev/null +++ b/src/test/java/com/google/cloud/mcp/McpCoverageTest.java @@ -0,0 +1,114 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; + +import com.google.cloud.mcp.auth.AuthTokenGetter; +import com.google.cloud.mcp.client.McpToolboxClientImpl; +import com.google.cloud.mcp.exception.McpException; +import com.google.cloud.mcp.tool.Tool; +import com.google.cloud.mcp.tool.ToolDefinition; +import com.google.cloud.mcp.tool.ToolResult; +import com.google.cloud.mcp.transport.Transport; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +/** Miscellaneous unit tests to achieve 100% code coverage. */ +@Timeout(5) +public class McpCoverageTest { + + @Test + public void testMcpExceptionCoverage() { + McpException ex = new McpException("error message", new RuntimeException("cause")); + assertEquals("error message", ex.getMessage()); + assertEquals("cause", ex.getCause().getMessage()); + } + + @Test + public void testMcpToolboxClientDefaultClose() { + McpToolboxClient dummyClient = + new McpToolboxClient() { + @Override + public CompletableFuture> listTools() { + return null; + } + + @Override + public CompletableFuture> loadToolset(String name) { + return null; + } + + @Override + public CompletableFuture> loadToolset( + String name, + Map> p, + Map> a, + boolean s) { + return null; + } + + @Override + public CompletableFuture loadTool(String name) { + return null; + } + + @Override + public CompletableFuture loadTool( + String name, Map getters) { + return null; + } + + @Override + public CompletableFuture invokeTool(String name, Map args) { + return null; + } + + @Override + public CompletableFuture invokeTool( + String name, Map args, Map headers) { + return null; + } + }; + // Call default close (no-op) + dummyClient.close(); + } + + @Test + public void testMcpToolboxClientImplCloseThrowsException() throws Exception { + Transport mockTransport = mock(Transport.class); + doThrow(new RuntimeException("transport close error")).when(mockTransport).close(); + + McpToolboxClientImpl client = new McpToolboxClientImpl(mockTransport, java.util.Map.of(), null); + McpException ex = assertThrows(McpException.class, client::close); + assertEquals("Failed to close transport", ex.getMessage()); + assertEquals("transport close error", ex.getCause().getMessage()); + } + + @Test + public void testProtocolVersionFromString() { + assertNull(ProtocolVersion.fromString(null)); + assertNull(ProtocolVersion.fromString("invalid-version-string")); + assertEquals(ProtocolVersion.VERSION_2025_11_25, ProtocolVersion.fromString("2025-11-25")); + } +} diff --git a/src/test/java/com/google/cloud/mcp/TelemetryTest.java b/src/test/java/com/google/cloud/mcp/TelemetryTest.java new file mode 100644 index 0000000..9749bd9 --- /dev/null +++ b/src/test/java/com/google/cloud/mcp/TelemetryTest.java @@ -0,0 +1,245 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.cloud.mcp.tool.ToolDefinition; +import com.google.cloud.mcp.tool.ToolResult; +import com.sun.net.httpserver.HttpServer; +import io.opentelemetry.api.common.AttributeKey; +import io.opentelemetry.sdk.testing.junit5.OpenTelemetryExtension; +import io.opentelemetry.sdk.trace.data.SpanData; +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.extension.RegisterExtension; + +@Timeout(value = 15, unit = TimeUnit.SECONDS) +public class TelemetryTest { + + @RegisterExtension + static final OpenTelemetryExtension otelTesting = OpenTelemetryExtension.create(); + + private HttpServer server; + private String serverUrl; + private final List receivedRequests = Collections.synchronizedList(new ArrayList<>()); + private final ObjectMapper mapper = new ObjectMapper(); + + @BeforeEach + public void setUp() throws Exception { + receivedRequests.clear(); + server = HttpServer.create(new InetSocketAddress("localhost", 0), 0); + server.createContext( + "/mcp", + exchange -> { + try { + byte[] reqBytes = exchange.getRequestBody().readAllBytes(); + JsonNode reqNode = mapper.readTree(reqBytes); + receivedRequests.add(reqNode); + + String method = reqNode.has("method") ? reqNode.get("method").asText() : ""; + String responseBody = "{\"jsonrpc\":\"2.0\",\"id\":\"1\",\"result\":{}}"; + + if ("tools/list".equals(method)) { + responseBody = + "{\n" + + " \"jsonrpc\": \"2.0\",\n" + + " \"id\": \"1\",\n" + + " \"result\": {\n" + + " \"tools\": [\n" + + " {\n" + + " \"name\": \"test-tool\",\n" + + " \"description\": \"A test tool\",\n" + + " \"inputSchema\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {}\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + "}"; + } else if ("tools/call".equals(method)) { + responseBody = + "{\n" + + " \"jsonrpc\": \"2.0\",\n" + + " \"id\": \"1\",\n" + + " \"result\": {\n" + + " \"content\": [\n" + + " {\n" + + " \"type\": \"text\",\n" + + " \"text\": \"Success\"\n" + + " }\n" + + " ],\n" + + " \"isError\": false\n" + + " }\n" + + "}"; + } + + exchange.getResponseHeaders().set("Content-Type", "application/json"); + byte[] responseBytes = responseBody.getBytes(); + exchange.sendResponseHeaders(200, responseBytes.length); + try (OutputStream os = exchange.getResponseBody()) { + os.write(responseBytes); + } + } catch (Exception e) { + exchange.sendResponseHeaders(500, 0); + exchange.close(); + } + }); + server.start(); + int port = server.getAddress().getPort(); + serverUrl = "http://localhost:" + port + "/mcp"; + } + + @AfterEach + public void tearDown() { + if (server != null) { + server.stop(0); + } + } + + @Test + public void testTelemetrySpansAndContextPropagation() throws Exception { + try (McpToolboxClient client = McpToolboxClient.builder().baseUrl(serverUrl).build()) { + // 1. Load toolset (triggers initialize and tools/list) + Map tools = client.loadToolset().get(); + assertNotNull(tools); + assertTrue(tools.containsKey("test-tool")); + + // 2. Invoke tool + ToolResult result = client.invokeTool("test-tool", Map.of()).get(); + assertNotNull(result); + assertFalse(result.isError()); + } + + // Verify Spans were created + List spans = otelTesting.getSpans(); + + // Spans should be: "initialize", "tools/list", "tools/call test-tool" + assertTrue(spans.stream().anyMatch(s -> "initialize".equals(s.getName()))); + assertTrue(spans.stream().anyMatch(s -> "tools/list".equals(s.getName()))); + assertTrue(spans.stream().anyMatch(s -> "tools/call test-tool".equals(s.getName()))); + + SpanData initSpan = + spans.stream().filter(s -> "initialize".equals(s.getName())).findFirst().orElseThrow(); + SpanData listSpan = + spans.stream().filter(s -> "tools/list".equals(s.getName())).findFirst().orElseThrow(); + SpanData callSpan = + spans.stream() + .filter(s -> "tools/call test-tool".equals(s.getName())) + .findFirst() + .orElseThrow(); + + // Verify Span attributes + assertEquals( + "initialize", initSpan.getAttributes().get(AttributeKey.stringKey("mcp.method.name"))); + assertEquals( + "tools/list", listSpan.getAttributes().get(AttributeKey.stringKey("mcp.method.name"))); + assertEquals( + "tools/call", callSpan.getAttributes().get(AttributeKey.stringKey("mcp.method.name"))); + assertEquals( + "test-tool", callSpan.getAttributes().get(AttributeKey.stringKey("gen_ai.tool.name"))); + + // Verify context propagation in JSON-RPC metadata + // Note: invokeTool does not trigger initialization again since it was already initialized + // So invokeTool adds tools/call request, making it 4 requests total. + // Wait, let's verify if the list size is 4. + // index 0: initialize (Request) + // index 1: notifications/initialized (Notification) + // index 2: tools/list (Request) + // index 3: tools/call (Request) + assertEquals(4, receivedRequests.size()); + + JsonNode initReq = receivedRequests.get(0); + JsonNode listReq = receivedRequests.get(2); + JsonNode callReq = receivedRequests.get(3); + + // Verify traceparent in requests matches the span's traceId + String initTraceParent = initReq.get("params").get("_meta").get("traceparent").asText(); + assertNotNull(initTraceParent); + assertTrue(initTraceParent.contains(initSpan.getTraceId())); + + String listTraceParent = listReq.get("params").get("_meta").get("traceparent").asText(); + assertNotNull(listTraceParent); + assertTrue(listTraceParent.contains(listSpan.getTraceId())); + + String callTraceParent = callReq.get("params").get("_meta").get("traceparent").asText(); + assertNotNull(callTraceParent); + assertTrue(callTraceParent.contains(callSpan.getTraceId())); + } + + @Test + public void testTelemetryHelperEdgeCases() { + // 1. Test ServerInfo record methods (equals, hashCode, toString, and accessors) + TelemetryHelper.ServerInfo info1 = new TelemetryHelper.ServerInfo("localhost", 8080, "http"); + TelemetryHelper.ServerInfo info2 = new TelemetryHelper.ServerInfo("localhost", 8080, "http"); + TelemetryHelper.ServerInfo info3 = new TelemetryHelper.ServerInfo("example.com", 9090, "https"); + + assertEquals(info1, info2); + assertNotEquals(info1, info3); + assertEquals(info1.hashCode(), info2.hashCode()); + assertNotNull(info1.toString()); + assertEquals("localhost", info1.address()); + assertEquals(8080, info1.port()); + assertEquals("http", info1.protocol()); + + // 2. Test extractServerInfo with various edge-case URLs + TelemetryHelper.ServerInfo invalid = TelemetryHelper.extractServerInfo(":::"); + assertEquals("", invalid.address()); + assertNull(invalid.port()); + assertEquals("http", invalid.protocol()); + + TelemetryHelper.ServerInfo noHost = TelemetryHelper.extractServerInfo("http:///mcp"); + assertEquals("", noHost.address()); + assertNull(noHost.port()); + + TelemetryHelper.ServerInfo noHostWithPort = + TelemetryHelper.extractServerInfo("http://my_server:8080"); + assertEquals("my_server", noHostWithPort.address()); + assertEquals(8080, noHostWithPort.port()); + + TelemetryHelper.ServerInfo invalidPort = + TelemetryHelper.extractServerInfo("http://my_server:invalidport"); + assertEquals("my_server", invalidPort.address()); + assertNull(invalidPort.port()); + + TelemetryHelper.ServerInfo noProtocol = TelemetryHelper.extractServerInfo("//localhost:8080"); + assertEquals("localhost", noProtocol.address()); + assertEquals(8080, noProtocol.port()); + assertEquals("http", noProtocol.protocol()); + + // 3. Test recordSessionDuration with error + TelemetryHelper.recordSessionDuration( + 5.5, "2025-11-25", "http://localhost:8080", new RuntimeException("session error")); + } +} diff --git a/src/test/java/com/google/cloud/mcp/auth/AuthMethodsTest.java b/src/test/java/com/google/cloud/mcp/auth/AuthMethodsTest.java new file mode 100644 index 0000000..66a497f --- /dev/null +++ b/src/test/java/com/google/cloud/mcp/auth/AuthMethodsTest.java @@ -0,0 +1,215 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.auth; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.withSettings; + +import com.google.auth.oauth2.GoogleCredentials; +import com.google.auth.oauth2.IdToken; +import com.google.auth.oauth2.IdTokenProvider; +import java.io.IOException; +import java.lang.reflect.Constructor; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class AuthMethodsTest { + + private int loadCount; + + @BeforeEach + void setUp() { + loadCount = 0; + } + + @Test + void testGetGoogleIdToken_Success() throws Exception { + String mockToken = "mock-id-token-xyz"; + String audience = "https://test-mcp-service.com"; + + // Setup Mock credentials implementing GoogleCredentials and IdTokenProvider + GoogleCredentials credentials = + mock(GoogleCredentials.class, withSettings().extraInterfaces(IdTokenProvider.class)); + IdToken mockIdToken = mock(IdToken.class); + when(mockIdToken.getTokenValue()).thenReturn(mockToken); + when(((IdTokenProvider) credentials).idTokenWithAudience(eq(audience), any())) + .thenReturn(mockIdToken); + + String token = AuthMethods.getGoogleIdToken(credentials, audience); + + assertEquals("Bearer " + mockToken, token); + } + + @Test + void testGetGoogleIdToken_NotAnIdTokenProvider() { + String audience = "https://test-mcp-service.com"; + + // Regular credentials that do not implement IdTokenProvider + GoogleCredentials credentials = mock(GoogleCredentials.class); + + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> AuthMethods.getGoogleIdToken(credentials, audience)); + assertTrue(exception.getMessage().contains("not an instance of IdTokenProvider")); + } + + @Test + void testGoogleCredentialsProvider_Success() throws Exception { + String mockToken = "mock-id-token-provider"; + String audience = "https://test-mcp-service.com"; + + GoogleCredentials credentials = + mock(GoogleCredentials.class, withSettings().extraInterfaces(IdTokenProvider.class)); + IdToken mockIdToken = mock(IdToken.class); + when(mockIdToken.getTokenValue()).thenReturn(mockToken); + when(((IdTokenProvider) credentials).idTokenWithAudience(eq(audience), any())) + .thenReturn(mockIdToken); + + GoogleCredentialsProvider provider = new GoogleCredentialsProvider(audience, () -> credentials); + String header = provider.getAuthorizationHeader().get(); + + assertEquals("Bearer " + mockToken, header); + } + + @Test + void testGoogleCredentialsProvider_Caching() throws Exception { + String mockToken = "mock-id-token-caching"; + String audience = "https://test-mcp-service.com"; + + GoogleCredentials credentials = + mock(GoogleCredentials.class, withSettings().extraInterfaces(IdTokenProvider.class)); + IdToken mockIdToken = mock(IdToken.class); + when(mockIdToken.getTokenValue()).thenReturn(mockToken); + when(((IdTokenProvider) credentials).idTokenWithAudience(eq(audience), any())) + .thenReturn(mockIdToken); + + GoogleCredentialsProvider.CredentialsLoader loader = + () -> { + loadCount++; + return credentials; + }; + + GoogleCredentialsProvider provider = new GoogleCredentialsProvider(audience, loader); + + // First call loads the credentials + String token1 = provider.getAuthorizationHeader().get(); + // Second call should reuse the cached credentials + String token2 = provider.getAuthorizationHeader().get(); + + assertEquals("Bearer " + mockToken, token1); + assertEquals("Bearer " + mockToken, token2); + assertEquals(1, loadCount, "Credentials should be loaded exactly once due to caching"); + } + + @Test + void testGoogleCredentialsProvider_FallbackOnException() throws Exception { + String audience = "https://test-mcp-service.com"; + + // Fail loading credentials + GoogleCredentialsProvider.CredentialsLoader loader = + () -> { + throw new IOException("Cannot load credentials"); + }; + + GoogleCredentialsProvider provider = new GoogleCredentialsProvider(audience, loader); + String header = provider.getAuthorizationHeader().get(); + + // Verification that it gracefully returns null (proceed without auth) + assertNull(header); + } + + @Test + void testGoogleCredentialsProvider_OidcFailure() throws Exception { + String audience = "https://test-mcp-service.com"; + GoogleCredentials creds = mock(GoogleCredentials.class); // Does NOT implement IdTokenProvider + GoogleCredentialsProvider provider = new GoogleCredentialsProvider(audience, () -> creds); + String header = provider.getAuthorizationHeader().get(); + assertNull(header, "OIDC incompatible credentials should return null auth header"); + } + + @Test + void testGoogleCredentialsProvider_InvalidAudience() { + assertThrows( + IllegalArgumentException.class, () -> new GoogleCredentialsProvider(null, () -> null)); + assertThrows( + IllegalArgumentException.class, () -> new GoogleCredentialsProvider("", () -> null)); + } + + @Test + void testGoogleCredentialsProvider_PublicConstructor() throws Exception { + GoogleCredentialsProvider provider = new GoogleCredentialsProvider("https://test.com"); + // Should run gracefully, even if ADC fails locally it returns null + String header = provider.getAuthorizationHeader().get(); + // No assertion on value, just verify it runs without crashing to cover constructor instructions + } + + @Test + void testAuthMethods_PrivateConstructor() throws Exception { + Constructor constructor = AuthMethods.class.getDeclaredConstructor(); + constructor.setAccessible(true); + AuthMethods instance = constructor.newInstance(); + assertNotNull(instance); + } + + @Test + void testAuthMethods_NullCredentials() { + assertThrows( + IllegalArgumentException.class, () -> AuthMethods.getGoogleIdToken(null, "audience")); + } + + @Test + void testAuthMethods_RefreshException() throws Exception { + GoogleCredentials creds = mock(GoogleCredentials.class); + IOException simulatedException = new IOException("Refresh failed"); + org.mockito.Mockito.doThrow(simulatedException).when(creds).refreshIfExpired(); + + assertThrows(IOException.class, () -> AuthMethods.getGoogleIdToken(creds, "audience")); + } + + @Test + void testGoogleCredentialsProvider_NullCredentialsLoaded() throws Exception { + String audience = "https://test-mcp-service.com"; + GoogleCredentialsProvider provider = new GoogleCredentialsProvider(audience, () -> null); + String header = provider.getAuthorizationHeader().get(); + assertNull(header, "Null credentials from loader should return null auth header"); + } + + @Test + void testAuthMethods_BearerTokenAlreadyPrefixed() throws Exception { + String mockToken = "Bearer custom-already-prefixed-token"; + String audience = "https://test-mcp-service.com"; + + GoogleCredentials credentials = + mock(GoogleCredentials.class, withSettings().extraInterfaces(IdTokenProvider.class)); + IdToken mockIdToken = mock(IdToken.class); + when(mockIdToken.getTokenValue()).thenReturn(mockToken); + when(((IdTokenProvider) credentials).idTokenWithAudience(eq(audience), any())) + .thenReturn(mockIdToken); + + String resolvedToken = AuthMethods.getGoogleIdToken(credentials, audience); + assertEquals(mockToken, resolvedToken, "Should not double-prefix Bearer tokens"); + } +} diff --git a/src/test/java/com/google/cloud/mcp/client/HttpMcpToolboxClientTest.java b/src/test/java/com/google/cloud/mcp/client/HttpMcpToolboxClientTest.java new file mode 100644 index 0000000..9ea1dad --- /dev/null +++ b/src/test/java/com/google/cloud/mcp/client/HttpMcpToolboxClientTest.java @@ -0,0 +1,342 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.client; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.cloud.mcp.McpToolboxClient; +import com.google.cloud.mcp.ProtocolVersion; +import com.sun.net.httpserver.HttpExchange; +import com.sun.net.httpserver.HttpHandler; +import com.sun.net.httpserver.HttpServer; +import java.io.IOException; +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +@Timeout(value = 5, unit = java.util.concurrent.TimeUnit.SECONDS) +class HttpMcpToolboxClientTest { + + private HttpServer server; + private int port; + private MockMcpHandler handler; + private ObjectMapper objectMapper = new ObjectMapper(); + + @BeforeEach + void startServer() throws IOException { + server = HttpServer.create(new InetSocketAddress(0), 0); + port = server.getAddress().getPort(); + handler = new MockMcpHandler(); + server.createContext("/", handler); + server.start(); + } + + @AfterEach + void stopServer() { + if (server != null) { + server.stop(0); + } + } + + private String getBaseUrl() { + return "http://localhost:" + port; + } + + @Test + void testVersionNegotiation_default_2025_11_25() throws Exception { + handler.serverProtocolVersion = "2025-11-25"; + + McpToolboxClient client = McpToolboxClient.builder().baseUrl(getBaseUrl()).build(); + + // Trigger loadToolset which runs initialization + client.loadToolset().get(); + + // Verify initialization requests + assertTrue(handler.requestsReceived.size() >= 2); + + // First request: initialize + MockRequest initReq = handler.requestsReceived.get(0); + assertEquals("initialize", initReq.method); + assertEquals("2025-11-25", initReq.params.get("protocolVersion").asText()); + + // Second request: notifications/initialized + MockRequest notifReq = handler.requestsReceived.get(1); + assertEquals("notifications/initialized", notifReq.method); + assertEquals("2025-11-25", notifReq.headers.get("mcp-protocol-version")); + assertEquals("application/json", notifReq.headers.get("accept")); + + // Third request: tools/list + MockRequest listReq = handler.requestsReceived.get(2); + assertEquals("tools/list", listReq.method); + assertEquals("2025-11-25", listReq.headers.get("mcp-protocol-version")); + assertEquals("application/json", listReq.headers.get("accept")); + } + + @Test + void testVersionNegotiation_mismatch_throws() { + // Client proposes 2025-11-25 by default, server returns 2025-03-26 + handler.serverProtocolVersion = "2025-03-26"; + handler.serverSessionId = "sess-12345"; + + McpToolboxClient client = McpToolboxClient.builder().baseUrl(getBaseUrl()).build(); + + ExecutionException exception = + assertThrows( + ExecutionException.class, + () -> { + client.loadToolset().get(); + }); + + assertTrue(exception.getCause().getMessage().contains("MCP version mismatch")); + } + + @Test + void testVersionNegotiation_success_2025_03_26() throws Exception { + handler.serverProtocolVersion = "2025-03-26"; + handler.serverSessionId = "sess-12345"; + + McpToolboxClient client = + McpToolboxClient.builder() + .baseUrl(getBaseUrl()) + .protocolVersion(ProtocolVersion.VERSION_2025_03_26) + .build(); + + client.loadToolset().get(); + + assertTrue(handler.requestsReceived.size() >= 2); + + // Initial check: proposed version is 2025-03-26 + MockRequest initReq = handler.requestsReceived.get(0); + assertEquals("initialize", initReq.method); + assertEquals("2025-03-26", initReq.params.get("protocolVersion").asText()); + + // Check initialized notification uses negotiated session ID and no version header + MockRequest notifReq = handler.requestsReceived.get(1); + assertEquals("notifications/initialized", notifReq.method); + assertEquals("sess-12345", notifReq.headers.get("mcp-session-id")); + assertTrue(!notifReq.headers.containsKey("mcp-protocol-version")); + assertEquals("application/json", notifReq.headers.get("accept")); + + // Check subsequent tools/list uses negotiated session ID + MockRequest listReq = handler.requestsReceived.get(2); + assertEquals("tools/list", listReq.method); + assertEquals("sess-12345", listReq.headers.get("mcp-session-id")); + assertTrue(!listReq.headers.containsKey("mcp-protocol-version")); + assertEquals("application/json", listReq.headers.get("accept")); + } + + @Test + void testVersionNegotiation_success_2024_11_05() throws Exception { + handler.serverProtocolVersion = "2024-11-05"; + + McpToolboxClient client = + McpToolboxClient.builder() + .baseUrl(getBaseUrl()) + .protocolVersion(ProtocolVersion.VERSION_2024_11_05) + .build(); + + client.loadToolset().get(); + + assertTrue(handler.requestsReceived.size() >= 2); + + MockRequest notifReq = handler.requestsReceived.get(1); + assertEquals("notifications/initialized", notifReq.method); + assertTrue(!notifReq.headers.containsKey("mcp-protocol-version")); + assertTrue(!notifReq.headers.containsKey("mcp-session-id")); + assertTrue(!notifReq.headers.containsKey("accept")); + + MockRequest listReq = handler.requestsReceived.get(2); + assertEquals("tools/list", listReq.method); + assertTrue(!listReq.headers.containsKey("mcp-protocol-version")); + assertTrue(!listReq.headers.containsKey("mcp-session-id")); + assertTrue(!listReq.headers.containsKey("accept")); + } + + @Test + void testVersionNegotiation_unsupportedVersion() { + handler.serverProtocolVersion = "2023-01-01"; // Unsupported version + + McpToolboxClient client = McpToolboxClient.builder().baseUrl(getBaseUrl()).build(); + + ExecutionException exception = + assertThrows( + ExecutionException.class, + () -> { + client.loadToolset().get(); + }); + + assertTrue(exception.getCause().getMessage().contains("MCP version mismatch")); + } + + @Test + void testVersionNegotiation_missingSessionId_for_2025_03_26() { + handler.serverProtocolVersion = "2025-03-26"; + handler.serverSessionId = null; // Missing session ID + + McpToolboxClient client = + McpToolboxClient.builder() + .baseUrl(getBaseUrl()) + .protocolVersion(ProtocolVersion.VERSION_2025_03_26) + .build(); + + ExecutionException exception = + assertThrows( + ExecutionException.class, + () -> { + client.loadToolset().get(); + }); + + assertTrue(exception.getCause().getMessage().contains("Mcp-Session-Id")); + } + + @Test + void testBuilderPreferredVersion() throws Exception { + handler.serverProtocolVersion = "2025-03-26"; + handler.serverSessionId = "sess-prefer"; + + // Build client with preferred version + McpToolboxClient client = + McpToolboxClient.builder() + .baseUrl(getBaseUrl()) + .protocolVersion(ProtocolVersion.VERSION_2025_03_26) + .build(); + + client.loadToolset().get(); + + assertTrue(handler.requestsReceived.size() >= 2); + + // Initial check: client proposed preferred version + MockRequest initReq = handler.requestsReceived.get(0); + assertEquals("initialize", initReq.method); + assertEquals("2025-03-26", initReq.params.get("protocolVersion").asText()); + } + + private static class MockRequest { + String method; + JsonNode params; + Map headers; + + MockRequest(String method, JsonNode params, Map headers) { + this.method = method; + this.params = params; + this.headers = headers; + } + } + + private class MockMcpHandler implements HttpHandler { + String serverProtocolVersion = "2025-11-25"; + String serverSessionId = null; + List requestsReceived = Collections.synchronizedList(new ArrayList<>()); + + @Override + public void handle(HttpExchange exchange) throws IOException { + if (!"POST".equalsIgnoreCase(exchange.getRequestMethod())) { + exchange.sendResponseHeaders(405, -1); + return; + } + + String bodyStr = new String(exchange.getRequestBody().readAllBytes(), StandardCharsets.UTF_8); + JsonNode jsonReq; + try { + jsonReq = objectMapper.readTree(bodyStr); + } catch (Exception e) { + exchange.sendResponseHeaders(400, -1); + return; + } + + String method = jsonReq.get("method").asText(); + JsonNode params = jsonReq.get("params"); + + // Extract headers case-insensitively for checking (lowercase keys) + java.util.Map requestHeaders = new java.util.HashMap<>(); + exchange + .getRequestHeaders() + .forEach( + (k, v) -> { + if (v != null && !v.isEmpty()) { + requestHeaders.put(k.toLowerCase(java.util.Locale.ROOT), v.get(0)); + } + }); + + requestsReceived.add(new MockRequest(method, params, requestHeaders)); + + String responseBody = ""; + int statusCode = 200; + + if ("initialize".equals(method)) { + if (serverSessionId != null) { + exchange.getResponseHeaders().set("Mcp-Session-Id", serverSessionId); + } + responseBody = + "{\n" + + " \"jsonrpc\": \"2.0\",\n" + + " \"id\": \"" + + jsonReq.get("id").asText() + + "\",\n" + + " \"result\": {\n" + + " \"protocolVersion\": \"" + + serverProtocolVersion + + "\",\n" + + " \"capabilities\": {\n" + + " \"tools\": {}\n" + + " },\n" + + " \"serverInfo\": {\n" + + " \"name\": \"mock-server\",\n" + + " \"version\": \"1.0.0\"\n" + + " }\n" + + " }\n" + + "}"; + } else if ("notifications/initialized".equals(method)) { + statusCode = 204; + } else if ("tools/list".equals(method)) { + responseBody = + "{\n" + + " \"jsonrpc\": \"2.0\",\n" + + " \"id\": \"" + + jsonReq.get("id").asText() + + "\",\n" + + " \"result\": {\n" + + " \"tools\": []\n" + + " }\n" + + "}"; + } + + byte[] bytes = responseBody.getBytes(StandardCharsets.UTF_8); + if (statusCode == 204 || bytes.length == 0) { + exchange.sendResponseHeaders(statusCode, -1); + } else { + exchange.sendResponseHeaders(statusCode, bytes.length); + try (OutputStream os = exchange.getResponseBody()) { + os.write(bytes); + } + } + } + } +} diff --git a/src/test/java/com/google/cloud/mcp/client/McpToolboxClientBuilderTest.java b/src/test/java/com/google/cloud/mcp/client/McpToolboxClientBuilderTest.java new file mode 100644 index 0000000..b5b1168 --- /dev/null +++ b/src/test/java/com/google/cloud/mcp/client/McpToolboxClientBuilderTest.java @@ -0,0 +1,202 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.client; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.google.cloud.mcp.McpToolboxClient; +import com.google.cloud.mcp.ProtocolVersion; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.exception.McpException; +import com.google.cloud.mcp.tool.ToolPostProcessor; +import com.google.cloud.mcp.tool.ToolPreProcessor; +import com.google.cloud.mcp.transport.Transport; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.Test; + +class McpToolboxClientBuilderTest { + + @Test + void testHeadersAndApiKey() { + McpToolboxClient client = + McpToolboxClient.builder() + .baseUrl("http://localhost:8080") + .apiKey("my-api-key") + .headers(Map.of("X-Custom-Header", "value1", "Authorization", "Bearer custom-token")) + .build(); + + assertNotNull(client); + assertTrue(client instanceof McpToolboxClientImpl); + } + + @Test + void testBaseUrlValidation() { + assertThrows( + IllegalArgumentException.class, () -> McpToolboxClient.builder().baseUrl(null).build()); + + assertThrows( + IllegalArgumentException.class, () -> McpToolboxClient.builder().baseUrl("").build()); + } + + @Test + void testBaseUrlTrailingSlashNormalization() throws Exception { + McpToolboxClient client = McpToolboxClient.builder().baseUrl("http://localhost:8080/").build(); + + Field transportField = McpToolboxClientImpl.class.getDeclaredField("transport"); + transportField.setAccessible(true); + Transport transport = (Transport) transportField.get(client); + assertEquals("http://localhost:8080", transport.getBaseUrl()); + } + + @Test + @SuppressWarnings("unchecked") + void testApiKeyPreprocessing() throws Exception { + Method getAuthHeaderMethod = + McpToolboxClientImpl.class.getDeclaredMethod("getAuthorizationHeader"); + getAuthHeaderMethod.setAccessible(true); + + // 1. ApiKey is null or empty + McpToolboxClient clientNullKey = + McpToolboxClient.builder().baseUrl("http://localhost:8080").apiKey(null).build(); + CompletableFuture futureNull = + (CompletableFuture) getAuthHeaderMethod.invoke(clientNullKey); + assertNull(futureNull.join()); + + // 2. ApiKey is raw (not prefixed with Bearer) + McpToolboxClient clientRawKey = + McpToolboxClient.builder().baseUrl("http://localhost:8080").apiKey("raw-key-123").build(); + CompletableFuture futureRaw = + (CompletableFuture) getAuthHeaderMethod.invoke(clientRawKey); + assertEquals("Bearer raw-key-123", futureRaw.join()); + + // 3. ApiKey already contains Bearer prefix + McpToolboxClient clientBearerKey = + McpToolboxClient.builder() + .baseUrl("http://localhost:8080") + .apiKey("Bearer token-456") + .build(); + CompletableFuture futureBearer = + (CompletableFuture) getAuthHeaderMethod.invoke(clientBearerKey); + assertEquals("Bearer token-456", futureBearer.join()); + + // 4. ApiKey does not override existing Authorization header + McpToolboxClient clientOverrideKey = + McpToolboxClient.builder() + .baseUrl("http://localhost:8080") + .headers(Map.of("Authorization", "Bearer existing-token")) + .apiKey("new-key-should-be-ignored") + .build(); + CompletableFuture futureOverride = + (CompletableFuture) getAuthHeaderMethod.invoke(clientOverrideKey); + assertEquals("Bearer existing-token", futureOverride.join()); + } + + @Test + void testHeadersNullHandledSafely() { + McpToolboxClient client = + McpToolboxClient.builder().baseUrl("http://localhost:8080").headers(null).build(); + assertNotNull(client); + } + + @Test + @SuppressWarnings("unchecked") + void testCredentialsProviderConfiguration() throws Exception { + CredentialsProvider provider = () -> CompletableFuture.completedFuture("Bearer test-token"); + McpToolboxClient client = + McpToolboxClient.builder() + .baseUrl("http://localhost:8080") + .credentialsProvider(provider) + .build(); + assertNotNull(client); + + Method getAuthHeaderMethod = + McpToolboxClientImpl.class.getDeclaredMethod("getAuthorizationHeader"); + getAuthHeaderMethod.setAccessible(true); + CompletableFuture future = + (CompletableFuture) getAuthHeaderMethod.invoke(client); + assertEquals("Bearer test-token", future.join()); + } + + @Test + @SuppressWarnings("unchecked") + void testEmptyApiKey_TreatedAsNoKey() throws Exception { + McpToolboxClient client = + McpToolboxClient.builder().baseUrl("http://localhost:8080").apiKey("").build(); + + Method getAuthHeaderMethod = + McpToolboxClientImpl.class.getDeclaredMethod("getAuthorizationHeader"); + getAuthHeaderMethod.setAccessible(true); + CompletableFuture future = + (CompletableFuture) getAuthHeaderMethod.invoke(client); + assertNull(future.join()); + } + + @Test + void testCustomHttpClientAndExecutor() { + java.net.http.HttpClient customClient = java.net.http.HttpClient.newHttpClient(); + java.util.concurrent.Executor customExecutor = java.util.concurrent.ForkJoinPool.commonPool(); + + McpToolboxClient client = + McpToolboxClient.builder() + .baseUrl("http://localhost:8080") + .httpClient(customClient) + .executor(customExecutor) + .protocolVersion(ProtocolVersion.VERSION_2025_11_25) + .build(); + + assertNotNull(client); + } + + @Test + void testProcessorsConfiguration() { + ToolPreProcessor pre = (name, args) -> CompletableFuture.completedFuture(args); + ToolPostProcessor post = (name, result) -> CompletableFuture.completedFuture(result); + + McpToolboxClient client = + McpToolboxClient.builder() + .baseUrl("http://localhost:8080") + .preProcessor(pre) + .preProcessor(null) + .postProcessor(post) + .postProcessor(null) + .build(); + assertNotNull(client); + } + + @Test + void testMcpExceptionConstructor() { + RuntimeException cause = new RuntimeException("root cause"); + McpException ex = new McpException("error message", cause); + assertEquals("error message", ex.getMessage()); + assertSame(cause, ex.getCause()); + } + + @Test + void testProtocolVersionFromString() { + assertNull(ProtocolVersion.fromString(null)); + assertNull(ProtocolVersion.fromString("invalid-version")); + assertEquals(ProtocolVersion.VERSION_2025_11_25, ProtocolVersion.fromString("2025-11-25")); + } +} diff --git a/src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplErrorsTest.java b/src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplErrorsTest.java new file mode 100644 index 0000000..49732a8 --- /dev/null +++ b/src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplErrorsTest.java @@ -0,0 +1,276 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.client; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.tool.ToolDefinition; +import com.google.cloud.mcp.tool.ToolResult; +import com.google.cloud.mcp.transport.HttpMcpTransport; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +@Timeout(value = 5, unit = java.util.concurrent.TimeUnit.SECONDS) +class McpToolboxClientImplErrorsTest { + + private McpToolboxClientImpl client; + private HttpClient mockHttpClient; + private ObjectMapper objectMapper = new ObjectMapper(); + + @BeforeEach + @SuppressWarnings("unchecked") + void setUp() throws Exception { + mockHttpClient = mock(HttpClient.class); + HttpMcpTransport transport = new HttpMcpTransport("http://localhost:8080", mockHttpClient); + CredentialsProvider provider = () -> CompletableFuture.completedFuture("Bearer test-api-key"); + client = new McpToolboxClientImpl(transport, java.util.Collections.emptyMap(), provider); + } + + @Test + void testEnsureInitializedFailsWith500() { + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(500); + when(initResponse.body()).thenReturn("Internal Server Error"); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)); + + CompletableFuture> future = client.listTools(); + + Exception exception = + org.junit.jupiter.api.Assertions.assertThrows(Exception.class, future::join); + + Throwable cause = exception.getCause(); + assertNotNull(cause); + assertTrue(cause instanceof RuntimeException); + assertTrue(cause.getMessage().contains("Init failed: 500")); + assertTrue(cause.getMessage().contains("Internal Server Error")); + } + + @Test + void testEnsureInitializedFailsWith401() { + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(401); + when(initResponse.body()).thenReturn("Unauthorized"); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)); + + CompletableFuture> future = client.listTools(); + + Exception exception = + org.junit.jupiter.api.Assertions.assertThrows(Exception.class, future::join); + + Throwable cause = exception.getCause(); + assertNotNull(cause); + assertTrue(cause instanceof RuntimeException); + assertTrue(cause.getMessage().contains("Init failed: 401")); + assertTrue(cause.getMessage().contains("Unauthorized")); + } + + @Test + void testInvokeToolFailsDuringInitializationWith403() { + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(403); + when(initResponse.body()).thenReturn("Forbidden"); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)); + + CompletableFuture future = client.invokeTool("test-tool", Map.of()); + + Exception exception = + org.junit.jupiter.api.Assertions.assertThrows(Exception.class, future::join); + + Throwable cause = exception.getCause(); + assertNotNull(cause); + assertTrue(cause instanceof RuntimeException); + assertTrue(cause.getMessage().contains("Init failed: 403")); + assertTrue(cause.getMessage().contains("Forbidden")); + } + + @Test + void testListToolsFailsWith500AfterInit() { + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + HttpResponse listResponse = mock(HttpResponse.class); + when(listResponse.statusCode()).thenReturn(500); + when(listResponse.body()).thenReturn("Internal Server Error"); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(listResponse)); + + CompletableFuture> future = client.listTools(); + + Exception exception = + org.junit.jupiter.api.Assertions.assertThrows(Exception.class, future::join); + + Throwable cause = exception.getCause(); + assertNotNull(cause); + assertTrue(cause instanceof RuntimeException); + assertTrue(cause.getMessage().contains("Failed to list tools. Status: 500")); + assertTrue(cause.getMessage().contains("Internal Server Error")); + } + + @Test + void testInvokeToolReturnsErrorOnNon200Response() { + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + HttpResponse callResponse = mock(HttpResponse.class); + when(callResponse.statusCode()).thenReturn(500); + when(callResponse.body()).thenReturn("Internal Server Error"); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(callResponse)); + + ToolResult result = client.invokeTool("test-tool", Map.of()).join(); + + assertNotNull(result); + assertTrue(result.isError()); + assertEquals(1, result.content().size()); + assertTrue(result.content().get(0).text().contains("Error 500")); + assertTrue(result.content().get(0).text().contains("Internal Server Error")); + } + + @Test + void testListToolsThrowsIOExceptionOnSend() { + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.failedFuture(new java.io.IOException("Connection reset"))); + + CompletableFuture> future = client.listTools(); + + Exception exception = + org.junit.jupiter.api.Assertions.assertThrows(Exception.class, future::join); + + Throwable cause = exception.getCause(); + assertNotNull(cause); + assertTrue(cause instanceof java.io.IOException); + assertEquals("Connection reset", cause.getMessage()); + } + + @Test + void testListToolsThrowsIOExceptionOnListRequest() { + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.failedFuture(new java.io.IOException("Connection timeout"))); + + CompletableFuture> future = client.listTools(); + + Exception exception = + org.junit.jupiter.api.Assertions.assertThrows(Exception.class, future::join); + + Throwable cause = exception.getCause(); + assertNotNull(cause); + assertTrue(cause instanceof java.io.IOException); + assertEquals("Connection timeout", cause.getMessage()); + } + + @Test + void testInvokeToolThrowsIOExceptionOnSend() { + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.failedFuture(new java.io.IOException("Timeout occurred"))); + + CompletableFuture future = client.invokeTool("test-tool", Map.of()); + + Exception exception = + org.junit.jupiter.api.Assertions.assertThrows(Exception.class, future::join); + + Throwable cause = exception.getCause(); + assertNotNull(cause); + assertTrue(cause instanceof java.io.IOException); + assertEquals("Timeout occurred", cause.getMessage()); + } + + private String getBodyStringQuietly(HttpRequest request) { + try { + return getBodyString(request); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private String getBodyString(HttpRequest request) throws Exception { + if (request.bodyPublisher().isPresent()) { + var publisher = request.bodyPublisher().get(); + var subscriber = + HttpResponse.BodySubscribers.ofString(java.nio.charset.StandardCharsets.UTF_8); + publisher.subscribe( + new java.util.concurrent.Flow.Subscriber() { + @Override + public void onSubscribe(java.util.concurrent.Flow.Subscription subscription) { + subscriber.onSubscribe(subscription); + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onNext(java.nio.ByteBuffer item) { + subscriber.onNext(java.util.List.of(item)); + } + + @Override + public void onError(Throwable throwable) { + subscriber.onError(throwable); + } + + @Override + public void onComplete() { + subscriber.onComplete(); + } + }); + return subscriber.getBody().toCompletableFuture().join(); + } + return ""; + } +} diff --git a/src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplHeadersTest.java b/src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplHeadersTest.java new file mode 100644 index 0000000..7dc2aab --- /dev/null +++ b/src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplHeadersTest.java @@ -0,0 +1,319 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.client; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.cloud.mcp.McpToolboxClient; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.transport.BaseMcpTransport; +import com.google.cloud.mcp.transport.HttpMcpTransport; +import java.lang.reflect.Field; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.mockito.ArgumentCaptor; + +@Timeout(value = 5, unit = java.util.concurrent.TimeUnit.SECONDS) +class McpToolboxClientImplHeadersTest { + + private McpToolboxClientImpl client; + private HttpClient mockHttpClient; + private ObjectMapper objectMapper = new ObjectMapper(); + + @BeforeEach + @SuppressWarnings("unchecked") + void setUp() throws Exception { + mockHttpClient = mock(HttpClient.class); + HttpMcpTransport transport = new HttpMcpTransport("http://localhost:8080", mockHttpClient); + CredentialsProvider provider = () -> CompletableFuture.completedFuture("Bearer test-api-key"); + client = new McpToolboxClientImpl(transport, java.util.Collections.emptyMap(), provider); + } + + @Test + void testCustomHeadersPopulatedInAllRequests() throws Exception { + McpToolboxClient client = + new McpToolboxClientBuilder() + .baseUrl("http://localhost:8080") + .apiKey("client-api-key") + .headers(Map.of("X-Client-Header", "client-value", "X-Common-Header", "client-common")) + .build(); + + HttpClient mockHttpClient = mock(HttpClient.class); + Field transportField = McpToolboxClientImpl.class.getDeclaredField("transport"); + transportField.setAccessible(true); + HttpMcpTransport transport = (HttpMcpTransport) transportField.get(client); + Field delegateField = HttpMcpTransport.class.getDeclaredField("delegate"); + delegateField.setAccessible(true); + Object delegate = delegateField.get(transport); + Field httpClientField = BaseMcpTransport.class.getDeclaredField("httpClient"); + httpClientField.setAccessible(true); + httpClientField.set(delegate, mockHttpClient); + + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + String listBody = + "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[{\"name\":\"test-tool\"," + + "\"description\":\"A test tool\",\"inputSchema\":{\"type\":\"object\"," + + "\"properties\":{\"param1\":{\"type\":\"string\"}}," + + "\"required\":[\"param1\"]}}]}}"; + HttpResponse listResponse = mock(HttpResponse.class); + when(listResponse.statusCode()).thenReturn(200); + when(listResponse.body()).thenReturn(listBody); + + String callBody = + "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"content\":[{\"type\":\"text\"," + + "\"text\":\"success\"}],\"isError\":false}}"; + HttpResponse callResponse = mock(HttpResponse.class); + when(callResponse.statusCode()).thenReturn(200); + when(callResponse.body()).thenReturn(callBody); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(listResponse)) + .thenReturn(CompletableFuture.completedFuture(callResponse)); + + // Call listTools (which initializes first) + client.listTools().join(); + // Call invokeTool + client.invokeTool("test-tool", Map.of("param1", "value1")).join(); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + verify(mockHttpClient, times(4)).sendAsync(requestCaptor.capture(), any()); + + // 1st request: initialize + HttpRequest initReq = requestCaptor.getAllValues().get(0); + assertEquals("client-value", initReq.headers().firstValue("X-Client-Header").orElse(null)); + assertEquals("client-common", initReq.headers().firstValue("X-Common-Header").orElse(null)); + assertEquals( + "Bearer client-api-key", initReq.headers().firstValue("Authorization").orElse(null)); + + // 2nd request: notifications/initialized + HttpRequest notifReq = requestCaptor.getAllValues().get(1); + assertEquals("client-value", notifReq.headers().firstValue("X-Client-Header").orElse(null)); + assertEquals("client-common", notifReq.headers().firstValue("X-Common-Header").orElse(null)); + assertEquals( + "Bearer client-api-key", notifReq.headers().firstValue("Authorization").orElse(null)); + + // 3rd request: tools/list + HttpRequest listReq = requestCaptor.getAllValues().get(2); + assertEquals("client-value", listReq.headers().firstValue("X-Client-Header").orElse(null)); + assertEquals("client-common", listReq.headers().firstValue("X-Common-Header").orElse(null)); + assertEquals( + "Bearer client-api-key", listReq.headers().firstValue("Authorization").orElse(null)); + + // 4th request: tools/call + HttpRequest callReq = requestCaptor.getAllValues().get(3); + assertEquals("client-value", callReq.headers().firstValue("X-Client-Header").orElse(null)); + assertEquals("client-common", callReq.headers().firstValue("X-Common-Header").orElse(null)); + assertEquals( + "Bearer client-api-key", callReq.headers().firstValue("Authorization").orElse(null)); + } + + @Test + void testExtraHeadersOverrideAndAuthPriority() throws Exception { + McpToolboxClient client = + new McpToolboxClientBuilder() + .baseUrl("http://localhost:8080") + .apiKey("client-api-key") + .headers(Map.of("X-Client-Header", "client-value", "X-Common-Header", "client-common")) + .build(); + + HttpClient mockHttpClient = mock(HttpClient.class); + Field transportField = McpToolboxClientImpl.class.getDeclaredField("transport"); + transportField.setAccessible(true); + HttpMcpTransport transport = (HttpMcpTransport) transportField.get(client); + Field delegateField = HttpMcpTransport.class.getDeclaredField("delegate"); + delegateField.setAccessible(true); + Object delegate = delegateField.get(transport); + Field httpClientField = BaseMcpTransport.class.getDeclaredField("httpClient"); + httpClientField.setAccessible(true); + httpClientField.set(delegate, mockHttpClient); + + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + String callBody = + "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"content\":[{\"type\":\"text\"," + + "\"text\":\"success\"}],\"isError\":false}}"; + HttpResponse callResponse = mock(HttpResponse.class); + when(callResponse.statusCode()).thenReturn(200); + when(callResponse.body()).thenReturn(callBody); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(callResponse)); + + // Call invokeTool directly (which will initialize client first) + // Pass extraHeaders containing X-Common-Header override and Authorization override + Map extraHeaders = + Map.of("X-Common-Header", "override-common", "Authorization", "Bearer extra-auth-key"); + client.invokeTool("test-tool", Map.of("param1", "value1"), extraHeaders).join(); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + verify(mockHttpClient, times(3)).sendAsync(requestCaptor.capture(), any()); + + // 1st request: initialize + HttpRequest initReq = requestCaptor.getAllValues().get(0); + assertEquals("client-value", initReq.headers().firstValue("X-Client-Header").orElse(null)); + assertEquals("client-common", initReq.headers().firstValue("X-Common-Header").orElse(null)); + assertEquals( + "Bearer extra-auth-key", initReq.headers().firstValue("Authorization").orElse(null)); + + // 2nd request: notifications/initialized + HttpRequest notifReq = requestCaptor.getAllValues().get(1); + assertEquals("client-value", notifReq.headers().firstValue("X-Client-Header").orElse(null)); + assertEquals("client-common", notifReq.headers().firstValue("X-Common-Header").orElse(null)); + assertEquals( + "Bearer extra-auth-key", notifReq.headers().firstValue("Authorization").orElse(null)); + + // 3rd request: tools/call + HttpRequest callReq = requestCaptor.getAllValues().get(2); + assertEquals("client-value", callReq.headers().firstValue("X-Client-Header").orElse(null)); + assertEquals("override-common", callReq.headers().firstValue("X-Common-Header").orElse(null)); + assertEquals( + "Bearer extra-auth-key", callReq.headers().firstValue("Authorization").orElse(null)); + } + + @Test + void testNoDuplicateHeaders() throws Exception { + Map customHeaders = new HashMap<>(); + customHeaders.put("X-Test-Header", "value1"); + customHeaders.put("x-test-header", "value2"); + customHeaders.put("Authorization", "Bearer initial-token"); + customHeaders.put("authorization", "Bearer lowercase-token"); + + McpToolboxClient client = + new McpToolboxClientBuilder() + .baseUrl("http://localhost:8080") + .headers(customHeaders) + .build(); + + HttpClient mockHttpClient = mock(HttpClient.class); + Field transportField = McpToolboxClientImpl.class.getDeclaredField("transport"); + transportField.setAccessible(true); + HttpMcpTransport transport = (HttpMcpTransport) transportField.get(client); + Field delegateField = HttpMcpTransport.class.getDeclaredField("delegate"); + delegateField.setAccessible(true); + Object delegate = delegateField.get(transport); + Field httpClientField = BaseMcpTransport.class.getDeclaredField("httpClient"); + httpClientField.setAccessible(true); + httpClientField.set(delegate, mockHttpClient); + + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + HttpResponse listResponse = mock(HttpResponse.class); + when(listResponse.statusCode()).thenReturn(200); + when(listResponse.body()) + .thenReturn("{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}"); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(listResponse)); + + client.listTools().join(); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + verify(mockHttpClient, times(3)).sendAsync(requestCaptor.capture(), any()); + + for (HttpRequest request : requestCaptor.getAllValues()) { + java.net.http.HttpHeaders headers = request.headers(); + + // Verify Authorization is not duplicated + List authHeaders = headers.allValues("Authorization"); + assertEquals(1, authHeaders.size(), "Authorization header should have exactly one value"); + + // Verify X-Test-Header is not duplicated + List testHeaders = headers.allValues("X-Test-Header"); + assertEquals(1, testHeaders.size(), "X-Test-Header should have exactly one value"); + } + } + + private String getBodyStringQuietly(HttpRequest request) { + try { + return getBodyString(request); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private String getBodyString(HttpRequest request) throws Exception { + if (request.bodyPublisher().isPresent()) { + var publisher = request.bodyPublisher().get(); + var subscriber = + HttpResponse.BodySubscribers.ofString(java.nio.charset.StandardCharsets.UTF_8); + publisher.subscribe( + new java.util.concurrent.Flow.Subscriber() { + @Override + public void onSubscribe(java.util.concurrent.Flow.Subscription subscription) { + subscriber.onSubscribe(subscription); + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onNext(java.nio.ByteBuffer item) { + subscriber.onNext(java.util.List.of(item)); + } + + @Override + public void onError(Throwable throwable) { + subscriber.onError(throwable); + } + + @Override + public void onComplete() { + subscriber.onComplete(); + } + }); + return subscriber.getBody().toCompletableFuture().join(); + } + return ""; + } +} diff --git a/src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplJsonRpcTest.java b/src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplJsonRpcTest.java new file mode 100644 index 0000000..6ec19e6 --- /dev/null +++ b/src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplJsonRpcTest.java @@ -0,0 +1,622 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.client; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.cloud.mcp.JsonRpc; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.tool.ToolDefinition; +import com.google.cloud.mcp.tool.ToolResult; +import com.google.cloud.mcp.transport.HttpMcpTransport; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +@Timeout(value = 5, unit = java.util.concurrent.TimeUnit.SECONDS) +class McpToolboxClientImplJsonRpcTest { + + private McpToolboxClientImpl client; + private HttpClient mockHttpClient; + private ObjectMapper objectMapper = new ObjectMapper(); + + @BeforeEach + @SuppressWarnings("unchecked") + void setUp() throws Exception { + mockHttpClient = mock(HttpClient.class); + HttpMcpTransport transport = new HttpMcpTransport("http://localhost:8080", mockHttpClient); + CredentialsProvider provider = () -> CompletableFuture.completedFuture("Bearer test-api-key"); + client = new McpToolboxClientImpl(transport, java.util.Collections.emptyMap(), provider); + } + + @Test + void testListTools_ProtocolError() throws Exception { + HttpResponse errorResponse = mock(HttpResponse.class); + when(errorResponse.statusCode()).thenReturn(200); + when(errorResponse.body()) + .thenReturn( + "{\"jsonrpc\":\"2.0\",\"id\":1,\"error\":{\"code\":-32601,\"message\":\"Method not" + + " found\"}}"); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenAnswer( + invocation -> { + HttpRequest req = invocation.getArgument(0); + if (getBodyString(req).contains("tools/list")) { + return CompletableFuture.completedFuture(errorResponse); + } + HttpResponse initRes = mock(HttpResponse.class); + when(initRes.statusCode()).thenReturn(200); + when(initRes.body()).thenReturn("{}"); + return CompletableFuture.completedFuture(initRes); + }); + + CompletionException exception = + assertThrows(CompletionException.class, () -> client.listTools().join()); + assertTrue(exception.getCause() instanceof RuntimeException); + assertTrue(exception.getCause().getMessage().contains("MCP Error:")); + assertTrue(exception.getCause().getMessage().contains("Method not found")); + } + + @Test + void testInvokeTool_ProtocolError() throws Exception { + HttpResponse errorResponse = mock(HttpResponse.class); + when(errorResponse.statusCode()).thenReturn(200); + when(errorResponse.body()) + .thenReturn( + "{\"jsonrpc\":\"2.0\",\"id\":1,\"error\":{\"code\":-32602,\"message\":\"Invalid" + + " params\"}}"); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenAnswer( + invocation -> { + HttpRequest req = invocation.getArgument(0); + if (getBodyString(req).contains("tools/call")) { + return CompletableFuture.completedFuture(errorResponse); + } + HttpResponse initRes = mock(HttpResponse.class); + when(initRes.statusCode()).thenReturn(200); + when(initRes.body()).thenReturn("{}"); + return CompletableFuture.completedFuture(initRes); + }); + + ToolResult result = client.invokeTool("test-tool", Collections.emptyMap()).join(); + assertNotNull(result); + assertTrue(result.isError()); + assertEquals(1, result.content().size()); + assertTrue(result.content().get(0).text().contains("MCP Error:")); + assertTrue(result.content().get(0).text().contains("Invalid params")); + } + + @Test + void testListTools_MalformedJson() throws Exception { + HttpResponse malformedResponse = mock(HttpResponse.class); + when(malformedResponse.statusCode()).thenReturn(200); + when(malformedResponse.body()).thenReturn("{\"invalid_json"); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenAnswer( + invocation -> { + HttpRequest req = invocation.getArgument(0); + if (getBodyString(req).contains("tools/list")) { + return CompletableFuture.completedFuture(malformedResponse); + } + HttpResponse initRes = mock(HttpResponse.class); + when(initRes.statusCode()).thenReturn(200); + when(initRes.body()).thenReturn("{}"); + return CompletableFuture.completedFuture(initRes); + }); + + CompletionException exception = + assertThrows(CompletionException.class, () -> client.listTools().join()); + assertTrue(exception.getCause() instanceof RuntimeException); + assertTrue( + exception.getCause().getCause() + instanceof com.fasterxml.jackson.core.JsonProcessingException); + } + + @Test + void testInvokeTool_MalformedJson() throws Exception { + HttpResponse malformedResponse = mock(HttpResponse.class); + when(malformedResponse.statusCode()).thenReturn(200); + when(malformedResponse.body()).thenReturn("{\"invalid_json"); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenAnswer( + invocation -> { + HttpRequest req = invocation.getArgument(0); + if (getBodyString(req).contains("tools/call")) { + return CompletableFuture.completedFuture(malformedResponse); + } + HttpResponse initRes = mock(HttpResponse.class); + when(initRes.statusCode()).thenReturn(200); + when(initRes.body()).thenReturn("{}"); + return CompletableFuture.completedFuture(initRes); + }); + + ToolResult result = client.invokeTool("test-tool", Collections.emptyMap()).join(); + assertNotNull(result); + assertFalse(result.isError()); + assertEquals(1, result.content().size()); + assertEquals("{\"invalid_json", result.content().get(0).text()); + } + + @Test + void testListTools_MissingResult() throws Exception { + HttpResponse missingResultResponse = mock(HttpResponse.class); + when(missingResultResponse.statusCode()).thenReturn(200); + when(missingResultResponse.body()).thenReturn("{\"jsonrpc\":\"2.0\",\"id\":1}"); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenAnswer( + invocation -> { + HttpRequest req = invocation.getArgument(0); + if (getBodyString(req).contains("tools/list")) { + return CompletableFuture.completedFuture(missingResultResponse); + } + HttpResponse initRes = mock(HttpResponse.class); + when(initRes.statusCode()).thenReturn(200); + when(initRes.body()).thenReturn("{}"); + return CompletableFuture.completedFuture(initRes); + }); + + CompletionException exception = + assertThrows(CompletionException.class, () -> client.listTools().join()); + assertTrue(exception.getCause() instanceof RuntimeException); + assertTrue(exception.getCause().getCause() instanceof NullPointerException); + } + + @Test + void testListTools_EmptyResult() throws Exception { + HttpResponse emptyResultResponse = mock(HttpResponse.class); + when(emptyResultResponse.statusCode()).thenReturn(200); + when(emptyResultResponse.body()).thenReturn("{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{}}"); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenAnswer( + invocation -> { + HttpRequest req = invocation.getArgument(0); + if (getBodyString(req).contains("tools/list")) { + return CompletableFuture.completedFuture(emptyResultResponse); + } + HttpResponse initRes = mock(HttpResponse.class); + when(initRes.statusCode()).thenReturn(200); + when(initRes.body()).thenReturn("{}"); + return CompletableFuture.completedFuture(initRes); + }); + + Map tools = client.listTools().join(); + assertNotNull(tools); + assertTrue(tools.isEmpty()); + } + + @Test + void testInvokeTool_MissingResult() throws Exception { + HttpResponse missingResultResponse = mock(HttpResponse.class); + when(missingResultResponse.statusCode()).thenReturn(200); + when(missingResultResponse.body()).thenReturn("{\"jsonrpc\":\"2.0\",\"id\":1}"); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenAnswer( + invocation -> { + HttpRequest req = invocation.getArgument(0); + if (getBodyString(req).contains("tools/call")) { + return CompletableFuture.completedFuture(missingResultResponse); + } + HttpResponse initRes = mock(HttpResponse.class); + when(initRes.statusCode()).thenReturn(200); + when(initRes.body()).thenReturn("{}"); + return CompletableFuture.completedFuture(initRes); + }); + + ToolResult result = client.invokeTool("test-tool", Collections.emptyMap()).join(); + assertNotNull(result); + assertFalse(result.isError()); + assertEquals(1, result.content().size()); + assertEquals("{\"jsonrpc\":\"2.0\",\"id\":1}", result.content().get(0).text()); + } + + @Test + void testInvokeTool_EmptyResult() throws Exception { + HttpResponse emptyResultResponse = mock(HttpResponse.class); + when(emptyResultResponse.statusCode()).thenReturn(200); + when(emptyResultResponse.body()).thenReturn("{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{}}"); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenAnswer( + invocation -> { + HttpRequest req = invocation.getArgument(0); + if (getBodyString(req).contains("tools/call")) { + return CompletableFuture.completedFuture(emptyResultResponse); + } + HttpResponse initRes = mock(HttpResponse.class); + when(initRes.statusCode()).thenReturn(200); + when(initRes.body()).thenReturn("{}"); + return CompletableFuture.completedFuture(initRes); + }); + + ToolResult result = client.invokeTool("test-tool", Collections.emptyMap()).join(); + assertNotNull(result); + assertFalse(result.isError()); + assertEquals(1, result.content().size()); + assertEquals("", result.content().get(0).text()); + } + + private String getBodyStringQuietly(HttpRequest request) { + try { + return getBodyString(request); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private String getBodyString(HttpRequest request) throws Exception { + if (request.bodyPublisher().isPresent()) { + var publisher = request.bodyPublisher().get(); + var subscriber = + HttpResponse.BodySubscribers.ofString(java.nio.charset.StandardCharsets.UTF_8); + publisher.subscribe( + new java.util.concurrent.Flow.Subscriber() { + @Override + public void onSubscribe(java.util.concurrent.Flow.Subscription subscription) { + subscriber.onSubscribe(subscription); + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onNext(java.nio.ByteBuffer item) { + subscriber.onNext(java.util.List.of(item)); + } + + @Override + public void onError(Throwable throwable) { + subscriber.onError(throwable); + } + + @Override + public void onComplete() { + subscriber.onComplete(); + } + }); + return subscriber.getBody().toCompletableFuture().join(); + } + return ""; + } + + @Test + void testListTools_withParamsMissingTypeAndDescription() throws Exception { + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + // parameter "param1" does not have type or description + String listBody = + "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[{\"name\":\"test-tool\"," + + "\"description\":\"A test tool\",\"inputSchema\":{\"type\":\"object\"," + + "\"properties\":{\"param1\":{}}," + + "\"required\":[]}}]}}"; + HttpResponse listResponse = mock(HttpResponse.class); + when(listResponse.statusCode()).thenReturn(200); + when(listResponse.body()).thenReturn(listBody); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(listResponse)); + + Map tools = client.listTools().join(); + assertNotNull(tools); + ToolDefinition toolDef = tools.get("test-tool"); + assertEquals(1, toolDef.parameters().size()); + ToolDefinition.Parameter param = toolDef.parameters().get(0); + assertEquals("param1", param.name()); + assertEquals("string", param.type()); // defaulted to string + assertEquals("", param.description()); // defaulted to empty string + } + + @Test + void testListTools_withAuthParamMetadata() throws Exception { + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + // parameter "param1" has authParam metadata in toolbox/authParam + String listBody = + "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[{\"name\":\"test-tool\",\"description\":\"A" + + " test tool\",\"inputSchema\":{\"type\":\"object\"," + + "\"properties\":{\"param1\":{\"type\":\"string\"}}," + + "\"required\":[]},\"_meta\":{\"toolbox/authParam\":{\"param1\":[\"my-auth-source\"]}}}]}}"; + HttpResponse listResponse = mock(HttpResponse.class); + when(listResponse.statusCode()).thenReturn(200); + when(listResponse.body()).thenReturn(listBody); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(listResponse)); + + Map tools = client.listTools().join(); + assertNotNull(tools); + ToolDefinition toolDef = tools.get("test-tool"); + ToolDefinition.Parameter param = toolDef.parameters().get(0); + assertEquals(1, param.authSources().size()); + assertEquals("my-auth-source", param.authSources().get(0)); + } + + @Test + void testInvokeTool_withIsErrorFlag() throws Exception { + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + String callBody = + "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"content\":[{\"type\":\"text\",\"text\":\"failed" + + " to run tool\"}],\"isError\":true}}"; + HttpResponse callResponse = mock(HttpResponse.class); + when(callResponse.statusCode()).thenReturn(200); + when(callResponse.body()).thenReturn(callBody); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(callResponse)); + + ToolResult result = client.invokeTool("test-tool", Map.of()).join(); + assertNotNull(result); + assertTrue(result.isError()); + assertEquals("failed to run tool", result.content().get(0).text()); + } + + @Test + void testInvokeTool_non200WithResponseBody() throws Exception { + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + HttpResponse callResponse = mock(HttpResponse.class); + when(callResponse.statusCode()).thenReturn(500); + when(callResponse.body()).thenReturn("Internal Failure"); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(callResponse)); + + ToolResult result = client.invokeTool("test-tool", Map.of()).join(); + assertNotNull(result); + assertTrue(result.isError()); + assertTrue(result.content().get(0).text().contains("Error 500:")); + assertTrue(result.content().get(0).text().contains("Internal Failure")); + } + + @Test + void testInvokeTool_resultWithoutContentField() throws Exception { + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + // result is a JSON object but has no content field + String callBody = "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{}}"; + HttpResponse callResponse = mock(HttpResponse.class); + when(callResponse.statusCode()).thenReturn(200); + when(callResponse.body()).thenReturn(callBody); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(callResponse)); + + ToolResult result = client.invokeTool("test-tool", Map.of()).join(); + assertNotNull(result); + assertFalse(result.isError()); + assertEquals(1, result.content().size()); + assertEquals("", result.content().get(0).text()); + } + + @Test + void testListTools_withMissingInputSchemaOrProperties() throws Exception { + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + // tool definition with no inputSchema at all + String listBody = + "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[{\"name\":\"test-tool\"}]}}"; + HttpResponse listResponse = mock(HttpResponse.class); + when(listResponse.statusCode()).thenReturn(200); + when(listResponse.body()).thenReturn(listBody); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(listResponse)); + + Map tools = client.listTools().join(); + assertNotNull(tools); + assertTrue(tools.containsKey("test-tool")); + ToolDefinition toolDef = tools.get("test-tool"); + assertTrue(toolDef.parameters().isEmpty()); + } + + @Test + void testJsonRpcInstantiation() throws Exception { + // Instantiate package-private JsonRpc namespace to cover its default constructor + java.lang.reflect.Constructor constructor = JsonRpc.class.getDeclaredConstructor(); + constructor.setAccessible(true); + JsonRpc rpc = constructor.newInstance(); + assertNotNull(rpc); + } + + @Test + void testListTools_withMetaNodeEdgeCases() throws Exception { + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + // Case A: _meta has no toolbox/authInvoke key + // Case B: toolbox/authInvoke is not an array (string instead) + String listBody = + "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[{\"name\":\"test-tool\",\"description\":\"A" + + " test tool\",\"inputSchema\":{\"type\":\"object\"," + + "\"properties\":{\"param1\":{\"type\":\"string\"}}}," + + "\"_meta\":{\"toolbox/authInvoke\":\"not-an-array\",\"toolbox/authParam\":\"not-an-object\"}}]}}"; + HttpResponse listResponse = mock(HttpResponse.class); + when(listResponse.statusCode()).thenReturn(200); + when(listResponse.body()).thenReturn(listBody); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(listResponse)); + + Map tools = client.listTools().join(); + assertNotNull(tools); + ToolDefinition toolDef = tools.get("test-tool"); + assertTrue(toolDef.authRequired().isEmpty()); // should skip invalid authInvoke formats + } + + @Test + void testListTools_withRequiredNodeNotArray() throws Exception { + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + // inputSchema.required is a string instead of array + String listBody = + "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[{\"name\":\"test-tool\",\"description\":\"A" + + " test tool\",\"inputSchema\":{\"type\":\"object\"," + + "\"properties\":{\"param1\":{\"type\":\"string\"}},\"required\":\"not-an-array\"}}]}}"; + HttpResponse listResponse = mock(HttpResponse.class); + when(listResponse.statusCode()).thenReturn(200); + when(listResponse.body()).thenReturn(listBody); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(listResponse)); + + Map tools = client.listTools().join(); + assertNotNull(tools); + ToolDefinition toolDef = tools.get("test-tool"); + assertFalse(toolDef.parameters().get(0).required()); // should default to false + } + + @Test + void testListTools_withAuthParamNotArray() throws Exception { + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + // toolbox/authParam.param1 is a string instead of an array + String listBody = + "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[{\"name\":\"test-tool\"," + + "\"description\":\"A test tool\",\"inputSchema\":{\"type\":\"object\"," + + "\"properties\":{\"param1\":{\"type\":\"string\"}}}," + + "\"_meta\":{\"toolbox/authParam\":{\"param1\":\"not-an-array\"}}}]}}"; + HttpResponse listResponse = mock(HttpResponse.class); + when(listResponse.statusCode()).thenReturn(200); + when(listResponse.body()).thenReturn(listBody); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(listResponse)); + + Map tools = client.listTools().join(); + assertNotNull(tools); + ToolDefinition toolDef = tools.get("test-tool"); + assertTrue(toolDef.parameters().get(0).authSources().isEmpty()); + } + + @Test + void testInvokeTool_isErrorTrueWithoutResultNode() throws Exception { + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + // Response has isError = true but NO result or content node + String callBody = "{\"jsonrpc\":\"2.0\",\"id\":1,\"isError\":true}"; + HttpResponse callResponse = mock(HttpResponse.class); + when(callResponse.statusCode()).thenReturn(200); + when(callResponse.body()).thenReturn(callBody); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(callResponse)); + + ToolResult result = client.invokeTool("test-tool", Map.of()).join(); + assertNotNull(result); + assertTrue(result.isError()); + assertEquals(callBody, result.content().get(0).text()); + } +} diff --git a/src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplTest.java b/src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplTest.java new file mode 100644 index 0000000..2ff4387 --- /dev/null +++ b/src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplTest.java @@ -0,0 +1,1088 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.client; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.auth.oauth2.GoogleCredentials; +import com.google.cloud.mcp.JsonRpc; +import com.google.cloud.mcp.McpToolboxClient; +import com.google.cloud.mcp.auth.AuthTokenGetter; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.tool.Tool; +import com.google.cloud.mcp.tool.ToolDefinition; +import com.google.cloud.mcp.tool.ToolPostProcessor; +import com.google.cloud.mcp.tool.ToolPreProcessor; +import com.google.cloud.mcp.tool.ToolResult; +import com.google.cloud.mcp.transport.BaseMcpTransport; +import com.google.cloud.mcp.transport.HttpMcpTransport; +import com.google.cloud.mcp.transport.Transport; +import com.google.cloud.mcp.transport.TransportManifest; +import com.google.cloud.mcp.transport.TransportResponse; +import java.io.IOException; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.mockito.ArgumentCaptor; +import org.mockito.MockedStatic; + +@Timeout(value = 5, unit = java.util.concurrent.TimeUnit.SECONDS) +class McpToolboxClientImplTest { + + private McpToolboxClientImpl client; + private HttpClient mockHttpClient; + private ObjectMapper objectMapper = new ObjectMapper(); + + @BeforeEach + @SuppressWarnings("unchecked") + void setUp() throws Exception { + mockHttpClient = mock(HttpClient.class); + HttpMcpTransport transport = new HttpMcpTransport("http://localhost:8080", mockHttpClient); + CredentialsProvider provider = () -> CompletableFuture.completedFuture("Bearer test-api-key"); + client = new McpToolboxClientImpl(transport, java.util.Collections.emptyMap(), provider); + } + + @Test + void testEnsureInitializedCalledOnce() throws Exception { + // Setup mock responses + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + HttpResponse listResponse = mock(HttpResponse.class); + when(listResponse.statusCode()).thenReturn(200); + when(listResponse.body()) + .thenReturn( + "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[{\"name\":\"test-tool\"," + + "\"description\":\"A test tool\",\"inputSchema\":{\"type\":\"object\"," + + "\"properties\":{\"param1\":{\"type\":\"string\"}}," + + "\"required\":[\"param1\"]}}]}}"); + + // The order of requests will be: initialize, notifications/initialized, tools/list + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(listResponse)) + .thenReturn(CompletableFuture.completedFuture(listResponse)); + + // Call listTools multiple times + client.listTools().join(); + client.listTools().join(); + + // Verify requests + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + verify(mockHttpClient, times(4)).sendAsync(requestCaptor.capture(), any()); + + long initCount = + requestCaptor.getAllValues().stream() + .filter(req -> getBodyStringQuietly(req).contains("\"method\":\"initialize\"")) + .count(); + long notifCount = + requestCaptor.getAllValues().stream() + .filter( + req -> + getBodyStringQuietly(req).contains("\"method\":\"notifications/initialized\"")) + .count(); + long listCount = + requestCaptor.getAllValues().stream() + .filter(req -> getBodyStringQuietly(req).contains("\"method\":\"tools/list\"")) + .count(); + + assertEquals(1, initCount, "initialize should be called exactly once"); + assertEquals(1, notifCount, "notifications/initialized should be called exactly once"); + assertEquals(2, listCount, "tools/list should be called twice"); + } + + @Test + void testListTools() throws Exception { + // Setup mock responses + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + String listBody = + "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[{\"name\":\"test-tool\"," + + "\"description\":\"A test tool\",\"inputSchema\":{\"type\":\"object\"," + + "\"properties\":{\"param1\":{\"type\":\"string\",\"description\":\"param desc\"}}," + + "\"required\":[\"param1\"]},\"_meta\":{\"toolbox/authInvoke\":[\"auth1\"]}}]}}"; + HttpResponse listResponse = mock(HttpResponse.class); + when(listResponse.statusCode()).thenReturn(200); + when(listResponse.body()).thenReturn(listBody); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(listResponse)); + + Map tools = client.listTools().join(); + + assertNotNull(tools); + assertEquals(1, tools.size()); + assertTrue(tools.containsKey("test-tool")); + + ToolDefinition toolDef = tools.get("test-tool"); + assertEquals("A test tool", toolDef.description()); + assertEquals(1, toolDef.authRequired().size()); + assertEquals("auth1", toolDef.authRequired().get(0)); + + assertEquals(1, toolDef.parameters().size()); + ToolDefinition.Parameter param = toolDef.parameters().get(0); + assertEquals("param1", param.name()); + assertEquals("string", param.type()); + assertEquals("param desc", param.description()); + assertTrue(param.required()); + } + + @Test + void testInvokeTool() throws Exception { + // Setup mock responses + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + String callBody = + "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"content\":[{\"type\":\"text\"," + + "\"text\":\"success\"}],\"isError\":false}}"; + HttpResponse callResponse = mock(HttpResponse.class); + when(callResponse.statusCode()).thenReturn(200); + when(callResponse.body()).thenReturn(callBody); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(callResponse)); + + ToolResult result = client.invokeTool("test-tool", Map.of("param1", "value1")).join(); + + assertNotNull(result); + assertFalse(result.isError()); + assertEquals(1, result.content().size()); + assertEquals("success", result.content().get(0).text()); + + // Verify request payload + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + verify(mockHttpClient, times(3)).sendAsync(requestCaptor.capture(), any()); + + HttpRequest callReq = requestCaptor.getAllValues().get(2); + String bodyStr = getBodyString(callReq); + + JsonNode root = objectMapper.readTree(bodyStr); + assertEquals("tools/call", root.get("method").asText()); + JsonNode params = root.get("params"); + assertEquals("test-tool", params.get("name").asText()); + assertEquals("value1", params.get("arguments").get("param1").asText()); + } + + private String getBodyStringQuietly(HttpRequest request) { + try { + return getBodyString(request); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private String getBodyString(HttpRequest request) throws Exception { + if (request.bodyPublisher().isPresent()) { + var publisher = request.bodyPublisher().get(); + var subscriber = + HttpResponse.BodySubscribers.ofString(java.nio.charset.StandardCharsets.UTF_8); + publisher.subscribe( + new java.util.concurrent.Flow.Subscriber() { + @Override + public void onSubscribe(java.util.concurrent.Flow.Subscription subscription) { + subscriber.onSubscribe(subscription); + subscription.request(Long.MAX_VALUE); + } + + @Override + public void onNext(java.nio.ByteBuffer item) { + subscriber.onNext(java.util.List.of(item)); + } + + @Override + public void onError(Throwable throwable) { + subscriber.onError(throwable); + } + + @Override + public void onComplete() { + subscriber.onComplete(); + } + }); + return subscriber.getBody().toCompletableFuture().join(); + } + return ""; + } + + @Test + @SuppressWarnings("unchecked") + void testConstructor_withNullAndEmptyAndRawApiKeys() throws Exception { + Method getAuthHeaderMethod = + McpToolboxClientImpl.class.getDeclaredMethod("getAuthorizationHeader"); + getAuthHeaderMethod.setAccessible(true); + + McpToolboxClientImpl clientNull = + new McpToolboxClientImpl("http://localhost:8080", (String) null); + CompletableFuture futureNull = + (CompletableFuture) getAuthHeaderMethod.invoke(clientNull); + assertNull(futureNull.join()); + + McpToolboxClientImpl clientEmpty = new McpToolboxClientImpl("http://localhost:8080", ""); + CompletableFuture futureEmpty = + (CompletableFuture) getAuthHeaderMethod.invoke(clientEmpty); + assertNull(futureEmpty.join()); + + McpToolboxClientImpl clientRaw = new McpToolboxClientImpl("http://localhost:8080", "my-key"); + CompletableFuture futureRaw = + (CompletableFuture) getAuthHeaderMethod.invoke(clientRaw); + assertEquals("Bearer my-key", futureRaw.join()); + + McpToolboxClientImpl clientBearer = + new McpToolboxClientImpl("http://localhost:8080", "Bearer already-bearer"); + CompletableFuture futureBearer = + (CompletableFuture) getAuthHeaderMethod.invoke(clientBearer); + assertEquals("Bearer already-bearer", futureBearer.join()); + } + + @Test + void testLoadToolset_strictMode_unknownToolsThrowsException() throws Exception { + // Setup mock responses to return empty tools + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + HttpResponse listResponse = mock(HttpResponse.class); + when(listResponse.statusCode()).thenReturn(200); + when(listResponse.body()) + .thenReturn("{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}"); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(listResponse)); + + // Try strict loading with binding for unknown tool + Exception exception = + org.junit.jupiter.api.Assertions.assertThrows( + Exception.class, + () -> + client.loadToolset("my-set", Map.of("unknown-tool", Map.of()), null, true).join()); + Throwable cause = exception.getCause(); + assertNotNull(cause); + assertTrue(cause instanceof IllegalArgumentException); + assertTrue( + cause + .getMessage() + .contains("Strict mode error: Bindings provided for unknown tools: [unknown-tool]")); + } + + @Test + void testLoadTool_notFoundThrowsException() throws Exception { + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + HttpResponse listResponse = mock(HttpResponse.class); + when(listResponse.statusCode()).thenReturn(200); + when(listResponse.body()) + .thenReturn("{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}"); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(listResponse)); + + Exception exception = + org.junit.jupiter.api.Assertions.assertThrows( + Exception.class, () -> client.loadTool("non-existent-tool").join()); + Throwable cause = exception.getCause(); + assertNotNull(cause); + assertTrue(cause instanceof RuntimeException); + assertTrue(cause.getMessage().contains("Tool not found: non-existent-tool")); + } + + @Test + void testLoadTool_successWithAuthTokenGetters() throws Exception { + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + String listBody = + "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[{\"name\":\"test-tool\"," + + "\"description\":\"A test tool\",\"inputSchema\":{\"type\":\"object\"," + + "\"properties\":{\"param1\":{\"type\":\"string\"}}," + + "\"required\":[\"param1\"]}}]}}"; + HttpResponse listResponse = mock(HttpResponse.class); + when(listResponse.statusCode()).thenReturn(200); + when(listResponse.body()).thenReturn(listBody); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(listResponse)); + + Tool tool = + client + .loadTool("test-tool", Map.of("my-svc", () -> CompletableFuture.completedFuture("tok"))) + .join(); + + assertNotNull(tool); + assertEquals("test-tool", tool.name()); + } + + @Test + void testLoadToolset_successWithAuthBinds() throws Exception { + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + String listBody = + "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[{\"name\":\"test-tool\"," + + "\"description\":\"A test tool\",\"inputSchema\":{\"type\":\"object\"," + + "\"properties\":{\"param1\":{\"type\":\"string\"}}," + + "\"required\":[\"param1\"]}}]}}"; + HttpResponse listResponse = mock(HttpResponse.class); + when(listResponse.statusCode()).thenReturn(200); + when(listResponse.body()).thenReturn(listBody); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(listResponse)); + + Map> authBinds = + Map.of("test-tool", Map.of("my-svc", () -> CompletableFuture.completedFuture("tok"))); + + Map tools = client.loadToolset("my-set", null, authBinds, false).join(); + assertNotNull(tools); + assertTrue(tools.containsKey("test-tool")); + Tool tool = tools.get("test-tool"); + assertEquals("test-tool", tool.name()); + } + + @Test + void testEnsureInitialized_withHttpsBaseUrl() throws Exception { + HttpMcpTransport transport = new HttpMcpTransport("https://localhost:8443", mockHttpClient); + CredentialsProvider provider = () -> CompletableFuture.completedFuture("Bearer test-api-key"); + McpToolboxClientImpl httpsClient = + new McpToolboxClientImpl(transport, java.util.Collections.emptyMap(), provider); + + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + HttpResponse listResponse = mock(HttpResponse.class); + when(listResponse.statusCode()).thenReturn(200); + when(listResponse.body()) + .thenReturn("{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}"); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(listResponse)); + + httpsClient.listTools().join(); // should succeed and NOT print any HTTP_WARNING + } + + @Test + void testEnsureInitialized_withoutApiKeyFallbackToAdcException() throws Exception { + HttpMcpTransport transport = new HttpMcpTransport("http://localhost:8080", mockHttpClient); + McpToolboxClientImpl noAuthClient = + new McpToolboxClientImpl(transport, java.util.Collections.emptyMap(), null); + + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + HttpResponse listResponse = mock(HttpResponse.class); + when(listResponse.statusCode()).thenReturn(200); + when(listResponse.body()) + .thenReturn("{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}"); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(listResponse)); + + // This triggers getAuthorizationHeader() -> OIDC / ADC resolution -> Exception -> fallback to + // null + noAuthClient.listTools().join(); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + verify(mockHttpClient, times(3)).sendAsync(requestCaptor.capture(), any()); + + HttpRequest initReq = requestCaptor.getAllValues().get(0); + if (initReq.headers().map().containsKey("Authorization")) { + String auth = initReq.headers().firstValue("Authorization").get(); + assertTrue(auth.startsWith("Bearer ")); + } + } + + @Test + void testLoadToolset_withNullToolsetName() throws Exception { + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + HttpResponse listResponse = mock(HttpResponse.class); + when(listResponse.statusCode()).thenReturn(200); + when(listResponse.body()) + .thenReturn("{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}"); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(listResponse)); + + Map tools = client.loadToolset(null).join(); + assertNotNull(tools); + assertTrue(tools.isEmpty()); + } + + @Test + void testLoadToolset_withInvalidUriThrowsException() { + HttpMcpTransport transport = new HttpMcpTransport("http://invalid uri", mockHttpClient); + McpToolboxClientImpl badClient = + new McpToolboxClientImpl(transport, java.util.Collections.emptyMap(), null); + + Exception exception = + org.junit.jupiter.api.Assertions.assertThrows( + Exception.class, () -> badClient.listTools().join()); + assertNotNull(exception.getCause()); + assertTrue(exception.getCause() instanceof IllegalArgumentException); + } + + @Test + void testInvokeTool_withInvalidUriThrowsException() throws Exception { + HttpMcpTransport transport = new HttpMcpTransport("http://invalid uri", mockHttpClient); + Field delegateField = HttpMcpTransport.class.getDeclaredField("delegate"); + delegateField.setAccessible(true); + Object delegate = delegateField.get(transport); + Field initFutureField = BaseMcpTransport.class.getDeclaredField("initFuture"); + initFutureField.setAccessible(true); + initFutureField.set(delegate, CompletableFuture.completedFuture(null)); // bypass initialization + McpToolboxClientImpl badClient = + new McpToolboxClientImpl(transport, java.util.Collections.emptyMap(), null); + + Exception exception = + org.junit.jupiter.api.Assertions.assertThrows( + Exception.class, () -> badClient.invokeTool("test-tool", Map.of()).join()); + assertNotNull(exception.getCause()); + assertTrue(exception.getCause() instanceof IllegalArgumentException); + } + + @Test + void testListTools_withInvalidJsonResponseThrowsException() throws Exception { + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + HttpResponse listResponse = mock(HttpResponse.class); + when(listResponse.statusCode()).thenReturn(200); + when(listResponse.body()).thenReturn("invalid-json-body"); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(listResponse)); + + Exception exception = + org.junit.jupiter.api.Assertions.assertThrows( + Exception.class, () -> client.listTools().join()); + assertNotNull(exception.getCause()); + assertTrue(exception.getCause() instanceof RuntimeException); + assertTrue( + exception.getCause().getCause() + instanceof com.fasterxml.jackson.core.JsonProcessingException); + } + + @Test + @SuppressWarnings("unchecked") + void testGetAuthorizationHeader_withAdcException() throws Exception { + McpToolboxClientImpl noAuthClient = + new McpToolboxClientImpl("http://localhost:8080", (String) null); + Method method = McpToolboxClientImpl.class.getDeclaredMethod("getAuthorizationHeader"); + method.setAccessible(true); + + try (MockedStatic mockedCredentials = mockStatic(GoogleCredentials.class)) { + mockedCredentials + .when(GoogleCredentials::getApplicationDefault) + .thenThrow(new IOException("Simulated ADC exception")); + + CompletableFuture future = (CompletableFuture) method.invoke(noAuthClient); + String header = future.join(); + org.junit.jupiter.api.Assertions.assertNull(header); + } + } + + @Test + void testEnsureInitialized_withNullAuthHeader() throws Exception { + HttpMcpTransport transport = new HttpMcpTransport("http://localhost:8080", mockHttpClient); + McpToolboxClientImpl noAuthClient = + new McpToolboxClientImpl(transport, java.util.Collections.emptyMap(), null); + + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)); + + Field delegateField = HttpMcpTransport.class.getDeclaredField("delegate"); + delegateField.setAccessible(true); + Object delegate = delegateField.get(transport); + + Method initMethod = BaseMcpTransport.class.getDeclaredMethod("ensureInitialized", Map.class); + initMethod.setAccessible(true); + + CompletableFuture future = + (CompletableFuture) initMethod.invoke(delegate, java.util.Collections.emptyMap()); + future.join(); // should complete and NOT set Authorization header + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + verify(mockHttpClient, times(2)).sendAsync(requestCaptor.capture(), any()); + + HttpRequest initReq = requestCaptor.getAllValues().get(0); + assertFalse(initReq.headers().map().containsKey("Authorization")); + } + + @Test + void testConstructor_withTrailingSlashAndNullHeaders() throws Exception { + McpToolboxClientImpl clientWithSlash = + new McpToolboxClientImpl("http://localhost:8080/", (Map) null); + + Field transportField = McpToolboxClientImpl.class.getDeclaredField("transport"); + transportField.setAccessible(true); + Transport transport = (Transport) transportField.get(clientWithSlash); + assertEquals("http://localhost:8080", transport.getBaseUrl()); + + Field headersField = McpToolboxClientImpl.class.getDeclaredField("headers"); + headersField.setAccessible(true); + Map headersMap = (Map) headersField.get(clientWithSlash); + assertNotNull(headersMap); + assertTrue(headersMap.isEmpty()); + } + + @Test + void testEnsureInitialized_withCustomHeaders() throws Exception { + Map customHeaders = + Map.of("X-Custom-Header", "custom-val", "Authorization", "some-apiKey"); + HttpMcpTransport transport = + new HttpMcpTransport("http://localhost:8080", customHeaders, mockHttpClient); + McpToolboxClientImpl customClient = new McpToolboxClientImpl(transport, customHeaders, null); + + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + HttpResponse listResponse = mock(HttpResponse.class); + when(listResponse.statusCode()).thenReturn(200); + when(listResponse.body()) + .thenReturn("{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}"); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(listResponse)); + + customClient.listTools().join(); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + verify(mockHttpClient, times(3)).sendAsync(requestCaptor.capture(), any()); + + HttpRequest initReq = requestCaptor.getAllValues().get(0); + assertEquals("custom-val", initReq.headers().firstValue("X-Custom-Header").orElse(null)); + assertEquals("some-apiKey", initReq.headers().firstValue("Authorization").orElse(null)); + } + + @Test + void testLoadToolset_withVariousBinds() throws Exception { + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + String listBody = + "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[{\"name\":\"test-tool\"," + + "\"description\":\"A test tool\",\"inputSchema\":{\"type\":\"object\"," + + "\"properties\":{\"param1\":{\"type\":\"string\"}}," + + "\"required\":[\"param1\"]}}]}}"; + HttpResponse listResponse = mock(HttpResponse.class); + when(listResponse.statusCode()).thenReturn(200); + when(listResponse.body()).thenReturn(listBody); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(listResponse)); + + Map> paramBinds = Map.of("test-tool", Map.of("param1", "value1")); + Map> authBinds = + Map.of("test-tool", Map.of("my-svc", () -> CompletableFuture.completedFuture("tok"))); + + Map tools = client.loadToolset("my-set", paramBinds, authBinds, true).join(); + assertNotNull(tools); + assertTrue(tools.containsKey("test-tool")); + Tool tool = tools.get("test-tool"); + assertEquals("test-tool", tool.name()); + } + + @Test + @SuppressWarnings("unchecked") + void testConstructor_withCredentialsProvider() throws Exception { + CredentialsProvider provider = () -> CompletableFuture.completedFuture("Bearer provider-token"); + McpToolboxClientImpl client = new McpToolboxClientImpl("http://localhost:8080", provider); + assertNotNull(client); + + Method getAuthHeaderMethod = + McpToolboxClientImpl.class.getDeclaredMethod("getAuthorizationHeader"); + getAuthHeaderMethod.setAccessible(true); + CompletableFuture future = + (CompletableFuture) getAuthHeaderMethod.invoke(client); + assertEquals("Bearer provider-token", future.join()); + } + + @Test + @SuppressWarnings("unchecked") + void testDefaultLoadToolset() throws Exception { + HttpMcpTransport transport = new HttpMcpTransport("http://localhost:8080", mockHttpClient); + McpToolboxClientImpl client = + new McpToolboxClientImpl(transport, java.util.Collections.emptyMap(), null); + + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + String listBody = "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[]}}"; + HttpResponse listResponse = mock(HttpResponse.class); + when(listResponse.statusCode()).thenReturn(200); + when(listResponse.body()).thenReturn(listBody); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(listResponse)); + + Map tools = ((McpToolboxClient) client).loadToolset().join(); + assertNotNull(tools); + assertTrue(tools.isEmpty()); + } + + @Test + @SuppressWarnings("deprecation") + void testCoverageBoosters() throws Exception { + // 1. Cover HttpMcpTransport close() method + HttpMcpTransport transport = new HttpMcpTransport("http://localhost:8080", mockHttpClient); + transport.close(); + + // 2. Cover HttpMcpTransport constructor with null headers + HttpMcpTransport transportWithNullHeaders = + new HttpMcpTransport("http://localhost:8080", null, mockHttpClient); + assertNotNull(transportWithNullHeaders); + + // 3. Cover McpToolboxClientImpl deprecated constructor 1 + McpToolboxClientImpl client1 = + new McpToolboxClientImpl("http://localhost:8080", java.util.Collections.emptyMap(), null); + assertNotNull(client1); + + // 4. Cover McpToolboxClientImpl deprecated constructor 2 + McpToolboxClientImpl client2 = new McpToolboxClientImpl(transport, null); + assertNotNull(client2); + } + + @Test + void testInvokeTool_withNullHeadersThrows() { + HttpMcpTransport transport = new HttpMcpTransport("http://localhost:8080", mockHttpClient); + McpToolboxClientImpl client = + new McpToolboxClientImpl(transport, java.util.Collections.emptyMap(), null); + + CompletableFuture future = + client.invokeTool("test-tool", java.util.Collections.emptyMap(), null); + java.util.concurrent.ExecutionException ex = + org.junit.jupiter.api.Assertions.assertThrows( + java.util.concurrent.ExecutionException.class, future::get); + assertTrue(ex.getCause() instanceof NullPointerException); + } + + @Test + void testListTools_withInvalidToolsetNameThrows() throws Exception { + HttpMcpTransport transport = new HttpMcpTransport("http://localhost:8080", mockHttpClient); + + // Force transport to be initialized first + Field delegateField = HttpMcpTransport.class.getDeclaredField("delegate"); + delegateField.setAccessible(true); + Object delegate = delegateField.get(transport); + Field initFutureField = BaseMcpTransport.class.getDeclaredField("initFuture"); + initFutureField.setAccessible(true); + initFutureField.set(delegate, CompletableFuture.completedFuture(null)); + + CompletableFuture future = + transport.listTools("invalid path with spaces \\", java.util.Collections.emptyMap()); + java.util.concurrent.ExecutionException ex = + org.junit.jupiter.api.Assertions.assertThrows( + java.util.concurrent.ExecutionException.class, future::get); + assertTrue(ex.getCause() instanceof IllegalArgumentException); + } + + @Test + void testEnsureInitialized_withNotificationSerializationFailure() throws Exception { + HttpMcpTransport transport = new HttpMcpTransport("http://localhost:8080", mockHttpClient); + + // Mock ObjectMapper to throw on notification + ObjectMapper mockMapper = mock(ObjectMapper.class); + when(mockMapper.readTree(any(String.class))).thenReturn(new ObjectMapper().readTree("{}")); + when(mockMapper.writeValueAsString(any(JsonRpc.Request.class))).thenReturn("{}"); + when(mockMapper.writeValueAsString(any(JsonRpc.Notification.class))) + .thenThrow(new RuntimeException("Simulated notification serialization failure")); + + Field delegateField = HttpMcpTransport.class.getDeclaredField("delegate"); + delegateField.setAccessible(true); + Object delegate = delegateField.get(transport); + Field mapperField = BaseMcpTransport.class.getDeclaredField("objectMapper"); + mapperField.setAccessible(true); + mapperField.set(delegate, mockMapper); + + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)); + + Method initMethod = BaseMcpTransport.class.getDeclaredMethod("ensureInitialized", Map.class); + initMethod.setAccessible(true); + + CompletableFuture future = + (CompletableFuture) initMethod.invoke(delegate, java.util.Collections.emptyMap()); + + java.util.concurrent.ExecutionException ex = + org.junit.jupiter.api.Assertions.assertThrows( + java.util.concurrent.ExecutionException.class, future::get); + assertTrue(ex.getCause().getMessage().contains("Simulated notification serialization failure")); + } + + @SuppressWarnings("unchecked") + @Test + void testClientPrePostProcessorsPropagation() throws Exception { + Transport mockTransport = mock(Transport.class); + ToolDefinition def = + new ToolDefinition("desc", java.util.List.of(), java.util.List.of(), null, null); + TransportManifest manifest = new TransportManifest(java.util.Map.of("test-tool", def)); + + when(mockTransport.listTools(any(), any())) + .thenReturn(CompletableFuture.completedFuture(manifest)); + when(mockTransport.getBaseUrl()).thenReturn("http://localhost:8080"); + + ToolPreProcessor mockPre = mock(ToolPreProcessor.class); + ToolPostProcessor mockPost = mock(ToolPostProcessor.class); + + McpToolboxClientImpl customClient = + new McpToolboxClientImpl( + mockTransport, + java.util.Collections.emptyMap(), + null, + java.util.List.of(mockPre), + java.util.List.of(mockPost)); + + // 1. Verify loadToolset propagates processors + java.util.Map tools = customClient.loadToolset("", null, null, false).get(); + Tool tool1 = tools.get("test-tool"); + assertNotNull(tool1); + + Field preField = Tool.class.getDeclaredField("preProcessors"); + preField.setAccessible(true); + java.util.List toolPre1 = + (java.util.List) preField.get(tool1); + assertEquals(1, toolPre1.size()); + org.junit.jupiter.api.Assertions.assertSame(mockPre, toolPre1.get(0)); + + Field postField = Tool.class.getDeclaredField("postProcessors"); + postField.setAccessible(true); + java.util.List toolPost1 = + (java.util.List) postField.get(tool1); + assertEquals(1, toolPost1.size()); + org.junit.jupiter.api.Assertions.assertSame(mockPost, toolPost1.get(0)); + + // 2. Verify loadTool propagates processors + Tool tool2 = customClient.loadTool("test-tool").get(); + assertNotNull(tool2); + + java.util.List toolPre2 = + (java.util.List) preField.get(tool2); + assertEquals(1, toolPre2.size()); + org.junit.jupiter.api.Assertions.assertSame(mockPre, toolPre2.get(0)); + + java.util.List toolPost2 = + (java.util.List) postField.get(tool2); + assertEquals(1, toolPost2.size()); + org.junit.jupiter.api.Assertions.assertSame(mockPost, toolPost2.get(0)); + } + + @Test + void testInvokeTool_MalformedJsonResponse_GracefullyFallsBack() throws Exception { + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + HttpResponse invokeResponse = mock(HttpResponse.class); + when(invokeResponse.statusCode()).thenReturn(200); + when(invokeResponse.body()).thenReturn("{invalid-json"); // Trigger JSON parse exception + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(invokeResponse)); + + ToolResult res = client.invokeTool("test-tool", Map.of()).join(); + assertNotNull(res); + assertFalse(res.isError()); + assertEquals("{invalid-json", res.content().get(0).text()); + } + + @Test + @SuppressWarnings("unchecked") + void testGetAuthorizationHeader_WithNoAuthInHeaders() throws Exception { + McpToolboxClientImpl clientNoAuth = + new McpToolboxClientImpl(mock(Transport.class), Map.of("X-Other", "SomeValue"), null); + + Method getAuthHeaderMethod = + McpToolboxClientImpl.class.getDeclaredMethod("getAuthorizationHeader"); + getAuthHeaderMethod.setAccessible(true); + CompletableFuture future = + (CompletableFuture) getAuthHeaderMethod.invoke(clientNoAuth); + assertNull(future.join()); + } + + @Test + void testConstructor_WithOnlyTransport() { + Transport mockTransport = mock(Transport.class); + McpToolboxClientImpl simpleClient = new McpToolboxClientImpl(mockTransport); + assertNotNull(simpleClient); + } + + @Test + @SuppressWarnings("unchecked") + void testGetMergedMetadata_WithMockGenericTransport_AllBranches() throws Exception { + Transport mockTransport = mock(Transport.class); + when(mockTransport.getBaseUrl()).thenReturn("https://test-mcp-service.com"); + when(mockTransport.invokeTool(any(), any(), any())) + .thenReturn( + CompletableFuture.completedFuture(new TransportResponse(200, "{\"result\":{}}"))); + + CredentialsProvider provider = () -> CompletableFuture.completedFuture("Bearer test-api-key"); + McpToolboxClientImpl genericClient = + new McpToolboxClientImpl(mockTransport, Map.of("Custom-Header", "Value"), provider); + + // Call invokeTool with extra metadata to trigger merge + genericClient + .invokeTool( + "test-tool", + Map.of(), + Map.of("Extra-Header", "ExtraValue", "Authorization", "Bearer overridden")) + .join(); + + ArgumentCaptor> metadataCaptor = ArgumentCaptor.forClass(Map.class); + verify(mockTransport).invokeTool(any(), any(), metadataCaptor.capture()); + + Map mergedMetadata = metadataCaptor.getValue(); + assertEquals("Value", mergedMetadata.get("Custom-Header")); + assertEquals("ExtraValue", mergedMetadata.get("Extra-Header")); + assertEquals("Bearer overridden", mergedMetadata.get("Authorization")); + } + + @Test + @SuppressWarnings("unchecked") + void testGetMergedMetadata_WithMockGenericTransport_NullExtraMetadata() throws Exception { + Transport mockTransport = mock(Transport.class); + when(mockTransport.getBaseUrl()).thenReturn("https://test-mcp-service.com"); + when(mockTransport.invokeTool(any(), any(), any())) + .thenReturn( + CompletableFuture.completedFuture(new TransportResponse(200, "{\"result\":{}}"))); + + CredentialsProvider provider = () -> CompletableFuture.completedFuture("Bearer test-api-key"); + McpToolboxClientImpl genericClient = + new McpToolboxClientImpl(mockTransport, Map.of("Custom-Header", "Value"), provider); + + // Call invokeTool with null extra metadata to trigger branch + genericClient.invokeTool("test-tool", Map.of(), (Map) null).join(); + + ArgumentCaptor> metadataCaptor = ArgumentCaptor.forClass(Map.class); + verify(mockTransport).invokeTool(any(), any(), metadataCaptor.capture()); + + Map mergedMetadata = metadataCaptor.getValue(); + assertEquals("Value", mergedMetadata.get("Custom-Header")); + assertEquals("Bearer test-api-key", mergedMetadata.get("Authorization")); + } + + @Test + @SuppressWarnings("unchecked") + void testGetMergedMetadata_WithMockGenericTransport_NullProviderAndEmptyHeaders() + throws Exception { + Transport mockTransport = mock(Transport.class); + when(mockTransport.getBaseUrl()).thenReturn("https://test-mcp-service.com"); + when(mockTransport.invokeTool(any(), any(), any())) + .thenReturn( + CompletableFuture.completedFuture(new TransportResponse(200, "{\"result\":{}}"))); + + McpToolboxClientImpl genericClient = new McpToolboxClientImpl(mockTransport, Map.of(), null); + + genericClient.invokeTool("test-tool", Map.of(), Map.of("Extra-Header", "ExtraValue")).join(); + + ArgumentCaptor> metadataCaptor = ArgumentCaptor.forClass(Map.class); + verify(mockTransport).invokeTool(any(), any(), metadataCaptor.capture()); + + Map mergedMetadata = metadataCaptor.getValue(); + assertEquals("ExtraValue", mergedMetadata.get("Extra-Header")); + assertFalse(mergedMetadata.containsKey("Authorization")); + } + + @Test + void testLoadToolset_withDefaultValuesAndHints() throws Exception { + HttpResponse initResponse = mock(HttpResponse.class); + when(initResponse.statusCode()).thenReturn(200); + when(initResponse.body()).thenReturn("{}"); + + HttpResponse notifResponse = mock(HttpResponse.class); + when(notifResponse.statusCode()).thenReturn(200); + when(notifResponse.body()).thenReturn("{}"); + + HttpResponse listResponse = mock(HttpResponse.class); + when(listResponse.statusCode()).thenReturn(200); + when(listResponse.body()) + .thenReturn( + "{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"tools\":[{" + + "\"name\":\"test-tool\"," + + "\"description\":\"A test tool description\"," + + "\"readOnlyHint\":true," + + "\"destructiveHint\":false," + + "\"inputSchema\":{" + + " \"type\":\"object\"," + + " \"properties\":{" + + " \"param1\":{" + + " \"type\":\"string\"," + + " \"description\":\"parameter 1\"," + + " \"default\":\"default-val\"" + + " }" + + " }" + + "}" + + "}]}}"); + + when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(initResponse)) + .thenReturn(CompletableFuture.completedFuture(notifResponse)) + .thenReturn(CompletableFuture.completedFuture(listResponse)); + + Map tools = client.loadToolset("").join(); + assertNotNull(tools); + assertEquals(1, tools.size()); + + ToolDefinition def = tools.get("test-tool"); + assertNotNull(def); + assertEquals("A test tool description", def.description()); + assertEquals(true, def.readOnlyHint()); + assertEquals(false, def.destructiveHint()); + + assertEquals(1, def.parameters().size()); + ToolDefinition.Parameter param = def.parameters().get(0); + assertEquals("param1", param.name()); + assertEquals("string", param.type()); + assertEquals("default-val", param.defaultValue()); + } +} diff --git a/src/test/java/com/google/cloud/mcp/tool/ToolTest.java b/src/test/java/com/google/cloud/mcp/tool/ToolTest.java new file mode 100644 index 0000000..6fa6154 --- /dev/null +++ b/src/test/java/com/google/cloud/mcp/tool/ToolTest.java @@ -0,0 +1,561 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.tool; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.cloud.mcp.McpToolboxClient; +import com.google.cloud.mcp.auth.ResolvedAuth; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadLocalRandom; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.mockito.ArgumentCaptor; + +@Timeout(10) +class ToolTest { + + private ExecutorService pool; + private McpToolboxClient mockClient; + private ToolDefinition toolDefinition; + private Tool tool; + + @BeforeEach + void setUp() { + pool = Executors.newFixedThreadPool(8); + mockClient = mock(McpToolboxClient.class); + toolDefinition = new ToolDefinition("Test Tool", null, null); + tool = new Tool("test_tool", toolDefinition, mockClient); + } + + @AfterEach + void tearDown() { + if (pool != null) { + pool.shutdownNow(); + } + } + + /** + * Regression test: when several authenticated services resolve their tokens concurrently, {@link + * Tool#execute(Map)} must not drop any credential header or authenticated argument from the + * outgoing request. The previous implementation mutated the non-thread-safe finalArgs / + * extraHeaders HashMaps from each getter's completion thread, which could lose writes. + */ + @Test + void execute_withManyConcurrentAuthGetters_doesNotDropCredentials() { + int services = Integer.getInteger("toolTest.services", 24); + int iterations = Integer.getInteger("toolTest.iters", 2500); + + List params = new ArrayList<>(); + for (int i = 0; i < services; i++) { + params.add(new ToolDefinition.Parameter("p" + i, "string", false, "", List.of("svc" + i))); + } + ToolDefinition def = new ToolDefinition("race-tool", params, new ArrayList<>()); + + List> capturedArgs = new ArrayList<>(); + List> capturedHeaders = new ArrayList<>(); + + McpToolboxClient client = mock(McpToolboxClient.class); + when(client.invokeTool(anyString(), anyMap(), anyMap())) + .thenAnswer( + inv -> { + capturedArgs.add(new HashMap<>(inv.getArgument(1))); + capturedHeaders.add(new HashMap<>(inv.getArgument(2))); + return CompletableFuture.completedFuture( + new ToolResult(List.of(new ToolResult.Content("text", "ok")), false)); + }); + + Tool raceTool = new Tool("race-tool", def, client); + for (int i = 0; i < services; i++) { + final String token = "tok-" + i; + raceTool.addAuthTokenGetter( + "svc" + i, + () -> + CompletableFuture.supplyAsync( + () -> { + int spins = ThreadLocalRandom.current().nextInt(50); + for (int s = 0; s < spins; s++) { + Thread.onSpinWait(); + } + return token; + }, + pool)); + } + + for (int iter = 0; iter < iterations; iter++) { + raceTool.execute(new HashMap<>()).join(); + } + + assertEquals(iterations, capturedHeaders.size(), "every invocation should reach the client"); + for (int iter = 0; iter < iterations; iter++) { + Map headers = capturedHeaders.get(iter); + Map args = capturedArgs.get(iter); + for (int i = 0; i < services; i++) { + assertEquals( + "tok-" + i, + headers.get("svc" + i + "_token"), + "missing/garbled svc" + i + "_token header on iteration " + iter); + assertEquals( + "tok-" + i, args.get("p" + i), "missing/garbled p" + i + " arg on iteration " + iter); + } + } + } + + @Test + void resolvedAuth_appliesTokensWithCorrectBearerNormalization() { + ToolDefinition def = new ToolDefinition("test-tool", List.of(), List.of()); + Map tokens = + Map.of( + "svc-raw", "rawToken123", + "svc-prefixed", "Bearer alreadyPrefixed456", + "svc-lowercase-prefixed", "bearer alreadyPrefixed789"); + + ResolvedAuth resolvedAuth = new ResolvedAuth(tokens); + Map finalArgs = new HashMap<>(); + Map extraHeaders = new HashMap<>(); + + resolvedAuth.applyTo(finalArgs, extraHeaders, def); + + // Verify token values map to sdk custom headers + assertEquals("rawToken123", extraHeaders.get("svc-raw_token")); + assertEquals("Bearer alreadyPrefixed456", extraHeaders.get("svc-prefixed_token")); + assertEquals("bearer alreadyPrefixed789", extraHeaders.get("svc-lowercase-prefixed_token")); + + // Verify standard OIDC authorization header matches and doesn't double prefix + String authHeader = extraHeaders.get("Authorization"); + assertTrue( + authHeader.equals("Bearer rawToken123") + || authHeader.equals("Bearer alreadyPrefixed456") + || authHeader.equals("bearer alreadyPrefixed789")); + } + + @Test + void resolvedAuth_withNullAndEmptyTokens_ignoresThemSafely() { + ToolDefinition def = new ToolDefinition("test-tool", List.of(), List.of()); + Map tokens = new HashMap<>(); + tokens.put("svc-null", null); + tokens.put("svc-empty", ""); + tokens.put("svc-valid", "validToken"); + + ResolvedAuth resolvedAuth = new ResolvedAuth(tokens); + Map finalArgs = new HashMap<>(); + Map extraHeaders = new HashMap<>(); + + resolvedAuth.applyTo(finalArgs, extraHeaders, def); + + // Only the valid token should be mapped + assertEquals("Bearer validToken", extraHeaders.get("Authorization")); + assertEquals("validToken", extraHeaders.get("svc-valid_token")); + assertTrue(!extraHeaders.containsKey("svc-null_token")); + assertTrue(!extraHeaders.containsKey("svc-empty_token")); + } + + @Test + void testToolGetters() { + ToolDefinition def = new ToolDefinition("test-tool", List.of(), List.of()); + McpToolboxClient client = mock(McpToolboxClient.class); + Tool tool = new Tool("test-tool", def, client); + + assertEquals("test-tool", tool.name()); + assertEquals(def, tool.definition()); + } + + @Test + void testBindParamStaticAndSupplier() throws Exception { + List params = + List.of( + new ToolDefinition.Parameter("p-static", "string", false, "desc", List.of()), + new ToolDefinition.Parameter("p-supplier", "string", false, "desc", List.of())); + ToolDefinition def = new ToolDefinition("test-tool", params, List.of()); + McpToolboxClient client = mock(McpToolboxClient.class); + + List> capturedArgs = new ArrayList<>(); + when(client.invokeTool(anyString(), anyMap(), anyMap())) + .thenAnswer( + inv -> { + capturedArgs.add(new HashMap<>(inv.getArgument(1))); + return CompletableFuture.completedFuture(new ToolResult(List.of(), false)); + }); + + Tool tool = new Tool("test-tool", def, client); + tool.bindParam("p-static", "static-value"); + tool.bindParam("p-supplier", () -> "supplier-value"); + + tool.execute(Map.of()).join(); + + assertEquals(1, capturedArgs.size()); + Map args = capturedArgs.get(0); + assertEquals("static-value", args.get("p-static")); + assertEquals("supplier-value", args.get("p-supplier")); + } + + @Test + void testResolvedAuth_withNullParametersListInDefinition() { + ToolDefinition def = new ToolDefinition("test-tool", null, List.of()); + ResolvedAuth resolvedAuth = new ResolvedAuth(Map.of("svc", "token")); + Map finalArgs = new HashMap<>(); + Map extraHeaders = new HashMap<>(); + + resolvedAuth.applyTo(finalArgs, extraHeaders, def); + + assertEquals("Bearer token", extraHeaders.get("Authorization")); + assertTrue(finalArgs.isEmpty()); + } + + @Test + void testResolvedAuth_withNullTokensMap() { + ToolDefinition def = new ToolDefinition("test-tool", List.of(), List.of()); + ResolvedAuth resolvedAuth = new ResolvedAuth(null); + Map finalArgs = new HashMap<>(); + Map extraHeaders = new HashMap<>(); + + resolvedAuth.applyTo(finalArgs, extraHeaders, def); + + assertTrue(finalArgs.isEmpty()); + assertTrue(extraHeaders.isEmpty()); + } + + @Test + void testResolvedAuth_withNullKeysAndValuesInTokensMap() { + ToolDefinition def = new ToolDefinition("test-tool", List.of(), List.of()); + Map tokens = new HashMap<>(); + tokens.put(null, "val1"); + tokens.put("svc2", null); + tokens.put("svc3", "val3"); + + ResolvedAuth resolvedAuth = new ResolvedAuth(tokens); + Map finalArgs = new HashMap<>(); + Map extraHeaders = new HashMap<>(); + + resolvedAuth.applyTo(finalArgs, extraHeaders, def); + + assertEquals("Bearer val3", extraHeaders.get("Authorization")); + assertEquals("val3", extraHeaders.get("svc3_token")); + assertTrue(!extraHeaders.containsKey("svc2_token")); + assertTrue(!extraHeaders.containsKey("null_token")); + } + + @Test + void testValidateAndSanitizeArgs_customTypeMatch() throws Exception { + List params = + List.of( + new ToolDefinition.Parameter("p-custom", "custom-type-name", false, "desc", List.of())); + ToolDefinition def = new ToolDefinition("test-tool", params, List.of()); + McpToolboxClient client = mock(McpToolboxClient.class); + when(client.invokeTool(anyString(), anyMap(), anyMap())) + .thenReturn(CompletableFuture.completedFuture(new ToolResult(List.of(), false))); + + Tool tool = new Tool("test-tool", def, client); + tool.execute(Map.of("p-custom", "any-value")).join(); // should succeed + } + + @Test + void testValidateAndSanitizeArgs_withNullParameters() throws Exception { + ToolDefinition def = new ToolDefinition("test-tool", null, List.of()); + McpToolboxClient client = mock(McpToolboxClient.class); + when(client.invokeTool(anyString(), anyMap(), anyMap())) + .thenReturn(CompletableFuture.completedFuture(new ToolResult(List.of(), false))); + + Tool tool = new Tool("test-tool", def, client); + tool.execute(Map.of("any-param", "any-value")).join(); // should bypass validation loop safely + } + + @Test + void testDefaultValueInjection() throws Exception { + McpToolboxClient mockClient = mock(McpToolboxClient.class); + + ToolDefinition.Parameter paramWithDefault = + new ToolDefinition.Parameter( + "param1", "string", false, "A parameter", null, "default_value"); + ToolDefinition.Parameter paramNoDefault = + new ToolDefinition.Parameter("param2", "string", false, "Another parameter", null, null); + + ToolDefinition def = + new ToolDefinition("A test tool", List.of(paramWithDefault, paramNoDefault), null); + + Tool tool = new Tool("testTool", def, mockClient); + + when(mockClient.invokeTool(eq("testTool"), any(), any())) + .thenReturn( + CompletableFuture.completedFuture(new ToolResult(Collections.emptyList(), false))); + + Map args = new HashMap<>(); + args.put("param2", "provided_value"); + + CompletableFuture future = tool.execute(args); + future.join(); // Wait for execution + + @SuppressWarnings("unchecked") + ArgumentCaptor> argsCaptor = ArgumentCaptor.forClass(Map.class); + @SuppressWarnings("unchecked") + ArgumentCaptor> headersCaptor = ArgumentCaptor.forClass(Map.class); + + verify(mockClient).invokeTool(eq("testTool"), argsCaptor.capture(), headersCaptor.capture()); + + Map capturedArgs = argsCaptor.getValue(); + + assertEquals( + "default_value", + capturedArgs.get("param1"), + "Default value should be injected when not provided"); + assertEquals("provided_value", capturedArgs.get("param2"), "Provided value should be kept"); + } + + @Test + void testDefaultValueNotOverwritten() throws Exception { + McpToolboxClient mockClient = mock(McpToolboxClient.class); + + ToolDefinition.Parameter paramWithDefault = + new ToolDefinition.Parameter( + "param1", "string", false, "A parameter", null, "default_value"); + + ToolDefinition def = new ToolDefinition("A test tool", List.of(paramWithDefault), null); + + Tool tool = new Tool("testTool", def, mockClient); + + when(mockClient.invokeTool(eq("testTool"), any(), any())) + .thenReturn( + CompletableFuture.completedFuture(new ToolResult(Collections.emptyList(), false))); + + Map args = new HashMap<>(); + args.put("param1", "custom_value"); + + CompletableFuture future = tool.execute(args); + future.join(); // Wait for execution + + @SuppressWarnings("unchecked") + ArgumentCaptor> argsCaptor = ArgumentCaptor.forClass(Map.class); + @SuppressWarnings("unchecked") + ArgumentCaptor> headersCaptor = ArgumentCaptor.forClass(Map.class); + + verify(mockClient).invokeTool(eq("testTool"), argsCaptor.capture(), headersCaptor.capture()); + + Map capturedArgs = argsCaptor.getValue(); + + assertEquals( + "custom_value", + capturedArgs.get("param1"), + "Provided value should not be overwritten by default value"); + } + + @Test + void testDefaultValueDeepCloning() throws Exception { + McpToolboxClient mockClient = mock(McpToolboxClient.class); + + Map complexDefault = new HashMap<>(); + complexDefault.put("key", "value"); + + ToolDefinition.Parameter paramWithDefault = + new ToolDefinition.Parameter( + "param1", "object", false, "A parameter", null, complexDefault); + + ToolDefinition def = new ToolDefinition("A test tool", List.of(paramWithDefault), null); + + Tool tool = new Tool("testTool", def, mockClient); + + when(mockClient.invokeTool(eq("testTool"), any(), any())) + .thenReturn( + CompletableFuture.completedFuture(new ToolResult(Collections.emptyList(), false))); + + Map args = new HashMap<>(); + CompletableFuture future = tool.execute(args); + future.join(); + + @SuppressWarnings("unchecked") + ArgumentCaptor> argsCaptor = ArgumentCaptor.forClass(Map.class); + verify(mockClient).invokeTool(eq("testTool"), argsCaptor.capture(), any()); + + Map capturedArgs = argsCaptor.getValue(); + @SuppressWarnings("unchecked") + Map injectedDefault = (Map) capturedArgs.get("param1"); + + // Mutate the injected map + injectedDefault.put("key", "mutated_value"); + + // Ensure the original defaultValue stored in the definition remains untouched + @SuppressWarnings("unchecked") + Map defValueInDefinition = + (Map) def.parameters().get(0).defaultValue(); + assertEquals( + "value", + defValueInDefinition.get("key"), + "The default value in definition must remain unmutated"); + } + + @Test + void testDefaultValueDeepCloning_withList() throws Exception { + McpToolboxClient mockClient = mock(McpToolboxClient.class); + + List complexDefault = new ArrayList<>(); + complexDefault.add("value"); + + ToolDefinition.Parameter paramWithDefault = + new ToolDefinition.Parameter("param1", "array", false, "A parameter", null, complexDefault); + + ToolDefinition def = new ToolDefinition("A test tool", List.of(paramWithDefault), null); + + Tool tool = new Tool("testTool", def, mockClient); + + when(mockClient.invokeTool(eq("testTool"), any(), any())) + .thenReturn( + CompletableFuture.completedFuture(new ToolResult(Collections.emptyList(), false))); + + Map args = new HashMap<>(); + CompletableFuture future = tool.execute(args); + future.join(); + + @SuppressWarnings("unchecked") + ArgumentCaptor> argsCaptor = ArgumentCaptor.forClass(Map.class); + verify(mockClient).invokeTool(eq("testTool"), argsCaptor.capture(), any()); + + Map capturedArgs = argsCaptor.getValue(); + @SuppressWarnings("unchecked") + List injectedDefault = (List) capturedArgs.get("param1"); + + // Mutate the injected list + injectedDefault.set(0, "mutated_value"); + + // Ensure the original defaultValue stored in the definition remains untouched + @SuppressWarnings("unchecked") + List defValueInDefinition = (List) def.parameters().get(0).defaultValue(); + assertEquals( + "value", + defValueInDefinition.get(0), + "The default value in definition must remain unmutated"); + } + + @Test + void testToolDefinitionHints() { + ToolDefinition defWithHints = + new ToolDefinition("A test tool", List.of(), List.of(), true, false); + + assertEquals(true, defWithHints.readOnlyHint()); + assertEquals(false, defWithHints.destructiveHint()); + + ToolDefinition defWithoutHints = new ToolDefinition("A test tool", List.of(), List.of()); + assertEquals(null, defWithoutHints.readOnlyHint()); + assertEquals(null, defWithoutHints.destructiveHint()); + } + + @Test + @SuppressWarnings("unchecked") + void testExecute_withPreAndPostProcessors_modifiesArgsAndResult() throws Exception { + // Arrange + Map initialArgs = new HashMap<>(); + initialArgs.put("arg1", "val1"); + + ToolResult originalResult = + new ToolResult(List.of(new ToolResult.Content("text", "original")), false); + ToolResult modifiedResult = + new ToolResult(List.of(new ToolResult.Content("text", "modified")), false); + + ToolPreProcessor preProcessor1 = + (name, args) -> { + Map newArgs = new HashMap<>(args); + newArgs.put("arg2", "val2"); + return CompletableFuture.completedFuture(newArgs); + }; + + ToolPreProcessor preProcessor2 = + (name, args) -> { + Map newArgs = new HashMap<>(args); + newArgs.put("arg3", "val3"); + return CompletableFuture.completedFuture(newArgs); + }; + + ToolPostProcessor postProcessor = + (name, result) -> { + if (result.content().get(0).text().equals("original")) { + return CompletableFuture.completedFuture(modifiedResult); + } + return CompletableFuture.completedFuture(result); + }; + + tool.addPreProcessor(preProcessor1); + tool.addPreProcessor(preProcessor2); + tool.addPostProcessor(postProcessor); + + when(mockClient.invokeTool(eq("test_tool"), anyMap(), anyMap())) + .thenReturn(CompletableFuture.completedFuture(originalResult)); + + // Act + CompletableFuture futureResult = tool.execute(initialArgs); + ToolResult finalResult = futureResult.get(); + + // Assert + ArgumentCaptor> argsCaptor = ArgumentCaptor.forClass(Map.class); + verify(mockClient, times(1)).invokeTool(eq("test_tool"), argsCaptor.capture(), anyMap()); + + Map capturedArgs = argsCaptor.getValue(); + assertEquals(3, capturedArgs.size()); + assertEquals("val1", capturedArgs.get("arg1")); + assertEquals("val2", capturedArgs.get("arg2")); + assertEquals("val3", capturedArgs.get("arg3")); + + assertSame(modifiedResult, finalResult); + } + + @Test + void testExecute_preProcessorException_failsFutureWithoutInvokingClient() { + // Arrange + Map initialArgs = new HashMap<>(); + + ToolPreProcessor preProcessor = + (name, args) -> CompletableFuture.failedFuture(new RuntimeException("PreProcessor failed")); + + tool.addPreProcessor(preProcessor); + + // Act + CompletableFuture futureResult = tool.execute(initialArgs); + + // Assert + assertTrue(futureResult.isCompletedExceptionally()); + + Exception exception = null; + try { + futureResult.get(); + } catch (InterruptedException | ExecutionException e) { + exception = e; + } + assertTrue(exception.getCause() instanceof RuntimeException); + assertEquals("PreProcessor failed", exception.getCause().getMessage()); + + verify(mockClient, never()).invokeTool(eq("test_tool"), anyMap(), anyMap()); + verify(mockClient, never()).invokeTool(eq("test_tool"), anyMap()); + } +} diff --git a/src/test/java/com/google/cloud/mcp/tool/ToolValidationTest.java b/src/test/java/com/google/cloud/mcp/tool/ToolValidationTest.java new file mode 100644 index 0000000..e645a73 --- /dev/null +++ b/src/test/java/com/google/cloud/mcp/tool/ToolValidationTest.java @@ -0,0 +1,387 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.tool; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.cloud.mcp.McpToolboxClient; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.mockito.ArgumentCaptor; + +@Timeout(10) +class ToolValidationTest { + + private McpToolboxClient mockClient; + + @BeforeEach + void setUp() { + mockClient = mock(McpToolboxClient.class); + } + + @Test + void testValidateAndSanitizeArgs_nullsRemoved() throws Exception { + ToolDefinition def = new ToolDefinition("test-tool", List.of(), List.of()); + + List> capturedArgs = new ArrayList<>(); + when(mockClient.invokeTool(anyString(), anyMap(), anyMap())) + .thenAnswer( + inv -> { + capturedArgs.add(new HashMap<>(inv.getArgument(1))); + return CompletableFuture.completedFuture(new ToolResult(List.of(), false)); + }); + + Tool tool = new Tool("test-tool", def, mockClient); + Map inputArgs = new HashMap<>(); + inputArgs.put("param-null", null); + inputArgs.put("param-valid", "value"); + + tool.execute(inputArgs).join(); + + assertEquals(1, capturedArgs.size()); + Map args = capturedArgs.get(0); + assertTrue(args.containsKey("param-valid")); + assertFalse(args.containsKey("param-null")); + } + + @Test + void testValidateAndSanitizeArgs_missingRequired() { + List params = + List.of(new ToolDefinition.Parameter("p-required", "string", true, "desc", List.of())); + ToolDefinition def = new ToolDefinition("test-tool", params, List.of()); + Tool tool = new Tool("test-tool", def, mockClient); + + CompletionException exception = + org.junit.jupiter.api.Assertions.assertThrows( + CompletionException.class, () -> tool.execute(Map.of()).join()); + assertTrue(exception.getCause() instanceof IllegalArgumentException); + assertTrue( + exception.getCause().getMessage().contains("Missing required parameter 'p-required'")); + } + + @Test + void testValidateAndSanitizeArgs_typeMismatches() { + List params = + List.of( + new ToolDefinition.Parameter("p-string", "string", false, "desc", List.of()), + new ToolDefinition.Parameter("p-int", "integer", false, "desc", List.of()), + new ToolDefinition.Parameter("p-number", "number", false, "desc", List.of()), + new ToolDefinition.Parameter("p-bool", "boolean", false, "desc", List.of()), + new ToolDefinition.Parameter("p-array", "array", false, "desc", List.of()), + new ToolDefinition.Parameter("p-obj", "object", false, "desc", List.of())); + ToolDefinition def = new ToolDefinition("test-tool", params, List.of()); + Tool tool = new Tool("test-tool", def, mockClient); + + // Expected string, got integer + CompletionException ex1 = + org.junit.jupiter.api.Assertions.assertThrows( + CompletionException.class, () -> tool.execute(Map.of("p-string", 123)).join()); + assertTrue(ex1.getCause() instanceof IllegalArgumentException); + + // Expected integer, got string + CompletionException ex2 = + org.junit.jupiter.api.Assertions.assertThrows( + CompletionException.class, () -> tool.execute(Map.of("p-int", "not-an-int")).join()); + assertTrue(ex2.getCause() instanceof IllegalArgumentException); + + // Expected number, got string + CompletionException ex3 = + org.junit.jupiter.api.Assertions.assertThrows( + CompletionException.class, + () -> tool.execute(Map.of("p-number", "not-a-number")).join()); + assertTrue(ex3.getCause() instanceof IllegalArgumentException); + + // Expected boolean, got string + CompletionException ex4 = + org.junit.jupiter.api.Assertions.assertThrows( + CompletionException.class, + () -> tool.execute(Map.of("p-bool", "not-a-boolean")).join()); + assertTrue(ex4.getCause() instanceof IllegalArgumentException); + + // Expected array, got string + CompletionException ex5 = + org.junit.jupiter.api.Assertions.assertThrows( + CompletionException.class, + () -> tool.execute(Map.of("p-array", "not-an-array")).join()); + assertTrue(ex5.getCause() instanceof IllegalArgumentException); + + // Expected object, got string + CompletionException ex6 = + org.junit.jupiter.api.Assertions.assertThrows( + CompletionException.class, () -> tool.execute(Map.of("p-obj", "not-an-object")).join()); + assertTrue(ex6.getCause() instanceof IllegalArgumentException); + } + + @Test + void testValidateAndSanitizeArgs_typeMatches() throws Exception { + List params = + List.of( + new ToolDefinition.Parameter("p-string", "string", false, "desc", List.of()), + new ToolDefinition.Parameter("p-int", "integer", false, "desc", List.of()), + new ToolDefinition.Parameter("p-int-val", "integer", false, "desc", List.of()), + new ToolDefinition.Parameter("p-number", "number", false, "desc", List.of()), + new ToolDefinition.Parameter("p-bool", "boolean", false, "desc", List.of()), + new ToolDefinition.Parameter("p-array", "array", false, "desc", List.of()), + new ToolDefinition.Parameter("p-array-arr", "array", false, "desc", List.of()), + new ToolDefinition.Parameter("p-obj", "object", false, "desc", List.of())); + ToolDefinition def = new ToolDefinition("test-tool", params, List.of()); + when(mockClient.invokeTool(anyString(), anyMap(), anyMap())) + .thenReturn(CompletableFuture.completedFuture(new ToolResult(List.of(), false))); + + Tool tool = new Tool("test-tool", def, mockClient); + tool.execute( + Map.of( + "p-string", + "valid-string", + "p-int", + 123L, + "p-int-val", + 123, + "p-number", + 4.56, + "p-bool", + true, + "p-array", + List.of("item"), + "p-array-arr", + new String[] {"item"}, + "p-obj", + Map.of("key", "val"))) + .join(); // should succeed without exceptions + } + + @Test + void testValidateAndSanitizeArgs_customTypeMatch() throws Exception { + List params = + List.of( + new ToolDefinition.Parameter("p-custom", "custom-type-name", false, "desc", List.of())); + ToolDefinition def = new ToolDefinition("test-tool", params, List.of()); + when(mockClient.invokeTool(anyString(), anyMap(), anyMap())) + .thenReturn(CompletableFuture.completedFuture(new ToolResult(List.of(), false))); + + Tool tool = new Tool("test-tool", def, mockClient); + tool.execute(Map.of("p-custom", "any-value")).join(); // should succeed + } + + @Test + void testValidateAndSanitizeArgs_withNullParameters() throws Exception { + ToolDefinition def = new ToolDefinition("test-tool", null, List.of()); + when(mockClient.invokeTool(anyString(), anyMap(), anyMap())) + .thenReturn(CompletableFuture.completedFuture(new ToolResult(List.of(), false))); + + Tool tool = new Tool("test-tool", def, mockClient); + tool.execute(Map.of("any-param", "any-value")).join(); // should bypass validation loop safely + } + + @Test + void testDefaultValueInjection() throws Exception { + ToolDefinition.Parameter paramWithDefault = + new ToolDefinition.Parameter( + "param1", "string", false, "A parameter", null, "default_value"); + ToolDefinition.Parameter paramNoDefault = + new ToolDefinition.Parameter("param2", "string", false, "Another parameter", null, null); + + ToolDefinition def = + new ToolDefinition("A test tool", List.of(paramWithDefault, paramNoDefault), null); + + Tool tool = new Tool("testTool", def, mockClient); + + when(mockClient.invokeTool(eq("testTool"), any(), any())) + .thenReturn( + CompletableFuture.completedFuture(new ToolResult(Collections.emptyList(), false))); + + Map args = new HashMap<>(); + args.put("param2", "provided_value"); + + CompletableFuture future = tool.execute(args); + future.join(); // Wait for execution + + @SuppressWarnings("unchecked") + ArgumentCaptor> argsCaptor = ArgumentCaptor.forClass(Map.class); + @SuppressWarnings("unchecked") + ArgumentCaptor> headersCaptor = ArgumentCaptor.forClass(Map.class); + + verify(mockClient).invokeTool(eq("testTool"), argsCaptor.capture(), headersCaptor.capture()); + + Map capturedArgs = argsCaptor.getValue(); + + assertEquals( + "default_value", + capturedArgs.get("param1"), + "Default value should be injected when not provided"); + assertEquals("provided_value", capturedArgs.get("param2"), "Provided value should be kept"); + } + + @Test + void testDefaultValueNotOverwritten() throws Exception { + ToolDefinition.Parameter paramWithDefault = + new ToolDefinition.Parameter( + "param1", "string", false, "A parameter", null, "default_value"); + + ToolDefinition def = new ToolDefinition("A test tool", List.of(paramWithDefault), null); + + Tool tool = new Tool("testTool", def, mockClient); + + when(mockClient.invokeTool(eq("testTool"), any(), any())) + .thenReturn( + CompletableFuture.completedFuture(new ToolResult(Collections.emptyList(), false))); + + Map args = new HashMap<>(); + args.put("param1", "custom_value"); + + CompletableFuture future = tool.execute(args); + future.join(); // Wait for execution + + @SuppressWarnings("unchecked") + ArgumentCaptor> argsCaptor = ArgumentCaptor.forClass(Map.class); + @SuppressWarnings("unchecked") + ArgumentCaptor> headersCaptor = ArgumentCaptor.forClass(Map.class); + + verify(mockClient).invokeTool(eq("testTool"), argsCaptor.capture(), headersCaptor.capture()); + + Map capturedArgs = argsCaptor.getValue(); + + assertEquals( + "custom_value", + capturedArgs.get("param1"), + "Provided value should not be overwritten by default value"); + } + + @Test + void testDefaultValueDeepCloning() throws Exception { + Map complexDefault = new HashMap<>(); + complexDefault.put("key", "value"); + + ToolDefinition.Parameter paramWithDefault = + new ToolDefinition.Parameter( + "param1", "object", false, "A parameter", null, complexDefault); + + ToolDefinition def = new ToolDefinition("A test tool", List.of(paramWithDefault), null); + + Tool tool = new Tool("testTool", def, mockClient); + + when(mockClient.invokeTool(eq("testTool"), any(), any())) + .thenReturn( + CompletableFuture.completedFuture(new ToolResult(Collections.emptyList(), false))); + + Map args = new HashMap<>(); + CompletableFuture future = tool.execute(args); + future.join(); + + @SuppressWarnings("unchecked") + ArgumentCaptor> argsCaptor = ArgumentCaptor.forClass(Map.class); + verify(mockClient).invokeTool(eq("testTool"), argsCaptor.capture(), any()); + + Map capturedArgs = argsCaptor.getValue(); + @SuppressWarnings("unchecked") + Map injectedDefault = (Map) capturedArgs.get("param1"); + + // Mutate the injected map + injectedDefault.put("key", "mutated_value"); + + // Ensure the original defaultValue stored in the definition remains untouched + @SuppressWarnings("unchecked") + Map defValueInDefinition = + (Map) def.parameters().get(0).defaultValue(); + assertEquals( + "value", + defValueInDefinition.get("key"), + "The default value in definition must remain unmutated"); + } + + @Test + void testDefaultValueDeepCloning_withList() throws Exception { + List complexDefault = new ArrayList<>(); + complexDefault.add("item1"); + complexDefault.add(Map.of("nestedKey", "nestedValue")); + + ToolDefinition.Parameter paramWithDefault = + new ToolDefinition.Parameter("param1", "array", false, "A parameter", null, complexDefault); + + ToolDefinition def = new ToolDefinition("A test tool", List.of(paramWithDefault), null); + + Tool tool = new Tool("testTool", def, mockClient); + + when(mockClient.invokeTool(eq("testTool"), any(), any())) + .thenReturn( + CompletableFuture.completedFuture(new ToolResult(Collections.emptyList(), false))); + + Map args = new HashMap<>(); + CompletableFuture future = tool.execute(args); + future.join(); + + @SuppressWarnings("unchecked") + ArgumentCaptor> argsCaptor = ArgumentCaptor.forClass(Map.class); + verify(mockClient).invokeTool(eq("testTool"), argsCaptor.capture(), any()); + + Map capturedArgs = argsCaptor.getValue(); + @SuppressWarnings("unchecked") + List injectedDefault = (List) capturedArgs.get("param1"); + + // Mutate the injected list + injectedDefault.set(0, "mutated_item"); + + // Ensure the original defaultValue stored in the definition remains untouched + @SuppressWarnings("unchecked") + List defValueInDefinition = (List) def.parameters().get(0).defaultValue(); + assertEquals( + "item1", + defValueInDefinition.get(0), + "The default value in definition must remain unmutated"); + } + + @Test + void testValidateAndSanitizeArgs_requiredParameterProvided() throws Exception { + List params = + List.of(new ToolDefinition.Parameter("p-required", "string", true, "desc", List.of())); + ToolDefinition def = new ToolDefinition("test-tool", params, List.of()); + when(mockClient.invokeTool(anyString(), anyMap(), anyMap())) + .thenReturn(CompletableFuture.completedFuture(new ToolResult(List.of(), false))); + + Tool tool = new Tool("test-tool", def, mockClient); + tool.execute(Map.of("p-required", "provided-value")).join(); // should succeed + } + + @Test + void testValidateAndSanitizeArgs_nullTypeWithNonNullValue() throws Exception { + List params = + List.of(new ToolDefinition.Parameter("p-no-type", null, false, "desc", List.of())); + ToolDefinition def = new ToolDefinition("test-tool", params, List.of()); + when(mockClient.invokeTool(anyString(), anyMap(), anyMap())) + .thenReturn(CompletableFuture.completedFuture(new ToolResult(List.of(), false))); + + Tool tool = new Tool("test-tool", def, mockClient); + tool.execute(Map.of("p-no-type", "some-value")).join(); // should succeed without checking type + } +} diff --git a/src/test/java/com/google/cloud/mcp/transport/HttpMcpTransportTest.java b/src/test/java/com/google/cloud/mcp/transport/HttpMcpTransportTest.java new file mode 100644 index 0000000..3d1ad86 --- /dev/null +++ b/src/test/java/com/google/cloud/mcp/transport/HttpMcpTransportTest.java @@ -0,0 +1,503 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.transport; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.cloud.mcp.ProtocolVersion; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.exception.McpException; +import com.google.cloud.mcp.tool.ToolDefinition; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +@Timeout(value = 5, unit = java.util.concurrent.TimeUnit.SECONDS) +class HttpMcpTransportTest { + + private HttpClient mockClient; + private HttpMcpTransport transport; + + @BeforeEach + @SuppressWarnings("unchecked") + void setUp() { + mockClient = mock(HttpClient.class); + transport = new HttpMcpTransport("https://test-mcp-service.com", mockClient); + } + + @Test + @SuppressWarnings("unchecked") + void testListTools_PerformsHandshakeAndFetchesTools() throws Exception { + // 1. Mock response for 'initialize' + HttpResponse mockInitResponse = mock(HttpResponse.class); + when(mockInitResponse.statusCode()).thenReturn(200); + when(mockInitResponse.body()) + .thenReturn( + "{\"jsonrpc\":\"2.0\",\"id\":\"1\",\"result\":{\"protocolVersion\":\"2025-11-25\"}}"); + + // 2. Mock response for 'notifications/initialized' + HttpResponse mockInitializedResponse = mock(HttpResponse.class); + when(mockInitializedResponse.statusCode()).thenReturn(200); + when(mockInitializedResponse.body()).thenReturn(""); + + // 3. Mock response for 'tools/list' + HttpResponse mockListResponse = mock(HttpResponse.class); + when(mockListResponse.statusCode()).thenReturn(200); + when(mockListResponse.body()) + .thenReturn( + "{\"jsonrpc\":\"2.0\",\"id\":\"2\",\"result\":{\"tools\":[{\"name\":\"test-tool\",\"description\":\"A" + + " test" + + " tool\",\"inputSchema\":{\"type\":\"object\",\"properties\":{\"param1\":{\"type\":\"string\",\"description\":\"param" + + " desc\"}}," + + "\"required\":[\"param1\"]},\"_meta\":{\"toolbox/authInvoke\":[\"gcp\"]}}]}}"); + + CompletableFuture> initFuture = + CompletableFuture.completedFuture(mockInitResponse); + CompletableFuture> initializedFuture = + CompletableFuture.completedFuture(mockInitializedResponse); + CompletableFuture> listFuture = + CompletableFuture.completedFuture(mockListResponse); + + // Set up mock calls sequentially with type hint + when(mockClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(initFuture) + .thenReturn(initializedFuture) + .thenReturn(listFuture); + + CompletableFuture futureManifest = + transport.listTools("", Collections.emptyMap()); + TransportManifest manifest = futureManifest.get(); + + assertNotNull(manifest); + assertEquals(1, manifest.getTools().size()); + assertTrue(manifest.getTools().containsKey("test-tool")); + ToolDefinition def = manifest.getTools().get("test-tool"); + assertEquals("A test tool", def.description()); + assertEquals(1, def.parameters().size()); + assertEquals("param1", def.parameters().get(0).name()); + assertTrue(def.parameters().get(0).required()); + } + + @Test + @SuppressWarnings("unchecked") + void testInvokeTool_PerformsHandshakeAndExecutesCall() throws Exception { + // 1. Mock response for 'initialize' + HttpResponse mockInitResponse = mock(HttpResponse.class); + when(mockInitResponse.statusCode()).thenReturn(200); + when(mockInitResponse.body()) + .thenReturn( + "{\"jsonrpc\":\"2.0\",\"id\":\"1\",\"result\":{\"protocolVersion\":\"2025-11-25\"}}"); + + // 2. Mock response for 'notifications/initialized' + HttpResponse mockInitializedResponse = mock(HttpResponse.class); + when(mockInitializedResponse.statusCode()).thenReturn(200); + when(mockInitializedResponse.body()).thenReturn(""); + + // 3. Mock response for 'tools/call' + HttpResponse mockInvokeResponse = mock(HttpResponse.class); + when(mockInvokeResponse.statusCode()).thenReturn(200); + when(mockInvokeResponse.body()) + .thenReturn( + "{\"jsonrpc\":\"2.0\",\"id\":\"3\",\"result\":{\"content\":[{\"type\":\"text\",\"text\":\"success\"}]}}"); + + CompletableFuture> initFuture = + CompletableFuture.completedFuture(mockInitResponse); + CompletableFuture> initializedFuture = + CompletableFuture.completedFuture(mockInitializedResponse); + CompletableFuture> invokeFuture = + CompletableFuture.completedFuture(mockInvokeResponse); + + when(mockClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(initFuture) + .thenReturn(initializedFuture) + .thenReturn(invokeFuture); + + CompletableFuture futureResult = + transport.invokeTool("test-tool", Map.of("param1", "value1"), Collections.emptyMap()); + TransportResponse response = futureResult.get(); + + assertNotNull(response); + assertEquals(200, response.getStatusCode()); + assertTrue(response.getBody().contains("success")); + } + + @Test + @SuppressWarnings("unchecked") + void testSubsequentCalls_DoNotReinitialize() throws Exception { + // 1. Mock response for 'initialize' + HttpResponse mockInitResponse = mock(HttpResponse.class); + when(mockInitResponse.statusCode()).thenReturn(200); + when(mockInitResponse.body()) + .thenReturn( + "{\"jsonrpc\":\"2.0\",\"id\":\"1\",\"result\":{\"protocolVersion\":\"2025-11-25\"}}"); + + // 2. Mock response for 'notifications/initialized' + HttpResponse mockInitializedResponse = mock(HttpResponse.class); + when(mockInitializedResponse.statusCode()).thenReturn(200); + when(mockInitializedResponse.body()).thenReturn(""); + + // 3. Mock response for first 'tools/list' + HttpResponse mockListResponse1 = mock(HttpResponse.class); + when(mockListResponse1.statusCode()).thenReturn(200); + when(mockListResponse1.body()) + .thenReturn("{\"jsonrpc\":\"2.0\",\"id\":\"2\",\"result\":{\"tools\":[]}}"); + + // 4. Mock response for second 'tools/list' + HttpResponse mockListResponse2 = mock(HttpResponse.class); + when(mockListResponse2.statusCode()).thenReturn(200); + when(mockListResponse2.body()) + .thenReturn("{\"jsonrpc\":\"2.0\",\"id\":\"3\",\"result\":{\"tools\":[]}}"); + + CompletableFuture> initFuture = + CompletableFuture.completedFuture(mockInitResponse); + CompletableFuture> initializedFuture = + CompletableFuture.completedFuture(mockInitializedResponse); + CompletableFuture> listFuture1 = + CompletableFuture.completedFuture(mockListResponse1); + CompletableFuture> listFuture2 = + CompletableFuture.completedFuture(mockListResponse2); + + // Set up sequential answers + when(mockClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(initFuture) + .thenReturn(initializedFuture) + .thenReturn(listFuture1) + .thenReturn(listFuture2); + + // First call lists tools (performs handshake + lists) + transport.listTools("", Collections.emptyMap()).get(); + + // Second call lists tools (should only list tools directly) + transport.listTools("", Collections.emptyMap()).get(); + + // Total calls to sendAsync should be 4 (1: init, 2: initialized, 3: list1, 4: list2) + verify(mockClient, times(4)) + .sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class)); + } + + @Test + void testConstructor_InvalidBaseUrlThrows() { + org.junit.jupiter.api.Assertions.assertThrows( + IllegalArgumentException.class, () -> new HttpMcpTransport(null)); + org.junit.jupiter.api.Assertions.assertThrows( + IllegalArgumentException.class, () -> new HttpMcpTransport("")); + } + + @Test + void testConstructor_WithOnlyBaseUrl() { + HttpMcpTransport simpleTransport = new HttpMcpTransport("https://test-mcp-service.com"); + assertNotNull(simpleTransport); + assertEquals("https://test-mcp-service.com", simpleTransport.getBaseUrl()); + } + + @Test + void testOtherOverloadedConstructors() { + java.net.http.HttpClient client = java.net.http.HttpClient.newHttpClient(); + java.util.concurrent.Executor executor = java.util.concurrent.ForkJoinPool.commonPool(); + CredentialsProvider provider = () -> CompletableFuture.completedFuture("Bearer test"); + + HttpMcpTransport transport1 = + new HttpMcpTransport( + "https://test-mcp.com", + Map.of("X-Header", "value"), + ProtocolVersion.VERSION_2025_11_25, + client, + executor); + assertNotNull(transport1); + + HttpMcpTransport transport2 = + new HttpMcpTransport("https://test-mcp.com", Map.of("X-Header", "value"), provider, client); + assertNotNull(transport2); + } + + @Test + void testConstructor_WithCustomExecutorConfiguresHttpClient() throws Exception { + java.util.concurrent.atomic.AtomicInteger taskCount = + new java.util.concurrent.atomic.AtomicInteger(0); + java.util.concurrent.Executor customExecutor = + runnable -> { + taskCount.incrementAndGet(); + new Thread(runnable).start(); + }; + + HttpMcpTransport transport = + new HttpMcpTransport( + "http://localhost:8080", + java.util.Map.of(), + ProtocolVersion.VERSION_2025_11_25, + null, + customExecutor); + + java.lang.reflect.Field delegateField = HttpMcpTransport.class.getDeclaredField("delegate"); + delegateField.setAccessible(true); + Object delegate = delegateField.get(transport); + + java.lang.reflect.Field httpClientField = BaseMcpTransport.class.getDeclaredField("httpClient"); + httpClientField.setAccessible(true); + java.net.http.HttpClient httpClient = (java.net.http.HttpClient) httpClientField.get(delegate); + + assertNotNull(httpClient); + Object internalExecutor = null; + try { + java.lang.reflect.Field executorField = httpClient.getClass().getDeclaredField("executor"); + executorField.setAccessible(true); + internalExecutor = executorField.get(httpClient); + } catch (NoSuchFieldException e) { + // Fallback + } + + if (internalExecutor != null) { + org.junit.jupiter.api.Assertions.assertSame(customExecutor, internalExecutor); + } + } + + @Test + @SuppressWarnings("unchecked") + void testInitialize_ServerReturnsErrorJsonRpcResponse() throws Exception { + HttpResponse mockInitResponse = mock(HttpResponse.class); + when(mockInitResponse.statusCode()).thenReturn(200); + when(mockInitResponse.body()) + .thenReturn( + "{\"jsonrpc\":\"2.0\",\"id\":\"1\",\"error\":{\"code\":-32603,\"message\":\"Internal" + + " error\"}}"); + + when(mockClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(mockInitResponse)); + + CompletableFuture future = transport.listTools("", Collections.emptyMap()); + java.util.concurrent.ExecutionException ex = + org.junit.jupiter.api.Assertions.assertThrows( + java.util.concurrent.ExecutionException.class, future::get); + assertTrue(ex.getCause() instanceof McpException); + assertTrue(ex.getCause().getMessage().contains("MCP Error")); + } + + @Test + @SuppressWarnings("unchecked") + void testListTools_WithHttpUrlAndMetadata_LogsWarning() throws Exception { + HttpMcpTransport httpTransport = + new HttpMcpTransport("http://test-mcp-service.com", mockClient); + HttpResponse mockInitResponse = mock(HttpResponse.class); + when(mockInitResponse.statusCode()).thenReturn(200); + when(mockInitResponse.body()) + .thenReturn( + "{\"jsonrpc\":\"2.0\",\"id\":\"1\",\"result\":{\"protocolVersion\":\"2025-11-25\"}}"); + + HttpResponse mockInitializedResponse = mock(HttpResponse.class); + when(mockInitializedResponse.statusCode()).thenReturn(200); + when(mockInitializedResponse.body()).thenReturn(""); + + HttpResponse mockListResponse = mock(HttpResponse.class); + when(mockListResponse.statusCode()).thenReturn(200); + when(mockListResponse.body()) + .thenReturn("{\"jsonrpc\":\"2.0\",\"id\":\"2\",\"result\":{\"tools\":[]}}"); + + when(mockClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(mockInitResponse)) + .thenReturn(CompletableFuture.completedFuture(mockInitializedResponse)) + .thenReturn(CompletableFuture.completedFuture(mockListResponse)); + + java.util.logging.Logger transportLogger = + java.util.logging.Logger.getLogger(BaseMcpTransport.class.getName()); + java.util.List logRecords = new java.util.ArrayList<>(); + java.util.logging.Handler logHandler = + new java.util.logging.Handler() { + @Override + public void publish(java.util.logging.LogRecord record) { + logRecords.add(record); + } + + @Override + public void flush() {} + + @Override + public void close() throws SecurityException {} + }; + transportLogger.addHandler(logHandler); + + try { + httpTransport.listTools("", Map.of("key", "val")).get(); + } finally { + transportLogger.removeHandler(logHandler); + } + + assertFalse(logRecords.isEmpty()); + assertTrue(logRecords.get(0).getMessage().contains("This connection is using HTTP")); + } + + @Test + @SuppressWarnings("unchecked") + void testListTools_Non200Response_ThrowsException() { + HttpResponse mockInitResponse = mock(HttpResponse.class); + when(mockInitResponse.statusCode()).thenReturn(200); + when(mockInitResponse.body()) + .thenReturn( + "{\"jsonrpc\":\"2.0\",\"id\":\"1\",\"result\":{\"protocolVersion\":\"2025-11-25\"}}"); + + HttpResponse mockInitializedResponse = mock(HttpResponse.class); + when(mockInitializedResponse.statusCode()).thenReturn(200); + when(mockInitializedResponse.body()).thenReturn(""); + + HttpResponse mockErrorResponse = mock(HttpResponse.class); + when(mockErrorResponse.statusCode()).thenReturn(500); + when(mockErrorResponse.body()).thenReturn("Internal Server Error"); + + when(mockClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(mockInitResponse)) + .thenReturn(CompletableFuture.completedFuture(mockInitializedResponse)) + .thenReturn(CompletableFuture.completedFuture(mockErrorResponse)); + + Exception ex = + org.junit.jupiter.api.Assertions.assertThrows( + Exception.class, () -> transport.listTools("", Collections.emptyMap()).get()); + assertTrue(ex.getCause().getMessage().contains("Status: 500")); + } + + @Test + @SuppressWarnings("unchecked") + void testListTools_JsonRpcError_ThrowsException() { + HttpResponse mockInitResponse = mock(HttpResponse.class); + when(mockInitResponse.statusCode()).thenReturn(200); + when(mockInitResponse.body()) + .thenReturn( + "{\"jsonrpc\":\"2.0\",\"id\":\"1\",\"result\":{\"protocolVersion\":\"2025-11-25\"}}"); + + HttpResponse mockInitializedResponse = mock(HttpResponse.class); + when(mockInitializedResponse.statusCode()).thenReturn(200); + when(mockInitializedResponse.body()).thenReturn(""); + + HttpResponse mockErrorResponse = mock(HttpResponse.class); + when(mockErrorResponse.statusCode()).thenReturn(200); + when(mockErrorResponse.body()) + .thenReturn( + "{\"jsonrpc\":\"2.0\",\"id\":\"2\",\"error\":{\"code\":-1,\"message\":\"Custom" + + " error\"}}"); + + when(mockClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(mockInitResponse)) + .thenReturn(CompletableFuture.completedFuture(mockInitializedResponse)) + .thenReturn(CompletableFuture.completedFuture(mockErrorResponse)); + + Exception ex = + org.junit.jupiter.api.Assertions.assertThrows( + Exception.class, () -> transport.listTools("", Collections.emptyMap()).get()); + assertTrue(ex.getCause().getMessage().contains("Custom error")); + } + + @Test + @SuppressWarnings("unchecked") + void testListTools_ParsesComplexToolsCorrectly() throws Exception { + HttpResponse mockInitResponse = mock(HttpResponse.class); + when(mockInitResponse.statusCode()).thenReturn(200); + when(mockInitResponse.body()) + .thenReturn( + "{\"jsonrpc\":\"2.0\",\"id\":\"1\",\"result\":{\"protocolVersion\":\"2025-11-25\"}}"); + + HttpResponse mockInitializedResponse = mock(HttpResponse.class); + when(mockInitializedResponse.statusCode()).thenReturn(200); + when(mockInitializedResponse.body()).thenReturn(""); + + String json = + "{\"jsonrpc\":\"2.0\",\"id\":\"2\",\"result\":{\"tools\":[" + + "{" + + " \"name\":\"test-tool\"," + + " \"description\":\"Desc\"," + + " \"inputSchema\":{" + + " \"type\":\"object\"," + + " \"required\":[\"p1\"]," + + " \"properties\":{" + + " \"p1\": { \"type\":\"string\", \"description\":\"p1 desc\" }," + + " \"p2\": { }" + + " }" + + " }," + + " \"_meta\":{" + + " \"toolbox/authInvoke\": \"not-an-array\"," + + " \"toolbox/authParam\": {" + + " \"p1\": [\"gcp\"]" + + " }" + + " }" + + "}" + + "]}}"; + HttpResponse mockListResponse = mock(HttpResponse.class); + when(mockListResponse.statusCode()).thenReturn(200); + when(mockListResponse.body()).thenReturn(json); + + when(mockClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(mockInitResponse)) + .thenReturn(CompletableFuture.completedFuture(mockInitializedResponse)) + .thenReturn(CompletableFuture.completedFuture(mockListResponse)); + + TransportManifest manifest = transport.listTools("", Collections.emptyMap()).get(); + assertNotNull(manifest); + ToolDefinition tool = manifest.getTools().get("test-tool"); + assertNotNull(tool); + assertEquals("Desc", tool.description()); + assertEquals(2, tool.parameters().size()); + + ToolDefinition.Parameter p1 = + tool.parameters().stream().filter(p -> p.name().equals("p1")).findFirst().get(); + assertTrue(p1.required()); + assertEquals("p1 desc", p1.description()); + assertEquals(List.of("gcp"), p1.authSources()); + + ToolDefinition.Parameter p2 = + tool.parameters().stream().filter(p -> p.name().equals("p2")).findFirst().get(); + assertFalse(p2.required()); + assertEquals("string", p2.type()); + } + + @Test + @Timeout(5) + @SuppressWarnings("unchecked") + void testInvokeTool_ExceptionRecording() throws Exception { + HttpResponse mockInitResponse = mock(HttpResponse.class); + when(mockInitResponse.statusCode()).thenReturn(200); + when(mockInitResponse.body()) + .thenReturn( + "{\"jsonrpc\":\"2.0\",\"id\":\"1\",\"result\":{\"protocolVersion\":\"2025-11-25\"}}"); + + HttpResponse mockInitializedResponse = mock(HttpResponse.class); + when(mockInitializedResponse.statusCode()).thenReturn(200); + when(mockInitializedResponse.body()).thenReturn(""); + + when(mockClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(mockInitResponse)) + .thenReturn(CompletableFuture.completedFuture(mockInitializedResponse)) + .thenReturn(CompletableFuture.failedFuture(new java.io.IOException("connection failure"))); + + CompletableFuture futureResult = + transport.invokeTool("test-tool", Map.of(), Collections.emptyMap()); + + assertThrows(Exception.class, futureResult::get); + } +}