Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ module commander

go 1.25.4

// TEMP: local replace for unpublished squadron-wire OAuth proxy messages.
// Revert + publish a new squadron-wire tag before merging.
replace github.com/mlund01/squadron-wire => ../squadron-wire

require (
github.com/gorilla/websocket v1.5.3
github.com/mlund01/squadron-wire v0.0.40
Expand Down
197 changes: 197 additions & 0 deletions internal/api/oauth.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
package api

import (
"encoding/json"
"fmt"
"html"
"log"
"net/http"
"time"

"github.com/mlund01/squadron-wire/protocol"

"commander/internal/hub"
)

// HandleOAuthCallback serves GET /oauth/callback, the public URL IdPs
// redirect the user's browser to after authorization. The callback is
// routed to the right squadron instance via the cryptographic `state`
// value (which squadron reserved in advance via OAuthRegisterFlow).
//
// This handler is intentionally unauthenticated — IdPs do not carry
// commander session cookies. Security comes from the state value being
// unguessable and single-use.
func HandleOAuthCallback(h *hub.Hub) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
state := q.Get("state")
code := q.Get("code")
idpErr := q.Get("error")
if idpErrDesc := q.Get("error_description"); idpErrDesc != "" && idpErr != "" {
idpErr = idpErr + ": " + idpErrDesc
}

if state == "" {
writeOAuthErrorPage(w, "callback missing state parameter")
return
}

flow, ok := h.PendingFlows().Claim(state)
if !ok {
writeOAuthErrorPage(w, "no matching OAuth flow (it may have expired)")
return
}

// Forward to the originating squadron.
env, err := protocol.NewRequest(protocol.TypeOAuthCallbackDelivery, &protocol.OAuthCallbackDeliveryPayload{
State: state,
Code: code,
Error: idpErr,
})
if err != nil {
writeOAuthErrorPage(w, "internal error building delivery: "+err.Error())
return
}
resp, err := h.SendRequest(flow.InstanceID, env, 30*time.Second)
if err != nil {
writeOAuthErrorPage(w, "failed to deliver callback to squadron: "+err.Error())
return
}
if resp.Type == protocol.TypeError {
var perr protocol.ErrorPayload
_ = protocol.DecodePayload(resp, &perr)
writeOAuthErrorPage(w, "squadron rejected callback: "+perr.Message)
return
}

// Notify any open commander tabs for this instance.
success := idpErr == "" && code != ""
noteType := "oauth_completed"
if !success {
noteType = "oauth_failed"
}
h.Notifications().Publish(flow.InstanceID, hub.Notification{
Type: noteType,
Data: map[string]interface{}{
"mcpName": flow.McpName,
"error": idpErr,
},
})

if success {
writeOAuthSuccessPage(w, flow.McpName)
} else {
writeOAuthErrorPage(w, idpErr)
}
}
}

// HandleStartOAuth kicks off a commander-initiated OAuth login for the
// named MCP server on the specified squadron. Returns the authorization URL
// for the browser to open in a new tab.
func HandleStartOAuth(h *hub.Hub) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
instanceID := r.PathValue("id")
mcpName := r.PathValue("name")

env, err := protocol.NewRequest(protocol.TypeStartMCPLogin, &protocol.StartMCPLoginPayload{
McpName: mcpName,
})
if err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
resp, err := h.SendRequest(instanceID, env, 30*time.Second)
if err != nil {
writeJSON(w, http.StatusBadGateway, map[string]string{"error": err.Error()})
return
}
if resp.Type == protocol.TypeError {
var perr protocol.ErrorPayload
_ = protocol.DecodePayload(resp, &perr)
writeJSON(w, http.StatusBadGateway, map[string]string{"error": perr.Message})
return
}
var ack protocol.StartMCPLoginAckPayload
if err := protocol.DecodePayload(resp, &ack); err != nil {
writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()})
return
}
if !ack.Accepted {
writeJSON(w, http.StatusBadRequest, map[string]string{"error": ack.Reason})
return
}
writeJSON(w, http.StatusOK, map[string]string{"authUrl": ack.AuthURL})
}
}

