diff --git a/api/activity.go b/api/activity.go index 745799c..3a57259 100644 --- a/api/activity.go +++ b/api/activity.go @@ -2,6 +2,7 @@ package main import ( "context" + "log" "strings" "time" @@ -59,3 +60,53 @@ func (s *server) markSpritzActivity(ctx context.Context, namespace, name string, return s.client.Status().Update(ctx, current) }) } + +func spritzActivityRefreshInterval(spec spritzv1.SpritzSpec, fallback time.Duration) time.Duration { + interval := fallback + if interval <= 0 { + interval = time.Minute + } + if raw := strings.TrimSpace(spec.IdleTTL); raw != "" { + if idleTTL, err := time.ParseDuration(raw); err == nil && idleTTL > 0 { + candidate := idleTTL / 2 + if candidate <= 0 { + candidate = idleTTL + } + if candidate > 0 && candidate < interval { + interval = candidate + } + } + } + if interval <= 0 { + return time.Minute + } + return interval +} + +func (s *server) startSpritzActivityLoop(ctx context.Context, spritz *spritzv1.Spritz, fallback time.Duration, source string) { + if s == nil || spritz == nil { + return + } + record := func(when time.Time) { + refreshCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := s.recordSpritzActivity(refreshCtx, spritz.Namespace, spritz.Name, when); err != nil { + log.Printf("spritz %s: failed to refresh activity name=%s namespace=%s err=%v", source, spritz.Name, spritz.Namespace, err) + } + } + record(time.Now()) + + interval := spritzActivityRefreshInterval(spritz.Spec, fallback) + go func() { + ticker := time.NewTicker(interval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case tick := <-ticker.C: + record(tick) + } + } + }() +} diff --git a/api/main.go b/api/main.go index 81e977a..bc8029a 100644 --- a/api/main.go +++ b/api/main.go @@ -44,6 +44,7 @@ type server struct { routeModel spritzv1.SharedHostRouteModel instanceProxy instanceProxyConfig terminal terminalConfig + portForward portForwardConfig sshGateway sshGatewayConfig sshDefaults sshDefaults sshMintLimiter *sshMintLimiter @@ -64,7 +65,7 @@ type server struct { nameGeneratorFactory func(context.Context, string, string) (func() string, error) activityRecorder func(context.Context, string, string, time.Time) error findRunningPodFunc func(context.Context, string, string, string) (*corev1.Pod, error) - openSSHPortForwardFunc func(context.Context, *corev1.Pod, uint32) (net.Conn, io.Closer, error) + openPodPortForwardFunc func(context.Context, *corev1.Pod, uint32) (net.Conn, io.Closer, error) } func main() { @@ -114,6 +115,7 @@ func main() { routeModel := spritzRouteModelFromEnv() instanceProxy := newInstanceProxyConfig() terminal := newTerminalConfig() + portForward := newPortForwardConfig() acp := newACPConfig() extensions, err := newExtensionRegistry() if err != nil { @@ -185,6 +187,7 @@ func main() { routeModel: routeModel, instanceProxy: instanceProxy, terminal: terminal, + portForward: portForward, sshGateway: sshGateway, sshDefaults: sshDefaults, sshMintLimiter: sshMintLimiter, @@ -301,6 +304,9 @@ func (s *server) registerRoutes(e *echo.Echo) { if s.terminal.enabled { group.GET("/spritzes/:name/terminal", s.openTerminal) } + if s.portForward.enabled { + group.GET("/spritzes/:name/port-forward", s.openPortForward) + } if s.instanceProxy.enabled { rootSecured := e.Group("", s.authMiddleware()) prefix := s.instanceProxy.pathPrefix(s.routeModel) diff --git a/api/port_forward.go b/api/port_forward.go new file mode 100644 index 0000000..9fd22ae --- /dev/null +++ b/api/port_forward.go @@ -0,0 +1,283 @@ +package main + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "net" + "net/http" + "net/url" + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/labstack/echo/v4" + corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" +) + +type portForwardConfig struct { + enabled bool + containerName string + allowedOrigins map[string]struct{} + activityRefresh time.Duration +} + +type portForwardControlMessage struct { + Type string `json:"type"` +} + +var errPortForwardHalfClose = errors.New("port-forward half-close") + +func newPortForwardConfig() portForwardConfig { + return portForwardConfig{ + enabled: parseBoolEnv("SPRITZ_PORT_FORWARD_ENABLED", true), + containerName: envOrDefault("SPRITZ_PORT_FORWARD_CONTAINER", "spritz"), + allowedOrigins: splitSet(os.Getenv("SPRITZ_PORT_FORWARD_ORIGINS")), + activityRefresh: parseDurationEnv("SPRITZ_PORT_FORWARD_ACTIVITY_REFRESH", time.Minute), + } +} + +func (p portForwardConfig) allowOrigin(r *http.Request) bool { + origin := strings.TrimSpace(r.Header.Get("Origin")) + if len(p.allowedOrigins) == 0 { + if origin == "" { + return false + } + parsed, err := url.Parse(origin) + if err != nil { + return false + } + return strings.EqualFold(parsed.Host, r.Host) + } + if origin == "" { + return false + } + return hasSetValue(p.allowedOrigins, origin) +} + +func parsePortForwardQueryPort(c echo.Context) (uint32, error) { + value := strings.TrimSpace(c.QueryParam("port")) + if value == "" { + return 0, fmt.Errorf("remote port required") + } + port, err := strconv.Atoi(value) + if err != nil || port < 1 || port > 65535 { + return 0, fmt.Errorf("invalid remote port") + } + return uint32(port), nil +} + +func (s *server) openPortForward(c echo.Context) error { + if !s.portForward.enabled { + return writeError(c, http.StatusNotFound, "port forward disabled") + } + principal, err := requestPrincipal(c, s.auth) + if err != nil { + return writeAuthError(c, err) + } + if err := ensureAuthenticated(principal, s.auth.enabled()); err != nil { + return writeAuthError(c, err) + } + + name := strings.TrimSpace(c.Param("name")) + if name == "" { + return writeError(c, http.StatusBadRequest, "spritz name required") + } + remotePort, err := parsePortForwardQueryPort(c) + if err != nil { + return writeError(c, http.StatusBadRequest, err.Error()) + } + + namespace := s.requestNamespace(c) + if namespace == "" { + namespace = "default" + } + spritz, err := s.getAuthorizedSpritz(c.Request().Context(), principal, namespace, name) + if err != nil { + if apierrors.IsNotFound(err) { + return writeError(c, http.StatusNotFound, "spritz not found") + } + if errors.Is(err, errForbidden) { + return writeForbidden(c) + } + return writeError(c, http.StatusInternalServerError, err.Error()) + } + + pod, err := s.findPortForwardPod(c.Request().Context(), namespace, name, s.portForward.containerName) + if err != nil { + log.Printf("spritz port-forward: pod not ready name=%s namespace=%s user_id=%s err=%v", name, namespace, principal.ID, err) + return writeError(c, http.StatusConflict, "spritz not ready") + } + + upgrader := websocket.Upgrader{CheckOrigin: s.portForward.allowOrigin} + conn, err := upgrader.Upgrade(c.Response(), c.Request(), nil) + if err != nil { + return err + } + defer func() { + _ = conn.Close() + }() + + ctx, cancel := context.WithCancel(c.Request().Context()) + defer cancel() + s.startSpritzActivityLoop(ctx, spritz, s.portForward.activityRefresh, "port-forward") + + upstream, cleanup, err := s.openPodPortForward(ctx, pod, remotePort) + if err != nil { + log.Printf("spritz port-forward: open failed name=%s namespace=%s port=%d user_id=%s err=%v", name, namespace, remotePort, principal.ID, err) + _ = conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseInternalServerErr, "port forward unavailable"), time.Now().Add(500*time.Millisecond)) + return nil + } + defer func() { + _ = upstream.Close() + _ = cleanup.Close() + }() + + if err := proxyWebSocketNetConn(conn, upstream); err != nil { + if errors.Is(err, context.Canceled) || websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { + return nil + } + return err + } + return nil +} + +func proxyWebSocketNetConn(ws *websocket.Conn, upstream net.Conn) error { + errCh := make(chan error, 2) + var once sync.Once + closeAll := func() { + once.Do(func() { + _ = ws.Close() + _ = upstream.Close() + }) + } + + go func() { + errCh <- copyWebSocketToNetConn(ws, upstream) + }() + go func() { + errCh <- copyNetConnToWebSocket(upstream, ws) + }() + + halfClosed := 0 + for completed := 0; completed < 2; completed++ { + err := <-errCh + switch { + case err == nil: + closeAll() + return nil + case errors.Is(err, errPortForwardHalfClose): + halfClosed++ + if halfClosed == 2 { + closeAll() + return nil + } + case errors.Is(err, io.EOF), errors.Is(err, context.Canceled), websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway): + closeAll() + return nil + case func() bool { + ne, ok := err.(net.Error) + return ok && ne.Timeout() + }(): + closeAll() + return nil + default: + closeAll() + return err + } + } + closeAll() + return nil +} + +func (s *server) findPortForwardPod(ctx context.Context, namespace, name, container string) (*corev1.Pod, error) { + if s.findRunningPodFunc != nil { + return s.findRunningPodFunc(ctx, namespace, name, container) + } + return s.findRunningPod(ctx, namespace, name, container) +} + +func copyWebSocketToNetConn(ws *websocket.Conn, upstream net.Conn) error { + for { + msgType, payload, err := ws.ReadMessage() + if err != nil { + return err + } + if msgType == websocket.TextMessage { + control, err := parsePortForwardControl(payload) + if err != nil { + return err + } + if control.Type == "eof" { + if err := closeConnWrite(upstream); err != nil { + return err + } + return errPortForwardHalfClose + } + continue + } + if msgType != websocket.BinaryMessage { + continue + } + if len(payload) == 0 { + continue + } + if _, err := upstream.Write(payload); err != nil { + return err + } + } +} + +func copyNetConnToWebSocket(upstream net.Conn, ws *websocket.Conn) error { + buffer := make([]byte, 32*1024) + for { + n, err := upstream.Read(buffer) + if n > 0 { + if writeErr := ws.WriteMessage(websocket.BinaryMessage, buffer[:n]); writeErr != nil { + return writeErr + } + } + if err != nil { + if errors.Is(err, io.EOF) { + if writeErr := ws.WriteMessage(websocket.TextMessage, mustMarshalPortForwardControl(portForwardControlMessage{Type: "eof"})); writeErr != nil { + return writeErr + } + return errPortForwardHalfClose + } + return err + } + } +} + +func parsePortForwardControl(payload []byte) (portForwardControlMessage, error) { + var message portForwardControlMessage + if err := json.Unmarshal(payload, &message); err != nil { + return portForwardControlMessage{}, fmt.Errorf("invalid port-forward control: %w", err) + } + return message, nil +} + +func mustMarshalPortForwardControl(message portForwardControlMessage) []byte { + payload, err := json.Marshal(message) + if err != nil { + panic(err) + } + return payload +} + +func closeConnWrite(conn net.Conn) error { + type closeWriter interface { + CloseWrite() error + } + if writer, ok := conn.(closeWriter); ok { + return writer.CloseWrite() + } + return conn.Close() +} diff --git a/api/port_forward_test.go b/api/port_forward_test.go new file mode 100644 index 0000000..d34caa0 --- /dev/null +++ b/api/port_forward_test.go @@ -0,0 +1,332 @@ +package main + +import ( + "context" + "errors" + "io" + "net" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/gorilla/websocket" + "github.com/labstack/echo/v4" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + ctrlclientfake "sigs.k8s.io/controller-runtime/pkg/client/fake" + + spritzv1 "spritz.sh/operator/api/v1" +) + +const portForwardEOFControl = `{"type":"eof"}` + +func TestOpenPortForwardRejectsInvalidRemotePort(t *testing.T) { + s := newCreateSpritzTestServer(t) + s.portForward = portForwardConfig{enabled: true, containerName: "spritz"} + e := echo.New() + e.GET("/api/spritzes/:name/port-forward", s.openPortForward) + + req := httptest.NewRequest(http.MethodGet, "/api/spritzes/devbox1/port-forward?port=99999", nil) + req.Header.Set("X-Spritz-User-Id", "user-1") + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + + if rec.Code != http.StatusBadRequest { + t.Fatalf("expected 400 for invalid port, got %d: %s", rec.Code, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), "invalid remote port") { + t.Fatalf("expected invalid remote port error, got %q", rec.Body.String()) + } +} + +func TestOpenPortForwardProxiesToInjectedUpstream(t *testing.T) { + scheme := newTestSpritzScheme(t) + spritz := &spritzv1.Spritz{ + ObjectMeta: metav1.ObjectMeta{ + Name: "tidal-falcon", + Namespace: "spritz-test", + }, + Spec: spritzv1.SpritzSpec{ + Owner: spritzv1.SpritzOwner{ID: "user-1"}, + }, + } + + var activityCalls atomic.Int32 + var forwardedPort atomic.Int32 + s := &server{ + client: ctrlclientfake.NewClientBuilder(). + WithScheme(scheme). + WithObjects(spritz). + Build(), + scheme: scheme, + namespace: "spritz-test", + auth: authConfig{ + mode: authModeHeader, + headerID: "X-Spritz-User-Id", + headerDefaultType: principalTypeHuman, + }, + internalAuth: internalAuthConfig{enabled: false}, + portForward: portForwardConfig{enabled: true, containerName: "spritz"}, + activityRecorder: func(ctx context.Context, namespace, name string, when time.Time) error { + activityCalls.Add(1) + return nil + }, + findRunningPodFunc: func(ctx context.Context, namespace, name, container string) (*corev1.Pod, error) { + return &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "tidal-falcon-pod", + Namespace: namespace, + }, + }, nil + }, + openPodPortForwardFunc: func(ctx context.Context, pod *corev1.Pod, remotePort uint32) (net.Conn, io.Closer, error) { + forwardedPort.Store(int32(remotePort)) + clientConn, serverConn := net.Pipe() + go func() { + defer serverConn.Close() + _, _ = io.Copy(serverConn, serverConn) + }() + return clientConn, closeFunc(func() error { return nil }), nil + }, + } + + e := echo.New() + e.GET("/api/spritzes/:name/port-forward", s.openPortForward) + srv := httptest.NewServer(e) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/api/spritzes/tidal-falcon/port-forward?port=3000" + headers := http.Header{} + headers.Set("X-Spritz-User-Id", "user-1") + headers.Set("Origin", srv.URL) + conn, _, err := websocket.DefaultDialer.Dial(wsURL, headers) + if err != nil { + t.Fatalf("dial websocket: %v", err) + } + defer conn.Close() + + if err := conn.WriteMessage(websocket.BinaryMessage, []byte("ping")); err != nil { + t.Fatalf("write websocket: %v", err) + } + _, payload, err := conn.ReadMessage() + if err != nil { + t.Fatalf("read websocket: %v", err) + } + if string(payload) != "ping" { + t.Fatalf("unexpected echoed payload %q", string(payload)) + } + if got := forwardedPort.Load(); got != 3000 { + t.Fatalf("forwarded remote port = %d, want 3000", got) + } + if activityCalls.Load() == 0 { + t.Fatal("expected port forwarding to refresh activity") + } +} + +func TestOpenPortForwardPreservesEOFFramedExchange(t *testing.T) { + scheme := newTestSpritzScheme(t) + spritz := &spritzv1.Spritz{ + ObjectMeta: metav1.ObjectMeta{ + Name: "tidal-falcon", + Namespace: "spritz-test", + }, + Spec: spritzv1.SpritzSpec{ + Owner: spritzv1.SpritzOwner{ID: "user-1"}, + }, + } + + upstreamListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen upstream: %v", err) + } + defer upstreamListener.Close() + + upstreamDone := make(chan error, 1) + go func() { + conn, err := upstreamListener.Accept() + if err != nil { + upstreamDone <- err + return + } + defer conn.Close() + payload, err := io.ReadAll(conn) + if err != nil { + upstreamDone <- err + return + } + if string(payload) != "ping" { + upstreamDone <- io.ErrUnexpectedEOF + return + } + if _, err := conn.Write([]byte("pong")); err != nil { + upstreamDone <- err + return + } + upstreamDone <- nil + }() + + s := &server{ + client: ctrlclientfake.NewClientBuilder(). + WithScheme(scheme). + WithObjects(spritz). + Build(), + scheme: scheme, + namespace: "spritz-test", + auth: authConfig{ + mode: authModeHeader, + headerID: "X-Spritz-User-Id", + headerDefaultType: principalTypeHuman, + }, + internalAuth: internalAuthConfig{enabled: false}, + portForward: portForwardConfig{enabled: true, containerName: "spritz"}, + findRunningPodFunc: func(ctx context.Context, namespace, name, container string) (*corev1.Pod, error) { + return &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "tidal-falcon-pod", + Namespace: namespace, + }, + }, nil + }, + openPodPortForwardFunc: func(ctx context.Context, pod *corev1.Pod, remotePort uint32) (net.Conn, io.Closer, error) { + conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", upstreamListener.Addr().String()) + if err != nil { + return nil, nil, err + } + return conn, closeFunc(func() error { return nil }), nil + }, + } + + e := echo.New() + e.GET("/api/spritzes/:name/port-forward", s.openPortForward) + srv := httptest.NewServer(e) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/api/spritzes/tidal-falcon/port-forward?port=3000" + headers := http.Header{} + headers.Set("X-Spritz-User-Id", "user-1") + headers.Set("Origin", srv.URL) + conn, _, err := websocket.DefaultDialer.Dial(wsURL, headers) + if err != nil { + t.Fatalf("dial websocket: %v", err) + } + defer conn.Close() + + if err := conn.WriteMessage(websocket.BinaryMessage, []byte("ping")); err != nil { + t.Fatalf("write websocket payload: %v", err) + } + if err := conn.WriteMessage(websocket.TextMessage, []byte(portForwardEOFControl)); err != nil { + t.Fatalf("write websocket eof: %v", err) + } + + messageType, payload, err := conn.ReadMessage() + if err != nil { + t.Fatalf("read websocket response: %v", err) + } + if messageType != websocket.BinaryMessage || string(payload) != "pong" { + t.Fatalf("unexpected websocket response type=%d payload=%q", messageType, string(payload)) + } + + messageType, payload, err = conn.ReadMessage() + if err != nil { + t.Fatalf("read websocket eof: %v", err) + } + if messageType != websocket.TextMessage || string(payload) != portForwardEOFControl { + t.Fatalf("unexpected websocket eof type=%d payload=%q", messageType, string(payload)) + } + + if err := <-upstreamDone; err != nil { + t.Fatalf("upstream exchange failed: %v", err) + } +} + +func TestOpenPortForwardClosesUpstreamWhenWebSocketCloses(t *testing.T) { + scheme := newTestSpritzScheme(t) + spritz := &spritzv1.Spritz{ + ObjectMeta: metav1.ObjectMeta{ + Name: "tidal-falcon", + Namespace: "spritz-test", + }, + Spec: spritzv1.SpritzSpec{ + Owner: spritzv1.SpritzOwner{ID: "user-1"}, + }, + } + + upstreamListener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen upstream: %v", err) + } + defer upstreamListener.Close() + + upstreamDone := make(chan error, 1) + go func() { + conn, err := upstreamListener.Accept() + if err != nil { + upstreamDone <- err + return + } + defer conn.Close() + _ = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + buffer := make([]byte, 1) + _, err = conn.Read(buffer) + if err == nil || !errors.Is(err, io.EOF) { + upstreamDone <- err + return + } + upstreamDone <- nil + }() + + s := &server{ + client: ctrlclientfake.NewClientBuilder(). + WithScheme(scheme). + WithObjects(spritz). + Build(), + scheme: scheme, + namespace: "spritz-test", + auth: authConfig{ + mode: authModeHeader, + headerID: "X-Spritz-User-Id", + headerDefaultType: principalTypeHuman, + }, + internalAuth: internalAuthConfig{enabled: false}, + portForward: portForwardConfig{enabled: true, containerName: "spritz"}, + findRunningPodFunc: func(ctx context.Context, namespace, name, container string) (*corev1.Pod, error) { + return &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "tidal-falcon-pod", + Namespace: namespace, + }, + }, nil + }, + openPodPortForwardFunc: func(ctx context.Context, pod *corev1.Pod, remotePort uint32) (net.Conn, io.Closer, error) { + conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", upstreamListener.Addr().String()) + if err != nil { + return nil, nil, err + } + return conn, closeFunc(func() error { return nil }), nil + }, + } + + e := echo.New() + e.GET("/api/spritzes/:name/port-forward", s.openPortForward) + srv := httptest.NewServer(e) + defer srv.Close() + + wsURL := "ws" + strings.TrimPrefix(srv.URL, "http") + "/api/spritzes/tidal-falcon/port-forward?port=3000" + headers := http.Header{} + headers.Set("X-Spritz-User-Id", "user-1") + headers.Set("Origin", srv.URL) + conn, _, err := websocket.DefaultDialer.Dial(wsURL, headers) + if err != nil { + t.Fatalf("dial websocket: %v", err) + } + + if err := conn.Close(); err != nil { + t.Fatalf("close websocket: %v", err) + } + if err := <-upstreamDone; err != nil { + t.Fatalf("expected upstream to close after websocket exit: %v", err) + } +} diff --git a/api/ssh_gateway.go b/api/ssh_gateway.go index ca022c2..38ba2a5 100644 --- a/api/ssh_gateway.go +++ b/api/ssh_gateway.go @@ -201,7 +201,7 @@ func (s *server) handleSSHPortForward(srv *sshserver.Server, conn *gossh.ServerC } s.ensureSSHActivityLoop(ctx, spritz) - upstream, cleanup, err := s.openSSHPortForward(ctx, pod, request.DestPort) + upstream, cleanup, err := s.openPodPortForward(ctx, pod, request.DestPort) if err != nil { log.Printf("spritz ssh: forward open failed name=%s namespace=%s port=%d err=%v", name, namespace, request.DestPort, err) newChan.Reject(gossh.ConnectionFailed, "port forward unavailable") @@ -277,9 +277,9 @@ func (s *server) findSSHGatewayPod(ctx context.Context, namespace, name, contain return s.findRunningPod(ctx, namespace, name, container) } -func (s *server) openSSHPortForward(ctx context.Context, pod *corev1.Pod, remotePort uint32) (net.Conn, io.Closer, error) { - if s.openSSHPortForwardFunc != nil { - return s.openSSHPortForwardFunc(ctx, pod, remotePort) +func (s *server) openPodPortForward(ctx context.Context, pod *corev1.Pod, remotePort uint32) (net.Conn, io.Closer, error) { + if s.openPodPortForwardFunc != nil { + return s.openPodPortForwardFunc(ctx, pod, remotePort) } if s.clientset == nil || s.restConfig == nil { return nil, nil, errors.New("ssh port forwarding is not configured") @@ -364,53 +364,11 @@ func (s *server) openSSHPortForward(ctx context.Context, pod *corev1.Pod, remote } func sshActivityRefreshInterval(spec spritzv1.SpritzSpec, fallback time.Duration) time.Duration { - interval := fallback - if interval <= 0 { - interval = time.Minute - } - if raw := strings.TrimSpace(spec.IdleTTL); raw != "" { - if idleTTL, err := time.ParseDuration(raw); err == nil && idleTTL > 0 { - candidate := idleTTL / 2 - if candidate <= 0 { - candidate = idleTTL - } - if candidate > 0 && candidate < interval { - interval = candidate - } - } - } - if interval <= 0 { - return time.Minute - } - return interval + return spritzActivityRefreshInterval(spec, fallback) } func (s *server) startSSHActivityLoop(ctx context.Context, spritz *spritzv1.Spritz) { - if s == nil || spritz == nil { - return - } - record := func(when time.Time) { - refreshCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - if err := s.recordSpritzActivity(refreshCtx, spritz.Namespace, spritz.Name, when); err != nil { - log.Printf("spritz ssh: failed to refresh activity name=%s namespace=%s err=%v", spritz.Name, spritz.Namespace, err) - } - } - record(time.Now()) - - interval := sshActivityRefreshInterval(spritz.Spec, s.sshGateway.activityRefresh) - go func() { - ticker := time.NewTicker(interval) - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case tick := <-ticker.C: - record(tick) - } - } - }() + s.startSpritzActivityLoop(ctx, spritz, s.sshGateway.activityRefresh, "ssh") } func (s *server) streamSSH(ctx context.Context, pod *corev1.Pod, sess sshserver.Session, hasPty bool, sizeQueue *terminalSizeQueue) error { diff --git a/api/ssh_gateway_test.go b/api/ssh_gateway_test.go index b420d2f..c51dda1 100644 --- a/api/ssh_gateway_test.go +++ b/api/ssh_gateway_test.go @@ -145,7 +145,7 @@ func TestSSHGatewayPortForwardProxiesToInjectedUpstream(t *testing.T) { }, }, nil }, - openSSHPortForwardFunc: func(ctx context.Context, pod *corev1.Pod, remotePort uint32) (net.Conn, io.Closer, error) { + openPodPortForwardFunc: func(ctx context.Context, pod *corev1.Pod, remotePort uint32) (net.Conn, io.Closer, error) { forwardedPort.Store(int32(remotePort)) conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", echoListener.Addr().String()) if err != nil { diff --git a/cli/src/index.ts b/cli/src/index.ts index 4d2ef90..6c765cb 100644 --- a/cli/src/index.ts +++ b/cli/src/index.ts @@ -3,6 +3,7 @@ import { spawn, spawnSync } from 'node:child_process'; import { closeSync, openSync, readlinkSync, writeFileSync, writeSync } from 'node:fs'; import { chmod, mkdtemp, mkdir, readFile, rm, writeFile } from 'node:fs/promises'; +import net from 'node:net'; import os from 'node:os'; import path from 'node:path'; import readline from 'node:readline/promises'; @@ -507,7 +508,7 @@ Usage: spritz open [--namespace ] spritz terminal [--namespace ] [--session ] [--transport ] [--print] spritz ssh [--namespace ] [--session ] [--transport ] [--print] - spritz port-forward [--namespace ] --local --remote [--print] + spritz port-forward [--namespace ] --local --remote [--transport ] [--print] spritz chat send (--instance | --conversation ) --message [--reason ] [--cwd ] [--title ] [--namespace <ns>] [--json] spritz profile list spritz profile current @@ -748,6 +749,12 @@ function resolveTransport(): 'ws' | 'ssh' { return normalizeTransport(terminalTransportDefault); } +function resolvePortForwardTransport(): 'auto' | 'ws' | 'ssh' { + const flag = argValue('--transport'); + if (flag) return normalizeTransport(flag); + return 'auto'; +} + function isJSend(payload: any): payload is { status: string; data?: any; message?: string } { return payload && typeof payload === 'object' && typeof payload.status === 'string'; } @@ -1026,6 +1033,73 @@ function terminalWsUrl(apiBase: string, name: string, namespace?: string, sessio return { url: baseUrl.toString(), origin }; } +function portForwardWsUrl( + apiBase: string, + name: string, + remotePort: number, + namespace?: string +): { url: string; origin: string } { + const baseUrl = new URL(apiBase); + const basePath = baseUrl.pathname.replace(/\/$/, ''); + baseUrl.pathname = `${basePath}/spritzes/${encodeURIComponent(name)}/port-forward`; + if (namespace) { + baseUrl.searchParams.set('namespace', namespace); + } + baseUrl.searchParams.set('port', String(remotePort)); + const origin = baseUrl.origin; + baseUrl.protocol = baseUrl.protocol === 'https:' ? 'wss:' : 'ws:'; + return { url: baseUrl.toString(), origin }; +} + +function writePortForwardOutput(socket: net.Socket, data: RawData) { + if (typeof data === 'string') { + socket.write(data); + return; + } + if (Array.isArray(data)) { + data.forEach((chunk) => socket.write(chunk)); + return; + } + if (data instanceof ArrayBuffer) { + socket.write(Buffer.from(data)); + return; + } + socket.write(data); +} + +function portForwardDescription(localPort: number, remotePort: number, url: string): string { + return `127.0.0.1:${localPort} -> 127.0.0.1:${remotePort} via ${url}`; +} + +function portForwardControl(type: 'eof'): string { + return JSON.stringify({ type }); +} + +function parsePortForwardControl(data: RawData): { type: 'eof' } | null { + const text = + typeof data === 'string' + ? data + : Buffer.isBuffer(data) + ? data.toString('utf8') + : data instanceof ArrayBuffer + ? Buffer.from(data).toString('utf8') + : Array.isArray(data) + ? Buffer.concat(data.map((chunk) => Buffer.from(chunk))).toString('utf8') + : null; + if (text == null) { + return null; + } + try { + const parsed = JSON.parse(text); + if (parsed?.type === 'eof') { + return { type: 'eof' }; + } + } catch { + return null; + } + return null; +} + function terminalResizePayload(): string { const cols = process.stdout.columns ?? 80; const rows = process.stdout.rows ?? 24; @@ -1214,9 +1288,9 @@ async function openTerminalSSH(name: string, namespace: string | undefined, prin } /** - * Opens a local loopback port forward to a loopback port inside one instance. + * Opens a local loopback port forward to a loopback port inside one instance over legacy SSH. */ -async function openPortForward( +async function openPortForwardSSH( name: string, namespace: string | undefined, printOnly: boolean, @@ -1227,6 +1301,232 @@ async function openPortForward( await openSSHConnection(name, namespace, printOnly, ['-N', '-L', forwardSpec]); } +async function openPortForwardWs( + name: string, + namespace: string | undefined, + printOnly: boolean, + localPort: number, + remotePort: number, +) { + const apiBase = await resolveApiBase(); + const { url, origin } = portForwardWsUrl(apiBase, name, remotePort, namespace); + if (printOnly) { + console.log(portForwardDescription(localPort, remotePort, url)); + return; + } + + const headers: Record<string, string> = { + ...(await authHeaders()), + Origin: origin, + }; + await validatePortForwardWebSocket(url, headers); + const sockets = new Set<net.Socket>(); + const server = net.createServer({ allowHalfOpen: true }, (socket) => { + sockets.add(socket); + socket.pause(); + socket.on('close', () => { + sockets.delete(socket); + }); + void bridgePortForwardSocket(socket, url, headers).catch(() => { + socket.destroy(); + }); + }); + + await new Promise<void>((resolve, reject) => { + const onError = (err: Error) => { + server.off('listening', onListening); + reject(err); + }; + const onListening = () => { + server.off('error', onError); + resolve(); + }; + server.once('error', onError); + server.once('listening', onListening); + server.listen(localPort, '127.0.0.1'); + }); + + console.error(`[spz] forwarding 127.0.0.1:${localPort} -> ${name}:127.0.0.1:${remotePort} over websocket`); + + await new Promise<void>((resolve, reject) => { + const cleanup = () => { + process.off('SIGINT', shutdown); + process.off('SIGTERM', shutdown); + process.off('SIGHUP', shutdown); + server.off('error', onServerError); + }; + const onServerError = (err: Error) => { + cleanup(); + reject(err); + }; + const shutdown = () => { + cleanup(); + sockets.forEach((socket) => socket.destroy()); + server.close((err) => { + if (err) { + reject(err); + return; + } + resolve(); + }); + }; + process.on('SIGINT', shutdown); + process.on('SIGTERM', shutdown); + process.on('SIGHUP', shutdown); + server.on('error', onServerError); + }); +} + +async function validatePortForwardWebSocket(url: string, headers: Record<string, string>) { + const ws = new WebSocket(url, { + headers, + handshakeTimeout: Number.isFinite(requestTimeoutMs) ? requestTimeoutMs : 10000, + }); + + await new Promise<void>((resolve, reject) => { + let settled = false; + let opened = false; + let readyTimer: NodeJS.Timeout | undefined; + const finish = (err?: Error) => { + if (settled) return; + settled = true; + if (readyTimer) { + clearTimeout(readyTimer); + } + ws.off('open', onOpen); + ws.off('close', onClose); + ws.off('error', onError); + ws.off('unexpected-response', onUnexpectedResponse); + if (!err && (ws.readyState === WebSocket.OPEN || ws.readyState === WebSocket.CONNECTING)) { + ws.close(); + } + if (err) { + reject(err); + return; + } + resolve(); + }; + const onOpen = () => { + opened = true; + readyTimer = setTimeout(() => finish(), 250); + }; + const onClose = (code: number, reason: Buffer) => { + if (!opened) { + finish(new Error(`port-forward validation failed: websocket closed (${code})`)); + return; + } + if (code !== 1000) { + const suffix = reason.length > 0 ? ` ${reason.toString()}` : ''; + finish(new Error(`port-forward validation failed: websocket closed (${code})${suffix}`)); + } + }; + const onError = (err: Error) => { + finish(err); + }; + const onUnexpectedResponse = (_req: any, res: any) => { + const status = typeof res?.statusCode === 'number' ? String(res.statusCode) : 'unknown'; + const text = typeof res?.statusMessage === 'string' && res.statusMessage.trim() ? ` ${res.statusMessage.trim()}` : ''; + finish(new Error(`port-forward validation failed: ${status}${text}`)); + }; + + ws.on('open', onOpen); + ws.on('close', onClose); + ws.on('error', onError); + ws.on('unexpected-response', onUnexpectedResponse); + }); +} + +async function bridgePortForwardSocket(socket: net.Socket, url: string, headers: Record<string, string>) { + const ws = new WebSocket(url, { + headers, + handshakeTimeout: Number.isFinite(requestTimeoutMs) ? requestTimeoutMs : 10000, + }); + ws.binaryType = 'nodebuffer'; + + await new Promise<void>((resolve, reject) => { + let opened = false; + let finished = false; + let socketEnded = false; + let wsEnded = false; + const maybeFinish = () => { + if (!socketEnded || !wsEnded) { + return; + } + finish(); + }; + const finish = (err?: Error) => { + if (finished) return; + finished = true; + socket.off('data', onSocketData); + socket.off('end', onSocketEnd); + socket.off('error', onSocketError); + socket.off('close', onSocketClose); + ws.off('open', onWsOpen); + ws.off('message', onWsMessage); + ws.off('close', onWsClose); + ws.off('error', onWsError); + if (socket.writable) { + socket.end(); + } + if (ws.readyState === WebSocket.OPEN || ws.readyState === WebSocket.CONNECTING) { + ws.close(); + } + if (err) { + reject(err); + return; + } + resolve(); + }; + const onSocketData = (chunk: Buffer) => { + if (ws.readyState === WebSocket.OPEN) { + ws.send(chunk); + } + }; + const onSocketEnd = () => { + socketEnded = true; + if (ws.readyState === WebSocket.OPEN) { + ws.send(portForwardControl('eof')); + } + maybeFinish(); + }; + const onSocketError = (err: Error) => finish(err); + const onSocketClose = () => finish(); + const onWsOpen = () => { + opened = true; + socket.resume(); + }; + const onWsMessage = (data: RawData, isBinary: boolean) => { + const control = isBinary ? null : parsePortForwardControl(data); + if (control?.type === 'eof') { + wsEnded = true; + if (socket.writable) { + socket.end(); + } + maybeFinish(); + return; + } + writePortForwardOutput(socket, data); + }; + const onWsClose = () => finish(); + const onWsError = (err: Error) => { + if (!opened) { + finish(err); + return; + } + socket.destroy(err); + }; + + socket.on('data', onSocketData); + socket.on('end', onSocketEnd); + socket.on('error', onSocketError); + socket.on('close', onSocketClose); + ws.on('open', onWsOpen); + ws.on('message', onWsMessage); + ws.on('close', onWsClose); + ws.on('error', onWsError); + }); +} + /** * Resolve namespace from CLI flags or active profile. */ @@ -1589,14 +1889,33 @@ async function main() { if (argValueInfo('--session').present) { throw new Error('--session is not supported with port-forward'); } - if (argValueInfo('--transport').present) { - throw new Error('--transport is not supported with port-forward'); - } const ns = await resolveNamespace(); const printOnly = hasFlag('--print'); const localPort = parsePortFlag('--local'); const remotePort = parsePortFlag('--remote'); - await openPortForward(name, ns, printOnly, localPort, remotePort); + const transport = resolvePortForwardTransport(); + if (transport === 'ssh') { + if (!printOnly) { + console.error('Using legacy SSH port forwarding.'); + } + await openPortForwardSSH(name, ns, printOnly, localPort, remotePort); + return; + } + if (transport === 'ws') { + await openPortForwardWs(name, ns, printOnly, localPort, remotePort); + return; + } + try { + await openPortForwardWs(name, ns, printOnly, localPort, remotePort); + return; + } catch (err) { + if (printOnly) { + throw err; + } + const message = err instanceof Error ? err.message : String(err); + console.error(`Websocket port-forward unavailable; falling back to legacy SSH: ${message}`); + await openPortForwardSSH(name, ns, false, localPort, remotePort); + } return; } diff --git a/cli/test/help.test.ts b/cli/test/help.test.ts index 1eb8ae8..773c5d6 100644 --- a/cli/test/help.test.ts +++ b/cli/test/help.test.ts @@ -53,6 +53,6 @@ test('top-level help lists port-forward command', async () => { assert.equal(result.code, 0, result.stderr); assert.match( result.stdout, - /spritz port-forward <name> \[--namespace <ns>\] --local <port> --remote <port> \[--print\]/, + /spritz port-forward <name> \[--namespace <ns>\] --local <port> --remote <port> \[--transport <ws\|ssh>\] \[--print\]/, ); }); diff --git a/cli/test/port-forward.test.ts b/cli/test/port-forward.test.ts index 63f5367..835be94 100644 --- a/cli/test/port-forward.test.ts +++ b/cli/test/port-forward.test.ts @@ -2,10 +2,12 @@ import assert from 'node:assert/strict'; import { spawn } from 'node:child_process'; import { mkdtempSync, writeFileSync, readFileSync } from 'node:fs'; import http from 'node:http'; +import net from 'node:net'; import os from 'node:os'; import test from 'node:test'; import path from 'node:path'; import { fileURLToPath } from 'node:url'; +import { WebSocketServer } from 'ws'; const __filename = fileURLToPath(import.meta.url); const __dirname = path.dirname(__filename); @@ -37,7 +39,64 @@ function spawnCli(args: string[], env: NodeJS.ProcessEnv) { }); } -test('port-forward --print prints the SSH command for the requested mapping', async (t) => { +async function getFreePort() { + return await new Promise<number>((resolve, reject) => { + const server = net.createServer(); + server.once('error', reject); + server.listen(0, '127.0.0.1', () => { + const address = server.address(); + if (!address || typeof address !== 'object') { + reject(new Error('failed to allocate port')); + return; + } + const port = address.port; + server.close((err) => { + if (err) reject(err); + else resolve(port); + }); + }); + }); +} + +async function waitForPattern(buffer: { value: string }, pattern: RegExp, timeoutMs = 5000) { + const deadline = Date.now() + timeoutMs; + while (Date.now() < deadline) { + if (pattern.test(buffer.value)) return; + await new Promise((resolve) => setTimeout(resolve, 25)); + } + throw new Error(`timed out waiting for ${pattern}: ${buffer.value}`); +} + +test('port-forward --print describes the default websocket mapping', async () => { + const localPort = await getFreePort(); + let requestHeaders: http.IncomingHttpHeaders | null = null; + const child = spawnCli( + ['port-forward', 'devbox1', '--local', String(localPort), '--remote', '4000', '--print'], + buildTestEnv('http://127.0.0.1:38080/api'), + ); + + let stdout = ''; + let stderr = ''; + child.stdout.on('data', (chunk) => { + stdout += chunk.toString(); + }); + child.stderr.on('data', (chunk) => { + stderr += chunk.toString(); + }); + + const exitCode = await new Promise<number | null>((resolve) => child.on('exit', resolve)); + assert.equal(exitCode, 0, `spz port-forward --print should succeed: ${stderr}`); + + assert.equal(requestHeaders, null); + assert.match( + stdout.trim(), + new RegExp( + `^127\\.0\\.0\\.1:${localPort} -> 127\\.0\\.0\\.1:4000 via ws://127\\.0\\.0\\.1:38080/api/spritzes/devbox1/port-forward\\?port=4000$` + ) + ); +}); + +test('port-forward --transport ssh --print prints the SSH command for the requested mapping', async (t) => { let requestHeaders: http.IncomingHttpHeaders | null = null; let requestPath = ''; let requestMethod = ''; @@ -73,6 +132,9 @@ test('port-forward --print prints the SSH command for the requested mapping', as const tempDir = mkdtempSync(path.join(os.tmpdir(), 'spz-port-forward-')); const fakeKeygen = path.join(tempDir, 'ssh-keygen'); + const fakeSsh = path.join(tempDir, 'ssh'); + const sshArgsLog = path.join(tempDir, 'ssh-args.log'); + writeExecutable( fakeKeygen, `#!/usr/bin/env bash @@ -89,11 +151,18 @@ done printf '%s\\n' 'PRIVATE KEY' > "$target" printf '%s\\n' 'ssh-ed25519 AAAATEST generated@test' > "\${target}.pub" chmod 600 "$target" "\${target}.pub" +`, + ); + writeExecutable( + fakeSsh, + `#!/usr/bin/env bash +set -euo pipefail +printf '%s\\n' "$@" > "$SSH_ARGS_LOG" `, ); const child = spawnCli( - ['port-forward', 'devbox1', '--local', '3000', '--remote', '4000', '--print'], + ['port-forward', 'devbox1', '--transport', 'ssh', '--local', '3000', '--remote', '4000', '--print'], buildTestEnv(`http://127.0.0.1:${address.port}/api`, { SPRITZ_SSH_KEYGEN: fakeKeygen, }), @@ -109,7 +178,7 @@ chmod 600 "$target" "\${target}.pub" }); const exitCode = await new Promise<number | null>((resolve) => child.on('exit', resolve)); - assert.equal(exitCode, 0, `spz port-forward --print should succeed: ${stderr}`); + assert.equal(exitCode, 0, `spz port-forward should succeed: ${stderr}`); assert.equal(requestHeaders?.authorization, 'Bearer service-token'); assert.equal(requestMethod, 'POST'); @@ -123,7 +192,7 @@ chmod 600 "$target" "\${target}.pub" assert.match(stdout, /spritz@127\.0\.0\.1/); }); -test('port-forward executes the SSH client with the expected loopback mapping', async (t) => { +test('port-forward --transport ssh executes the SSH client with the expected loopback mapping', async (t) => { const server = http.createServer((req, res) => { const chunks: Buffer[] = []; req.on('data', (chunk) => chunks.push(Buffer.from(chunk))); @@ -180,7 +249,7 @@ printf '%s\\n' "$@" > "$SSH_ARGS_LOG" ); const child = spawnCli( - ['port-forward', 'devbox1', '--namespace', 'spritz', '--local', '3000', '--remote', '4000'], + ['port-forward', 'devbox1', '--namespace', 'spritz', '--local', '3000', '--remote', '4000', '--transport', 'ssh'], buildTestEnv(`http://127.0.0.1:${address.port}/api`, { SPRITZ_SSH_KEYGEN: fakeKeygen, SPRITZ_SSH_BINARY: fakeSsh, @@ -207,6 +276,329 @@ printf '%s\\n' "$@" > "$SSH_ARGS_LOG" assert.equal(args.at(-1), 'spritz@127.0.0.1'); }); +test('port-forward proxies localhost traffic over websocket by default', async (t) => { + const localPort = await getFreePort(); + let upgradePath = ''; + let authorization = ''; + let origin = ''; + const server = http.createServer(); + const wss = new WebSocketServer({ noServer: true }); + await listen(server); + t.after(() => { + wss.close(); + server.close(); + }); + const address = server.address(); + assert.ok(address && typeof address === 'object'); + + server.on('upgrade', (req, socket, head) => { + upgradePath = req.url || ''; + authorization = req.headers.authorization || ''; + origin = req.headers.origin || ''; + wss.handleUpgrade(req, socket, head, (ws) => { + wss.emit('connection', ws, req); + }); + }); + wss.on('connection', (ws) => { + ws.on('message', (payload) => { + ws.send(payload); + }); + }); + + const child = spawnCli( + ['port-forward', 'devbox1', '--namespace', 'spritz', '--local', String(localPort), '--remote', '4000'], + buildTestEnv(`http://127.0.0.1:${address.port}/api`), + ); + let stderr = ''; + const stderrBuffer = { value: '' }; + child.stderr.on('data', (chunk) => { + const text = chunk.toString(); + stderr += text; + stderrBuffer.value += text; + }); + t.after(() => { + child.kill('SIGTERM'); + }); + + await waitForPattern(stderrBuffer, new RegExp(`forwarding 127\\.0\\.0\\.1:${localPort}`)); + + const client = net.connect(localPort, '127.0.0.1'); + t.after(() => { + client.destroy(); + }); + const replyPromise = new Promise<Buffer>((resolve) => { + client.once('data', (chunk) => resolve(Buffer.from(chunk))); + }); + client.write('ping'); + const reply = await replyPromise; + assert.equal(reply.toString(), 'ping'); + assert.equal(authorization, 'Bearer service-token'); + assert.equal(origin, `http://127.0.0.1:${address.port}`); + assert.equal(upgradePath, '/api/spritzes/devbox1/port-forward?namespace=spritz&port=4000'); + + child.kill('SIGTERM'); + const exitCode = await new Promise<number | null>((resolve) => child.on('exit', resolve)); + assert.equal(exitCode, 0, `spz port-forward should exit cleanly: ${stderr}`); +}); + +test('port-forward preserves EOF-framed exchanges over websocket', async (t) => { + const localPort = await getFreePort(); + const server = http.createServer(); + const wss = new WebSocketServer({ noServer: true }); + await listen(server); + t.after(() => { + wss.close(); + server.close(); + }); + const address = server.address(); + assert.ok(address && typeof address === 'object'); + + server.on('upgrade', (req, socket, head) => { + wss.handleUpgrade(req, socket, head, (ws) => { + wss.emit('connection', ws, req); + }); + }); + wss.on('connection', (ws) => { + const chunks: Buffer[] = []; + ws.on('message', (payload, isBinary) => { + if (isBinary) { + chunks.push(Buffer.from(payload as Buffer)); + return; + } + const control = JSON.parse(payload.toString('utf8')); + assert.equal(control.type, 'eof'); + assert.equal(Buffer.concat(chunks).toString('utf8'), 'ping'); + ws.send(Buffer.from('pong')); + ws.send(JSON.stringify({ type: 'eof' })); + }); + }); + + const child = spawnCli( + ['port-forward', 'devbox1', '--namespace', 'spritz', '--local', String(localPort), '--remote', '4000'], + buildTestEnv(`http://127.0.0.1:${address.port}/api`), + ); + let stderr = ''; + const stderrBuffer = { value: '' }; + child.stderr.on('data', (chunk) => { + const text = chunk.toString(); + stderr += text; + stderrBuffer.value += text; + }); + t.after(() => { + child.kill('SIGTERM'); + }); + + await waitForPattern(stderrBuffer, new RegExp(`forwarding 127\\.0\\.0\\.1:${localPort}`)); + + const client = net.connect(localPort, '127.0.0.1'); + const replyPromise = new Promise<string>((resolve, reject) => { + let payload = ''; + client.on('data', (chunk) => { + payload += chunk.toString(); + }); + client.on('end', () => resolve(payload)); + client.on('error', reject); + }); + client.write('ping'); + client.end(); + + const reply = await replyPromise; + assert.equal(reply, 'pong'); + + child.kill('SIGTERM'); + const exitCode = await new Promise<number | null>((resolve) => child.on('exit', resolve)); + assert.equal(exitCode, 0, `spz port-forward should exit cleanly: ${stderr}`); +}); + +test('port-forward closes the local client when the websocket tunnel drops after startup', async (t) => { + const localPort = await getFreePort(); + const server = http.createServer(); + const wss = new WebSocketServer({ noServer: true }); + await listen(server); + t.after(() => { + wss.close(); + server.close(); + }); + const address = server.address(); + assert.ok(address && typeof address === 'object'); + + let connections = 0; + server.on('upgrade', (req, socket, head) => { + wss.handleUpgrade(req, socket, head, (ws) => { + wss.emit('connection', ws, req); + }); + }); + wss.on('connection', (ws) => { + connections += 1; + if (connections === 1) { + return; + } + ws.close(1011, 'boom'); + }); + + const child = spawnCli( + ['port-forward', 'devbox1', '--transport', 'ws', '--namespace', 'spritz', '--local', String(localPort), '--remote', '4000'], + buildTestEnv(`http://127.0.0.1:${address.port}/api`), + ); + let stderr = ''; + const stderrBuffer = { value: '' }; + child.stderr.on('data', (chunk) => { + const text = chunk.toString(); + stderr += text; + stderrBuffer.value += text; + }); + t.after(() => { + child.kill('SIGTERM'); + }); + + await waitForPattern(stderrBuffer, new RegExp(`forwarding 127\\.0\\.0\\.1:${localPort}`)); + + const client = net.connect(localPort, '127.0.0.1'); + const clientClosed = new Promise<void>((resolve, reject) => { + const timer = setTimeout(() => reject(new Error('timed out waiting for local socket to close')), 2000); + client.on('error', () => { + clearTimeout(timer); + resolve(); + }); + client.on('close', () => { + clearTimeout(timer); + resolve(); + }); + }); + + await clientClosed; + client.destroy(); + + child.kill('SIGTERM'); + const exitCode = await new Promise<number | null>((resolve) => child.on('exit', resolve)); + assert.equal(exitCode, 0, `spz port-forward should exit cleanly: ${stderr}`); +}); + +test('port-forward falls back to SSH when websocket startup validation is rejected by default', async (t) => { + const tempDir = mkdtempSync(path.join(os.tmpdir(), 'spz-port-forward-')); + const fakeKeygen = path.join(tempDir, 'ssh-keygen'); + const fakeSsh = path.join(tempDir, 'ssh'); + const sshArgsLog = path.join(tempDir, 'ssh-args.log'); + const server = http.createServer((req, res) => { + if ((req.url || '').includes('/ssh')) { + const chunks: Buffer[] = []; + req.on('data', (chunk) => chunks.push(Buffer.from(chunk))); + req.on('end', () => { + res.writeHead(200, { 'Content-Type': 'application/json' }); + res.end(JSON.stringify({ + status: 'success', + data: { + host: '127.0.0.1', + user: 'spritz', + cert: 'ssh-ed25519-cert-v01@openssh.com AAAATEST', + port: 2201, + known_hosts: '[127.0.0.1]:2201 ssh-ed25519 AAAAKNOWNHOST', + }, + })); + }); + return; + } + res.writeHead(404, { 'Content-Type': 'text/plain' }); + res.end('missing'); + }); + await listen(server); + t.after(() => { + server.close(); + }); + const address = server.address(); + assert.ok(address && typeof address === 'object'); + + writeExecutable( + fakeKeygen, + `#!/usr/bin/env bash +set -euo pipefail +target="" +while (($#)); do + if [[ "$1" == "-f" ]]; then + target="$2" + shift 2 + continue + fi + shift +done +printf '%s\\n' 'PRIVATE KEY' > "$target" +printf '%s\\n' 'ssh-ed25519 AAAATEST generated@test' > "\${target}.pub" +chmod 600 "$target" "\${target}.pub" +`, + ); + writeExecutable( + fakeSsh, + `#!/usr/bin/env bash +set -euo pipefail +printf '%s\\n' "$@" > "$SSH_ARGS_LOG" +`, + ); + + const child = spawnCli( + ['port-forward', 'devbox1', '--namespace', 'spritz', '--local', '3000', '--remote', '4000'], + buildTestEnv(`http://127.0.0.1:${address.port}/api`, { + SPRITZ_SSH_KEYGEN: fakeKeygen, + SPRITZ_SSH_BINARY: fakeSsh, + SSH_ARGS_LOG: sshArgsLog, + }), + ); + + let stderr = ''; + child.stderr.on('data', (chunk) => { + stderr += chunk.toString(); + }); + + const exitCode = await new Promise<number | null>((resolve) => child.on('exit', resolve)); + assert.equal(exitCode, 0, `spz port-forward should fall back to SSH: ${stderr}`); + assert.match(stderr, /Websocket port-forward unavailable; falling back to legacy SSH/); + + const args = readFileSync(sshArgsLog, 'utf8').trim().split('\n'); + assert.ok(args.includes('-N')); + const localIndex = args.indexOf('-L'); + assert.notEqual(localIndex, -1); + assert.equal(args[localIndex + 1], '127.0.0.1:3000:127.0.0.1:4000'); +}); + +test('port-forward fails during startup when websocket validation is rejected for explicit websocket transport', async (t) => { + const localPort = await getFreePort(); + const server = http.createServer((req, res) => { + res.writeHead(503, { 'Content-Type': 'text/plain' }); + res.end('unavailable'); + }); + await listen(server); + t.after(() => { + server.close(); + }); + const address = server.address(); + assert.ok(address && typeof address === 'object'); + + const child = spawnCli( + ['port-forward', 'devbox1', '--transport', 'ws', '--local', String(localPort), '--remote', '4000'], + buildTestEnv(`http://127.0.0.1:${address.port}/api`), + ); + + let stderr = ''; + child.stderr.on('data', (chunk) => { + stderr += chunk.toString(); + }); + + const exitCode = await new Promise<number | null>((resolve) => child.on('exit', resolve)); + assert.notEqual(exitCode, 0, 'spz port-forward should fail before listening'); + assert.match(stderr, /port-forward validation failed: 503/); + + await assert.rejects( + () => + new Promise<void>((resolve, reject) => { + const socket = net.connect(localPort, '127.0.0.1'); + socket.once('connect', () => { + socket.destroy(); + resolve(); + }); + socket.once('error', reject); + }), + ); +}); + test('port-forward rejects missing remote port', async () => { const child = spawnCli( ['port-forward', 'devbox1', '--local', '3000'], diff --git a/docs/2026-04-06-spz-port-forward-architecture.md b/docs/2026-04-06-spz-port-forward-architecture.md index ca3a09c..93a1efc 100644 --- a/docs/2026-04-06-spz-port-forward-architecture.md +++ b/docs/2026-04-06-spz-port-forward-architecture.md @@ -2,7 +2,7 @@ date: 2026-04-06 author: Onur Solmaz <onur@textcortex.com> title: spz Port-Forward Architecture -tags: [spritz, spz, cli, ssh, port-forwarding, architecture] +tags: [spritz, spz, cli, ssh, websocket, https, port-forwarding, architecture] --- ## Overview @@ -20,15 +20,22 @@ spz port-forward <instance> --local <port> --remote <port> This should be the canonical Spritz CLI shape for forwarding a local loopback port to a private port inside one Spritz instance. -It should be implemented as an instance-scoped control-plane feature, not as a -Kubernetes workflow, not as a browser preview product, and not as a -deployment-specific convenience alias. +It should be implemented as an instance-scoped control-plane feature. The +preferred public transport should be the authenticated Spritz control plane +over HTTPS/WebSocket on port `443`, not raw public SSH on port `22`. + +It should not be framed as a Kubernetes workflow, not as a browser preview +product, and not as a deployment-specific convenience alias. ## TL;DR - Spritz core should add `spz port-forward`. -- The command should be explicit and transport-agnostic in meaning, while the - first implementation should reuse the existing SSH credential minting path. +- The command should be explicit and transport-agnostic in meaning. +- The preferred public transport should be authenticated HTTPS/WebSocket over + the existing Spritz control plane on port `443`. +- The current SSH-backed implementation may remain as a fallback, but SSH + should be treated as deprecated for default public use unless somebody + explicitly asks for it. - The command should default to: - local bind host `127.0.0.1` - remote target host `127.0.0.1` @@ -97,8 +104,12 @@ This should be the core primitive. It should mean: - forward to one private remote port inside one named instance - keep the tunnel alive until interrupted -The first implementation should be built on the existing SSH certificate mint -flow already used by `spz ssh`. +The public-facing design should prefer an authenticated control-plane tunnel +over HTTPS/WebSocket on `443`. + +The existing SSH certificate mint flow may still be used as an implementation +fallback where raw TCP is available or explicitly desired, but Spritz should +not require public raw SSH exposure as the default internet-facing transport. The product contract, however, should be described as "instance port forwarding", not as "raw SSH with custom flags". That distinction matters: @@ -107,6 +118,26 @@ forwarding", not as "raw SSH with custom flags". That distinction matters: - the control plane remains the owner of authorization and target resolution - future transports may change without renaming the user-facing intent +## Preferred Public Transport + +The ideal Spritz system should not require a public inbound instance port at +all. + +For public usage, `spz port-forward` should terminate on the existing Spritz +control plane host over HTTPS/WebSocket on `443`, and the control plane should +perform the pod-scoped forwarding inside the cluster. + +This is the preferred architecture because it: + +- works in environments where raw public TCP may be restricted +- keeps auth, policy, rate limiting, and audit under the control plane +- avoids depending on cloud-specific behavior for arbitrary inbound TCP +- gives one clean public access surface instead of separate web and SSH entry + points + +SSH may still exist as a transport, but it should not be the primary public +story. + ## Why `spz port-forward` Instead Of `spz preview` `preview` is the wrong upstream primitive because it encodes application @@ -266,21 +297,28 @@ Press Ctrl+C to stop. ## Relationship To `spz ssh` -`spz ssh` should remain the raw shell-access command. +`spz ssh` should remain the raw shell-access command when explicitly needed. `spz port-forward` should be a sibling command with a narrower purpose: - `spz ssh`: interactive shell access - `spz port-forward`: local access to one private instance port -The implementation may share most of the credential plumbing. +The implementation may still share credential plumbing, but the public default +for `spz port-forward` should not be "SSH unless proven otherwise". -That is good: +That split is still good: - less duplicated auth logic -- one consistent trust and host-verification path - one clear control-plane contract for instance access +For now: + +- `spz ssh` remains available +- SSH-backed `spz port-forward` may remain available +- SSH should be considered deprecated as the default public transport unless a + deployment or operator explicitly asks for it + ## Downstream Wrappers Spritz core should stop at the generic primitive. @@ -297,6 +335,8 @@ contract. The holy grail shape is: - Spritz core provides `spz port-forward` +- Spritz routes public interactive access through one authenticated control + plane on `443` - downstream deployments compose deployment-specific UX on top of it That keeps Spritz portable while still enabling polished local workflows where @@ -307,7 +347,7 @@ needed. ### Phase 1: Core CLI Primitive - add `spz port-forward` -- reuse the existing SSH credential minting path +- keep the user-facing contract transport-agnostic - support one local forward per command invocation - keep both local and remote hosts pinned to loopback @@ -318,11 +358,35 @@ Acceptance criteria: - the command requires no Kubernetes credentials - the command targets one named instance, not a pod -### Phase 2: CLI Help And Tests +### Phase 2: Public Control-Plane Transport + +- implement forwarding over the authenticated Spritz control plane on + HTTPS/WebSocket +- make that path the preferred public transport +- avoid requiring any public raw TCP listener on the instance gateway + +Acceptance criteria: + +- the standard public path works over `443` +- no public per-instance or per-feature raw TCP exposure is required +- authorization remains owned by the Spritz control plane + +### Phase 3: SSH Fallback + +- keep the SSH-backed transport available for private networks, operators, or + deployments that explicitly want it +- document that SSH is a fallback transport, not the preferred public one + +Acceptance criteria: + +- SSH remains available when explicitly requested +- SSH is no longer the default public transport assumption in docs or UX + +### Phase 4: CLI Help And Tests - document the new command in CLI help - add help tests for the new usage line -- add command tests for printed SSH execution shape or equivalent command +- add command tests for printed transport execution shape or equivalent command plumbing Acceptance criteria: @@ -330,7 +394,7 @@ Acceptance criteria: - the new command is discoverable through `spz --help` - printed guidance stays generic and deployment-agnostic -### Phase 3: Downstream Composition +### Phase 5: Downstream Composition - allow downstreams to add wrappers without changing the core primitive - document that application auth remains outside Spritz forwarding @@ -346,10 +410,13 @@ This architecture is successful when all of the following are true: - the standard path for instance port access is `spz port-forward`, not raw `ssh -L` +- the preferred public path runs through the authenticated Spritz control + plane on `443` - the caller does not need Kubernetes credentials - the command works by instance identity rather than pod identity - the default bind scope is local loopback only - the application behind the forwarded port can keep its own auth model +- SSH remains optional rather than mandatory for public use - downstream wrappers can exist without forcing Spritz core to become app-specific