diff --git a/.github/workflows/skit.yml b/.github/workflows/skit.yml index 2356ce04..f17b6206 100644 --- a/.github/workflows/skit.yml +++ b/.github/workflows/skit.yml @@ -57,6 +57,7 @@ jobs: run: | cargo clippy --locked --workspace --all-targets -- -D warnings cargo clippy --locked -p streamkit-server --all-targets --features "moq" -- -D warnings + cargo clippy --locked -p streamkit-server --all-targets --features "mcp" -- -D warnings - name: Install cargo-deny run: cargo install --locked cargo-deny @@ -111,6 +112,7 @@ jobs: run: | cargo test --locked --workspace -- --skip gpu_tests:: cargo test --locked -p streamkit-server --features "moq" + cargo test --locked -p streamkit-server --features "mcp" test-gpu: name: Test (GPU) diff --git a/Cargo.lock b/Cargo.lock index e3790295..d3045bf1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -413,7 +413,7 @@ dependencies = [ "log", "num-rational", "num-traits", - "pastey", + "pastey 0.1.1", "rayon", "thiserror 2.0.18", "v_frame", @@ -1832,7 +1832,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -2948,7 +2948,7 @@ checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" dependencies = [ "hermit-abi", "libc", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -3039,7 +3039,7 @@ dependencies = [ "portable-atomic", "portable-atomic-util", "serde_core", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -3735,7 +3735,7 @@ version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -4231,6 +4231,12 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "35fb2e5f958ec131621fdd531e9fc186ed768cbe395337403ae56c17a74c68ec" +[[package]] +name = "pastey" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5a797f0e07bdf071d15742978fc3128ec6c22891c31a3a931513263904c982a" + [[package]] name = "pear" version = "0.2.9" @@ -5235,6 +5241,50 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rmcp" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67d69668de0b0ccd9cc435f700f3b39a7861863cf37a15e1f304ea78688a4826" +dependencies = [ + "async-trait", + "base64 0.22.1", + "bytes", + "chrono", + "futures", + "http", + "http-body", + "http-body-util", + "pastey 0.2.2", + "pin-project-lite", + "rand 0.10.1", + "rmcp-macros", + "schemars 1.2.1", + "serde", + "serde_json", + "sse-stream", + "thiserror 2.0.18", + "tokio", + "tokio-stream", + "tokio-util 0.7.18", + "tower-service", + "tracing", + "uuid", +] + +[[package]] +name = "rmcp-macros" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48fdc01c81097b0aed18633e676e269fefa3a78ec1df56b4fe597c1241b92025" +dependencies = [ + "darling", + "proc-macro2", + "quote", + "serde_json", + "syn 2.0.117", +] + [[package]] name = "roxmltree" version = "0.21.1" @@ -5407,7 +5457,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.12.1", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -5485,7 +5535,7 @@ dependencies = [ "security-framework", "security-framework-sys", "webpki-root-certs", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -5506,7 +5556,7 @@ dependencies = [ "security-framework", "security-framework-sys", "webpki-root-certs", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -5616,6 +5666,7 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2b42f36aa1cd011945615b92222f6bf73c599a102a300334cd7f8dbeec726cc" dependencies = [ + "chrono", "dyn-clone", "ref-cast", "schemars_derive", @@ -5692,7 +5743,7 @@ checksum = "09fbdfe7a27a1b1633dfc0c4c8e65940b8d819c5ddb9cca48ebc3223b00c8b14" dependencies = [ "ahash", "annotate-snippets", - "base64 0.21.7", + "base64 0.22.1", "encoding_rs_io", "getrandom 0.3.4", "nohash-hasher", @@ -6031,6 +6082,19 @@ dependencies = [ "bitflags 2.11.0", ] +[[package]] +name = "sse-stream" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c5e6deb40826033bd7b11c7ef25ef71193fabd71f680f40dd16538a2704d2f4" +dependencies = [ + "bytes", + "futures-util", + "http-body", + "http-body-util", + "pin-project-lite", +] + [[package]] name = "stable_deref_trait" version = "1.2.1" @@ -6054,6 +6118,7 @@ name = "streamkit-api" version = "0.2.0" dependencies = [ "indexmap 2.14.0", + "schemars 1.2.1", "serde", "serde-saphyr", "serde_json", @@ -6267,6 +6332,7 @@ dependencies = [ "getrandom 0.4.2", "glob", "hex", + "http", "http-body-util", "hyper", "image", @@ -6284,6 +6350,7 @@ dependencies = [ "opus", "pprof", "reqwest 0.13.2", + "rmcp", "rust-embed", "rustls", "schemars 1.2.1", @@ -6632,7 +6699,7 @@ dependencies = [ "getrandom 0.4.2", "once_cell", "rustix 1.1.4", - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] @@ -8482,7 +8549,7 @@ version = "0.1.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" dependencies = [ - "windows-sys 0.59.0", + "windows-sys 0.61.2", ] [[package]] diff --git a/apps/skit/Cargo.toml b/apps/skit/Cargo.toml index c9a2be8a..f8577d5d 100644 --- a/apps/skit/Cargo.toml +++ b/apps/skit/Cargo.toml @@ -137,6 +137,18 @@ moq-lite = { version = "0.15.13", optional = true } blake2 = "0.10.6" async-stream = "0.3.6" +# MCP (Model Context Protocol) server (optional) +rmcp = { version = "1.5", features = [ + "server", # ServerHandler trait + service runner + "transport-streamable-http-server", # Streamable HTTP transport (tower service) + "transport-streamable-http-server-session", # Stateful session management (LocalSessionManager) + "transport-io", # STDIO transport (stdin/stdout) + "macros", # #[tool], #[tool_router] proc macros + "schemars", # JSON Schema generation for tool arguments +], optional = true } +# Used by the MCP module to extract HTTP request parts from rmcp's request context. +http = { version = "1", optional = true } + [features] default = ["script", "compositor", "gpu", "moq"] tokio-console = ["console-subscriber"] @@ -145,6 +157,7 @@ profiling = ["dep:pprof", "dep:tikv-jemallocator", "dep:jemalloc_pprof"] # Use this to find hot allocation sites. Output is written on graceful shutdown. dhat-heap = ["dep:dhat"] moq = ["dep:moq-native", "dep:moq-lite"] +mcp = ["dep:rmcp", "dep:http"] script = ["streamkit-nodes/script", "streamkit-engine/script"] compositor = ["streamkit-nodes/compositor", "streamkit-engine/compositor"] gpu = ["compositor", "streamkit-nodes/gpu", "streamkit-engine/gpu"] diff --git a/apps/skit/src/cli.rs b/apps/skit/src/cli.rs index 1bb52d0a..bb1b7a27 100644 --- a/apps/skit/src/cli.rs +++ b/apps/skit/src/cli.rs @@ -51,6 +51,9 @@ pub struct Cli { pub enum Commands { /// Starts the skit server Serve, + /// Run the MCP server over STDIO (for MCP client integration) + #[cfg(feature = "mcp")] + Mcp, /// Manage configuration #[command(subcommand)] Config(ConfigCommands), @@ -251,6 +254,39 @@ fn log_startup_info(config: &config::Config) { ); } +/// Handle the "mcp" command — start the MCP server over STDIO. +/// Exits the process on error with status code 1. +/// +/// Uses [`crate::logging::init_logging_stderr`] so that tracing output goes +/// to stderr, keeping stdout clean for the JSON-RPC message stream. +#[cfg(feature = "mcp")] +#[allow(clippy::disallowed_macros)] +async fn handle_mcp_command(config_path: &str, _init_logging: LogInitFn) { + let config_result = match config::load(config_path) { + Ok(result) => result, + Err(e) => { + eprintln!("Failed to load configuration: {e}"); + std::process::exit(1); + }, + }; + + let _log_guard = match crate::logging::init_logging_stderr( + &config_result.config.log, + &config_result.config.telemetry, + ) { + Ok(guard) => guard, + Err(e) => { + eprintln!("Failed to initialize logging: {e}"); + std::process::exit(1); + }, + }; + + if let Err(e) = crate::server::start_mcp_stdio(&config_result.config).await { + error!(error = %e, "Failed to start MCP STDIO server"); + std::process::exit(1); + } +} + /// Handle the "serve" command - start the server /// Exits the process on error with status code 1 // Allow eprintln before logging is initialized (CLI output) @@ -724,6 +760,10 @@ pub async fn handle_command(cli: &Cli, init_logging: LogInitFn) { Commands::Serve => { handle_serve_command(&cli.config, init_logging).await; }, + #[cfg(feature = "mcp")] + Commands::Mcp => { + handle_mcp_command(&cli.config, init_logging).await; + }, Commands::Config(ConfigCommands::Default) => { handle_config_default_command(); }, diff --git a/apps/skit/src/config.rs b/apps/skit/src/config.rs index ce6ee7cd..0238e6e8 100644 --- a/apps/skit/src/config.rs +++ b/apps/skit/src/config.rs @@ -857,6 +857,103 @@ impl Default for SecurityConfig { } } +fn default_mcp_endpoint() -> String { + "/api/v1/mcp".to_string() +} + +/// MCP (Model Context Protocol) server configuration. +#[derive(Deserialize, Serialize, Debug, Clone, JsonSchema)] +pub struct McpConfig { + /// Enable the embedded MCP endpoint (default: false). + #[serde(default)] + pub enabled: bool, + /// Streamable HTTP endpoint path (default: "/api/v1/mcp"). + #[serde(default = "default_mcp_endpoint")] + pub endpoint: String, + /// Hostnames accepted by the MCP transport's `Host` header check + /// (DNS rebinding protection). + /// + /// When empty (default), the check is disabled — acceptable when the + /// endpoint sits behind `auth_guard_middleware` and + /// `origin_guard_middleware`. For deployments exposed to untrusted + /// networks, set this to the public hostname(s) of the server. + #[serde(default)] + pub allowed_hosts: Vec, +} + +impl Default for McpConfig { + fn default() -> Self { + Self { enabled: false, endpoint: default_mcp_endpoint(), allowed_hosts: Vec::new() } + } +} + +impl McpConfig { + /// Validate MCP configuration. + /// + /// The endpoint MUST live under `/api/` so that `auth_guard_middleware`, + /// `origin_guard_middleware`, CORS, tracing, and metrics all apply. + /// It must NOT start with `/api/v1/auth/` because that prefix is + /// short-circuited by the auth guard. Only paths matching + /// `/api/v/mcp` (with optional trailing subpath) are accepted. + /// + /// # Errors + /// + /// Returns an error string describing the misconfiguration. + pub fn validate(&self) -> Result<(), String> { + if !self.enabled { + return Ok(()); + } + + let ep = &self.endpoint; + + // Must start with /api/ + if !ep.starts_with("/api/") { + return Err(format!( + "mcp.endpoint must start with /api/ to ensure auth and origin guards apply. Got: '{ep}'" + )); + } + + // Must not sit under the auth prefix (auth_guard_middleware short-circuits it) + if ep.starts_with("/api/v1/auth/") || ep == "/api/v1/auth" { + return Err(format!( + "mcp.endpoint must not start with /api/v1/auth/ (bypasses auth guard). Got: '{ep}'" + )); + } + + // Reject unsafe path segments (e.g. ".." which could escape the + // intended mount point). This is conservative — it also rejects + // legitimate paths like "/api/v1/mcp/foo..bar" — but mount paths + // should never need such patterns. + if ep.contains("..") { + return Err(format!("mcp.endpoint must not contain unsafe path segments (..): '{ep}'")); + } + + // Must match /api/v/mcp or /api/v/mcp/... + let parts: Vec<&str> = ep.trim_start_matches('/').split('/').collect(); + if parts.len() < 3 { + return Err(format!("mcp.endpoint must be at least /api/v/mcp. Got: '{ep}'")); + } + // parts[0] = "api", parts[1] = "v", parts[2] = "mcp" + let version_part = parts[1]; + if !version_part.starts_with('v') + || !version_part[1..].chars().all(|c| c.is_ascii_digit()) + || version_part.len() < 2 + { + return Err(format!( + "mcp.endpoint version segment must be v (e.g. v1). Got: '{version_part}' in '{ep}'" + )); + } + if parts[2] != "mcp" { + return Err(format!( + "mcp.endpoint third segment must be 'mcp'. Got: '{}' in '{ep}'", + parts[2] + )); + } + + Ok(()) + } +} + /// Root configuration for the StreamKit server. #[derive(Deserialize, Serialize, Default, Debug, Clone, JsonSchema)] pub struct Config { @@ -892,6 +989,9 @@ pub struct Config { #[serde(default)] pub auth: AuthConfig, + + #[serde(default)] + pub mcp: McpConfig, } #[derive(Debug)] @@ -926,6 +1026,10 @@ pub fn load(config_path: &str) -> Result> normalize_permissions_config(&mut config); + if let Err(e) = config.mcp.validate() { + return Err(Box::new(figment::Error::from(e))); + } + Ok(ConfigLoadResult { config, file_missing }) } diff --git a/apps/skit/src/lib.rs b/apps/skit/src/lib.rs index aa1ce36d..1dfc0e1b 100644 --- a/apps/skit/src/lib.rs +++ b/apps/skit/src/lib.rs @@ -12,6 +12,8 @@ pub mod logging; pub mod marketplace; pub mod marketplace_installer; pub mod marketplace_security; +#[cfg(feature = "mcp")] +pub mod mcp; #[cfg(feature = "moq")] pub mod moq_gateway; pub mod mse_gateway; diff --git a/apps/skit/src/logging.rs b/apps/skit/src/logging.rs index b8aad75a..09ab921f 100644 --- a/apps/skit/src/logging.rs +++ b/apps/skit/src/logging.rs @@ -31,6 +31,18 @@ fn make_console_layer(console_level: tracing::Level) -> DynLayer { tracing_subscriber::fmt::layer().with_filter(env_filter_or_level(console_level)).boxed() } +/// Console layer that writes to **stderr** instead of stdout. +/// +/// Used by the STDIO MCP transport so that tracing output does not interfere +/// with the JSON-RPC message stream on stdout. +#[cfg(feature = "mcp")] +fn make_stderr_console_layer(console_level: tracing::Level) -> DynLayer { + tracing_subscriber::fmt::layer() + .with_writer(std::io::stderr) + .with_filter(env_filter_or_level(console_level)) + .boxed() +} + fn make_file_layer( non_blocking: tracing_appender::non_blocking::NonBlocking, file_level: tracing::Level, @@ -181,3 +193,79 @@ pub fn init_logging( Ok(guard) } + +/// Variant of [`init_logging`] that sends console output to **stderr**. +/// +/// This is required when stdout is reserved for a protocol stream (e.g. the +/// MCP STDIO transport). All other behaviour (file logging, OpenTelemetry, +/// tokio-console) is identical to [`init_logging`]. +/// +/// # Errors +/// +/// Same as [`init_logging`]. +#[cfg(feature = "mcp")] +#[allow(clippy::too_many_lines)] +pub fn init_logging_stderr( + log_config: &config::LogConfig, + telemetry_config: &config::TelemetryConfig, +) -> Result, Box> { + let mut guard = None; + + let enable_otel = should_enable_otel_tracing(telemetry_config); + + let setup_file_appender = |log_config: &config::LogConfig| -> Result< + (tracing_appender::non_blocking::NonBlocking, tracing_appender::non_blocking::WorkerGuard), + Box, + > { + let log_path = std::path::Path::new(&log_config.file_path); + let log_dir = log_path.parent().unwrap_or_else(|| std::path::Path::new(".")); + let log_filename = log_path.file_name().unwrap_or_else(|| std::ffi::OsStr::new("skit.log")); + + if let Err(e) = std::fs::create_dir_all(log_dir) { + return Err( + format!("Failed to create log directory {}: {}", log_dir.display(), e).into() + ); + } + + let file_appender = tracing_appender::rolling::never(log_dir, log_filename); + Ok(tracing_appender::non_blocking(file_appender)) + }; + + let mut layers: Vec = Vec::new(); + + #[cfg(feature = "tokio-console")] + if telemetry_config.tokio_console { + let tokio_console_layer = + console_subscriber::ConsoleLayer::builder().with_default_env().spawn(); + layers.push(tokio_console_layer.boxed()); + } + + if log_config.file_enable { + let (non_blocking, file_guard) = setup_file_appender(log_config)?; + guard = Some(file_guard); + let file_level: tracing::Level = log_config.file_level.clone().into(); + layers.push(make_file_layer(non_blocking, file_level, log_config.file_format)); + } + + if log_config.console_enable { + let console_level: tracing::Level = log_config.console_level.clone().into(); + layers.push(make_stderr_console_layer(console_level)); + } + + if !log_config.console_enable && !log_config.file_enable { + layers.push(make_stderr_console_layer(tracing::Level::INFO)); + } + + if enable_otel { + let telemetry_default_level = telemetry_default_level_for_config(log_config); + layers.push( + telemetry::init_tracing_with_otlp(telemetry_config)? + .with_filter(env_filter_or_level(telemetry_default_level)) + .boxed(), + ); + } + + tracing_subscriber::registry().with(layers).init(); + + Ok(guard) +} diff --git a/apps/skit/src/main.rs b/apps/skit/src/main.rs index 61d46f3b..70b2d622 100644 --- a/apps/skit/src/main.rs +++ b/apps/skit/src/main.rs @@ -40,6 +40,8 @@ mod logging; mod marketplace; mod marketplace_installer; mod marketplace_security; +#[cfg(feature = "mcp")] +mod mcp; #[cfg(feature = "moq")] mod moq_gateway; mod mse_gateway; diff --git a/apps/skit/src/mcp/mod.rs b/apps/skit/src/mcp/mod.rs new file mode 100644 index 00000000..37719c3e --- /dev/null +++ b/apps/skit/src/mcp/mod.rs @@ -0,0 +1,773 @@ +// SPDX-FileCopyrightText: © 2025 StreamKit Contributors +// +// SPDX-License-Identifier: MPL-2.0 + +//! Embedded MCP (Model Context Protocol) server for StreamKit. +//! +//! Exposes StreamKit control-plane capabilities (node discovery, pipeline +//! validation, session management) as MCP tools over Streamable HTTP or +//! STDIO. The endpoint reuses the existing Axum application state, auth, +//! and permission model — no separate bridge process required. +//! +//! # Security — STDIO transport +//! +//! `skit mcp` runs unauthenticated: the STDIO caller is implicitly trusted +//! with admin-level permissions (see [`extract_auth`]). Only expose its +//! stdin to trusted local processes (e.g. Devin, Claude Desktop, Cursor). + +mod oneshot; +mod prompts; + +use std::sync::Arc; + +use rmcp::handler::server::router::prompt::PromptRouter; +use rmcp::handler::server::router::tool::ToolRouter; +use rmcp::handler::server::wrapper::Parameters; +use rmcp::model::{ + CallToolResult, Content, GetPromptRequestParams, GetPromptResult, ListPromptsResult, + PaginatedRequestParams, ServerCapabilities, ServerInfo, +}; +use rmcp::schemars; +use rmcp::service::RequestContext; +use rmcp::transport::streamable_http_server::session::local::LocalSessionManager; +use rmcp::transport::streamable_http_server::{StreamableHttpServerConfig, StreamableHttpService}; +use rmcp::{prompt_handler, tool, tool_handler, tool_router}; +use rmcp::{ErrorData as McpError, RoleServer, ServerHandler}; + +use serde::{Deserialize, Serialize}; +use streamkit_api::Pipeline; +use streamkit_core::NodeDefinition; +use tracing::{info, warn}; + +use crate::permissions::Permissions; +use crate::session::Session; +use crate::state::AppState; + +// --------------------------------------------------------------------------- +// Auth helper +// --------------------------------------------------------------------------- + +/// Extract `(role_name, permissions)` from the HTTP request parts that `rmcp` +/// injects into the request-context extensions. +/// +/// For STDIO transport there are no HTTP parts in the context — the caller is +/// a local, trusted process. In that case we fall back to admin-level +/// permissions (resolved via the configured `default_role`, which defaults to +/// `"admin"`). +#[allow(clippy::unnecessary_wraps)] +fn extract_auth( + ctx: &RequestContext, + app_state: &Arc, +) -> Result<(String, Permissions), McpError> { + Ok(ctx.extensions.get::().map_or_else( + || { + // STDIO transport — no HTTP context. Treat as local/trusted. + let empty_headers = axum::http::HeaderMap::new(); + crate::role_extractor::get_role_and_permissions(&empty_headers, app_state) + }, + |parts| crate::role_extractor::get_role_and_permissions(&parts.headers, app_state), + )) +} + +// --------------------------------------------------------------------------- +// Shared helpers +// --------------------------------------------------------------------------- + +/// Look up a session by name or ID, verify permission, and return the session +/// along with the caller's role name and permissions. +async fn resolve_session( + app_state: &Arc, + session_id: &str, + ctx: &RequestContext, + check_perm: impl FnOnce(&Permissions) -> bool, + perm_label: &str, +) -> Result<(Session, String, Permissions), McpError> { + let (role_name, perms) = extract_auth(ctx, app_state)?; + + if !check_perm(&perms) { + return Err(McpError::invalid_request( + format!("Permission denied: {perm_label} required"), + None, + )); + } + + let session = { + let sm = app_state.session_manager.lock().await; + sm.get_session_by_name_or_id(session_id) + }; + + let Some(session) = session else { + return Err(McpError::invalid_params(format!("Session '{session_id}' not found"), None)); + }; + + if !perms.access_all_sessions && session.created_by.as_ref().is_some_and(|c| c != &role_name) { + return Err(McpError::invalid_request( + "Permission denied: you do not own this session", + None, + )); + } + + Ok((session, role_name, perms)) +} + +/// Serialize a value to pretty-printed JSON and wrap it in a successful +/// `CallToolResult`. +fn json_tool_result(value: &T) -> Result { + let json = serde_json::to_string_pretty(value) + .map_err(|e| McpError::internal_error(format!("serialization error: {e}"), None))?; + Ok(CallToolResult::success(vec![Content::text(json)])) +} + +/// Return permission-filtered node definitions, including synthetic oneshot +/// nodes. +fn filtered_node_definitions( + app_state: &Arc, + perms: &Permissions, +) -> Result, McpError> { + let mut definitions = app_state + .engine + .registry + .read() + .map_err(|e| McpError::internal_error(format!("Failed to read node registry: {e}"), None))? + .definitions(); + + definitions.extend(crate::server::synthetic_node_definitions()); + + definitions.retain(|def| { + if !perms.is_node_allowed(&def.kind) { + return false; + } + if def.kind.starts_with("plugin::") { + return perms.is_plugin_allowed(&def.kind); + } + true + }); + + Ok(definitions) +} + +/// Assemble the full pipeline state for a session, merging node states, view +/// data, and runtime schemas into the cloned pipeline. +async fn assemble_pipeline_state(session: &Session) -> Pipeline { + let node_states = session.get_node_states().await.unwrap_or_default(); + let node_view_data = session.get_node_view_data().await.unwrap_or_default(); + let runtime_schemas = session.get_runtime_schemas().await.unwrap_or_default(); + + let mut api_pipeline = { + let pipeline = session.pipeline.lock().await; + pipeline.clone() + }; + for (id, node) in &mut api_pipeline.nodes { + node.state = node_states.get(id).cloned(); + } + if !node_view_data.is_empty() { + api_pipeline.view_data = Some(Arc::unwrap_or_clone(node_view_data)); + } + if !runtime_schemas.is_empty() { + api_pipeline.runtime_schemas = Some(runtime_schemas); + } + + api_pipeline +} + +// --------------------------------------------------------------------------- +// MCP tool argument structs +// --------------------------------------------------------------------------- + +#[derive(Debug, Deserialize, schemars::JsonSchema)] +pub struct OneshotInput { + /// Input field name matching a node ID in the pipeline (e.g., "input"). + pub field: String, + /// Path to the input file on the local filesystem. + pub path: String, +} + +#[derive(Debug, Deserialize, schemars::JsonSchema)] +pub struct GenerateOneshotCommandArgs { + /// Pipeline YAML for the oneshot run. + pub yaml: String, + /// Input file(s) to include in the request. + pub inputs: Vec, + /// Path where the output should be saved. + pub output: String, + /// Server URL (defaults to "http://localhost:4545"). + #[serde(default)] + pub server_url: Option, + /// Command format: "curl" or "skit-cli". Defaults to "curl". + #[serde(default)] + pub format: Option, +} + +#[derive(Debug, Deserialize, schemars::JsonSchema)] +pub struct ValidatePipelineArgs { + /// Pipeline YAML to validate. + pub yaml: String, + /// Optional mode: "dynamic" or "oneshot". + #[serde(default)] + pub mode: Option, +} + +#[derive(Debug, Deserialize, schemars::JsonSchema)] +pub struct CreateSessionArgs { + /// Pipeline YAML for the new session. + pub yaml: String, + /// Optional human-readable session name. + #[serde(default)] + pub name: Option, +} + +#[derive(Debug, Deserialize, schemars::JsonSchema)] +pub struct SessionIdArgs { + /// Session ID or name. + pub session_id: String, +} + +#[derive(Debug, Deserialize, schemars::JsonSchema)] +pub struct ValidateBatchArgs { + /// Session ID or name. + pub session_id: String, + /// List of batch operations to validate. + pub operations: Vec, +} + +#[derive(Debug, Deserialize, schemars::JsonSchema)] +pub struct ApplyBatchArgs { + /// Session ID or name. + pub session_id: String, + /// List of batch operations to apply atomically. + pub operations: Vec, +} + +#[derive(Debug, Deserialize, schemars::JsonSchema)] +pub struct TuneNodeArgs { + /// Session ID or name. + pub session_id: String, + /// Node ID to send the control message to. + pub node_id: String, + /// The control message (e.g., UpdateParams with a JSON value). + pub message: streamkit_core::control::NodeControlMessage, +} + +// --------------------------------------------------------------------------- +// StreamKit MCP service +// --------------------------------------------------------------------------- + +/// StreamKit MCP service implementing `rmcp::ServerHandler`. +#[derive(Clone)] +pub struct StreamKitMcp { + app_state: Arc, + tool_router: ToolRouter, + prompt_router: PromptRouter, +} + +#[tool_router] +impl StreamKitMcp { + pub fn new(app_state: Arc) -> Self { + Self { + app_state, + tool_router: Self::tool_router(), + prompt_router: prompts::create_prompt_router(), + } + } + + // -- list_nodes -------------------------------------------------------- + + #[tool( + description = "List available StreamKit node types with their schemas, pins, and categories. Returns permission-filtered node definitions including synthetic oneshot nodes." + )] + async fn list_nodes( + &self, + ctx: RequestContext, + ) -> Result { + let (_role_name, perms) = extract_auth(&ctx, &self.app_state)?; + + let definitions = filtered_node_definitions(&self.app_state, &perms)?; + + info!(count = definitions.len(), "MCP list_nodes"); + + json_tool_result(&definitions) + } + + // -- validate_pipeline ------------------------------------------------- + + #[tool( + description = "Validate a StreamKit pipeline YAML without creating a session. Returns diagnostics (errors, warnings) and the parsed graph. Optionally pass mode='dynamic' or mode='oneshot' to apply mode-specific rules." + )] + async fn validate_pipeline( + &self, + Parameters(args): Parameters, + ctx: RequestContext, + ) -> Result { + let (_role_name, perms) = extract_auth(&ctx, &self.app_state)?; + + if !perms.create_sessions { + return Err(McpError::invalid_request( + "Permission denied: create_sessions required", + None, + )); + } + + let mode = match args.mode.as_deref() { + Some("dynamic") => Some(crate::server::PipelineMode::Dynamic), + Some("oneshot") => Some(crate::server::PipelineMode::Oneshot), + None => None, + Some(other) => { + return Err(McpError::invalid_params( + format!("Invalid mode '{other}'. Must be 'dynamic' or 'oneshot'."), + None, + )); + }, + }; + + let response = + crate::server::validate_pipeline_yaml(&self.app_state, &perms, &args.yaml, mode) + .map_err(|e| McpError::internal_error(e, None))?; + + json_tool_result(&response) + } + + // -- create_session ---------------------------------------------------- + + #[tool( + description = "Create a new dynamic StreamKit session from pipeline YAML. Returns the session ID, generated name, and creation timestamp." + )] + async fn create_session( + &self, + Parameters(args): Parameters, + ctx: RequestContext, + ) -> Result { + let (role_name, perms) = extract_auth(&ctx, &self.app_state)?; + + if !perms.create_sessions { + return Err(McpError::invalid_request( + "Permission denied: cannot create sessions", + None, + )); + } + + let r = crate::server::create_dynamic_session( + &self.app_state, + &args.yaml, + args.name, + role_name, + &perms, + ) + .await + .map_err(|e| match e { + crate::server::CreateSessionError::InvalidInput(msg) => { + McpError::invalid_params(msg, None) + }, + crate::server::CreateSessionError::Forbidden(msg) + | crate::server::CreateSessionError::Conflict(msg) + | crate::server::CreateSessionError::LimitReached(msg) => { + McpError::invalid_request(msg, None) + }, + crate::server::CreateSessionError::Internal(msg) => McpError::internal_error(msg, None), + })?; + + let result = serde_json::json!({ + "session_id": r.session_id, + "name": r.name, + "created_at": r.created_at, + }); + json_tool_result(&result) + } + + // -- list_sessions ----------------------------------------------------- + + #[tool( + description = "List active StreamKit sessions. Returns session IDs, names, and creation timestamps." + )] + async fn list_sessions( + &self, + ctx: RequestContext, + ) -> Result { + let (role_name, perms) = extract_auth(&ctx, &self.app_state)?; + + if !perms.list_sessions { + return Err(McpError::invalid_request("Permission denied: cannot list sessions", None)); + } + + let sessions = self.app_state.session_manager.lock().await.list_sessions(); + let infos: Vec = sessions + .into_iter() + .filter(|s| { + if perms.access_all_sessions { + return true; + } + s.created_by.as_ref().is_none_or(|c| c == &role_name) + }) + .map(|s| streamkit_api::SessionInfo { + id: s.id, + name: s.name, + created_at: crate::session::system_time_to_rfc3339(s.created_at), + }) + .collect(); + + info!(count = infos.len(), "MCP list_sessions"); + + json_tool_result(&infos) + } + + // -- get_pipeline ------------------------------------------------------ + + #[tool( + description = "Get the full pipeline state for a StreamKit session, including nodes, connections, node states, view data, and runtime schemas." + )] + async fn get_pipeline( + &self, + Parameters(args): Parameters, + ctx: RequestContext, + ) -> Result { + let (session, _role_name, _perms) = resolve_session( + &self.app_state, + &args.session_id, + &ctx, + |p| p.list_sessions, + "list_sessions", + ) + .await?; + + let api_pipeline = assemble_pipeline_state(&session).await; + + info!(session_id = %args.session_id, "MCP get_pipeline"); + + json_tool_result(&api_pipeline) + } + + // -- generate_oneshot_command ------------------------------------------- + + #[tool( + description = "Generate a curl or skit-cli command to execute a oneshot (batch processing) pipeline. The oneshot runs through the HTTP data plane (POST /api/v1/process), not through MCP. Use validate_pipeline with mode='oneshot' first to ensure the YAML is valid." + )] + async fn generate_oneshot_command( + &self, + Parameters(args): Parameters, + ctx: RequestContext, + ) -> Result { + let (_role_name, perms) = extract_auth(&ctx, &self.app_state)?; + + if !perms.create_sessions { + return Err(McpError::invalid_request( + "Permission denied: create_sessions required", + None, + )); + } + + // Validate the YAML before generating a command. + let validation = crate::server::validate_pipeline_yaml( + &self.app_state, + &perms, + &args.yaml, + Some(crate::server::PipelineMode::Oneshot), + ) + .map_err(|e| McpError::internal_error(e, None))?; + + if !validation.valid { + let pretty = serde_json::to_string_pretty(&validation) + .map_err(|e| McpError::internal_error(format!("serialization error: {e}"), None))?; + return Ok(CallToolResult::success(vec![Content::text(format!( + "Pipeline validation failed. Fix the errors before generating a command:\n{pretty}" + ))])); + } + + let server_url = args.server_url.as_deref().unwrap_or("http://localhost:4545"); + let format = args.format.as_deref().unwrap_or("curl"); + + let command = match format { + "curl" => { + oneshot::generate_curl_command(&args.yaml, &args.inputs, &args.output, server_url) + }, + "skit-cli" => oneshot::generate_skit_cli_command( + &args.yaml, + &args.inputs, + &args.output, + server_url, + ), + other => { + return Err(McpError::invalid_params( + format!("Invalid format '{other}'. Must be 'curl' or 'skit-cli'."), + None, + )); + }, + }; + + info!(format, "MCP generate_oneshot_command"); + + Ok(CallToolResult::success(vec![Content::text(command)])) + } + + // -- destroy_session --------------------------------------------------- + + #[tool( + description = "Destroy (stop and remove) a StreamKit session. Shuts down the engine and frees all resources." + )] + async fn destroy_session( + &self, + Parameters(args): Parameters, + ctx: RequestContext, + ) -> Result { + let (role_name, perms) = extract_auth(&ctx, &self.app_state)?; + + if !perms.destroy_sessions { + return Err(McpError::invalid_request( + "Permission denied: cannot destroy sessions", + None, + )); + } + + let removed_session = { + let mut sm = self.app_state.session_manager.lock().await; + let Some(session) = sm.get_session_by_name_or_id(&args.session_id) else { + return Err(McpError::invalid_params( + format!("Session '{}' not found", args.session_id), + None, + )); + }; + + if !perms.access_all_sessions + && session.created_by.as_ref().is_some_and(|c| c != &role_name) + { + warn!( + session_id = %args.session_id, + role = %role_name, + "MCP: blocked attempt to destroy session: not owner" + ); + return Err(McpError::invalid_request( + "Permission denied: you do not own this session", + None, + )); + } + + sm.remove_session_by_id(&session.id) + }; + + let Some(session) = removed_session else { + return Err(McpError::invalid_params( + format!("Session '{}' not found", args.session_id), + None, + )); + }; + + let destroyed_id = session.id.clone(); + + // Broadcast event + let event = streamkit_api::Event { + message_type: streamkit_api::MessageType::Event, + correlation_id: None, + payload: streamkit_api::EventPayload::SessionDestroyed { + session_id: destroyed_id.clone(), + }, + }; + if let Err(e) = self.app_state.event_tx.send(crate::state::BroadcastEvent::to_all(event)) { + tracing::error!("Failed to broadcast SessionDestroyed event: {}", e); + } + + // Background shutdown + let shutdown_id = destroyed_id.clone(); + let tracker = self.app_state.shutdown_tracker.clone(); + let handle = tokio::spawn(async move { + #[cfg(feature = "moq")] + crate::server::preview::teardown_all_previews(&session).await; + + if let Err(e) = session.shutdown_and_wait().await { + warn!(session_id = %shutdown_id, error = %e, "Error during engine shutdown"); + opentelemetry::global::meter("skit_server") + .u64_counter("session.shutdown.errors") + .build() + .add(1, &[]); + } else { + info!(session_id = %shutdown_id, "Session destroyed successfully via MCP"); + } + }); + tracker.track(handle).await; + + info!(session_id = %destroyed_id, "MCP destroy_session"); + + let result = serde_json::json!({ "session_id": destroyed_id }); + json_tool_result(&result) + } + + // -- validate_batch ---------------------------------------------------- + + #[tool( + description = "Validate a batch of graph mutations against a running session without applying them. Returns validation errors for any operations that would fail. Operations: addnode, removenode, connect, disconnect." + )] + async fn validate_batch( + &self, + Parameters(args): Parameters, + ctx: RequestContext, + ) -> Result { + let (session, _role_name, perms) = resolve_session( + &self.app_state, + &args.session_id, + &ctx, + |p| p.modify_sessions, + "modify_sessions", + ) + .await?; + + let errors = crate::server::validate_batch_operations( + &session, + &args.operations, + &perms, + &self.app_state.config.security, + ) + .await; + + info!( + session_id = %args.session_id, + operation_count = args.operations.len(), + error_count = errors.len(), + "MCP validate_batch" + ); + + json_tool_result(&errors) + } + + // -- apply_batch ------------------------------------------------------- + + #[tool( + description = "Apply a batch of graph mutations to a running session as a single validated batch. All operations are validated before any are applied; if validation fails, none are applied. Note: engine-side errors after validation do not roll back already-applied operations. Operations: addnode, removenode, connect, disconnect." + )] + async fn apply_batch( + &self, + Parameters(args): Parameters, + ctx: RequestContext, + ) -> Result { + let (session, _role_name, perms) = resolve_session( + &self.app_state, + &args.session_id, + &ctx, + |p| p.modify_sessions, + "modify_sessions", + ) + .await?; + + crate::server::apply_batch_operations( + &session, + args.operations, + &perms, + &self.app_state.config.security, + ) + .await + .map_err(|e| McpError::invalid_params(e, None))?; + + info!(session_id = %args.session_id, "MCP apply_batch"); + + let result = serde_json::json!({ "success": true }); + json_tool_result(&result) + } + + // -- tune_node --------------------------------------------------------- + + #[tool( + description = "Send a control message to a specific node in a running session. Commonly used with UpdateParams to modify node parameters at runtime." + )] + async fn tune_node( + &self, + Parameters(args): Parameters, + ctx: RequestContext, + ) -> Result { + let (session, _role_name, _perms) = resolve_session( + &self.app_state, + &args.session_id, + &ctx, + |p| p.tune_nodes, + "tune_nodes", + ) + .await?; + + crate::server::tune_session_node( + &session, + args.node_id.clone(), + args.message, + &self.app_state.config.security, + &self.app_state.event_tx, + ) + .await + .map_err(|e| McpError::invalid_params(e, None))?; + + info!(session_id = %args.session_id, node_id = %args.node_id, "MCP tune_node"); + + let result = serde_json::json!({ "success": true }); + json_tool_result(&result) + } +} + +// --------------------------------------------------------------------------- +// ServerHandler trait impl +// --------------------------------------------------------------------------- + +#[tool_handler(router = self.tool_router)] +#[prompt_handler(router = self.prompt_router)] +impl ServerHandler for StreamKitMcp { + fn get_info(&self) -> ServerInfo { + let capabilities = ServerCapabilities::builder().enable_tools().enable_prompts().build(); + let mut info = ServerInfo::new(capabilities).with_instructions( + "StreamKit MCP server. Use list_nodes to discover available \ + processing nodes, validate_pipeline to check YAML, and \ + create_session / list_sessions / get_pipeline / destroy_session \ + to manage dynamic pipeline sessions. Use validate_batch and \ + apply_batch to mutate a running session's graph as a validated batch, \ + tune_node to send control messages, and \ + generate_oneshot_command to get a command for batch processing.", + ); + info.server_info = rmcp::model::Implementation::new("streamkit", env!("CARGO_PKG_VERSION")); + info + } +} + +// --------------------------------------------------------------------------- +// Service factory +// --------------------------------------------------------------------------- + +/// Create the `StreamableHttpService` tower service for mounting in the Axum +/// router via `nest_service`. +/// +/// ## `StreamableHttpServerConfig` defaults (rmcp 1.5) +/// +/// | Field | Default | +/// |--------------------|--------------------------------------| +/// | `sse_keep_alive` | 15 s | +/// | `sse_retry` | 3 s | +/// | `stateful_mode` | true | +/// | `json_response` | false | +/// | `allowed_hosts` | localhost, 127.0.0.1, ::1 | +/// +/// ## `SessionConfig` defaults (rmcp 1.5) +/// +/// | Field | Default | +/// |------------------------|----------| +/// | `channel_capacity` | 16 | +/// | `keep_alive` | 5 min | +/// | `sse_retry` | 3 s | +/// | `completed_cache_ttl` | 60 s | +/// +/// The 5-minute `keep_alive` TTL automatically evicts idle MCP sessions, +/// preventing unbounded growth from dropped connections. All sessions +/// require authentication, which further bounds creation rate. +/// +/// `allowed_hosts` is configured from `mcp.allowed_hosts` in the config. +/// When the list is empty (default), the `Host`-header check is disabled. +/// This is acceptable because Axum's `auth_guard_middleware` (bearer-token +/// validation) already prevents DNS rebinding exploitation — a rebound +/// request cannot supply a valid token. `origin_guard_middleware` +/// additionally restricts browser-initiated cross-origin requests. +/// For deployments exposed to untrusted networks *without* auth enabled, +/// populate `mcp.allowed_hosts` to re-enable `Host`-header validation. +pub fn streamable_http_service( + app_state: Arc, +) -> StreamableHttpService { + let mut config = StreamableHttpServerConfig::default(); + if app_state.config.mcp.allowed_hosts.is_empty() { + config = config.disable_allowed_hosts(); + } else { + config = config.with_allowed_hosts(app_state.config.mcp.allowed_hosts.clone()); + } + StreamableHttpService::new( + move || Ok(StreamKitMcp::new(Arc::clone(&app_state))), + Arc::new(LocalSessionManager::default()), + config, + ) +} diff --git a/apps/skit/src/mcp/oneshot.rs b/apps/skit/src/mcp/oneshot.rs new file mode 100644 index 00000000..c56ce2a3 --- /dev/null +++ b/apps/skit/src/mcp/oneshot.rs @@ -0,0 +1,110 @@ +// SPDX-FileCopyrightText: © 2025 StreamKit Contributors +// +// SPDX-License-Identifier: MPL-2.0 + +//! Oneshot command generation helpers for the MCP module. + +use std::fmt::Write; + +use super::OneshotInput; + +/// Shell-quote a value by wrapping it in single quotes and escaping any +/// embedded single quotes (`'` → `'\''`). +pub(super) fn shell_quote(s: &str) -> String { + format!("'{}'", s.replace('\'', "'\\''")) +} + +/// Return a heredoc delimiter that does not appear in `content`. +pub(super) fn unique_heredoc_delimiter(content: &str) -> String { + let base = "PIPELINE_EOF"; + if !content.contains(base) { + return base.to_string(); + } + for i in 0u32.. { + let candidate = format!("{base}_{i}"); + if !content.contains(&candidate) { + return candidate; + } + } + unreachable!() +} + +pub(super) fn generate_curl_command( + yaml: &str, + inputs: &[OneshotInput], + output: &str, + server_url: &str, +) -> String { + let delim = unique_heredoc_delimiter(yaml); + + let mut cmd = String::new(); + let _ = writeln!(cmd, "# Save pipeline YAML to a temporary file, then run curl."); + let _ = writeln!(cmd, "PIPELINE=$(mktemp /tmp/pipeline-XXXXXX.yaml)"); + let _ = writeln!(cmd, "cat > \"$PIPELINE\" <<'{delim}'"); + let _ = writeln!(cmd, "{yaml}"); + let _ = writeln!(cmd, "{delim}"); + let _ = writeln!(cmd); + let url = format!("{server_url}/api/v1/process"); + let _ = write!(cmd, "curl -X POST {} \\\n -F 'config=<'\"$PIPELINE\"''", shell_quote(&url)); + for input in inputs { + let _ = + write!(cmd, " \\\n -F {}", shell_quote(&format!("{}=@{}", input.field, input.path))); + } + let _ = write!(cmd, " \\\n -o {}", shell_quote(output)); + cmd +} + +pub(super) fn generate_skit_cli_command( + yaml: &str, + inputs: &[OneshotInput], + output: &str, + server_url: &str, +) -> String { + let delim = unique_heredoc_delimiter(yaml); + + let mut cmd = String::new(); + let _ = writeln!(cmd, "# Save pipeline YAML to a temporary file, then run the CLI."); + let _ = writeln!(cmd, "PIPELINE=$(mktemp /tmp/pipeline-XXXXXX.yaml)"); + let _ = writeln!(cmd, "cat > \"$PIPELINE\" <<'{delim}'"); + let _ = writeln!(cmd, "{yaml}"); + let _ = writeln!(cmd, "{delim}"); + let _ = writeln!(cmd); + + // The CLI takes one positional input mapped to the "media" field, + // plus optional --input field=path for additional inputs. + let (primary, extras): (Vec<_>, Vec<_>) = inputs.iter().partition(|i| i.field == "media"); + + if let Some(primary_input) = primary.first() { + let _ = write!( + cmd, + "streamkit-client oneshot \"$PIPELINE\" {}", + shell_quote(&primary_input.path) + ); + } else if let Some(first) = inputs.first() { + // No input named "media" — use the first as positional and re-add + // it via --input so the server receives the correct field name. + let _ = write!(cmd, "streamkit-client oneshot \"$PIPELINE\" {}", shell_quote(&first.path)); + } else { + let _ = write!(cmd, "streamkit-client oneshot \"$PIPELINE\" "); + } + + let _ = write!(cmd, " {}", shell_quote(output)); + + // Emit --input flags: when a "media" input exists, only extras need + // flags; otherwise all inputs are emitted (the first was used as the + // positional arg but with a non-"media" field name). + if primary.is_empty() { + for input in inputs { + let _ = + write!(cmd, " --input {}", shell_quote(&format!("{}={}", input.field, input.path))); + } + } else { + for input in &extras { + let _ = + write!(cmd, " --input {}", shell_quote(&format!("{}={}", input.field, input.path))); + } + } + + let _ = write!(cmd, " --server {}", shell_quote(server_url)); + cmd +} diff --git a/apps/skit/src/mcp/prompts.rs b/apps/skit/src/mcp/prompts.rs new file mode 100644 index 00000000..33d86d90 --- /dev/null +++ b/apps/skit/src/mcp/prompts.rs @@ -0,0 +1,352 @@ +// SPDX-FileCopyrightText: © 2025 StreamKit Contributors +// +// SPDX-License-Identifier: MPL-2.0 + +//! Prompt definitions and content builder helpers for the MCP module. +//! +//! The `#[prompt_router]` impl block defines the MCP prompts exposed via +//! `prompts/list` and `prompts/get`. Content builder functions remain as +//! plain helpers called by the prompt methods. + +use std::fmt::Write; + +use rmcp::handler::server::router::prompt::PromptRouter; +use rmcp::handler::server::wrapper::Parameters; +use rmcp::model::{GetPromptResult, PromptMessage, PromptMessageRole}; +use rmcp::service::RequestContext; +use rmcp::{prompt, prompt_router, ErrorData as McpError, RoleServer}; +use serde::Deserialize; +use streamkit_api::Pipeline; +use streamkit_core::NodeDefinition; +use tracing::info; + +use super::{ + assemble_pipeline_state, extract_auth, filtered_node_definitions, resolve_session, StreamKitMcp, +}; + +/// Create a [`PromptRouter`] for [`StreamKitMcp`]. +/// +/// Exposed to the parent module so the struct constructor can store +/// the router alongside the tool router. +pub(super) fn create_prompt_router() -> PromptRouter { + StreamKitMcp::prompt_router() +} + +// --------------------------------------------------------------------------- +// Prompt argument structs +// --------------------------------------------------------------------------- + +#[derive(Debug, Deserialize, rmcp::schemars::JsonSchema)] +pub(super) struct DesignPipelinePromptArgs { + /// Optional natural language description of the desired pipeline. + #[serde(default)] + pub description: Option, +} + +#[derive(Debug, Deserialize, rmcp::schemars::JsonSchema)] +pub(super) struct DebugPipelinePromptArgs { + /// Session ID or name to debug. + pub session_id: String, +} + +// --------------------------------------------------------------------------- +// Prompt router +// --------------------------------------------------------------------------- + +#[prompt_router] +impl StreamKitMcp { + /// Design a StreamKit pipeline from scratch. Provides available node + /// definitions, YAML format, connection rules, and workflow steps. + #[prompt( + name = "design_pipeline", + description = "Design a StreamKit pipeline with available nodes and YAML format" + )] + async fn design_pipeline( + &self, + Parameters(args): Parameters, + ctx: RequestContext, + ) -> Result { + let (_role_name, perms) = extract_auth(&ctx, &self.app_state)?; + + let definitions = filtered_node_definitions(&self.app_state, &perms)?; + + let content = build_design_pipeline_content(&definitions, args.description.as_deref()); + + let messages = vec![PromptMessage::new_text(PromptMessageRole::User, content)]; + Ok(GetPromptResult::new(messages).with_description("Design a StreamKit pipeline")) + } + + /// Debug a running StreamKit session. Shows pipeline state, node states, + /// connections, and diagnostic checklist. + #[prompt( + name = "debug_pipeline", + description = "Debug a running StreamKit session with pipeline state and diagnostics" + )] + async fn debug_pipeline( + &self, + Parameters(args): Parameters, + ctx: RequestContext, + ) -> Result { + let (session, _role_name, _perms) = resolve_session( + &self.app_state, + &args.session_id, + &ctx, + |p| p.list_sessions, + "list_sessions", + ) + .await?; + + let api_pipeline = assemble_pipeline_state(&session).await; + + let content = build_debug_pipeline_content(&args.session_id, &api_pipeline) + .map_err(|e| McpError::internal_error(e, None))?; + + info!(session_id = %args.session_id, "MCP debug_pipeline prompt"); + + let messages = vec![PromptMessage::new_text(PromptMessageRole::User, content)]; + Ok(GetPromptResult::new(messages) + .with_description(format!("Debug StreamKit session '{}'", args.session_id))) + } +} + +// --------------------------------------------------------------------------- +// Content builder helpers +// --------------------------------------------------------------------------- + +/// Build the `design_pipeline` prompt content string. +fn build_design_pipeline_content( + definitions: &[NodeDefinition], + description: Option<&str>, +) -> String { + // Group definitions by first category (or "uncategorized"). + let mut by_category: std::collections::BTreeMap> = + std::collections::BTreeMap::new(); + for def in definitions { + let cat = def.categories.first().cloned().unwrap_or_else(|| "uncategorized".to_string()); + by_category.entry(cat).or_default().push(def); + } + + let mut content = String::with_capacity(8192); + + content.push_str( + "You are helping design a StreamKit pipeline. StreamKit pipelines are \ + defined in YAML with two sections: `nodes` and `connections`.\n\n", + ); + + // YAML format explanation + content.push_str("## YAML Format\n\n"); + content.push_str("```yaml\nnodes:\n :\n kind: \n"); + content.push_str(" params: # optional, node-specific\n : \n"); + content.push_str("connections:\n - from_node: \n from_pin: \n"); + content.push_str(" to_node: \n to_pin: \n"); + content.push_str(" mode: reliable # or best_effort\n```\n\n"); + + // Available nodes by category + content.push_str("## Available Nodes (by category)\n\n"); + + for (category, defs) in &by_category { + let _ = write!(content, "### {category}\n\n"); + for def in defs { + let _ = write!(content, "- **`{}`**", def.kind); + if let Some(desc) = &def.description { + let _ = write!(content, " — {desc}"); + } + content.push('\n'); + + if !def.inputs.is_empty() { + let pins: Vec = def + .inputs + .iter() + .map(|p| format!("`{}` ({:?})", p.name, p.accepts_types)) + .collect(); + let _ = writeln!(content, " - Inputs: {}", pins.join(", ")); + } + if !def.outputs.is_empty() { + let pins: Vec = def + .outputs + .iter() + .map(|p| format!("`{}` ({:?})", p.name, p.produces_type)) + .collect(); + let _ = writeln!(content, " - Outputs: {}", pins.join(", ")); + } + // Param schema summary (skip trivially empty schemas) + if def.param_schema != serde_json::json!({}) + && def.param_schema != serde_json::json!(null) + { + if let Some(props) = def.param_schema.get("properties") { + if let Some(obj) = props.as_object() { + if !obj.is_empty() { + let keys: Vec<&String> = obj.keys().collect(); + let _ = writeln!( + content, + " - Params: {}", + keys.iter() + .map(|k| format!("`{k}`")) + .collect::>() + .join(", ") + ); + } + } + } + } + } + content.push('\n'); + } + + // Connection rules + content.push_str("## Connection Rules\n\n"); + content.push_str( + "- Pins have types (RawAudio, RawVideo, EncodedAudio, EncodedVideo, \ + Text, Transcription, Binary, Any, Passthrough, Custom) — only \ + matching types can connect. `Any` accepts all types; `Passthrough` \ + adapts to the connected input type.\n", + ); + content.push_str( + "- Pin cardinality: `One` (single connection), `Broadcast` (fan-out \ + to many), `Dynamic` (runtime-created pin family, e.g. mixer inputs).\n", + ); + content.push_str( + "- Connection modes: `reliable` (backpressure — sender blocks if \ + receiver is slow), `best_effort` (drop packets if the receiver \ + can't keep up).\n\n", + ); + + // Pipeline modes + content.push_str("## Pipeline Modes\n\n"); + content.push_str( + "- **dynamic**: Real-time, hot-reconfigurable pipeline. Nodes run \ + continuously and the graph can be mutated at runtime.\n", + ); + content.push_str( + "- **oneshot**: Stateless batch/request-response pipeline. Processes \ + a single request and exits. Uses synthetic `streamkit::http_input` / \ + `streamkit::http_output` nodes.\n\n", + ); + + // Workflow + content.push_str("## Workflow\n\n"); + content.push_str("1. Design the YAML based on user requirements.\n"); + content.push_str( + "2. Call the `validate_pipeline` tool to check for errors before creating a session.\n", + ); + content.push_str("3. Fix any issues reported by validation.\n"); + content.push_str("4. Call `create_session` to start the pipeline.\n"); + + // Optional user description + if let Some(desc) = description { + let _ = write!(content, "\n## User Request\n\n{desc}\n"); + } + + content +} + +/// Build the `debug_pipeline` prompt content string. +fn build_debug_pipeline_content( + session_id: &str, + api_pipeline: &Pipeline, +) -> Result { + let pipeline_json = serde_json::to_string_pretty(api_pipeline) + .map_err(|e| format!("serialization error: {e}"))?; + + let mut content = String::with_capacity(4096); + + let _ = write!(content, "You are debugging StreamKit session `{session_id}`.\n\n"); + + // Current pipeline state + content.push_str("## Current Pipeline State\n\n"); + content.push_str("```json\n"); + content.push_str(&pipeline_json); + content.push_str("\n```\n\n"); + + // Per-node state summary + content.push_str("## Node States\n\n"); + let mut has_errors = false; + for (id, node) in &api_pipeline.nodes { + let state_str = + node.state.as_ref().map_or_else(|| "unknown".to_string(), |s| format!("{s:?}")); + let _ = write!(content, "- **`{id}`** (`{}`): {state_str}", node.kind); + if let Some(ref state) = node.state { + match state { + streamkit_core::NodeState::Failed { reason } => { + has_errors = true; + let _ = write!(content, " — error: {reason}"); + }, + streamkit_core::NodeState::Recovering { reason, .. } => { + let _ = write!(content, " — recovering: {reason}"); + }, + streamkit_core::NodeState::Degraded { reason, .. } => { + let _ = write!(content, " — degraded: {reason}"); + }, + _ => {}, + } + } + content.push('\n'); + } + + // Connection summary + if !api_pipeline.connections.is_empty() { + content.push_str("\n## Connections\n\n"); + for conn in &api_pipeline.connections { + let _ = writeln!( + content, + "- `{}`.`{}` → `{}`.`{}` ({})", + conn.from_node, + conn.from_pin, + conn.to_node, + conn.to_pin, + match conn.mode { + streamkit_core::control::ConnectionMode::Reliable => "reliable", + streamkit_core::control::ConnectionMode::BestEffort => "best_effort", + }, + ); + } + } + + // Diagnostic guidance + content.push_str("\n## Diagnostic Checklist\n\n"); + content.push_str("1. Are all nodes in a **running** state?\n"); + if has_errors { + content.push_str( + "2. **Errors detected** — review the error messages above and \ + check node parameters.\n", + ); + } else { + content.push_str("2. No errors reported so far.\n"); + } + content.push_str( + "3. Are all connections type-compatible? (Check that output pin types \ + match the connected input pin's accepted types.)\n", + ); + content.push_str("4. Are all required node parameters set correctly?\n"); + content.push_str( + "5. Are connection modes appropriate? (`reliable` for lossless \ + processing, `best_effort` for real-time streaming where drops are \ + acceptable.)\n\n", + ); + + // Remediation tools + content.push_str("## Available Tools for Fixing Issues\n\n"); + content.push_str( + "- `validate_batch` — dry-run a set of graph mutations (add/remove \ + nodes, connect/disconnect) without applying them.\n", + ); + content.push_str( + "- `apply_batch` — atomically apply validated mutations to the \ + running session.\n", + ); + content.push_str( + "- `tune_node` — send a control message to a specific node \ + (e.g. UpdateParams to change parameters at runtime).\n", + ); + content.push_str( + "- `validate_pipeline` — re-validate the pipeline YAML to catch \ + structural issues.\n", + ); + content.push_str( + "- `get_pipeline` — fetch the latest pipeline state (node states may \ + change over time).\n", + ); + content.push_str("- `destroy_session` — tear down the session if it is unrecoverable.\n"); + + Ok(content) +} diff --git a/apps/skit/src/server/mod.rs b/apps/skit/src/server/mod.rs index 9cfe52a0..b4bbbb0e 100644 --- a/apps/skit/src/server/mod.rs +++ b/apps/skit/src/server/mod.rs @@ -519,7 +519,7 @@ async fn list_node_definitions_handler( /// A single node in the validated graph. #[derive(Serialize)] -struct ValidateGraphNode { +pub struct ValidateGraphNode { id: String, kind: String, #[serde(skip_serializing_if = "Option::is_none")] @@ -528,7 +528,7 @@ struct ValidateGraphNode { /// A single connection in the validated graph. #[derive(Serialize)] -struct ValidateGraphConnection { +pub struct ValidateGraphConnection { from_node: String, from_pin: String, to_node: String, @@ -537,7 +537,7 @@ struct ValidateGraphConnection { /// The parsed graph structure — always returned so the UI can highlight nodes. #[derive(Serialize)] -struct ValidateGraph { +pub struct ValidateGraph { nodes: Vec, connections: Vec, } @@ -545,7 +545,7 @@ struct ValidateGraph { /// Diagnostic category. #[derive(Debug, Clone, Copy, Serialize)] #[serde(rename_all = "snake_case")] -enum DiagnosticKind { +pub enum DiagnosticKind { Parse, Schema, Connection, @@ -555,7 +555,7 @@ enum DiagnosticKind { /// A single validation diagnostic. #[derive(Debug, Serialize)] -struct ValidateDiagnostic { +pub struct ValidateDiagnostic { kind: DiagnosticKind, message: String, #[serde(skip_serializing_if = "Option::is_none")] @@ -564,10 +564,11 @@ struct ValidateDiagnostic { connection_id: Option, } -/// Top-level response for `POST /api/v1/validate`. +/// Top-level response for `POST /api/v1/validate` and the MCP +/// `validate_pipeline` tool. #[derive(Serialize)] -struct ValidateResponse { - valid: bool, +pub struct ValidateResponse { + pub(crate) valid: bool, errors: Vec, warnings: Vec, graph: Option, @@ -576,7 +577,7 @@ struct ValidateResponse { /// Pipeline mode for validation — determines which synthetic-node rules apply. #[derive(Deserialize, Clone, Copy, PartialEq, Eq)] #[serde(rename_all = "lowercase")] -enum PipelineMode { +pub enum PipelineMode { Dynamic, Oneshot, } @@ -598,7 +599,10 @@ static SYNTHETIC_KINDS: std::sync::LazyLock> = std::sync::LazyLock::new(|| synthetic_node_definitions().into_iter().map(|d| d.kind).collect()); /// Returns `true` for node kinds that are synthetic oneshot-only markers. -fn is_synthetic_kind(kind: &str) -> bool { +/// +/// Used by both the HTTP and MCP `create_session` paths to reject +/// oneshot-only nodes in dynamic pipelines. +pub fn is_synthetic_kind(kind: &str) -> bool { SYNTHETIC_KINDS.iter().any(|k| k == kind) } @@ -607,7 +611,7 @@ fn is_synthetic_kind(kind: &str) -> bool { /// /// Used by both `list_node_definitions_handler` and the validate endpoint so /// there is a single source of truth for these definitions. -fn synthetic_node_definitions() -> Vec { +pub fn synthetic_node_definitions() -> Vec { use streamkit_core::types::PacketType; use streamkit_core::{InputPin, NodeDefinition, OutputPin, PinCardinality}; @@ -977,99 +981,17 @@ async fn validate_pipeline_handler( return Err((StatusCode::FORBIDDEN, "Permission denied: create_sessions required".into())); } - let mut errors: Vec = Vec::new(); - let mut warnings: Vec = Vec::new(); - - // 1. YAML parsing - let user_pipeline = match streamkit_api::yaml::parse_yaml(&payload.yaml) { - Ok(p) => p, - Err(e) => { - debug!(error = %e, "Pipeline YAML parse error"); - errors.push(ValidateDiagnostic { - kind: DiagnosticKind::Parse, - message: e, - node_id: None, - connection_id: None, - }); - return Ok(Json(ValidateResponse { valid: false, errors, warnings, graph: None })); - }, - }; - - // 2. Compile to internal Pipeline - let pipeline = match compile(user_pipeline) { - Ok(p) => p, - Err(e) => { - debug!(error = %e, "Pipeline compilation error"); - errors.push(ValidateDiagnostic { - kind: DiagnosticKind::Parse, - message: e, - node_id: None, - connection_id: None, - }); - return Ok(Json(ValidateResponse { valid: false, errors, warnings, graph: None })); - }, - }; - - // 3. Reject empty pipelines (matches create_session_handler) - if pipeline.nodes.is_empty() { - errors.push(ValidateDiagnostic { - kind: DiagnosticKind::Schema, - message: "Pipeline is empty. Add some nodes before validating.".into(), - node_id: None, - connection_id: None, - }); - return Ok(Json(ValidateResponse { valid: false, errors, warnings, graph: None })); - } - - // 4. Validate nodes against the registry - let registry_guard = - read_registry(&app_state).map_err(|sc| (sc, "Failed to read node registry".to_string()))?; - let node_defs = - validate_nodes(&pipeline, ®istry_guard, Some(&perms), &mut errors, &mut warnings); - drop(registry_guard); - - // 5. Mode-specific checks: reject synthetic nodes in dynamic mode - check_mode(&pipeline, payload.mode, &mut errors); - - // 6. Validate connections - validate_connections(&pipeline, &node_defs, &mut errors); - - // 7. File-path security checks (reuse session/oneshot helpers) - collect_file_path_errors(&pipeline, &app_state.config.security, &mut errors); - - // 8. Build graph (always included so the UI can highlight bad nodes) - let graph = Some(ValidateGraph { - nodes: pipeline - .nodes - .iter() - .map(|(id, n)| ValidateGraphNode { - id: id.clone(), - kind: n.kind.clone(), - params: n.params.clone(), - }) - .collect(), - connections: pipeline - .connections - .iter() - .map(|c| ValidateGraphConnection { - from_node: c.from_node.clone(), - from_pin: c.from_pin.clone(), - to_node: c.to_node.clone(), - to_pin: c.to_pin.clone(), - }) - .collect(), - }); - - let valid = errors.is_empty(); + let response = validate_pipeline_yaml(&app_state, &perms, &payload.yaml, payload.mode) + .map_err(|e| (StatusCode::SERVICE_UNAVAILABLE, e))?; debug!( - valid = valid, - error_count = errors.len(), - warning_count = warnings.len(), + valid = response.valid, + error_count = response.errors.len(), + warning_count = response.warnings.len(), "Pipeline validation completed" ); - Ok(Json(ValidateResponse { valid, errors, warnings, graph })) + Ok(Json(response)) } /// Extract a human-readable message from an `AppError`. @@ -2520,7 +2442,10 @@ struct CreateSessionResponse { /// Helper function to populate the session's in-memory pipeline representation /// from the compiled engine pipeline definition. -async fn populate_session_pipeline(session: &crate::session::Session, engine_pipeline: &Pipeline) { +pub async fn populate_session_pipeline( + session: &crate::session::Session, + engine_pipeline: &Pipeline, +) { let mut pipeline = session.pipeline.lock().await; // Forward top-level metadata so the UI can read it from the session snapshot. @@ -2554,7 +2479,10 @@ async fn populate_session_pipeline(session: &crate::session::Session, engine_pip } /// Helper function to send all node and connection control messages to the engine actor. -async fn send_pipeline_to_engine(session: &crate::session::Session, engine_pipeline: &Pipeline) { +pub async fn send_pipeline_to_engine( + session: &crate::session::Session, + engine_pipeline: &Pipeline, +) { // Send control messages to engine actor (asynchronous) // The engine will actually instantiate the nodes for (node_id, node_spec) in &engine_pipeline.nodes { @@ -2604,194 +2532,22 @@ async fn create_session_handler( )); } - // Global session limit - let (current_count, name_taken) = { - let session_manager = app_state.session_manager.lock().await; - let current_count = session_manager.session_count(); - let name_taken = req.name.as_deref().is_some_and(|n| session_manager.is_name_taken(n)); - drop(session_manager); - (current_count, name_taken) - }; - if let Some(ref session_name) = req.name { - if name_taken { - return Err(( - StatusCode::CONFLICT, - format!( - "Failed to create session: Session with name '{session_name}' already exists" - ), - )); - } - } - if !app_state.config.permissions.can_accept_session(current_count) { - return Err(( - StatusCode::TOO_MANY_REQUESTS, - "Maximum concurrent sessions limit reached".to_string(), - )); - } - - // Parse and compile the YAML pipeline - let user_pipeline: UserPipeline = - streamkit_api::yaml::parse_yaml(&req.yaml).map_err(|e| (StatusCode::BAD_REQUEST, e))?; - - let engine_pipeline = compile(user_pipeline) - .map_err(|e| (StatusCode::BAD_REQUEST, format!("Invalid pipeline: {e}")))?; - - // Validate the pipeline has at least one node - if engine_pipeline.nodes.is_empty() { - return Err(( - StatusCode::BAD_REQUEST, - "Pipeline is empty. Add some nodes before creating a session.".to_string(), - )); - } - - for (node_id, node) in &engine_pipeline.nodes { - if node.kind == "streamkit::http_input" || node.kind == "streamkit::http_output" { - return Err(( - StatusCode::BAD_REQUEST, - format!( - "Node '{node_id}' kind '{}' is oneshot-only and cannot be used in dynamic sessions", - node.kind - ), - )); - } - - if !perms.is_node_allowed(&node.kind) { - return Err(( - StatusCode::FORBIDDEN, - format!("Permission denied: node '{node_id}' kind '{}' not allowed", node.kind), - )); - } + let result = create_dynamic_session(&app_state, &req.yaml, req.name, role_name, &perms).await; - if node.kind.starts_with("plugin::") && !perms.is_plugin_allowed(&node.kind) { - return Err(( - StatusCode::FORBIDDEN, - format!("Permission denied: node '{node_id}' plugin '{}' not allowed", node.kind), - )); - } + match result { + Ok(r) => Ok(Json(CreateSessionResponse { + session_id: r.session_id, + name: r.name, + created_at: r.created_at, + })), + Err(e) => Err(match e { + CreateSessionError::InvalidInput(msg) => (StatusCode::BAD_REQUEST, msg), + CreateSessionError::Forbidden(msg) => (StatusCode::FORBIDDEN, msg), + CreateSessionError::Conflict(msg) => (StatusCode::CONFLICT, msg), + CreateSessionError::LimitReached(msg) => (StatusCode::TOO_MANY_REQUESTS, msg), + CreateSessionError::Internal(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg), + }), } - - validate_file_reader_paths(&engine_pipeline, &app_state.config.security).map_err( - |e| match e { - AppError::BadRequest(msg) => (StatusCode::BAD_REQUEST, msg), - AppError::PipelineCompilation(msg) => { - (StatusCode::BAD_REQUEST, format!("Invalid pipeline: {msg}")) - }, - AppError::Serde(err) => { - (StatusCode::BAD_REQUEST, format!("Invalid YAML config format: {err}")) - }, - AppError::Multipart(err) => { - (StatusCode::BAD_REQUEST, format!("Invalid multipart payload: {err}")) - }, - AppError::Engine(err) => { - (StatusCode::INTERNAL_SERVER_ERROR, format!("Pipeline execution error: {err}")) - }, - AppError::Forbidden(msg) => (StatusCode::FORBIDDEN, msg), - }, - )?; - - validate_file_writer_paths(&engine_pipeline, &app_state.config.security).map_err( - |e| match e { - AppError::BadRequest(msg) => (StatusCode::BAD_REQUEST, msg), - AppError::PipelineCompilation(msg) => { - (StatusCode::BAD_REQUEST, format!("Invalid pipeline: {msg}")) - }, - AppError::Serde(err) => { - (StatusCode::BAD_REQUEST, format!("Invalid YAML config format: {err}")) - }, - AppError::Multipart(err) => { - (StatusCode::BAD_REQUEST, format!("Invalid multipart payload: {err}")) - }, - AppError::Engine(err) => { - (StatusCode::INTERNAL_SERVER_ERROR, format!("Pipeline execution error: {err}")) - }, - AppError::Forbidden(msg) => (StatusCode::FORBIDDEN, msg), - }, - )?; - - validate_script_paths(&engine_pipeline, &app_state.config.security).map_err(|e| match e { - AppError::BadRequest(msg) => (StatusCode::BAD_REQUEST, msg), - AppError::PipelineCompilation(msg) => { - (StatusCode::BAD_REQUEST, format!("Invalid pipeline: {msg}")) - }, - AppError::Serde(err) => { - (StatusCode::BAD_REQUEST, format!("Invalid YAML config format: {err}")) - }, - AppError::Multipart(err) => { - (StatusCode::BAD_REQUEST, format!("Invalid multipart payload: {err}")) - }, - AppError::Engine(err) => { - (StatusCode::INTERNAL_SERVER_ERROR, format!("Pipeline execution error: {err}")) - }, - AppError::Forbidden(msg) => (StatusCode::FORBIDDEN, msg), - })?; - - // Create the session without holding the session manager lock. - let session = crate::session::Session::create( - &app_state.engine, - &app_state.config, - req.name.clone(), - app_state.event_tx.clone(), - Some(role_name.clone()), - ) - .await - .map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, format!("Failed to create session: {e}")))?; - - // Insert the session with short lock hold and re-check limits to avoid races. - let insert_result = { - let mut session_manager = app_state.session_manager.lock().await; - let current_count = session_manager.session_count(); - if app_state.config.permissions.can_accept_session(current_count) { - session_manager.add_session(session.clone()) - } else { - Err("Maximum concurrent sessions limit reached".to_string()) - } - }; - if let Err(error_msg) = insert_result { - let _ = session.shutdown_and_wait().await; - if error_msg == "Maximum concurrent sessions limit reached" { - return Err((StatusCode::TOO_MANY_REQUESTS, error_msg)); - } - return Err(( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to create session: {error_msg}"), - )); - } - - let session_id = session.id.clone(); - let session_name = session.name.clone(); - let created_at_str = crate::session::system_time_to_rfc3339(session.created_at); - - info!(session_id = %session_id, name = ?session_name, "Created new session via HTTP"); - - // Update the session pipeline immediately (synchronous) - // This ensures GET /sessions/{id}/pipeline returns the nodes right away - populate_session_pipeline(&session, &engine_pipeline).await; - - // Send control messages to engine actor to instantiate nodes and connections - send_pipeline_to_engine(&session, &engine_pipeline).await; - - info!( - "Session {} initialized with {} nodes and {} connections", - session_id, - engine_pipeline.nodes.len(), - engine_pipeline.connections.len() - ); - - // Broadcast event to all WebSocket clients - let event = ApiEvent { - message_type: MessageType::Event, - correlation_id: None, - payload: EventPayload::SessionCreated { - session_id: session_id.clone(), - name: session_name.clone(), - created_at: created_at_str.clone(), - }, - }; - if app_state.event_tx.send(crate::state::BroadcastEvent::to_all(event)).is_err() { - debug!("No WebSocket clients connected to receive SessionCreated event"); - } - - Ok(Json(CreateSessionResponse { session_id, name: session_name, created_at: created_at_str })) } /// Axum handler to get the list of active sessions. @@ -3953,39 +3709,594 @@ async fn metrics_middleware(req: axum::http::Request, next: Next) -> Respo response } -/// Creates the Axum application with all routes and middleware. +// --------------------------------------------------------------------------- +// Shared helpers — used by both HTTP handlers and crate::mcp +// --------------------------------------------------------------------------- + +/// Validate a pipeline YAML string with optional mode. /// -/// # Arguments +/// Shared implementation behind `POST /api/v1/validate` and the MCP +/// `validate_pipeline` tool. /// -/// * `config` - The server configuration -/// * `auth` - Optional pre-initialized AuthState. If None, creates a disabled auth state. +/// # Errors /// -/// # Panics +/// Returns an error string only if the node registry lock is poisoned. +pub fn validate_pipeline_yaml( + app_state: &Arc, + perms: &crate::permissions::Permissions, + yaml: &str, + mode: Option, +) -> Result { + let mut errors: Vec = Vec::new(); + let mut warnings: Vec = Vec::new(); + + let user_pipeline = match streamkit_api::yaml::parse_yaml(yaml) { + Ok(p) => p, + Err(e) => { + debug!(error = %e, "Pipeline YAML parse error"); + errors.push(ValidateDiagnostic { + kind: DiagnosticKind::Parse, + message: e, + node_id: None, + connection_id: None, + }); + return Ok(ValidateResponse { valid: false, errors, warnings, graph: None }); + }, + }; + + let pipeline = match compile(user_pipeline) { + Ok(p) => p, + Err(e) => { + debug!(error = %e, "Pipeline compilation error"); + errors.push(ValidateDiagnostic { + kind: DiagnosticKind::Parse, + message: e, + node_id: None, + connection_id: None, + }); + return Ok(ValidateResponse { valid: false, errors, warnings, graph: None }); + }, + }; + + if pipeline.nodes.is_empty() { + errors.push(ValidateDiagnostic { + kind: DiagnosticKind::Schema, + message: "Pipeline is empty. Add some nodes before validating.".into(), + node_id: None, + connection_id: None, + }); + return Ok(ValidateResponse { valid: false, errors, warnings, graph: None }); + } + + let registry_guard = + read_registry(app_state).map_err(|_| "Failed to read node registry".to_string())?; + let node_defs = + validate_nodes(&pipeline, ®istry_guard, Some(perms), &mut errors, &mut warnings); + drop(registry_guard); + + check_mode(&pipeline, mode, &mut errors); + validate_connections(&pipeline, &node_defs, &mut errors); + collect_file_path_errors(&pipeline, &app_state.config.security, &mut errors); + + let graph = Some(ValidateGraph { + nodes: pipeline + .nodes + .iter() + .map(|(id, n)| ValidateGraphNode { + id: id.clone(), + kind: n.kind.clone(), + params: n.params.clone(), + }) + .collect(), + connections: pipeline + .connections + .iter() + .map(|c| ValidateGraphConnection { + from_node: c.from_node.clone(), + from_pin: c.from_pin.clone(), + to_node: c.to_node.clone(), + to_pin: c.to_pin.clone(), + }) + .collect(), + }); + + let valid = errors.is_empty(); + Ok(ValidateResponse { valid, errors, warnings, graph }) +} + +/// Run all file-path security checks against a pipeline. /// -/// Panics if the plugin manager fails to initialize. This can happen if: -/// - Plugin directories cannot be created due to filesystem permissions -/// - Plugin directories exist but are not accessible -/// - CORS configuration is invalid (wildcard with auth enabled) +/// # Errors /// -/// Since this occurs during application initialization, a panic here is acceptable -/// as the server cannot function without proper configuration. +/// Returns a human-readable error message if any path violates the security +/// policy. +pub fn check_file_path_security( + pipeline: &Pipeline, + security_config: &crate::config::SecurityConfig, +) -> Result<(), String> { + let mut msgs = Vec::new(); + for result in [ + validate_file_reader_paths(pipeline, security_config), + validate_file_writer_paths(pipeline, security_config), + validate_script_paths(pipeline, security_config), + ] { + if let Err(e) = result { + msgs.push(app_error_message(e)); + } + } + if msgs.is_empty() { + Ok(()) + } else { + Err(msgs.join("; ")) + } +} + +/// Error type returned by [`create_dynamic_session`]. +/// +/// Each variant carries enough semantic meaning for both HTTP and MCP callers +/// to map to the appropriate protocol-level error (e.g. status codes for HTTP, +/// `McpError` variants for MCP). +pub enum CreateSessionError { + /// Invalid input (YAML parse, compile, empty pipeline, synthetic nodes, + /// bad file paths). + InvalidInput(String), + /// Permission denied (node or plugin not allowed). + Forbidden(String), + /// Session name already taken. + Conflict(String), + /// Maximum concurrent-session limit reached. + LimitReached(String), + /// Internal failure (engine allocation, session insert, etc.). + Internal(String), +} + +/// Result returned by [`create_dynamic_session`] on success. +pub struct CreateSessionResult { + pub session_id: String, + pub name: Option, + pub created_at: String, +} + +/// Shared implementation for creating a dynamic pipeline session. +/// +/// Handles YAML parsing, compilation, permission checks, file-path security, +/// session-limit pre-flight, engine allocation, session insertion, pipeline +/// population, engine dispatch, and event broadcast. +/// +/// Callers are responsible for extracting auth and checking +/// `perms.create_sessions` before calling this function. +/// +/// # Errors +/// +/// Returns a [`CreateSessionError`] variant matching the failure category +/// (invalid input, permission denied, name conflict, session limit, or +/// internal error). +pub async fn create_dynamic_session( + app_state: &Arc, + yaml: &str, + name: Option, + role_name: String, + perms: &crate::permissions::Permissions, +) -> Result { + // Parse & compile + let user_pipeline: UserPipeline = streamkit_api::yaml::parse_yaml(yaml) + .map_err(|e| CreateSessionError::InvalidInput(format!("YAML parse error: {e}")))?; + + let engine_pipeline = compile(user_pipeline).map_err(|e| { + CreateSessionError::InvalidInput(format!("Pipeline compilation error: {e}")) + })?; + + if engine_pipeline.nodes.is_empty() { + return Err(CreateSessionError::InvalidInput( + "Pipeline is empty. Add some nodes before creating a session.".to_string(), + )); + } + + // Per-node permission and security checks. + for (node_id, node) in &engine_pipeline.nodes { + if is_synthetic_kind(&node.kind) { + return Err(CreateSessionError::InvalidInput(format!( + "Node '{node_id}' kind '{}' is oneshot-only and cannot be used in dynamic sessions", + node.kind + ))); + } + if !perms.is_node_allowed(&node.kind) { + return Err(CreateSessionError::Forbidden(format!( + "Permission denied: node '{node_id}' kind '{}' not allowed", + node.kind + ))); + } + if node.kind.starts_with("plugin::") && !perms.is_plugin_allowed(&node.kind) { + return Err(CreateSessionError::Forbidden(format!( + "Permission denied: node '{node_id}' plugin '{}' not allowed", + node.kind + ))); + } + } + + // File-path security — policy violations are permission denials, not + // malformed input (preserves the 403 FORBIDDEN status the old HTTP + // handler returned for AppError::Forbidden from validate_file_*_paths). + check_file_path_security(&engine_pipeline, &app_state.config.security) + .map_err(CreateSessionError::Forbidden)?; + + // Pre-flight: reject early if over the session limit or name is taken, + // avoiding wasted engine allocation. The checks are re-verified under + // the lock inside add_session for correctness. + let (current_count, name_taken) = { + let sm = app_state.session_manager.lock().await; + (sm.session_count(), name.as_deref().is_some_and(|n| sm.is_name_taken(n))) + }; + if let Some(ref session_name) = name { + if name_taken { + return Err(CreateSessionError::Conflict(format!( + "Session with name '{session_name}' already exists" + ))); + } + } + if !app_state.config.permissions.can_accept_session(current_count) { + return Err(CreateSessionError::LimitReached( + "Maximum concurrent sessions limit reached".to_string(), + )); + } + + // Create session (engine allocation). + let session = crate::session::Session::create( + &app_state.engine, + &app_state.config, + name, + app_state.event_tx.clone(), + Some(role_name), + ) + .await + .map_err(|e| CreateSessionError::Internal(format!("Failed to create session: {e}")))?; + + // Insert under the lock (re-checks limit and name uniqueness). + let insert_result = { + let mut sm = app_state.session_manager.lock().await; + let count = sm.session_count(); + if app_state.config.permissions.can_accept_session(count) { + sm.add_session(session.clone()) + } else { + Err("Maximum concurrent sessions limit reached".to_string()) + } + }; + if let Err(msg) = insert_result { + warn!(error = %msg, "create_dynamic_session failed during insert"); + let _ = session.shutdown_and_wait().await; + if msg.contains("limit reached") { + return Err(CreateSessionError::LimitReached(msg)); + } + return Err(CreateSessionError::Internal(format!("Failed to create session: {msg}"))); + } + + let session_id = session.id.clone(); + let session_name = session.name.clone(); + let created_at = crate::session::system_time_to_rfc3339(session.created_at); + + // Populate pipeline and send to engine + populate_session_pipeline(&session, &engine_pipeline).await; + send_pipeline_to_engine(&session, &engine_pipeline).await; + + info!( + session_id = %session_id, + name = ?session_name, + nodes = engine_pipeline.nodes.len(), + connections = engine_pipeline.connections.len(), + "Created new session" + ); + + // Broadcast event + let event = ApiEvent { + message_type: MessageType::Event, + correlation_id: None, + payload: EventPayload::SessionCreated { + session_id: session_id.clone(), + name: session_name.clone(), + created_at: created_at.clone(), + }, + }; + if app_state.event_tx.send(crate::state::BroadcastEvent::to_all(event)).is_err() { + debug!("No WebSocket clients connected to receive SessionCreated event"); + } + + Ok(CreateSessionResult { session_id, name: session_name, created_at }) +} + +/// Validate a batch of operations against a session's pipeline without applying. +/// +/// Returns a list of validation errors. An empty list means all operations +/// are valid. Callers must perform session-level permission and ownership +/// checks before calling this function. +/// Check batch operations for duplicate node IDs by simulating the +/// Add/Remove sequence. Returns the IDs of nodes that would collide. +async fn check_batch_node_id_uniqueness( + session: &crate::session::Session, + operations: &[streamkit_api::BatchOperation], +) -> Vec { + let mut live_ids: std::collections::HashSet = + session.pipeline.lock().await.nodes.keys().cloned().collect(); + let mut duplicates = Vec::new(); + for op in operations { + match op { + streamkit_api::BatchOperation::AddNode { node_id, .. } => { + if !live_ids.insert(node_id.clone()) { + duplicates.push(node_id.clone()); + } + }, + streamkit_api::BatchOperation::RemoveNode { node_id } => { + live_ids.remove(node_id.as_str()); + }, + _ => {}, + } + } + duplicates +} + +pub async fn validate_batch_operations( + session: &crate::session::Session, + operations: &[streamkit_api::BatchOperation], + perms: &crate::permissions::Permissions, + security_config: &crate::config::SecurityConfig, +) -> Vec { + let mut errors: Vec = Vec::new(); + + for node_id in check_batch_node_id_uniqueness(session, operations).await { + errors.push(streamkit_api::ValidationError { + error_type: streamkit_api::ValidationErrorType::Error, + message: format!("Batch rejected: node '{node_id}' already exists in the pipeline"), + node_id: Some(node_id), + connection_id: None, + }); + } + + // Validate all AddNode operations against permission and security rules. + for op in operations { + if let streamkit_api::BatchOperation::AddNode { node_id, kind, params, .. } = op { + if let Some(message) = crate::websocket_handlers::validate_add_node_op( + kind, + params.as_ref(), + perms, + security_config, + ) { + errors.push(streamkit_api::ValidationError { + error_type: streamkit_api::ValidationErrorType::Error, + message, + node_id: Some(node_id.clone()), + connection_id: None, + }); + } + } + } + + errors +} + +/// Apply a batch of graph mutations atomically to a running session. +/// +/// Returns `Ok(())` on success, or `Err(message)` if pre-validation fails +/// (e.g. duplicate node IDs or forbidden node kinds). Callers must perform +/// session-level permission and ownership checks before calling this function. +/// +/// # Errors +/// +/// Returns an error string when a batch operation fails pre-validation +/// (duplicate node IDs or forbidden node kinds). +pub async fn apply_batch_operations( + session: &crate::session::Session, + operations: Vec, + perms: &crate::permissions::Permissions, + security_config: &crate::config::SecurityConfig, +) -> Result<(), String> { + // Pre-validate duplicate node_ids. + let duplicates = check_batch_node_id_uniqueness(session, &operations).await; + if let Some(node_id) = duplicates.first() { + return Err(format!("Batch rejected: node '{node_id}' already exists in the pipeline")); + } + + // Validate permissions for all AddNode operations. + for op in &operations { + if let streamkit_api::BatchOperation::AddNode { kind, params, .. } = op { + if let Some(message) = crate::websocket_handlers::validate_add_node_op( + kind, + params.as_ref(), + perms, + security_config, + ) { + return Err(message); + } + } + } + + // Apply all operations in order. + let mut engine_operations = Vec::new(); + { + let mut pipeline = session.pipeline.lock().await; + for op in operations { + match op { + streamkit_api::BatchOperation::AddNode { node_id, kind, params } => { + pipeline.nodes.insert( + node_id.clone(), + streamkit_api::Node { + kind: kind.clone(), + params: params.clone(), + state: None, + }, + ); + engine_operations.push( + streamkit_core::control::EngineControlMessage::AddNode { + node_id, + kind, + params, + }, + ); + }, + streamkit_api::BatchOperation::RemoveNode { node_id } => { + pipeline.nodes.shift_remove(&node_id); + pipeline + .connections + .retain(|conn| conn.from_node != node_id && conn.to_node != node_id); + engine_operations.push( + streamkit_core::control::EngineControlMessage::RemoveNode { node_id }, + ); + }, + streamkit_api::BatchOperation::Connect { + from_node, + from_pin, + to_node, + to_pin, + mode, + } => { + pipeline.connections.push(streamkit_api::Connection { + from_node: from_node.clone(), + from_pin: from_pin.clone(), + to_node: to_node.clone(), + to_pin: to_pin.clone(), + mode, + }); + let core_mode = match mode { + streamkit_api::ConnectionMode::Reliable => { + streamkit_core::control::ConnectionMode::Reliable + }, + streamkit_api::ConnectionMode::BestEffort => { + streamkit_core::control::ConnectionMode::BestEffort + }, + }; + engine_operations.push( + streamkit_core::control::EngineControlMessage::Connect { + from_node, + from_pin, + to_node, + to_pin, + mode: core_mode, + }, + ); + }, + streamkit_api::BatchOperation::Disconnect { + from_node, + from_pin, + to_node, + to_pin, + } => { + pipeline.connections.retain(|conn| { + !(conn.from_node == from_node + && conn.from_pin == from_pin + && conn.to_node == to_node + && conn.to_pin == to_pin) + }); + engine_operations.push( + streamkit_core::control::EngineControlMessage::Disconnect { + from_node, + from_pin, + to_node, + to_pin, + }, + ); + }, + } + } + drop(pipeline); + } + + // Send control messages to the engine. + for msg in engine_operations { + session.send_control_message(msg).await; + } + + Ok(()) +} + +/// Send a control message to a specific node in a running session. +/// +/// For `UpdateParams` messages, this function also validates file-path +/// security, updates the durable pipeline model, and broadcasts a +/// `NodeParamsChanged` event. Callers must perform session-level +/// permission and ownership checks before calling this function. +/// +/// # Errors +/// +/// Returns an error string when the security policy rejects the +/// `UpdateParams` payload. +pub async fn tune_session_node( + session: &crate::session::Session, + node_id: String, + message: streamkit_core::control::NodeControlMessage, + security_config: &crate::config::SecurityConfig, + event_tx: &tokio::sync::broadcast::Sender, +) -> Result<(), String> { + use streamkit_core::control::NodeControlMessage; + + if let NodeControlMessage::UpdateParams(ref params) = message { + let kind = { + let pipeline = session.pipeline.lock().await; + pipeline.nodes.get(&node_id).map(|n| n.kind.clone()) + }; + + if !crate::websocket_handlers::validate_update_params_security( + kind.as_deref(), + params, + security_config, + ) { + return Err("Security policy rejected the UpdateParams payload".to_string()); + } + + { + let mut durable_params = params.clone(); + if let serde_json::Value::Object(ref mut map) = durable_params { + map.retain(|k, _| !k.starts_with('_')); + } + let mut pipeline = session.pipeline.lock().await; + if let Some(node) = pipeline.nodes.get_mut(&node_id) { + node.params = Some(match node.params.take() { + Some(existing) => { + crate::websocket_handlers::deep_merge_json(existing, durable_params) + }, + None => durable_params, + }); + } + } + + let event = streamkit_api::Event { + message_type: streamkit_api::MessageType::Event, + correlation_id: None, + payload: streamkit_api::EventPayload::NodeParamsChanged { + session_id: session.id.clone(), + node_id: node_id.clone(), + params: params.clone(), + }, + }; + if let Err(e) = event_tx.send(crate::state::BroadcastEvent::to_all(event)) { + tracing::error!("Failed to broadcast NodeParamsChanged event: {}", e); + } + } + + let control_msg = streamkit_core::control::EngineControlMessage::TuneNode { node_id, message }; + session.send_control_message(control_msg).await; + + Ok(()) +} + +/// Build the shared [`AppState`] without constructing any HTTP router. +/// +/// This is the common initialisation path used by both the HTTP server +/// (`create_app`) and the STDIO MCP server (`start_mcp_stdio`). +/// +/// # Panics +/// +/// See the panic documentation on [`create_app`] — the same invariants apply. #[allow(clippy::expect_used)] -pub fn create_app( +pub fn create_app_state( mut config: Config, auth: Option>, -) -> (Router, Arc) { - // --- Create the shared application state --- +) -> Arc { let (event_tx, _) = tokio::sync::broadcast::channel(128); - // Create ResourceManager for shared resources (ML models, etc.) let resource_policy = streamkit_core::ResourcePolicy { keep_loaded: config.resources.keep_models_loaded, max_memory_mb: config.resources.max_memory_mb, }; let resource_manager = Arc::new(streamkit_core::ResourceManager::new(resource_policy)); - // Set node buffer configuration for codec/container nodes - // This must be done before any nodes are created let node_buffer_config = streamkit_core::NodeBufferConfig { codec_channel_capacity: config .engine @@ -4010,12 +4321,10 @@ pub fn create_app( }; streamkit_core::set_node_buffer_config(node_buffer_config); - // Create engine with resource management support let plugin_base_dir = std::path::PathBuf::from(&config.plugins.directory); let wasm_plugin_dir = plugin_base_dir.join("wasm"); let native_plugin_dir = plugin_base_dir.join("native"); - // Build server-level node constraints from config let mut constraints = streamkit_core::GlobalNodeConstraints::new(); #[cfg(feature = "script")] @@ -4052,8 +4361,6 @@ pub fn create_app( &constraints, )); - // Initialize plugin manager - panic on failure since we can't proceed without it - // This expect is justified and documented in the function's # Panics section #[allow(clippy::expect_used)] let plugin_manager = UnifiedPluginManager::new( Arc::clone(&engine), @@ -4068,7 +4375,6 @@ pub fn create_app( let plugin_asset_registry = crate::plugin_assets::PluginAssetRegistry::new(); - // Spawn background task to load plugins asynchronously to avoid blocking startup UnifiedPluginManager::spawn_load_existing( Arc::clone(&plugin_manager), config.resources.prewarm.clone(), @@ -4085,13 +4391,11 @@ pub fn create_app( #[cfg(feature = "moq")] let moq_gateway = { let gateway = Arc::new(crate::moq_gateway::MoqGateway::new()); - // Initialize global gateway registry so nodes can access it let trait_obj: Arc = gateway.clone(); streamkit_core::moq_gateway::init_moq_gateway(trait_obj); Some(gateway) }; - // Initialize MSE gateway for HTTP-based media streaming let mse_gateway = { let gateway = Arc::new(crate::mse_gateway::MseGateway::new()); let trait_obj: Arc = gateway.clone(); @@ -4099,17 +4403,13 @@ pub fn create_app( gateway }; - // Use provided auth state or create disabled auth let auth = auth.unwrap_or_else(|| Arc::new(crate::auth::AuthState::disabled())); - // When built-in auth is enabled, treat the injected role header as the trusted role source. - // - // SECURITY: This header is overwritten by `auth_guard_middleware` for every API request. if auth.is_enabled() { config.permissions.role_header = Some(BUILTIN_AUTH_ROLE_HEADER.to_string()); } - let app_state = Arc::new(AppState { + Arc::new(AppState { engine, session_manager: Arc::new(tokio::sync::Mutex::new(SessionManager::default())), config: Arc::new(config), @@ -4122,7 +4422,26 @@ pub fn create_app( #[cfg(feature = "moq")] moq_gateway, mse_gateway, - }); + }) +} + +/// Create the full Axum application router and shared application state. +/// +/// # Panics +/// +/// - The unified plugin manager cannot be initialized (missing plugin directories, etc.) +/// - Plugin directories cannot be created due to filesystem permissions +/// - Plugin directories exist but are not accessible +/// - CORS configuration is invalid (wildcard with auth enabled) +/// +/// Since this occurs during application initialization, a panic here is acceptable +/// as the server cannot function without proper configuration. +#[allow(clippy::expect_used)] +pub fn create_app( + config: Config, + auth: Option>, +) -> (Router, Arc) { + let app_state = create_app_state(config, auth); let mut oneshot_route = post(process_oneshot_pipeline_handler) // Use configurable body limit for oneshot processing @@ -4131,7 +4450,7 @@ pub fn create_app( oneshot_route = oneshot_route.layer(ConcurrencyLimitLayer::new(max)); } - #[cfg_attr(not(feature = "moq"), allow(unused_mut))] + #[cfg_attr(not(any(feature = "moq", feature = "mcp")), allow(unused_mut))] let mut router = Router::new() .route("/healthz", get(health_handler)) .route("/health", get(health_handler)) @@ -4221,6 +4540,35 @@ pub fn create_app( router = router.route("/certificate.sha256", get(get_certificate_sha256_handler)); } + // Mount MCP (Model Context Protocol) endpoint when the feature is enabled + // and the config has it turned on. + // + // SECURITY: The endpoint MUST live under /api/ so that auth_guard_middleware, + // origin_guard_middleware, CORS, tracing, and metrics all apply automatically. + // Endpoint validation is enforced at config-load time (McpConfig::validate). + #[cfg(feature = "mcp")] + { + if app_state.config.mcp.enabled { + info!( + endpoint = %app_state.config.mcp.endpoint, + "MCP endpoint enabled" + ); + router = router.nest_service( + &app_state.config.mcp.endpoint, + crate::mcp::streamable_http_service(Arc::clone(&app_state)), + ); + } + } + + // Warn if mcp.enabled is set but the binary was compiled without the mcp feature. + #[cfg(not(feature = "mcp"))] + if app_state.config.mcp.enabled { + warn!( + "mcp.enabled is true but the binary was compiled without the 'mcp' feature. \ + The MCP endpoint will not be available. Rebuild with --features mcp." + ); + } + // Add MSE streaming route. // SECURITY: Intentionally outside /api/ so auth_guard_middleware does not // apply — matches the MoQ WebTransport model. See mse_stream_handler doc comment. @@ -4789,6 +5137,85 @@ pub async fn start_server(config: &Config) -> Result<(), Box Result<(), Box> { + let app_state = create_app_state(config.clone(), None); + + let mcp = crate::mcp::StreamKitMcp::new(app_state); + + let ct = tokio_util::sync::CancellationToken::new(); + + // Listen for Ctrl-C / SIGTERM and cancel the token so the STDIO + // transport shuts down gracefully. + // These expect() calls are justified and documented in the function's # Panics section. + let ct_clone = ct.clone(); + tokio::spawn(async move { + let ctrl_c = tokio::signal::ctrl_c(); + + #[cfg(unix)] + { + #[allow(clippy::expect_used)] + let mut sigterm = + tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) + .expect("failed to install SIGTERM handler"); + tokio::select! { + _ = ctrl_c => {}, + _ = sigterm.recv() => {}, + } + } + + #[cfg(not(unix))] + { + #[allow(clippy::expect_used)] + ctrl_c.await.expect("failed to install Ctrl+C handler"); + } + + info!("Received shutdown signal, stopping MCP STDIO server"); + ct_clone.cancel(); + }); + + info!("Starting MCP server over STDIO transport"); + + let service = rmcp::service::serve_server_with_ct(mcp, rmcp::transport::io::stdio(), ct) + .await + .map_err(|e| format!("Failed to initialize MCP STDIO server: {e}"))?; + + service.waiting().await?; + + info!("MCP STDIO server stopped"); + Ok(()) +} + // --- A simple error type for the Axum handler --- #[derive(Debug)] enum AppError { diff --git a/apps/skit/src/websocket_handlers.rs b/apps/skit/src/websocket_handlers.rs index dc4f2d32..c68a8377 100644 --- a/apps/skit/src/websocket_handlers.rs +++ b/apps/skit/src/websocket_handlers.rs @@ -14,8 +14,7 @@ use crate::state::{AppState, BroadcastEvent}; use opentelemetry::global; use std::sync::Arc; use streamkit_api::{ - Event as ApiEvent, EventPayload, MessageType, RequestPayload, ResponsePayload, ValidationError, - ValidationErrorType, + Event as ApiEvent, EventPayload, MessageType, RequestPayload, ResponsePayload, }; use streamkit_core::control::{EngineControlMessage, NodeControlMessage}; use streamkit_core::registry::NodeDefinition; @@ -41,7 +40,7 @@ fn can_access_session(session: &Session, role_name: &str, perms: &Permissions) - /// Returns `Some(error_message)` if the operation is not allowed, `None` if it passes. /// This is the single source of truth for AddNode validation, used by `handle_add_node`, /// `handle_validate_batch`, and `handle_apply_batch`. -fn validate_add_node_op( +pub fn validate_add_node_op( kind: &str, params: Option<&serde_json::Value>, perms: &Permissions, @@ -776,18 +775,16 @@ async fn handle_tune_node( perms: &Permissions, role_name: &str, ) -> Option { - // Check permission to tune nodes if !perms.tune_nodes { return Some(ResponsePayload::Error { message: "Permission denied: cannot tune nodes".to_string(), }); } - // Get session with SHORT lock hold to avoid blocking other operations let session = { let session_manager = app_state.session_manager.lock().await; session_manager.get_session_by_name_or_id(&session_id) - }; // Session manager lock released here + }; let Some(session) = session else { return Some(ResponsePayload::Error { @@ -795,115 +792,30 @@ async fn handle_tune_node( }); }; - // Check ownership (session is cloned, doesn't need lock) if !can_access_session(&session, role_name, perms) { return Some(ResponsePayload::Error { message: "Permission denied: you do not own this session".to_string(), }); } - // Handle UpdateParams specially for event broadcasting (and validate file paths) - if let NodeControlMessage::UpdateParams(ref params) = message { - let (kind, file_path, script_path) = { - let pipeline = session.pipeline.lock().await; - let kind = pipeline.nodes.get(&node_id).map(|n| n.kind.clone()); - let file_path = - params.get("path").and_then(serde_json::Value::as_str).map(str::to_string); - let script_path = - params.get("script_path").and_then(serde_json::Value::as_str).map(str::to_string); - drop(pipeline); - (kind, file_path, script_path) - }; - - let file_path = file_path.as_deref(); - let script_path = script_path.as_deref(); - - if kind.as_deref() == Some("core::file_reader") { - let Some(path) = file_path else { - return Some(ResponsePayload::Error { - message: "Invalid file_reader params: expected params.path to be a string" - .to_string(), - }); - }; - if let Err(e) = file_security::validate_file_path(path, &app_state.config.security) { - return Some(ResponsePayload::Error { message: format!("Invalid file path: {e}") }); - } - } - - if kind.as_deref() == Some("core::file_writer") { - if let Some(path) = file_path { - if let Err(e) = file_security::validate_write_path(path, &app_state.config.security) - { - return Some(ResponsePayload::Error { - message: format!("Invalid write path: {e}"), - }); - } - } - } - - if kind.as_deref() == Some("core::script") { - if let Some(path) = script_path { - if !path.trim().is_empty() { - if let Err(e) = - file_security::validate_file_path(path, &app_state.config.security) - { - return Some(ResponsePayload::Error { - message: format!("Invalid script_path: {e}"), - }); - } - } - } - } - - { - // Store sanitized params: strip transient sync metadata - // (_sender, _rev, etc.) for consistency with the - // fire-and-forget handler. - let mut durable_params = params.clone(); - if let serde_json::Value::Object(ref mut map) = durable_params { - map.retain(|k, _| !k.starts_with('_')); - } - let mut pipeline = session.pipeline.lock().await; - if let Some(node) = pipeline.nodes.get_mut(&node_id) { - // Deep-merge the partial update into existing params so - // sibling keys are preserved (mirrors the async handler). - node.params = Some(match node.params.take() { - Some(existing) => deep_merge_json(existing, durable_params), - None => durable_params, - }); - } else { - warn!( - node_id = %node_id, - "Attempted to tune params for non-existent node in pipeline model" - ); - } - } // Lock released here - - // Broadcast event to all clients - let event = ApiEvent { - message_type: MessageType::Event, - correlation_id: None, - payload: EventPayload::NodeParamsChanged { - session_id: session.id.clone(), - node_id: node_id.clone(), - params: params.clone(), - }, - }; - if let Err(e) = app_state.event_tx.send(BroadcastEvent::to_all(event)) { - error!("Failed to broadcast NodeParamsChanged event: {}", e); - } + match crate::server::tune_session_node( + &session, + node_id, + message, + &app_state.config.security, + &app_state.event_tx, + ) + .await + { + Ok(()) => Some(ResponsePayload::Success), + Err(message) => Some(ResponsePayload::Error { message }), } - - // Now safe to do async operations without holding session_manager lock - let control_msg = EngineControlMessage::TuneNode { node_id, message }; - session.send_control_message(control_msg).await; - Some(ResponsePayload::Success) } /// Validate file/script paths in UpdateParams against security policy. /// /// Returns `true` if the params are allowed, `false` if they should be rejected. -fn validate_update_params_security( +pub fn validate_update_params_security( kind: Option<&str>, params: &serde_json::Value, security: &crate::config::SecurityConfig, @@ -960,9 +872,7 @@ async fn handle_tune_node_fire_and_forget( ) -> Option { let action_label = "TuneNodeAsync"; - // Check permission to tune nodes if !perms.tune_nodes { - // For async operations, we don't send a response but we should still log warn!("Permission denied: attempted to tune node without permission via {action_label}"); return None; } @@ -970,10 +880,9 @@ async fn handle_tune_node_fire_and_forget( let session = { let session_manager = app_state.session_manager.lock().await; session_manager.get_session_by_name_or_id(&session_id) - }; // Session manager lock released here + }; if let Some(session) = session { - // Check ownership if !can_access_session(&session, role_name, perms) { warn!( session_id = %session_id, @@ -983,71 +892,21 @@ async fn handle_tune_node_fire_and_forget( return None; } - // Handle UpdateParams specially for pipeline model updates and event broadcasting - if let NodeControlMessage::UpdateParams(ref params) = message { - let kind = { - let pipeline = session.pipeline.lock().await; - pipeline.nodes.get(&node_id).map(|n| n.kind.clone()) - }; - - if !validate_update_params_security(kind.as_deref(), params, &app_state.config.security) - { - return None; - } - - { - // Store sanitized params: strip transient sync metadata - // (_sender, _rev, etc.) from the durable pipeline model. - // Top-level keys prefixed with `_` are reserved for - // in-flight metadata and must not leak into persistence - // or GetPipeline responses. - let mut durable_params = params.clone(); - if let serde_json::Value::Object(ref mut map) = durable_params { - map.retain(|k, _| !k.starts_with('_')); - } - let mut pipeline = session.pipeline.lock().await; - if let Some(node) = pipeline.nodes.get_mut(&node_id) { - // Deep-merge the partial update into existing params so - // sibling keys are preserved. Without this, a partial - // nested update like `{ properties: { show: false } }` - // would overwrite the entire params, losing keys such - // as `fps`, `width`, or `properties.name`. - node.params = Some(match node.params.take() { - Some(existing) => deep_merge_json(existing, durable_params), - None => durable_params, - }); - } else { - warn!( - node_id = %node_id, - "Attempted to tune params for non-existent node in pipeline model via {action_label}" - ); - } - } // Lock released here - - // Broadcast the *partial delta* (not merged state) to all clients. - // Correct deep-merge on receive depends on each client having a - // valid base state, which is guaranteed because every client - // fetches the full pipeline on connect. - let event = ApiEvent { - message_type: MessageType::Event, - correlation_id: None, - payload: EventPayload::NodeParamsChanged { - session_id: session.id.clone(), - node_id: node_id.clone(), - params: params.clone(), - }, - }; - if let Err(e) = app_state.event_tx.send(BroadcastEvent::to_all(event)) { - error!("Failed to broadcast NodeParamsChanged event: {}", e); - } + if let Err(e) = crate::server::tune_session_node( + &session, + node_id, + message, + &app_state.config.security, + &app_state.event_tx, + ) + .await + { + warn!("Security policy rejected tune via {action_label}: {e}"); } - - let control_msg = EngineControlMessage::TuneNode { node_id, message }; - session.send_control_message(control_msg).await; } else { warn!("Could not tune non-existent session '{session_id}' via {action_label}"); } - None // Do not send a response + None } /// Handle async node tuning (fire-and-forget, broadcasts to all). @@ -1136,14 +995,12 @@ async fn handle_validate_batch( perms: &Permissions, role_name: &str, ) -> ResponsePayload { - // Validate that user has permission for modify_sessions if !perms.modify_sessions { return ResponsePayload::Error { message: "Permission denied: cannot modify sessions".to_string(), }; } - // Verify session exists let session = { let session_manager = app_state.session_manager.lock().await; session_manager.get_session_by_name_or_id(&session_id) @@ -1153,56 +1010,19 @@ async fn handle_validate_batch( return ResponsePayload::Error { message: format!("Session '{session_id}' not found") }; }; - // Check ownership if !can_access_session(&session, role_name, perms) { return ResponsePayload::Error { message: "Permission denied: you do not own this session".to_string(), }; } - // Collect all validation errors so the caller sees every problem at once. - let mut errors: Vec = Vec::new(); - - // Pre-validate duplicate node_ids against the pipeline model, mirroring - // the same simulation that handle_apply_batch performs. - let mut live_ids: std::collections::HashSet = - session.pipeline.lock().await.nodes.keys().cloned().collect(); - for op in operations { - match op { - streamkit_api::BatchOperation::AddNode { node_id, .. } => { - if !live_ids.insert(node_id.clone()) { - errors.push(ValidationError { - error_type: ValidationErrorType::Error, - message: format!( - "Batch rejected: node '{node_id}' already exists in the pipeline" - ), - node_id: Some(node_id.clone()), - connection_id: None, - }); - } - }, - streamkit_api::BatchOperation::RemoveNode { node_id } => { - live_ids.remove(node_id.as_str()); - }, - _ => {}, - } - } - - // Validate all AddNode operations against permission and security rules. - for op in operations { - if let streamkit_api::BatchOperation::AddNode { node_id, kind, params, .. } = op { - if let Some(message) = - validate_add_node_op(kind, params.as_ref(), perms, &app_state.config.security) - { - errors.push(ValidationError { - error_type: ValidationErrorType::Error, - message, - node_id: Some(node_id.clone()), - connection_id: None, - }); - } - } - } + let errors = crate::server::validate_batch_operations( + &session, + operations, + perms, + &app_state.config.security, + ) + .await; info!( operation_count = operations.len(), @@ -1212,7 +1032,6 @@ async fn handle_validate_batch( ResponsePayload::ValidationResult { errors } } -#[allow(clippy::significant_drop_tightening)] async fn handle_apply_batch( session_id: String, operations: Vec, @@ -1220,18 +1039,16 @@ async fn handle_apply_batch( perms: &Permissions, role_name: &str, ) -> Option { - // Check permission to modify sessions if !perms.modify_sessions { return Some(ResponsePayload::Error { message: "Permission denied: cannot modify sessions".to_string(), }); } - // Get session with SHORT lock hold to avoid blocking other operations let session = { let session_manager = app_state.session_manager.lock().await; session_manager.get_session_by_name_or_id(&session_id) - }; // Session manager lock released here + }; let Some(session) = session else { return Some(ResponsePayload::Error { @@ -1239,142 +1056,26 @@ async fn handle_apply_batch( }); }; - // Check ownership (session is cloned, doesn't need lock) if !can_access_session(&session, role_name, perms) { return Some(ResponsePayload::Error { message: "Permission denied: you do not own this session".to_string(), }); } - // Pre-validate duplicate node_ids against the pipeline model. - // Simulate the batch's Add/Remove sequence so that Remove→Add for - // the same ID within the batch is allowed, but duplicate Adds - // (without intervening Remove) are rejected before any mutation. - { - let pipeline = session.pipeline.lock().await; - let mut live_ids: std::collections::HashSet<&str> = - pipeline.nodes.keys().map(String::as_str).collect(); - for op in &operations { - match op { - streamkit_api::BatchOperation::AddNode { node_id, .. } => { - if !live_ids.insert(node_id.as_str()) { - return Some(ResponsePayload::Error { - message: format!( - "Batch rejected: node '{node_id}' already exists in the pipeline" - ), - }); - } - }, - streamkit_api::BatchOperation::RemoveNode { node_id } => { - live_ids.remove(node_id.as_str()); - }, - _ => {}, - } - } - } // Pipeline lock released after pre-validation - - // Validate permissions for all operations. - for op in &operations { - if let streamkit_api::BatchOperation::AddNode { kind, params, .. } = op { - if let Some(message) = - validate_add_node_op(kind, params.as_ref(), perms, &app_state.config.security) - { - return Some(ResponsePayload::Error { message }); - } - } - } - - // Apply all operations in order - let mut engine_operations = Vec::new(); - + match crate::server::apply_batch_operations( + &session, + operations, + perms, + &app_state.config.security, + ) + .await { - let mut pipeline = session.pipeline.lock().await; - - for op in operations { - match op { - streamkit_api::BatchOperation::AddNode { node_id, kind, params } => { - pipeline.nodes.insert( - node_id.clone(), - streamkit_api::Node { - kind: kind.clone(), - params: params.clone(), - state: None, - }, - ); - engine_operations.push(EngineControlMessage::AddNode { node_id, kind, params }); - }, - streamkit_api::BatchOperation::RemoveNode { node_id } => { - pipeline.nodes.shift_remove(&node_id); - pipeline - .connections - .retain(|conn| conn.from_node != node_id && conn.to_node != node_id); - engine_operations.push(EngineControlMessage::RemoveNode { node_id }); - }, - streamkit_api::BatchOperation::Connect { - from_node, - from_pin, - to_node, - to_pin, - mode, - } => { - pipeline.connections.push(streamkit_api::Connection { - from_node: from_node.clone(), - from_pin: from_pin.clone(), - to_node: to_node.clone(), - to_pin: to_pin.clone(), - mode, - }); - let core_mode = match mode { - streamkit_api::ConnectionMode::Reliable => { - streamkit_core::control::ConnectionMode::Reliable - }, - streamkit_api::ConnectionMode::BestEffort => { - streamkit_core::control::ConnectionMode::BestEffort - }, - }; - engine_operations.push(EngineControlMessage::Connect { - from_node, - from_pin, - to_node, - to_pin, - mode: core_mode, - }); - }, - streamkit_api::BatchOperation::Disconnect { - from_node, - from_pin, - to_node, - to_pin, - } => { - pipeline.connections.retain(|conn| { - !(conn.from_node == from_node - && conn.from_pin == from_pin - && conn.to_node == to_node - && conn.to_pin == to_pin) - }); - engine_operations.push(EngineControlMessage::Disconnect { - from_node, - from_pin, - to_node, - to_pin, - }); - }, - } - } - drop(pipeline); - } // Release pipeline lock - - // Now safe to do async operations without holding session_manager lock - for msg in engine_operations { - session.send_control_message(msg).await; + Ok(()) => { + info!(session_id = %session_id, "Applied batch operations successfully"); + Some(ResponsePayload::BatchApplied { success: true, errors: Vec::new() }) + }, + Err(message) => Some(ResponsePayload::Error { message }), } - - info!( - session_id = %session_id, - "Applied batch operations successfully" - ); - - Some(ResponsePayload::BatchApplied { success: true, errors: Vec::new() }) } fn handle_get_permissions(perms: &Permissions, role_name: &str) -> ResponsePayload { @@ -1385,7 +1086,7 @@ fn handle_get_permissions(perms: &Permissions, role_name: &str) -> ResponsePaylo /// Recursively deep-merges `source` into `target`, returning the merged value. /// Only JSON objects are merged recursively; arrays and scalars in `source` /// replace the corresponding value in `target`. -fn deep_merge_json(target: serde_json::Value, source: serde_json::Value) -> serde_json::Value { +pub fn deep_merge_json(target: serde_json::Value, source: serde_json::Value) -> serde_json::Value { match (target, source) { (serde_json::Value::Object(mut t_map), serde_json::Value::Object(s_map)) => { for (key, s_val) in s_map { diff --git a/apps/skit/tests/mcp_integration_test.rs b/apps/skit/tests/mcp_integration_test.rs new file mode 100644 index 00000000..2b72107e --- /dev/null +++ b/apps/skit/tests/mcp_integration_test.rs @@ -0,0 +1,1487 @@ +// SPDX-FileCopyrightText: © 2025 StreamKit Contributors +// +// SPDX-License-Identifier: MPL-2.0 + +//! Integration tests for the embedded MCP (Model Context Protocol) server. +//! +//! These tests exercise the MCP endpoint through real HTTP requests, +//! verifying auth enforcement, tool routing, and permission filtering. +//! +//! Requires the `mcp` feature: `cargo test --features mcp -p streamkit-server` + +#![cfg(feature = "mcp")] +#![allow( + clippy::unwrap_used, + clippy::expect_used, + clippy::disallowed_macros, + clippy::uninlined_format_args +)] + +use axum::http::StatusCode; +use reqwest::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE}; +use serde_json::json; +use std::net::SocketAddr; +use std::sync::Arc; +use streamkit_server::Config; +use tempfile::TempDir; +use tokio::net::TcpListener; +use tokio::time::Duration; + +/// Poll `/healthz` until the server is ready (up to 2 s). +async fn wait_for_healthz(addr: SocketAddr) { + let client = reqwest::Client::new(); + let url = format!("http://{addr}/healthz"); + for _ in 0..40 { + if client.get(&url).send().await.is_ok() { + return; + } + tokio::time::sleep(Duration::from_millis(50)).await; + } + panic!("Server at {addr} did not become healthy within 2 s"); +} + +/// Start a test server with MCP enabled and built-in auth. +/// +/// # Panics +/// +/// Panics if the TCP listener cannot be bound (including `PermissionDenied`). +async fn start_mcp_server() -> (SocketAddr, tokio::task::JoinHandle<()>, String, TempDir) { + let listener = + TcpListener::bind("127.0.0.1:0").await.expect("Failed to bind test server listener"); + let addr = listener.local_addr().unwrap(); + let temp_dir = TempDir::new().unwrap(); + + let mut config = Config::default(); + config.mcp.enabled = true; + config.auth.mode = streamkit_server::config::AuthMode::Enabled; + config.auth.state_dir = temp_dir.path().to_string_lossy().to_string(); + + let auth_state = streamkit_server::auth::AuthState::new(&config.auth, true) + .await + .expect("Failed to init auth state"); + let auth_state = Arc::new(auth_state); + + let admin_token_path = temp_dir.path().join("admin.token"); + let admin_token = + tokio::fs::read_to_string(&admin_token_path).await.expect("Missing admin.token"); + let admin_token = admin_token.trim().to_string(); + + let (app, _state) = streamkit_server::server::create_app(config, Some(auth_state)); + let server_handle = tokio::spawn(async move { + axum::serve(listener, app.into_make_service()).await.unwrap(); + }); + + wait_for_healthz(addr).await; + (addr, server_handle, admin_token, temp_dir) +} + +/// Start a test server with restricted permissions (no create_sessions). +async fn start_restricted_mcp_server() -> (SocketAddr, tokio::task::JoinHandle<()>, String, TempDir) +{ + let listener = + TcpListener::bind("127.0.0.1:0").await.expect("Failed to bind test server listener"); + let addr = listener.local_addr().unwrap(); + let temp_dir = TempDir::new().unwrap(); + + let mut config = Config::default(); + config.mcp.enabled = true; + config.auth.mode = streamkit_server::config::AuthMode::Enabled; + config.auth.state_dir = temp_dir.path().to_string_lossy().to_string(); + + // Restrict the admin role: disable create_sessions. + // The admin token generated by AuthState::new has role="admin". + if let Some(admin_perms) = config.permissions.roles.get_mut("admin") { + admin_perms.create_sessions = false; + } + + let auth_state = streamkit_server::auth::AuthState::new(&config.auth, true) + .await + .expect("Failed to init auth state"); + let auth_state = Arc::new(auth_state); + + let admin_token_path = temp_dir.path().join("admin.token"); + let admin_token = + tokio::fs::read_to_string(&admin_token_path).await.expect("Missing admin.token"); + let admin_token = admin_token.trim().to_string(); + + let (app, _state) = streamkit_server::server::create_app(config, Some(auth_state)); + let server_handle = tokio::spawn(async move { + axum::serve(listener, app.into_make_service()).await.unwrap(); + }); + + wait_for_healthz(addr).await; + (addr, server_handle, admin_token, temp_dir) +} + +/// Send a JSON-RPC request to the MCP endpoint. +async fn mcp_post( + client: &reqwest::Client, + addr: SocketAddr, + body: &serde_json::Value, + auth_header: Option<&str>, +) -> reqwest::Response { + let mut req = client + .post(format!("http://{addr}/api/v1/mcp")) + .header(CONTENT_TYPE, "application/json") + .header(ACCEPT, "application/json, text/event-stream") + .json(body); + + if let Some(token) = auth_header { + req = req.header(AUTHORIZATION, format!("Bearer {token}")); + } + + req.send().await.expect("Failed to send MCP request") +} + +/// Send a JSON-RPC request with a session ID. +async fn mcp_post_with_session( + client: &reqwest::Client, + addr: SocketAddr, + body: &serde_json::Value, + token: &str, + session_id: &str, +) -> reqwest::Response { + client + .post(format!("http://{addr}/api/v1/mcp")) + .header(CONTENT_TYPE, "application/json") + .header(ACCEPT, "application/json, text/event-stream") + .header(AUTHORIZATION, format!("Bearer {token}")) + .header("mcp-session-id", session_id) + .json(body) + .send() + .await + .expect("Failed to send MCP request") +} + +/// Initialize an MCP session and return the session ID. +async fn init_mcp_session(client: &reqwest::Client, addr: SocketAddr, token: &str) -> String { + let init = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2025-03-26", + "capabilities": {}, + "clientInfo": { "name": "test", "version": "0.1" } + } + }); + let res = mcp_post(client, addr, &init, Some(token)).await; + assert_eq!(res.status(), StatusCode::OK); + + let session_id = res + .headers() + .get("mcp-session-id") + .expect("missing mcp-session-id header") + .to_str() + .unwrap() + .to_string(); + + // Send initialized notification + let initialized = json!({ + "jsonrpc": "2.0", + "method": "notifications/initialized" + }); + let _ = mcp_post_with_session(client, addr, &initialized, token, &session_id).await; + + session_id +} + +/// Extract JSON-RPC result from SSE response body. +fn extract_sse_json(body_text: &str) -> serde_json::Value { + let json_str = body_text + .lines() + .filter_map(|l| l.strip_prefix("data:")) + .map(str::trim) + .find(|s| s.starts_with('{')) + .expect("no SSE data line with JSON found in response"); + serde_json::from_str(json_str).expect("invalid JSON in SSE data") +} + +// ----------------------------------------------------------------------- +// Tests +// ----------------------------------------------------------------------- + +#[tokio::test] +async fn mcp_unauthenticated_request_is_rejected() { + let _ = tracing_subscriber::fmt::try_init(); + + let (addr, _handle, _token, _dir) = start_mcp_server().await; + let client = reqwest::Client::new(); + + // JSON-RPC initialize request without auth + let body = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2025-03-26", + "capabilities": {}, + "clientInfo": { "name": "test", "version": "0.1" } + } + }); + + let res = mcp_post(&client, addr, &body, None).await; + assert_eq!(res.status(), StatusCode::UNAUTHORIZED); +} + +#[tokio::test] +async fn mcp_authenticated_initialize_succeeds() { + let _ = tracing_subscriber::fmt::try_init(); + + let (addr, _handle, token, _dir) = start_mcp_server().await; + let client = reqwest::Client::new(); + + let body = json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2025-03-26", + "capabilities": {}, + "clientInfo": { "name": "test", "version": "0.1" } + } + }); + + let res = mcp_post(&client, addr, &body, Some(&token)).await; + assert_eq!(res.status(), StatusCode::OK); +} + +#[tokio::test] +async fn mcp_validate_pipeline_returns_diagnostics() { + let _ = tracing_subscriber::fmt::try_init(); + + let (addr, _handle, token, _dir) = start_mcp_server().await; + let client = reqwest::Client::new(); + + let session_id = init_mcp_session(&client, addr, &token).await; + + // Call validate_pipeline with invalid YAML + let validate = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "validate_pipeline", + "arguments": { + "yaml": "not: valid: yaml: [[" + } + } + }); + let res = mcp_post_with_session(&client, addr, &validate, &token, &session_id).await; + assert_eq!(res.status(), StatusCode::OK); + + // Response is SSE stream — extract JSON-RPC result from `data:` lines. + let body_text = res.text().await.expect("failed to read response body"); + let body = extract_sse_json(&body_text); + let result = &body["result"]; + assert!(!result.is_null(), "expected result in response, got: {body}"); + + // The tool should return text content with validation diagnostics + let content = &result["content"][0]["text"]; + assert!(content.is_string(), "expected text content in result"); + let text = content.as_str().unwrap(); + let parsed: serde_json::Value = serde_json::from_str(text).expect("tool output not valid JSON"); + assert_eq!(parsed["valid"], false); + assert!(!parsed["errors"].as_array().unwrap().is_empty()); +} + +#[tokio::test] +async fn mcp_create_session_permission_denied() { + let _ = tracing_subscriber::fmt::try_init(); + + let (addr, _handle, token, _dir) = start_restricted_mcp_server().await; + let client = reqwest::Client::new(); + + let session_id = init_mcp_session(&client, addr, &token).await; + + // Try to create a session — should be denied by permissions + let create = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "create_session", + "arguments": { + "yaml": "nodes:\n tone:\n kind: streamkit::tone_generator\n params:\n frequency: 440\n sample_rate: 48000\n duration: 1.0", + "name": "test-denied" + } + } + }); + let res = mcp_post_with_session(&client, addr, &create, &token, &session_id).await; + assert_eq!(res.status(), StatusCode::OK); + + let body_text = res.text().await.expect("failed to read response body"); + let body = extract_sse_json(&body_text); + let error = &body["error"]; + assert!(!error.is_null(), "expected error in response, got: {body}"); + let error_msg = error["message"].as_str().unwrap_or(""); + assert!( + error_msg.contains("Permission denied") || error_msg.contains("permission"), + "expected permission error, got: {error_msg}" + ); +} + +#[tokio::test] +async fn mcp_destroy_session_nonexistent_returns_error() { + let _ = tracing_subscriber::fmt::try_init(); + + let (addr, _handle, token, _dir) = start_mcp_server().await; + let client = reqwest::Client::new(); + + let session_id = init_mcp_session(&client, addr, &token).await; + + // Try to destroy a session that doesn't exist + let destroy = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "destroy_session", + "arguments": { + "session_id": "nonexistent-session-id-12345" + } + } + }); + let res = mcp_post_with_session(&client, addr, &destroy, &token, &session_id).await; + assert_eq!(res.status(), StatusCode::OK); + + let body_text = res.text().await.expect("failed to read response body"); + let body = extract_sse_json(&body_text); + let error = &body["error"]; + assert!(!error.is_null(), "expected error for nonexistent session destroy, got: {body}"); +} + +#[tokio::test] +async fn mcp_config_endpoint_validation() { + // Verify that McpConfig::validate rejects bad endpoints at config-load time + let mut mcp_config = streamkit_server::config::McpConfig { + enabled: false, + endpoint: "/bad/path".to_string(), + ..Default::default() + }; + assert!(mcp_config.validate().is_ok()); + + // Enabled with bad prefix should fail + mcp_config.enabled = true; + mcp_config.endpoint = "/bad/path".to_string(); + assert!(mcp_config.validate().is_err()); + + // Auth bypass path should fail + mcp_config.endpoint = "/api/v1/auth/mcp".to_string(); + assert!(mcp_config.validate().is_err()); + + // Path traversal should fail + mcp_config.endpoint = "/api/v1/../mcp".to_string(); + assert!(mcp_config.validate().is_err()); + + // Only /api/ prefix should fail (missing version/mcp segments) + mcp_config.endpoint = "/api/".to_string(); + assert!(mcp_config.validate().is_err()); + + // Non-versioned path should fail + mcp_config.endpoint = "/api/mcp".to_string(); + assert!(mcp_config.validate().is_err()); + + // Colliding with existing routes should fail (wrong third segment) + mcp_config.endpoint = "/api/v1/sessions".to_string(); + assert!(mcp_config.validate().is_err()); + + // Valid default endpoint should pass + mcp_config.endpoint = "/api/v1/mcp".to_string(); + assert!(mcp_config.validate().is_ok()); + + // Valid with subpath should pass + mcp_config.endpoint = "/api/v2/mcp/extra".to_string(); + assert!(mcp_config.validate().is_ok()); +} + +// ----------------------------------------------------------------------- +// Positive-path tests +// ----------------------------------------------------------------------- + +/// Minimal valid pipeline YAML for dynamic sessions. +const PASSTHROUGH_YAML: &str = "nodes:\n pass:\n kind: core::passthrough"; + +/// Minimal valid pipeline YAML for oneshot (steps-based). +const ONESHOT_PASSTHROUGH_YAML: &str = "\ +mode: oneshot\n\ +steps:\n\ + - kind: streamkit::http_input\n\ + - kind: core::passthrough\n\ + - kind: streamkit::http_output"; + +#[tokio::test] +async fn mcp_list_nodes_returns_definitions() { + let _ = tracing_subscriber::fmt::try_init(); + + let (addr, _handle, token, _dir) = start_mcp_server().await; + let client = reqwest::Client::new(); + let session_id = init_mcp_session(&client, addr, &token).await; + + let list = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "list_nodes", + "arguments": {} + } + }); + let res = mcp_post_with_session(&client, addr, &list, &token, &session_id).await; + assert_eq!(res.status(), StatusCode::OK); + + let body_text = res.text().await.unwrap(); + let body = extract_sse_json(&body_text); + let result = &body["result"]; + assert!(!result.is_null(), "expected result, got: {body}"); + + let text = result["content"][0]["text"].as_str().expect("expected text content"); + let defs: Vec = serde_json::from_str(text).expect("expected JSON array"); + assert!(!defs.is_empty(), "list_nodes should return at least one definition"); + + // Verify passthrough and synthetic nodes are present + let kinds: Vec<&str> = defs.iter().filter_map(|d| d["kind"].as_str()).collect(); + assert!(kinds.contains(&"core::passthrough"), "missing core::passthrough in: {kinds:?}"); + assert!(kinds.contains(&"streamkit::http_input"), "missing synthetic http_input in: {kinds:?}"); +} + +#[tokio::test] +async fn mcp_create_list_get_destroy_session_round_trip() { + let _ = tracing_subscriber::fmt::try_init(); + + let (addr, _handle, token, _dir) = start_mcp_server().await; + let client = reqwest::Client::new(); + let session_id = init_mcp_session(&client, addr, &token).await; + + // 1. Create a session + let create = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "create_session", + "arguments": { + "yaml": PASSTHROUGH_YAML, + "name": "mcp-roundtrip-test" + } + } + }); + let res = mcp_post_with_session(&client, addr, &create, &token, &session_id).await; + assert_eq!(res.status(), StatusCode::OK); + + let body_text = res.text().await.unwrap(); + let body = extract_sse_json(&body_text); + let result = &body["result"]; + assert!(!result.is_null(), "expected result from create_session, got: {body}"); + + let text = result["content"][0]["text"].as_str().expect("expected text content"); + let created: serde_json::Value = serde_json::from_str(text).expect("expected JSON"); + let skit_session_id = created["session_id"].as_str().expect("missing session_id"); + assert_eq!(created["name"].as_str(), Some("mcp-roundtrip-test")); + assert!(created["created_at"].as_str().is_some()); + + // 2. List sessions — our session should appear + let list = json!({ + "jsonrpc": "2.0", + "id": 3, + "method": "tools/call", + "params": { + "name": "list_sessions", + "arguments": {} + } + }); + let res = mcp_post_with_session(&client, addr, &list, &token, &session_id).await; + assert_eq!(res.status(), StatusCode::OK); + + let body_text = res.text().await.unwrap(); + let body = extract_sse_json(&body_text); + let text = body["result"]["content"][0]["text"].as_str().expect("expected text"); + let sessions: Vec = serde_json::from_str(text).expect("expected JSON array"); + assert!( + sessions.iter().any(|s| s["id"].as_str() == Some(skit_session_id)), + "created session not found in list_sessions" + ); + + // 3. Get pipeline — should return the passthrough node + let get = json!({ + "jsonrpc": "2.0", + "id": 4, + "method": "tools/call", + "params": { + "name": "get_pipeline", + "arguments": { + "session_id": skit_session_id + } + } + }); + let res = mcp_post_with_session(&client, addr, &get, &token, &session_id).await; + assert_eq!(res.status(), StatusCode::OK); + + let body_text = res.text().await.unwrap(); + let body = extract_sse_json(&body_text); + let text = body["result"]["content"][0]["text"].as_str().expect("expected text"); + let pipeline: serde_json::Value = serde_json::from_str(text).expect("expected JSON"); + assert!(pipeline["nodes"]["pass"].is_object(), "expected 'pass' node in pipeline"); + + // 4. Destroy the session + let destroy = json!({ + "jsonrpc": "2.0", + "id": 5, + "method": "tools/call", + "params": { + "name": "destroy_session", + "arguments": { + "session_id": skit_session_id + } + } + }); + let res = mcp_post_with_session(&client, addr, &destroy, &token, &session_id).await; + assert_eq!(res.status(), StatusCode::OK); + + let body_text = res.text().await.unwrap(); + let body = extract_sse_json(&body_text); + let result = &body["result"]; + assert!(!result.is_null(), "expected result from destroy_session, got: {body}"); + let text = result["content"][0]["text"].as_str().expect("expected text"); + let destroyed: serde_json::Value = serde_json::from_str(text).expect("expected JSON"); + assert_eq!(destroyed["session_id"].as_str(), Some(skit_session_id)); + + // 5. Verify session is gone + let res = mcp_post_with_session(&client, addr, &list, &token, &session_id).await; + assert_eq!(res.status(), StatusCode::OK); + + let body_text = res.text().await.unwrap(); + let body = extract_sse_json(&body_text); + let text = body["result"]["content"][0]["text"].as_str().expect("expected text"); + let sessions: Vec = serde_json::from_str(text).expect("expected JSON array"); + assert!( + !sessions.iter().any(|s| s["id"].as_str() == Some(skit_session_id)), + "destroyed session should not appear in list_sessions" + ); +} + +// ----------------------------------------------------------------------- +// generate_oneshot_command tests +// ----------------------------------------------------------------------- + +#[tokio::test] +async fn mcp_generate_oneshot_command_curl() { + let _ = tracing_subscriber::fmt::try_init(); + + let (addr, _handle, token, _dir) = start_mcp_server().await; + let client = reqwest::Client::new(); + let session_id = init_mcp_session(&client, addr, &token).await; + + let call = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "generate_oneshot_command", + "arguments": { + "yaml": ONESHOT_PASSTHROUGH_YAML, + "inputs": [{ "field": "media", "path": "/tmp/input.wav" }], + "output": "/tmp/output.wav", + "server_url": "http://localhost:4545", + "format": "curl" + } + } + }); + let res = mcp_post_with_session(&client, addr, &call, &token, &session_id).await; + assert_eq!(res.status(), StatusCode::OK); + + let body_text = res.text().await.unwrap(); + let body = extract_sse_json(&body_text); + let result = &body["result"]; + assert!(!result.is_null(), "expected result, got: {body}"); + + let text = result["content"][0]["text"].as_str().expect("expected text content"); + assert!(text.contains("curl"), "curl command expected in output: {text}"); + assert!(text.contains("/api/v1/process"), "endpoint expected in output: {text}"); + assert!(text.contains("config="), "config field expected in output: {text}"); + assert!(text.contains("mktemp"), "mktemp expected in output: {text}"); + assert!(text.contains("'media=@/tmp/input.wav'"), "input field expected in output: {text}"); + assert!(text.contains("-o '/tmp/output.wav'"), "output path expected in output: {text}"); +} + +#[tokio::test] +async fn mcp_generate_oneshot_command_skit_cli() { + let _ = tracing_subscriber::fmt::try_init(); + + let (addr, _handle, token, _dir) = start_mcp_server().await; + let client = reqwest::Client::new(); + let session_id = init_mcp_session(&client, addr, &token).await; + + let call = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "generate_oneshot_command", + "arguments": { + "yaml": ONESHOT_PASSTHROUGH_YAML, + "inputs": [{ "field": "media", "path": "/tmp/input.wav" }], + "output": "/tmp/output.wav", + "server_url": "http://localhost:9999", + "format": "skit-cli" + } + } + }); + let res = mcp_post_with_session(&client, addr, &call, &token, &session_id).await; + assert_eq!(res.status(), StatusCode::OK); + + let body_text = res.text().await.unwrap(); + let body = extract_sse_json(&body_text); + let result = &body["result"]; + assert!(!result.is_null(), "expected result, got: {body}"); + + let text = result["content"][0]["text"].as_str().expect("expected text content"); + assert!( + text.contains("streamkit-client oneshot"), + "skit-cli command expected in output: {text}" + ); + assert!(text.contains("'/tmp/input.wav'"), "input path expected in output: {text}"); + assert!(text.contains("'/tmp/output.wav'"), "output path expected in output: {text}"); + assert!( + text.contains("--server 'http://localhost:9999'"), + "server URL expected in output: {text}" + ); +} + +#[tokio::test] +async fn mcp_generate_oneshot_command_invalid_yaml() { + let _ = tracing_subscriber::fmt::try_init(); + + let (addr, _handle, token, _dir) = start_mcp_server().await; + let client = reqwest::Client::new(); + let session_id = init_mcp_session(&client, addr, &token).await; + + let call = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "generate_oneshot_command", + "arguments": { + "yaml": "not: valid: yaml: [[", + "inputs": [{ "field": "media", "path": "/tmp/input.wav" }], + "output": "/tmp/output.wav" + } + } + }); + let res = mcp_post_with_session(&client, addr, &call, &token, &session_id).await; + assert_eq!(res.status(), StatusCode::OK); + + let body_text = res.text().await.unwrap(); + let body = extract_sse_json(&body_text); + let result = &body["result"]; + assert!(!result.is_null(), "expected result (validation error), got: {body}"); + + let text = result["content"][0]["text"].as_str().expect("expected text content"); + assert!(text.contains("validation failed"), "expected validation failure message, got: {text}"); +} + +#[tokio::test] +async fn mcp_generate_oneshot_command_permission_denied() { + let _ = tracing_subscriber::fmt::try_init(); + + let (addr, _handle, token, _dir) = start_restricted_mcp_server().await; + let client = reqwest::Client::new(); + let session_id = init_mcp_session(&client, addr, &token).await; + + let call = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "generate_oneshot_command", + "arguments": { + "yaml": ONESHOT_PASSTHROUGH_YAML, + "inputs": [{ "field": "media", "path": "/tmp/input.wav" }], + "output": "/tmp/output.wav" + } + } + }); + let res = mcp_post_with_session(&client, addr, &call, &token, &session_id).await; + assert_eq!(res.status(), StatusCode::OK); + + let body_text = res.text().await.unwrap(); + let body = extract_sse_json(&body_text); + let error = &body["error"]; + assert!(!error.is_null(), "expected error for permission denied, got: {body}"); + let error_msg = error["message"].as_str().unwrap_or(""); + assert!( + error_msg.contains("Permission denied") || error_msg.contains("permission"), + "expected permission error, got: {error_msg}" + ); +} + +// ----------------------------------------------------------------------- +// Prompt tests +// ----------------------------------------------------------------------- + +#[tokio::test] +async fn mcp_list_prompts_returns_both_prompts() { + let _ = tracing_subscriber::fmt::try_init(); + + let (addr, _handle, token, _dir) = start_mcp_server().await; + let client = reqwest::Client::new(); + let session_id = init_mcp_session(&client, addr, &token).await; + + let list = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "prompts/list", + "params": {} + }); + let res = mcp_post_with_session(&client, addr, &list, &token, &session_id).await; + assert_eq!(res.status(), StatusCode::OK); + + let body_text = res.text().await.unwrap(); + let body = extract_sse_json(&body_text); + let result = &body["result"]; + assert!(!result.is_null(), "expected result from prompts/list, got: {body}"); + + let prompts = result["prompts"].as_array().expect("expected prompts array"); + assert_eq!(prompts.len(), 2, "expected exactly 2 prompts, got: {prompts:?}"); + + let names: Vec<&str> = prompts.iter().filter_map(|p| p["name"].as_str()).collect(); + assert!(names.contains(&"design_pipeline"), "missing design_pipeline in: {names:?}"); + assert!(names.contains(&"debug_pipeline"), "missing debug_pipeline in: {names:?}"); + + // Verify design_pipeline has optional description argument + let design = prompts.iter().find(|p| p["name"] == "design_pipeline").unwrap(); + let design_args = design["arguments"].as_array().expect("expected arguments array"); + assert_eq!(design_args.len(), 1); + assert_eq!(design_args[0]["name"], "description"); + assert_eq!(design_args[0]["required"], false); + + // Verify debug_pipeline has required session_id argument + let debug = prompts.iter().find(|p| p["name"] == "debug_pipeline").unwrap(); + let debug_args = debug["arguments"].as_array().expect("expected arguments array"); + assert_eq!(debug_args.len(), 1); + assert_eq!(debug_args[0]["name"], "session_id"); + assert_eq!(debug_args[0]["required"], true); +} + +#[tokio::test] +async fn mcp_get_prompt_design_pipeline_returns_node_info() { + let _ = tracing_subscriber::fmt::try_init(); + + let (addr, _handle, token, _dir) = start_mcp_server().await; + let client = reqwest::Client::new(); + let session_id = init_mcp_session(&client, addr, &token).await; + + // Call get_prompt without arguments (description is optional) + let get = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "prompts/get", + "params": { + "name": "design_pipeline" + } + }); + let res = mcp_post_with_session(&client, addr, &get, &token, &session_id).await; + assert_eq!(res.status(), StatusCode::OK); + + let body_text = res.text().await.unwrap(); + let body = extract_sse_json(&body_text); + let result = &body["result"]; + assert!(!result.is_null(), "expected result from prompts/get, got: {body}"); + + let messages = result["messages"].as_array().expect("expected messages array"); + assert!(!messages.is_empty(), "expected at least one message"); + + let text = messages[0]["content"]["text"].as_str().expect("expected text content"); + // Verify it contains key sections + assert!(text.contains("YAML Format"), "missing YAML format section"); + assert!(text.contains("Available Nodes"), "missing available nodes section"); + assert!(text.contains("Connection Rules"), "missing connection rules section"); + assert!(text.contains("core::passthrough"), "missing core::passthrough node"); + assert!(text.contains("validate_pipeline"), "missing workflow reference to validate_pipeline"); +} + +#[tokio::test] +async fn mcp_get_prompt_design_pipeline_with_description() { + let _ = tracing_subscriber::fmt::try_init(); + + let (addr, _handle, token, _dir) = start_mcp_server().await; + let client = reqwest::Client::new(); + let session_id = init_mcp_session(&client, addr, &token).await; + + let get = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "prompts/get", + "params": { + "name": "design_pipeline", + "arguments": { + "description": "Build a pipeline that mixes two audio streams" + } + } + }); + let res = mcp_post_with_session(&client, addr, &get, &token, &session_id).await; + assert_eq!(res.status(), StatusCode::OK); + + let body_text = res.text().await.unwrap(); + let body = extract_sse_json(&body_text); + let result = &body["result"]; + assert!(!result.is_null(), "expected result, got: {body}"); + + let text = result["messages"][0]["content"]["text"].as_str().expect("expected text content"); + assert!( + text.contains("Build a pipeline that mixes two audio streams"), + "expected user description in prompt content" + ); + assert!(text.contains("User Request"), "expected User Request section header"); +} + +#[tokio::test] +async fn mcp_get_prompt_debug_pipeline_with_session() { + let _ = tracing_subscriber::fmt::try_init(); + + let (addr, _handle, token, _dir) = start_mcp_server().await; + let client = reqwest::Client::new(); + let session_id = init_mcp_session(&client, addr, &token).await; + + // 1. Create a session to debug + let create = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "create_session", + "arguments": { + "yaml": PASSTHROUGH_YAML, + "name": "debug-prompt-test" + } + } + }); + let res = mcp_post_with_session(&client, addr, &create, &token, &session_id).await; + assert_eq!(res.status(), StatusCode::OK); + + let body_text = res.text().await.unwrap(); + let body = extract_sse_json(&body_text); + let text = body["result"]["content"][0]["text"] + .as_str() + .expect("expected text content from create_session"); + let created: serde_json::Value = serde_json::from_str(text).expect("expected JSON"); + let skit_session_id = created["session_id"].as_str().expect("missing session_id"); + + // 2. Call debug_pipeline prompt with the session ID + let get = json!({ + "jsonrpc": "2.0", + "id": 3, + "method": "prompts/get", + "params": { + "name": "debug_pipeline", + "arguments": { + "session_id": skit_session_id + } + } + }); + let res = mcp_post_with_session(&client, addr, &get, &token, &session_id).await; + assert_eq!(res.status(), StatusCode::OK); + + let body_text = res.text().await.unwrap(); + let body = extract_sse_json(&body_text); + let result = &body["result"]; + assert!(!result.is_null(), "expected result from debug_pipeline prompt, got: {body}"); + + let messages = result["messages"].as_array().expect("expected messages array"); + assert!(!messages.is_empty(), "expected at least one message"); + + let text = messages[0]["content"]["text"].as_str().expect("expected text content"); + // Verify it includes pipeline state and diagnostic sections + assert!(text.contains("Current Pipeline State"), "missing pipeline state section"); + assert!(text.contains("Node States"), "missing node states section"); + assert!(text.contains("Diagnostic Checklist"), "missing diagnostic checklist"); + assert!(text.contains("core::passthrough"), "missing passthrough node reference"); + + // 3. Clean up — destroy the session + let destroy = json!({ + "jsonrpc": "2.0", + "id": 4, + "method": "tools/call", + "params": { + "name": "destroy_session", + "arguments": { + "session_id": skit_session_id + } + } + }); + let _ = mcp_post_with_session(&client, addr, &destroy, &token, &session_id).await; +} + +#[tokio::test] +async fn mcp_get_prompt_debug_pipeline_missing_session_returns_error() { + let _ = tracing_subscriber::fmt::try_init(); + + let (addr, _handle, token, _dir) = start_mcp_server().await; + let client = reqwest::Client::new(); + let session_id = init_mcp_session(&client, addr, &token).await; + + let get = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "prompts/get", + "params": { + "name": "debug_pipeline", + "arguments": { + "session_id": "nonexistent-session-12345" + } + } + }); + let res = mcp_post_with_session(&client, addr, &get, &token, &session_id).await; + assert_eq!(res.status(), StatusCode::OK); + + let body_text = res.text().await.unwrap(); + let body = extract_sse_json(&body_text); + let error = &body["error"]; + assert!(!error.is_null(), "expected error for nonexistent session, got: {body}"); + let error_msg = error["message"].as_str().unwrap_or(""); + assert!(error_msg.contains("not found"), "expected 'not found' in error, got: {error_msg}"); +} + +#[tokio::test] +async fn mcp_get_prompt_unknown_returns_error() { + let _ = tracing_subscriber::fmt::try_init(); + + let (addr, _handle, token, _dir) = start_mcp_server().await; + let client = reqwest::Client::new(); + let session_id = init_mcp_session(&client, addr, &token).await; + + let get = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "prompts/get", + "params": { + "name": "nonexistent_prompt" + } + }); + let res = mcp_post_with_session(&client, addr, &get, &token, &session_id).await; + assert_eq!(res.status(), StatusCode::OK); + + let body_text = res.text().await.unwrap(); + let body = extract_sse_json(&body_text); + let error = &body["error"]; + assert!(!error.is_null(), "expected error for unknown prompt, got: {body}"); +} + +// ----------------------------------------------------------------------- +// Batch & Tune tests +// ----------------------------------------------------------------------- + +/// Helper: create a StreamKit session via MCP and return its session_id. +async fn create_skit_session( + client: &reqwest::Client, + addr: SocketAddr, + token: &str, + mcp_session: &str, + yaml: &str, +) -> String { + let create = json!({ + "jsonrpc": "2.0", + "id": 100, + "method": "tools/call", + "params": { + "name": "create_session", + "arguments": { "yaml": yaml } + } + }); + let res = mcp_post_with_session(client, addr, &create, token, mcp_session).await; + assert_eq!(res.status(), StatusCode::OK); + let body = extract_sse_json(&res.text().await.unwrap()); + let text = body["result"]["content"][0]["text"].as_str().expect("create_session text"); + let parsed: serde_json::Value = serde_json::from_str(text).unwrap(); + parsed["session_id"].as_str().expect("session_id").to_string() +} + +#[tokio::test] +async fn mcp_validate_batch_valid_operations() { + let _ = tracing_subscriber::fmt::try_init(); + let (addr, _handle, token, _dir) = start_mcp_server().await; + let client = reqwest::Client::new(); + let mcp_session = init_mcp_session(&client, addr, &token).await; + + let skit_session = + create_skit_session(&client, addr, &token, &mcp_session, PASSTHROUGH_YAML).await; + + let validate = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "validate_batch", + "arguments": { + "session_id": skit_session, + "operations": [ + { "action": "addnode", "node_id": "new_pass", "kind": "core::passthrough" } + ] + } + } + }); + let res = mcp_post_with_session(&client, addr, &validate, &token, &mcp_session).await; + assert_eq!(res.status(), StatusCode::OK); + + let body = extract_sse_json(&res.text().await.unwrap()); + let text = body["result"]["content"][0]["text"].as_str().expect("validate_batch text"); + let errors: Vec = serde_json::from_str(text).unwrap(); + assert!(errors.is_empty(), "expected no validation errors, got: {errors:?}"); +} + +#[tokio::test] +async fn mcp_validate_batch_invalid_duplicate_node() { + let _ = tracing_subscriber::fmt::try_init(); + let (addr, _handle, token, _dir) = start_mcp_server().await; + let client = reqwest::Client::new(); + let mcp_session = init_mcp_session(&client, addr, &token).await; + + let skit_session = + create_skit_session(&client, addr, &token, &mcp_session, PASSTHROUGH_YAML).await; + + // "pass" already exists in the pipeline — adding it again should fail. + let validate = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "validate_batch", + "arguments": { + "session_id": skit_session, + "operations": [ + { "action": "addnode", "node_id": "pass", "kind": "core::passthrough" } + ] + } + } + }); + let res = mcp_post_with_session(&client, addr, &validate, &token, &mcp_session).await; + assert_eq!(res.status(), StatusCode::OK); + + let body = extract_sse_json(&res.text().await.unwrap()); + let text = body["result"]["content"][0]["text"].as_str().expect("validate_batch text"); + let errors: Vec = serde_json::from_str(text).unwrap(); + assert!(!errors.is_empty(), "expected validation errors for duplicate node"); + assert!( + errors[0]["message"].as_str().unwrap().contains("already exists"), + "expected 'already exists' error" + ); +} + +#[tokio::test] +async fn mcp_apply_batch_add_node_round_trip() { + let _ = tracing_subscriber::fmt::try_init(); + let (addr, _handle, token, _dir) = start_mcp_server().await; + let client = reqwest::Client::new(); + let mcp_session = init_mcp_session(&client, addr, &token).await; + + let skit_session = + create_skit_session(&client, addr, &token, &mcp_session, PASSTHROUGH_YAML).await; + + // Apply: add a new node + let apply = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "apply_batch", + "arguments": { + "session_id": skit_session, + "operations": [ + { "action": "addnode", "node_id": "extra", "kind": "core::passthrough" } + ] + } + } + }); + let res = mcp_post_with_session(&client, addr, &apply, &token, &mcp_session).await; + assert_eq!(res.status(), StatusCode::OK); + + let body = extract_sse_json(&res.text().await.unwrap()); + let text = body["result"]["content"][0]["text"].as_str().expect("apply_batch text"); + let result: serde_json::Value = serde_json::from_str(text).unwrap(); + assert_eq!(result["success"], true); + + // Verify via get_pipeline that "extra" exists + let get = json!({ + "jsonrpc": "2.0", + "id": 3, + "method": "tools/call", + "params": { + "name": "get_pipeline", + "arguments": { "session_id": skit_session } + } + }); + let res = mcp_post_with_session(&client, addr, &get, &token, &mcp_session).await; + assert_eq!(res.status(), StatusCode::OK); + + let body = extract_sse_json(&res.text().await.unwrap()); + let text = body["result"]["content"][0]["text"].as_str().expect("get_pipeline text"); + let pipeline: serde_json::Value = serde_json::from_str(text).unwrap(); + assert!(pipeline["nodes"]["extra"].is_object(), "expected 'extra' node in pipeline"); + assert!(pipeline["nodes"]["pass"].is_object(), "expected original 'pass' node in pipeline"); + + // Clean up + let destroy = json!({ + "jsonrpc": "2.0", + "id": 4, + "method": "tools/call", + "params": { "name": "destroy_session", "arguments": { "session_id": skit_session } } + }); + let _ = mcp_post_with_session(&client, addr, &destroy, &token, &mcp_session).await; +} + +#[tokio::test] +async fn mcp_tune_node_update_params() { + let _ = tracing_subscriber::fmt::try_init(); + let (addr, _handle, token, _dir) = start_mcp_server().await; + let client = reqwest::Client::new(); + let mcp_session = init_mcp_session(&client, addr, &token).await; + + let skit_session = + create_skit_session(&client, addr, &token, &mcp_session, PASSTHROUGH_YAML).await; + + // Tune: send UpdateParams to the "pass" node + let tune = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "tune_node", + "arguments": { + "session_id": skit_session, + "node_id": "pass", + "message": { "UpdateParams": { "gain": 0.5 } } + } + } + }); + let res = mcp_post_with_session(&client, addr, &tune, &token, &mcp_session).await; + assert_eq!(res.status(), StatusCode::OK); + + let body = extract_sse_json(&res.text().await.unwrap()); + let text = body["result"]["content"][0]["text"].as_str().expect("tune_node text"); + let result: serde_json::Value = serde_json::from_str(text).unwrap(); + assert_eq!(result["success"], true); + + // Verify params persisted in pipeline model + let get = json!({ + "jsonrpc": "2.0", + "id": 3, + "method": "tools/call", + "params": { + "name": "get_pipeline", + "arguments": { "session_id": skit_session } + } + }); + let res = mcp_post_with_session(&client, addr, &get, &token, &mcp_session).await; + assert_eq!(res.status(), StatusCode::OK); + + let body = extract_sse_json(&res.text().await.unwrap()); + let text = body["result"]["content"][0]["text"].as_str().expect("get_pipeline text"); + let pipeline: serde_json::Value = serde_json::from_str(text).unwrap(); + let pass_params = &pipeline["nodes"]["pass"]["params"]; + assert_eq!(pass_params["gain"], 0.5, "expected tuned gain param"); + + // Clean up + let destroy = json!({ + "jsonrpc": "2.0", + "id": 4, + "method": "tools/call", + "params": { "name": "destroy_session", "arguments": { "session_id": skit_session } } + }); + let _ = mcp_post_with_session(&client, addr, &destroy, &token, &mcp_session).await; +} + +#[tokio::test] +async fn mcp_modify_sessions_permission_denied() { + let _ = tracing_subscriber::fmt::try_init(); + + // Start a server where modify_sessions is disabled for admin + let listener = + TcpListener::bind("127.0.0.1:0").await.expect("Failed to bind test server listener"); + let addr = listener.local_addr().unwrap(); + let temp_dir = TempDir::new().unwrap(); + + let mut config = Config::default(); + config.mcp.enabled = true; + config.auth.mode = streamkit_server::config::AuthMode::Enabled; + config.auth.state_dir = temp_dir.path().to_string_lossy().to_string(); + + if let Some(admin_perms) = config.permissions.roles.get_mut("admin") { + admin_perms.modify_sessions = false; + admin_perms.tune_nodes = false; + } + + let auth_state = streamkit_server::auth::AuthState::new(&config.auth, true) + .await + .expect("Failed to init auth state"); + let auth_state = Arc::new(auth_state); + + let admin_token_path = temp_dir.path().join("admin.token"); + let admin_token = + tokio::fs::read_to_string(&admin_token_path).await.expect("Missing admin.token"); + let admin_token = admin_token.trim().to_string(); + + let (app, _state) = streamkit_server::server::create_app(config, Some(auth_state)); + let _server_handle = tokio::spawn(async move { + axum::serve(listener, app.into_make_service()).await.unwrap(); + }); + + wait_for_healthz(addr).await; + + let client = reqwest::Client::new(); + let mcp_session = init_mcp_session(&client, addr, &admin_token).await; + + // validate_batch should be denied + let validate = json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "validate_batch", + "arguments": { + "session_id": "any", + "operations": [] + } + } + }); + let res = mcp_post_with_session(&client, addr, &validate, &admin_token, &mcp_session).await; + assert_eq!(res.status(), StatusCode::OK); + let body = extract_sse_json(&res.text().await.unwrap()); + let error = &body["error"]; + assert!(!error.is_null(), "expected error for validate_batch, got: {body}"); + assert!( + error["message"].as_str().unwrap_or("").contains("Permission denied"), + "expected permission denied" + ); + + // apply_batch should be denied + let apply = json!({ + "jsonrpc": "2.0", + "id": 3, + "method": "tools/call", + "params": { + "name": "apply_batch", + "arguments": { + "session_id": "any", + "operations": [] + } + } + }); + let res = mcp_post_with_session(&client, addr, &apply, &admin_token, &mcp_session).await; + assert_eq!(res.status(), StatusCode::OK); + let body = extract_sse_json(&res.text().await.unwrap()); + let error = &body["error"]; + assert!(!error.is_null(), "expected error for apply_batch, got: {body}"); + assert!( + error["message"].as_str().unwrap_or("").contains("Permission denied"), + "expected permission denied" + ); + + // tune_node should be denied + let tune = json!({ + "jsonrpc": "2.0", + "id": 4, + "method": "tools/call", + "params": { + "name": "tune_node", + "arguments": { + "session_id": "any", + "node_id": "any", + "message": { "UpdateParams": {} } + } + } + }); + let res = mcp_post_with_session(&client, addr, &tune, &admin_token, &mcp_session).await; + assert_eq!(res.status(), StatusCode::OK); + let body = extract_sse_json(&res.text().await.unwrap()); + let error = &body["error"]; + assert!(!error.is_null(), "expected error for tune_node, got: {body}"); + assert!( + error["message"].as_str().unwrap_or("").contains("Permission denied"), + "expected permission denied" + ); +} + +// --------------------------------------------------------------------------- +// STDIO transport tests +// --------------------------------------------------------------------------- + +/// Spawn the `skit mcp` binary as a child process and verify it responds to +/// a JSON-RPC `initialize` request over STDIO. +#[tokio::test] +async fn test_mcp_stdio_initialize() { + use std::process::Stdio; + use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; + use tokio::process::Command; + + // Locate the skit binary built alongside the test binary. + let skit_bin = std::path::PathBuf::from(env!("CARGO_BIN_EXE_skit")); + + let mut child = Command::new(&skit_bin) + .arg("mcp") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .expect("Failed to spawn skit mcp process"); + + let stdin = child.stdin.as_mut().expect("Failed to open stdin"); + let stdout = child.stdout.take().expect("Failed to open stdout"); + let mut reader = BufReader::new(stdout); + + // Send a JSON-RPC initialize request (MCP protocol). + let init_request = serde_json::json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { + "name": "test-client", + "version": "0.1.0" + } + } + }); + let msg = serde_json::to_string(&init_request).unwrap(); + stdin.write_all(msg.as_bytes()).await.unwrap(); + stdin.write_all(b"\n").await.unwrap(); + stdin.flush().await.unwrap(); + + // Read the response line (with a timeout). + let mut response_line = String::new(); + let read_result = + tokio::time::timeout(Duration::from_secs(30), reader.read_line(&mut response_line)).await; + + assert!(read_result.is_ok(), "Timed out waiting for initialize response"); + let bytes_read = read_result.unwrap().expect("Failed to read response"); + assert!(bytes_read > 0, "Empty response from MCP STDIO server"); + + let response: serde_json::Value = + serde_json::from_str(response_line.trim()).expect("Invalid JSON response"); + + // Verify it's a successful JSON-RPC response. + assert_eq!(response["jsonrpc"], "2.0"); + assert_eq!(response["id"], 1); + assert!( + response.get("result").is_some(), + "Expected 'result' in initialize response, got: {}", + response + ); + + let result = &response["result"]; + assert_eq!(result["protocolVersion"], "2024-11-05"); + assert!( + result["serverInfo"]["name"].as_str().unwrap_or("").contains("streamkit"), + "Expected serverInfo.name to contain 'streamkit', got: {}", + result["serverInfo"] + ); + + // Verify capabilities include tools. + assert!( + result["capabilities"]["tools"].is_object(), + "Expected tools capability, got: {}", + result["capabilities"] + ); + + // Clean up: kill the child process. + child.kill().await.ok(); +} + +/// After initializing, invoke `tools/call` with `list_nodes` over STDIO to +/// verify that the admin-fallback auth path grants tool access. +#[tokio::test] +async fn test_mcp_stdio_tool_call() { + use std::process::Stdio; + use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader}; + use tokio::process::Command; + + let skit_bin = std::path::PathBuf::from(env!("CARGO_BIN_EXE_skit")); + + let mut child = Command::new(&skit_bin) + .arg("mcp") + .stdin(Stdio::piped()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .expect("Failed to spawn skit mcp process"); + + let stdin = child.stdin.as_mut().expect("Failed to open stdin"); + let stdout = child.stdout.take().expect("Failed to open stdout"); + let mut reader = BufReader::new(stdout); + + // 1. Initialize + let init_request = serde_json::json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2024-11-05", + "capabilities": {}, + "clientInfo": { "name": "test-client", "version": "0.1.0" } + } + }); + let msg = serde_json::to_string(&init_request).unwrap(); + stdin.write_all(msg.as_bytes()).await.unwrap(); + stdin.write_all(b"\n").await.unwrap(); + stdin.flush().await.unwrap(); + + let mut line = String::new(); + let read_result = tokio::time::timeout(Duration::from_secs(30), reader.read_line(&mut line)) + .await + .expect("Timed out waiting for initialize response") + .expect("Failed to read initialize response"); + assert!(read_result > 0, "Empty initialize response"); + + // 2. Send initialized notification (required by MCP protocol) + let initialized = serde_json::json!({ + "jsonrpc": "2.0", + "method": "notifications/initialized" + }); + let msg = serde_json::to_string(&initialized).unwrap(); + stdin.write_all(msg.as_bytes()).await.unwrap(); + stdin.write_all(b"\n").await.unwrap(); + stdin.flush().await.unwrap(); + + // 3. Call list_nodes tool + let tool_call = serde_json::json!({ + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + "params": { + "name": "list_nodes", + "arguments": {} + } + }); + let msg = serde_json::to_string(&tool_call).unwrap(); + stdin.write_all(msg.as_bytes()).await.unwrap(); + stdin.write_all(b"\n").await.unwrap(); + stdin.flush().await.unwrap(); + + let mut response_line = String::new(); + let read_result = + tokio::time::timeout(Duration::from_secs(30), reader.read_line(&mut response_line)) + .await + .expect("Timed out waiting for list_nodes response") + .expect("Failed to read list_nodes response"); + assert!(read_result > 0, "Empty list_nodes response"); + + let response: serde_json::Value = + serde_json::from_str(response_line.trim()).expect("Invalid JSON response"); + + assert_eq!(response["jsonrpc"], "2.0"); + assert_eq!(response["id"], 2); + assert!( + response.get("result").is_some(), + "Expected 'result' in list_nodes response, got: {}", + response + ); + + // The result should contain node definitions as text content. + let content = &response["result"]["content"]; + assert!(content.is_array(), "Expected content array, got: {content}"); + let text = content[0]["text"].as_str().unwrap_or(""); + assert!(!text.is_empty(), "Expected non-empty node listing from list_nodes"); + + child.kill().await.ok(); +} diff --git a/crates/api/Cargo.toml b/crates/api/Cargo.toml index ea133ef5..1cc3a07b 100644 --- a/crates/api/Cargo.toml +++ b/crates/api/Cargo.toml @@ -18,6 +18,7 @@ serde = { version = "1.0.228", features = ["derive", "rc"] } serde_json = "1.0" serde-saphyr = "0.0.23" ts-rs = { version = "12.0.1" } +schemars = { version = "1.2.0", features = ["derive"] } indexmap = { version = "2.14", features = ["serde"] } [[bin]] diff --git a/crates/api/src/lib.rs b/crates/api/src/lib.rs index ec204f85..9d567730 100644 --- a/crates/api/src/lib.rs +++ b/crates/api/src/lib.rs @@ -234,7 +234,7 @@ pub enum RequestPayload { GetPermissions, } -#[derive(Serialize, Deserialize, Debug, Clone, TS)] +#[derive(Serialize, Deserialize, Debug, Clone, TS, schemars::JsonSchema)] #[ts(export)] #[serde(tag = "action")] #[serde(rename_all = "lowercase")] diff --git a/crates/core/src/control.rs b/crates/core/src/control.rs index f73e90dd..e36a085d 100644 --- a/crates/core/src/control.rs +++ b/crates/core/src/control.rs @@ -15,7 +15,7 @@ use serde::{Deserialize, Serialize}; use ts_rs::TS; /// A message sent to a specific, running node to tune its parameters or control its lifecycle. -#[derive(Debug, Deserialize, Serialize, TS)] +#[derive(Debug, Deserialize, Serialize, TS, schemars::JsonSchema)] #[ts(export)] pub enum NodeControlMessage { UpdateParams(#[ts(type = "JsonValue")] serde_json::Value), @@ -28,7 +28,9 @@ pub enum NodeControlMessage { } /// Specifies how a connection handles backpressure from slow consumers. -#[derive(Debug, Deserialize, Serialize, Clone, Copy, PartialEq, Eq, Default, TS)] +#[derive( + Debug, Deserialize, Serialize, Clone, Copy, PartialEq, Eq, Default, TS, schemars::JsonSchema, +)] #[ts(export)] #[serde(rename_all = "snake_case")] pub enum ConnectionMode { diff --git a/justfile b/justfile index 5b5354f6..7cf6bba6 100644 --- a/justfile +++ b/justfile @@ -143,6 +143,10 @@ skit-profiling *args='': @echo "Note: Heap profiling configuration is embedded in the binary" @RUSTFLAGS="-C force-frame-pointers=yes -C target-cpu=native" cargo run --profile release-lto {{moq_features}} {{profiling_features}} -p streamkit-server --bin skit -- {{args}} +# Start the MCP server over STDIO (for MCP client integration) +skit-mcp *args='': check-ui-dist + @cargo run --features mcp {{moq_features}} {{extra_features}} -p streamkit-server --bin skit -- mcp {{args}} + # Start the skit server with tokio-console support skit-console *args='': @echo "Starting skit with tokio-console support..." @@ -217,6 +221,7 @@ test-skit: @echo "Testing skit..." @cargo test --workspace -- --skip gpu_tests:: @cargo test -p streamkit-server --features "moq" + @cargo test -p streamkit-server --features "mcp" # Run GPU tests (requires a machine with a GPU) test-skit-gpu: @@ -231,6 +236,7 @@ lint-skit: @echo "Linting skit..." @cargo fmt --all -- --check @cargo clippy -p streamkit-server --all-targets --features "moq" -- -D warnings + @cargo clippy -p streamkit-server --all-targets --features "mcp" -- -D warnings @cargo clippy --workspace --exclude streamkit-server --all-targets -- -D warnings @mkdir -p target @HOST=$(rustc -vV | sed -n 's/^host: //p'); \