Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions cli/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ gateway start [flags]

- `--addr` - Address and port for the gateway server (e.g., ':9090', '127.0.0.1:8080') (default: ":9090")
- `--config` - Path to YAML file with gateway configuration (default: "./gateway.yaml")
- `--servers` - Comma-separated list of additional server URLs for Swagger UI (e.g., 'https://dev1.example.com,https://dev2.example.com')
- `--servers` - Comma-separated list of server URLs used for Swagger UI and CORS for SSE (e.g., 'https://dev1.example.com,https://dev2.example.com')
- `--connection-string` - Database connection string (DSN) for direct database connection
- `--disable-swagger` - Disable Swagger UI documentation (default: "false")
- `--mcp` - Start MCP SSE server (default: "true")
Expand Down Expand Up @@ -188,7 +188,7 @@ gateway start stdio [flags]
- `--raw` - Enable raw protocol mode optimized for AI agents (default: "false")
- `--addr` - Address and port for the gateway server (e.g., ':9090', '127.0.0.1:8080') (default: ":9090")
- `--config` - Path to YAML file with gateway configuration (default: "./gateway.yaml")
- `--servers` - Comma-separated list of additional server URLs for Swagger UI (e.g., 'https://dev1.example.com,https://dev2.example.com')
- `--servers` - Comma-separated list of server URLs used for Swagger UI and CORS for SSE (e.g., 'https://dev1.example.com,https://dev2.example.com')



