diff --git a/examples/stdio_mcp_server.cpp b/examples/stdio_mcp_server.cpp index fb32ac0..255a2da 100644 --- a/examples/stdio_mcp_server.cpp +++ b/examples/stdio_mcp_server.cpp @@ -10,6 +10,7 @@ int main() using Json = nlohmann::json; fastmcpp::tools::ToolManager tm; + int counter_value = 0; fastmcpp::tools::Tool add{ "add", Json{{"type", "object"}, @@ -29,8 +30,27 @@ int main() }}; tm.register_tool(add); + fastmcpp::tools::Tool counter{ + "counter", + Json{{"type", "object"}, {"properties", Json::object()}}, + Json{{"type", "array"}, + {"items", + Json::array({Json{{"type", "object"}, + {"properties", Json{{"type", Json{{"type", "string"}}}, + {"text", Json{{"type", "string"}}}}}, + {"required", Json::array({"type", "text"})}}})}}, + [&counter_value](const Json&) -> Json + { + counter_value += 1; + return Json{{"content", + Json::array({Json{{"type", "text"}, {"text", std::to_string(counter_value)}}})}}; + }}; + tm.register_tool(counter); + auto handler = - fastmcpp::mcp::make_mcp_handler("demo_stdio", "0.1.0", tm, {{"add", "Add two numbers"}}); + fastmcpp::mcp::make_mcp_handler("demo_stdio", "0.1.0", tm, + {{"add", "Add two numbers"}, + {"counter", "Increment and return an in-process counter"}}); fastmcpp::server::StdioServerWrapper server(handler); server.run(); return 0; diff --git a/include/fastmcpp/client/transports.hpp b/include/fastmcpp/client/transports.hpp index 3f7b751..2edf03f 100644 --- a/include/fastmcpp/client/transports.hpp +++ b/include/fastmcpp/client/transports.hpp @@ -53,8 +53,9 @@ class WebSocketTransport : public ITransport std::string url_; }; -// Launches an MCP stdio server as a subprocess and performs -// a single JSON-RPC request/response per call. +// Launches an MCP stdio server as a subprocess and performs JSON-RPC requests +// over its stdin/stdout. By default, the subprocess is kept alive between calls +// to better match Python fastmcp behavior; pass keep_alive=false to spawn per call. class StdioTransport : public ITransport { public: @@ -64,29 +65,44 @@ class StdioTransport : public ITransport /// @param log_file Optional path where subprocess stderr will be written. /// If provided, stderr is redirected to this file in append mode. /// If not provided, stderr is captured and included in error messages. + /// @param keep_alive Whether to keep the subprocess alive between calls. Defaults to true. explicit StdioTransport(std::string command, std::vector args = {}, - std::optional log_file = std::nullopt) - : command_(std::move(command)), args_(std::move(args)), log_file_(std::move(log_file)) - { - } + std::optional log_file = std::nullopt, + bool keep_alive = true); /// Construct with ostream pointer for stderr (v2.13.0+) /// @param command The command to execute /// @param args Command-line arguments /// @param log_stream Stream pointer where subprocess stderr will be written /// Caller retains ownership; must remain valid during request() - StdioTransport(std::string command, std::vector args, std::ostream* log_stream) - : command_(std::move(command)), args_(std::move(args)), log_stream_(log_stream) + /// @param keep_alive Whether to keep the subprocess alive between calls. Defaults to true. + StdioTransport(std::string command, std::vector args, std::ostream* log_stream, + bool keep_alive = true); + + StdioTransport(const StdioTransport&) = delete; + StdioTransport& operator=(const StdioTransport&) = delete; + StdioTransport(StdioTransport&&) noexcept; + StdioTransport& operator=(StdioTransport&&) noexcept; + + ~StdioTransport(); + + fastmcpp::Json request(const std::string& route, const fastmcpp::Json& payload) override; + + bool keep_alive() const noexcept { + return keep_alive_; } - fastmcpp::Json request(const std::string& route, const fastmcpp::Json& payload); - private: std::string command_; std::vector args_; std::optional log_file_; std::ostream* log_stream_ = nullptr; + bool keep_alive_{true}; + int64_t next_id_{1}; + + struct State; + std::unique_ptr state_; }; /// SSE client transport for connecting to MCP servers using Server-Sent Events protocol. diff --git a/src/cli/main.cpp b/src/cli/main.cpp index e71b569..d54654f 100644 --- a/src/cli/main.cpp +++ b/src/cli/main.cpp @@ -52,6 +52,7 @@ static int tasks_usage(int exit_code = 1) std::cout << " --ws WebSocket URL (e.g. ws://127.0.0.1:8765)\n"; std::cout << " --stdio Spawn an MCP stdio server\n"; std::cout << " --stdio-arg Repeatable args for --stdio\n"; + std::cout << " --stdio-one-shot Spawn a fresh process per request (disables keep-alive)\n"; std::cout << "\n"; std::cout << "Notes:\n"; std::cout << " - Python fastmcp's `tasks` CLI is for Docket (distributed workers/Redis).\n"; @@ -75,6 +76,7 @@ struct TasksConnection std::string url_or_command; std::string mcp_path = "/mcp"; std::vector stdio_args; + bool stdio_keep_alive = true; }; static bool is_flag(const std::string& s) @@ -158,6 +160,8 @@ static std::optional parse_tasks_connection(std::vector(conn.url_or_command)); case TasksConnection::Kind::Stdio: - return Client(std::make_unique(conn.url_or_command, conn.stdio_args)); + return Client(std::make_unique(conn.url_or_command, conn.stdio_args, + std::nullopt, conn.stdio_keep_alive)); } throw std::runtime_error("Unsupported transport kind"); } diff --git a/src/client/transports.cpp b/src/client/transports.cpp index 548e03e..e3094ed 100644 --- a/src/client/transports.cpp +++ b/src/client/transports.cpp @@ -1,12 +1,15 @@ #include "fastmcpp/client/transports.hpp" #include "fastmcpp/exceptions.hpp" -#include "fastmcpp/util/json.hpp" +#include "fastmcpp/util/json.hpp" #include -#include +#include +#include +#include #include #include +#include #include #include #ifdef FASTMCPP_POST_STREAMING @@ -16,12 +19,28 @@ #include #endif -namespace fastmcpp::client -{ +namespace fastmcpp::client +{ -namespace +struct StdioTransport::State { -struct ParsedUrl +#ifdef TINY_PROCESS_LIB_AVAILABLE + std::unique_ptr process; + std::ofstream log_file_stream; + std::ostream* stderr_target{nullptr}; + + std::mutex request_mutex; + std::mutex mutex; + std::condition_variable cv; + std::string stdout_partial; + std::deque stdout_lines; + std::string stderr_data; +#endif +}; + +namespace +{ +struct ParsedUrl { std::string scheme; // "http" or "https" std::string host; @@ -508,6 +527,20 @@ void WebSocketTransport::request_stream(const std::string& route, const fastmcpp ws->close(); } +StdioTransport::StdioTransport(std::string command, std::vector args, + std::optional log_file, bool keep_alive) + : command_(std::move(command)), args_(std::move(args)), log_file_(std::move(log_file)), + keep_alive_(keep_alive) +{ +} + +StdioTransport::StdioTransport(std::string command, std::vector args, + std::ostream* log_stream, bool keep_alive) + : command_(std::move(command)), args_(std::move(args)), log_stream_(log_stream), + keep_alive_(keep_alive) +{ +} + fastmcpp::Json StdioTransport::request(const std::string& route, const fastmcpp::Json& payload) { // Use TinyProcessLibrary (fetched via CMake) for cross-platform subprocess handling @@ -519,6 +552,131 @@ fastmcpp::Json StdioTransport::request(const std::string& route, const fastmcpp: #ifdef TINY_PROCESS_LIB_AVAILABLE using namespace TinyProcessLib; + + if (keep_alive_) + { + if (!state_) + { + state_ = std::make_unique(); + + if (log_file_.has_value()) + { + state_->log_file_stream.open(log_file_.value(), std::ios::app); + if (state_->log_file_stream.is_open()) + state_->stderr_target = &state_->log_file_stream; + } + else if (log_stream_ != nullptr) + { + state_->stderr_target = log_stream_; + } + + auto stdout_callback = [st_ptr = state_.get()](const char* bytes, size_t n) + { + std::lock_guard lock(st_ptr->mutex); + st_ptr->stdout_partial.append(bytes, n); + + for (;;) + { + auto pos = st_ptr->stdout_partial.find('\n'); + if (pos == std::string::npos) + break; + + std::string line = st_ptr->stdout_partial.substr(0, pos); + if (!line.empty() && line.back() == '\r') + line.pop_back(); + st_ptr->stdout_lines.push_back(std::move(line)); + st_ptr->stdout_partial.erase(0, pos + 1); + } + + st_ptr->cv.notify_all(); + }; + + auto stderr_callback = [st_ptr = state_.get()](const char* bytes, size_t n) + { + std::lock_guard lock(st_ptr->mutex); + if (st_ptr->stderr_target != nullptr) + { + st_ptr->stderr_target->write(bytes, n); + st_ptr->stderr_target->flush(); + } + st_ptr->stderr_data.append(bytes, n); + }; + + state_->process = std::make_unique(cmd.str(), "", stdout_callback, + stderr_callback, /*open_stdin*/ true); + } + + auto* st = state_.get(); + std::lock_guard request_lock(st->request_mutex); + + const int64_t id = next_id_++; + fastmcpp::Json request = { + {"jsonrpc", "2.0"}, + {"id", id}, + {"method", route}, + {"params", payload}, + }; + + { + std::lock_guard lock(st->mutex); + st->stderr_data.clear(); + } + + if (!st->process->write(request.dump() + "\n")) + throw fastmcpp::TransportError("StdioTransport: failed to write request"); + + // Wait for a response matching this ID. + // Note: stdio servers may emit notifications or logs; ignore non-matching lines. + for (;;) + { + int exit_status = 0; + if (st->process->try_get_exit_status(exit_status)) + { + std::lock_guard lock(st->mutex); + throw fastmcpp::TransportError( + "StdioTransport process exited with code: " + + std::to_string(exit_status) + + (st->stderr_data.empty() ? std::string("") + : ("; stderr: ") + st->stderr_data)); + } + + std::unique_lock lock(st->mutex); + if (!st->cv.wait_for(lock, std::chrono::seconds(30), + [&]() { return !st->stdout_lines.empty(); })) + { + throw fastmcpp::TransportError("StdioTransport: timed out waiting for response"); + } + + while (!st->stdout_lines.empty()) + { + auto line = std::move(st->stdout_lines.front()); + st->stdout_lines.pop_front(); + lock.unlock(); + + if (line.empty()) + { + lock.lock(); + continue; + } + + try + { + auto parsed = fastmcpp::util::json::parse(line); + if (parsed.contains("id") && parsed["id"].is_number_integer() && + parsed["id"].get() == id) + { + return parsed; + } + } + catch (...) + { + // Ignore non-JSON stdout lines (e.g., server logs). + } + + lock.lock(); + } + } + } std::string stdout_data; std::string stderr_data; @@ -582,6 +740,29 @@ fastmcpp::Json StdioTransport::request(const std::string& route, const fastmcpp: #endif } +StdioTransport::StdioTransport(StdioTransport&&) noexcept = default; +StdioTransport& StdioTransport::operator=(StdioTransport&&) noexcept = default; + +StdioTransport::~StdioTransport() +{ +#ifdef TINY_PROCESS_LIB_AVAILABLE + if (state_ && state_->process) + { + state_->process->close_stdin(); + + int exit_status = 0; + for (int i = 0; i < 10; i++) + { + if (state_->process->try_get_exit_status(exit_status)) + return; + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + + state_->process->kill(false); + } +#endif +} + // ============================================================================= // SseClientTransport implementation // ============================================================================= diff --git a/tests/transports/stdio_client.cpp b/tests/transports/stdio_client.cpp index bb38cc3..eaba023 100644 --- a/tests/transports/stdio_client.cpp +++ b/tests/transports/stdio_client.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include static std::string find_stdio_server_binary() @@ -64,6 +65,33 @@ int main() std::cout << "[PASS] tools/call add returned 7" << std::endl; } + auto call_counter = [](StdioTransport& transport) -> int + { + Json params = Json{{"name", "counter"}, {"arguments", Json::object()}}; + auto resp = transport.request("tools/call", params); + auto content = resp["result"]["content"]; + std::string text = content.at(0).value("text", std::string()); + return std::stoi(text); + }; + + // keep_alive (default): state persists across calls + { + int first = call_counter(tx); + int second = call_counter(tx); + assert(second == first + 1); + std::cout << "[PASS] keep_alive preserved counter state" << std::endl; + } + + // keep_alive=false: each request spawns a fresh server (counter resets) + { + StdioTransport one_shot{find_stdio_server_binary(), {}, std::nullopt, false}; + int first = call_counter(one_shot); + int second = call_counter(one_shot); + assert(first == 1); + assert(second == 1); + std::cout << "[PASS] keep_alive=false resets counter state" << std::endl; + } + std::cout << "\n[OK] stdio client conformance passed" << std::endl; return 0; }