// HandleNotifications opens an SSE stream of per-instance notifications
// (e.g. oauth_completed). Used by the commander SPA to surface toasts.
func HandleNotifications(h *hub.Hub) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
instanceID := r.PathValue("id")
flusher, ok := w.(http.Flusher)
if !ok {
http.Error(w, "streaming unsupported", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no")

ch, cleanup := h.Notifications().Subscribe(instanceID)
defer cleanup()

// Initial comment line so the connection is flushed immediately.
fmt.Fprint(w, ": connected\n\n")
flusher.Flush()

keepalive := time.NewTicker(30 * time.Second)
defer keepalive.Stop()

for {
select {
case <-r.Context().Done():
return
case note, ok := <-ch:
if !ok {
return
}
data, err := json.Marshal(note)
if err != nil {
log.Printf("notification marshal: %v", err)
continue
}
fmt.Fprintf(w, "data: %s\n\n", data)
flusher.Flush()
case <-keepalive.C:
fmt.Fprint(w, ": keepalive\n\n")
flusher.Flush()
}
}
}
}

func writeOAuthSuccessPage(w http.ResponseWriter, mcpName string) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
_, _ = fmt.Fprintf(w, `<!doctype html>
<html><head><title>Authorized</title></head>
<body style="font-family:system-ui;padding:3rem;max-width:40rem;margin:auto">
<h1>Authorization complete</h1>
<p>%s is now connected. You can close this window.</p>
<script>setTimeout(function(){window.close();},2000);</script>
</body></html>`, html.EscapeString(mcpName))
}

func writeOAuthErrorPage(w http.ResponseWriter, msg string) {
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusBadRequest)
_, _ = fmt.Fprintf(w, `<!doctype html>
<html><head><title>Authorization failed</title></head>
<body style="font-family:system-ui;padding:3rem;max-width:40rem;margin:auto">
<h1>Authorization failed</h1>
<p>%s</p>
<p>You can close this window and try again from the command center UI.</p>
</body></html>`, html.EscapeString(msg))
}
6 changes: 6 additions & 0 deletions internal/api/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ func RegisterRoutes(mux *http.ServeMux, h *hub.Hub, ka *keepalive.KeepAlive) {
mux.HandleFunc("GET /api/instances/{id}/agents/{name}/chats", handleChatHistory(h))
mux.HandleFunc("GET /api/instances/{id}/chats/{sessionId}/messages", handleChatMessages(h))
mux.HandleFunc("DELETE /api/instances/{id}/chats/{sessionId}", handleArchiveChat(h))

// OAuth proxy: start a login flow, stream completion notifications.
// The public callback endpoint (/oauth/callback) is registered separately
// on the outer mux so IdPs can reach it without auth.
mux.HandleFunc("POST /api/instances/{id}/mcp/{name}/oauth/start", HandleStartOAuth(h))
mux.HandleFunc("GET /api/instances/{id}/notifications", HandleNotifications(h))
}

func handleListInstances(h *hub.Hub) http.HandlerFunc {
Expand Down
33 changes: 33 additions & 0 deletions internal/hub/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"time"

"github.com/gorilla/websocket"
"github.com/mlund01/squadron-wire/protocol"

Check failure on line 11 in internal/hub/connection.go

View workflow job for this annotation

GitHub Actions / test

github.com/mlund01/squadron-wire@v0.0.40: replacement directory ../squadron-wire does not exist
)

const (
Expand Down Expand Up @@ -371,11 +371,44 @@
c.fanOutChatEvent(env)
case protocol.TypeChatComplete:
c.fanOutChatComplete(env)
case protocol.TypeOAuthRegisterFlow:
c.handleOAuthRegisterFlow(env)
default:
log.Printf("Unhandled message type: %s", env.Type)
}
}

