Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 61 additions & 9 deletions cmd/zoekt-webserver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ func main() {
logRefresh := flag.Duration("log_refresh", 24*time.Hour, "if using --log_dir, start writing a new file this often.")

listen := flag.String("listen", ":6070", "listen on this address.")
listenUnix := flag.String("listen_unix", "", "listen on this Unix socket path instead of TCP")
indexDir := flag.String("index", index.DefaultDir, "set index directory to use")
html := flag.Bool("html", true, "enable HTML interface")
enableRPC := flag.Bool("rpc", false, "enable go/net RPC")
Expand Down Expand Up @@ -182,6 +183,9 @@ func main() {
// caller to divert stderr output if necessary.
go divertLogs(*logDir, *logRefresh)
}
if *listenUnix != "" && (*sslCert != "" || *sslKey != "") {
log.Fatal("-listen_unix cannot be combined with -ssl_cert or -ssl_key")
}

// Tune GOMAXPROCS to match Linux container CPU quota.
_, _ = maxprocs.Set()
Expand Down Expand Up @@ -278,10 +282,15 @@ func main() {
if *sslCert != "" || *sslKey != "" {
watchdogAddr = "https://" + *listen
}
watchdogUnixSocket := ""
if *listenUnix != "" {
watchdogUnixSocket = *listenUnix
watchdogAddr = "http://unix"
}
watchdogAddr += "/healthz"

if watchdogErrCount > 0 && watchdogTick > 0 {
go watchdog(watchdogTick, watchdogErrCount, watchdogAddr)
go watchdog(watchdogTick, watchdogErrCount, watchdogAddr, watchdogUnixSocket)
} else {
log.Println("watchdog disabled")
}
Expand All @@ -299,19 +308,15 @@ func main() {
Handler: handler,
}

serveErrCh := make(chan error, 1)
go func() {
sglog.Scoped("server").Info("starting server", sglog.Stringp("address", listen))
var err error
if *sslCert != "" || *sslKey != "" {
err = srv.ListenAndServeTLS(*sslCert, *sslKey)
} else {
err = srv.ListenAndServe()
}
err := serveHTTP(srv, *listenUnix, *sslCert, *sslKey)

if err != http.ErrServerClosed {
// Fatal otherwise shutdownOnSignal will block
log.Fatalf("ListenAndServe: %v", err)
}
serveErrCh <- err
}()

if s.RPC {
Expand All @@ -326,6 +331,47 @@ func main() {
log.Fatalf("http.Server.Shutdown: %v", err)
}
}
<-serveErrCh
}

func serveHTTP(srv *http.Server, unixSocket, sslCert, sslKey string) error {
logger := sglog.Scoped("server")
if unixSocket != "" {
l, err := listenUnixSocket(unixSocket)
if err != nil {
return err
}
logger.Info("starting server", sglog.String("unixSocket", unixSocket))
return srv.Serve(l)
}

logger.Info("starting server", sglog.String("address", srv.Addr))
if sslCert != "" || sslKey != "" {
return srv.ListenAndServeTLS(sslCert, sslKey)
}
return srv.ListenAndServe()
}

func listenUnixSocket(socket string) (net.Listener, error) {
// We cannot bind a socket to an existing pathname.
if err := os.Remove(socket); err != nil && !os.IsNotExist(err) {
return nil, fmt.Errorf("error removing socket file: %s", socket)
}

l, err := net.Listen("unix", socket)
if err != nil {
return nil, fmt.Errorf("failed to listen on socket %s: %w", socket, err)
}

// nginx and zoekt-webserver often run as different users. Make the socket
// broadly connectable like the indexserver socket used by this binary's
// reverse proxy support.
if err := os.Chmod(socket, 0o777); err != nil {
_ = l.Close()
return nil, fmt.Errorf("failed to change permission of socket %s: %w", socket, err)
}

return l, nil
}

// addProxyHandler adds a handler to "mux" that proxies all requests with base
Expand Down Expand Up @@ -424,10 +470,16 @@ func watchdogOnce(ctx context.Context, client *http.Client, addr string) error {
return nil
}

func watchdog(dt time.Duration, maxErrCount int, addr string) {
func watchdog(dt time.Duration, maxErrCount int, addr, unixSocket string) {
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
if unixSocket != "" {
tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, "unix", unixSocket)
}
}
client := &http.Client{
Transport: tr,
}
Expand Down
96 changes: 96 additions & 0 deletions cmd/zoekt-webserver/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package main

import (
"context"
"errors"
"io"
"net"
"net/http"
"os"
"path/filepath"
"testing"
"time"
)

func TestServeHTTPUnixSocket(t *testing.T) {
socket := filepath.Join(t.TempDir(), "zoekt.sock")
if err := os.WriteFile(socket, []byte("stale"), 0o600); err != nil {
t.Fatal(err)
}

srv := &http.Server{
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/healthz" {
http.NotFound(w, r)
return
}
_, _ = io.WriteString(w, "ok")
}),
}

errCh := make(chan error, 1)
go func() {
errCh <- serveHTTP(srv, socket, "", "")
}()

client := &http.Client{
Timeout: time.Second,
Transport: &http.Transport{
DialContext: func(ctx context.Context, _, _ string) (net.Conn, error) {
var d net.Dialer
return d.DialContext(ctx, "unix", socket)
},
},
}

var resp *http.Response
var err error
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
resp, err = client.Get("http://unix/healthz")
if err == nil {
break
}
time.Sleep(10 * time.Millisecond)
}
if err != nil {
t.Fatalf("GET over unix socket failed: %v", err)
}

if resp.StatusCode != http.StatusOK {
t.Fatalf("got status %d, want %d", resp.StatusCode, http.StatusOK)
}
body, err := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if err != nil {
t.Fatal(err)
}
if string(body) != "ok" {
t.Fatalf("got body %q, want %q", string(body), "ok")
}
if mode := socketMode(t, socket); mode != 0o777 {
t.Fatalf("got socket mode %o, want 777", mode)
}

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := srv.Shutdown(ctx); err != nil {
t.Fatal(err)
}

if err := <-errCh; !errors.Is(err, http.ErrServerClosed) {
t.Fatalf("serveHTTP returned %v, want %v", err, http.ErrServerClosed)
}
if _, err := os.Stat(socket); !os.IsNotExist(err) {
t.Fatalf("socket was not removed after shutdown: %v", err)
}
}

func socketMode(t *testing.T, socket string) os.FileMode {
t.Helper()
fi, err := os.Stat(socket)
if err != nil {
t.Fatal(err)
}
return fi.Mode().Perm()
}
Loading