diff --git a/cli/README.md b/cli/README.md index 7b5cccd..322a5a2 100644 --- a/cli/README.md +++ b/cli/README.md @@ -160,7 +160,7 @@ gateway start [flags] - `--addr` - Address and port for the gateway server (e.g., ':9090', '127.0.0.1:8080') (default: ":9090") - `--config` - Path to YAML file with gateway configuration (default: "./gateway.yaml") -- `--servers` - Comma-separated list of additional server URLs for Swagger UI (e.g., 'https://dev1.example.com,https://dev2.example.com') +- `--servers` - Comma-separated list of server URLs used for Swagger UI and CORS for SSE (e.g., 'https://dev1.example.com,https://dev2.example.com') - `--connection-string` - Database connection string (DSN) for direct database connection - `--disable-swagger` - Disable Swagger UI documentation (default: "false") - `--mcp` - Start MCP SSE server (default: "true") @@ -188,7 +188,7 @@ gateway start stdio [flags] - `--raw` - Enable raw protocol mode optimized for AI agents (default: "false") - `--addr` - Address and port for the gateway server (e.g., ':9090', '127.0.0.1:8080') (default: ":9090") - `--config` - Path to YAML file with gateway configuration (default: "./gateway.yaml") -- `--servers` - Comma-separated list of additional server URLs for Swagger UI (e.g., 'https://dev1.example.com,https://dev2.example.com') +- `--servers` - Comma-separated list of server URLs used for Swagger UI and CORS for SSE (e.g., 'https://dev1.example.com,https://dev2.example.com') diff --git a/cli/start.go b/cli/start.go index 8fe638b..27738ce 100644 --- a/cli/start.go +++ b/cli/start.go @@ -156,7 +156,7 @@ Upon successful startup, the terminal will display URLs for both services.`, for _, plug := range plugs { plug.EnrichMCP(srv) } - sse := srv.ServeSSE(serverAddresses[0], prefix) + sse := srv.ServeSSE(serverAddresses[0], prefix, serverAddresses) mux.Handle(path.Join("/", prefix, "sse"), sse) mux.Handle(path.Join("/", prefix, "message"), sse) // Set up SSE (Server-Sent Events) endpoints for real-time event streaming diff --git a/mcpgenerator/server.go b/mcpgenerator/server.go index c584b93..2c2c73f 100644 --- a/mcpgenerator/server.go +++ b/mcpgenerator/server.go @@ -48,8 +48,8 @@ func (s *MCPServer) SetConnector(connector connectors.Connector) error { return nil } -func (s *MCPServer) ServeSSE(addr string, prefix string) *server.SSEServer { - return server.NewSSEServer(s.server, addr, prefix) +func (s *MCPServer) ServeSSE(addr string, prefix string, origins []string) *server.SSEServer { + return server.NewSSEServer(s.server, addr, prefix, origins...) } func (s *MCPServer) ServeStdio() *server.StdioServer { diff --git a/server/sse.go b/server/sse.go index 2c4fb21..071949b 100644 --- a/server/sse.go +++ b/server/sse.go @@ -17,11 +17,12 @@ import ( // SSEServer implements a Server-Sent Events (SSE) based MCP server. // It provides real-time communication capabilities over HTTP using the SSE protocol. type SSEServer struct { - server *MCPServer - baseURL string - sessions sync.Map - srv *http.Server - prefix string + server *MCPServer + baseURL string + sessions sync.Map + srv *http.Server + prefix string + allowedOrigin []string } // sseSession represents an active SSE connection. @@ -33,14 +34,32 @@ type sseSession struct { } // NewSSEServer creates a new SSE server instance with the given MCP server and base URL. -func NewSSEServer(server *MCPServer, baseURL string, prefix string) *SSEServer { +// Optional origins can be provided to restrict CORS headers. +func NewSSEServer(server *MCPServer, baseURL string, prefix string, origins ...string) *SSEServer { return &SSEServer{ - server: server, - baseURL: baseURL, - prefix: prefix, + server: server, + baseURL: baseURL, + prefix: prefix, + allowedOrigin: origins, } } +func (s *SSEServer) setCORSHeaders(w http.ResponseWriter, r *http.Request) { + origin := r.Header.Get("Origin") + if len(s.allowedOrigin) == 0 { + w.Header().Set("Access-Control-Allow-Origin", "*") + } else { + for _, o := range s.allowedOrigin { + if o == origin { + w.Header().Set("Access-Control-Allow-Origin", origin) + break + } + } + } + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") +} + // NewTestServer creates a test server for testing purposes func NewTestServer(server *MCPServer) *httptest.Server { sseServer := &SSEServer{ @@ -88,10 +107,10 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) { return } + s.setCORSHeaders(w, r) w.Header().Set("Content-Type", "text/event-stream") w.Header().Set("Cache-Control", "no-cache") w.Header().Set("Connection", "keep-alive") - w.Header().Set("Access-Control-Allow-Origin", "*") if s.server.NeedAuth(r) { http.Error(w, "Unauthorized", http.StatusUnauthorized) @@ -172,6 +191,8 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { return } + s.setCORSHeaders(w, r) + sessionID := r.URL.Query().Get("sessionId") if sessionID == "" { s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Missing sessionId") @@ -228,6 +249,11 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) { } } +func (s *SSEServer) handleOptions(w http.ResponseWriter, r *http.Request) { + s.setCORSHeaders(w, r) + w.WriteHeader(http.StatusOK) +} + // writeJSONRPCError writes a JSON-RPC error response with the given error details. func (s *SSEServer) writeJSONRPCError( w http.ResponseWriter, @@ -273,8 +299,16 @@ func (s *SSEServer) SendEventToSession( func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/" + path.Join(s.prefix, "sse"): + if r.Method == http.MethodOptions { + s.handleOptions(w, r) + return + } s.handleSSE(w, r) case "/" + path.Join(s.prefix, "message"): + if r.Method == http.MethodOptions { + s.handleOptions(w, r) + return + } s.handleMessage(w, r) default: http.NotFound(w, r) diff --git a/server/sse_test.go b/server/sse_test.go index ee694a4..33d0833 100644 --- a/server/sse_test.go +++ b/server/sse_test.go @@ -16,7 +16,7 @@ import ( func TestSSEServer(t *testing.T) { t.Run("Can instantiate", func(t *testing.T) { mcpServer := NewMCPServer("test", "1.0.0") - sseServer := NewSSEServer(mcpServer, "http://localhost:8080", "") + sseServer := NewSSEServer(mcpServer, "http://localhost:8080", "", "http://localhost:8080") if sseServer == nil { t.Error("SSEServer should not be nil") @@ -234,7 +234,7 @@ func TestSSEServer(t *testing.T) { t.Run("Can be used as http.Handler", func(t *testing.T) { mcpServer := NewMCPServer("test", "1.0.0") - sseServer := NewSSEServer(mcpServer, "http://localhost:8080", "") + sseServer := NewSSEServer(mcpServer, "http://localhost:8080", "", "http://localhost:8080") ts := httptest.NewServer(sseServer) defer ts.Close() @@ -287,7 +287,7 @@ func TestSSEServer(t *testing.T) { t.Run("Works with middleware", func(t *testing.T) { mcpServer := NewMCPServer("test", "1.0.0") - sseServer := NewSSEServer(mcpServer, "http://localhost:8080", "") + sseServer := NewSSEServer(mcpServer, "http://localhost:8080", "", "http://localhost:8080") middleware := func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -336,7 +336,7 @@ func TestSSEServer(t *testing.T) { t.Run("Works with custom mux", func(t *testing.T) { mcpServer := NewMCPServer("test", "1.0.0") - sseServer := NewSSEServer(mcpServer, "", "") + sseServer := NewSSEServer(mcpServer, "", "", "") mux := http.NewServeMux() mux.Handle("/mcp/", http.StripPrefix("/mcp", sseServer)) @@ -405,4 +405,32 @@ func TestSSEServer(t *testing.T) { // Clean up SSE connection cancel() }) + + t.Run("Handles OPTIONS with CORS", func(t *testing.T) { + mcpServer := NewMCPServer("test", "1.0.0") + sseServer := NewSSEServer(mcpServer, "http://localhost:8080", "", "http://localhost:8080") + + ts := httptest.NewServer(sseServer) + defer ts.Close() + + req, err := http.NewRequest("OPTIONS", fmt.Sprintf("%s/sse", ts.URL), nil) + if err != nil { + t.Fatalf("Failed to create request: %v", err) + } + req.Header.Set("Origin", "http://localhost:8080") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("Failed to send request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + t.Errorf("Expected status 200, got %d", resp.StatusCode) + } + + if resp.Header.Get("Access-Control-Allow-Origin") != "http://localhost:8080" { + t.Errorf("Unexpected CORS origin: %s", resp.Header.Get("Access-Control-Allow-Origin")) + } + }) }