Expand Down
2 changes: 1 addition & 1 deletion cli/start.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ Upon successful startup, the terminal will display URLs for both services.`,
for _, plug := range plugs {
plug.EnrichMCP(srv)
}
sse := srv.ServeSSE(serverAddresses[0], prefix)
sse := srv.ServeSSE(serverAddresses[0], prefix, serverAddresses)
mux.Handle(path.Join("/", prefix, "sse"), sse)
mux.Handle(path.Join("/", prefix, "message"), sse)
// Set up SSE (Server-Sent Events) endpoints for real-time event streaming
Expand Down
4 changes: 2 additions & 2 deletions mcpgenerator/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ func (s *MCPServer) SetConnector(connector connectors.Connector) error {
return nil
}

func (s *MCPServer) ServeSSE(addr string, prefix string) *server.SSEServer {
return server.NewSSEServer(s.server, addr, prefix)
func (s *MCPServer) ServeSSE(addr string, prefix string, origins []string) *server.SSEServer {
return server.NewSSEServer(s.server, addr, prefix, origins...)
}

func (s *MCPServer) ServeStdio() *server.StdioServer {
Expand Down
54 changes: 44 additions & 10 deletions server/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@ import (
// SSEServer implements a Server-Sent Events (SSE) based MCP server.
// It provides real-time communication capabilities over HTTP using the SSE protocol.
type SSEServer struct {
server *MCPServer
baseURL string
sessions sync.Map
srv *http.Server
prefix string
server *MCPServer
baseURL string
sessions sync.Map
srv *http.Server
prefix string
allowedOrigin []string
}

// sseSession represents an active SSE connection.
Expand All @@ -33,14 +34,32 @@ type sseSession struct {
}

// NewSSEServer creates a new SSE server instance with the given MCP server and base URL.
func NewSSEServer(server *MCPServer, baseURL string, prefix string) *SSEServer {
// Optional origins can be provided to restrict CORS headers.
func NewSSEServer(server *MCPServer, baseURL string, prefix string, origins ...string) *SSEServer {
return &SSEServer{
server: server,
baseURL: baseURL,
prefix: prefix,
server: server,
baseURL: baseURL,
prefix: prefix,
allowedOrigin: origins,
}
}

func (s *SSEServer) setCORSHeaders(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if len(s.allowedOrigin) == 0 {
w.Header().Set("Access-Control-Allow-Origin", "*")
} else {
for _, o := range s.allowedOrigin {
if o == origin {
w.Header().Set("Access-Control-Allow-Origin", origin)
break
}
}
}
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Content-Type")
}

// NewTestServer creates a test server for testing purposes
func NewTestServer(server *MCPServer) *httptest.Server {
sseServer := &SSEServer{
Expand Down Expand Up @@ -88,10 +107,10 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
return
}

s.setCORSHeaders(w, r)
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Access-Control-Allow-Origin", "*")

if s.server.NeedAuth(r) {
http.Error(w, "Unauthorized", http.StatusUnauthorized)
Expand Down Expand Up @@ -172,6 +191,8 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
return
}

s.setCORSHeaders(w, r)

sessionID := r.URL.Query().Get("sessionId")
if sessionID == "" {
s.writeJSONRPCError(w, nil, mcp.INVALID_PARAMS, "Missing sessionId")
Expand Down Expand Up @@ -228,6 +249,11 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
}
}

func (s *SSEServer) handleOptions(w http.ResponseWriter, r *http.Request) {
s.setCORSHeaders(w, r)
w.WriteHeader(http.StatusOK)
}

// writeJSONRPCError writes a JSON-RPC error response with the given error details.
func (s *SSEServer) writeJSONRPCError(
w http.ResponseWriter,
Expand Down Expand Up @@ -273,8 +299,16 @@ func (s *SSEServer) SendEventToSession(
func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/" + path.Join(s.prefix, "sse"):
if r.Method == http.MethodOptions {
s.handleOptions(w, r)
return
}
s.handleSSE(w, r)
case "/" + path.Join(s.prefix, "message"):
if r.Method == http.MethodOptions {
s.handleOptions(w, r)
return
}
s.handleMessage(w, r)
default:
http.NotFound(w, r)
Expand Down
36 changes: 32 additions & 4 deletions server/sse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
func TestSSEServer(t *testing.T) {
t.Run("Can instantiate", func(t *testing.T) {
mcpServer := NewMCPServer("test", "1.0.0")
sseServer := NewSSEServer(mcpServer, "http://localhost:8080", "")
sseServer := NewSSEServer(mcpServer, "http://localhost:8080", "", "http://localhost:8080")

if sseServer == nil {
t.Error("SSEServer should not be nil")
Expand Down Expand Up @@ -234,7 +234,7 @@ func TestSSEServer(t *testing.T) {

t.Run("Can be used as http.Handler", func(t *testing.T) {
mcpServer := NewMCPServer("test", "1.0.0")
sseServer := NewSSEServer(mcpServer, "http://localhost:8080", "")
sseServer := NewSSEServer(mcpServer, "http://localhost:8080", "", "http://localhost:8080")

ts := httptest.NewServer(sseServer)
defer ts.Close()
Expand Down Expand Up @@ -287,7 +287,7 @@ func TestSSEServer(t *testing.T) {

t.Run("Works with middleware", func(t *testing.T) {
mcpServer := NewMCPServer("test", "1.0.0")
sseServer := NewSSEServer(mcpServer, "http://localhost:8080", "")
sseServer := NewSSEServer(mcpServer, "http://localhost:8080", "", "http://localhost:8080")

middleware := func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -336,7 +336,7 @@ func TestSSEServer(t *testing.T) {

t.Run("Works with custom mux", func(t *testing.T) {
mcpServer := NewMCPServer("test", "1.0.0")
sseServer := NewSSEServer(mcpServer, "", "")
sseServer := NewSSEServer(mcpServer, "", "", "")

mux := http.NewServeMux()
mux.Handle("/mcp/", http.StripPrefix("/mcp", sseServer))
Expand Down Expand Up @@ -405,4 +405,32 @@ func TestSSEServer(t *testing.T) {
// Clean up SSE connection
cancel()
})

t.Run("Handles OPTIONS with CORS", func(t *testing.T) {
mcpServer := NewMCPServer("test", "1.0.0")
sseServer := NewSSEServer(mcpServer, "http://localhost:8080", "", "http://localhost:8080")

ts := httptest.NewServer(sseServer)
defer ts.Close()

req, err := http.NewRequest("OPTIONS", fmt.Sprintf("%s/sse", ts.URL), nil)
if err != nil {
t.Fatalf("Failed to create request: %v", err)
}
req.Header.Set("Origin", "http://localhost:8080")

resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("Failed to send request: %v", err)
}
defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
t.Errorf("Expected status 200, got %d", resp.StatusCode)
}

if resp.Header.Get("Access-Control-Allow-Origin") != "http://localhost:8080" {
t.Errorf("Unexpected CORS origin: %s", resp.Header.Get("Access-Control-Allow-Origin"))
}
})
}