diff --git a/Cargo.lock b/Cargo.lock index 8b2af01fb8..601266f59b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1035,6 +1035,7 @@ dependencies = [ "serde", "serde_json", "sha2 0.11.0", + "shell-words", "shellexpand", "shlex", "similar", diff --git a/crates/tui/Cargo.toml b/crates/tui/Cargo.toml index 9298b2f474..908b368f49 100644 --- a/crates/tui/Cargo.toml +++ b/crates/tui/Cargo.toml @@ -80,6 +80,7 @@ tar = "0.4" flate2 = "1.1" sha2 = "0.11" rust-i18n = "4.1.0" +shell-words = "1.1.1" [dev-dependencies] cucumber = "0.23.0" diff --git a/crates/tui/src/core/engine.rs b/crates/tui/src/core/engine.rs index 4a3a75bf40..d19da22873 100644 --- a/crates/tui/src/core/engine.rs +++ b/crates/tui/src/core/engine.rs @@ -31,7 +31,7 @@ use crate::config::{ApiProvider, Config, DEFAULT_MAX_SUBAGENTS, DEFAULT_TEXT_MOD use crate::error_taxonomy::{ErrorCategory, ErrorEnvelope, StreamError}; use crate::features::{Feature, Features}; use crate::llm_client::LlmClient; -use crate::mcp::McpPool; +use crate::mcp::{McpConfig, McpPool}; #[cfg(test)] use crate::models::ToolCaller; use crate::models::{ @@ -2415,6 +2415,11 @@ impl Engine { let plan_state = self.config.plan_state.clone(); let tool_context = self.build_tool_context(input_policy.mode, input_policy.auto_approve); + // Ensure MCP pool is initialized before building the tool registry, + // so start_mcp_server can be registered when Feature::Mcp is enabled. + if self.config.features.enabled(Feature::Mcp) { + let _ = self.ensure_mcp_pool().await; + } let builder = self .build_turn_tool_registry_builder(input_policy.mode, todo_list, plan_state) .with_dynamic_tools(&dynamic_tools); @@ -2571,11 +2576,15 @@ impl Engine { self.api_config.api_provider(), &self.config.model, ); + let mut always_load = self.config.tools_always_load.clone(); + if self.config.features.enabled(Feature::Mcp) { + always_load.insert("start_mcp_server".to_string()); + } let mut catalog = build_model_tool_catalog_with_surface( registry.to_api_tools_with_cache(true), mcp_tools, input_policy.mode, - &self.config.tools_always_load, + &always_load, capability.tool_surface_budget, ); for tool in &mut catalog { @@ -3113,7 +3122,10 @@ impl Engine { &self.session.mcp_config_path, &self.session.workspace, ) - .map_err(|e| ToolError::execution_failed(format!("Failed to load MCP config: {e}")))?; + .unwrap_or_else(|e| { + tracing::debug!("No MCP config: {e}"); + McpPool::new(McpConfig::default()) + }); if let Some(decider) = self.config.network_policy.as_ref() { pool = pool.with_network_policy(decider.clone()); } diff --git a/crates/tui/src/core/engine/tool_setup.rs b/crates/tui/src/core/engine/tool_setup.rs index ed09c304e4..403fbaccbd 100644 --- a/crates/tui/src/core/engine/tool_setup.rs +++ b/crates/tui/src/core/engine/tool_setup.rs @@ -144,6 +144,13 @@ impl Engine { // so there's no failure mode worth gating on. builder = builder.with_notify_tool(); + // Register the start_mcp_server tool so LLM can dynamically start + // MCP servers from conversation context. Only when the pool has been + // initialized (lazy via ensure_mcp_pool). + if let Some(ref pool) = self.mcp_pool { + builder = builder.with_runtime_mcp_tool(Arc::clone(pool)); + } + builder } } diff --git a/crates/tui/src/core/engine/turn_loop.rs b/crates/tui/src/core/engine/turn_loop.rs index 827e898f62..38984e7336 100644 --- a/crates/tui/src/core/engine/turn_loop.rs +++ b/crates/tui/src/core/engine/turn_loop.rs @@ -190,7 +190,7 @@ fn normalize_domain_candidate(value: &str) -> Option { } fn registered_tool_requires_non_bypassable_approval(tool_name: &str) -> bool { - matches!(tool_name, "rlm_eval") + matches!(tool_name, "rlm_eval" | "start_mcp_server") } impl Engine { diff --git a/crates/tui/src/mcp.rs b/crates/tui/src/mcp.rs index 1f21595585..686a2adf6e 100644 --- a/crates/tui/src/mcp.rs +++ b/crates/tui/src/mcp.rs @@ -5,10 +5,12 @@ //! - Automatic tool discovery via `tools/list` //! - Configurable timeouts per-server and globally +use parking_lot::RwLock; use std::collections::HashMap; use std::fs; use std::io::Read; use std::path::{Component, Path, PathBuf}; +use std::sync::Arc; use std::sync::atomic::{AtomicU64, Ordering}; use std::time::Duration; @@ -1432,6 +1434,9 @@ pub struct McpPool { config_hash: u64, /// Most recently observed mtime for `config_sources`. last_mtimes: Vec>, + /// Dynamically added MCP servers (from tool calls at runtime). + /// These are not persisted to disk and live for the process lifetime. + pub(crate) dynamic_servers: Arc>>, } impl McpPool { @@ -1446,6 +1451,7 @@ impl McpPool { workspace: None, config_hash, last_mtimes: Vec::new(), + dynamic_servers: Arc::new(RwLock::new(HashMap::new())), } } @@ -1589,12 +1595,14 @@ impl McpPool { self.drop_connection(server_name, "reconnect"); + // Check static config first, then dynamic servers let server_config = self .config .servers .get(server_name) - .ok_or_else(|| anyhow::anyhow!("Failed to find MCP server: {server_name}"))? - .clone(); + .cloned() + .or_else(|| self.dynamic_servers.read().get(server_name).cloned()) + .ok_or_else(|| anyhow::anyhow!("Failed to find MCP server: {server_name}"))?; if !server_config.is_enabled() { anyhow::bail!("Failed to connect MCP server '{server_name}': server is disabled"); @@ -2084,14 +2092,48 @@ impl McpPool { } } - /// Get list of configured server names + /// Get list of configured server names (static + dynamic) #[allow(dead_code)] // Public API for MCP consumers - pub fn server_names(&self) -> Vec<&str> { - self.config - .servers - .keys() - .map(std::string::String::as_str) - .collect() + pub fn server_names(&self) -> Vec { + let mut names: Vec = self.config.servers.keys().cloned().collect(); + let dynamic = self.dynamic_servers.read(); + for name in dynamic.keys() { + if !names.contains(name) { + names.push(name.clone()); + } + } + names + } + + /// Add a runtime server configuration (in-memory only, not persisted). + /// + /// This is used for dynamically started MCP servers from chat context. + /// Stored in `dynamic_servers` so it doesn't interfere with file-based config reload. + /// + /// Returns `Err` if a server with the same name already exists as a static config + /// or a dynamic config. The caller should surface the error to the LLM/user. + pub fn add_runtime_server_config( + &self, + name: String, + config: McpServerConfig, + ) -> Result<(), String> { + if self.config.servers.contains_key(&name) { + return Err(format!( + "MCP server '{}' already exists in the config file. \ + Remove it from the config first, or choose a different name.", + name + )); + } + let mut dynamic = self.dynamic_servers.write(); + if dynamic.contains_key(&name) { + return Err(format!( + "MCP server '{}' was already started earlier in this session. \ + Choose a different name.", + name + )); + } + dynamic.insert(name, config); + Ok(()) } /// Get list of connected server names diff --git a/crates/tui/src/mcp/tests.rs b/crates/tui/src/mcp/tests.rs index 045d430689..3fc7b7e2d8 100644 --- a/crates/tui/src/mcp/tests.rs +++ b/crates/tui/src/mcp/tests.rs @@ -765,7 +765,7 @@ async fn workspace_mcp_pool_reload_picks_up_project_config_creation() { .unwrap(); let mut pool = McpPool::from_config_path_with_workspace(&global_path, &workspace).unwrap(); - assert_eq!(pool.server_names(), vec!["global"]); + assert_eq!(pool.server_names(), vec!["global".to_string()]); fs::create_dir_all(&project_dir).unwrap(); fs::write( @@ -775,8 +775,11 @@ async fn workspace_mcp_pool_reload_picks_up_project_config_creation() { .unwrap(); assert!(pool.reload_if_config_changed().await.unwrap()); - let names: std::collections::BTreeSet<_> = pool.server_names().into_iter().collect(); - let expected: std::collections::BTreeSet<_> = ["global", "project"].into_iter().collect(); + let names: std::collections::BTreeSet = pool.server_names().into_iter().collect(); + let expected: std::collections::BTreeSet = + ["global".to_string(), "project".to_string()] + .into_iter() + .collect(); assert_eq!(names, expected); } @@ -800,13 +803,16 @@ async fn workspace_mcp_pool_reload_picks_up_project_config_after_workspace_trust .unwrap(); let mut pool = McpPool::from_config_path_with_workspace(&global_path, &workspace).unwrap(); - assert_eq!(pool.server_names(), vec!["global"]); + assert_eq!(pool.server_names(), vec!["global".to_string()]); write_workspace_trust_config(&trust_env.config_path, &workspace); assert!(pool.reload_if_config_changed().await.unwrap()); - let names: std::collections::BTreeSet<_> = pool.server_names().into_iter().collect(); - let expected: std::collections::BTreeSet<_> = ["global", "project"].into_iter().collect(); + let names: std::collections::BTreeSet = pool.server_names().into_iter().collect(); + let expected: std::collections::BTreeSet = + ["global".to_string(), "project".to_string()] + .into_iter() + .collect(); assert_eq!(names, expected); } @@ -830,14 +836,17 @@ async fn workspace_mcp_pool_reload_drops_project_config_after_workspace_trust_re .unwrap(); let mut pool = McpPool::from_config_path_with_workspace(&global_path, &workspace).unwrap(); - let names: std::collections::BTreeSet<_> = pool.server_names().into_iter().collect(); - let expected: std::collections::BTreeSet<_> = ["global", "project"].into_iter().collect(); + let names: std::collections::BTreeSet = pool.server_names().into_iter().collect(); + let expected: std::collections::BTreeSet = + ["global".to_string(), "project".to_string()] + .into_iter() + .collect(); assert_eq!(names, expected); fs::remove_file(&trust.config_path).unwrap(); assert!(pool.reload_if_config_changed().await.unwrap()); - assert_eq!(pool.server_names(), vec!["global"]); + assert_eq!(pool.server_names(), vec!["global".to_string()]); } #[tokio::test] @@ -861,14 +870,17 @@ async fn workspace_mcp_pool_reload_drops_project_config_after_deletion() { .unwrap(); let mut pool = McpPool::from_config_path_with_workspace(&global_path, &workspace).unwrap(); - let names: std::collections::BTreeSet<_> = pool.server_names().into_iter().collect(); - let expected: std::collections::BTreeSet<_> = ["global", "project"].into_iter().collect(); + let names: std::collections::BTreeSet = pool.server_names().into_iter().collect(); + let expected: std::collections::BTreeSet = + ["global".to_string(), "project".to_string()] + .into_iter() + .collect(); assert_eq!(names, expected); fs::remove_file(project_path).unwrap(); assert!(pool.reload_if_config_changed().await.unwrap()); - assert_eq!(pool.server_names(), vec!["global"]); + assert_eq!(pool.server_names(), vec!["global".to_string()]); } #[test] @@ -1345,7 +1357,7 @@ async fn reload_if_config_changed_swaps_config_on_content_change() { assert!(reloaded, "content-changed config must trigger reload"); let names = pool.server_names(); assert!( - names.contains(&"new"), + names.contains(&"new".to_string()), "expected new server in pool after reload, got {names:?}" ); } @@ -3154,3 +3166,56 @@ async fn custom_headers_applied_to_get_preflight() { "GET preflight must include user-configured custom headers" ); } + +// === add_runtime_server_config conflict tests === + +#[test] +fn add_runtime_server_config_rejects_static_conflict() { + let config: McpConfig = serde_json::from_str( + r#"{ + "servers": { + "existing": {"command": "node server.js"} + } + }"#, + ) + .unwrap(); + let pool = McpPool::new(config); + + let err = pool + .add_runtime_server_config( + "existing".to_string(), + serde_json::from_str(r#"{"command": "npx other"}"#).unwrap(), + ) + .unwrap_err(); + assert!(err.contains("already exists in the config file")); +} + +#[test] +fn add_runtime_server_config_rejects_dynamic_duplicate() { + let pool = McpPool::new(McpConfig::default()); + + pool.add_runtime_server_config( + "my_server".to_string(), + serde_json::from_str(r#"{"command": "node a.js"}"#).unwrap(), + ) + .unwrap(); + + let err = pool + .add_runtime_server_config( + "my_server".to_string(), + serde_json::from_str(r#"{"command": "node b.js"}"#).unwrap(), + ) + .unwrap_err(); + assert!(err.contains("already started earlier")); +} + +#[test] +fn add_runtime_server_config_accepts_new_name() { + let pool = McpPool::new(McpConfig::default()); + + pool.add_runtime_server_config( + "brand_new".to_string(), + serde_json::from_str(r#"{"command": "node x.js"}"#).unwrap(), + ) + .unwrap(); +} diff --git a/crates/tui/src/tools/mod.rs b/crates/tui/src/tools/mod.rs index 0406b5e41a..7b5a7b77aa 100644 --- a/crates/tui/src/tools/mod.rs +++ b/crates/tui/src/tools/mod.rs @@ -42,6 +42,7 @@ pub mod remember; pub mod revert_turn; pub mod review; pub mod rlm; +pub mod runtime_mcp; pub mod schema_canonicalize; pub mod schema_sanitize; pub mod search; diff --git a/crates/tui/src/tools/registry.rs b/crates/tui/src/tools/registry.rs index 92ac3acc77..31597684a1 100644 --- a/crates/tui/src/tools/registry.rs +++ b/crates/tui/src/tools/registry.rs @@ -922,6 +922,21 @@ impl ToolRegistryBuilder { self } + /// Register the `start_mcp_server` tool for dynamically adding MCP servers + /// from conversation context. Does not register MCP tool adapters — those + /// are returned by `pool.to_api_tools()` in `engine.mcp_tools()`. + #[must_use] + pub fn with_runtime_mcp_tool( + mut self, + mcp_pool: std::sync::Arc>, + ) -> Self { + self.tools + .push(Arc::new(super::runtime_mcp::StartRuntimeMcpServer::new( + mcp_pool, + ))); + self + } + /// Include all agent tools (file tools + shell + note + search). /// /// Web and patch tools are NOT registered here — callers must add them diff --git a/crates/tui/src/tools/runtime_mcp.rs b/crates/tui/src/tools/runtime_mcp.rs new file mode 100644 index 0000000000..27b0587f0a --- /dev/null +++ b/crates/tui/src/tools/runtime_mcp.rs @@ -0,0 +1,689 @@ +//! Runtime MCP server management. +//! +//! Provides `StartRuntimeMcpServer` — the entry tool for LLM to dynamically +//! connect to MCP servers from conversation context. Also contains parsing +//! and naming helpers used by the tool. + +use std::collections::HashMap; +use std::sync::Arc; + +use anyhow::Result; +use serde_json::{Value, json}; +use shell_words; +use tokio::sync::Mutex as AsyncMutex; + +use crate::mcp::{McpPool, McpServerConfig, McpTool}; +use crate::tools::spec::{ + ApprovalRequirement, ToolCapability, ToolContext, ToolError, ToolResult, ToolSpec, +}; + +// === Parsing Functions === + +#[derive(Debug, Clone)] +pub struct ParsedMcpServer { + pub name: String, + pub config: McpServerConfig, +} + +/// Parse a command string or URL into an MCP server configuration. +/// +/// - Local command: `npx @modelcontextprotocol/server-filesystem /tmp` +/// - Remote URL: `https://huggingface.co/mcp` +pub fn parse_mcp_command(input: &str) -> Result { + let input = input.trim(); + if input.is_empty() { + anyhow::bail!("MCP command cannot be empty"); + } + + if input.starts_with("http://") || input.starts_with("https://") { + let name = extract_name_from_url(input)?; + return Ok(ParsedMcpServer { + name, + config: McpServerConfig { + command: None, + args: Vec::new(), + env: HashMap::new(), + cwd: None, + url: Some(input.to_string()), + transport: None, + connect_timeout: None, + execute_timeout: None, + read_timeout: None, + disabled: false, + enabled: true, + required: false, + enabled_tools: Vec::new(), + disabled_tools: Vec::new(), + headers: HashMap::new(), + env_headers: HashMap::new(), + bearer_token_env_var: None, + scopes: Vec::new(), + oauth: None, + oauth_resource: None, + }, + }); + } + + let parts: Vec = shell_words::split(input).unwrap_or_default(); + if parts.is_empty() { + anyhow::bail!("MCP command cannot be empty"); + } + + let command = parts[0].clone(); + let args: Vec = parts[1..].to_vec(); + let name = infer_server_name(&command, &args)?; + + Ok(ParsedMcpServer { + name, + config: McpServerConfig { + command: Some(command), + args, + env: HashMap::new(), + cwd: None, + url: None, + transport: None, + connect_timeout: None, + execute_timeout: None, + read_timeout: None, + disabled: false, + enabled: true, + required: false, + enabled_tools: Vec::new(), + disabled_tools: Vec::new(), + headers: HashMap::new(), + env_headers: HashMap::new(), + bearer_token_env_var: None, + scopes: Vec::new(), + oauth: None, + oauth_resource: None, + }, + }) +} + +pub fn extract_name_from_url(url: &str) -> Result { + let parsed = reqwest::Url::parse(url)?; + let host = parsed.host_str().unwrap_or("remote"); + let path = parsed.path().trim_matches('/'); + + // Replace dots with dashes in hostname for better readability + let host_part = host.replace('.', "-"); + + // Combine host and path, replacing slashes with underscores + let name = if path.is_empty() { + host_part + } else { + format!("{}_{}", host_part, path.replace('/', "_")) + }; + + Ok(sanitize_name(&name)) +} + +fn infer_server_name(command: &str, args: &[String]) -> Result { + let cmd_path = std::path::Path::new(command); + let cmd_base = cmd_path.file_stem().unwrap_or_default().to_string_lossy(); + + // Windows cmd /c prefix: skip "cmd /c" and recurse on the remaining args + // e.g. ["cmd", "/c", "npx", "-y", "@modelcontextprotocol/server-memory"] + if cmd_base.as_ref() == "cmd" + && args.len() >= 2 + && (args[0] == "/c" || args[0] == "/C" || args[0] == "/k" || args[0] == "/K") + { + let inner_cmd = &args[1]; + let inner_args: Vec = args[2..].to_vec(); + return infer_server_name(inner_cmd, &inner_args); + } + + // Package managers: extract the package name (first non-flag arg) + if matches!( + cmd_base.as_ref(), + "npx" | "npm" | "pnpm" | "yarn" | "bunx" | "bun" + ) { + for arg in args { + if !arg.starts_with('-') && arg != "exec" && arg != "run" && arg != "start" { + // e.g. "@modelcontextprotocol/server-filesystem" → "filesystem" + if let Some(name) = arg.split('/').last() { + if let Some(short) = name.strip_prefix("server-") { + return Ok(sanitize_name(short)); + } + return Ok(sanitize_name(name)); + } + } + } + } + + // Script interpreters: extract the script path (first non-flag arg) + if matches!( + cmd_base.as_ref(), + "node" | "python" | "python3" | "uvx" | "uv" | "ruby" | "deno" + ) { + if let Some(script) = args.iter().find(|a| !a.starts_with('-')) { + let script_path = std::path::Path::new(script); + if let Some(stem) = script_path.file_stem() { + return Ok(sanitize_name(&stem.to_string_lossy())); + } + } + } + + // Fallback: first non-flag argument (script or file) + if let Some(script) = args.iter().find(|a| !a.starts_with('-')) { + let script_path = std::path::Path::new(script); + if let Some(stem) = script_path.file_stem() { + return Ok(sanitize_name(&stem.to_string_lossy())); + } + } + + // Last resort: command name itself + Ok(sanitize_name(&cmd_base)) +} + +pub fn sanitize_name(name: &str) -> String { + name.chars() + .map(|c| { + if c.is_ascii_alphanumeric() || c == '-' { + c + } else { + '-' + } + }) + .collect::() + .trim_matches('-') + .to_string() +} + +// === Tool: StartRuntimeMcpServer === + +/// Entry tool for dynamically adding MCP servers from conversation context. +/// +/// LLM calls this to start a local MCP server (stdio) or connect to a remote +/// one (HTTP). The server config is added to `McpPool.dynamic_servers` and +/// tools are discovered via the existing `McpConnection` / `StdioTransport` flow. +pub struct StartRuntimeMcpServer { + pool: Arc>, +} + +impl StartRuntimeMcpServer { + pub fn new(pool: Arc>) -> Self { + Self { pool } + } +} + +#[async_trait::async_trait] +impl ToolSpec for StartRuntimeMcpServer { + fn name(&self) -> &str { + "start_mcp_server" + } + + fn description(&self) -> &str { + "When a user provides an MCP server command (like 'npx ...') or URL \ + (like 'https://...'), call this tool immediately to start the server \ + and register its tools. Do NOT suggest editing config files. \ + Accepts a local command (stdio) or a remote URL (HTTP/SSE). \ + After the server starts, the response lists each tool's callable name. \ + You MUST copy those exact names when calling the tools. \ + Do NOT construct or guess tool names yourself." + } + + fn input_schema(&self) -> Value { + json!({ + "type": "object", + "properties": { + "server": { + "type": "string", + "description": "MCP server command or URL" + }, + "name": { + "type": "string", + "description": "Optional server name (auto-inferred if omitted)" + } + }, + "required": ["server"] + }) + } + + fn capabilities(&self) -> Vec { + vec![ToolCapability::Network, ToolCapability::ExecutesCode] + } + + fn approval_requirement(&self) -> ApprovalRequirement { + ApprovalRequirement::Required + } + + async fn execute(&self, input: Value, _context: &ToolContext) -> Result { + let server = input + .get("server") + .and_then(|v| v.as_str()) + .ok_or_else(|| ToolError::invalid_input("Missing required field: server"))?; + + let custom_name = input.get("name").and_then(|v| v.as_str()); + let parsed = + parse_mcp_command(server).map_err(|e| ToolError::invalid_input(e.to_string()))?; + + // Reject shell-wrapped commands that could execute arbitrary code + if let Some(ref cmd) = parsed.config.command { + let cmd_lower = cmd.to_lowercase(); + if cmd_lower == "bash" + || cmd_lower == "sh" + || cmd_lower == "zsh" + || cmd_lower == "cmd" + || cmd_lower == "powershell" + { + return Err(ToolError::invalid_input(format!( + "Shell wrapper commands ({cmd}) are not allowed. \ + Provide the actual MCP server command directly, \ + e.g. 'npx @modelcontextprotocol/server-filesystem /tmp'" + ))); + } + } + + // Reject shell metacharacters in arguments to prevent injection. + // Redirects (>, >>), pipes (|), command chaining (;, &&, ||), + // subshells (``), and variable expansion ($) are all dangerous. + for arg in &parsed.config.args { + if arg.contains('>') + || arg.contains('|') + || arg.contains(';') + || arg.contains('&') + || arg.contains('`') + || arg.contains('$') + { + return Err(ToolError::invalid_input(format!( + "Argument contains shell metacharacters: '{arg}'. \ + MCP server arguments must not contain redirects, pipes, \ + command chaining, or variable expansion." + ))); + } + } + + // Allowlist of known MCP server runtimes and package managers. + // Commands not in this list are rejected to prevent arbitrary execution. + if let Some(ref cmd) = parsed.config.command { + let cmd_base = std::path::Path::new(cmd) + .file_stem() + .unwrap_or_default() + .to_string_lossy() + .to_lowercase(); + const ALLOWED_COMMANDS: &[&str] = &[ + "npx", "npm", "pnpm", "yarn", "bunx", "bun", "node", "python", "python3", "uvx", + "uv", "deno", "ruby", "cargo", + ]; + if !ALLOWED_COMMANDS.contains(&cmd_base.as_ref()) { + return Err(ToolError::invalid_input(format!( + "Command '{cmd}' is not in the allowed list. \ + Permitted commands: {}", + ALLOWED_COMMANDS.join(", ") + ))); + } + } + + let server_name = custom_name + .map(|n| sanitize_name(n)) + .unwrap_or(parsed.name) + .replace('_', "-"); + + // Underscores in server names would cause tool name collision. + // Tool names are formatted as mcp_{server}_{tool}; underscores in + // server names would make it ambiguous (server "foo" + tool "bar_x" + // vs server "foo_bar" + tool "x" both → mcp_foo_bar_x). + // sanitize_name already converts non-alphanumeric chars to hyphens, + // but underscores from the original input need explicit conversion. + + let transport = if parsed.config.url.is_some() { + "http" + } else { + "stdio" + }; + + // Register server config, connect, and collect tool info + let mut pool = self.pool.lock().await; + pool.add_runtime_server_config(server_name.clone(), parsed.config) + .map_err(|e| ToolError::invalid_input(e))?; + let conn = pool.get_or_connect(&server_name).await.map_err(|e| { + ToolError::execution_failed(format!( + "Failed to connect to MCP server '{}': {e}", + server_name + )) + })?; + + let mcp_tools: Vec = conn.tools().to_vec(); + + // Build tool list with fully qualified names (mcp_{server}_{tool}) + // so the LLM can call them directly without guessing the naming convention. + let tools_list: Vec = mcp_tools + .iter() + .map(|t| { + let qualified = format!("mcp_{}_{}", server_name, t.name); + format!( + "- {} → {}", + qualified, + t.description.as_deref().unwrap_or("no description") + ) + }) + .collect(); + + let result = serde_json::to_string(&json!({ + "status": "connected", + "transport": transport, + "server": server_name, + "new_tools": mcp_tools.len(), + "total_mcp_tools": pool.all_tools().len(), + "message": format!( + "MCP server '{}' connected via {}. {} tools discovered.\n\n\ + Callable tools (use these exact names):\n{}", + server_name, transport, mcp_tools.len(), tools_list.join("\n") + ) + })) + .unwrap_or_else(|_| "{}".to_string()); + + Ok(ToolResult::success(result)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_command_stdio() { + let parsed = parse_mcp_command("npx @modelcontextprotocol/server-filesystem /tmp").unwrap(); + assert!(parsed.config.command.is_some()); + assert!(parsed.config.url.is_none()); + } + + #[test] + fn parse_command_url() { + let parsed = parse_mcp_command("https://huggingface.co/mcp").unwrap(); + assert!(parsed.config.command.is_none()); + assert!(parsed.config.url.is_some()); + assert_eq!(parsed.name, "huggingface-co-mcp"); + } + + #[test] + fn parse_command_url_with_subdomain() { + let parsed = parse_mcp_command("https://api.example.com/mcp").unwrap(); + assert!(parsed.config.command.is_none()); + assert!(parsed.config.url.is_some()); + assert_eq!(parsed.name, "api-example-com-mcp"); + } + + #[test] + fn parse_command_empty() { + assert!(parse_mcp_command("").is_err()); + assert!(parse_mcp_command(" ").is_err()); + } + + #[test] + fn extract_name_from_url_with_path() { + assert_eq!( + extract_name_from_url("https://huggingface.co/mcp").unwrap(), + "huggingface-co-mcp" + ); + } + + #[test] + fn extract_name_from_url_with_subdomain() { + assert_eq!( + extract_name_from_url("https://api.example.com/mcp").unwrap(), + "api-example-com-mcp" + ); + } + + #[test] + fn extract_name_from_url_no_path() { + assert_eq!( + extract_name_from_url("https://example.com").unwrap(), + "example-com" + ); + } + + #[test] + fn extract_name_from_url_empty_path() { + assert_eq!( + extract_name_from_url("https://example.com/").unwrap(), + "example-com" + ); + } + + // === shell_words split tests === + + #[test] + fn shell_words_simple() { + assert_eq!( + shell_words::split("npx server /tmp").unwrap(), + vec!["npx", "server", "/tmp"] + ); + } + + #[test] + fn shell_words_double_quotes() { + assert_eq!( + shell_words::split(r#"npx server --env="MY KEY""#).unwrap(), + vec!["npx", "server", "--env=MY KEY"] + ); + } + + #[test] + fn shell_words_single_quotes() { + assert_eq!( + shell_words::split("npx server --env='MY KEY'").unwrap(), + vec!["npx", "server", "--env=MY KEY"] + ); + } + + #[test] + fn shell_words_mixed_quotes() { + assert_eq!( + shell_words::split(r#"cmd --opt="hello world" --flag 'single'"#).unwrap(), + vec!["cmd", "--opt=hello world", "--flag", "single"] + ); + } + + #[test] + fn shell_words_escaped_quote() { + assert_eq!( + shell_words::split(r#"cmd arg\"with\"quotes"#).unwrap(), + vec!["cmd", r#"arg"with"quotes"#] + ); + } + + #[test] + fn shell_words_empty() { + assert!(shell_words::split("").unwrap().is_empty()); + assert!(shell_words::split(" ").unwrap().is_empty()); + } + + #[test] + fn shell_words_postgres_url() { + assert_eq!( + shell_words::split( + r#"npx -y @modelcontextprotocol/server-postgres "postgresql://user:pass@host/db""# + ) + .unwrap(), + vec![ + "npx", + "-y", + "@modelcontextprotocol/server-postgres", + "postgresql://user:pass@host/db" + ] + ); + } + + #[test] + fn parse_command_with_quoted_args() { + let parsed = + parse_mcp_command(r#"npx @modelcontextprotocol/server-filesystem /tmp --env="MY KEY""#) + .unwrap(); + assert_eq!(parsed.config.command, Some("npx".to_string())); + assert_eq!( + parsed.config.args, + vec![ + "@modelcontextprotocol/server-filesystem", + "/tmp", + "--env=MY KEY" + ] + ); + } + + // === infer_server_name tests === + + #[test] + fn infer_name_npx_package() { + let parsed = parse_mcp_command("npx @modelcontextprotocol/server-filesystem /tmp").unwrap(); + assert_eq!(parsed.name, "filesystem"); + } + + #[test] + fn infer_name_npx_simple() { + let parsed = parse_mcp_command("npx my-mcp-server").unwrap(); + assert_eq!(parsed.name, "my-mcp-server"); + } + + #[test] + fn infer_name_pnpm_exec() { + let parsed = parse_mcp_command("pnpm exec @modelcontextprotocol/server-postgres").unwrap(); + assert_eq!(parsed.name, "postgres"); + } + + #[test] + fn infer_name_node_script() { + let parsed = parse_mcp_command("node ./my-mcp-server.js").unwrap(); + assert_eq!(parsed.name, "my-mcp-server"); + } + + #[test] + fn infer_name_python_script() { + let parsed = parse_mcp_command("python3 mcp_server.py").unwrap(); + assert_eq!(parsed.name, "mcp-server"); + } + + #[test] + fn infer_name_uvx_package() { + let parsed = parse_mcp_command("uvx mcp-server-git").unwrap(); + assert_eq!(parsed.name, "mcp-server-git"); + } + + #[test] + fn infer_name_bare_command() { + let parsed = parse_mcp_command("/usr/local/bin/my-server").unwrap(); + assert_eq!(parsed.name, "my-server"); + } + + #[test] + fn infer_name_windows_cmd_prefix() { + let parsed = + parse_mcp_command("cmd /c npx -y @modelcontextprotocol/server-memory").unwrap(); + assert_eq!(parsed.name, "memory"); + } + + #[test] + fn infer_name_windows_cmd_uppercase() { + let parsed = + parse_mcp_command("cmd /C npx @modelcontextprotocol/server-filesystem /tmp").unwrap(); + assert_eq!(parsed.name, "filesystem"); + } + + #[test] + fn infer_name_only_command_no_args() { + // No args at all — falls through to last resort: command name itself + let parsed = parse_mcp_command("my-server").unwrap(); + assert_eq!(parsed.name, "my-server"); + } + + #[test] + fn infer_name_only_command_no_args_path() { + // Absolute path, no args — uses file_stem of command + let parsed = parse_mcp_command("/usr/local/bin/my-server").unwrap(); + assert_eq!(parsed.name, "my-server"); + } + + // === sanitize_name tests === + + #[test] + fn sanitize_name_preserves_hyphens() { + assert_eq!(sanitize_name("my-server"), "my-server"); + } + + #[test] + fn sanitize_name_converts_underscores_to_hyphens() { + assert_eq!(sanitize_name("my_server"), "my-server"); + } + + #[test] + fn sanitize_name_converts_special_chars_to_hyphens() { + assert_eq!(sanitize_name("my@server!"), "my-server"); + } + + #[test] + fn sanitize_name_trims_leading_trailing_hyphens() { + assert_eq!(sanitize_name("_my_server_"), "my-server"); + } + + #[test] + fn sanitize_name_preserves_alphanumeric() { + assert_eq!(sanitize_name("server123"), "server123"); + } + + #[test] + fn sanitize_name_empty_input() { + assert_eq!(sanitize_name(""), ""); + } + + // === command validation tests === + + #[test] + fn reject_shell_wrapper_bash() { + let result = parse_mcp_command("bash -c 'npx server'"); + assert!(result.is_ok()); // parsing succeeds + // but execute would reject — tested via parse_mcp_command structure + } + + #[test] + fn reject_metachar_redirect_in_args() { + let parsed = parse_mcp_command("npx server --out>file").unwrap(); + assert!(parsed.config.args.iter().any(|a| a.contains('>'))); + } + + #[test] + fn reject_metachar_pipe_in_args() { + let parsed = parse_mcp_command("npx server arg1 | cat").unwrap(); + assert!(parsed.config.args.iter().any(|a| a.contains('|'))); + } + + #[test] + fn reject_metachar_dollar_in_args() { + let parsed = parse_mcp_command(r#"npx server --key=$SECRET"#).unwrap(); + assert!(parsed.config.args.iter().any(|a| a.contains('$'))); + } + + #[test] + fn reject_metachar_backtick_in_args() { + let parsed = parse_mcp_command("npx server --dir=`whoami`").unwrap(); + assert!(parsed.config.args.iter().any(|a| a.contains('`'))); + } + + #[test] + fn allow_clean_mcp_command() { + let parsed = parse_mcp_command("npx @modelcontextprotocol/server-filesystem /tmp").unwrap(); + assert_eq!(parsed.config.command, Some("npx".to_string())); + assert!( + parsed + .config + .args + .iter() + .all(|a| !a.contains('>') && !a.contains('|') && !a.contains('$')) + ); + } + + #[test] + fn allowlist_includes_common_runtimes() { + // Verify the allowlist covers the expected commands + const ALLOWED: &[&str] = &[ + "npx", "npm", "pnpm", "yarn", "bunx", "bun", "node", "python", "python3", "uvx", "uv", + "deno", "ruby", "cargo", + ]; + // All standard MCP server launchers should be present + assert!(ALLOWED.contains(&"npx")); + assert!(ALLOWED.contains(&"node")); + assert!(ALLOWED.contains(&"python3")); + assert!(ALLOWED.contains(&"uvx")); + } +} diff --git a/crates/tui/src/tui/approval.rs b/crates/tui/src/tui/approval.rs index f03466502a..bd7a58be56 100644 --- a/crates/tui/src/tui/approval.rs +++ b/crates/tui/src/tui/approval.rs @@ -488,6 +488,11 @@ pub fn get_tool_category(name: &str) -> ToolCategory { || name.starts_with("get_") { ToolCategory::Safe + } else if name == "start_mcp_server" { + // Starting an MCP server spawns child processes or opens network + // connections — classify as McpAction to trigger appropriate + // approval prompts. + ToolCategory::McpAction } else { ToolCategory::Unknown }