From a40b73b079abb4c557cd183711fd840d1f837c68 Mon Sep 17 00:00:00 2001 From: Adi <6841988+DeepSpace2@users.noreply.github.com> Date: Tue, 16 Jun 2026 20:24:27 +0300 Subject: [PATCH 1/4] fix: close pi-hole sessions closes #50 --- main.go | 1 + pkg/clients/pihole/errors.go | 10 ++++++ pkg/clients/pihole/pihole.go | 59 +++++++++++++++++++++---------- pkg/clients/pihole/pihole_test.go | 56 +++++++++++++++++++++++++++++ pkg/processor/processor.go | 8 +++++ 5 files changed, 115 insertions(+), 19 deletions(-) create mode 100644 pkg/clients/pihole/errors.go diff --git a/main.go b/main.go index 894390d..646f141 100644 --- a/main.go +++ b/main.go @@ -78,6 +78,7 @@ func main() { <-ctx.Done() log.Info("Shutdown signal received, exiting gracefully.") + proc.Shutdown() wg.Wait() log.Info("Shutdown complete.") } diff --git a/pkg/clients/pihole/errors.go b/pkg/clients/pihole/errors.go new file mode 100644 index 0000000..a60e052 --- /dev/null +++ b/pkg/clients/pihole/errors.go @@ -0,0 +1,10 @@ +package pihole + +import ( + "errors" +) + +var ( + errAuthRefreshFailed = errors.New("failed to refresh Pi-Hole authentication") + errMissingSessionId = errors.New("missing Pi-Hole session ID") +) diff --git a/pkg/clients/pihole/pihole.go b/pkg/clients/pihole/pihole.go index e10502e..488ac35 100644 --- a/pkg/clients/pihole/pihole.go +++ b/pkg/clients/pihole/pihole.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "net/http" - "os" "strings" "github.com/deepspace2/plugnpin/pkg/clients/common" @@ -55,6 +54,23 @@ func (p *Client) Login(password string) error { return nil } +func (p *Client) Logout() error { + if p.sid == "" { + return nil + } + headers["X-FTL-SID"] = p.sid + _, statusCode, err := common.Delete(&p.Client, p.baseURL+"/auth", headers) + if err != nil { + return err + } + p.sid = "" + p.password = "" + if statusCode >= 400 { + log.Warn("Pi-Hole logout returned non-success status", "status", statusCode) + } + return nil +} + func rawDnsRecordToRecord(rawDnsRecord string) (DomainName, IP, error) { splitRawDnsRecord := strings.Split(rawDnsRecord, " ") if len(splitRawDnsRecord) == 2 { @@ -72,8 +88,7 @@ func dnsRecordToRaw(domain DomainName, ip IP) string { func (p *Client) GetDnsRecords() (DnsRecords, error) { if p.sid == "" { - log.Error("Missing Pi-Hole session ID") - os.Exit(1) + return nil, errMissingSessionId } headers["X-FTL-SID"] = p.sid configResponseString, _, err := common.Get(&p.Client, p.baseURL+"/config", headers) @@ -127,8 +142,7 @@ func (p *Client) AddDnsRecords(domains []string, ip string) (numOfAddedDnsRecord } if p.sid == "" { - log.Error("Missing Pi-Hole session ID") - os.Exit(1) + return 0, errMissingSessionId } headers["X-FTL-SID"] = p.sid resp, statusCode, err := common.Patch(&p.Client, p.baseURL+"/config", headers, string(payloadString)) @@ -137,7 +151,9 @@ func (p *Client) AddDnsRecords(domains []string, ip string) (numOfAddedDnsRecord } if statusCode == 401 { - p.refreshAuth() + if err := p.refreshAuth(); err != nil { + return 0, errors.Join(errAuthRefreshFailed, err) + } return p.AddDnsRecords(domains, ip) } @@ -183,8 +199,7 @@ func (p *Client) DeleteDnsRecords(domains []string) (numOfDeletedDnsRecords int, } if p.sid == "" { - log.Error("Missing Pi-Hole session ID") - os.Exit(1) + return 0, errMissingSessionId } headers["X-FTL-SID"] = p.sid resp, statusCode, err := common.Patch(&p.Client, p.baseURL+"/config", headers, string(payloadString)) @@ -193,7 +208,9 @@ func (p *Client) DeleteDnsRecords(domains []string) (numOfDeletedDnsRecords int, } if statusCode == 401 { - p.refreshAuth() + if err := p.refreshAuth(); err != nil { + return 0, errors.Join(errAuthRefreshFailed, err) + } return p.DeleteDnsRecords(domains) } @@ -223,8 +240,7 @@ func cNameRecordToRaw(domain DomainName, target Target) string { func (p *Client) getCNameRecords() (CNameRecords, error) { if p.sid == "" { - log.Error("Missing Pi-Hole session ID") - os.Exit(1) + return nil, errMissingSessionId } headers["X-FTL-SID"] = p.sid configResponseString, _, err := common.Get(&p.Client, p.baseURL+"/config", headers) @@ -278,8 +294,7 @@ func (p *Client) AddCNameRecords(domains []string, target string) (numOfAddedCNa } if p.sid == "" { - log.Error("Missing Pi-Hole session ID") - os.Exit(1) + return numOfAddedCNameRecords, errMissingSessionId } headers["X-FTL-SID"] = p.sid resp, statusCode, err := common.Patch(&p.Client, p.baseURL+"/config", headers, string(payloadString)) @@ -288,7 +303,9 @@ func (p *Client) AddCNameRecords(domains []string, target string) (numOfAddedCNa } if statusCode == 401 { - p.refreshAuth() + if err := p.refreshAuth(); err != nil { + return numOfAddedCNameRecords, errors.Join(errAuthRefreshFailed, err) + } return p.AddCNameRecords(domains, target) } @@ -334,8 +351,7 @@ func (p *Client) DeleteCNameRecords(domains []string) (numOfDeletedCNameRecords } if p.sid == "" { - log.Error("Missing Pi-Hole session ID") - os.Exit(1) + return numOfDeletedCNameRecords, errMissingSessionId } headers["X-FTL-SID"] = p.sid resp, statusCode, err := common.Patch(&p.Client, p.baseURL+"/config", headers, string(payloadString)) @@ -344,7 +360,9 @@ func (p *Client) DeleteCNameRecords(domains []string) (numOfDeletedCNameRecords } if statusCode == 401 { - p.refreshAuth() + if err := p.refreshAuth(); err != nil { + return numOfDeletedCNameRecords, errors.Join(errAuthRefreshFailed, err) + } return p.DeleteCNameRecords(domains) } @@ -357,7 +375,10 @@ func (p *Client) DeleteCNameRecords(domains []string) (numOfDeletedCNameRecords return len(deletedDomains), nil } -func (p *Client) refreshAuth() { +func (p *Client) refreshAuth() error { log.Info("Refreshing Pi-Hole authentication") - p.Login(p.password) + if err := p.Logout(); err != nil { + log.Warn("Failed to logout old Pi-Hole session", "error", err) + } + return p.Login(p.password) } diff --git a/pkg/clients/pihole/pihole_test.go b/pkg/clients/pihole/pihole_test.go index 76aea43..f09270e 100644 --- a/pkg/clients/pihole/pihole_test.go +++ b/pkg/clients/pihole/pihole_test.go @@ -60,6 +60,62 @@ func TestLogin(t *testing.T) { }) } +func TestLogout(t *testing.T) { + t.Run("successful logout", func(t *testing.T) { + deleteCalled := false + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/auth", r.URL.Path) + assert.Equal(t, "DELETE", r.Method) + assert.Equal(t, "test-sid", r.Header.Get("X-FTL-SID")) + deleteCalled = true + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, `{"session": {"sid": "", "message": "Session deleted"}}`) + }) + client, server := setupTestServer(handler) + defer server.Close() + client.sid = "test-sid" + client.password = "test-password" + + err := client.Logout() + + assert.NoError(t, err) + assert.True(t, deleteCalled, "DELETE /api/auth was not called") + assert.Empty(t, client.sid) + assert.Empty(t, client.password) + }) + + t.Run("no-op when already logged out", func(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Fatal("Unexpected request when sid is empty") + }) + client, server := setupTestServer(handler) + defer server.Close() + + err := client.Logout() + + assert.NoError(t, err) + }) + + t.Run("non-success status is non-fatal", func(t *testing.T) { + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/api/auth", r.URL.Path) + assert.Equal(t, "DELETE", r.Method) + w.WriteHeader(http.StatusNotFound) + fmt.Fprint(w, `{}`) + }) + client, server := setupTestServer(handler) + defer server.Close() + client.sid = "test-sid" + client.password = "test-password" + + err := client.Logout() + + assert.NoError(t, err) + assert.Empty(t, client.sid) + assert.Empty(t, client.password) + }) +} + func TestAddDnsRecords(t *testing.T) { t.Run("successful add multiple", func(t *testing.T) { // This handler needs to handle two requests in sequence diff --git a/pkg/processor/processor.go b/pkg/processor/processor.go index 668888a..8b36cc2 100644 --- a/pkg/processor/processor.go +++ b/pkg/processor/processor.go @@ -290,6 +290,14 @@ func (p *Processor) handleNpm(ctx context.Context, containerEvent events.Action, } } +func (p *Processor) Shutdown() { + if p.piholeClient != nil { + if err := p.piholeClient.Logout(); err != nil { + log.Warn("Failed to logout from Pi-Hole", "error", err) + } + } +} + func (p *Processor) processContainer(ctx context.Context, containerEvent events.Action, containerId string, dockerClient *docker.Client, containerName, ip string, urls []string, port int, opts *docker.ClientOptions) { log := log.With( "container", containerName, From 2d692fa7980a6490ede43129fa706df4bd5ebb10 Mon Sep 17 00:00:00 2001 From: Adi <6841988+DeepSpace2@users.noreply.github.com> Date: Thu, 25 Jun 2026 20:21:41 +0300 Subject: [PATCH 2/4] fix bug in pihole.Logout - storing password on pihole client --- e2e_tests/e2e_test.go | 4 ++-- pkg/clients/clients.go | 4 ++-- pkg/clients/pihole/pihole.go | 15 +++++++------- pkg/clients/pihole/pihole_test.go | 34 ++++++++++++++----------------- 4 files changed, 26 insertions(+), 31 deletions(-) diff --git a/e2e_tests/e2e_test.go b/e2e_tests/e2e_test.go index 8cea682..089b670 100644 --- a/e2e_tests/e2e_test.go +++ b/e2e_tests/e2e_test.go @@ -232,7 +232,7 @@ func setClients(t *testing.T, containers []Container) (*docker.Client, *pihole.C } } - piholeClient := pihole.NewClient(piholeURL) + piholeClient := pihole.NewClient(piholeURL, "password") logger.Info("Waiting for Pi-hole to be ready...") piholeLoginTimeout := time.After(60 * time.Second) piholeLoginTicker := time.NewTicker(3 * time.Second) @@ -243,7 +243,7 @@ PiholeLoginLoop: case <-piholeLoginTimeout: t.Fatalf("Timed out waiting for Pi-hole to be ready at %s", piholeURL) case <-piholeLoginTicker.C: - err = piholeClient.Login("password") + err = piholeClient.Login() if err == nil { logger.Info("Successfully logged into Pi-hole") break PiholeLoginLoop diff --git a/pkg/clients/clients.go b/pkg/clients/clients.go index c9b6c0f..778b1b6 100644 --- a/pkg/clients/clients.go +++ b/pkg/clients/clients.go @@ -19,8 +19,8 @@ func GetClients(cliFlags cli.Flags, config *config.Config) (map[string]*docker.C if !cliFlags.DryRun { if !config.PiholeDisabled { - piholeClient = pihole.NewClient(config.PiholeHost) - err := piholeClient.Login(config.PiholePassword) + piholeClient = pihole.NewClient(config.PiholeHost, config.PiholePassword) + err := piholeClient.Login() if err != nil { log.Error("Failed to login to Pi-Hole", "error", err) return nil, nil, nil, nil, err diff --git a/pkg/clients/pihole/pihole.go b/pkg/clients/pihole/pihole.go index 488ac35..fb702a7 100644 --- a/pkg/clients/pihole/pihole.go +++ b/pkg/clients/pihole/pihole.go @@ -26,18 +26,19 @@ var headers map[string]string = map[string]string{ "content-type": "application/json", } -func NewClient(baseURL string) *Client { +func NewClient(baseURL, password string) *Client { return &Client{ Client: http.Client{ Transport: common.NewInstrumentedRoundTripper(metrics.PI_HOLE, metrics.ObserveApiRequestDuration), }, - baseURL: fmt.Sprintf("%v/api", baseURL), - sid: "", + baseURL: fmt.Sprintf("%v/api", baseURL), + password: password, + sid: "", } } -func (p *Client) Login(password string) error { - loginPayload := fmt.Sprintf(`{"password": "%v"}`, password) +func (p *Client) Login() error { + loginPayload := fmt.Sprintf(`{"password": "%v"}`, p.password) loginResponseString, statusCode, err := common.Post(&p.Client, p.baseURL+"/auth", headers, &loginPayload) if err != nil { return err @@ -49,7 +50,6 @@ func (p *Client) Login(password string) error { return errors.New(resp.Session.Message) } - p.password = password p.sid = resp.Session.Sid return nil } @@ -64,7 +64,6 @@ func (p *Client) Logout() error { return err } p.sid = "" - p.password = "" if statusCode >= 400 { log.Warn("Pi-Hole logout returned non-success status", "status", statusCode) } @@ -380,5 +379,5 @@ func (p *Client) refreshAuth() error { if err := p.Logout(); err != nil { log.Warn("Failed to logout old Pi-Hole session", "error", err) } - return p.Login(p.password) + return p.Login() } diff --git a/pkg/clients/pihole/pihole_test.go b/pkg/clients/pihole/pihole_test.go index f09270e..2027991 100644 --- a/pkg/clients/pihole/pihole_test.go +++ b/pkg/clients/pihole/pihole_test.go @@ -13,9 +13,9 @@ import ( ) // setupTestServer creates a new test server and a client pointing to it. -func setupTestServer(handler http.Handler) (*Client, *httptest.Server) { +func setupTestServer(handler http.Handler, password string) (*Client, *httptest.Server) { server := httptest.NewServer(handler) - client := NewClient(server.URL) + client := NewClient(server.URL, password) client.Client = *server.Client() // Replace the default client with the test server's client return client, server } @@ -35,10 +35,10 @@ func TestLogin(t *testing.T) { // The actual API response is more complex, so we mock the whole thing fmt.Fprint(w, `{"session": {"sid": "test-sid", "message": "Login successful"}}`) }) - client, server := setupTestServer(handler) + client, server := setupTestServer(handler, "test-password") defer server.Close() - err := client.Login("test-password") + err := client.Login() assert.NoError(t, err) assert.Equal(t, "test-sid", client.sid) @@ -49,10 +49,10 @@ func TestLogin(t *testing.T) { w.WriteHeader(http.StatusUnauthorized) fmt.Fprint(w, `{"session": {"sid": "", "message": "Invalid password"}}`) }) - client, server := setupTestServer(handler) + client, server := setupTestServer(handler, "wrong-password") defer server.Close() - err := client.Login("wrong-password") + err := client.Login() assert.Error(t, err) assert.Equal(t, "Invalid password", err.Error()) @@ -71,24 +71,22 @@ func TestLogout(t *testing.T) { w.WriteHeader(http.StatusOK) fmt.Fprint(w, `{"session": {"sid": "", "message": "Session deleted"}}`) }) - client, server := setupTestServer(handler) + client, server := setupTestServer(handler, "test-password") defer server.Close() client.sid = "test-sid" - client.password = "test-password" err := client.Logout() assert.NoError(t, err) assert.True(t, deleteCalled, "DELETE /api/auth was not called") assert.Empty(t, client.sid) - assert.Empty(t, client.password) }) t.Run("no-op when already logged out", func(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { t.Fatal("Unexpected request when sid is empty") }) - client, server := setupTestServer(handler) + client, server := setupTestServer(handler, "test-password") defer server.Close() err := client.Logout() @@ -103,16 +101,14 @@ func TestLogout(t *testing.T) { w.WriteHeader(http.StatusNotFound) fmt.Fprint(w, `{}`) }) - client, server := setupTestServer(handler) + client, server := setupTestServer(handler, "test-password") defer server.Close() client.sid = "test-sid" - client.password = "test-password" err := client.Logout() assert.NoError(t, err) assert.Empty(t, client.sid) - assert.Empty(t, client.password) }) } @@ -150,7 +146,7 @@ func TestAddDnsRecords(t *testing.T) { t.Fatalf("Received unexpected request: %s %s", r.Method, r.URL.Path) }) - client, server := setupTestServer(handler) + client, server := setupTestServer(handler, "test-password") defer server.Close() // Manually set the session ID that the login step would have provided @@ -193,7 +189,7 @@ func TestDeleteDnsRecords(t *testing.T) { t.Fatalf("Received unexpected request: %s %s", r.Method, r.URL.Path) }) - client, server := setupTestServer(handler) + client, server := setupTestServer(handler, "test-password") defer server.Close() client.sid = "test-sid" @@ -217,7 +213,7 @@ func TestDeleteDnsRecords(t *testing.T) { } }) - client, server := setupTestServer(handler) + client, server := setupTestServer(handler, "test-password") defer server.Close() client.sid = "test-sid" @@ -304,7 +300,7 @@ func TestAddCNameRecords(t *testing.T) { t.Fatalf("Received unexpected request: %s %s", r.Method, r.URL.Path) }) - client, server := setupTestServer(handler) + client, server := setupTestServer(handler, "test-password") defer server.Close() // Manually set the session ID that the login step would have provided @@ -347,7 +343,7 @@ func TestDeleteCNameRecords(t *testing.T) { t.Fatalf("Received unexpected request: %s %s", r.Method, r.URL.Path) }) - client, server := setupTestServer(handler) + client, server := setupTestServer(handler, "test-password") defer server.Close() client.sid = "test-sid" @@ -371,7 +367,7 @@ func TestDeleteCNameRecords(t *testing.T) { } }) - client, server := setupTestServer(handler) + client, server := setupTestServer(handler, "test-password") defer server.Close() client.sid = "test-sid" From b48027f0fbae3ffb0f26b767db2034571173bed6 Mon Sep 17 00:00:00 2001 From: Adi <6841988+DeepSpace2@users.noreply.github.com> Date: Thu, 25 Jun 2026 21:04:43 +0300 Subject: [PATCH 3/4] add timeout to pihole.Logout --- pkg/clients/common/common.go | 36 ++++++++++++++++++++++++++++-------- pkg/clients/pihole/pihole.go | 7 ++++++- 2 files changed, 34 insertions(+), 9 deletions(-) diff --git a/pkg/clients/common/common.go b/pkg/clients/common/common.go index c1ccbb6..d8a88b4 100644 --- a/pkg/clients/common/common.go +++ b/pkg/clients/common/common.go @@ -1,6 +1,7 @@ package common import ( + "context" "fmt" "io" "net/http" @@ -126,6 +127,23 @@ func Patch(client *http.Client, path string, headers map[string]string, data str return string(body), resp.StatusCode, nil } +func doDeleteRequest(req *http.Request, client *http.Client, headers map[string]string) (string, int, error) { + setHeaders(req, headers) + + resp, err := client.Do(req) + if err != nil { + return "", 0, err + } + + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", 0, err + } + + return string(body), resp.StatusCode, nil +} + func Delete(client *http.Client, path string, headers map[string]string) (string, int, error) { req, err := http.NewRequest( http.MethodDelete, @@ -136,17 +154,19 @@ func Delete(client *http.Client, path string, headers map[string]string) (string return "", 0, err } - setHeaders(req, headers) + return doDeleteRequest(req, client, headers) +} - resp, err := client.Do(req) +func DeleteWithContext(ctx context.Context, client *http.Client, path string, headers map[string]string) (string, int, error) { + req, err := http.NewRequestWithContext( + ctx, + http.MethodDelete, + path, + nil, + ) if err != nil { return "", 0, err } - defer resp.Body.Close() - body, err := io.ReadAll(resp.Body) - if err != nil { - return "", 0, err - } - return string(body), resp.StatusCode, nil + return doDeleteRequest(req, client, headers) } diff --git a/pkg/clients/pihole/pihole.go b/pkg/clients/pihole/pihole.go index fb702a7..d1e9d0a 100644 --- a/pkg/clients/pihole/pihole.go +++ b/pkg/clients/pihole/pihole.go @@ -1,11 +1,13 @@ package pihole import ( + "context" "encoding/json" "errors" "fmt" "net/http" "strings" + "time" "github.com/deepspace2/plugnpin/pkg/clients/common" "github.com/deepspace2/plugnpin/pkg/logging" @@ -58,8 +60,11 @@ func (p *Client) Logout() error { if p.sid == "" { return nil } + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + headers["X-FTL-SID"] = p.sid - _, statusCode, err := common.Delete(&p.Client, p.baseURL+"/auth", headers) + _, statusCode, err := common.DeleteWithContext(ctx, &p.Client, p.baseURL+"/auth", headers) if err != nil { return err } From 3e21710e517f09e56ca598c4cafb5058910a8e6c Mon Sep 17 00:00:00 2001 From: Adi <6841988+DeepSpace2@users.noreply.github.com> Date: Thu, 25 Jun 2026 21:30:47 +0300 Subject: [PATCH 4/4] fix usage of proc.Shutdown --- main.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/main.go b/main.go index 646f141..58c53d5 100644 --- a/main.go +++ b/main.go @@ -46,6 +46,7 @@ func main() { } proc := processor.New(dockerClients, adguardHomeClient, piholeClient, npmClient, cliFlags.DryRun) + defer proc.Shutdown() if config.RunInterval == 0 { log.Info("RUN_INTERVAL is 0, will run once") @@ -78,7 +79,6 @@ func main() { <-ctx.Done() log.Info("Shutdown signal received, exiting gracefully.") - proc.Shutdown() wg.Wait() log.Info("Shutdown complete.") }