diff --git a/demo-applications/cymbal-transit/pom.xml b/demo-applications/cymbal-transit/pom.xml index 7c420d4..0897eaa 100644 --- a/demo-applications/cymbal-transit/pom.xml +++ b/demo-applications/cymbal-transit/pom.xml @@ -69,7 +69,7 @@ com.google.cloud.mcp mcp-toolbox-sdk-java - 0.2.0 + 0.2.1-SNAPSHOT diff --git a/demo-applications/cymbal-transit/src/main/java/cloudcode/cymbal/CymbalTransitApplication.java b/demo-applications/cymbal-transit/src/main/java/cloudcode/cymbal/CymbalTransitApplication.java index 070c86e..a65ad96 100644 --- a/demo-applications/cymbal-transit/src/main/java/cloudcode/cymbal/CymbalTransitApplication.java +++ b/demo-applications/cymbal-transit/src/main/java/cloudcode/cymbal/CymbalTransitApplication.java @@ -40,6 +40,8 @@ public static void main(final String[] args) throws Exception { // Start the Spring Boot application. app.run(args); logger.info( - "Hello from Cloud Run! The container started successfully and is listening for HTTP requests on " + port); + "Hello from Cloud Run! The container started successfully and is listening for HTTP" + + " requests on " + + port); } } diff --git a/demo-applications/cymbal-transit/src/main/java/cloudcode/cymbal/web/CymbalTransitController.java b/demo-applications/cymbal-transit/src/main/java/cloudcode/cymbal/web/CymbalTransitController.java index 6cea16f..4ef044d 100644 --- a/demo-applications/cymbal-transit/src/main/java/cloudcode/cymbal/web/CymbalTransitController.java +++ b/demo-applications/cymbal-transit/src/main/java/cloudcode/cymbal/web/CymbalTransitController.java @@ -16,250 +16,272 @@ package cloudcode.cymbal.web; -import com.google.cloud.mcp.McpToolboxClient; import cloudcode.cymbal.CymbalTransitApplication; -import com.google.cloud.mcp.AuthTokenGetter; import com.google.auth.oauth2.GoogleCredentials; import com.google.auth.oauth2.IdTokenProvider; - -import org.springframework.boot.SpringApplication; -import org.springframework.boot.autoconfigure.SpringBootApplication; -import org.springframework.web.bind.annotation.*; -import org.springframework.http.ResponseEntity; -import org.springframework.stereotype.Service; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; -import org.springframework.beans.factory.annotation.Value; - -import javax.annotation.PostConstruct; -import javax.servlet.http.HttpSession; -import java.io.FileInputStream; -import java.util.Collections; -import java.util.Map; -import java.util.concurrent.CompletableFuture; - -// LangChain4j Imports for Agentic Routing & Gemini 3 Flash +import com.google.cloud.mcp.McpToolboxClient; +import com.google.cloud.mcp.auth.AuthTokenGetter; +import dev.langchain4j.agent.tool.Tool; +import dev.langchain4j.memory.chat.MessageWindowChatMemory; import dev.langchain4j.model.chat.ChatLanguageModel; import dev.langchain4j.model.vertexai.VertexAiGeminiChatModel; -import dev.langchain4j.memory.chat.MessageWindowChatMemory; import dev.langchain4j.service.AiServices; -import dev.langchain4j.agent.tool.Tool; -import dev.langchain4j.service.SystemMessage; import dev.langchain4j.service.MemoryId; +import dev.langchain4j.service.SystemMessage; import dev.langchain4j.service.UserMessage; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; +import javax.annotation.PostConstruct; +import javax.servlet.http.HttpSession; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.http.ResponseEntity; +import org.springframework.stereotype.Service; +import org.springframework.web.bind.annotation.*; @SpringBootApplication public class CymbalTransitController { - public static void main(String[] args) { - SpringApplication.run(CymbalTransitApplication.class, args); - } + public static void main(String[] args) { + SpringApplication.run(CymbalTransitApplication.class, args); + } } /** - * 1. AI AGENT CONFIGURATION - * Configures Gemini 3 Flash and binds it to our LangChain4j Agent Interface. + * 1. AI AGENT CONFIGURATION Configures Gemini 3 Flash and binds it to our LangChain4j Agent + * Interface. */ @Configuration class AgentConfiguration { - @Value("${GCP_PROJECT_ID:fallback_project_id}") - private String projectId; - - @Value("${GCP_REGION:fallback_region}") - private String region; - - @Value("${GEMINI_MODEL_NAME:fallback_model}") - private String modelName; - - @Bean - ChatLanguageModel geminiChatModel() { - return VertexAiGeminiChatModel.builder() - .project(projectId) - .location(region) - .modelName(modelName) // Utilizing externalized parameters - .build(); - } - - @Bean - TransitAgent transitAgent(ChatLanguageModel chatLanguageModel, TransitAgentTools tools) { - return AiServices.builder(TransitAgent.class) - .chatLanguageModel(chatLanguageModel) - .chatMemoryProvider(memoryId -> MessageWindowChatMemory.withMaxMessages(20)) - .tools(tools) // Exposes our MCP tools to Gemini - .build(); - } + @Value("${GCP_PROJECT_ID:fallback_project_id}") + private String projectId; + + @Value("${GCP_REGION:fallback_region}") + private String region; + + @Value("${GEMINI_MODEL_NAME:fallback_model}") + private String modelName; + + @Bean + ChatLanguageModel geminiChatModel() { + return VertexAiGeminiChatModel.builder() + .project(projectId) + .location(region) + .modelName(modelName) // Utilizing externalized parameters + .build(); + } + + @Bean + TransitAgent transitAgent(ChatLanguageModel chatLanguageModel, TransitAgentTools tools) { + return AiServices.builder(TransitAgent.class) + .chatLanguageModel(chatLanguageModel) + .chatMemoryProvider(memoryId -> MessageWindowChatMemory.withMaxMessages(20)) + .tools(tools) // Exposes our MCP tools to Gemini + .build(); + } } -/** - * 2. THE AI AGENT INTERFACE - * Declarative AI service handling routing logic via System Prompt. - */ +/** 2. THE AI AGENT INTERFACE Declarative AI service handling routing logic via System Prompt. */ interface TransitAgent { - @SystemMessage({ - "You are the Cymbal Transit Concierge.", - "CRITICAL INSTRUCTION: On your very first interaction, you MUST use the 'findAllSchedules' tool to fetch and memorize the broad bus routes.", - "Keep this data handy in your context. Answer general routing questions using this stored data. ", - "If you have to list the route details to the user, show it along with the full UUID and with other details that are meaningful. If the user chooses to book ticket as the next step, prompt them to copy the correct UUID nad paste so the transaction can be confirmed.", - "ONLY if the user asks a specifically narrowed-down question, asks for precise times, or assigns a booking task, or asks about policies should you route to the specific tools like 'querySchedules', 'bookTicket', 'searchPolicies'.", - "Remember the tool 'querySchedules' is for finding schedules between cities, 'bookTicket' is for booking ticket actionable between 2 cities, 'searchPolicies' is for finding matching policies for this company.", - "Be intuitive and intelligent in finding the context even when user has typos. Do no hallucinate and make up stuff though. USe only data from the tools. ", - "Don't show any asterisks while listing results. Keep it formatted and numbered or bulleted. asterisks distract." - }) - String chat(@MemoryId String sessionId, @UserMessage String userMessage); + @SystemMessage({ + "You are the Cymbal Transit Concierge.", + "CRITICAL INSTRUCTION: On your very first interaction, you MUST use the 'findAllSchedules' tool" + + " to fetch and memorize the broad bus routes.", + "Keep this data handy in your context. Answer general routing questions using this stored data." + + " ", + "If you have to list the route details to the user, show it along with the full UUID and with" + + " other details that are meaningful. If the user chooses to book ticket as the next step," + + " prompt them to copy the correct UUID nad paste so the transaction can be confirmed.", + "ONLY if the user asks a specifically narrowed-down question, asks for precise times, or" + + " assigns a booking task, or asks about policies should you route to the specific tools" + + " like 'querySchedules', 'bookTicket', 'searchPolicies'.", + "Remember the tool 'querySchedules' is for finding schedules between cities, 'bookTicket' is" + + " for booking ticket actionable between 2 cities, 'searchPolicies' is for finding" + + " matching policies for this company.", + "Be intuitive and intelligent in finding the context even when user has typos. Do no" + + " hallucinate and make up stuff though. USe only data from the tools. ", + "Don't show any asterisks while listing results. Keep it formatted and numbered or bulleted." + + " asterisks distract." + }) + String chat(@MemoryId String sessionId, @UserMessage String userMessage); } /** - * 3. THE TOOLBOX BRIDGE - * Wraps our asynchronous MCP Client calls into synchronous @Tools that LangChain4j (Gemini) can execute. + * 3. THE TOOLBOX BRIDGE Wraps our asynchronous MCP Client calls into synchronous @Tools that + * LangChain4j (Gemini) can execute. */ @Service class TransitAgentTools { - - private final McpToolboxService mcpService; - - public TransitAgentTools(McpToolboxService mcpService) { - this.mcpService = mcpService; - } - @Tool("Fetches the initial, broad dataset of all available bus schedules and routes. Use this to build your context.") - public String findAllSchedules() { - return mcpService.findAllSchedules().join(); - } - - @Tool("Query specific schedules between an origin and destination city. Use only when the user narrows down their request.") - public String querySchedules(String origin, String destination) { - return mcpService.querySchedules(origin, destination).join(); - } - - @Tool("Book a ticket for a passenger using a specific trip ID.") - public String bookTicket(String tripId, String passengerName) { - return mcpService.bookTicket(tripId, passengerName).join(); - } - - @Tool("Semantic search for transit policies regarding luggage, pets, refunds, and general rules.") - public String searchPolicies(String searchQuery) { - return mcpService.searchPolicies(searchQuery).join(); - } + private final McpToolboxService mcpService; + + public TransitAgentTools(McpToolboxService mcpService) { + this.mcpService = mcpService; + } + + @Tool( + "Fetches the initial, broad dataset of all available bus schedules and routes. Use this to" + + " build your context.") + public String findAllSchedules() { + return mcpService.findAllSchedules().join(); + } + + @Tool( + "Query specific schedules between an origin and destination city. Use only when the user" + + " narrows down their request.") + public String querySchedules(String origin, String destination) { + return mcpService.querySchedules(origin, destination).join(); + } + + @Tool("Book a ticket for a passenger using a specific trip ID.") + public String bookTicket(String tripId, String passengerName) { + return mcpService.bookTicket(tripId, passengerName).join(); + } + + @Tool("Semantic search for transit policies regarding luggage, pets, refunds, and general rules.") + public String searchPolicies(String searchQuery) { + return mcpService.searchPolicies(searchQuery).join(); + } } /** - * 4. THE MCP TOOLBOX SERVICE - * Handles the actual connection and execution against the AlloyDB backend. + * 4. THE MCP TOOLBOX SERVICE Handles the actual connection and execution against the AlloyDB + * backend. */ @Service class McpToolboxService { - - private McpToolboxClient mcpClient; - private String idToken; - - @Value("${MCP_TOOLBOX_URL:fallback_toolbox_url}") - private String targetUrl; - - @PostConstruct - public void init() { - try { - String tokenAudience = targetUrl; - - System.out.println("--- Initializing MCP Toolbox Client ---"); - - GoogleCredentials credentials = GoogleCredentials.getApplicationDefault(); - if (!(credentials instanceof IdTokenProvider)) { - throw new RuntimeException("Loaded credentials do not support ID Tokens."); - } - - this.idToken = ((IdTokenProvider) credentials) - .idTokenWithAudience(tokenAudience, Collections.emptyList()) - .getTokenValue(); - - this.mcpClient = McpToolboxClient.builder() - .baseUrl(targetUrl) - .apiKey(idToken) - .build(); - - mcpClient.listTools().thenAccept(tools -> { - System.out.println("Successfully discovered " + tools.size() + " tools."); - }).join(); - } catch (Exception e) { - System.err.println("Failed to initialize MCP Toolbox Client:"); - e.printStackTrace(); - } - } + private McpToolboxClient mcpClient; + private String idToken; - public CompletableFuture findAllSchedules() { - return mcpClient.invokeTool("find-bus-schedules", Collections.emptyMap()).thenApply(result -> { - if (result.isError() || result.content() == null || result.content().isEmpty()) return "No schedules found."; - //return result.content().get(0).text(); - //return result.text(); - return result.content().stream() - .map(content -> content.text()) - .collect(Collectors.joining(", ", "[", "]")); - }); - } + @Value("${MCP_TOOLBOX_URL:fallback_toolbox_url}") + private String targetUrl; - public CompletableFuture querySchedules(String origin, String destination) { - java.util.Map params = new java.util.HashMap<>(); - params.put("origin", origin); - params.put("destination", destination); - return mcpClient.invokeTool("query-schedules", params).thenApply(result -> { - if (result.isError() || result.content() == null || result.content().isEmpty()) return "No specific schedules found."; - System.out.println(result); - return result.content().stream() - .map(content -> content.text()) - .collect(Collectors.joining(", ", "[", "]")); - }); - } + @PostConstruct + public void init() { + try { + String tokenAudience = targetUrl; + + System.out.println("--- Initializing MCP Toolbox Client ---"); + + GoogleCredentials credentials = GoogleCredentials.getApplicationDefault(); + if (!(credentials instanceof IdTokenProvider)) { + throw new RuntimeException("Loaded credentials do not support ID Tokens."); + } + + this.idToken = + ((IdTokenProvider) credentials) + .idTokenWithAudience(tokenAudience, Collections.emptyList()) + .getTokenValue(); - public CompletableFuture bookTicket(String tripId, String passengerName) { - AuthTokenGetter toolAuthGetter = () -> CompletableFuture.completedFuture(idToken); - return mcpClient.loadTool("book-ticket", Collections.singletonMap("google_auth", toolAuthGetter)) - .thenCompose(tool -> { - tool.bindParam("passenger_name", passengerName); - return tool.execute(Collections.singletonMap("trip_id", tripId)); + this.mcpClient = McpToolboxClient.builder().baseUrl(targetUrl).apiKey(idToken).build(); + + mcpClient + .listTools() + .thenAccept( + tools -> { + System.out.println("Successfully discovered " + tools.size() + " tools."); + }) + .join(); + + } catch (Exception e) { + System.err.println("Failed to initialize MCP Toolbox Client:"); + e.printStackTrace(); + } + } + + public CompletableFuture findAllSchedules() { + return mcpClient + .invokeTool("find-bus-schedules", Collections.emptyMap()) + .thenApply( + result -> { + if (result.isError() || result.content() == null || result.content().isEmpty()) + return "No schedules found."; + // return result.content().get(0).text(); + // return result.text(); + return result.content().stream() + .map(content -> content.text()) + .collect(Collectors.joining(", ", "[", "]")); + }); + } + + public CompletableFuture querySchedules(String origin, String destination) { + java.util.Map params = new java.util.HashMap<>(); + params.put("origin", origin); + params.put("destination", destination); + return mcpClient + .invokeTool("query-schedules", params) + .thenApply( + result -> { + if (result.isError() || result.content() == null || result.content().isEmpty()) + return "No specific schedules found."; + System.out.println(result); + return result.content().stream() + .map(content -> content.text()) + .collect(Collectors.joining(", ", "[", "]")); + }); + } + + public CompletableFuture bookTicket(String tripId, String passengerName) { + AuthTokenGetter toolAuthGetter = () -> CompletableFuture.completedFuture(idToken); + return mcpClient + .loadTool("book-ticket", Collections.singletonMap("google_auth", toolAuthGetter)) + .thenCompose( + tool -> { + tool.bindParam("passenger_name", passengerName); + return tool.execute(Collections.singletonMap("trip_id", tripId)); }) - .thenApply(result -> { - if (result.isError() || result.content() == null || result.content().isEmpty()) { - System.err.println("Tool execution failed: " + result.content().get(0).text()); - return "Transaction failed."; - } - return result.content().get(0).text(); + .thenApply( + result -> { + if (result.isError() || result.content() == null || result.content().isEmpty()) { + System.err.println("Tool execution failed: " + result.content().get(0).text()); + return "Transaction failed."; + } + return result.content().get(0).text(); }); - } - - public CompletableFuture searchPolicies(String searchQuery) { - return mcpClient.invokeTool("search-policies", Map.of("search_query", searchQuery)) - .thenApply(result -> { - if (result.isError() || result.content() == null || result.content().isEmpty()) return "No policy information found."; - return result.content().stream() - .map(content -> content.text()) - .collect(Collectors.joining(", ", "[", "]")); + } + + public CompletableFuture searchPolicies(String searchQuery) { + return mcpClient + .invokeTool("search-policies", Map.of("search_query", searchQuery)) + .thenApply( + result -> { + if (result.isError() || result.content() == null || result.content().isEmpty()) + return "No policy information found."; + return result.content().stream() + .map(content -> content.text()) + .collect(Collectors.joining(", ", "[", "]")); }); - } + } } /** - * 5. THE REST CONTROLLER - * Now radically simplified! No more manual if/else logic or JSON parsing. + * 5. THE REST CONTROLLER Now radically simplified! No more manual if/else logic or JSON parsing. */ @RestController @RequestMapping("/api/agent") class TransitAgentController { - private final TransitAgent transitAgent; + private final TransitAgent transitAgent; - public TransitAgentController(TransitAgent transitAgent) { - this.transitAgent = transitAgent; - } + public TransitAgentController(TransitAgent transitAgent) { + this.transitAgent = transitAgent; + } - @PostMapping("/chat") - public ResponseEntity handleUserChat(@RequestBody String userMessage, HttpSession session) { - // We use the HTTP Session ID to tell LangChain4j which memory context to load - String sessionId = session.getId(); - - // Let Gemini 3 Flash handle the thinking, tool execution, and response generation! - String agentResponse = transitAgent.chat(sessionId, userMessage); - - return ResponseEntity.ok(agentResponse); - } + @PostMapping("/chat") + public ResponseEntity handleUserChat( + @RequestBody String userMessage, HttpSession session) { + // We use the HTTP Session ID to tell LangChain4j which memory context to load + String sessionId = session.getId(); + + // Let Gemini 3 Flash handle the thinking, tool execution, and response generation! + String agentResponse = transitAgent.chat(sessionId, userMessage); + + return ResponseEntity.ok(agentResponse); + } } diff --git a/example/pom.xml b/example/pom.xml index 53ca764..551141e 100644 --- a/example/pom.xml +++ b/example/pom.xml @@ -33,7 +33,7 @@ com.google.cloud.mcp mcp-toolbox-sdk-java - 0.2.0 + 0.2.1-SNAPSHOT diff --git a/example/src/main/java/cloudcode/helloworld/ExampleUsage.java b/example/src/main/java/cloudcode/helloworld/ExampleUsage.java index 504886c..70e18e5 100644 --- a/example/src/main/java/cloudcode/helloworld/ExampleUsage.java +++ b/example/src/main/java/cloudcode/helloworld/ExampleUsage.java @@ -16,145 +16,166 @@ package cloudcode.helloworld; -import java.util.Map; -import java.util.Collections; -import java.util.concurrent.CompletableFuture; -import java.io.FileInputStream; -import com.google.cloud.mcp.McpToolboxClient; -import com.google.cloud.mcp.AuthTokenGetter; import com.google.auth.oauth2.GoogleCredentials; import com.google.auth.oauth2.IdTokenProvider; +import com.google.cloud.mcp.McpToolboxClient; +import com.google.cloud.mcp.auth.AuthTokenGetter; +import java.io.FileInputStream; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.CompletableFuture; /** - * Sample Application to demostrate the usage of the MCP Toolbox Java SDK. - * Covers: Global Auth, Parameterized Auth, Discovery, Simple Tool, Authenticated Tool, Parameter Binding. + * Sample Application to demostrate the usage of the MCP Toolbox Java SDK. Covers: Global Auth, + * Parameterized Auth, Discovery, Simple Tool, Authenticated Tool, Parameter Binding. */ public class ExampleUsage { - public static void main(String[] args) { - // CONFIGURATION - String targetUrl = "YOUR_TOOLBOX_SERVICE_ENDPOINT"; - - // Match the Service URL if using Cloud Run OIDC - String tokenAudience = targetUrl; - - // -------------------------------------------------------------------------------- - // AUTHENTICATION SETUP - // -------------------------------------------------------------------------------- - // FOR LOCAL DEVELOPMENT: Use a Service Account Key JSON file. - // FOR PRODUCTION (Cloud Run): Comment out the 'keyPath' logic and use ADC directly. - // -------------------------------------------------------------------------------- - - String keyPath = "YOUR_CREDENTIALS_JSON_FILE_PATH.json"; - - System.out.println("--- Starting MCP Toolbox Integration Test ---"); - System.out.println("Target Server: " + targetUrl); - -try { - System.out.println(" [Init] Fetching ID Token..."); - - GoogleCredentials credentials; - - // --- OPTION A: LOCAL DEV (Explicit Key File) --- - if (keyPath != null && !keyPath.isEmpty()) { - System.out.println(" [Auth] Using Service Account Key File: " + keyPath); - credentials = GoogleCredentials.fromStream(new FileInputStream(keyPath)); - } - // --- OPTION B: PRODUCTION (ADC) --- - else { - System.out.println(" [Auth] Using Application Default Credentials (ADC)"); - credentials = GoogleCredentials.getApplicationDefault(); - } - - if (!(credentials instanceof IdTokenProvider)) { - throw new RuntimeException("Loaded credentials do not support ID Tokens."); - } - - // Generate Token for the specified Audience - String idToken = ((IdTokenProvider) credentials).idTokenWithAudience(tokenAudience, Collections.emptyList()).getTokenValue(); - System.out.println(" [Debug] Token Generated."); - - // Initialize Client with Global Auth (Applies to ALL calls - Gate 1) - McpToolboxClient client = McpToolboxClient.builder() - .baseUrl(targetUrl) - .apiKey(idToken) - .build(); - - // STEP 1: TEST DISCOVERY METHODS - client.listTools() - .thenCompose(tools -> { - System.out.println("\n[1] listTools(): Success. Found " + tools.size() + " tools."); - return client.loadToolset(); - }) - .thenCompose(tools -> { - System.out.println("[2] loadToolset() (Alias): Success."); - return client.loadToolset("retail") - .handle((res, ex) -> { - if (ex == null) System.out.println("[3] loadToolset('retail'): Found " + res.size() + " tools."); - else System.out.println("[3] loadToolset('retail'): Skipped (Not configured on server)."); - return null; + public static void main(String[] args) { + // CONFIGURATION + String targetUrl = "YOUR_TOOLBOX_SERVICE_ENDPOINT"; + + // Match the Service URL if using Cloud Run OIDC + String tokenAudience = targetUrl; + + // -------------------------------------------------------------------------------- + // AUTHENTICATION SETUP + // -------------------------------------------------------------------------------- + // FOR LOCAL DEVELOPMENT: Use a Service Account Key JSON file. + // FOR PRODUCTION (Cloud Run): Comment out the 'keyPath' logic and use ADC directly. + // -------------------------------------------------------------------------------- + + String keyPath = "YOUR_CREDENTIALS_JSON_FILE_PATH.json"; + + System.out.println("--- Starting MCP Toolbox Integration Test ---"); + System.out.println("Target Server: " + targetUrl); + + try { + System.out.println(" [Init] Fetching ID Token..."); + + GoogleCredentials credentials; + + // --- OPTION A: LOCAL DEV (Explicit Key File) --- + if (keyPath != null && !keyPath.isEmpty()) { + System.out.println(" [Auth] Using Service Account Key File: " + keyPath); + credentials = GoogleCredentials.fromStream(new FileInputStream(keyPath)); + } + // --- OPTION B: PRODUCTION (ADC) --- + else { + System.out.println(" [Auth] Using Application Default Credentials (ADC)"); + credentials = GoogleCredentials.getApplicationDefault(); + } + + if (!(credentials instanceof IdTokenProvider)) { + throw new RuntimeException("Loaded credentials do not support ID Tokens."); + } + + // Generate Token for the specified Audience + String idToken = + ((IdTokenProvider) credentials) + .idTokenWithAudience(tokenAudience, Collections.emptyList()) + .getTokenValue(); + System.out.println(" [Debug] Token Generated."); + + // Initialize Client with Global Auth (Applies to ALL calls - Gate 1) + McpToolboxClient client = + McpToolboxClient.builder().baseUrl(targetUrl).apiKey(idToken).build(); + + // STEP 1: TEST DISCOVERY METHODS + client + .listTools() + .thenCompose( + tools -> { + System.out.println("\n[1] listTools(): Success. Found " + tools.size() + " tools."); + return client.loadToolset(); + }) + .thenCompose( + tools -> { + System.out.println("[2] loadToolset() (Alias): Success."); + return client + .loadToolset("retail") + .handle( + (res, ex) -> { + if (ex == null) + System.out.println( + "[3] loadToolset('retail'): Found " + res.size() + " tools."); + else + System.out.println( + "[3] loadToolset('retail'): Skipped (Not configured on server)."); + return null; }); - }) - .thenCompose(ignore -> { - - // STEP 2: INVOKE TOOL WITHOUT EXTRA AUTH - System.out.println("\n[4] Testing Simple Tool: 'get-retail-facet-filters'..."); - return client.invokeTool("get-retail-facet-filters", Map.of()); - }) - .thenCompose(result -> { - System.out.println(" -> Result: " + (result.content() != null ? "Received Data" : "Empty")); - - // STEP 3: INVOKE TOOL WITH AUTHENTICATED PARAMETERS - System.out.println("\n[5] Testing Authenticated Tool: 'get-toy-price'..."); - - // Define the getter for the 'google_auth' service - AuthTokenGetter toolAuthGetter = () -> CompletableFuture.completedFuture(idToken); - - // Load using the sophisticated overload - return client.loadTool("get-toy-price", Map.of("google_auth", toolAuthGetter)); - }) - .thenCompose(tool -> { - System.out.println(" -> Loaded Tool: " + tool.definition().description()); - - // STEP 4: TEST BINDING PARAMETERS SEQUENTIALLY - System.out.println("\n[A] Executing UNBOUND (Runtime arg: 'barbie')..."); - - return tool.execute(Map.of("description", "barbie")) - .thenCompose(result1 -> { - if (result1.content() != null && !result1.content().isEmpty()) { - System.out - .println(" -> Result (Unbound): " + result1.content().get(0).text()); - } - - // NOW bind the parameter - System.out.println("\n[B] Binding 'description' to 'soft toy'..."); - tool.bindParam("description", "soft toy"); - - System.out.println(" -> Executing BOUND (Runtime arg: 'barbie' - should be IGNORED)..."); - // We pass 'barbie', but expecting 'soft toy' price because of binding override - return tool.execute(Map.of("description", "barbie")); + }) + .thenCompose( + ignore -> { + + // STEP 2: INVOKE TOOL WITHOUT EXTRA AUTH + System.out.println("\n[4] Testing Simple Tool: 'get-retail-facet-filters'..."); + return client.invokeTool("get-retail-facet-filters", Map.of()); + }) + .thenCompose( + result -> { + System.out.println( + " -> Result: " + (result.content() != null ? "Received Data" : "Empty")); + + // STEP 3: INVOKE TOOL WITH AUTHENTICATED PARAMETERS + System.out.println("\n[5] Testing Authenticated Tool: 'get-toy-price'..."); + + // Define the getter for the 'google_auth' service + AuthTokenGetter toolAuthGetter = () -> CompletableFuture.completedFuture(idToken); + + // Load using the sophisticated overload + return client.loadTool("get-toy-price", Map.of("google_auth", toolAuthGetter)); + }) + .thenCompose( + tool -> { + System.out.println(" -> Loaded Tool: " + tool.definition().description()); + + // STEP 4: TEST BINDING PARAMETERS SEQUENTIALLY + System.out.println("\n[A] Executing UNBOUND (Runtime arg: 'barbie')..."); + + return tool.execute(Map.of("description", "barbie")) + .thenCompose( + result1 -> { + if (result1.content() != null && !result1.content().isEmpty()) { + System.out.println( + " -> Result (Unbound): " + result1.content().get(0).text()); + } + + // NOW bind the parameter + System.out.println("\n[B] Binding 'description' to 'soft toy'..."); + tool.bindParam("description", "soft toy"); + + System.out.println( + " -> Executing BOUND (Runtime arg: 'barbie' - should be" + + " IGNORED)..."); + // We pass 'barbie', but expecting 'soft toy' price because of binding + // override + return tool.execute(Map.of("description", "barbie")); }); - }) - .thenAccept(result -> { - System.out.println("\n[6] Final Result (Bound):"); - if (result.isError()) { - System.err.println("Tool execution failed: " + result.content().get(0).text()); - } else if (result.content() != null && !result.content().isEmpty()) { - String output = result.content().get(0).text(); - System.out.println(" " + output.substring(0, Math.min(output.length(), 200)) + "..."); - } else { - System.out.println(" Empty Response"); - } - }) - .exceptionally(ex -> { - System.err.println("\n!!! TEST FAILED !!!"); - ex.printStackTrace(); - return null; - }) - .join(); - - } catch (Exception e) { - e.printStackTrace(); - } - System.out.println("\n--- Test Suite Complete ---"); + }) + .thenAccept( + result -> { + System.out.println("\n[6] Final Result (Bound):"); + if (result.isError()) { + System.err.println("Tool execution failed: " + result.content().get(0).text()); + } else if (result.content() != null && !result.content().isEmpty()) { + String output = result.content().get(0).text(); + System.out.println( + " " + output.substring(0, Math.min(output.length(), 200)) + "..."); + } else { + System.out.println(" Empty Response"); + } + }) + .exceptionally( + ex -> { + System.err.println("\n!!! TEST FAILED !!!"); + ex.printStackTrace(); + return null; + }) + .join(); + + } catch (Exception e) { + e.printStackTrace(); } + System.out.println("\n--- Test Suite Complete ---"); + } } diff --git a/example/src/main/java/cloudcode/helloworld/InputValidationTest.java b/example/src/main/java/cloudcode/helloworld/InputValidationTest.java index b171ae7..4971e25 100644 --- a/example/src/main/java/cloudcode/helloworld/InputValidationTest.java +++ b/example/src/main/java/cloudcode/helloworld/InputValidationTest.java @@ -16,100 +16,101 @@ package cloudcode.helloworld; -import com.google.cloud.mcp.McpToolboxClient; -import com.google.cloud.mcp.Tool; import com.google.auth.oauth2.GoogleCredentials; import com.google.auth.oauth2.IdTokenProvider; +import com.google.cloud.mcp.McpToolboxClient; +import com.google.cloud.mcp.tool.Tool; import java.io.FileInputStream; import java.util.Collections; -import java.util.Map; import java.util.HashMap; +import java.util.Map; import java.util.concurrent.CompletableFuture; public class InputValidationTest { - public static void main(String[] args) { - String targetUrl = "YOUR_TOOLBOX_SERVICE_ENDPOINT"; - String tokenAudience = targetUrl; - // -------------------------------------------------------------------------------- - // AUTHENTICATION SETUP - // -------------------------------------------------------------------------------- - // FOR LOCAL DEVELOPMENT: Use a Service Account Key JSON file. - // FOR PRODUCTION (Cloud Run): Comment out the 'keyPath' logic and use ADC directly. - // -------------------------------------------------------------------------------- + public static void main(String[] args) { + String targetUrl = "YOUR_TOOLBOX_SERVICE_ENDPOINT"; + String tokenAudience = targetUrl; + // -------------------------------------------------------------------------------- + // AUTHENTICATION SETUP + // -------------------------------------------------------------------------------- + // FOR LOCAL DEVELOPMENT: Use a Service Account Key JSON file. + // FOR PRODUCTION (Cloud Run): Comment out the 'keyPath' logic and use ADC directly. + // -------------------------------------------------------------------------------- - String keyPath = "/YOUR_CREDENTIALS_JSON_FILE_PATH.json"; + String keyPath = "/YOUR_CREDENTIALS_JSON_FILE_PATH.json"; - System.out.println("--- Starting MCP Toolbox Input Validation Test ---"); + System.out.println("--- Starting MCP Toolbox Input Validation Test ---"); - try { - // 1. Setup Auth (Same as before) - System.out.println(" [Init] Fetching ID Token..."); - GoogleCredentials credentials = GoogleCredentials.fromStream(new FileInputStream(keyPath)); - if (!(credentials instanceof IdTokenProvider)) { - throw new RuntimeException("Loaded credentials do not support ID Tokens."); - } - String idToken = ((IdTokenProvider) credentials).idTokenWithAudience(tokenAudience, Collections.emptyList()).getTokenValue(); + try { + // 1. Setup Auth (Same as before) + System.out.println(" [Init] Fetching ID Token..."); + GoogleCredentials credentials = GoogleCredentials.fromStream(new FileInputStream(keyPath)); + if (!(credentials instanceof IdTokenProvider)) { + throw new RuntimeException("Loaded credentials do not support ID Tokens."); + } + String idToken = + ((IdTokenProvider) credentials) + .idTokenWithAudience(tokenAudience, Collections.emptyList()) + .getTokenValue(); - // 2. Initialize Client - McpToolboxClient client = McpToolboxClient.builder() - .baseUrl(targetUrl) - .build(); + // 2. Initialize Client + McpToolboxClient client = McpToolboxClient.builder().baseUrl(targetUrl).build(); - // 3. Load the Tool - // We MUST use loadTool() because validation relies on the ToolDefinition fetched from the server. - System.out.println(" [Init] Loading tool 'get-toy-price'..."); - Tool tool = client.loadTool("get-toy-price").join(); + // 3. Load the Tool + // We MUST use loadTool() because validation relies on the ToolDefinition fetched from the + // server. + System.out.println(" [Init] Loading tool 'get-toy-price'..."); + Tool tool = client.loadTool("get-toy-price").join(); - // 4. Register Auth - // We manually register the token getter so the Tool object can inject the header. - tool.addAuthTokenGetter("google_auth", () -> CompletableFuture.completedFuture(idToken)); + // 4. Register Auth + // We manually register the token getter so the Tool object can inject the header. + tool.addAuthTokenGetter("google_auth", () -> CompletableFuture.completedFuture(idToken)); + // --- Test Case A: Valid Input --- + System.out.println("\n[Test A] Sending VALID input (String)..."); + try { + Map validArgs = Map.of("description", "barbie"); + var result = tool.execute(validArgs).join(); + System.out.println( + " ✅ Success! Output: " + + (result.content().isEmpty() ? "Empty" : result.content().get(0).text())); + } catch (Exception e) { + System.err.println(" ❌ Unexpected failure: " + e.getMessage()); + e.printStackTrace(); + } - // --- Test Case A: Valid Input --- - System.out.println("\n[Test A] Sending VALID input (String)..."); - try { - Map validArgs = Map.of("description", "barbie"); - var result = tool.execute(validArgs).join(); - System.out.println(" ✅ Success! Output: " + (result.content().isEmpty() ? "Empty" : result.content().get(0).text())); - } catch (Exception e) { - System.err.println(" ❌ Unexpected failure: " + e.getMessage()); - e.printStackTrace(); - } + // --- Test Case B: Invalid Type (Int instead of String) --- + System.out.println("\n[Test B] Sending INVALID input (Integer instead of String)..."); + try { + // The 'description' parameter is defined as type: string. We pass an Integer. + Map invalidArgs = Map.of("description", 12345); + tool.execute(invalidArgs).join(); + System.err.println(" ❌ FAILED: Validation did not catch the error!"); + } catch (Exception e) { + // We expect a RuntimeException wrapping IllegalArgumentException + Throwable cause = e.getCause(); + System.out.println(" ✅ Caught Expected Error: " + cause.getMessage()); + } - // --- Test Case B: Invalid Type (Int instead of String) --- - System.out.println("\n[Test B] Sending INVALID input (Integer instead of String)..."); - try { - // The 'description' parameter is defined as type: string. We pass an Integer. - Map invalidArgs = Map.of("description", 12345); - - tool.execute(invalidArgs).join(); - System.err.println(" ❌ FAILED: Validation did not catch the error!"); - } catch (Exception e) { - // We expect a RuntimeException wrapping IllegalArgumentException - Throwable cause = e.getCause(); - System.out.println(" ✅ Caught Expected Error: " + cause.getMessage()); - } - + // --- Test Case C: Null Value (Filtering) --- + System.out.println("\n[Test C] Sending NULL value (should be filtered)..."); + try { + // We use a HashMap because Map.of doesn't allow nulls + Map nullArgs = new HashMap<>(); + nullArgs.put("description", "barbie"); // Valid param + nullArgs.put("some_optional_param", null); // Null param - // --- Test Case C: Null Value (Filtering) --- - System.out.println("\n[Test C] Sending NULL value (should be filtered)..."); - try { - // We use a HashMap because Map.of doesn't allow nulls - Map nullArgs = new HashMap<>(); - nullArgs.put("description", "barbie"); // Valid param - nullArgs.put("some_optional_param", null); // Null param - - // If validation works, 'some_optional_param' will be removed before sending - var result = tool.execute(nullArgs).join(); - System.out.println(" ✅ Success! Null value was filtered and request succeeded."); - } catch (Exception e) { - System.out.println(" ❌ Result: " + e.getCause().getMessage()); - } + // If validation works, 'some_optional_param' will be removed before sending + var result = tool.execute(nullArgs).join(); + System.out.println(" ✅ Success! Null value was filtered and request succeeded."); + } catch (Exception e) { + System.out.println(" ❌ Result: " + e.getCause().getMessage()); + } - } catch (Exception e) { - e.printStackTrace(); - } - System.out.println("\n--- Done ---"); + } catch (Exception e) { + e.printStackTrace(); } + System.out.println("\n--- Done ---"); + } } diff --git a/example/src/main/java/cloudcode/helloworld/StrictFlagTest.java b/example/src/main/java/cloudcode/helloworld/StrictFlagTest.java index ba32ff9..df01a8c 100644 --- a/example/src/main/java/cloudcode/helloworld/StrictFlagTest.java +++ b/example/src/main/java/cloudcode/helloworld/StrictFlagTest.java @@ -17,18 +17,16 @@ package cloudcode.helloworld; import com.google.cloud.mcp.McpToolboxClient; -import com.google.cloud.mcp.Tool; -import java.util.Map; +import com.google.cloud.mcp.tool.Tool; import java.util.HashMap; +import java.util.Map; public class StrictFlagTest { public static void main(String[] args) { String targetUrl = "YOUR_TOOLBOX_SERVICE_ENDPOINT"; System.out.println("--- Starting MCP Toolbox Strict Flag Test ---"); - McpToolboxClient client = McpToolboxClient.builder() - .baseUrl(targetUrl) - .build(); + McpToolboxClient client = McpToolboxClient.builder().baseUrl(targetUrl).build(); // Prepare bindings for a NON-EXISTENT tool Map> paramBinds = new HashMap<>(); @@ -38,7 +36,8 @@ public static void main(String[] args) { System.out.println("\n[Test 1] Loading with Strict = FALSE..."); try { Map tools = client.loadToolset(null, paramBinds, null, false).join(); - System.out.println(" ✅ Success! Loaded " + tools.size() + " tools. Unknown binding was ignored."); + System.out.println( + " ✅ Success! Loaded " + tools.size() + " tools. Unknown binding was ignored."); } catch (Exception e) { System.err.println(" ❌ Failed unexpectedly: " + e.getMessage()); } @@ -60,4 +59,4 @@ public static void main(String[] args) { System.out.println("\n--- Done ---"); } -} \ No newline at end of file +} diff --git a/pom.xml b/pom.xml index 2a8076c..8ae672c 100644 --- a/pom.xml +++ b/pom.xml @@ -94,6 +94,13 @@ ${google.auth.version} + + + io.opentelemetry + opentelemetry-api + 1.34.1 + + org.junit.jupiter @@ -119,6 +126,12 @@ 2.32.0 test + + io.opentelemetry + opentelemetry-sdk-testing + 1.34.1 + test + diff --git a/src/main/java/com/google/cloud/mcp/JsonRpc.java b/src/main/java/com/google/cloud/mcp/JsonRpc.java index 902ac2a..6dd5632 100644 --- a/src/main/java/com/google/cloud/mcp/JsonRpc.java +++ b/src/main/java/com/google/cloud/mcp/JsonRpc.java @@ -19,50 +19,173 @@ import java.util.Map; import java.util.UUID; -class JsonRpc { - static class Request { +/** Helper classes representing JSON-RPC requests, notifications, and parameters. */ +public class JsonRpc { + + /** Hide default constructor. */ + private JsonRpc() {} + + /** Represents a JSON-RPC request with an ID. */ + public static class Request { + /** The JSON-RPC version. */ public String jsonrpc = "2.0"; + + /** The request ID. */ public String id; + + /** The method name. */ public String method; + + /** The parameters. */ public Object params; - public Request(String method, Object params) { + /** + * Constructs a new Request. + * + * @param method The method name. + * @param params The parameters. + */ + public Request(final String method, final Object params) { this.id = UUID.randomUUID().toString(); this.method = method; this.params = params; } } - static class Notification { + /** Represents a JSON-RPC notification without an ID. */ + public static class Notification { + /** The JSON-RPC version. */ public String jsonrpc = "2.0"; + + /** The method name. */ public String method; + + /** The parameters. */ public Object params; - public Notification(String method, Object params) { + /** + * Constructs a new Notification. + * + * @param method The method name. + * @param params The parameters. + */ + public Notification(final String method, final Object params) { this.method = method; this.params = params; } } - static class CallToolParams { + /** Represents telemetry metadata in JSON-RPC parameters. */ + public static class RequestMetadata { + /** The traceparent header value. */ + public String traceparent; + + /** The tracestate header value. */ + public String tracestate; + + /** + * Constructs a new RequestMetadata. + * + * @param traceparent The traceparent header value. + * @param tracestate The tracestate header value. + */ + public RequestMetadata(String traceparent, String tracestate) { + this.traceparent = traceparent; + this.tracestate = tracestate; + } + } + + /** Parameters for calling a tool. */ + public static class CallToolParams { + /** The tool name. */ public String name; + + /** The arguments. */ public Map arguments; - public CallToolParams(String name, Map arguments) { + /** Telemetry metadata. */ + public RequestMetadata _meta; + + /** + * Constructs a new CallToolParams without metadata. + * + * @param name The tool name. + * @param arguments The arguments. + */ + public CallToolParams(final String name, final Map arguments) { + this(name, arguments, null); + } + + /** + * Constructs a new CallToolParams with metadata. + * + * @param name The tool name. + * @param arguments The arguments. + * @param meta The telemetry metadata. + */ + public CallToolParams(String name, Map arguments, RequestMetadata meta) { this.name = name; this.arguments = arguments; + this._meta = meta; } } - static class InitializeParams { + /** Parameters for listing tools. */ + public static class ListToolsParams { + /** The pagination cursor. */ + public String cursor; + + /** Telemetry metadata. */ + public RequestMetadata _meta; + + /** + * Constructs a new ListToolsParams. + * + * @param cursor The pagination cursor. + * @param meta The telemetry metadata. + */ + public ListToolsParams(String cursor, RequestMetadata meta) { + this.cursor = cursor; + this._meta = meta; + } + } + + /** Parameters for initializing the connection. */ + public static class InitializeParams { + /** The client protocol version. */ public String protocolVersion; + + /** The client capabilities. */ public Map capabilities; + + /** The client info. */ public Map clientInfo; - public InitializeParams(String version, String clientName) { + /** Telemetry metadata. */ + public RequestMetadata _meta; + + /** + * Constructs a new InitializeParams without metadata. + * + * @param version The protocol version. + * @param clientName The client name. + */ + public InitializeParams(final String version, final String clientName) { + this(version, clientName, null); + } + + /** + * Constructs a new InitializeParams with metadata. + * + * @param version The protocol version. + * @param clientName The client name. + * @param meta The telemetry metadata. + */ + public InitializeParams(String version, String clientName, RequestMetadata meta) { this.protocolVersion = version; this.capabilities = Map.of(); this.clientInfo = Map.of("name", clientName, "version", "1.0.0"); + this._meta = meta; } } } diff --git a/src/main/java/com/google/cloud/mcp/McpToolboxClient.java b/src/main/java/com/google/cloud/mcp/McpToolboxClient.java index 472a6f3..29fee63 100644 --- a/src/main/java/com/google/cloud/mcp/McpToolboxClient.java +++ b/src/main/java/com/google/cloud/mcp/McpToolboxClient.java @@ -16,11 +16,19 @@ package com.google.cloud.mcp; +import com.google.cloud.mcp.auth.AuthTokenGetter; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.client.McpToolboxClientBuilder; +import com.google.cloud.mcp.tool.Tool; +import com.google.cloud.mcp.tool.ToolDefinition; +import com.google.cloud.mcp.tool.ToolPostProcessor; +import com.google.cloud.mcp.tool.ToolPreProcessor; +import com.google.cloud.mcp.tool.ToolResult; import java.util.Map; import java.util.concurrent.CompletableFuture; /** The core client for interacting with an MCP Toolbox Server. */ -public interface McpToolboxClient { +public interface McpToolboxClient extends AutoCloseable { /** * Connects to the MCP Server and retrieves the list of all available tools. @@ -190,4 +198,10 @@ interface Builder { */ McpToolboxClient build(); } + + /** Closes the client and records session metrics. */ + @Override + default void close() { + // No-op by default + } } diff --git a/src/main/java/com/google/cloud/mcp/TelemetryHelper.java b/src/main/java/com/google/cloud/mcp/TelemetryHelper.java new file mode 100644 index 0000000..1f52ae7 --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/TelemetryHelper.java @@ -0,0 +1,336 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp; + +import io.opentelemetry.api.GlobalOpenTelemetry; +import io.opentelemetry.api.common.Attributes; +import io.opentelemetry.api.common.AttributesBuilder; +import io.opentelemetry.api.metrics.DoubleHistogram; +import io.opentelemetry.api.metrics.Meter; +import io.opentelemetry.api.trace.Span; +import io.opentelemetry.api.trace.SpanKind; +import io.opentelemetry.api.trace.StatusCode; +import io.opentelemetry.api.trace.Tracer; +import io.opentelemetry.context.Context; +import io.opentelemetry.context.Scope; +import io.opentelemetry.context.propagation.TextMapPropagator; +import java.net.URI; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +/** Helper class for OpenTelemetry metrics and tracing instrumentation. */ +public final class TelemetryHelper { + /** Bucket boundary 0.01. */ + private static final double B_0_01 = 0.01; + + /** Bucket boundary 0.02. */ + private static final double B_0_02 = 0.02; + + /** Bucket boundary 0.05. */ + private static final double B_0_05 = 0.05; + + /** Bucket boundary 0.1. */ + private static final double B_0_1 = 0.1; + + /** Bucket boundary 0.2. */ + private static final double B_0_2 = 0.2; + + /** Bucket boundary 0.5. */ + private static final double B_0_5 = 0.5; + + /** Bucket boundary 1.0. */ + private static final double B_1 = 1.0; + + /** Bucket boundary 2.0. */ + private static final double B_2 = 2.0; + + /** Bucket boundary 5.0. */ + private static final double B_5 = 5.0; + + /** Bucket boundary 10.0. */ + private static final double B_10 = 10.0; + + /** Bucket boundary 30.0. */ + private static final double B_30 = 30.0; + + /** Bucket boundary 60.0. */ + private static final double B_60 = 60.0; + + /** Bucket boundary 120.0. */ + private static final double B_120 = 120.0; + + /** Bucket boundary 300.0. */ + private static final double B_300 = 300.0; + + /** Conversion factor from nanoseconds to seconds. */ + static final double NANOS_IN_SECOND = 1e9; + + /** Name of the instrumentation library. */ + private static final String INSTRUMENTATION_NAME = "toolbox.mcp.sdk"; + + // Dynamic / lazy OpenTelemetry binding cache + private static io.opentelemetry.api.OpenTelemetry lastOtel = null; + private static DoubleHistogram cachedOperationDuration = null; + private static DoubleHistogram cachedSessionDuration = null; + + private static synchronized void checkRebind() { + io.opentelemetry.api.OpenTelemetry currentOtel = GlobalOpenTelemetry.get(); + if (currentOtel != lastOtel) { + lastOtel = currentOtel; + Meter meter = currentOtel.getMeter(INSTRUMENTATION_NAME); + cachedOperationDuration = + meter + .histogramBuilder("mcp.client.operation.duration") + .setUnit("s") + .setDescription( + "Duration of MCP client operations (requests/notifications) from the time it was" + + " sent until the response or ack is received.") + .setExplicitBucketBoundariesAdvice( + Arrays.asList( + B_0_01, B_0_02, B_0_05, B_0_1, B_0_2, B_0_5, B_1, B_2, B_5, B_10, B_30, B_60, + B_120, B_300)) + .build(); + cachedSessionDuration = + meter + .histogramBuilder("mcp.client.session.duration") + .setUnit("s") + .setDescription("Total duration of MCP client sessions") + .setExplicitBucketBoundariesAdvice( + Arrays.asList( + B_0_01, B_0_02, B_0_05, B_0_1, B_0_2, B_0_5, B_1, B_2, B_5, B_10, B_30, B_60, + B_120, B_300)) + .build(); + } + } + + private static DoubleHistogram operationDuration() { + checkRebind(); + return cachedOperationDuration; + } + + private static DoubleHistogram sessionDuration() { + checkRebind(); + return cachedSessionDuration; + } + + private static Tracer tracer() { + return GlobalOpenTelemetry.getTracer(INSTRUMENTATION_NAME); + } + + private static TextMapPropagator propagator() { + return GlobalOpenTelemetry.getPropagators().getTextMapPropagator(); + } + + private TelemetryHelper() {} + + /** + * Helper record to extract ServerInfo. + * + * @param address The server host address. + * @param port The server port. + * @param protocol The network protocol (e.g. http). + */ + record ServerInfo(String address, Integer port, String protocol) {} + + static ServerInfo extractServerInfo(final String urlStr) { + try { + URI uri = new URI(urlStr); + String host = uri.getHost(); + if (host == null) { + host = uri.getAuthority(); + if (host != null && host.contains(":")) { + host = host.substring(0, host.indexOf(':')); + } + } + int port = uri.getPort(); + if (port == -1 && uri.getAuthority() != null && uri.getAuthority().contains(":")) { + try { + String auth = uri.getAuthority(); + port = Integer.parseInt(auth.substring(auth.indexOf(':') + 1)); + } catch (NumberFormatException e) { + // ignore + } + } + String protocol = uri.getScheme(); + if (protocol == null) { + protocol = "http"; + } + return new ServerInfo(host != null ? host : "", port != -1 ? port : null, protocol); + } catch (Exception e) { + return new ServerInfo("", null, "http"); + } + } + + /** Wrapper for recording client operation metrics and tracing spans. */ + public static class OperationSpan implements AutoCloseable { + /** The OpenTelemetry span. */ + private final Span span; + + /** The scope for the current span context. */ + private final Scope scope; + + /** Start time of the span in nanoseconds. */ + private final long startTimeNanos; + + /** Name of the MCP method. */ + private final String methodName; + + /** Protocol version of MCP. */ + private final String protocolVersion; + + /** Server base URL. */ + private final String serverUrl; + + /** Name of the tool. */ + private final String toolName; + + /** Class name of the error if an error occurred. */ + private String errorType = null; + + /** + * Constructs a new OperationSpan. + * + * @param method The MCP method name. + * @param version The protocol version. + * @param url The server base URL. + * @param tool The tool name, or null. + */ + public OperationSpan( + final String method, final String version, final String url, final String tool) { + this.methodName = method; + this.protocolVersion = version; + this.serverUrl = url; + this.toolName = tool; + this.startTimeNanos = System.nanoTime(); + + String spanName = tool != null ? method + " " + tool : method; + this.span = tracer().spanBuilder(spanName).setSpanKind(SpanKind.CLIENT).startSpan(); + this.scope = span.makeCurrent(); + + // Set standard span attributes + span.setAttribute("mcp.method.name", method); + span.setAttribute("mcp.protocol.version", version); + ServerInfo info = extractServerInfo(url); + span.setAttribute("server.address", info.address()); + span.setAttribute("network.protocol.name", info.protocol()); + span.setAttribute("network.transport", "tcp"); + if (info.port() != null) { + span.setAttribute("server.port", (long) info.port()); + } + if (tool != null) { + span.setAttribute("gen_ai.tool.name", tool); + } + if ("tools/call".equals(method)) { + span.setAttribute("gen_ai.operation.name", "execute_tool"); + } + } + + /** + * Gets W3C context headers to inject into the request. + * + * @return A map containing trace context headers. + */ + public Map getTraceContextHeaders() { + Map carrier = new HashMap<>(); + propagator().inject(Context.current(), carrier, Map::put); + return carrier; + } + + /** + * Records a throwable error on the span. + * + * @param t The error thrown. + */ + public void recordError(final Throwable t) { + span.recordException(t); + span.setStatus(StatusCode.ERROR, t.getMessage()); + this.errorType = t.getClass().getName(); + span.setAttribute("error.type", errorType); + } + + /** + * Records a JSON-RPC error on the span. + * + * @param code The JSON-RPC error code. + * @param message The error message. + */ + public void recordError(final int code, final String message) { + span.setStatus(StatusCode.ERROR, message); + this.errorType = "jsonrpc.error." + code; + span.setAttribute("error.type", errorType); + } + + @Override + public void close() { + scope.close(); + span.end(); + + // Record operation duration metric + double durationSeconds = (System.nanoTime() - startTimeNanos) / NANOS_IN_SECOND; + AttributesBuilder attrs = + Attributes.builder() + .put("mcp.method.name", methodName) + .put("mcp.protocol.version", protocolVersion); + ServerInfo info = extractServerInfo(serverUrl); + attrs.put("server.address", info.address()); + attrs.put("network.protocol.name", info.protocol()); + attrs.put("network.transport", "tcp"); + if (info.port() != null) { + attrs.put("server.port", (long) info.port()); + } + if (toolName != null) { + attrs.put("gen_ai.tool.name", toolName); + } + if ("tools/call".equals(methodName)) { + attrs.put("gen_ai.operation.name", "execute_tool"); + } + if (errorType != null) { + attrs.put("error.type", errorType); + } + + operationDuration().record(durationSeconds, attrs.build()); + } + } + + /** + * Records the duration of an MCP session. + * + * @param durationSeconds The duration of the session in seconds. + * @param protocolVersion The negotiated protocol version. + * @param serverUrl The server base URL. + * @param error The session error, or null if successful. + */ + public static void recordSessionDuration( + final double durationSeconds, + final String protocolVersion, + final String serverUrl, + final Throwable error) { + AttributesBuilder attrs = Attributes.builder().put("mcp.protocol.version", protocolVersion); + ServerInfo info = extractServerInfo(serverUrl); + attrs.put("server.address", info.address()); + attrs.put("network.protocol.name", info.protocol()); + attrs.put("network.transport", "tcp"); + if (info.port() != null) { + attrs.put("server.port", (long) info.port()); + } + if (error != null) { + attrs.put("error.type", error.getClass().getName()); + } + sessionDuration().record(durationSeconds, attrs.build()); + } +} diff --git a/src/main/java/com/google/cloud/mcp/AuthMethods.java b/src/main/java/com/google/cloud/mcp/auth/AuthMethods.java similarity index 98% rename from src/main/java/com/google/cloud/mcp/AuthMethods.java rename to src/main/java/com/google/cloud/mcp/auth/AuthMethods.java index cb3815f..a287009 100644 --- a/src/main/java/com/google/cloud/mcp/AuthMethods.java +++ b/src/main/java/com/google/cloud/mcp/auth/AuthMethods.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.google.cloud.mcp; +package com.google.cloud.mcp.auth; import com.google.auth.oauth2.GoogleCredentials; import com.google.auth.oauth2.IdTokenProvider; diff --git a/src/main/java/com/google/cloud/mcp/AuthResolver.java b/src/main/java/com/google/cloud/mcp/auth/AuthResolver.java similarity index 98% rename from src/main/java/com/google/cloud/mcp/AuthResolver.java rename to src/main/java/com/google/cloud/mcp/auth/AuthResolver.java index 109c6ba..0fccadd 100644 --- a/src/main/java/com/google/cloud/mcp/AuthResolver.java +++ b/src/main/java/com/google/cloud/mcp/auth/AuthResolver.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.google.cloud.mcp; +package com.google.cloud.mcp.auth; import java.util.HashMap; import java.util.List; diff --git a/src/main/java/com/google/cloud/mcp/AuthTokenGetter.java b/src/main/java/com/google/cloud/mcp/auth/AuthTokenGetter.java similarity index 96% rename from src/main/java/com/google/cloud/mcp/AuthTokenGetter.java rename to src/main/java/com/google/cloud/mcp/auth/AuthTokenGetter.java index 5d3f736..6067352 100644 --- a/src/main/java/com/google/cloud/mcp/AuthTokenGetter.java +++ b/src/main/java/com/google/cloud/mcp/auth/AuthTokenGetter.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.google.cloud.mcp; +package com.google.cloud.mcp.auth; import java.util.concurrent.CompletableFuture; diff --git a/src/main/java/com/google/cloud/mcp/CredentialsProvider.java b/src/main/java/com/google/cloud/mcp/auth/CredentialsProvider.java similarity index 96% rename from src/main/java/com/google/cloud/mcp/CredentialsProvider.java rename to src/main/java/com/google/cloud/mcp/auth/CredentialsProvider.java index eb428a0..a9c7ca4 100644 --- a/src/main/java/com/google/cloud/mcp/CredentialsProvider.java +++ b/src/main/java/com/google/cloud/mcp/auth/CredentialsProvider.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.google.cloud.mcp; +package com.google.cloud.mcp.auth; import java.util.concurrent.CompletableFuture; diff --git a/src/main/java/com/google/cloud/mcp/GoogleCredentialsProvider.java b/src/main/java/com/google/cloud/mcp/auth/GoogleCredentialsProvider.java similarity index 98% rename from src/main/java/com/google/cloud/mcp/GoogleCredentialsProvider.java rename to src/main/java/com/google/cloud/mcp/auth/GoogleCredentialsProvider.java index 11a7232..bb1490e 100644 --- a/src/main/java/com/google/cloud/mcp/GoogleCredentialsProvider.java +++ b/src/main/java/com/google/cloud/mcp/auth/GoogleCredentialsProvider.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.google.cloud.mcp; +package com.google.cloud.mcp.auth; import com.google.auth.oauth2.GoogleCredentials; import java.io.IOException; diff --git a/src/main/java/com/google/cloud/mcp/ResolvedAuth.java b/src/main/java/com/google/cloud/mcp/auth/ResolvedAuth.java similarity index 97% rename from src/main/java/com/google/cloud/mcp/ResolvedAuth.java rename to src/main/java/com/google/cloud/mcp/auth/ResolvedAuth.java index 017d7e7..e79ecfb 100644 --- a/src/main/java/com/google/cloud/mcp/ResolvedAuth.java +++ b/src/main/java/com/google/cloud/mcp/auth/ResolvedAuth.java @@ -14,8 +14,9 @@ * limitations under the License. */ -package com.google.cloud.mcp; +package com.google.cloud.mcp.auth; +import com.google.cloud.mcp.tool.ToolDefinition; import java.util.Map; /** Represents a resolved set of authentication credentials for a tool execution. */ diff --git a/src/main/java/com/google/cloud/mcp/McpToolboxClientBuilder.java b/src/main/java/com/google/cloud/mcp/client/McpToolboxClientBuilder.java similarity index 91% rename from src/main/java/com/google/cloud/mcp/McpToolboxClientBuilder.java rename to src/main/java/com/google/cloud/mcp/client/McpToolboxClientBuilder.java index 7c2f765..0b66c13 100644 --- a/src/main/java/com/google/cloud/mcp/McpToolboxClientBuilder.java +++ b/src/main/java/com/google/cloud/mcp/client/McpToolboxClientBuilder.java @@ -14,8 +14,15 @@ * limitations under the License. */ -package com.google.cloud.mcp; - +package com.google.cloud.mcp.client; + +import com.google.cloud.mcp.McpToolboxClient; +import com.google.cloud.mcp.ProtocolVersion; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.tool.ToolPostProcessor; +import com.google.cloud.mcp.tool.ToolPreProcessor; +import com.google.cloud.mcp.transport.HttpMcpTransport; +import com.google.cloud.mcp.transport.Transport; import java.util.ArrayList; import java.util.HashMap; import java.util.List; diff --git a/src/main/java/com/google/cloud/mcp/McpToolboxClientImpl.java b/src/main/java/com/google/cloud/mcp/client/McpToolboxClientImpl.java similarity index 92% rename from src/main/java/com/google/cloud/mcp/McpToolboxClientImpl.java rename to src/main/java/com/google/cloud/mcp/client/McpToolboxClientImpl.java index c5d254b..5f586f5 100644 --- a/src/main/java/com/google/cloud/mcp/McpToolboxClientImpl.java +++ b/src/main/java/com/google/cloud/mcp/client/McpToolboxClientImpl.java @@ -14,10 +14,23 @@ * limitations under the License. */ -package com.google.cloud.mcp; +package com.google.cloud.mcp.client; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.cloud.mcp.McpToolboxClient; +import com.google.cloud.mcp.auth.AuthTokenGetter; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.exception.McpException; +import com.google.cloud.mcp.tool.Tool; +import com.google.cloud.mcp.tool.ToolDefinition; +import com.google.cloud.mcp.tool.ToolPostProcessor; +import com.google.cloud.mcp.tool.ToolPreProcessor; +import com.google.cloud.mcp.tool.ToolResult; +import com.google.cloud.mcp.transport.HttpMcpTransport; +import com.google.cloud.mcp.transport.Transport; +import com.google.cloud.mcp.transport.TransportManifest; +import com.google.cloud.mcp.transport.TransportResponse; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; @@ -364,10 +377,11 @@ private ToolResult handleInvokeResponse(final TransportResponse response, final try { JsonNode root = objectMapper.readTree(body); if (root.has("error")) { + JsonNode errNode = root.get("error"); + int code = errNode.has("code") ? errNode.get("code").asInt() : -1; + String msg = errNode.has("message") ? errNode.get("message").asText() : errNode.toString(); return new ToolResult( - java.util.List.of( - new ToolResult.Content("text", "MCP Error: " + root.get("error").toString())), - true); + java.util.List.of(new ToolResult.Content("text", "MCP Error: " + msg)), true); } boolean isError = root.has("isError") && root.get("isError").asBoolean(); @@ -388,4 +402,13 @@ private ToolResult handleInvokeResponse(final TransportResponse response, final return new ToolResult(java.util.List.of(new ToolResult.Content("text", body)), false); } } + + @Override + public void close() { + try { + transport.close(); + } catch (Exception e) { + throw new McpException("Failed to close transport", e); + } + } } diff --git a/src/main/java/com/google/cloud/mcp/McpException.java b/src/main/java/com/google/cloud/mcp/exception/McpException.java similarity index 96% rename from src/main/java/com/google/cloud/mcp/McpException.java rename to src/main/java/com/google/cloud/mcp/exception/McpException.java index c057016..9016674 100644 --- a/src/main/java/com/google/cloud/mcp/McpException.java +++ b/src/main/java/com/google/cloud/mcp/exception/McpException.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.google.cloud.mcp; +package com.google.cloud.mcp.exception; /** Unchecked exception thrown for MCP Toolbox Client operations and protocol failures. */ public class McpException extends RuntimeException { diff --git a/src/main/java/com/google/cloud/mcp/Tool.java b/src/main/java/com/google/cloud/mcp/tool/Tool.java similarity index 97% rename from src/main/java/com/google/cloud/mcp/Tool.java rename to src/main/java/com/google/cloud/mcp/tool/Tool.java index 49cfe59..4fb0229 100644 --- a/src/main/java/com/google/cloud/mcp/Tool.java +++ b/src/main/java/com/google/cloud/mcp/tool/Tool.java @@ -14,8 +14,11 @@ * limitations under the License. */ -package com.google.cloud.mcp; +package com.google.cloud.mcp.tool; +import com.google.cloud.mcp.McpToolboxClient; +import com.google.cloud.mcp.auth.AuthResolver; +import com.google.cloud.mcp.auth.AuthTokenGetter; import java.util.ArrayList; import java.util.HashMap; import java.util.List; diff --git a/src/main/java/com/google/cloud/mcp/ToolDefinition.java b/src/main/java/com/google/cloud/mcp/tool/ToolDefinition.java similarity index 96% rename from src/main/java/com/google/cloud/mcp/ToolDefinition.java rename to src/main/java/com/google/cloud/mcp/tool/ToolDefinition.java index ac8e60b..d73d098 100644 --- a/src/main/java/com/google/cloud/mcp/ToolDefinition.java +++ b/src/main/java/com/google/cloud/mcp/tool/ToolDefinition.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.google.cloud.mcp; +package com.google.cloud.mcp.tool; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; @@ -25,7 +25,7 @@ * * @param description A description of what the tool does. * @param parameters A list of parameters the tool accepts. - * @param authRequired List of auth services required by the tool. + * @param authRequired A list of authentication sources required by the tool. * @param readOnlyHint Hint indicating whether the tool is read-only. * @param destructiveHint Hint indicating whether the tool is destructive. */ diff --git a/src/main/java/com/google/cloud/mcp/ToolPostProcessor.java b/src/main/java/com/google/cloud/mcp/tool/ToolPostProcessor.java similarity index 96% rename from src/main/java/com/google/cloud/mcp/ToolPostProcessor.java rename to src/main/java/com/google/cloud/mcp/tool/ToolPostProcessor.java index 09000bb..61280ea 100644 --- a/src/main/java/com/google/cloud/mcp/ToolPostProcessor.java +++ b/src/main/java/com/google/cloud/mcp/tool/ToolPostProcessor.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.google.cloud.mcp; +package com.google.cloud.mcp.tool; import java.util.concurrent.CompletableFuture; diff --git a/src/main/java/com/google/cloud/mcp/ToolPreProcessor.java b/src/main/java/com/google/cloud/mcp/tool/ToolPreProcessor.java similarity index 97% rename from src/main/java/com/google/cloud/mcp/ToolPreProcessor.java rename to src/main/java/com/google/cloud/mcp/tool/ToolPreProcessor.java index a280a85..190e38d 100644 --- a/src/main/java/com/google/cloud/mcp/ToolPreProcessor.java +++ b/src/main/java/com/google/cloud/mcp/tool/ToolPreProcessor.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.google.cloud.mcp; +package com.google.cloud.mcp.tool; import java.util.Map; import java.util.concurrent.CompletableFuture; diff --git a/src/main/java/com/google/cloud/mcp/ToolResult.java b/src/main/java/com/google/cloud/mcp/tool/ToolResult.java similarity index 97% rename from src/main/java/com/google/cloud/mcp/ToolResult.java rename to src/main/java/com/google/cloud/mcp/tool/ToolResult.java index d447adc..29b6a61 100644 --- a/src/main/java/com/google/cloud/mcp/ToolResult.java +++ b/src/main/java/com/google/cloud/mcp/tool/ToolResult.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.google.cloud.mcp; +package com.google.cloud.mcp.tool; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonProperty; diff --git a/src/main/java/com/google/cloud/mcp/HttpMcpTransport.java b/src/main/java/com/google/cloud/mcp/transport/BaseMcpTransport.java similarity index 54% rename from src/main/java/com/google/cloud/mcp/HttpMcpTransport.java rename to src/main/java/com/google/cloud/mcp/transport/BaseMcpTransport.java index 6cfcc86..68342a8 100644 --- a/src/main/java/com/google/cloud/mcp/HttpMcpTransport.java +++ b/src/main/java/com/google/cloud/mcp/transport/BaseMcpTransport.java @@ -14,10 +14,15 @@ * limitations under the License. */ -package com.google.cloud.mcp; +package com.google.cloud.mcp.transport; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.cloud.mcp.JsonRpc; +import com.google.cloud.mcp.ProtocolVersion; +import com.google.cloud.mcp.TelemetryHelper; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.tool.ToolDefinition; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; @@ -33,91 +38,48 @@ import java.util.concurrent.CompletableFuture; import java.util.logging.Logger; -/** Default HTTP transport implementation using Java 11 HttpClient. */ -public final class HttpMcpTransport implements Transport { +public abstract class BaseMcpTransport implements Transport { - private static final Logger logger = Logger.getLogger(HttpMcpTransport.class.getName()); - private static final String HTTP_WARNING = + protected static final Logger logger = Logger.getLogger(BaseMcpTransport.class.getName()); + protected static final String HTTP_WARNING = "This connection is using HTTP. To prevent credential exposure, please ensure all" + " communication is sent over HTTPS."; - private final String baseUrl; - private final Map clientHeaders; - private final CredentialsProvider credentialsProvider; - private final HttpClient httpClient; - private final ObjectMapper objectMapper; - private final ProtocolVersion preferredProtocolVersion; - private final Object initLock = new Object(); - private CompletableFuture initFuture; - private volatile ProtocolVersion negotiatedProtocolVersion; - private volatile String sessionId; + protected final String baseUrl; + protected final Map clientHeaders; + protected final CredentialsProvider credentialsProvider; + protected final HttpClient httpClient; + protected final ObjectMapper objectMapper; + protected final ProtocolVersion preferredProtocolVersion; + protected final Object initLock = new Object(); + protected CompletableFuture initFuture; - /** - * Constructs a new HttpMcpTransport with a base URL. - * - * @param baseUrl The base URL of the remote service. - */ - public HttpMcpTransport(String baseUrl) { - this(baseUrl, Map.of(), (CredentialsProvider) null); - } + /** The start time of the session in nanoseconds. */ + protected Long sessionStartTime; - /** - * Constructs a new HttpMcpTransport with base URL and default headers. - * - * @param baseUrl The base URL of the remote service. - * @param clientHeaders Default HTTP headers to include in every request. - */ - public HttpMcpTransport(String baseUrl, Map clientHeaders) { - this(baseUrl, clientHeaders, (CredentialsProvider) null); - } + /** The error that occurred during the session, if any. */ + protected Throwable sessionError; - /** - * Constructs a new HttpMcpTransport with base URL, default headers and credentials provider. - * - * @param baseUrl The base URL of the remote service. - * @param clientHeaders Default HTTP headers to include in every request. - * @param credentialsProvider Provider for retrieving authorization credentials. - */ - public HttpMcpTransport( - String baseUrl, Map clientHeaders, CredentialsProvider credentialsProvider) { - this(baseUrl, clientHeaders, credentialsProvider, null, null, null); - } + /** The negotiated protocol version. */ + protected ProtocolVersion negotiatedProtocolVersion; /** - * Constructs a HttpMcpTransport. + * Constructs a new BaseMcpTransport. * - * @param baseUrl The base URL of the remote service. - * @param clientHeaders Default HTTP headers to include in every request. - * @param preferredProtocolVersion Preferred MCP protocol version. - * @param httpClient Custom HTTP Client. - * @param executor Optional Executor for handling async requests. + * @param baseUrl The base URL. + * @param clientHeaders The client headers. + * @param credentialsProvider The credentials provider. + * @param preferredProtocolVersion The preferred protocol version. + * @param httpClient The HTTP client. + * @param executor The executor. */ - public HttpMcpTransport( - String baseUrl, - Map clientHeaders, - ProtocolVersion preferredProtocolVersion, - HttpClient httpClient, - java.util.concurrent.Executor executor) { - this(baseUrl, clientHeaders, null, preferredProtocolVersion, httpClient, executor); - } - - /** - * Primary constructor for HttpMcpTransport. - * - * @param baseUrl The base URL of the remote service. - * @param clientHeaders Default HTTP headers to include in every request. - * @param credentialsProvider Provider for retrieving authorization credentials. - * @param preferredProtocolVersion Preferred MCP protocol version. - * @param httpClient Custom HTTP Client. - * @param executor Optional Executor for handling async requests. - */ - public HttpMcpTransport( - String baseUrl, - Map clientHeaders, - CredentialsProvider credentialsProvider, - ProtocolVersion preferredProtocolVersion, - HttpClient httpClient, - java.util.concurrent.Executor executor) { + protected BaseMcpTransport( + final String baseUrl, + final Map clientHeaders, + final CredentialsProvider credentialsProvider, + final ProtocolVersion preferredProtocolVersion, + final HttpClient httpClient, + final java.util.concurrent.Executor executor) { if (baseUrl == null || baseUrl.isEmpty()) { throw new IllegalArgumentException("Base URL must be provided"); } @@ -146,28 +108,13 @@ public HttpMcpTransport( this.objectMapper = new ObjectMapper(); } - HttpMcpTransport(String baseUrl, HttpClient httpClient) { - this(baseUrl, Map.of(), null, null, httpClient, null); - } - - HttpMcpTransport(String baseUrl, Map clientHeaders, HttpClient httpClient) { - this(baseUrl, clientHeaders, null, null, httpClient, null); - } - - HttpMcpTransport( - String baseUrl, - Map clientHeaders, - CredentialsProvider credentialsProvider, - HttpClient httpClient) { - this(baseUrl, clientHeaders, credentialsProvider, null, httpClient, null); - } - @Override - public String getBaseUrl() { + public final String getBaseUrl() { return this.baseUrl; } - private CompletableFuture> mergeHeaders(Map extraMetadata) { + final CompletableFuture> mergeHeaders( + final Map extraMetadata) { CompletableFuture authFuture = this.credentialsProvider != null ? this.credentialsProvider.getAuthorizationHeader() @@ -234,9 +181,16 @@ private CompletableFuture> mergeHeaders(Map }); } - private CompletableFuture ensureInitialized(Map extraMetadata) { + final CompletableFuture ensureInitialized(final Map extraMetadata) { synchronized (initLock) { if (initFuture == null) { + if (sessionStartTime == null) { + sessionStartTime = System.nanoTime(); + } + TelemetryHelper.OperationSpan initSpan = + new TelemetryHelper.OperationSpan( + "initialize", preferredProtocolVersion.getValue(), baseUrl, null); + Map handshakeMetadata = new HashMap<>(); if (extraMetadata != null) { String authKey = @@ -248,130 +202,51 @@ private CompletableFuture ensureInitialized(Map extraMetad handshakeMetadata.put("Authorization", extraMetadata.get(authKey)); } } - initFuture = + CompletableFuture future = mergeHeaders(handshakeMetadata) .thenCompose( handshakeHeaders -> { String authHeader = handshakeHeaders.get("Authorization"); - return performInitialization(authHeader, handshakeHeaders); + Map traceHeaders = initSpan.getTraceContextHeaders(); + return performInitialization(authHeader, handshakeHeaders, traceHeaders); }); + + future.whenComplete( + (v, err) -> { + if (err != null) { + initSpan.recordError(err); + sessionError = err; + synchronized (initLock) { + initFuture = null; + } + } + initSpan.close(); + }); + initFuture = future; + return future; } return initFuture; } } - private CompletableFuture performInitialization( - String authHeader, Map handshakeHeaders) { - try { - if (this.baseUrl.toLowerCase(java.util.Locale.ROOT).startsWith("http://") - && authHeader != null) { - logger.warning(HTTP_WARNING); - } - JsonRpc.Request initReq = - new JsonRpc.Request( - "initialize", - new JsonRpc.InitializeParams( - preferredProtocolVersion.getValue(), "mcp-toolbox-sdk-java")); - String body = objectMapper.writeValueAsString(initReq); - HttpRequest.Builder req = - HttpRequest.newBuilder() - .uri(URI.create(baseUrl)) - .POST(HttpRequest.BodyPublishers.ofString(body)); - - handshakeHeaders.forEach(req::setHeader); - applyProtocolHeaders(req); - - return httpClient - .sendAsync(req.build(), HttpResponse.BodyHandlers.ofString()) - .thenCompose( - res -> { - if (res.statusCode() != 200) { - return CompletableFuture.failedFuture( - new McpException("Init failed: " + res.statusCode() + " " + res.body())); - } - try { - JsonNode responseJson = objectMapper.readTree(res.body()); - if (responseJson.has("error")) { - return CompletableFuture.failedFuture( - new McpException("MCP Error: " + responseJson.get("error").toString())); - } - JsonNode result = responseJson.get("result"); - String serverVersion; - if (result != null && result.has("protocolVersion")) { - serverVersion = result.get("protocolVersion").asText(); - } else { - // Fallback to the client's preferred version for backward-compatible/mock - // servers - serverVersion = preferredProtocolVersion.getValue(); - } - - // Verify strict compliance with Python/Go behavior - if (!preferredProtocolVersion.getValue().equals(serverVersion)) { - return CompletableFuture.failedFuture( - new McpException( - "MCP version mismatch: client (" - + preferredProtocolVersion.getValue() - + ") != server (" - + serverVersion - + ")")); - } - - this.negotiatedProtocolVersion = ProtocolVersion.fromString(serverVersion); - - if (negotiatedProtocolVersion == ProtocolVersion.VERSION_2025_03_26) { - java.util.Optional sessionIdOpt = - res.headers().firstValue("Mcp-Session-Id"); - if (sessionIdOpt.isEmpty()) { - return CompletableFuture.failedFuture( - new McpException( - "Server did not return a Mcp-Session-Id header during" - + " initialization.")); - } - this.sessionId = sessionIdOpt.get(); - } + /** + * Performs the version-specific initialization handshake. + * + * @param authHeader The authorization header value, if present. + * @param handshakeHeaders The resolved headers for the handshake. + * @param traceHeaders The trace context headers to propagate. + * @return A CompletableFuture that completes when initialization is done. + */ + protected abstract CompletableFuture performInitialization( + final String authHeader, + final Map handshakeHeaders, + final Map traceHeaders); - JsonRpc.Notification notif = - new JsonRpc.Notification("notifications/initialized", Map.of()); - String notifBody = objectMapper.writeValueAsString(notif); - HttpRequest.Builder nReq = - HttpRequest.newBuilder() - .uri(URI.create(baseUrl)) - .POST(HttpRequest.BodyPublishers.ofString(notifBody)); - - handshakeHeaders.forEach(nReq::setHeader); - applyProtocolHeaders(nReq); - - return httpClient - .sendAsync(nReq.build(), HttpResponse.BodyHandlers.ofString()) - .thenAccept(nRes -> {}); - } catch (Exception e) { - return CompletableFuture.failedFuture(e); - } - }); - } catch (Exception e) { - return CompletableFuture.failedFuture(e); - } - } - - private void applyProtocolHeaders(HttpRequest.Builder builder) { - builder.header("Content-Type", "application/json"); - if (negotiatedProtocolVersion == null) { - return; - } - if (negotiatedProtocolVersion.requiresAcceptJson()) { - builder.header("Accept", "application/json"); - } - if (negotiatedProtocolVersion.requiresVersionHeader()) { - builder.header("MCP-Protocol-Version", negotiatedProtocolVersion.getValue()); - } - if (negotiatedProtocolVersion.requiresSessionIdHeader() && sessionId != null) { - builder.header("Mcp-Session-Id", sessionId); - } - } + protected abstract void applyProtocolHeaders(final HttpRequest.Builder builder); @Override - public CompletableFuture listTools( - String toolsetName, Map metadata) { + public final CompletableFuture listTools( + final String toolsetName, final Map metadata) { if (this.baseUrl.toLowerCase(java.util.Locale.ROOT).startsWith("http://") && !metadata.isEmpty()) { logger.warning(HTTP_WARNING); @@ -382,8 +257,25 @@ public CompletableFuture listTools( mergedHeaders -> { String path = toolsetName != null && !toolsetName.isEmpty() ? "/" + toolsetName : ""; String url = baseUrl + path; + + TelemetryHelper.OperationSpan listSpan = + new TelemetryHelper.OperationSpan( + "tools/list", + negotiatedProtocolVersion != null + ? negotiatedProtocolVersion.getValue() + : preferredProtocolVersion.getValue(), + url, + null); + try { - JsonRpc.Request listReq = new JsonRpc.Request("tools/list", Map.of()); + Map traceHeaders = listSpan.getTraceContextHeaders(); + JsonRpc.RequestMetadata reqMetadata = + new JsonRpc.RequestMetadata( + traceHeaders.get("traceparent"), traceHeaders.get("tracestate")); + + JsonRpc.Request listReq = + new JsonRpc.Request( + "tools/list", new JsonRpc.ListToolsParams(null, reqMetadata)); String body = objectMapper.writeValueAsString(listReq); HttpRequest.Builder req = HttpRequest.newBuilder() @@ -394,16 +286,27 @@ public CompletableFuture listTools( return httpClient .sendAsync(req.build(), HttpResponse.BodyHandlers.ofString()) - .thenApply(this::handleListToolsResponse); + .thenApply(res -> handleListToolsResponse(res, listSpan)) + .whenComplete( + (res, err) -> { + if (err != null) { + listSpan.recordError(err); + } + listSpan.close(); + }); } catch (Exception e) { + listSpan.recordError(e); + listSpan.close(); return CompletableFuture.failedFuture(e); } }); } @Override - public CompletableFuture invokeTool( - String toolName, Map arguments, Map metadata) { + public final CompletableFuture invokeTool( + final String toolName, + final Map arguments, + final Map metadata) { if (this.baseUrl.toLowerCase(java.util.Locale.ROOT).startsWith("http://") && !metadata.isEmpty()) { logger.warning(HTTP_WARNING); @@ -412,10 +315,24 @@ public CompletableFuture invokeTool( .thenCompose(v -> mergeHeaders(metadata)) .thenCompose( mergedHeaders -> { + TelemetryHelper.OperationSpan callSpan = + new TelemetryHelper.OperationSpan( + "tools/call", + negotiatedProtocolVersion != null + ? negotiatedProtocolVersion.getValue() + : preferredProtocolVersion.getValue(), + baseUrl, + toolName); + try { + Map traceHeaders = callSpan.getTraceContextHeaders(); + JsonRpc.RequestMetadata reqMetadata = + new JsonRpc.RequestMetadata( + traceHeaders.get("traceparent"), traceHeaders.get("tracestate")); + JsonRpc.Request invokeReq = new JsonRpc.Request( - "tools/call", new JsonRpc.CallToolParams(toolName, arguments)); + "tools/call", new JsonRpc.CallToolParams(toolName, arguments, reqMetadata)); String requestBody = objectMapper.writeValueAsString(invokeReq); HttpRequest.Builder requestBuilder = @@ -428,8 +345,39 @@ public CompletableFuture invokeTool( return httpClient .sendAsync(requestBuilder.build(), HttpResponse.BodyHandlers.ofString()) - .thenApply(res -> new TransportResponse(res.statusCode(), res.body())); + .thenApply( + res -> { + if (res.statusCode() < 200 || res.statusCode() >= 300) { + callSpan.recordError( + res.statusCode(), "Error " + res.statusCode() + ": " + res.body()); + } else { + try { + JsonNode root = objectMapper.readTree(res.body()); + if (root.has("error")) { + JsonNode errNode = root.get("error"); + int code = errNode.has("code") ? errNode.get("code").asInt() : -1; + String msg = + errNode.has("message") + ? errNode.get("message").asText() + : errNode.toString(); + callSpan.recordError(code, msg); + } + } catch (Exception ignored) { + // Ignore parsing exceptions here + } + } + return new TransportResponse(res.statusCode(), res.body()); + }) + .whenComplete( + (res, err) -> { + if (err != null) { + callSpan.recordError(err); + } + callSpan.close(); + }); } catch (Exception e) { + callSpan.recordError(e); + callSpan.close(); return CompletableFuture.failedFuture(e); } }); @@ -437,17 +385,39 @@ public CompletableFuture invokeTool( @Override public void close() { - // No-op for HttpClient in Java 11 + if (sessionStartTime != null) { + double durationSeconds = (System.nanoTime() - sessionStartTime) / 1e9; + TelemetryHelper.recordSessionDuration( + durationSeconds, + negotiatedProtocolVersion != null + ? negotiatedProtocolVersion.getValue() + : preferredProtocolVersion.getValue(), + baseUrl, + sessionError); + } } - private TransportManifest handleListToolsResponse(HttpResponse response) { - if (response.statusCode() != 200) + private TransportManifest handleListToolsResponse( + final HttpResponse response, TelemetryHelper.OperationSpan span) { + if (response.statusCode() != 200) { + if (span != null) { + span.recordError( + response.statusCode(), + "Failed to list tools. Status: " + response.statusCode() + " " + response.body()); + } throw new RuntimeException( "Failed to list tools. Status: " + response.statusCode() + " " + response.body()); + } try { JsonNode root = objectMapper.readTree(response.body()); if (root.has("error")) { - throw new RuntimeException("MCP Error: " + root.get("error").toString()); + JsonNode errNode = root.get("error"); + int code = errNode.has("code") ? errNode.get("code").asInt() : -1; + String msg = errNode.has("message") ? errNode.get("message").asText() : errNode.toString(); + if (span != null) { + span.recordError(code, msg); + } + throw new RuntimeException("MCP Error: " + msg); } JsonNode result = root.get("result"); JsonNode toolsNode = result.get("tools"); diff --git a/src/main/java/com/google/cloud/mcp/transport/HttpMcpTransport.java b/src/main/java/com/google/cloud/mcp/transport/HttpMcpTransport.java new file mode 100644 index 0000000..d5833fc --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/transport/HttpMcpTransport.java @@ -0,0 +1,175 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.transport; + +import com.google.cloud.mcp.ProtocolVersion; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.transport.v20241105.HttpMcpTransportV20241105; +import com.google.cloud.mcp.transport.v20250326.HttpMcpTransportV20250326; +import com.google.cloud.mcp.transport.v20250618.HttpMcpTransportV20250618; +import com.google.cloud.mcp.transport.v20251125.HttpMcpTransportV20251125; +import java.net.http.HttpClient; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +/** Default HTTP transport implementation routing requests to version-specific handlers. */ +public final class HttpMcpTransport implements Transport { + + private final Transport delegate; + + /** + * Constructs a new HttpMcpTransport with a base URL. + * + * @param baseUrl The base URL of the remote service. + */ + public HttpMcpTransport(final String baseUrl) { + this(baseUrl, Map.of(), (CredentialsProvider) null); + } + + /** + * Constructs a new HttpMcpTransport with base URL and default headers. + * + * @param baseUrl The base URL of the remote service. + * @param clientHeaders Default HTTP headers to include in every request. + */ + public HttpMcpTransport(final String baseUrl, final Map clientHeaders) { + this(baseUrl, clientHeaders, (CredentialsProvider) null); + } + + /** + * Constructs a new HttpMcpTransport with base URL, default headers and credentials provider. + * + * @param baseUrl The base URL of the remote service. + * @param clientHeaders Default HTTP headers to include in every request. + * @param credentialsProvider Provider for retrieving authorization credentials. + */ + public HttpMcpTransport( + final String baseUrl, + final Map clientHeaders, + final CredentialsProvider credentialsProvider) { + this(baseUrl, clientHeaders, credentialsProvider, null, null, null); + } + + /** + * Constructs a HttpMcpTransport. + * + * @param baseUrl The base URL of the remote service. + * @param clientHeaders Default HTTP headers to include in every request. + * @param preferredProtocolVersion Preferred MCP protocol version. + * @param httpClient Custom HTTP Client. + * @param executor Optional Executor for handling async requests. + */ + public HttpMcpTransport( + final String baseUrl, + final Map clientHeaders, + final ProtocolVersion preferredProtocolVersion, + final HttpClient httpClient, + final java.util.concurrent.Executor executor) { + this(baseUrl, clientHeaders, null, preferredProtocolVersion, httpClient, executor); + } + + /** + * Primary constructor for HttpMcpTransport. + * + * @param baseUrl The base URL of the remote service. + * @param clientHeaders Default HTTP headers to include in every request. + * @param credentialsProvider Provider for retrieving authorization credentials. + * @param preferredProtocolVersion Preferred MCP protocol version. + * @param httpClient Custom HTTP Client. + * @param executor Optional Executor for handling async requests. + */ + public HttpMcpTransport( + final String baseUrl, + final Map clientHeaders, + final CredentialsProvider credentialsProvider, + final ProtocolVersion preferredProtocolVersion, + final HttpClient httpClient, + final java.util.concurrent.Executor executor) { + final ProtocolVersion version = + preferredProtocolVersion != null + ? preferredProtocolVersion + : ProtocolVersion.VERSION_2025_11_25; + + switch (version) { + case VERSION_2025_11_25: + this.delegate = + new HttpMcpTransportV20251125( + baseUrl, clientHeaders, credentialsProvider, httpClient, executor); + break; + case VERSION_2025_06_18: + this.delegate = + new HttpMcpTransportV20250618( + baseUrl, clientHeaders, credentialsProvider, httpClient, executor); + break; + case VERSION_2025_03_26: + this.delegate = + new HttpMcpTransportV20250326( + baseUrl, clientHeaders, credentialsProvider, httpClient, executor); + break; + case VERSION_2024_11_05: + this.delegate = + new HttpMcpTransportV20241105( + baseUrl, clientHeaders, credentialsProvider, httpClient, executor); + break; + default: + throw new IllegalArgumentException("Unsupported protocol version: " + version); + } + } + + /** Internal constructor for testing purposes. */ + public HttpMcpTransport(final String baseUrl, final HttpClient httpClient) { + this(baseUrl, Map.of(), null, null, httpClient, null); + } + + /** Internal constructor for testing purposes. */ + public HttpMcpTransport( + final String baseUrl, final Map clientHeaders, final HttpClient httpClient) { + this(baseUrl, clientHeaders, null, null, httpClient, null); + } + + HttpMcpTransport( + final String baseUrl, + final Map clientHeaders, + final CredentialsProvider credentialsProvider, + final HttpClient httpClient) { + this(baseUrl, clientHeaders, credentialsProvider, null, httpClient, null); + } + + @Override + public String getBaseUrl() { + return delegate.getBaseUrl(); + } + + @Override + public CompletableFuture listTools( + final String toolsetName, final Map metadata) { + return delegate.listTools(toolsetName, metadata); + } + + @Override + public CompletableFuture invokeTool( + final String toolName, + final Map arguments, + final Map metadata) { + return delegate.invokeTool(toolName, arguments, metadata); + } + + @Override + public void close() { + delegate.close(); + } +} diff --git a/src/main/java/com/google/cloud/mcp/Transport.java b/src/main/java/com/google/cloud/mcp/transport/Transport.java similarity index 97% rename from src/main/java/com/google/cloud/mcp/Transport.java rename to src/main/java/com/google/cloud/mcp/transport/Transport.java index 566eefe..37ac88b 100644 --- a/src/main/java/com/google/cloud/mcp/Transport.java +++ b/src/main/java/com/google/cloud/mcp/transport/Transport.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.google.cloud.mcp; +package com.google.cloud.mcp.transport; import java.util.Map; import java.util.concurrent.CompletableFuture; diff --git a/src/main/java/com/google/cloud/mcp/TransportManifest.java b/src/main/java/com/google/cloud/mcp/transport/TransportManifest.java similarity index 92% rename from src/main/java/com/google/cloud/mcp/TransportManifest.java rename to src/main/java/com/google/cloud/mcp/transport/TransportManifest.java index e294afa..f8a8dac 100644 --- a/src/main/java/com/google/cloud/mcp/TransportManifest.java +++ b/src/main/java/com/google/cloud/mcp/transport/TransportManifest.java @@ -14,8 +14,9 @@ * limitations under the License. */ -package com.google.cloud.mcp; +package com.google.cloud.mcp.transport; +import com.google.cloud.mcp.tool.ToolDefinition; import java.util.Map; /** Represents the raw tools manifest returned by the transport. */ diff --git a/src/main/java/com/google/cloud/mcp/TransportResponse.java b/src/main/java/com/google/cloud/mcp/transport/TransportResponse.java similarity index 97% rename from src/main/java/com/google/cloud/mcp/TransportResponse.java rename to src/main/java/com/google/cloud/mcp/transport/TransportResponse.java index 4532b69..5044af4 100644 --- a/src/main/java/com/google/cloud/mcp/TransportResponse.java +++ b/src/main/java/com/google/cloud/mcp/transport/TransportResponse.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.google.cloud.mcp; +package com.google.cloud.mcp.transport; /** Represents a raw transport response containing status code and response body. */ public final class TransportResponse { diff --git a/src/main/java/com/google/cloud/mcp/transport/v20241105/HttpMcpTransportV20241105.java b/src/main/java/com/google/cloud/mcp/transport/v20241105/HttpMcpTransportV20241105.java new file mode 100644 index 0000000..4fdb288 --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/transport/v20241105/HttpMcpTransportV20241105.java @@ -0,0 +1,137 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.transport.v20241105; + +import com.fasterxml.jackson.databind.JsonNode; +import com.google.cloud.mcp.JsonRpc; +import com.google.cloud.mcp.ProtocolVersion; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.exception.McpException; +import com.google.cloud.mcp.transport.BaseMcpTransport; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +public final class HttpMcpTransportV20241105 extends BaseMcpTransport { + + public HttpMcpTransportV20241105( + final String baseUrl, + final Map clientHeaders, + final CredentialsProvider credentialsProvider, + final HttpClient httpClient, + final java.util.concurrent.Executor executor) { + super( + baseUrl, + clientHeaders, + credentialsProvider, + ProtocolVersion.VERSION_2024_11_05, + httpClient, + executor); + } + + @Override + protected CompletableFuture performInitialization( + final String authHeader, + final Map handshakeHeaders, + final Map traceHeaders) { + try { + if (this.baseUrl.toLowerCase(java.util.Locale.ROOT).startsWith("http://") + && authHeader != null) { + logger.warning(HTTP_WARNING); + } + JsonRpc.RequestMetadata metadata = + new JsonRpc.RequestMetadata( + traceHeaders.get("traceparent"), traceHeaders.get("tracestate")); + JsonRpc.Request initReq = + new JsonRpc.Request( + "initialize", + new JsonRpc.InitializeParams( + ProtocolVersion.VERSION_2024_11_05.getValue(), "mcp-toolbox-sdk-java", metadata)); + String body = objectMapper.writeValueAsString(initReq); + HttpRequest.Builder req = + HttpRequest.newBuilder() + .uri(URI.create(baseUrl)) + .POST(HttpRequest.BodyPublishers.ofString(body)); + + handshakeHeaders.forEach(req::setHeader); + applyProtocolHeaders(req); + + return httpClient + .sendAsync(req.build(), HttpResponse.BodyHandlers.ofString()) + .thenCompose( + res -> { + if (res.statusCode() != 200) { + return CompletableFuture.failedFuture( + new McpException("Init failed: " + res.statusCode() + " " + res.body())); + } + try { + JsonNode responseJson = objectMapper.readTree(res.body()); + if (responseJson.has("error")) { + return CompletableFuture.failedFuture( + new McpException("MCP Error: " + responseJson.get("error").toString())); + } + JsonNode result = responseJson.get("result"); + String serverVersion; + if (result != null && result.has("protocolVersion")) { + serverVersion = result.get("protocolVersion").asText(); + } else { + serverVersion = ProtocolVersion.VERSION_2024_11_05.getValue(); + } + + if (!ProtocolVersion.VERSION_2024_11_05.getValue().equals(serverVersion)) { + return CompletableFuture.failedFuture( + new McpException( + "MCP version mismatch: client (" + + ProtocolVersion.VERSION_2024_11_05.getValue() + + ") != server (" + + serverVersion + + ")")); + } + + this.negotiatedProtocolVersion = ProtocolVersion.VERSION_2024_11_05; + + JsonRpc.Notification notif = + new JsonRpc.Notification("notifications/initialized", Map.of()); + String notifBody = objectMapper.writeValueAsString(notif); + HttpRequest.Builder nReq = + HttpRequest.newBuilder() + .uri(URI.create(baseUrl)) + .POST(HttpRequest.BodyPublishers.ofString(notifBody)); + + handshakeHeaders.forEach(nReq::setHeader); + applyProtocolHeaders(nReq); + + return httpClient + .sendAsync(nReq.build(), HttpResponse.BodyHandlers.ofString()) + .thenAccept(nRes -> {}); + } catch (Exception e) { + return CompletableFuture.failedFuture(e); + } + }); + } catch (Exception e) { + return CompletableFuture.failedFuture(e); + } + } + + @Override + protected void applyProtocolHeaders(final HttpRequest.Builder builder) { + builder.header("Content-Type", "application/json"); + } +} diff --git a/src/main/java/com/google/cloud/mcp/transport/v20250326/HttpMcpTransportV20250326.java b/src/main/java/com/google/cloud/mcp/transport/v20250326/HttpMcpTransportV20250326.java new file mode 100644 index 0000000..b3e46ca --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/transport/v20250326/HttpMcpTransportV20250326.java @@ -0,0 +1,153 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.transport.v20250326; + +import com.fasterxml.jackson.databind.JsonNode; +import com.google.cloud.mcp.JsonRpc; +import com.google.cloud.mcp.ProtocolVersion; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.exception.McpException; +import com.google.cloud.mcp.transport.BaseMcpTransport; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +public final class HttpMcpTransportV20250326 extends BaseMcpTransport { + + private volatile String sessionId; + + public HttpMcpTransportV20250326( + final String baseUrl, + final Map clientHeaders, + final CredentialsProvider credentialsProvider, + final HttpClient httpClient, + final java.util.concurrent.Executor executor) { + super( + baseUrl, + clientHeaders, + credentialsProvider, + ProtocolVersion.VERSION_2025_03_26, + httpClient, + executor); + } + + @Override + protected CompletableFuture performInitialization( + final String authHeader, + final Map handshakeHeaders, + final Map traceHeaders) { + try { + if (this.baseUrl.toLowerCase(java.util.Locale.ROOT).startsWith("http://") + && authHeader != null) { + logger.warning(HTTP_WARNING); + } + JsonRpc.RequestMetadata metadata = + new JsonRpc.RequestMetadata( + traceHeaders.get("traceparent"), traceHeaders.get("tracestate")); + JsonRpc.Request initReq = + new JsonRpc.Request( + "initialize", + new JsonRpc.InitializeParams( + ProtocolVersion.VERSION_2025_03_26.getValue(), "mcp-toolbox-sdk-java", metadata)); + String body = objectMapper.writeValueAsString(initReq); + HttpRequest.Builder req = + HttpRequest.newBuilder() + .uri(URI.create(baseUrl)) + .POST(HttpRequest.BodyPublishers.ofString(body)); + + handshakeHeaders.forEach(req::setHeader); + applyProtocolHeaders(req); + + return httpClient + .sendAsync(req.build(), HttpResponse.BodyHandlers.ofString()) + .thenCompose( + res -> { + if (res.statusCode() != 200) { + return CompletableFuture.failedFuture( + new McpException("Init failed: " + res.statusCode() + " " + res.body())); + } + try { + JsonNode responseJson = objectMapper.readTree(res.body()); + if (responseJson.has("error")) { + return CompletableFuture.failedFuture( + new McpException("MCP Error: " + responseJson.get("error").toString())); + } + JsonNode result = responseJson.get("result"); + String serverVersion; + if (result != null && result.has("protocolVersion")) { + serverVersion = result.get("protocolVersion").asText(); + } else { + serverVersion = ProtocolVersion.VERSION_2025_03_26.getValue(); + } + + if (!ProtocolVersion.VERSION_2025_03_26.getValue().equals(serverVersion)) { + return CompletableFuture.failedFuture( + new McpException( + "MCP version mismatch: client (" + + ProtocolVersion.VERSION_2025_03_26.getValue() + + ") != server (" + + serverVersion + + ")")); + } + + Optional sessionIdOpt = res.headers().firstValue("Mcp-Session-Id"); + if (sessionIdOpt.isEmpty()) { + return CompletableFuture.failedFuture( + new McpException( + "Server did not return a Mcp-Session-Id header during" + + " initialization.")); + } + this.sessionId = sessionIdOpt.get(); + + this.negotiatedProtocolVersion = ProtocolVersion.VERSION_2025_03_26; + + JsonRpc.Notification notif = + new JsonRpc.Notification("notifications/initialized", Map.of()); + String notifBody = objectMapper.writeValueAsString(notif); + HttpRequest.Builder nReq = + HttpRequest.newBuilder() + .uri(URI.create(baseUrl)) + .POST(HttpRequest.BodyPublishers.ofString(notifBody)); + + handshakeHeaders.forEach(nReq::setHeader); + applyProtocolHeaders(nReq); + + return httpClient + .sendAsync(nReq.build(), HttpResponse.BodyHandlers.ofString()) + .thenAccept(nRes -> {}); + } catch (Exception e) { + return CompletableFuture.failedFuture(e); + } + }); + } catch (Exception e) { + return CompletableFuture.failedFuture(e); + } + } + + @Override + protected void applyProtocolHeaders(final HttpRequest.Builder builder) { + builder.header("Content-Type", "application/json"); + builder.header("Accept", "application/json"); + if (sessionId != null) { + builder.header("Mcp-Session-Id", sessionId); + } + } +} diff --git a/src/main/java/com/google/cloud/mcp/transport/v20250618/HttpMcpTransportV20250618.java b/src/main/java/com/google/cloud/mcp/transport/v20250618/HttpMcpTransportV20250618.java new file mode 100644 index 0000000..4973068 --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/transport/v20250618/HttpMcpTransportV20250618.java @@ -0,0 +1,139 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.transport.v20250618; + +import com.fasterxml.jackson.databind.JsonNode; +import com.google.cloud.mcp.JsonRpc; +import com.google.cloud.mcp.ProtocolVersion; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.exception.McpException; +import com.google.cloud.mcp.transport.BaseMcpTransport; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +public final class HttpMcpTransportV20250618 extends BaseMcpTransport { + + public HttpMcpTransportV20250618( + final String baseUrl, + final Map clientHeaders, + final CredentialsProvider credentialsProvider, + final HttpClient httpClient, + final java.util.concurrent.Executor executor) { + super( + baseUrl, + clientHeaders, + credentialsProvider, + ProtocolVersion.VERSION_2025_06_18, + httpClient, + executor); + } + + @Override + protected CompletableFuture performInitialization( + final String authHeader, + final Map handshakeHeaders, + final Map traceHeaders) { + try { + if (this.baseUrl.toLowerCase(java.util.Locale.ROOT).startsWith("http://") + && authHeader != null) { + logger.warning(HTTP_WARNING); + } + JsonRpc.RequestMetadata metadata = + new JsonRpc.RequestMetadata( + traceHeaders.get("traceparent"), traceHeaders.get("tracestate")); + JsonRpc.Request initReq = + new JsonRpc.Request( + "initialize", + new JsonRpc.InitializeParams( + ProtocolVersion.VERSION_2025_06_18.getValue(), "mcp-toolbox-sdk-java", metadata)); + String body = objectMapper.writeValueAsString(initReq); + HttpRequest.Builder req = + HttpRequest.newBuilder() + .uri(URI.create(baseUrl)) + .POST(HttpRequest.BodyPublishers.ofString(body)); + + handshakeHeaders.forEach(req::setHeader); + applyProtocolHeaders(req); + + return httpClient + .sendAsync(req.build(), HttpResponse.BodyHandlers.ofString()) + .thenCompose( + res -> { + if (res.statusCode() != 200) { + return CompletableFuture.failedFuture( + new McpException("Init failed: " + res.statusCode() + " " + res.body())); + } + try { + JsonNode responseJson = objectMapper.readTree(res.body()); + if (responseJson.has("error")) { + return CompletableFuture.failedFuture( + new McpException("MCP Error: " + responseJson.get("error").toString())); + } + JsonNode result = responseJson.get("result"); + String serverVersion; + if (result != null && result.has("protocolVersion")) { + serverVersion = result.get("protocolVersion").asText(); + } else { + serverVersion = ProtocolVersion.VERSION_2025_06_18.getValue(); + } + + if (!ProtocolVersion.VERSION_2025_06_18.getValue().equals(serverVersion)) { + return CompletableFuture.failedFuture( + new McpException( + "MCP version mismatch: client (" + + ProtocolVersion.VERSION_2025_06_18.getValue() + + ") != server (" + + serverVersion + + ")")); + } + + this.negotiatedProtocolVersion = ProtocolVersion.VERSION_2025_06_18; + + JsonRpc.Notification notif = + new JsonRpc.Notification("notifications/initialized", Map.of()); + String notifBody = objectMapper.writeValueAsString(notif); + HttpRequest.Builder nReq = + HttpRequest.newBuilder() + .uri(URI.create(baseUrl)) + .POST(HttpRequest.BodyPublishers.ofString(notifBody)); + + handshakeHeaders.forEach(nReq::setHeader); + applyProtocolHeaders(nReq); + + return httpClient + .sendAsync(nReq.build(), HttpResponse.BodyHandlers.ofString()) + .thenAccept(nRes -> {}); + } catch (Exception e) { + return CompletableFuture.failedFuture(e); + } + }); + } catch (Exception e) { + return CompletableFuture.failedFuture(e); + } + } + + @Override + protected void applyProtocolHeaders(final HttpRequest.Builder builder) { + builder.header("Content-Type", "application/json"); + builder.header("Accept", "application/json"); + builder.header("MCP-Protocol-Version", ProtocolVersion.VERSION_2025_06_18.getValue()); + } +} diff --git a/src/main/java/com/google/cloud/mcp/transport/v20251125/HttpMcpTransportV20251125.java b/src/main/java/com/google/cloud/mcp/transport/v20251125/HttpMcpTransportV20251125.java new file mode 100644 index 0000000..4a3a5ea --- /dev/null +++ b/src/main/java/com/google/cloud/mcp/transport/v20251125/HttpMcpTransportV20251125.java @@ -0,0 +1,139 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp.transport.v20251125; + +import com.fasterxml.jackson.databind.JsonNode; +import com.google.cloud.mcp.JsonRpc; +import com.google.cloud.mcp.ProtocolVersion; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.exception.McpException; +import com.google.cloud.mcp.transport.BaseMcpTransport; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +public final class HttpMcpTransportV20251125 extends BaseMcpTransport { + + public HttpMcpTransportV20251125( + final String baseUrl, + final Map clientHeaders, + final CredentialsProvider credentialsProvider, + final HttpClient httpClient, + final java.util.concurrent.Executor executor) { + super( + baseUrl, + clientHeaders, + credentialsProvider, + ProtocolVersion.VERSION_2025_11_25, + httpClient, + executor); + } + + @Override + protected CompletableFuture performInitialization( + final String authHeader, + final Map handshakeHeaders, + final Map traceHeaders) { + try { + if (this.baseUrl.toLowerCase(java.util.Locale.ROOT).startsWith("http://") + && authHeader != null) { + logger.warning(HTTP_WARNING); + } + JsonRpc.RequestMetadata metadata = + new JsonRpc.RequestMetadata( + traceHeaders.get("traceparent"), traceHeaders.get("tracestate")); + JsonRpc.Request initReq = + new JsonRpc.Request( + "initialize", + new JsonRpc.InitializeParams( + ProtocolVersion.VERSION_2025_11_25.getValue(), "mcp-toolbox-sdk-java", metadata)); + String body = objectMapper.writeValueAsString(initReq); + HttpRequest.Builder req = + HttpRequest.newBuilder() + .uri(URI.create(baseUrl)) + .POST(HttpRequest.BodyPublishers.ofString(body)); + + handshakeHeaders.forEach(req::setHeader); + applyProtocolHeaders(req); + + return httpClient + .sendAsync(req.build(), HttpResponse.BodyHandlers.ofString()) + .thenCompose( + res -> { + if (res.statusCode() != 200) { + return CompletableFuture.failedFuture( + new McpException("Init failed: " + res.statusCode() + " " + res.body())); + } + try { + JsonNode responseJson = objectMapper.readTree(res.body()); + if (responseJson.has("error")) { + return CompletableFuture.failedFuture( + new McpException("MCP Error: " + responseJson.get("error").toString())); + } + JsonNode result = responseJson.get("result"); + String serverVersion; + if (result != null && result.has("protocolVersion")) { + serverVersion = result.get("protocolVersion").asText(); + } else { + serverVersion = ProtocolVersion.VERSION_2025_11_25.getValue(); + } + + if (!ProtocolVersion.VERSION_2025_11_25.getValue().equals(serverVersion)) { + return CompletableFuture.failedFuture( + new McpException( + "MCP version mismatch: client (" + + ProtocolVersion.VERSION_2025_11_25.getValue() + + ") != server (" + + serverVersion + + ")")); + } + + this.negotiatedProtocolVersion = ProtocolVersion.VERSION_2025_11_25; + + JsonRpc.Notification notif = + new JsonRpc.Notification("notifications/initialized", Map.of()); + String notifBody = objectMapper.writeValueAsString(notif); + HttpRequest.Builder nReq = + HttpRequest.newBuilder() + .uri(URI.create(baseUrl)) + .POST(HttpRequest.BodyPublishers.ofString(notifBody)); + + handshakeHeaders.forEach(nReq::setHeader); + applyProtocolHeaders(nReq); + + return httpClient + .sendAsync(nReq.build(), HttpResponse.BodyHandlers.ofString()) + .thenAccept(nRes -> {}); + } catch (Exception e) { + return CompletableFuture.failedFuture(e); + } + }); + } catch (Exception e) { + return CompletableFuture.failedFuture(e); + } + } + + @Override + protected void applyProtocolHeaders(final HttpRequest.Builder builder) { + builder.header("Content-Type", "application/json"); + builder.header("Accept", "application/json"); + builder.header("MCP-Protocol-Version", ProtocolVersion.VERSION_2025_11_25.getValue()); + } +} diff --git a/src/test/java/com/google/cloud/mcp/McpCoverageTest.java b/src/test/java/com/google/cloud/mcp/McpCoverageTest.java new file mode 100644 index 0000000..34e6f73 --- /dev/null +++ b/src/test/java/com/google/cloud/mcp/McpCoverageTest.java @@ -0,0 +1,114 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; + +import com.google.cloud.mcp.auth.AuthTokenGetter; +import com.google.cloud.mcp.client.McpToolboxClientImpl; +import com.google.cloud.mcp.exception.McpException; +import com.google.cloud.mcp.tool.Tool; +import com.google.cloud.mcp.tool.ToolDefinition; +import com.google.cloud.mcp.tool.ToolResult; +import com.google.cloud.mcp.transport.Transport; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +/** Miscellaneous unit tests to achieve 100% code coverage. */ +@Timeout(5) +public class McpCoverageTest { + + @Test + public void testMcpExceptionCoverage() { + McpException ex = new McpException("error message", new RuntimeException("cause")); + assertEquals("error message", ex.getMessage()); + assertEquals("cause", ex.getCause().getMessage()); + } + + @Test + public void testMcpToolboxClientDefaultClose() { + McpToolboxClient dummyClient = + new McpToolboxClient() { + @Override + public CompletableFuture> listTools() { + return null; + } + + @Override + public CompletableFuture> loadToolset(String name) { + return null; + } + + @Override + public CompletableFuture> loadToolset( + String name, + Map> p, + Map> a, + boolean s) { + return null; + } + + @Override + public CompletableFuture loadTool(String name) { + return null; + } + + @Override + public CompletableFuture loadTool( + String name, Map getters) { + return null; + } + + @Override + public CompletableFuture invokeTool(String name, Map args) { + return null; + } + + @Override + public CompletableFuture invokeTool( + String name, Map args, Map headers) { + return null; + } + }; + // Call default close (no-op) + dummyClient.close(); + } + + @Test + public void testMcpToolboxClientImplCloseThrowsException() throws Exception { + Transport mockTransport = mock(Transport.class); + doThrow(new RuntimeException("transport close error")).when(mockTransport).close(); + + McpToolboxClientImpl client = new McpToolboxClientImpl(mockTransport, java.util.Map.of(), null); + McpException ex = assertThrows(McpException.class, client::close); + assertEquals("Failed to close transport", ex.getMessage()); + assertEquals("transport close error", ex.getCause().getMessage()); + } + + @Test + public void testProtocolVersionFromString() { + assertNull(ProtocolVersion.fromString(null)); + assertNull(ProtocolVersion.fromString("invalid-version-string")); + assertEquals(ProtocolVersion.VERSION_2025_11_25, ProtocolVersion.fromString("2025-11-25")); + } +} diff --git a/src/test/java/com/google/cloud/mcp/TelemetryTest.java b/src/test/java/com/google/cloud/mcp/TelemetryTest.java new file mode 100644 index 0000000..9749bd9 --- /dev/null +++ b/src/test/java/com/google/cloud/mcp/TelemetryTest.java @@ -0,0 +1,245 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.cloud.mcp; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.cloud.mcp.tool.ToolDefinition; +import com.google.cloud.mcp.tool.ToolResult; +import com.sun.net.httpserver.HttpServer; +import io.opentelemetry.api.common.AttributeKey; +import io.opentelemetry.sdk.testing.junit5.OpenTelemetryExtension; +import io.opentelemetry.sdk.trace.data.SpanData; +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.extension.RegisterExtension; + +@Timeout(value = 15, unit = TimeUnit.SECONDS) +public class TelemetryTest { + + @RegisterExtension + static final OpenTelemetryExtension otelTesting = OpenTelemetryExtension.create(); + + private HttpServer server; + private String serverUrl; + private final List receivedRequests = Collections.synchronizedList(new ArrayList<>()); + private final ObjectMapper mapper = new ObjectMapper(); + + @BeforeEach + public void setUp() throws Exception { + receivedRequests.clear(); + server = HttpServer.create(new InetSocketAddress("localhost", 0), 0); + server.createContext( + "/mcp", + exchange -> { + try { + byte[] reqBytes = exchange.getRequestBody().readAllBytes(); + JsonNode reqNode = mapper.readTree(reqBytes); + receivedRequests.add(reqNode); + + String method = reqNode.has("method") ? reqNode.get("method").asText() : ""; + String responseBody = "{\"jsonrpc\":\"2.0\",\"id\":\"1\",\"result\":{}}"; + + if ("tools/list".equals(method)) { + responseBody = + "{\n" + + " \"jsonrpc\": \"2.0\",\n" + + " \"id\": \"1\",\n" + + " \"result\": {\n" + + " \"tools\": [\n" + + " {\n" + + " \"name\": \"test-tool\",\n" + + " \"description\": \"A test tool\",\n" + + " \"inputSchema\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {}\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + "}"; + } else if ("tools/call".equals(method)) { + responseBody = + "{\n" + + " \"jsonrpc\": \"2.0\",\n" + + " \"id\": \"1\",\n" + + " \"result\": {\n" + + " \"content\": [\n" + + " {\n" + + " \"type\": \"text\",\n" + + " \"text\": \"Success\"\n" + + " }\n" + + " ],\n" + + " \"isError\": false\n" + + " }\n" + + "}"; + } + + exchange.getResponseHeaders().set("Content-Type", "application/json"); + byte[] responseBytes = responseBody.getBytes(); + exchange.sendResponseHeaders(200, responseBytes.length); + try (OutputStream os = exchange.getResponseBody()) { + os.write(responseBytes); + } + } catch (Exception e) { + exchange.sendResponseHeaders(500, 0); + exchange.close(); + } + }); + server.start(); + int port = server.getAddress().getPort(); + serverUrl = "http://localhost:" + port + "/mcp"; + } + + @AfterEach + public void tearDown() { + if (server != null) { + server.stop(0); + } + } + + @Test + public void testTelemetrySpansAndContextPropagation() throws Exception { + try (McpToolboxClient client = McpToolboxClient.builder().baseUrl(serverUrl).build()) { + // 1. Load toolset (triggers initialize and tools/list) + Map tools = client.loadToolset().get(); + assertNotNull(tools); + assertTrue(tools.containsKey("test-tool")); + + // 2. Invoke tool + ToolResult result = client.invokeTool("test-tool", Map.of()).get(); + assertNotNull(result); + assertFalse(result.isError()); + } + + // Verify Spans were created + List spans = otelTesting.getSpans(); + + // Spans should be: "initialize", "tools/list", "tools/call test-tool" + assertTrue(spans.stream().anyMatch(s -> "initialize".equals(s.getName()))); + assertTrue(spans.stream().anyMatch(s -> "tools/list".equals(s.getName()))); + assertTrue(spans.stream().anyMatch(s -> "tools/call test-tool".equals(s.getName()))); + + SpanData initSpan = + spans.stream().filter(s -> "initialize".equals(s.getName())).findFirst().orElseThrow(); + SpanData listSpan = + spans.stream().filter(s -> "tools/list".equals(s.getName())).findFirst().orElseThrow(); + SpanData callSpan = + spans.stream() + .filter(s -> "tools/call test-tool".equals(s.getName())) + .findFirst() + .orElseThrow(); + + // Verify Span attributes + assertEquals( + "initialize", initSpan.getAttributes().get(AttributeKey.stringKey("mcp.method.name"))); + assertEquals( + "tools/list", listSpan.getAttributes().get(AttributeKey.stringKey("mcp.method.name"))); + assertEquals( + "tools/call", callSpan.getAttributes().get(AttributeKey.stringKey("mcp.method.name"))); + assertEquals( + "test-tool", callSpan.getAttributes().get(AttributeKey.stringKey("gen_ai.tool.name"))); + + // Verify context propagation in JSON-RPC metadata + // Note: invokeTool does not trigger initialization again since it was already initialized + // So invokeTool adds tools/call request, making it 4 requests total. + // Wait, let's verify if the list size is 4. + // index 0: initialize (Request) + // index 1: notifications/initialized (Notification) + // index 2: tools/list (Request) + // index 3: tools/call (Request) + assertEquals(4, receivedRequests.size()); + + JsonNode initReq = receivedRequests.get(0); + JsonNode listReq = receivedRequests.get(2); + JsonNode callReq = receivedRequests.get(3); + + // Verify traceparent in requests matches the span's traceId + String initTraceParent = initReq.get("params").get("_meta").get("traceparent").asText(); + assertNotNull(initTraceParent); + assertTrue(initTraceParent.contains(initSpan.getTraceId())); + + String listTraceParent = listReq.get("params").get("_meta").get("traceparent").asText(); + assertNotNull(listTraceParent); + assertTrue(listTraceParent.contains(listSpan.getTraceId())); + + String callTraceParent = callReq.get("params").get("_meta").get("traceparent").asText(); + assertNotNull(callTraceParent); + assertTrue(callTraceParent.contains(callSpan.getTraceId())); + } + + @Test + public void testTelemetryHelperEdgeCases() { + // 1. Test ServerInfo record methods (equals, hashCode, toString, and accessors) + TelemetryHelper.ServerInfo info1 = new TelemetryHelper.ServerInfo("localhost", 8080, "http"); + TelemetryHelper.ServerInfo info2 = new TelemetryHelper.ServerInfo("localhost", 8080, "http"); + TelemetryHelper.ServerInfo info3 = new TelemetryHelper.ServerInfo("example.com", 9090, "https"); + + assertEquals(info1, info2); + assertNotEquals(info1, info3); + assertEquals(info1.hashCode(), info2.hashCode()); + assertNotNull(info1.toString()); + assertEquals("localhost", info1.address()); + assertEquals(8080, info1.port()); + assertEquals("http", info1.protocol()); + + // 2. Test extractServerInfo with various edge-case URLs + TelemetryHelper.ServerInfo invalid = TelemetryHelper.extractServerInfo(":::"); + assertEquals("", invalid.address()); + assertNull(invalid.port()); + assertEquals("http", invalid.protocol()); + + TelemetryHelper.ServerInfo noHost = TelemetryHelper.extractServerInfo("http:///mcp"); + assertEquals("", noHost.address()); + assertNull(noHost.port()); + + TelemetryHelper.ServerInfo noHostWithPort = + TelemetryHelper.extractServerInfo("http://my_server:8080"); + assertEquals("my_server", noHostWithPort.address()); + assertEquals(8080, noHostWithPort.port()); + + TelemetryHelper.ServerInfo invalidPort = + TelemetryHelper.extractServerInfo("http://my_server:invalidport"); + assertEquals("my_server", invalidPort.address()); + assertNull(invalidPort.port()); + + TelemetryHelper.ServerInfo noProtocol = TelemetryHelper.extractServerInfo("//localhost:8080"); + assertEquals("localhost", noProtocol.address()); + assertEquals(8080, noProtocol.port()); + assertEquals("http", noProtocol.protocol()); + + // 3. Test recordSessionDuration with error + TelemetryHelper.recordSessionDuration( + 5.5, "2025-11-25", "http://localhost:8080", new RuntimeException("session error")); + } +} diff --git a/src/test/java/com/google/cloud/mcp/AuthMethodsTest.java b/src/test/java/com/google/cloud/mcp/auth/AuthMethodsTest.java similarity index 99% rename from src/test/java/com/google/cloud/mcp/AuthMethodsTest.java rename to src/test/java/com/google/cloud/mcp/auth/AuthMethodsTest.java index 3cd75ab..66a497f 100644 --- a/src/test/java/com/google/cloud/mcp/AuthMethodsTest.java +++ b/src/test/java/com/google/cloud/mcp/auth/AuthMethodsTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.google.cloud.mcp; +package com.google.cloud.mcp.auth; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; diff --git a/src/test/java/com/google/cloud/mcp/HttpMcpToolboxClientTest.java b/src/test/java/com/google/cloud/mcp/client/HttpMcpToolboxClientTest.java similarity index 98% rename from src/test/java/com/google/cloud/mcp/HttpMcpToolboxClientTest.java rename to src/test/java/com/google/cloud/mcp/client/HttpMcpToolboxClientTest.java index a558e6d..9ea1dad 100644 --- a/src/test/java/com/google/cloud/mcp/HttpMcpToolboxClientTest.java +++ b/src/test/java/com/google/cloud/mcp/client/HttpMcpToolboxClientTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.google.cloud.mcp; +package com.google.cloud.mcp.client; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -22,6 +22,8 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.cloud.mcp.McpToolboxClient; +import com.google.cloud.mcp.ProtocolVersion; import com.sun.net.httpserver.HttpExchange; import com.sun.net.httpserver.HttpHandler; import com.sun.net.httpserver.HttpServer; diff --git a/src/test/java/com/google/cloud/mcp/McpToolboxClientBuilderTest.java b/src/test/java/com/google/cloud/mcp/client/McpToolboxClientBuilderTest.java similarity index 94% rename from src/test/java/com/google/cloud/mcp/McpToolboxClientBuilderTest.java rename to src/test/java/com/google/cloud/mcp/client/McpToolboxClientBuilderTest.java index 5de1e1f..b5b1168 100644 --- a/src/test/java/com/google/cloud/mcp/McpToolboxClientBuilderTest.java +++ b/src/test/java/com/google/cloud/mcp/client/McpToolboxClientBuilderTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.google.cloud.mcp; +package com.google.cloud.mcp.client; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -23,6 +23,13 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import com.google.cloud.mcp.McpToolboxClient; +import com.google.cloud.mcp.ProtocolVersion; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.exception.McpException; +import com.google.cloud.mcp.tool.ToolPostProcessor; +import com.google.cloud.mcp.tool.ToolPreProcessor; +import com.google.cloud.mcp.transport.Transport; import java.lang.reflect.Field; import java.lang.reflect.Method; import java.util.Map; diff --git a/src/test/java/com/google/cloud/mcp/McpToolboxClientImplErrorsTest.java b/src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplErrorsTest.java similarity index 97% rename from src/test/java/com/google/cloud/mcp/McpToolboxClientImplErrorsTest.java rename to src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplErrorsTest.java index 36d1686..49732a8 100644 --- a/src/test/java/com/google/cloud/mcp/McpToolboxClientImplErrorsTest.java +++ b/src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplErrorsTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.google.cloud.mcp; +package com.google.cloud.mcp.client; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; @@ -24,6 +24,10 @@ import static org.mockito.Mockito.when; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.tool.ToolDefinition; +import com.google.cloud.mcp.tool.ToolResult; +import com.google.cloud.mcp.transport.HttpMcpTransport; import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; diff --git a/src/test/java/com/google/cloud/mcp/McpToolboxClientImplHeadersTest.java b/src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplHeadersTest.java similarity index 92% rename from src/test/java/com/google/cloud/mcp/McpToolboxClientImplHeadersTest.java rename to src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplHeadersTest.java index 72b7031..7dc2aab 100644 --- a/src/test/java/com/google/cloud/mcp/McpToolboxClientImplHeadersTest.java +++ b/src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplHeadersTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.google.cloud.mcp; +package com.google.cloud.mcp.client; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; @@ -24,6 +24,10 @@ import static org.mockito.Mockito.when; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.cloud.mcp.McpToolboxClient; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.transport.BaseMcpTransport; +import com.google.cloud.mcp.transport.HttpMcpTransport; import java.lang.reflect.Field; import java.net.http.HttpClient; import java.net.http.HttpRequest; @@ -66,9 +70,12 @@ void testCustomHeadersPopulatedInAllRequests() throws Exception { Field transportField = McpToolboxClientImpl.class.getDeclaredField("transport"); transportField.setAccessible(true); HttpMcpTransport transport = (HttpMcpTransport) transportField.get(client); - Field httpClientField = HttpMcpTransport.class.getDeclaredField("httpClient"); + Field delegateField = HttpMcpTransport.class.getDeclaredField("delegate"); + delegateField.setAccessible(true); + Object delegate = delegateField.get(transport); + Field httpClientField = BaseMcpTransport.class.getDeclaredField("httpClient"); httpClientField.setAccessible(true); - httpClientField.set(transport, mockHttpClient); + httpClientField.set(delegate, mockHttpClient); HttpResponse initResponse = mock(HttpResponse.class); when(initResponse.statusCode()).thenReturn(200); @@ -150,9 +157,12 @@ void testExtraHeadersOverrideAndAuthPriority() throws Exception { Field transportField = McpToolboxClientImpl.class.getDeclaredField("transport"); transportField.setAccessible(true); HttpMcpTransport transport = (HttpMcpTransport) transportField.get(client); - Field httpClientField = HttpMcpTransport.class.getDeclaredField("httpClient"); + Field delegateField = HttpMcpTransport.class.getDeclaredField("delegate"); + delegateField.setAccessible(true); + Object delegate = delegateField.get(transport); + Field httpClientField = BaseMcpTransport.class.getDeclaredField("httpClient"); httpClientField.setAccessible(true); - httpClientField.set(transport, mockHttpClient); + httpClientField.set(delegate, mockHttpClient); HttpResponse initResponse = mock(HttpResponse.class); when(initResponse.statusCode()).thenReturn(200); @@ -223,9 +233,12 @@ void testNoDuplicateHeaders() throws Exception { Field transportField = McpToolboxClientImpl.class.getDeclaredField("transport"); transportField.setAccessible(true); HttpMcpTransport transport = (HttpMcpTransport) transportField.get(client); - Field httpClientField = HttpMcpTransport.class.getDeclaredField("httpClient"); + Field delegateField = HttpMcpTransport.class.getDeclaredField("delegate"); + delegateField.setAccessible(true); + Object delegate = delegateField.get(transport); + Field httpClientField = BaseMcpTransport.class.getDeclaredField("httpClient"); httpClientField.setAccessible(true); - httpClientField.set(transport, mockHttpClient); + httpClientField.set(delegate, mockHttpClient); HttpResponse initResponse = mock(HttpResponse.class); when(initResponse.statusCode()).thenReturn(200); diff --git a/src/test/java/com/google/cloud/mcp/McpToolboxClientImplJsonRpcTest.java b/src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplJsonRpcTest.java similarity index 98% rename from src/test/java/com/google/cloud/mcp/McpToolboxClientImplJsonRpcTest.java rename to src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplJsonRpcTest.java index 3a9a99e..6ec19e6 100644 --- a/src/test/java/com/google/cloud/mcp/McpToolboxClientImplJsonRpcTest.java +++ b/src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplJsonRpcTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.google.cloud.mcp; +package com.google.cloud.mcp.client; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -26,6 +26,11 @@ import static org.mockito.Mockito.when; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.cloud.mcp.JsonRpc; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.tool.ToolDefinition; +import com.google.cloud.mcp.tool.ToolResult; +import com.google.cloud.mcp.transport.HttpMcpTransport; import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; @@ -487,9 +492,11 @@ void testListTools_withMissingInputSchemaOrProperties() throws Exception { } @Test - void testJsonRpcInstantiation() { + void testJsonRpcInstantiation() throws Exception { // Instantiate package-private JsonRpc namespace to cover its default constructor - JsonRpc rpc = new JsonRpc(); + java.lang.reflect.Constructor constructor = JsonRpc.class.getDeclaredConstructor(); + constructor.setAccessible(true); + JsonRpc rpc = constructor.newInstance(); assertNotNull(rpc); } diff --git a/src/test/java/com/google/cloud/mcp/McpToolboxClientImplTest.java b/src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplTest.java similarity index 95% rename from src/test/java/com/google/cloud/mcp/McpToolboxClientImplTest.java rename to src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplTest.java index 075fbf0..2ff4387 100644 --- a/src/test/java/com/google/cloud/mcp/McpToolboxClientImplTest.java +++ b/src/test/java/com/google/cloud/mcp/client/McpToolboxClientImplTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.google.cloud.mcp; +package com.google.cloud.mcp.client; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -31,6 +31,20 @@ import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.auth.oauth2.GoogleCredentials; +import com.google.cloud.mcp.JsonRpc; +import com.google.cloud.mcp.McpToolboxClient; +import com.google.cloud.mcp.auth.AuthTokenGetter; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.tool.Tool; +import com.google.cloud.mcp.tool.ToolDefinition; +import com.google.cloud.mcp.tool.ToolPostProcessor; +import com.google.cloud.mcp.tool.ToolPreProcessor; +import com.google.cloud.mcp.tool.ToolResult; +import com.google.cloud.mcp.transport.BaseMcpTransport; +import com.google.cloud.mcp.transport.HttpMcpTransport; +import com.google.cloud.mcp.transport.Transport; +import com.google.cloud.mcp.transport.TransportManifest; +import com.google.cloud.mcp.transport.TransportResponse; import java.io.IOException; import java.lang.reflect.Field; import java.lang.reflect.Method; @@ -514,10 +528,12 @@ void testLoadToolset_withInvalidUriThrowsException() { @Test void testInvokeTool_withInvalidUriThrowsException() throws Exception { HttpMcpTransport transport = new HttpMcpTransport("http://invalid uri", mockHttpClient); - Field initFutureField = HttpMcpTransport.class.getDeclaredField("initFuture"); + Field delegateField = HttpMcpTransport.class.getDeclaredField("delegate"); + delegateField.setAccessible(true); + Object delegate = delegateField.get(transport); + Field initFutureField = BaseMcpTransport.class.getDeclaredField("initFuture"); initFutureField.setAccessible(true); - initFutureField.set( - transport, CompletableFuture.completedFuture(null)); // bypass initialization + initFutureField.set(delegate, CompletableFuture.completedFuture(null)); // bypass initialization McpToolboxClientImpl badClient = new McpToolboxClientImpl(transport, java.util.Collections.emptyMap(), null); @@ -594,11 +610,15 @@ void testEnsureInitialized_withNullAuthHeader() throws Exception { .thenReturn(CompletableFuture.completedFuture(initResponse)) .thenReturn(CompletableFuture.completedFuture(notifResponse)); - Method initMethod = HttpMcpTransport.class.getDeclaredMethod("ensureInitialized", Map.class); + Field delegateField = HttpMcpTransport.class.getDeclaredField("delegate"); + delegateField.setAccessible(true); + Object delegate = delegateField.get(transport); + + Method initMethod = BaseMcpTransport.class.getDeclaredMethod("ensureInitialized", Map.class); initMethod.setAccessible(true); CompletableFuture future = - (CompletableFuture) initMethod.invoke(transport, java.util.Collections.emptyMap()); + (CompletableFuture) initMethod.invoke(delegate, java.util.Collections.emptyMap()); future.join(); // should complete and NOT set Authorization header ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpRequest.class); @@ -782,9 +802,12 @@ void testListTools_withInvalidToolsetNameThrows() throws Exception { HttpMcpTransport transport = new HttpMcpTransport("http://localhost:8080", mockHttpClient); // Force transport to be initialized first - Field initFutureField = HttpMcpTransport.class.getDeclaredField("initFuture"); + Field delegateField = HttpMcpTransport.class.getDeclaredField("delegate"); + delegateField.setAccessible(true); + Object delegate = delegateField.get(transport); + Field initFutureField = BaseMcpTransport.class.getDeclaredField("initFuture"); initFutureField.setAccessible(true); - initFutureField.set(transport, CompletableFuture.completedFuture(null)); + initFutureField.set(delegate, CompletableFuture.completedFuture(null)); CompletableFuture future = transport.listTools("invalid path with spaces \\", java.util.Collections.emptyMap()); @@ -805,9 +828,12 @@ void testEnsureInitialized_withNotificationSerializationFailure() throws Excepti when(mockMapper.writeValueAsString(any(JsonRpc.Notification.class))) .thenThrow(new RuntimeException("Simulated notification serialization failure")); - Field mapperField = HttpMcpTransport.class.getDeclaredField("objectMapper"); + Field delegateField = HttpMcpTransport.class.getDeclaredField("delegate"); + delegateField.setAccessible(true); + Object delegate = delegateField.get(transport); + Field mapperField = BaseMcpTransport.class.getDeclaredField("objectMapper"); mapperField.setAccessible(true); - mapperField.set(transport, mockMapper); + mapperField.set(delegate, mockMapper); HttpResponse initResponse = mock(HttpResponse.class); when(initResponse.statusCode()).thenReturn(200); @@ -816,11 +842,11 @@ void testEnsureInitialized_withNotificationSerializationFailure() throws Excepti when(mockHttpClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) .thenReturn(CompletableFuture.completedFuture(initResponse)); - Method initMethod = HttpMcpTransport.class.getDeclaredMethod("ensureInitialized", Map.class); + Method initMethod = BaseMcpTransport.class.getDeclaredMethod("ensureInitialized", Map.class); initMethod.setAccessible(true); CompletableFuture future = - (CompletableFuture) initMethod.invoke(transport, java.util.Collections.emptyMap()); + (CompletableFuture) initMethod.invoke(delegate, java.util.Collections.emptyMap()); java.util.concurrent.ExecutionException ex = org.junit.jupiter.api.Assertions.assertThrows( diff --git a/src/test/java/com/google/cloud/mcp/e2e/McpToolboxClientTest.java b/src/test/java/com/google/cloud/mcp/e2e/McpToolboxClientTest.java index c2b61e9..77f9823 100644 --- a/src/test/java/com/google/cloud/mcp/e2e/McpToolboxClientTest.java +++ b/src/test/java/com/google/cloud/mcp/e2e/McpToolboxClientTest.java @@ -21,9 +21,9 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import com.google.cloud.mcp.McpToolboxClient; -import com.google.cloud.mcp.Tool; -import com.google.cloud.mcp.ToolDefinition; -import com.google.cloud.mcp.ToolResult; +import com.google.cloud.mcp.tool.Tool; +import com.google.cloud.mcp.tool.ToolDefinition; +import com.google.cloud.mcp.tool.ToolResult; import java.util.Map; import java.util.concurrent.CompletableFuture; import org.junit.jupiter.api.BeforeEach; diff --git a/src/test/java/com/google/cloud/mcp/ToolTest.java b/src/test/java/com/google/cloud/mcp/tool/ToolTest.java similarity index 63% rename from src/test/java/com/google/cloud/mcp/ToolTest.java rename to src/test/java/com/google/cloud/mcp/tool/ToolTest.java index 4c3682a..6fa6154 100644 --- a/src/test/java/com/google/cloud/mcp/ToolTest.java +++ b/src/test/java/com/google/cloud/mcp/tool/ToolTest.java @@ -14,11 +14,12 @@ * limitations under the License. */ -package com.google.cloud.mcp; +package com.google.cloud.mcp.tool; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyMap; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; @@ -28,7 +29,10 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.cloud.mcp.McpToolboxClient; +import com.google.cloud.mcp.auth.ResolvedAuth; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -266,6 +270,195 @@ void testResolvedAuth_withNullKeysAndValuesInTokensMap() { assertTrue(!extraHeaders.containsKey("null_token")); } + @Test + void testValidateAndSanitizeArgs_customTypeMatch() throws Exception { + List params = + List.of( + new ToolDefinition.Parameter("p-custom", "custom-type-name", false, "desc", List.of())); + ToolDefinition def = new ToolDefinition("test-tool", params, List.of()); + McpToolboxClient client = mock(McpToolboxClient.class); + when(client.invokeTool(anyString(), anyMap(), anyMap())) + .thenReturn(CompletableFuture.completedFuture(new ToolResult(List.of(), false))); + + Tool tool = new Tool("test-tool", def, client); + tool.execute(Map.of("p-custom", "any-value")).join(); // should succeed + } + + @Test + void testValidateAndSanitizeArgs_withNullParameters() throws Exception { + ToolDefinition def = new ToolDefinition("test-tool", null, List.of()); + McpToolboxClient client = mock(McpToolboxClient.class); + when(client.invokeTool(anyString(), anyMap(), anyMap())) + .thenReturn(CompletableFuture.completedFuture(new ToolResult(List.of(), false))); + + Tool tool = new Tool("test-tool", def, client); + tool.execute(Map.of("any-param", "any-value")).join(); // should bypass validation loop safely + } + + @Test + void testDefaultValueInjection() throws Exception { + McpToolboxClient mockClient = mock(McpToolboxClient.class); + + ToolDefinition.Parameter paramWithDefault = + new ToolDefinition.Parameter( + "param1", "string", false, "A parameter", null, "default_value"); + ToolDefinition.Parameter paramNoDefault = + new ToolDefinition.Parameter("param2", "string", false, "Another parameter", null, null); + + ToolDefinition def = + new ToolDefinition("A test tool", List.of(paramWithDefault, paramNoDefault), null); + + Tool tool = new Tool("testTool", def, mockClient); + + when(mockClient.invokeTool(eq("testTool"), any(), any())) + .thenReturn( + CompletableFuture.completedFuture(new ToolResult(Collections.emptyList(), false))); + + Map args = new HashMap<>(); + args.put("param2", "provided_value"); + + CompletableFuture future = tool.execute(args); + future.join(); // Wait for execution + + @SuppressWarnings("unchecked") + ArgumentCaptor> argsCaptor = ArgumentCaptor.forClass(Map.class); + @SuppressWarnings("unchecked") + ArgumentCaptor> headersCaptor = ArgumentCaptor.forClass(Map.class); + + verify(mockClient).invokeTool(eq("testTool"), argsCaptor.capture(), headersCaptor.capture()); + + Map capturedArgs = argsCaptor.getValue(); + + assertEquals( + "default_value", + capturedArgs.get("param1"), + "Default value should be injected when not provided"); + assertEquals("provided_value", capturedArgs.get("param2"), "Provided value should be kept"); + } + + @Test + void testDefaultValueNotOverwritten() throws Exception { + McpToolboxClient mockClient = mock(McpToolboxClient.class); + + ToolDefinition.Parameter paramWithDefault = + new ToolDefinition.Parameter( + "param1", "string", false, "A parameter", null, "default_value"); + + ToolDefinition def = new ToolDefinition("A test tool", List.of(paramWithDefault), null); + + Tool tool = new Tool("testTool", def, mockClient); + + when(mockClient.invokeTool(eq("testTool"), any(), any())) + .thenReturn( + CompletableFuture.completedFuture(new ToolResult(Collections.emptyList(), false))); + + Map args = new HashMap<>(); + args.put("param1", "custom_value"); + + CompletableFuture future = tool.execute(args); + future.join(); // Wait for execution + + @SuppressWarnings("unchecked") + ArgumentCaptor> argsCaptor = ArgumentCaptor.forClass(Map.class); + @SuppressWarnings("unchecked") + ArgumentCaptor> headersCaptor = ArgumentCaptor.forClass(Map.class); + + verify(mockClient).invokeTool(eq("testTool"), argsCaptor.capture(), headersCaptor.capture()); + + Map capturedArgs = argsCaptor.getValue(); + + assertEquals( + "custom_value", + capturedArgs.get("param1"), + "Provided value should not be overwritten by default value"); + } + + @Test + void testDefaultValueDeepCloning() throws Exception { + McpToolboxClient mockClient = mock(McpToolboxClient.class); + + Map complexDefault = new HashMap<>(); + complexDefault.put("key", "value"); + + ToolDefinition.Parameter paramWithDefault = + new ToolDefinition.Parameter( + "param1", "object", false, "A parameter", null, complexDefault); + + ToolDefinition def = new ToolDefinition("A test tool", List.of(paramWithDefault), null); + + Tool tool = new Tool("testTool", def, mockClient); + + when(mockClient.invokeTool(eq("testTool"), any(), any())) + .thenReturn( + CompletableFuture.completedFuture(new ToolResult(Collections.emptyList(), false))); + + Map args = new HashMap<>(); + CompletableFuture future = tool.execute(args); + future.join(); + + @SuppressWarnings("unchecked") + ArgumentCaptor> argsCaptor = ArgumentCaptor.forClass(Map.class); + verify(mockClient).invokeTool(eq("testTool"), argsCaptor.capture(), any()); + + Map capturedArgs = argsCaptor.getValue(); + @SuppressWarnings("unchecked") + Map injectedDefault = (Map) capturedArgs.get("param1"); + + // Mutate the injected map + injectedDefault.put("key", "mutated_value"); + + // Ensure the original defaultValue stored in the definition remains untouched + @SuppressWarnings("unchecked") + Map defValueInDefinition = + (Map) def.parameters().get(0).defaultValue(); + assertEquals( + "value", + defValueInDefinition.get("key"), + "The default value in definition must remain unmutated"); + } + + @Test + void testDefaultValueDeepCloning_withList() throws Exception { + McpToolboxClient mockClient = mock(McpToolboxClient.class); + + List complexDefault = new ArrayList<>(); + complexDefault.add("value"); + + ToolDefinition.Parameter paramWithDefault = + new ToolDefinition.Parameter("param1", "array", false, "A parameter", null, complexDefault); + + ToolDefinition def = new ToolDefinition("A test tool", List.of(paramWithDefault), null); + + Tool tool = new Tool("testTool", def, mockClient); + + when(mockClient.invokeTool(eq("testTool"), any(), any())) + .thenReturn( + CompletableFuture.completedFuture(new ToolResult(Collections.emptyList(), false))); + + Map args = new HashMap<>(); + CompletableFuture future = tool.execute(args); + future.join(); + + @SuppressWarnings("unchecked") + ArgumentCaptor> argsCaptor = ArgumentCaptor.forClass(Map.class); + verify(mockClient).invokeTool(eq("testTool"), argsCaptor.capture(), any()); + + Map capturedArgs = argsCaptor.getValue(); + @SuppressWarnings("unchecked") + List injectedDefault = (List) capturedArgs.get("param1"); + + // Mutate the injected list + injectedDefault.set(0, "mutated_value"); + + // Ensure the original defaultValue stored in the definition remains untouched + @SuppressWarnings("unchecked") + List defValueInDefinition = (List) def.parameters().get(0).defaultValue(); + assertEquals( + "value", + defValueInDefinition.get(0), + "The default value in definition must remain unmutated"); + } + @Test void testToolDefinitionHints() { ToolDefinition defWithHints = diff --git a/src/test/java/com/google/cloud/mcp/ToolValidationTest.java b/src/test/java/com/google/cloud/mcp/tool/ToolValidationTest.java similarity index 99% rename from src/test/java/com/google/cloud/mcp/ToolValidationTest.java rename to src/test/java/com/google/cloud/mcp/tool/ToolValidationTest.java index bef9d1a..e645a73 100644 --- a/src/test/java/com/google/cloud/mcp/ToolValidationTest.java +++ b/src/test/java/com/google/cloud/mcp/tool/ToolValidationTest.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package com.google.cloud.mcp; +package com.google.cloud.mcp.tool; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -27,6 +27,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.cloud.mcp.McpToolboxClient; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; diff --git a/src/test/java/com/google/cloud/mcp/HttpMcpTransportTest.java b/src/test/java/com/google/cloud/mcp/transport/HttpMcpTransportTest.java similarity index 91% rename from src/test/java/com/google/cloud/mcp/HttpMcpTransportTest.java rename to src/test/java/com/google/cloud/mcp/transport/HttpMcpTransportTest.java index 1787e89..3d1ad86 100644 --- a/src/test/java/com/google/cloud/mcp/HttpMcpTransportTest.java +++ b/src/test/java/com/google/cloud/mcp/transport/HttpMcpTransportTest.java @@ -14,11 +14,12 @@ * limitations under the License. */ -package com.google.cloud.mcp; +package com.google.cloud.mcp.transport; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.mock; @@ -26,6 +27,10 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import com.google.cloud.mcp.ProtocolVersion; +import com.google.cloud.mcp.auth.CredentialsProvider; +import com.google.cloud.mcp.exception.McpException; +import com.google.cloud.mcp.tool.ToolDefinition; import java.net.http.HttpClient; import java.net.http.HttpRequest; import java.net.http.HttpResponse; @@ -253,9 +258,13 @@ void testConstructor_WithCustomExecutorConfiguresHttpClient() throws Exception { null, customExecutor); - java.lang.reflect.Field httpClientField = HttpMcpTransport.class.getDeclaredField("httpClient"); + java.lang.reflect.Field delegateField = HttpMcpTransport.class.getDeclaredField("delegate"); + delegateField.setAccessible(true); + Object delegate = delegateField.get(transport); + + java.lang.reflect.Field httpClientField = BaseMcpTransport.class.getDeclaredField("httpClient"); httpClientField.setAccessible(true); - java.net.http.HttpClient httpClient = (java.net.http.HttpClient) httpClientField.get(transport); + java.net.http.HttpClient httpClient = (java.net.http.HttpClient) httpClientField.get(delegate); assertNotNull(httpClient); Object internalExecutor = null; @@ -319,7 +328,7 @@ void testListTools_WithHttpUrlAndMetadata_LogsWarning() throws Exception { .thenReturn(CompletableFuture.completedFuture(mockListResponse)); java.util.logging.Logger transportLogger = - java.util.logging.Logger.getLogger(HttpMcpTransport.class.getName()); + java.util.logging.Logger.getLogger(BaseMcpTransport.class.getName()); java.util.List logRecords = new java.util.ArrayList<>(); java.util.logging.Handler logHandler = new java.util.logging.Handler() { @@ -466,4 +475,29 @@ void testListTools_ParsesComplexToolsCorrectly() throws Exception { assertFalse(p2.required()); assertEquals("string", p2.type()); } + + @Test + @Timeout(5) + @SuppressWarnings("unchecked") + void testInvokeTool_ExceptionRecording() throws Exception { + HttpResponse mockInitResponse = mock(HttpResponse.class); + when(mockInitResponse.statusCode()).thenReturn(200); + when(mockInitResponse.body()) + .thenReturn( + "{\"jsonrpc\":\"2.0\",\"id\":\"1\",\"result\":{\"protocolVersion\":\"2025-11-25\"}}"); + + HttpResponse mockInitializedResponse = mock(HttpResponse.class); + when(mockInitializedResponse.statusCode()).thenReturn(200); + when(mockInitializedResponse.body()).thenReturn(""); + + when(mockClient.sendAsync(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(CompletableFuture.completedFuture(mockInitResponse)) + .thenReturn(CompletableFuture.completedFuture(mockInitializedResponse)) + .thenReturn(CompletableFuture.failedFuture(new java.io.IOException("connection failure"))); + + CompletableFuture futureResult = + transport.invokeTool("test-tool", Map.of(), Collections.emptyMap()); + + assertThrows(Exception.class, futureResult::get); + } }