diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 198c15b..9578c28 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: - id: rust-clippy name: Rust clippy description: Run cargo clippy on files included in the commit. clippy should be installed before-hand. - entry: cargo clippy --all-targets --features vulkan -- -Dclippy::all + entry: cargo clippy --all-targets --features vulkan,server -- -D warnings pass_filenames: false types: [file, rust] language: system diff --git a/Cargo.lock b/Cargo.lock index 58ebe11..f919a97 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -88,6 +88,7 @@ version = "0.0.7" dependencies = [ "anyhow", "arey-core", + "arey-mcp", "arey-tools-search", "async-stream", "async-trait", @@ -111,6 +112,7 @@ dependencies = [ "tracing", "tracing-subscriber", "wiremock", + "yare", ] [[package]] @@ -152,6 +154,23 @@ dependencies = [ "yare", ] +[[package]] +name = "arey-mcp" +version = "0.0.1" +dependencies = [ + "anyhow", + "arey-core", + "arey-mcp", + "async-trait", + "rmcp", + "serde", + "serde_json", + "serde_yaml", + "tokio", + "tracing", + "yare", +] + [[package]] name = "arey-tools-search" version = "0.0.7" @@ -391,6 +410,7 @@ dependencies = [ "iana-time-zone", "js-sys", "num-traits", + "serde", "wasm-bindgen", "windows-link", ] @@ -522,8 +542,18 @@ version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc7f46116c46ff9ab3eb1597a45688b6715c6e628b5c133e288e709a29bcb4ee" dependencies = [ - "darling_core", - "darling_macro", + "darling_core 0.20.11", + "darling_macro 0.20.11", +] + +[[package]] +name = "darling" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25ae13da2f202d56bd7f91c25fba009e7717a1e4a1cc98a76d844b65ae912e9d" +dependencies = [ + "darling_core 0.23.0", + "darling_macro 0.23.0", ] [[package]] @@ -540,13 +570,37 @@ dependencies = [ "syn", ] +[[package]] +name = "darling_core" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9865a50f7c335f53564bb694ef660825eb8610e0a53d3e11bf1b0d3df31e03b0" +dependencies = [ + "ident_case", + "proc-macro2", + "quote", + "strsim", + "syn", +] + [[package]] name = "darling_macro" version = "0.20.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" dependencies = [ - "darling_core", + "darling_core 0.20.11", + "quote", + "syn", +] + +[[package]] +name = "darling_macro" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3984ec7bd6cfa798e62b4a642426a5be0e68f9401cfc2a01e3fa9ea2fcdb8d" +dependencies = [ + "darling_core 0.23.0", "quote", "syn", ] @@ -593,7 +647,7 @@ version = "0.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2d5bcf7b024d6835cfb3d473887cd966994907effbe9227e8c8219824d06c4e8" dependencies = [ - "darling", + "darling 0.20.11", "proc-macro2", "quote", "syn", @@ -641,6 +695,12 @@ dependencies = [ "syn", ] +[[package]] +name = "dyn-clone" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" + [[package]] name = "either" version = "1.15.0" @@ -1735,6 +1795,12 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "pastey" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b867cad97c0791bbd3aaa6472142568c6c9e8f71937e98379f584cfb0cf35bec" + [[package]] name = "path-absolutize" version = "3.1.1" @@ -1842,6 +1908,20 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "process-wrap" +version = "9.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e842efad9119158434d193c6682e2ebee4b44d6ad801d7b349623b3f57cdf55" +dependencies = [ + "futures", + "indexmap", + "nix", + "tokio", + "tracing", + "windows", +] + [[package]] name = "quick-xml" version = "0.38.4" @@ -2024,6 +2104,26 @@ dependencies = [ "thiserror 2.0.18", ] +[[package]] +name = "ref-cast" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f354300ae66f76f1c85c5f84693f0ce81d747e2c3f21a45fef496d89c960bf7d" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7186006dcb21920990093f30e3dea63b7d6e977bf1256be20c3563a5db070da" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "regex" version = "1.12.3" @@ -2126,6 +2226,43 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rmcp" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f542f74cf247da16f19bbc87e298cd201e912314f4083e88cdd671f44f5fcb53" +dependencies = [ + "async-trait", + "base64", + "chrono", + "futures", + "pastey", + "pin-project-lite", + "process-wrap", + "rmcp-macros", + "schemars", + "serde", + "serde_json", + "thiserror 2.0.18", + "tokio", + "tokio-stream", + "tokio-util", + "tracing", +] + +[[package]] +name = "rmcp-macros" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2391e4ae47f314e70eaafb6c7bd82e495e770b935448864446302143019151f" +dependencies = [ + "darling 0.23.0", + "proc-macro2", + "quote", + "serde_json", + "syn", +] + [[package]] name = "rustc-hash" version = "2.1.2" @@ -2255,6 +2392,32 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "schemars" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2b42f36aa1cd011945615b92222f6bf73c599a102a300334cd7f8dbeec726cc" +dependencies = [ + "chrono", + "dyn-clone", + "ref-cast", + "schemars_derive", + "serde", + "serde_json", +] + +[[package]] +name = "schemars_derive" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d115b50f4aaeea07e79c1912f645c7513d81715d0420f8bc77a18c6260b307f" +dependencies = [ + "proc-macro2", + "quote", + "serde_derive_internals", + "syn", +] + [[package]] name = "secrecy" version = "0.10.3" @@ -2324,6 +2487,17 @@ dependencies = [ "syn", ] +[[package]] +name = "serde_derive_internals" +version = "0.29.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "serde_json" version = "1.0.149" @@ -3045,6 +3219,27 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "windows" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "527fadee13e0c05939a6a05d5bd6eec6cd2e3dbd648b9f8e447c6518133d8580" +dependencies = [ + "windows-collections", + "windows-core", + "windows-future", + "windows-numerics", +] + +[[package]] +name = "windows-collections" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23b2d95af1a8a14a3c7367e1ed4fc9c20e0a26e79551b1454d72583c97cc6610" +dependencies = [ + "windows-core", +] + [[package]] name = "windows-core" version = "0.62.2" @@ -3058,6 +3253,17 @@ dependencies = [ "windows-strings", ] +[[package]] +name = "windows-future" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1d6f90251fe18a279739e78025bd6ddc52a7e22f921070ccdc67dde84c605cb" +dependencies = [ + "windows-core", + "windows-link", + "windows-threading", +] + [[package]] name = "windows-implement" version = "0.60.2" @@ -3086,6 +3292,16 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-numerics" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e2e40844ac143cdb44aead537bbf727de9b044e107a0f1220392177d15b0f26" +dependencies = [ + "windows-core", + "windows-link", +] + [[package]] name = "windows-result" version = "0.4.1" @@ -3164,6 +3380,15 @@ dependencies = [ "windows_x86_64_msvc 0.53.1", ] +[[package]] +name = "windows-threading" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3949bd5b99cafdf1c7ca86b43ca564028dfe27d66958f2470940f73d86d75b37" +dependencies = [ + "windows-link", +] + [[package]] name = "windows_aarch64_gnullvm" version = "0.52.6" diff --git a/Cargo.toml b/Cargo.toml index 069d556..bdc2cf8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [workspace] -members = ["crates/core", "crates/arey", "crates/tools-search"] +members = ["crates/core", "crates/arey", "crates/tools-search", "crates/mcp"] resolver = "2" [workspace.package] diff --git a/crates/arey/Cargo.toml b/crates/arey/Cargo.toml index b0f0295..d52014e 100644 --- a/crates/arey/Cargo.toml +++ b/crates/arey/Cargo.toml @@ -9,6 +9,7 @@ path = "src/main.rs" [dependencies] arey-core = { path = "../core", version = "*" } +arey-mcp = { path = "../mcp" } arey-tools-search = { path = "../tools-search" } anyhow.workspace = true async-stream.workspace = true @@ -40,11 +41,13 @@ syntect = { version = "5.3.0", default-features = false, features = [ ] } [dev-dependencies] +arey-mcp = { path = "../mcp", features = ["test_utils"] } async-trait.workspace = true tempfile.workspace = true serde_json.workspace = true wiremock.workspace = true once_cell.workspace = true +yare.workspace = true [features] cuda = ["arey-core/cuda"] diff --git a/crates/arey/src/cli/chat/commands.rs b/crates/arey/src/cli/chat/commands.rs index 2608ee2..3b733ac 100644 --- a/crates/arey/src/cli/chat/commands.rs +++ b/crates/arey/src/cli/chat/commands.rs @@ -47,6 +47,16 @@ pub enum Command { /// Tool names to set, or "clear" names: Vec, }, + /// Manage MCP servers. + /// + /// With no arguments, lists all MCP servers and their status. + /// Use "enable", "disable", "start", "stop" with a server name. + Mcp { + /// Subcommand: enable, disable, start, stop + subcommand: Option, + /// Server name for enable/disable/start/stop + server: Option, + }, /// Manage chat agents. /// /// With no arguments, shows the current agent and available agents with their sources. @@ -88,6 +98,13 @@ impl Command { Command::Model { ref name } => self.execute_model(session, name).await, Command::Profile { ref name } => self.execute_profile(session, name).await, Command::Tool { ref names } => self.execute_tool(session, names).await, + Command::Mcp { + ref subcommand, + ref server, + } => { + self.execute_mcp(session, subcommand.as_deref(), server.as_deref()) + .await + } Command::Agent { ref name } => self.execute_agent(session, name).await, Command::System { ref prompt } => self.execute_system(session, prompt).await, Command::Think { ref mode } => self.execute_think(session, mode).await, @@ -124,10 +141,17 @@ impl Command { } } else if names.is_empty() { // Show current and available tools - let chat_guard = session.lock().await; + let mut chat_guard = session.lock().await; let current_tools = chat_guard.tools(); let available_tools = chat_guard.available_tool_names(); + // Get MCP server info + let mcp_servers = if let Some(mgr) = chat_guard.mcp_registry() { + mgr.list().await + } else { + vec![] + }; + if current_tools.is_empty() { println!("No tools are currently active."); } else { @@ -139,6 +163,19 @@ impl Command { if !available_tools.is_empty() { println!("Available tools: {}", available_tools.join(", ")); } + + // Show MCP server status + if !mcp_servers.is_empty() { + println!("\nMCP Servers:"); + for server in mcp_servers { + let status = if server.enabled { + "enabled" + } else { + "disabled" + }; + println!(" {}: {}", server.name, status); + } + } } else { // Set new tools let mut chat_guard = session.lock().await; @@ -379,6 +416,112 @@ impl Command { Ok(true) } + async fn execute_mcp( + &self, + session: Arc>>, + subcommand: Option<&str>, + server: Option<&str>, + ) -> Result { + let mut chat_guard = session.lock().await; + + match (subcommand, server) { + (None, None) => { + // List all MCP servers + let mcp_registry = chat_guard.mcp_registry(); + if let Some(manager) = mcp_registry { + let servers = manager.list().await; + if servers.is_empty() { + println!("No MCP servers configured."); + } else { + println!("MCP Servers:"); + for s in servers { + let status = if s.running { + if s.enabled { + "running, enabled" + } else { + "running, disabled" + } + } else { + "stopped" + }; + println!(" {}: {} ({} tools)", s.name, status, s.tool_count); + } + } + } else { + println!("MCP not available"); + } + } + (Some("list"), None) => { + let mcp_registry = chat_guard.mcp_registry(); + if let Some(manager) = mcp_registry { + let servers = manager.list().await; + for s in servers { + let status = if s.running { + if s.enabled { + "running, enabled" + } else { + "running, disabled" + } + } else { + "stopped" + }; + println!(" {}: {} ({} tools)", s.name, status, s.tool_count); + } + } + } + (Some("enable"), Some(name)) => { + let mcp_registry = chat_guard.mcp_registry(); + if let Some(manager) = mcp_registry { + if let Err(e) = manager.enable(name).await { + eprintln!("Failed to enable MCP server '{}': {}", name, e); + } else { + println!("MCP server '{}' enabled.", name); + } + } else { + eprintln!("MCP not available"); + } + } + (Some("disable"), Some(name)) => { + let mcp_registry = chat_guard.mcp_registry(); + if let Some(manager) = mcp_registry { + if let Err(e) = manager.disable(name).await { + eprintln!("Failed to disable MCP server '{}': {}", name, e); + } else { + println!("MCP server '{}' disabled.", name); + } + } else { + eprintln!("MCP not available"); + } + } + (Some("start"), Some(name)) => { + // This would require access to config, which we don't have here + eprintln!("Use '/mcp enable {}' to enable a running server", name); + } + (Some("stop"), Some(name)) => { + let mcp_registry = chat_guard.mcp_registry(); + if let Some(manager) = mcp_registry { + if let Err(e) = manager.remove_server(name).await { + eprintln!("Failed to stop MCP server '{}': {}", name, e); + } else { + println!("MCP server '{}' stopped.", name); + } + } else { + eprintln!("MCP not available"); + } + } + (Some(sub), _) => { + eprintln!( + "Unknown MCP subcommand: {}. Use list, enable, disable, start, or stop.", + sub + ); + } + (None, Some(_)) => { + eprintln!("Use /mcp "); + } + } + Ok(true) + } + fn execute_exit(&self) -> Result { println!("Bye!"); Ok(false) @@ -1064,4 +1207,159 @@ USER: Run tool Ok(()) } + + #[tokio::test] + async fn test_mcp_command_no_mcp_registry() -> Result<()> { + let config = create_test_config_with_custom_agent()?; + let chat = Chat::new( + &config, + Some("test-model-1".to_string()), + ToolRegistry::new(), + )?; + let chat_session = Arc::new(Mutex::new(chat)); + + // Test /mcp with no MCP registry - should show "MCP not available" + let mcp_cmd = Command::Mcp { + subcommand: None, + server: None, + }; + assert!(mcp_cmd.execute(chat_session.clone()).await?); + + Ok(()) + } + + #[tokio::test] + async fn test_mcp_command_list_no_mcp_registry() -> Result<()> { + let config = create_test_config_with_custom_agent()?; + let chat = Chat::new( + &config, + Some("test-model-1".to_string()), + ToolRegistry::new(), + )?; + let chat_session = Arc::new(Mutex::new(chat)); + + let mcp_cmd = Command::Mcp { + subcommand: Some("list".to_string()), + server: None, + }; + assert!(mcp_cmd.execute(chat_session.clone()).await?); + + Ok(()) + } + + #[tokio::test] + async fn test_mcp_command_enable_no_mcp_registry() -> Result<()> { + let config = create_test_config_with_custom_agent()?; + let chat = Chat::new( + &config, + Some("test-model-1".to_string()), + ToolRegistry::new(), + )?; + let chat_session = Arc::new(Mutex::new(chat)); + + let mcp_cmd = Command::Mcp { + subcommand: Some("enable".to_string()), + server: Some("test-server".to_string()), + }; + assert!(mcp_cmd.execute(chat_session.clone()).await?); + + Ok(()) + } + + #[tokio::test] + async fn test_mcp_command_disable_no_mcp_registry() -> Result<()> { + let config = create_test_config_with_custom_agent()?; + let chat = Chat::new( + &config, + Some("test-model-1".to_string()), + ToolRegistry::new(), + )?; + let chat_session = Arc::new(Mutex::new(chat)); + + let mcp_cmd = Command::Mcp { + subcommand: Some("disable".to_string()), + server: Some("test-server".to_string()), + }; + assert!(mcp_cmd.execute(chat_session.clone()).await?); + + Ok(()) + } + + #[tokio::test] + async fn test_mcp_command_stop_no_mcp_registry() -> Result<()> { + let config = create_test_config_with_custom_agent()?; + let chat = Chat::new( + &config, + Some("test-model-1".to_string()), + ToolRegistry::new(), + )?; + let chat_session = Arc::new(Mutex::new(chat)); + + let mcp_cmd = Command::Mcp { + subcommand: Some("stop".to_string()), + server: Some("test-server".to_string()), + }; + assert!(mcp_cmd.execute(chat_session.clone()).await?); + + Ok(()) + } + + #[tokio::test] + async fn test_mcp_command_unknown_subcommand() -> Result<()> { + let config = create_test_config_with_custom_agent()?; + let chat = Chat::new( + &config, + Some("test-model-1".to_string()), + ToolRegistry::new(), + )?; + let chat_session = Arc::new(Mutex::new(chat)); + + let mcp_cmd = Command::Mcp { + subcommand: Some("invalid".to_string()), + server: None, + }; + assert!(mcp_cmd.execute(chat_session.clone()).await?); + + Ok(()) + } + + #[tokio::test] + async fn test_mcp_command_start_shows_hint() -> Result<()> { + let config = create_test_config_with_custom_agent()?; + let chat = Chat::new( + &config, + Some("test-model-1".to_string()), + ToolRegistry::new(), + )?; + let chat_session = Arc::new(Mutex::new(chat)); + + // /mcp start should show a hint to use enable instead + let mcp_cmd = Command::Mcp { + subcommand: Some("start".to_string()), + server: Some("test-server".to_string()), + }; + assert!(mcp_cmd.execute(chat_session.clone()).await?); + + Ok(()) + } + + #[tokio::test] + async fn test_mcp_command_missing_server_arg() -> Result<()> { + let config = create_test_config_with_custom_agent()?; + let chat = Chat::new( + &config, + Some("test-model-1".to_string()), + ToolRegistry::new(), + )?; + let chat_session = Arc::new(Mutex::new(chat)); + + // /mcp enable without server name should show error + let mcp_cmd = Command::Mcp { + subcommand: Some("enable".to_string()), + server: None, + }; + assert!(mcp_cmd.execute(chat_session.clone()).await?); + + Ok(()) + } } diff --git a/crates/arey/src/cli/chat/mod.rs b/crates/arey/src/cli/chat/mod.rs index 7171b83..d47a8ff 100644 --- a/crates/arey/src/cli/chat/mod.rs +++ b/crates/arey/src/cli/chat/mod.rs @@ -3,6 +3,7 @@ use crate::svc::chat::Chat; use anyhow::{Context, Result}; use arey_core::config::Config; use arey_core::registry::ToolRegistry; +use arey_mcp::McpRegistry; use std::io::stdout; use std::sync::Arc; use tokio::sync::Mutex; @@ -18,9 +19,15 @@ pub async fn execute( model: Option, config: &Config, tool_registry: ToolRegistry, + mcp_registry: Option, ) -> Result<()> { - let chat = + let mut chat = Chat::new(config, model, tool_registry).context("Failed to initialize chat service")?; + + if let Some(mcp) = mcp_registry { + chat = chat.with_mcp_registry(mcp); + } + let theme = get_theme("ansi"); // TODO: Theme from config let mut stdout = stdout(); let mut renderer = TerminalRenderer::new(&mut stdout, &theme); diff --git a/crates/arey/src/cli/chat/test_utils.rs b/crates/arey/src/cli/chat/test_utils.rs index fc39993..babd786 100644 --- a/crates/arey/src/cli/chat/test_utils.rs +++ b/crates/arey/src/cli/chat/test_utils.rs @@ -90,6 +90,7 @@ pub fn create_test_config() -> Result { task: task_mode, theme: "ansi".to_string(), tools: HashMap::new(), + mcp: serde_yaml::Value::Null, }) } diff --git a/crates/arey/src/cli/mod.rs b/crates/arey/src/cli/mod.rs index d6fcbf7..45c3526 100644 --- a/crates/arey/src/cli/mod.rs +++ b/crates/arey/src/cli/mod.rs @@ -7,6 +7,7 @@ pub mod ux; use anyhow::{Context, Result}; use arey_core::config::{Config, get_config}; use arey_core::registry::ToolRegistry; +use arey_mcp::McpRegistry; use clap::{Parser, Subcommand}; use crate::ext::get_tools; @@ -67,13 +68,25 @@ pub async fn run() -> Result<()> { let config = get_config(None).context("Failed to load configuration")?; // Initialize all available tools - let tool_registry = get_tools(&config).context("Failed to get builtin tools")?; + let mut tool_registry = get_tools(&config).context("Failed to get builtin tools")?; + + // Initialize MCP servers from config + let mcp_registry = McpRegistry::from_config(&config).await?; + + // Add MCP tools to the tool registry + if let Some(ref mcp) = mcp_registry { + for tool in mcp.get_all_tools() { + tool_registry.register(tool)?; + } + } match &cli.command { Commands::Run { instruction, model } => { run::execute(instruction.clone(), model.clone(), &config).await } - Commands::Chat { model } => execute_chat(model.clone(), &config, tool_registry).await, + Commands::Chat { model } => { + execute_chat(model.clone(), &config, tool_registry, mcp_registry).await + } Commands::Play { file, no_watch } => { play::execute(file.as_deref(), *no_watch, &config).await } @@ -84,13 +97,7 @@ async fn execute_chat( model: Option, config: &Config, tool_registry: ToolRegistry, + mcp_registry: Option, ) -> Result<()> { - crate::cli::chat::execute(model, config, tool_registry).await -} - -#[cfg(test)] -mod tests { - // TODO: Add integration tests for the CLI entrypoint `run`. - // This would involve running the binary with different arguments and - // checking exit codes and output. + crate::cli::chat::execute(model, config, tool_registry, mcp_registry).await } diff --git a/crates/arey/src/svc/chat.rs b/crates/arey/src/svc/chat.rs index f87cc1c..b93288d 100644 --- a/crates/arey/src/svc/chat.rs +++ b/crates/arey/src/svc/chat.rs @@ -6,6 +6,7 @@ use arey_core::model::ModelConfig; use arey_core::registry::ToolRegistry; use arey_core::session::{Session, SessionConfig, SessionEvent}; use arey_core::tools::Tool; +use arey_mcp::McpRegistry; use futures::{StreamExt, stream::BoxStream}; use std::collections::HashMap; use std::fmt; @@ -22,6 +23,7 @@ pub struct Chat<'a> { current_agent: Agent, session: Option, tool_registry: ToolRegistry, + mcp_registry: Option, } impl<'a> fmt::Debug for Chat<'a> { @@ -97,9 +99,16 @@ impl<'a> Chat<'a> { current_agent: agent, tool_registry, config, + mcp_registry: None, }) } + /// Builder method to set MCP server manager after creating Chat + pub fn with_mcp_registry(mut self, mcp_registry: McpRegistry) -> Self { + self.mcp_registry = Some(mcp_registry); + self + } + /// Loads the session, initializing the model. /// /// This should be called before any chat operations. It can be wrapped @@ -333,6 +342,11 @@ impl<'a> Chat<'a> { self.tool_registry.list() } + /// Get MCP server manager reference + pub fn mcp_registry(&mut self) -> Option<&mut McpRegistry> { + self.mcp_registry.as_mut() + } + /// Get available agents with their sources pub fn available_agents_with_sources(&self) -> Vec<(&str, &AgentSource)> { self.config @@ -593,4 +607,29 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_chat_with_mcp_builder() -> Result<()> { + let server = MockServer::start().await; + let config = get_test_config(&server).await?; + + // Create Chat without MCP + let mut chat = Chat::new(&config, Some("test-model".to_string()), ToolRegistry::new())?; + + // Initially no MCP manager + let mcp_before = chat.mcp_registry(); + assert!(mcp_before.is_none()); + + // Create a mock McpRegistry for testing + let mcp_registry = McpRegistry::new(); + + // Use builder to add MCP registry + chat = chat.with_mcp_registry(mcp_registry); + + // Verify MCP registry is now set + let mcp_after = chat.mcp_registry(); + assert!(mcp_after.is_some()); + + Ok(()) + } } diff --git a/crates/arey/src/svc/run.rs b/crates/arey/src/svc/run.rs index e59f73f..3ad7cb2 100644 --- a/crates/arey/src/svc/run.rs +++ b/crates/arey/src/svc/run.rs @@ -241,6 +241,7 @@ mod tests { }, theme: "light".to_string(), tools: HashMap::new(), + mcp: serde_yaml::Value::Null, }; let task = Task::new( @@ -298,6 +299,7 @@ mod tests { }, theme: "light".to_string(), tools: HashMap::new(), + mcp: serde_yaml::Value::Null, }; let mut task = Task::new( diff --git a/crates/core/data/config.yml b/crates/core/data/config.yml index 09984c8..b50e5b1 100644 --- a/crates/core/data/config.yml +++ b/crates/core/data/config.yml @@ -89,3 +89,26 @@ tools: default_language: "en" default_categories: "general" default_results: 10 + +# MCP (Model Context Protocol) servers +# +# MCP servers provide additional tools that can be used by the AI. +# Each server has a command to spawn and optional arguments/environment. +# Tools from MCP servers are prefixed with the server name (e.g., filesystem_read_file). +# +# Example MCP server configuration: +# mcp: +# filesystem: +# command: npx +# args: +# - -y +# - @modelcontextprotocol/server-filesystem +# - /home/user +# env: {} +# enabled: true +# memory: +# command: npx +# args: +# - -y +# - @modelcontextprotocol/server-memory +# enabled: false diff --git a/crates/core/src/config.rs b/crates/core/src/config.rs index 3b9d94f..056298f 100644 --- a/crates/core/src/config.rs +++ b/crates/core/src/config.rs @@ -101,6 +101,8 @@ pub struct Config { pub theme: String, #[serde(default)] pub tools: HashMap, + #[serde(default)] + pub mcp: serde_yaml::Value, } fn default_theme() -> String { @@ -253,6 +255,8 @@ struct RawConfig { theme: Option, #[serde(default)] tools: HashMap, + #[serde(default)] + mcp: serde_yaml::Value, } impl RawConfig { @@ -387,6 +391,7 @@ impl RawConfig { }, theme: self.theme.clone().unwrap_or_else(default_theme), tools: self.tools.clone(), + mcp: self.mcp.clone(), }) } } @@ -674,6 +679,7 @@ theme: dark }, theme: Some("dark".to_string()), tools: HashMap::new(), + mcp: serde_yaml::Value::Null, }; let temp_dir = tempfile::TempDir::new().unwrap(); @@ -710,6 +716,7 @@ theme: dark }, theme: None, tools: HashMap::new(), + mcp: serde_yaml::Value::Null, }; let err = raw_config.to_config(&PathBuf::from("/tmp")).unwrap_err(); @@ -738,6 +745,7 @@ theme: dark }, theme: None, tools: HashMap::new(), + mcp: serde_yaml::Value::Null, }; let err = raw_config.to_config(&PathBuf::from("/tmp")).unwrap_err(); @@ -763,6 +771,7 @@ theme: dark }, theme: Some("light".to_string()), tools: HashMap::new(), + mcp: serde_yaml::Value::Null, }; let temp_dir = tempfile::TempDir::new().unwrap(); diff --git a/crates/mcp/Cargo.toml b/crates/mcp/Cargo.toml new file mode 100644 index 0000000..ef927d1 --- /dev/null +++ b/crates/mcp/Cargo.toml @@ -0,0 +1,43 @@ +[package] +name = "arey-mcp" +version = "0.0.1" +edition.workspace = true + +[features] +default = [] +test_utils = ["rmcp/server", "rmcp/macros", "rmcp/transport-io"] +server = ["rmcp/server", "rmcp/transport-io", "rmcp/macros", "rmcp/schemars", "test_utils"] + +[dependencies] +arey-core = { path = "../core" } +rmcp = { version = "1.4.0", features = ["client", "transport-child-process", "transport-io"] } +anyhow.workspace = true +async-trait.workspace = true +tokio = { version = "1.50", features = [ + "macros", + "process", + "rt-multi-thread", + "sync", + "signal", + "time", +] } +tracing.workspace = true +serde.workspace = true +serde_json.workspace = true +serde_yaml.workspace = true + +[dev-dependencies] +arey-mcp = { path = ".", features = ["test_utils"] } +tokio = { version = "1.50", features = ["macros", "test-util"] } +serde_yaml.workspace = true +yare.workspace = true + +[[test]] +name = "mcp_integration" + +[[example]] +name = "mcp_weather" +required-features = ["server"] + +[[example]] +name = "mcp_client" diff --git a/crates/mcp/examples/mcp_client.rs b/crates/mcp/examples/mcp_client.rs new file mode 100644 index 0000000..e153e3a --- /dev/null +++ b/crates/mcp/examples/mcp_client.rs @@ -0,0 +1,56 @@ +//! MCP Client Example +//! +//! A simple client that connects to an MCP weather server and calls its tools. +//! Run with: cargo run --package arey-mcp --example mcp_client + +use std::collections::HashMap; + +use anyhow::Result; +use arey_mcp::{McpClient, McpServerConfig}; +use serde_json::json; + +#[tokio::main] +async fn main() -> Result<()> { + println!("=== MCP Weather Client ===\n"); + + // 1. Configure the MCP server connection (using the weather example server) + let mcp_config = McpServerConfig { + command: "cargo".to_string(), + args: vec![ + "run".to_string(), + "--package".to_string(), + "arey-mcp".to_string(), + "--example".to_string(), + "mcp_weather".to_string(), + "--features".to_string(), + "server".to_string(), + "--quiet".to_string(), + "--".to_string(), + "serve".to_string(), + ], + env: HashMap::new(), + enabled: true, + }; + + // 2. Connect to the server + println!("Connecting to weather server..."); + let mcp_client = McpClient::new("weather".to_string(), &mcp_config).await?; + + // 3. List available tools + let mcp_tools = mcp_client.tools(); + println!("Found {} tools:", mcp_tools.len()); + for tool in &mcp_tools { + println!(" - {}: {}", tool.name(), tool.description()); + } + println!(); + + // 4. Call a tool using McpClient + println!("Calling 'weather_get_weather' tool..."); + let args = json!({ "location": "London" }); + let result = mcp_client.call_tool("weather_get_weather", &args).await?; + + println!("Response: {}", serde_json::to_string_pretty(&result)?); + println!("\n✅ Client example finished!"); + + Ok(()) +} diff --git a/crates/mcp/examples/mcp_weather.rs b/crates/mcp/examples/mcp_weather.rs new file mode 100644 index 0000000..fd6f81f --- /dev/null +++ b/crates/mcp/examples/mcp_weather.rs @@ -0,0 +1,24 @@ +//! Weather MCP Server Example +//! +//! A simple MCP server that provides weather information. +//! Run with: cargo run --package arey-mcp --example mcp_weather --features server + +use std::env; + +use anyhow::Result; +use arey_mcp::mock::WeatherServer; +use rmcp::{service::ServiceExt, transport::stdio}; + +#[tokio::main] +async fn main() -> Result<()> { + if env::args().len() < 1 || env::args().nth(1).unwrap() != "serve" { + println!( + "Usage: cargo run --package arey-mcp --example mcp_weather --features server -- serve" + ); + return Ok(()); + } + + let service = WeatherServer.serve(stdio()).await?; + service.waiting().await?; + Ok(()) +} diff --git a/crates/mcp/src/client.rs b/crates/mcp/src/client.rs new file mode 100644 index 0000000..80a4110 --- /dev/null +++ b/crates/mcp/src/client.rs @@ -0,0 +1,353 @@ +use std::sync::Arc; + +use anyhow::Result; +use async_trait::async_trait; +use rmcp::{ + ServiceExt, + model::CallToolRequestParams, + service::{RoleClient, RunningService}, + transport::TokioChildProcess, +}; +use serde_json::Value; +use tokio::process::Command; +use tracing::{debug, info}; + +use crate::config::McpServerConfig; +use arey_core::tools::{Tool, ToolError}; + +pub struct McpClient { + name: String, + service: Arc>, + tools: Vec>, +} + +impl McpClient { + pub async fn new(name: String, config: &McpServerConfig) -> Result { + let mut cmd = Command::new(&config.command); + cmd.args(&config.args); + for (key, value) in &config.env { + cmd.env(key, value); + } + + let transport = TokioChildProcess::new(cmd)?; + let service = Arc::new(().serve(transport).await?); + + Self::with_service(name, service).await + } + + pub async fn with_service( + name: String, + service: Arc>, + ) -> Result { + let tools = Self::list_tools(&service, &name, service.clone()).await?; + + info!("MCP server '{}' connected with {} tools", name, tools.len()); + + Ok(Self { + name, + service, + tools, + }) + } + + async fn list_tools( + service: &Arc>, + server_name: &str, + service_arc: Arc>, + ) -> Result>> { + let tool_defs = service.list_all_tools().await?; + let mut tools = Vec::new(); + + for tool_def in tool_defs { + let raw_name = tool_def.name.to_string(); + let name = format!("{}_{}", server_name, raw_name); + let description = tool_def.description.unwrap_or_default().to_string(); + let input_schema = serde_json::to_value(&*tool_def.input_schema)?; + + let tool = Arc::new(McpTool::with_service( + name, + description, + input_schema, + Arc::new(server_name.to_string()), + service_arc.clone(), + )) as Arc; + tools.push(tool); + } + + Ok(tools) + } + + pub fn tools(&self) -> Vec> { + self.tools.clone() + } + + pub fn name(&self) -> &str { + &self.name + } + + pub async fn call_tool(&self, tool_name: &str, arguments: &Value) -> Result { + // Strip the server prefix to get the actual tool name + let actual_tool_name = if tool_name.starts_with(&format!("{}_", self.name)) { + tool_name.strip_prefix(&format!("{}_", self.name)).unwrap() + } else { + tool_name + }; + let args_map = match arguments { + Value::Object(m) => Some(m.clone()), + _ => { + let mut map = serde_json::Map::new(); + map.insert("input".to_string(), arguments.clone()); + Some(map) + } + }; + + let mut params = CallToolRequestParams::default(); + params.meta = None; + params.name = std::borrow::Cow::Owned(actual_tool_name.to_string()); + params.arguments = args_map; + params.task = None; + debug!( + "MCP {}: calling tool {} with arguments {:?}", + self.name, tool_name, arguments + ); + let result = self.service.call_tool(params).await?; + + let output = result.structured_content.unwrap_or_else(|| { + if let Some(content) = result.content.into_iter().next() { + if let Some(text) = content.as_text() { + serde_json::json!({ "text": text.text }) + } else { + serde_json::json!({ "error": "Unsupported content type" }) + } + } else { + serde_json::json!({ "error": "Empty response" }) + } + }); + + Ok(output) + } +} + +struct McpTool { + name: String, + description: String, + input_schema: serde_json::Value, + server_name: Arc, + #[allow(dead_code)] + service: Option>>, +} + +impl McpTool { + #[allow(dead_code)] + pub fn new( + name: String, + description: String, + input_schema: serde_json::Value, + server_name: Arc, + ) -> Self { + Self { + name, + description, + input_schema, + server_name, + service: None, + } + } + + pub fn with_service( + name: String, + description: String, + input_schema: serde_json::Value, + server_name: Arc, + service: Arc>, + ) -> Self { + Self { + name, + description, + input_schema, + server_name, + service: Some(service), + } + } + + async fn execute_internal(&self, arguments: &Value) -> Result { + let service = self.service.as_ref().ok_or_else(|| { + anyhow::anyhow!("MCP service not available - tool created without service connection") + })?; + + let actual_tool_name = if self.name.starts_with(&format!("{}_", self.server_name)) { + self.name + .strip_prefix(&format!("{}_", self.server_name)) + .unwrap() + } else { + &self.name + }; + + let args_map = match arguments { + Value::Object(m) => Some(m.clone()), + _ => { + let mut map = serde_json::Map::new(); + map.insert("location".to_string(), arguments.clone()); + Some(map) + } + }; + + let mut params = CallToolRequestParams::default(); + params.meta = None; + params.name = std::borrow::Cow::Owned(actual_tool_name.to_string()); + params.arguments = args_map; + params.task = None; + let result = service.call_tool(params).await?; + + let output = result.structured_content.unwrap_or_else(|| { + if let Some(content) = result.content.into_iter().next() { + if let Some(text) = content.as_text() { + serde_json::json!({ "text": text.text }) + } else { + serde_json::json!({ "error": "Unsupported content type" }) + } + } else { + serde_json::json!({ "error": "Empty response" }) + } + }); + + Ok(output) + } +} + +#[async_trait] +impl Tool for McpTool { + fn name(&self) -> String { + self.name.clone() + } + + fn description(&self) -> String { + self.description.clone() + } + + fn parameters(&self) -> Value { + self.input_schema.clone() + } + + async fn execute(&self, arguments: &Value) -> Result { + match self.execute_internal(arguments).await { + Ok(result) => Ok(result), + Err(e) => Err(ToolError::ExecutionError(e.to_string())), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arey_core::tools::Tool; + use serde_json::json; + use yare::parameterized; + + fn create_test_tool() -> McpTool { + McpTool::new( + "test_tool".to_string(), + "A test tool for testing".to_string(), + json!({ + "type": "object", + "properties": { + "input": { "type": "string" } + }, + "required": ["input"] + }), + Arc::new("test_server".to_string()), + ) + } + + #[test] + fn test_mcp_tool_name() { + let tool = create_test_tool(); + assert_eq!(tool.name(), "test_tool"); + } + + #[test] + fn test_mcp_tool_description() { + let tool = create_test_tool(); + assert_eq!(tool.description(), "A test tool for testing"); + } + + #[test] + fn test_mcp_tool_parameters() { + let tool = create_test_tool(); + let params = tool.parameters(); + + assert_eq!(params["type"], "object"); + assert!(params["properties"].is_object()); + assert!(params["required"].is_array()); + } + + #[tokio::test] + async fn test_mcp_tool_execute_returns_error() { + let tool = create_test_tool(); + let arguments = json!({ "input": "hello" }); + + let result = tool.execute(&arguments).await; + + // Without service connection, should return error + assert!(result.is_err() || result.as_ref().ok().and_then(|v| v.get("error")).is_some()); + } + + #[tokio::test] + async fn test_mcp_tool_execute_empty_args() { + let tool = create_test_tool(); + let arguments = json!({}); + + let result = tool.execute(&arguments).await; + + // Without service connection, should return error + assert!(result.is_err() || result.as_ref().ok().and_then(|v| v.get("error")).is_some()); + } + + #[test] + fn test_mcp_tool_different_inputs() { + let tool = McpTool::new( + "read_file".to_string(), + "Read a file from the filesystem".to_string(), + json!({ + "type": "object", + "properties": { + "path": { "type": "string", "description": "The file path" } + } + }), + Arc::new("fs".to_string()), + ); + + assert_eq!(tool.name(), "read_file"); + assert_eq!(tool.description(), "Read a file from the filesystem"); + assert!(tool.parameters().is_object()); + } + + #[test] + fn test_mcp_tool_impl_tool_trait() { + // Verify McpTool implements the Tool trait + let tool: Box = Box::new(create_test_tool()); + + assert_eq!(tool.name(), "test_tool"); + assert_eq!(tool.description(), "A test tool for testing"); + assert!(tool.parameters().is_object()); + } + + #[parameterized( + strip_prefix = { "server1_echo", "server1", "echo" }, + fs_read = { "fs_read_file", "fs", "read_file" }, + memory_context = { "memory_get_context", "memory", "get_context" }, + no_prefix = { "echo", "server1", "echo" }, + exact_match = { "server1", "server1", "server1" }, + different_server = { "other_add", "server1", "other_add" }, + )] + fn test_prefix_stripping_logic(tool_name: &str, server_name: &str, expected: &str) { + let actual = if tool_name.starts_with(&format!("{}_", server_name)) { + tool_name + .strip_prefix(&format!("{}_", server_name)) + .unwrap() + } else { + tool_name + }; + assert_eq!(actual, expected); + } +} diff --git a/crates/mcp/src/config.rs b/crates/mcp/src/config.rs new file mode 100644 index 0000000..447798a --- /dev/null +++ b/crates/mcp/src/config.rs @@ -0,0 +1,261 @@ +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +#[derive(Debug, Deserialize, Serialize, Clone)] +pub struct McpServerConfig { + pub command: String, + pub args: Vec, + #[serde(default)] + pub env: HashMap, + #[serde(default = "default_enabled")] + pub enabled: bool, +} + +fn default_enabled() -> bool { + true +} + +#[derive(Debug, Deserialize, Serialize, Clone, Default)] +pub struct McpConfig { + #[serde(default)] + pub servers: HashMap, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct McpServerStatus { + pub name: String, + pub running: bool, + pub enabled: bool, + pub tool_count: usize, +} + +impl Default for McpServerConfig { + fn default() -> Self { + Self { + command: String::new(), + args: Vec::new(), + env: HashMap::new(), + enabled: true, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use yare::parameterized; + + #[parameterized( + no_enabled_field = { "command: npx\nargs: []", true }, + explicit_true = { "command: npx\nargs: []\nenabled: true", true }, + explicit_false = { "command: npx\nargs: []\nenabled: false", false }, + )] + fn test_mcp_server_config_enabled(yaml: &str, expected: bool) { + let config: McpServerConfig = serde_yaml::from_str(yaml).unwrap(); + assert_eq!(config.enabled, expected); + } + + #[test] + fn test_mcp_server_config_default() { + let config = McpServerConfig::default(); + assert!(config.command.is_empty()); + assert!(config.args.is_empty()); + assert!(config.env.is_empty()); + assert!(config.enabled); + } + + #[test] + fn test_mcp_server_config_deserialization() { + let yaml = r#"command: npx +args: + - -y + - "@modelcontextprotocol/server-filesystem" + - /home/user +env: + HOME: /home/user +enabled: true"#; + + let config: McpServerConfig = serde_yaml::from_str(yaml).unwrap(); + assert_eq!(config.command, "npx"); + assert_eq!(config.args.len(), 3); + assert_eq!(config.args[0], "-y"); + assert_eq!(config.env.get("HOME"), Some(&"/home/user".to_string())); + assert!(config.enabled); + } + + #[test] + fn test_mcp_config_default() { + let config = McpConfig::default(); + assert!(config.servers.is_empty()); + } + + #[test] + fn test_mcp_config_deserialization() { + let yaml = r#"servers: + filesystem: + command: npx + args: + - -y + - "@modelcontextprotocol/server-filesystem" + - /home + enabled: true + memory: + command: npx + args: + - -y + - "@modelcontextprotocol/server-memory" + enabled: false"#; + + let config: McpConfig = serde_yaml::from_str(yaml).unwrap(); + assert_eq!(config.servers.len(), 2); + + let fs = config.servers.get("filesystem").unwrap(); + assert_eq!(fs.command, "npx"); + assert!(fs.enabled); + + let memory = config.servers.get("memory").unwrap(); + assert!(!memory.enabled); + } + + #[parameterized( + empty = { "servers: {}", 0 }, + one_server = { "servers:\n fs:\n command: npx\n args: []", 1 }, + three_servers = { "servers:\n fs:\n command: npx\n args: []\n memory:\n command: npx\n args: []\n github:\n command: npx\n args: []", 3 }, + )] + fn test_mcp_config_server_count(yaml: &str, expected_count: usize) { + let config: McpConfig = serde_yaml::from_str(yaml).unwrap(); + assert_eq!(config.servers.len(), expected_count); + } + + #[test] + fn test_mcp_config_serialization() { + let mut servers = HashMap::new(); + servers.insert( + "test".to_string(), + McpServerConfig { + command: "npx".to_string(), + args: vec!["-y".to_string(), "server".to_string()], + env: HashMap::new(), + enabled: true, + }, + ); + + let config = McpConfig { servers }; + let yaml = serde_yaml::to_string(&config).unwrap(); + assert!(yaml.contains("test:")); + assert!(yaml.contains("command: npx")); + } + + #[test] + fn test_mcp_server_status_fields() { + let status = McpServerStatus { + name: "test-server".to_string(), + running: true, + enabled: true, + tool_count: 5, + }; + + assert_eq!(status.name, "test-server"); + assert!(status.running); + assert!(status.enabled); + assert_eq!(status.tool_count, 5); + } + + #[test] + fn test_mcp_server_config_clone() { + let config = McpServerConfig { + command: "npx".to_string(), + args: vec!["-y".to_string()], + env: HashMap::from([("KEY".to_string(), "value".to_string())]), + enabled: false, + }; + + let cloned = config.clone(); + assert_eq!(cloned.command, config.command); + assert_eq!(cloned.args, config.args); + assert_eq!(cloned.env.get("KEY"), Some(&"value".to_string())); + assert_eq!(cloned.enabled, config.enabled); + } + + #[test] + fn test_mcp_server_config_with_env() { + let yaml = r#"command: npx +args: + - -y + - server +env: + HOME: /home/user + PATH: /usr/bin +enabled: true"#; + + let config: McpServerConfig = serde_yaml::from_str(yaml).unwrap(); + assert_eq!(config.env.len(), 2); + assert_eq!(config.env.get("HOME"), Some(&"/home/user".to_string())); + assert_eq!(config.env.get("PATH"), Some(&"/usr/bin".to_string())); + } + + #[test] + fn test_mcp_config_multiple_servers() { + let yaml = r#"servers: + fs: + command: npx + args: [-y, fs-server] + enabled: true + memory: + command: npx + args: [-y, memory-server] + enabled: false + github: + command: npx + args: [-y, github-server] + enabled: true"#; + + let config: McpConfig = serde_yaml::from_str(yaml).unwrap(); + assert_eq!(config.servers.len(), 3); + + assert!(config.servers.get("fs").unwrap().enabled); + assert!(!config.servers.get("memory").unwrap().enabled); + assert!(config.servers.get("github").unwrap().enabled); + } + + #[test] + fn test_mcp_config_empty_servers() { + let yaml = "servers: {}"; + let config: McpConfig = serde_yaml::from_str(yaml).unwrap(); + assert!(config.servers.is_empty()); + } + + #[test] + fn test_mcp_server_status_equality() { + let status1 = McpServerStatus { + name: "test".to_string(), + running: true, + enabled: true, + tool_count: 5, + }; + + let status2 = McpServerStatus { + name: "test".to_string(), + running: true, + enabled: true, + tool_count: 5, + }; + + assert_eq!(status1, status2); + } + + #[test] + fn test_mcp_server_status_display() { + let status = McpServerStatus { + name: "test-server".to_string(), + running: true, + enabled: false, + tool_count: 3, + }; + + assert_eq!(status.name, "test-server"); + assert!(status.running); + assert!(!status.enabled); + assert_eq!(status.tool_count, 3); + } +} diff --git a/crates/mcp/src/lib.rs b/crates/mcp/src/lib.rs new file mode 100644 index 0000000..4ac0e48 --- /dev/null +++ b/crates/mcp/src/lib.rs @@ -0,0 +1,9 @@ +pub mod client; +pub mod config; +#[cfg(any(test, feature = "test_utils"))] +pub mod mock; +pub mod registry; + +pub use client::McpClient; +pub use config::{McpConfig, McpServerConfig, McpServerStatus}; +pub use registry::McpRegistry; diff --git a/crates/mcp/src/mock.rs b/crates/mcp/src/mock.rs new file mode 100644 index 0000000..59465b4 --- /dev/null +++ b/crates/mcp/src/mock.rs @@ -0,0 +1,217 @@ +//! Mock MCP tools for testing. +//! +//! This module provides mock tools that can be used for testing MCP integration +//! without spawning external MCP servers. + +#[cfg(any(test, feature = "test_utils"))] +pub mod test_helpers { + use std::sync::Arc; + + use anyhow::Result; + use async_trait::async_trait; + use rmcp::schemars; + use rmcp::{ + handler::server::wrapper::{Json, Parameters}, + schemars::JsonSchema, + tool, tool_router, + }; + use serde::{Deserialize, Serialize}; + use serde_json::{Value, json}; + + use arey_core::tools::{Tool, ToolError}; + + /// Mock tool for testing - echoes back input + #[derive(Clone)] + pub struct MockEchoTool { + server_name: String, + } + + impl MockEchoTool { + pub fn new(server_name: &str) -> Self { + Self { + server_name: server_name.to_string(), + } + } + } + + #[async_trait] + impl Tool for MockEchoTool { + fn name(&self) -> String { + format!("{}_echo", self.server_name) + } + + fn description(&self) -> String { + "Echo back the input (mock tool)".to_string() + } + + fn parameters(&self) -> Value { + json!({ + "type": "object", + "properties": { + "input": { "type": "string" } + } + }) + } + + async fn execute(&self, arguments: &Value) -> Result { + let input = arguments + .get("input") + .and_then(|v| v.as_str()) + .unwrap_or("mock echo result"); + Ok(json!({ "echoed": input })) + } + } + + /// Mock tool for testing - adds two numbers + #[derive(Clone)] + pub struct MockAddTool { + server_name: String, + } + + impl MockAddTool { + pub fn new(server_name: &str) -> Self { + Self { + server_name: server_name.to_string(), + } + } + } + + #[async_trait] + impl Tool for MockAddTool { + fn name(&self) -> String { + format!("{}_add", self.server_name) + } + + fn description(&self) -> String { + "Add two numbers (mock tool)".to_string() + } + + fn parameters(&self) -> Value { + json!({ + "type": "object", + "properties": { + "a": { "type": "integer" }, + "b": { "type": "integer" } + }, + "required": ["a", "b"] + }) + } + + async fn execute(&self, arguments: &Value) -> Result { + let a = arguments.get("a").and_then(|v| v.as_i64()).unwrap_or(0); + let b = arguments.get("b").and_then(|v| v.as_i64()).unwrap_or(0); + Ok(json!({ "sum": a + b })) + } + } + + /// Mock tool for testing - returns current time + #[derive(Clone)] + pub struct MockTimeTool { + server_name: String, + } + + impl MockTimeTool { + pub fn new(server_name: &str) -> Self { + Self { + server_name: server_name.to_string(), + } + } + } + + #[async_trait] + impl Tool for MockTimeTool { + fn name(&self) -> String { + format!("{}_get_time", self.server_name) + } + + fn description(&self) -> String { + "Get the current time (mock tool)".to_string() + } + + fn parameters(&self) -> Value { + json!({ + "type": "object", + "properties": {} + }) + } + + async fn execute(&self, _arguments: &Value) -> Result { + Ok(json!({ "time": "2024-01-01T00:00:00Z" })) + } + } + + /// Create mock MCP tools for testing + pub fn create_mock_tools(server_name: &str) -> Vec> { + vec![ + Arc::new(MockEchoTool::new(server_name)) as Arc, + Arc::new(MockAddTool::new(server_name)) as Arc, + Arc::new(MockTimeTool::new(server_name)) as Arc, + ] + } + + /// Create a test configuration with MCP servers for testing + pub fn create_test_mcp_config() -> serde_yaml::Value { + serde_yaml::from_str( + r#" +servers: + test: + command: echo + args: ["test"] + enabled: true +"#, + ) + .unwrap() + } + + // --- Weather Server for Testing --- + #[derive(Clone)] + pub struct WeatherServer; + + #[derive(Serialize, Deserialize, JsonSchema)] + pub struct AddRequest { + pub a: i32, + pub b: i32, + } + + #[derive(Serialize, Deserialize, JsonSchema)] + pub struct GetWeatherRequest { + pub location: String, + } + + #[derive(Serialize, Deserialize, JsonSchema)] + pub struct GetWeatherResponse { + pub location: String, + #[serde(rename = "temp_C")] + pub temp_c: i32, + #[serde(rename = "weatherDesc")] + pub weather_desc: String, + pub humidity: i32, + #[serde(rename = "precipMM")] + pub precip_mm: i32, + } + + #[tool_router(server_handler)] + impl WeatherServer { + #[tool(description = "Add two numbers")] + fn add(&self, Parameters(req): Parameters) -> String { + (req.a + req.b).to_string() + } + + #[tool(description = "Gets the current weather for a given location")] + fn get_weather( + &self, + Parameters(req): Parameters, + ) -> Json { + Json(GetWeatherResponse { + location: req.location, + temp_c: 28, + weather_desc: "sunny".to_string(), + humidity: 54, + precip_mm: 0, + }) + } + } +} + +#[cfg(any(test, feature = "test_utils"))] +pub use test_helpers::*; diff --git a/crates/mcp/src/registry.rs b/crates/mcp/src/registry.rs new file mode 100644 index 0000000..dd3c44f --- /dev/null +++ b/crates/mcp/src/registry.rs @@ -0,0 +1,546 @@ +use std::collections::HashMap; +use std::sync::Arc; + +use anyhow::Result; +use tokio::sync::RwLock; + +use arey_core::tools::Tool; + +use crate::client::McpClient; +pub use crate::config::{McpConfig, McpServerConfig, McpServerStatus}; + +pub struct McpRegistry { + servers: HashMap, + enabled: RwLock>, +} + +struct McpServerState { + client: McpClient, + #[allow(dead_code)] + config: McpServerConfig, +} + +impl McpRegistry { + pub fn new() -> Self { + Self { + servers: HashMap::new(), + enabled: RwLock::new(HashMap::new()), + } + } + + pub async fn add_server(&mut self, name: String, config: &McpServerConfig) -> Result<()> { + let client = McpClient::new(name.clone(), config).await?; + + self.servers.insert( + name.clone(), + McpServerState { + client, + config: config.clone(), + }, + ); + + if config.enabled { + self.enable(&name).await?; + } + + Ok(()) + } + + pub async fn remove_server(&mut self, name: &str) -> Result<()> { + self.disable(name).await?; + self.servers.remove(name); + Ok(()) + } + + pub async fn enable(&mut self, name: &str) -> Result<()> { + if !self.servers.contains_key(name) { + anyhow::bail!("MCP server '{}' not found", name); + } + + let mut enabled = self.enabled.write().await; + enabled.insert(name.to_string(), true); + Ok(()) + } + + pub async fn disable(&mut self, name: &str) -> Result<()> { + let mut enabled = self.enabled.write().await; + enabled.insert(name.to_string(), false); + Ok(()) + } + + pub async fn list(&self) -> Vec { + let enabled = self.enabled.read().await; + + self.servers + .iter() + .map(|(name, state)| { + let tool_count = if enabled.get(name) == Some(&true) { + state.client.tools().len() + } else { + 0 + }; + + McpServerStatus { + name: name.clone(), + running: true, + enabled: *enabled.get(name).unwrap_or(&false), + tool_count, + } + }) + .collect() + } + + pub fn get_tools(&self) -> Vec> { + self.get_all_tools() + } + + pub fn get_all_tools(&self) -> Vec> { + self.servers + .values() + .flat_map(|state| state.client.tools()) + .collect() + } + + pub async fn get_enabled_tools(&self) -> Vec> { + let enabled = self.enabled.read().await; + + self.servers + .iter() + .filter(|(name, _)| enabled.get(*name) == Some(&true)) + .flat_map(|(_, state)| state.client.tools()) + .collect() + } + + pub fn server_names(&self) -> Vec { + self.servers.keys().cloned().collect() + } + + pub async fn is_enabled(&self, name: &str) -> bool { + let enabled = self.enabled.read().await; + *enabled.get(name).unwrap_or(&false) + } +} + +impl Default for McpRegistry { + fn default() -> Self { + Self::new() + } +} + +impl McpRegistry { + /// Creates McpRegistry from config and starts enabled MCP servers. + /// Returns None if no servers are configured or enabled. + pub async fn from_config(config: &arey_core::config::Config) -> Result> { + let mcp_value = &config.mcp; + + let is_mcp_empty = match mcp_value { + serde_yaml::Value::Null => true, + serde_yaml::Value::Sequence(seq) => seq.is_empty(), + _ => false, + }; + + if is_mcp_empty { + return Ok(None); + } + + let mcp_config: McpConfig = match serde_yaml::from_value(mcp_value.clone()) { + Ok(cfg) => cfg, + Err(_) => return Ok(None), + }; + + if mcp_config.servers.is_empty() { + return Ok(None); + } + + let mut registry = Self::new(); + let mut started_any = false; + + for (name, server_config) in mcp_config.servers { + if server_config.enabled { + started_any = true; + match registry.add_server(name.clone(), &server_config).await { + Ok(_) => { + tracing::info!("Started MCP server: {}", name); + } + Err(e) => { + tracing::warn!("Failed to start MCP server '{}': {}", name, e); + } + } + } + } + + if started_any { + Ok(Some(registry)) + } else { + Ok(None) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + // Mock McpClient for testing + struct MockMcpClient { + _name: String, + tools: Vec>, + } + + impl MockMcpClient { + #[allow(dead_code)] + fn new(name: &str, tool_count: usize) -> Self { + let tools: Vec> = (0..tool_count) + .map(|i| { + Arc::new(MockTool { + name: format!("{}_tool_{}", name, i), + description: format!("Mock tool {} from {}", i, name), + }) as Arc + }) + .collect(); + + Self { + _name: name.to_string(), + tools, + } + } + + #[allow(dead_code)] + fn tools(&self) -> Vec> { + self.tools.clone() + } + } + + #[allow(dead_code)] + struct MockTool { + name: String, + description: String, + } + + #[async_trait::async_trait] + impl Tool for MockTool { + fn name(&self) -> String { + self.name.clone() + } + + fn description(&self) -> String { + self.description.clone() + } + + fn parameters(&self) -> serde_json::Value { + serde_json::json!({}) + } + + async fn execute( + &self, + _arguments: &serde_json::Value, + ) -> Result { + Ok(serde_json::json!({ "mock": true })) + } + } + + // Note: Since McpClient::new requires spawning a real process, + // we test the manager methods that don't require a running server + // by testing the logic that doesn't depend on actual MCP connections. + // Full integration tests would require mock MCP servers. + + #[tokio::test] + async fn test_manager_new() { + let manager = McpRegistry::new(); + let names = manager.server_names(); + assert!(names.is_empty()); + } + + #[tokio::test] + async fn test_manager_disable_nonexistent() { + let mut manager = McpRegistry::new(); + let result = manager.disable("nonexistent").await; + // Should succeed but do nothing + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_manager_disable_enables_then_disables() { + let mut manager = McpRegistry::new(); + + // Enable on non-existent should error + let result = manager.enable("nonexistent").await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_manager_list_empty() { + let manager = McpRegistry::new(); + let list = manager.list().await; + assert!(list.is_empty()); + } + + #[tokio::test] + async fn test_manager_get_tools_empty() { + let manager = McpRegistry::new(); + let tools = manager.get_tools(); + assert!(tools.is_empty()); + } + + #[tokio::test] + async fn test_manager_get_enabled_tools_empty() { + let manager = McpRegistry::new(); + let tools = manager.get_enabled_tools().await; + assert!(tools.is_empty()); + } + + #[tokio::test] + async fn test_manager_is_enabled_nonexistent() { + let manager = McpRegistry::new(); + let enabled = manager.is_enabled("nonexistent").await; + assert!(!enabled); + } + + #[tokio::test] + async fn test_manager_remove_nonexistent() { + let mut manager = McpRegistry::new(); + let result = manager.remove_server("nonexistent").await; + // Should succeed (disable does nothing, remove returns None) + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_manager_disable_twice() { + let mut manager = McpRegistry::new(); + + // Disable twice should be idempotent + let result1 = manager.disable("test").await; + assert!(result1.is_ok()); + + let result2 = manager.disable("test").await; + assert!(result2.is_ok()); + + // Should still be disabled + let enabled = manager.is_enabled("test").await; + assert!(!enabled); + } + + #[tokio::test] + async fn test_manager_enable_twice() { + let mut manager = McpRegistry::new(); + + // Enable twice should be idempotent + // But enable on nonexistent should error + let result1 = manager.enable("test").await; + assert!(result1.is_err()); // Server doesn't exist + + let result2 = manager.enable("test").await; + assert!(result2.is_err()); // Still doesn't exist + } + + #[tokio::test] + async fn test_manager_server_names_after_disable() { + let mut manager = McpRegistry::new(); + let names_before = manager.server_names(); + assert!(names_before.is_empty()); + + manager.disable("any").await.unwrap(); + + let names_after = manager.server_names(); + assert!(names_after.is_empty()); // No servers added + } +} + +#[cfg(test)] +mod config_tests { + use super::*; + use arey_core::agent::Agent; + use arey_core::config::{Config, ModeConfig, ProfileConfig}; + use arey_core::model::ModelConfig; + use serde_yaml::Value; + use std::collections::HashMap; + use yare::parameterized; + + fn create_test_config_with_mcp(mcp_value: Value) -> Config { + let mut models = HashMap::new(); + models.insert( + "test-model".to_string(), + ModelConfig { + name: "test-model".to_string(), + key: "test-model".to_string(), + ..Default::default() + }, + ); + + let mut profiles = HashMap::new(); + profiles.insert("default".to_string(), ProfileConfig::default()); + + let mut agents = HashMap::new(); + agents.insert( + "default".to_string(), + Agent::new( + "default".to_string(), + "You are helpful.".to_string(), + vec![], + ProfileConfig::default(), + Default::default(), + ), + ); + + let chat_mode = ModeConfig { + model: models.get("test-model").cloned().unwrap(), + agent_name: "default".to_string(), + profile: ProfileConfig::default(), + profile_name: Some("default".to_string()), + }; + + Config { + models, + profiles, + agents, + chat: chat_mode.clone(), + task: chat_mode, + theme: "light".to_string(), + tools: HashMap::new(), + mcp: mcp_value, + } + } + + #[tokio::test] + async fn test_mcp_registry_from_config_null() { + let config = create_test_config_with_mcp(Value::Null); + let result = McpRegistry::from_config(&config).await; + assert!(result.unwrap().is_none()); + } + + #[tokio::test] + async fn test_mcp_registry_from_config_empty_servers() { + let config = create_test_config_with_mcp(serde_yaml::from_str("servers: {}").unwrap()); + let result = McpRegistry::from_config(&config).await; + assert!(result.unwrap().is_none()); + } + + #[tokio::test] + async fn test_mcp_registry_from_config_invalid_yaml() { + let config = create_test_config_with_mcp(serde_yaml::from_str("servers: []").unwrap()); + let result = McpRegistry::from_config(&config).await; + assert!(result.unwrap().is_none()); + } + + #[tokio::test] + async fn test_mcp_registry_from_config_sequence() { + let config = + create_test_config_with_mcp(serde_yaml::from_str("- server1\n- server2").unwrap()); + let result = McpRegistry::from_config(&config).await; + assert!(result.unwrap().is_none()); + } + + #[test] + fn test_mcp_registry_from_config_disabled_server() { + let yaml = r#" +servers: + test-server: + command: npx + args: [-y, test-server] + enabled: false +"#; + let config = create_test_config_with_mcp(serde_yaml::from_str(yaml).unwrap()); + let result = tokio::runtime::Runtime::new() + .unwrap() + .block_on(McpRegistry::from_config(&config)); + let mgr = result.unwrap(); + if let Some(m) = mgr { + assert_eq!(m.server_names().len(), 0); + } + } + + #[parameterized( + null = { serde_yaml::Value::Null }, + empty_servers = { serde_yaml::from_str("servers: {}").unwrap() }, + invalid_servers = { serde_yaml::from_str("servers: []").unwrap() }, + sequence = { serde_yaml::from_str("- server1\n- server2").unwrap() }, + )] + #[test_macro(tokio::test)] + async fn test_mcp_registry_from_config_invalid(mcp_value: serde_yaml::Value) { + let config = create_test_config_with_mcp(mcp_value); + let result = McpRegistry::from_config(&config).await; + assert!(result.unwrap().is_none()); + } + + #[test] + fn test_mcp_config_parsing_enabled_server() { + let yaml = r#" +servers: + fs: + command: npx + args: + - -y + - "@modelcontextprotocol/server-filesystem" + - /tmp + enabled: true + memory: + command: npx + args: + - -y + - "@modelcontextprotocol/server-memory" + enabled: false +"#; + let mcp_config: McpConfig = serde_yaml::from_str(yaml).unwrap(); + + assert_eq!(mcp_config.servers.len(), 2); + + let fs = mcp_config.servers.get("fs").unwrap(); + assert_eq!(fs.command, "npx"); + assert!(fs.enabled); + + let memory = mcp_config.servers.get("memory").unwrap(); + assert!(!memory.enabled); + } + + #[test] + fn test_mcp_server_config_defaults() { + let yaml = r#" +command: npx +args: [-y, server] +"#; + let config: McpServerConfig = serde_yaml::from_str(yaml).unwrap(); + + assert_eq!(config.command, "npx"); + assert_eq!(config.args, vec!["-y", "server"]); + assert!(config.env.is_empty()); + assert!(config.enabled); // default is true + } + + #[test] + fn test_mcp_server_config_with_env() { + let yaml = r#" +command: npx +args: [-y, server] +env: + HOME: /home/user + DEBUG: "true" +enabled: true +"#; + let config: McpServerConfig = serde_yaml::from_str(yaml).unwrap(); + + assert_eq!(config.env.get("HOME"), Some(&"/home/user".to_string())); + assert_eq!(config.env.get("DEBUG"), Some(&"true".to_string())); + } + + #[test] + fn test_mcp_server_config_explicit_enabled() { + let yaml = r#" +command: npx +args: [] +enabled: true +"#; + let config: McpServerConfig = serde_yaml::from_str(yaml).unwrap(); + assert!(config.enabled); + } + + #[test] + fn test_mcp_server_config_explicit_disabled() { + let yaml = r#" +command: npx +args: [] +enabled: false +"#; + let config: McpServerConfig = serde_yaml::from_str(yaml).unwrap(); + assert!(!config.enabled); + } +} diff --git a/crates/mcp/tests/mcp_integration.rs b/crates/mcp/tests/mcp_integration.rs new file mode 100644 index 0000000..42f7859 --- /dev/null +++ b/crates/mcp/tests/mcp_integration.rs @@ -0,0 +1,51 @@ +use std::sync::Arc; + +use anyhow::Result; +use arey_mcp::McpClient; +use arey_mcp::mock::WeatherServer; +use rmcp::service::ServiceExt; +use serde_json::json; +use tokio::io::duplex; + +#[tokio::test] +async fn test_mcp_weather_server_in_process() -> Result<()> { + // 1. Setup in-process transport using duplex + let (client_io, server_io) = duplex(1024); + + // 2. Start MCP weather server on one end of the duplex + tokio::spawn(async move { + if let Ok(service) = WeatherServer.serve(server_io).await { + let _ = service.waiting().await; + } + }); + + // 3. Connect MCP client on the other end of the duplex + let client_service = Arc::new(().serve(client_io).await?); + let mcp_client = McpClient::with_service("weather".to_string(), client_service).await?; + let mcp_tools = mcp_client.tools(); + + assert_eq!(mcp_tools.len(), 2, "Should find 2 tools"); + + // 4. Execute tool via McpClient::call_tool + let args = json!({ "location": "London" }); + let result = mcp_client.call_tool("weather_get_weather", &args).await?; + assert_eq!(result["location"], "London"); + assert!(result.get("temp_C").is_some(), "Should have temp_C field"); + + // 5. Execute tool directly via Tool trait + if let Some(tool) = mcp_tools.iter().find(|t| t.name() == "weather_get_weather") { + let args = json!({ "location": "Tokyo" }); + let result = tool.execute(&args).await?; + assert_eq!(result["location"], "Tokyo"); + } else { + panic!("Tool weather_get_weather not found in registry"); + } + + // 6. Test add tool + let add_args = json!({ "a": 5, "b": 3 }); + let add_result = mcp_client.call_tool("weather_add", &add_args).await?; + // The server returns a string which McpClient wraps in {"text": "8"} + assert_eq!(add_result["text"], "8"); + + Ok(()) +}