// handleOAuthRegisterFlow records a pending OAuth flow for later callback
// routing. Called when a squadron kicks off an MCP login and asks commander
// to reserve the `state` value.
func (c *Connection) handleOAuthRegisterFlow(env *protocol.Envelope) {
var payload protocol.OAuthRegisterFlowPayload
if err := protocol.DecodePayload(env, &payload); err != nil {
log.Printf("Invalid oauth_register_flow payload: %v", err)
ack, _ := protocol.NewError(env.RequestID, "decode_error", err.Error())
c.Send(ack)
return
}
if c.instanceID == "" {
ack, _ := protocol.NewError(env.RequestID, "not_registered", "instance not registered yet")
c.Send(ack)
return
}
if payload.State == "" {
ack, _ := protocol.NewResponse(env.RequestID, protocol.TypeOAuthRegisterFlowAck, &protocol.OAuthRegisterFlowAckPayload{
Accepted: false,
Reason: "state is required",
})
c.Send(ack)
return
}
c.hub.PendingFlows().Register(payload.State, c.instanceID, payload.McpName)
ack, _ := protocol.NewResponse(env.RequestID, protocol.TypeOAuthRegisterFlowAck, &protocol.OAuthRegisterFlowAckPayload{
Accepted: true,
})
c.Send(ack)
}

func (c *Connection) handleRegister(env *protocol.Envelope) {
var payload protocol.RegisterPayload
if err := protocol.DecodePayload(env, &payload); err != nil {
Expand Down
12 changes: 12 additions & 0 deletions internal/hub/hub.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (

"github.com/gorilla/websocket"
"github.com/mlund01/squadron-wire/protocol"

oauthflows "commander/internal/oauth"
)

var upgrader = websocket.Upgrader{
Expand All @@ -19,6 +21,8 @@ type Hub struct {
mu sync.RWMutex
connections map[string]*Connection // instanceID → connection
registry *Registry
pendingFlows *oauthflows.PendingFlows
notifications *Notifications
AllowConfigEdit bool
}

Expand All @@ -27,10 +31,18 @@ func New(allowConfigEdit bool) *Hub {
return &Hub{
connections: make(map[string]*Connection),
registry: NewRegistry(),
pendingFlows: oauthflows.New(),
notifications: NewNotifications(),
AllowConfigEdit: allowConfigEdit,
}
}

// PendingFlows returns the OAuth flow store.
func (h *Hub) PendingFlows() *oauthflows.PendingFlows { return h.pendingFlows }

// Notifications returns the per-instance notification fan-out.
func (h *Hub) Notifications() *Notifications { return h.notifications }

// Start initializes background tasks (heartbeat, cleanup, etc.).
func (h *Hub) Start() {
// TODO: Start heartbeat ticker
Expand Down
68 changes: 68 additions & 0 deletions internal/hub/notifications.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package hub

import (
"sync"
"time"
)

// Notification is a generic per-instance event pushed to any open browser
// tab subscribed to that instance. Initially used to confirm OAuth-proxy
// MCP logins; designed to accept future types without schema churn.
type Notification struct {
Type string `json:"type"` // e.g. "oauth_completed"
Timestamp time.Time `json:"timestamp"`
Data map[string]interface{} `json:"data,omitempty"`
}

// Notifications fans out per-instance notifications to SSE subscribers.
// Unlike the mission-event fan-out on Connection, notifications are keyed
// by instanceID (not missionID) and have no buffer — they are ephemeral
// hints, not reliable history. Subscribers that aren't listening when an
// event fires will miss it.
type Notifications struct {
mu sync.Mutex
subs map[string][]chan Notification // instanceID → subscribers
}

// NewNotifications creates an empty fan-out.
func NewNotifications() *Notifications {
return &Notifications{subs: make(map[string][]chan Notification)}
}

// Subscribe returns a channel for the given instance's notifications and a
// cleanup function to remove the subscription.
func (n *Notifications) Subscribe(instanceID string) (chan Notification, func()) {
ch := make(chan Notification, 16)
n.mu.Lock()
n.subs[instanceID] = append(n.subs[instanceID], ch)
n.mu.Unlock()
return ch, func() {
n.mu.Lock()
defer n.mu.Unlock()
subs := n.subs[instanceID]
for i, s := range subs {
if s == ch {
n.subs[instanceID] = append(subs[:i], subs[i+1:]...)
break
}
}
close(ch)
}
}

// Publish delivers a notification to all subscribers for the instance.
// Slow subscribers are skipped (no blocking).
func (n *Notifications) Publish(instanceID string, note Notification) {
if note.Timestamp.IsZero() {
note.Timestamp = time.Now()
}
n.mu.Lock()
subs := append([]chan Notification(nil), n.subs[instanceID]...)
n.mu.Unlock()
for _, ch := range subs {
select {
case ch <- note:
default:
}
}
}
Loading
Loading