diff --git a/cmd/zoekt-webserver/main.go b/cmd/zoekt-webserver/main.go index d9e5c5773..79c82337f 100644 --- a/cmd/zoekt-webserver/main.go +++ b/cmd/zoekt-webserver/main.go @@ -131,6 +131,7 @@ func main() { logRefresh := flag.Duration("log_refresh", 24*time.Hour, "if using --log_dir, start writing a new file this often.") listen := flag.String("listen", ":6070", "listen on this address.") + listenUnix := flag.String("listen_unix", "", "listen on this Unix socket path instead of TCP") indexDir := flag.String("index", index.DefaultDir, "set index directory to use") html := flag.Bool("html", true, "enable HTML interface") enableRPC := flag.Bool("rpc", false, "enable go/net RPC") @@ -182,6 +183,9 @@ func main() { // caller to divert stderr output if necessary. go divertLogs(*logDir, *logRefresh) } + if *listenUnix != "" && (*sslCert != "" || *sslKey != "") { + log.Fatal("-listen_unix cannot be combined with -ssl_cert or -ssl_key") + } // Tune GOMAXPROCS to match Linux container CPU quota. _, _ = maxprocs.Set() @@ -278,10 +282,15 @@ func main() { if *sslCert != "" || *sslKey != "" { watchdogAddr = "https://" + *listen } + watchdogUnixSocket := "" + if *listenUnix != "" { + watchdogUnixSocket = *listenUnix + watchdogAddr = "http://unix" + } watchdogAddr += "/healthz" if watchdogErrCount > 0 && watchdogTick > 0 { - go watchdog(watchdogTick, watchdogErrCount, watchdogAddr) + go watchdog(watchdogTick, watchdogErrCount, watchdogAddr, watchdogUnixSocket) } else { log.Println("watchdog disabled") } @@ -299,19 +308,15 @@ func main() { Handler: handler, } + serveErrCh := make(chan error, 1) go func() { - sglog.Scoped("server").Info("starting server", sglog.Stringp("address", listen)) - var err error - if *sslCert != "" || *sslKey != "" { - err = srv.ListenAndServeTLS(*sslCert, *sslKey) - } else { - err = srv.ListenAndServe() - } + err := serveHTTP(srv, *listenUnix, *sslCert, *sslKey) if err != http.ErrServerClosed { // Fatal otherwise shutdownOnSignal will block log.Fatalf("ListenAndServe: %v", err) } + serveErrCh <- err }() if s.RPC { @@ -326,6 +331,47 @@ func main() { log.Fatalf("http.Server.Shutdown: %v", err) } } + <-serveErrCh +} + +func serveHTTP(srv *http.Server, unixSocket, sslCert, sslKey string) error { + logger := sglog.Scoped("server") + if unixSocket != "" { + l, err := listenUnixSocket(unixSocket) + if err != nil { + return err + } + logger.Info("starting server", sglog.String("unixSocket", unixSocket)) + return srv.Serve(l) + } + + logger.Info("starting server", sglog.String("address", srv.Addr)) + if sslCert != "" || sslKey != "" { + return srv.ListenAndServeTLS(sslCert, sslKey) + } + return srv.ListenAndServe() +} + +func listenUnixSocket(socket string) (net.Listener, error) { + // We cannot bind a socket to an existing pathname. + if err := os.Remove(socket); err != nil && !os.IsNotExist(err) { + return nil, fmt.Errorf("error removing socket file: %s", socket) + } + + l, err := net.Listen("unix", socket) + if err != nil { + return nil, fmt.Errorf("failed to listen on socket %s: %w", socket, err) + } + + // nginx and zoekt-webserver often run as different users. Make the socket + // broadly connectable like the indexserver socket used by this binary's + // reverse proxy support. + if err := os.Chmod(socket, 0o777); err != nil { + _ = l.Close() + return nil, fmt.Errorf("failed to change permission of socket %s: %w", socket, err) + } + + return l, nil } // addProxyHandler adds a handler to "mux" that proxies all requests with base @@ -424,10 +470,16 @@ func watchdogOnce(ctx context.Context, client *http.Client, addr string) error { return nil } -func watchdog(dt time.Duration, maxErrCount int, addr string) { +func watchdog(dt time.Duration, maxErrCount int, addr, unixSocket string) { tr := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } + if unixSocket != "" { + tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "unix", unixSocket) + } + } client := &http.Client{ Transport: tr, } diff --git a/cmd/zoekt-webserver/main_test.go b/cmd/zoekt-webserver/main_test.go new file mode 100644 index 000000000..b622274a8 --- /dev/null +++ b/cmd/zoekt-webserver/main_test.go @@ -0,0 +1,96 @@ +package main + +import ( + "context" + "errors" + "io" + "net" + "net/http" + "os" + "path/filepath" + "testing" + "time" +) + +func TestServeHTTPUnixSocket(t *testing.T) { + socket := filepath.Join(t.TempDir(), "zoekt.sock") + if err := os.WriteFile(socket, []byte("stale"), 0o600); err != nil { + t.Fatal(err) + } + + srv := &http.Server{ + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/healthz" { + http.NotFound(w, r) + return + } + _, _ = io.WriteString(w, "ok") + }), + } + + errCh := make(chan error, 1) + go func() { + errCh <- serveHTTP(srv, socket, "", "") + }() + + client := &http.Client{ + Timeout: time.Second, + Transport: &http.Transport{ + DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, "unix", socket) + }, + }, + } + + var resp *http.Response + var err error + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + resp, err = client.Get("http://unix/healthz") + if err == nil { + break + } + time.Sleep(10 * time.Millisecond) + } + if err != nil { + t.Fatalf("GET over unix socket failed: %v", err) + } + + if resp.StatusCode != http.StatusOK { + t.Fatalf("got status %d, want %d", resp.StatusCode, http.StatusOK) + } + body, err := io.ReadAll(resp.Body) + _ = resp.Body.Close() + if err != nil { + t.Fatal(err) + } + if string(body) != "ok" { + t.Fatalf("got body %q, want %q", string(body), "ok") + } + if mode := socketMode(t, socket); mode != 0o777 { + t.Fatalf("got socket mode %o, want 777", mode) + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if err := srv.Shutdown(ctx); err != nil { + t.Fatal(err) + } + + if err := <-errCh; !errors.Is(err, http.ErrServerClosed) { + t.Fatalf("serveHTTP returned %v, want %v", err, http.ErrServerClosed) + } + if _, err := os.Stat(socket); !os.IsNotExist(err) { + t.Fatalf("socket was not removed after shutdown: %v", err) + } +} + +func socketMode(t *testing.T, socket string) os.FileMode { + t.Helper() + fi, err := os.Stat(socket) + if err != nil { + t.Fatal(err) + } + return fi.Mode().Perm() +}