diff --git a/Cargo.lock b/Cargo.lock index 9cb16be..a9133d2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -105,6 +105,17 @@ dependencies = [ "derive_arbitrary", ] +[[package]] +name = "async-trait" +version = "0.1.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.114", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -117,6 +128,61 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +[[package]] +name = "axum" +version = "0.7.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edca88bc138befd0323b20752846e6587272d3b03b0343c8ea28a6f819e6e71f" +dependencies = [ + "async-trait", + "axum-core", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "itoa", + "matchit", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "rustversion", + "serde", + "serde_json", + "serde_path_to_error", + "serde_urlencoded", + "sync_wrapper", + "tokio", + "tower", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-core" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09f2bd6146b97ae3359fa0cc6d6b376d9539582c7b4220f041a33ec24c226199" +dependencies = [ + "async-trait", + "bytes", + "futures-util", + "http", + "http-body", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "sync_wrapper", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "base64" version = "0.22.1" @@ -204,6 +270,7 @@ name = "chomp" version = "0.1.0" dependencies = [ "anyhow", + "axum", "chrono", "clap", "csv", @@ -216,6 +283,9 @@ dependencies = [ "tabled", "tempfile", "tokio", + "tokio-stream", + "tower-http 0.5.2", + "uuid", "zip", ] @@ -515,6 +585,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" + [[package]] name = "foreign-types" version = "0.3.2" @@ -633,6 +709,19 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "getrandom" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "139ef39800118c7683f2fd3c98c1b23c09ae076556b435f8e9064ae108aaeeec" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", + "wasip3", +] + [[package]] name = "h2" version = "0.4.13" @@ -661,6 +750,15 @@ dependencies = [ "ahash", ] +[[package]] +name = "hashbrown" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1" +dependencies = [ + "foldhash", +] + [[package]] name = "hashbrown" version = "0.16.1" @@ -736,6 +834,12 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" +[[package]] +name = "httpdate" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" + [[package]] name = "hyper" version = "1.8.1" @@ -750,6 +854,7 @@ dependencies = [ "http", "http-body", "httparse", + "httpdate", "itoa", "pin-project-lite", "pin-utils", @@ -920,6 +1025,12 @@ dependencies = [ "zerovec", ] +[[package]] +name = "id-arena" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d3067d79b975e8844ca9eb072e16b31c3c1c36928edf9c6789548c524d0d954" + [[package]] name = "idna" version = "1.1.0" @@ -949,6 +1060,8 @@ checksum = "7714e70437a7dc3ac8eb7e6f8df75fd8eb422675fc7678aff7364301092b1017" dependencies = [ "equivalent", "hashbrown 0.16.1", + "serde", + "serde_core", ] [[package]] @@ -1008,6 +1121,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "leb128fmt" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" + [[package]] name = "libc" version = "0.2.180" @@ -1074,6 +1193,12 @@ dependencies = [ "pkg-config", ] +[[package]] +name = "matchit" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94" + [[package]] name = "memchr" version = "2.7.6" @@ -1261,6 +1386,16 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" +[[package]] +name = "prettyplease" +version = "0.2.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b" +dependencies = [ + "proc-macro2", + "syn 2.0.114", +] + [[package]] name = "proc-macro-error" version = "1.0.4" @@ -1354,7 +1489,7 @@ dependencies = [ "tokio", "tokio-native-tls", "tower", - "tower-http", + "tower-http 0.6.8", "tower-service", "url", "wasm-bindgen", @@ -1480,6 +1615,12 @@ dependencies = [ "libc", ] +[[package]] +name = "semver" +version = "1.0.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2" + [[package]] name = "serde" version = "1.0.228" @@ -1523,6 +1664,17 @@ dependencies = [ "zmij", ] +[[package]] +name = "serde_path_to_error" +version = "0.1.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "10a9ff822e371bb5403e391ecd83e182e0e77ba7f6fe0160b795797109d1b457" +dependencies = [ + "itoa", + "serde", + "serde_core", +] + [[package]] name = "serde_urlencoded" version = "0.7.1" @@ -1822,6 +1974,17 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-stream" +version = "0.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32da49809aab5c3bc678af03902d4ccddea2a87d028d86392a4b1560c6906c70" +dependencies = [ + "futures-core", + "pin-project-lite", + "tokio", +] + [[package]] name = "tokio-util" version = "0.7.18" @@ -1848,6 +2011,23 @@ dependencies = [ "tokio", "tower-layer", "tower-service", + "tracing", +] + +[[package]] +name = "tower-http" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5" +dependencies = [ + "bitflags", + "bytes", + "http", + "http-body", + "http-body-util", + "pin-project-lite", + "tower-layer", + "tower-service", ] [[package]] @@ -1886,6 +2066,7 @@ version = "0.1.44" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "63e71662fa4b2a2c3a26f570f037eb95bb1f85397f3cd8076caed2f026a6d100" dependencies = [ + "log", "pin-project-lite", "tracing-core", ] @@ -1923,6 +2104,12 @@ version = "0.1.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7dd6e30e90baa6f72411720665d41d89b9a3d039dc45b8faea1ddd07f617f6af" +[[package]] +name = "unicode-xid" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" + [[package]] name = "untrusted" version = "0.9.0" @@ -1953,6 +2140,17 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" +[[package]] +name = "uuid" +version = "1.21.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b672338555252d43fd2240c714dc444b8c6fb0a5c5335e65a07bba7742735ddb" +dependencies = [ + "getrandom 0.4.1", + "js-sys", + "wasm-bindgen", +] + [[package]] name = "vcpkg" version = "0.2.15" @@ -1989,6 +2187,15 @@ dependencies = [ "wit-bindgen", ] +[[package]] +name = "wasip3" +version = "0.4.0+wasi-0.3.0-rc-2026-01-06" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5428f8bf88ea5ddc08faddef2ac4a67e390b88186c703ce6dbd955e1c145aca5" +dependencies = [ + "wit-bindgen", +] + [[package]] name = "wasm-bindgen" version = "0.2.108" @@ -2048,6 +2255,40 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-encoder" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "990065f2fe63003fe337b932cfb5e3b80e0b4d0f5ff650e6985b1048f62c8319" +dependencies = [ + "leb128fmt", + "wasmparser", +] + +[[package]] +name = "wasm-metadata" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb0e353e6a2fbdc176932bbaab493762eb1255a7900fe0fea1a2f96c296cc909" +dependencies = [ + "anyhow", + "indexmap", + "wasm-encoder", + "wasmparser", +] + +[[package]] +name = "wasmparser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" +dependencies = [ + "bitflags", + "hashbrown 0.15.5", + "indexmap", + "semver", +] + [[package]] name = "web-sys" version = "0.3.85" @@ -2355,6 +2596,88 @@ name = "wit-bindgen" version = "0.51.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7249219f66ced02969388cf2bb044a09756a083d0fab1e566056b04d9fbcaa5" +dependencies = [ + "wit-bindgen-rust-macro", +] + +[[package]] +name = "wit-bindgen-core" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea61de684c3ea68cb082b7a88508a8b27fcc8b797d738bfc99a82facf1d752dc" +dependencies = [ + "anyhow", + "heck 0.5.0", + "wit-parser", +] + +[[package]] +name = "wit-bindgen-rust" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7c566e0f4b284dd6561c786d9cb0142da491f46a9fbed79ea69cdad5db17f21" +dependencies = [ + "anyhow", + "heck 0.5.0", + "indexmap", + "prettyplease", + "syn 2.0.114", + "wasm-metadata", + "wit-bindgen-core", + "wit-component", +] + +[[package]] +name = "wit-bindgen-rust-macro" +version = "0.51.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c0f9bfd77e6a48eccf51359e3ae77140a7f50b1e2ebfe62422d8afdaffab17a" +dependencies = [ + "anyhow", + "prettyplease", + "proc-macro2", + "quote", + "syn 2.0.114", + "wit-bindgen-core", + "wit-bindgen-rust", +] + +[[package]] +name = "wit-component" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" +dependencies = [ + "anyhow", + "bitflags", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "wasm-encoder", + "wasm-metadata", + "wasmparser", + "wit-parser", +] + +[[package]] +name = "wit-parser" +version = "0.244.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ecc8ac4bc1dc3381b7f59c34f00b67e18f910c2c0f50015669dde7def656a736" +dependencies = [ + "anyhow", + "id-arena", + "indexmap", + "log", + "semver", + "serde", + "serde_derive", + "serde_json", + "unicode-xid", + "wasmparser", +] [[package]] name = "writeable" diff --git a/Cargo.toml b/Cargo.toml index 28ca296..9302c0e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,7 +9,7 @@ keywords = ["nutrition", "food", "cli", "tracking", "ai"] categories = ["command-line-utilities"] [dependencies] -clap = { version = "4", features = ["derive"] } +clap = { version = "4", features = ["derive", "env"] } rusqlite = { version = "0.31", features = ["bundled"] } serde = { version = "1", features = ["derive"] } serde_json = "1" @@ -18,8 +18,17 @@ fuzzy-matcher = "0.3" dirs = "5" anyhow = "1" tabled = "0.15" -tokio = { version = "1", features = ["rt", "io-std", "io-util", "macros"] } +tokio = { version = "1", features = ["rt", "rt-multi-thread", "io-std", "io-util", "macros", "sync"] } +axum = { version = "0.7", optional = true } +tokio-stream = { version = "0.1", optional = true } +uuid = { version = "1", features = ["v4"], optional = true } +tower-http = { version = "0.5", features = ["cors"], optional = true } + csv = "1" reqwest = { version = "0.12", features = ["blocking"] } zip = "2" tempfile = "3" + +[features] +default = ["sse"] +sse = ["axum", "tokio-stream", "uuid", "tower-http"] diff --git a/README.md b/README.md index e30ce26..51709ff 100644 --- a/README.md +++ b/README.md @@ -64,9 +64,37 @@ chomp "salmon 4oz" --json # log + structured output chomp search salmon --json # nutrition lookup without web search ``` -### MCP Server (for Claude Desktop) +### MCP Server ```bash -chomp serve # starts MCP server on stdio +# stdio transport (Claude Desktop) +chomp serve # default: stdio +chomp serve --transport stdio + +# SSE transport (Poke.com, remote agents) +chomp serve --transport sse # default: http://127.0.0.1:3000 +chomp serve --transport sse --port 3456 --host 0.0.0.0 + +# Both transports simultaneously +chomp serve --transport both --port 3000 +``` + +**Transport options:** + +| Transport | Use case | Endpoint | +|-----------|----------|----------| +| `stdio` | Claude Desktop, local AI | stdin/stdout | +| `sse` | Poke.com, Railway, remote | `GET /sse` + `POST /message` | +| `both` | Run both simultaneously | stdio + HTTP | + +**SSE endpoints:** +- `GET /sse` — SSE event stream (returns `endpoint` event with session POST URL) +- `POST /message?sessionId=` — send JSON-RPC requests +- `GET /health` — health check + +**Environment variables (SSE mode):** +```bash +CHOMP_PORT=3000 # default: 3000 +CHOMP_HOST=0.0.0.0 # default: 127.0.0.1 ``` Exposes tools: diff --git a/src/main.rs b/src/main.rs index 9d89eaf..0c2475e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,6 +5,8 @@ mod db; mod food; mod logging; mod mcp; +#[cfg(feature = "sse")] +mod sse; #[derive(Parser)] #[command(name = "chomp")] @@ -135,7 +137,17 @@ enum Commands { /// Show database stats Stats, /// Start MCP server (for AI assistants like Claude Desktop) - Serve, + Serve { + /// Transport mode: stdio, sse, or both + #[arg(long, default_value = "stdio")] + transport: String, + /// Port for SSE server (env: CHOMP_PORT) + #[arg(long, default_value_t = 3000, env = "CHOMP_PORT")] + port: u16, + /// Host for SSE server (env: CHOMP_HOST) + #[arg(long, default_value = "127.0.0.1", env = "CHOMP_HOST")] + host: String, + }, } fn main() -> Result<()> { @@ -308,8 +320,43 @@ fn main() -> Result<()> { println!("First entry: {}", stats.first_entry.unwrap_or_default()); println!("Last entry: {}", stats.last_entry.unwrap_or_default()); } - Some(Commands::Serve) => { - mcp::serve()?; + Some(Commands::Serve { + transport, + port, + host, + }) => { + match transport.as_str() { + "stdio" => mcp::serve_stdio()?, + #[cfg(feature = "sse")] + "sse" => { + let rt = tokio::runtime::Runtime::new()?; + rt.block_on(sse::serve_sse(port, &host))?; + } + #[cfg(feature = "sse")] + "both" => { + // Run SSE in a background thread, stdio on main + let host_clone = host.clone(); + let sse_handle = std::thread::spawn(move || { + let rt = tokio::runtime::Runtime::new().expect("tokio runtime"); + rt.block_on(sse::serve_sse(port, &host_clone)) + }); + // Brief startup window — check if SSE died immediately + std::thread::sleep(std::time::Duration::from_millis(100)); + if sse_handle.is_finished() { + match sse_handle.join() { + Ok(Err(e)) => anyhow::bail!("SSE server failed to start: {}", e), + Err(_) => anyhow::bail!("SSE server thread panicked"), + Ok(Ok(())) => anyhow::bail!("SSE server exited unexpectedly"), + } + } + mcp::serve_stdio()?; + } + #[cfg(not(feature = "sse"))] + "sse" | "both" => { + anyhow::bail!("SSE transport requires the 'sse' feature. Rebuild with: cargo build --features sse"); + } + _ => anyhow::bail!("Invalid transport: {}. Use stdio, sse, or both.", transport), + } } None => { // Default action: log food diff --git a/src/mcp.rs b/src/mcp.rs index 7068eda..9322351 100644 --- a/src/mcp.rs +++ b/src/mcp.rs @@ -11,89 +11,49 @@ const SERVER_NAME: &str = "chomp"; const SERVER_VERSION: &str = env!("CARGO_PKG_VERSION"); #[derive(Debug, Deserialize)] -struct JsonRpcRequest { +pub struct JsonRpcRequest { #[allow(dead_code)] - jsonrpc: String, - id: Option, - method: String, + pub jsonrpc: String, + pub id: Option, + pub method: String, #[serde(default)] - params: Value, + pub params: Value, } #[derive(Debug, Serialize)] -struct JsonRpcResponse { - jsonrpc: String, - id: Value, +pub struct JsonRpcResponse { + pub jsonrpc: String, + pub id: Value, #[serde(skip_serializing_if = "Option::is_none")] - result: Option, + pub result: Option, #[serde(skip_serializing_if = "Option::is_none")] - error: Option, + pub error: Option, } #[derive(Debug, Serialize)] -struct JsonRpcError { - code: i32, - message: String, +pub struct JsonRpcError { + pub code: i32, + pub message: String, } -pub fn serve() -> Result<()> { - let db = Database::open()?; - db.init()?; - - let stdin = std::io::stdin(); - let mut stdout = std::io::stdout(); - - for line in stdin.lock().lines() { - let line = line?; - if line.trim().is_empty() { - continue; - } - - let request: JsonRpcRequest = match serde_json::from_str(&line) { - Ok(r) => r, - Err(e) => { - let response = JsonRpcResponse { - jsonrpc: "2.0".to_string(), - id: Value::Null, - result: None, - error: Some(JsonRpcError { - code: -32700, - message: format!("Parse error: {}", e), - }), - }; - writeln!(stdout, "{}", serde_json::to_string(&response)?)?; - stdout.flush()?; - continue; - } - }; - - let response = handle_request(&db, &request); - writeln!(stdout, "{}", serde_json::to_string(&response)?)?; - stdout.flush()?; - } - - Ok(()) -} - -fn handle_request(db: &Database, request: &JsonRpcRequest) -> JsonRpcResponse { - let id = request.id.clone().unwrap_or(Value::Null); +/// Handle a JSON-RPC request and return a response. +/// Returns None for notifications (no id) that don't need a response. +pub fn handle_request(db: &Database, request: &JsonRpcRequest) -> Option { + // Per JSON-RPC 2.0 spec, requests without an id are notifications + // and MUST NOT receive a response. + let id = match &request.id { + Some(id) => id.clone(), + None => return None, + }; let result = match request.method.as_str() { "initialize" => handle_initialize(), "tools/list" => handle_tools_list(), "tools/call" => handle_tools_call(db, &request.params), - "notifications/initialized" => { - return JsonRpcResponse { - jsonrpc: "2.0".to_string(), - id, - result: Some(Value::Null), - error: None, - } - } _ => Err(anyhow::anyhow!("Method not found: {}", request.method)), }; - match result { + Some(match result { Ok(value) => JsonRpcResponse { jsonrpc: "2.0".to_string(), id, @@ -109,7 +69,51 @@ fn handle_request(db: &Database, request: &JsonRpcRequest) -> JsonRpcResponse { message: e.to_string(), }), }, + }) +} + +/// Parse a JSON line into a request, returning an error response on failure. +pub fn parse_request(line: &str) -> std::result::Result { + serde_json::from_str(line).map_err(|e| JsonRpcResponse { + jsonrpc: "2.0".to_string(), + id: Value::Null, + result: None, + error: Some(JsonRpcError { + code: -32700, + message: format!("Parse error: {}", e), + }), + }) +} + +/// Run the MCP server over stdio transport. +pub fn serve_stdio() -> Result<()> { + let db = Database::open()?; + db.init()?; + + let stdin = std::io::stdin(); + let mut stdout = std::io::stdout(); + + for line in stdin.lock().lines() { + let line = line?; + if line.trim().is_empty() { + continue; + } + + match parse_request(&line) { + Ok(request) => { + if let Some(response) = handle_request(&db, &request) { + writeln!(stdout, "{}", serde_json::to_string(&response)?)?; + stdout.flush()?; + } + } + Err(error_response) => { + writeln!(stdout, "{}", serde_json::to_string(&error_response)?)?; + stdout.flush()?; + } + } } + + Ok(()) } fn handle_initialize() -> Result { diff --git a/src/sse.rs b/src/sse.rs new file mode 100644 index 0000000..9c8367d --- /dev/null +++ b/src/sse.rs @@ -0,0 +1,158 @@ +use anyhow::Result; +use axum::{ + extract::{Query, State}, + http::{Method, StatusCode}, + response::{ + sse::{Event, KeepAlive}, + Sse, + }, + routing::{get, post}, + Json, Router, +}; +use serde::Deserialize; +use std::collections::HashMap; +use std::convert::Infallible; +use std::sync::Arc; +use tokio::sync::{mpsc, Mutex}; +use tokio_stream::wrappers::ReceiverStream; +use tower_http::cors::{Any, CorsLayer}; + +use crate::db::Database; +use crate::mcp::{self, JsonRpcRequest}; + +/// Per-session sender for SSE events. +type SessionTx = mpsc::Sender>; + +/// Shared state across all handlers. +struct AppState { + sessions: Mutex>, +} + +#[derive(Deserialize)] +struct MessageQuery { + #[serde(rename = "sessionId")] + session_id: String, +} + +/// Start the SSE MCP server on the given port/host. +pub async fn serve_sse(port: u16, host: &str) -> Result<()> { + let state = Arc::new(AppState { + sessions: Mutex::new(HashMap::new()), + }); + + let cors = CorsLayer::new() + .allow_origin(Any) + .allow_methods([Method::GET, Method::POST, Method::OPTIONS]) + .allow_headers(Any); + + let app = Router::new() + .route("/sse", get(sse_handler)) + .route("/message", post(message_handler)) + .route("/health", get(health_handler)) + .layer(cors) + .with_state(state); + + let addr: std::net::SocketAddr = format!("{}:{}", host, port).parse()?; + eprintln!("chomp MCP server (SSE) listening on http://{}", addr); + eprintln!(" SSE endpoint: http://{}/sse", addr); + eprintln!(" POST endpoint: http://{}/message", addr); + eprintln!(" Health check: http://{}/health", addr); + + let listener = tokio::net::TcpListener::bind(addr).await?; + axum::serve(listener, app).await?; + + Ok(()) +} + +/// GET /sse — client connects here, receives an SSE stream. +/// First event is `endpoint` with the POST URL containing the session ID. +async fn sse_handler( + State(state): State>, +) -> Sse>> { + let session_id = uuid::Uuid::new_v4().to_string(); + let (tx, rx) = mpsc::channel(32); + + // Send the endpoint event so the client knows where to POST + let endpoint_url = format!("/message?sessionId={}", session_id); + let _ = tx + .send(Ok(Event::default().event("endpoint").data(endpoint_url))) + .await; + + // Store session + let tx_clone = tx.clone(); + state.sessions.lock().await.insert(session_id.clone(), tx); + + // Clean up on disconnect: periodically check if the sender's receiver is gone + let state_clone = state.clone(); + let sid = session_id.clone(); + tokio::spawn(async move { + loop { + tokio::time::sleep(tokio::time::Duration::from_secs(30)).await; + if tx_clone.is_closed() { + state_clone.sessions.lock().await.remove(&sid); + break; + } + } + }); + + Sse::new(ReceiverStream::new(rx)).keep_alive(KeepAlive::default()) +} + +/// POST /message?sessionId=xxx — client sends JSON-RPC requests here. +async fn message_handler( + State(state): State>, + Query(query): Query, + Json(request): Json, +) -> StatusCode { + // Lazy cleanup: check if session is dead before processing + let mut sessions = state.sessions.lock().await; + let tx = match sessions.get(&query.session_id) { + Some(tx) if tx.is_closed() => { + sessions.remove(&query.session_id); + return StatusCode::NOT_FOUND; + } + Some(tx) => tx.clone(), + None => return StatusCode::NOT_FOUND, + }; + drop(sessions); // Release lock before blocking DB work + + // Open a fresh DB connection per request (SQLite handles concurrent readers) + let db = match Database::open().and_then(|db| { + db.init()?; + Ok(db) + }) { + Ok(db) => db, + Err(err) => { + eprintln!("Database error in message_handler: {}", err); + return StatusCode::SERVICE_UNAVAILABLE; + } + }; + + if let Some(response) = mcp::handle_request(&db, &request) { + let json = match serde_json::to_string(&response) { + Ok(j) => j, + Err(e) => { + eprintln!("Failed to serialize JSON-RPC response: {e}"); + return StatusCode::INTERNAL_SERVER_ERROR; + } + }; + + let event = Event::default().event("message").data(json); + if tx.send(Ok(event)).await.is_err() { + eprintln!("SSE client disconnected, could not deliver response"); + return StatusCode::INTERNAL_SERVER_ERROR; + } + } + + StatusCode::ACCEPTED +} + +/// GET /health — simple health check. +async fn health_handler() -> Json { + Json(serde_json::json!({ + "status": "healthy", + "transport": "sse", + "server": "chomp", + "version": env!("CARGO_PKG_VERSION") + })) +}