Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions e2e_tests/e2e_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ func setClients(t *testing.T, containers []Container) (*docker.Client, *pihole.C
}
}

piholeClient := pihole.NewClient(piholeURL)
piholeClient := pihole.NewClient(piholeURL, "password")
logger.Info("Waiting for Pi-hole to be ready...")
piholeLoginTimeout := time.After(60 * time.Second)
piholeLoginTicker := time.NewTicker(3 * time.Second)
Expand All @@ -243,7 +243,7 @@ PiholeLoginLoop:
case <-piholeLoginTimeout:
t.Fatalf("Timed out waiting for Pi-hole to be ready at %s", piholeURL)
case <-piholeLoginTicker.C:
err = piholeClient.Login("password")
err = piholeClient.Login()
if err == nil {
logger.Info("Successfully logged into Pi-hole")
break PiholeLoginLoop
Expand Down
1 change: 1 addition & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ func main() {
}

proc := processor.New(dockerClients, adguardHomeClient, piholeClient, npmClient, cliFlags.DryRun)
defer proc.Shutdown()

if config.RunInterval == 0 {
log.Info("RUN_INTERVAL is 0, will run once")
Expand Down
4 changes: 2 additions & 2 deletions pkg/clients/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ func GetClients(cliFlags cli.Flags, config *config.Config) (map[string]*docker.C

if !cliFlags.DryRun {
if !config.PiholeDisabled {
piholeClient = pihole.NewClient(config.PiholeHost)
err := piholeClient.Login(config.PiholePassword)
piholeClient = pihole.NewClient(config.PiholeHost, config.PiholePassword)
err := piholeClient.Login()
if err != nil {
log.Error("Failed to login to Pi-Hole", "error", err)
return nil, nil, nil, nil, err
Expand Down
36 changes: 28 additions & 8 deletions pkg/clients/common/common.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package common

import (
"context"
"fmt"
"io"
"net/http"
Expand Down Expand Up @@ -126,6 +127,23 @@ func Patch(client *http.Client, path string, headers map[string]string, data str
return string(body), resp.StatusCode, nil
}

func doDeleteRequest(req *http.Request, client *http.Client, headers map[string]string) (string, int, error) {
setHeaders(req, headers)

resp, err := client.Do(req)
if err != nil {
return "", 0, err
}

defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", 0, err
}

return string(body), resp.StatusCode, nil
}

func Delete(client *http.Client, path string, headers map[string]string) (string, int, error) {
req, err := http.NewRequest(
http.MethodDelete,
Expand All @@ -136,17 +154,19 @@ func Delete(client *http.Client, path string, headers map[string]string) (string
return "", 0, err
}

setHeaders(req, headers)
return doDeleteRequest(req, client, headers)
}

resp, err := client.Do(req)
func DeleteWithContext(ctx context.Context, client *http.Client, path string, headers map[string]string) (string, int, error) {
req, err := http.NewRequestWithContext(
ctx,
http.MethodDelete,
path,
nil,
)
if err != nil {
return "", 0, err
}

defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return "", 0, err
}
return string(body), resp.StatusCode, nil
return doDeleteRequest(req, client, headers)
}
10 changes: 10 additions & 0 deletions pkg/clients/pihole/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package pihole

import (
"errors"
)

var (
errAuthRefreshFailed = errors.New("failed to refresh Pi-Hole authentication")
errMissingSessionId = errors.New("missing Pi-Hole session ID")
)
75 changes: 50 additions & 25 deletions pkg/clients/pihole/pihole.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package pihole

import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"os"
"strings"
"time"

"github.com/deepspace2/plugnpin/pkg/clients/common"
"github.com/deepspace2/plugnpin/pkg/logging"
Expand All @@ -27,18 +28,19 @@ var headers map[string]string = map[string]string{
"content-type": "application/json",
}

func NewClient(baseURL string) *Client {
func NewClient(baseURL, password string) *Client {
return &Client{
Client: http.Client{
Transport: common.NewInstrumentedRoundTripper(metrics.PI_HOLE, metrics.ObserveApiRequestDuration),
},
baseURL: fmt.Sprintf("%v/api", baseURL),
sid: "",
baseURL: fmt.Sprintf("%v/api", baseURL),
password: password,
sid: "",
}
}

func (p *Client) Login(password string) error {
loginPayload := fmt.Sprintf(`{"password": "%v"}`, password)
func (p *Client) Login() error {
loginPayload := fmt.Sprintf(`{"password": "%v"}`, p.password)
loginResponseString, statusCode, err := common.Post(&p.Client, p.baseURL+"/auth", headers, &loginPayload)
if err != nil {
return err
Expand All @@ -50,11 +52,29 @@ func (p *Client) Login(password string) error {
return errors.New(resp.Session.Message)
}

p.password = password
p.sid = resp.Session.Sid
return nil
}

func (p *Client) Logout() error {
if p.sid == "" {
return nil
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()

headers["X-FTL-SID"] = p.sid
_, statusCode, err := common.DeleteWithContext(ctx, &p.Client, p.baseURL+"/auth", headers)
if err != nil {
return err
}
p.sid = ""
if statusCode >= 400 {
log.Warn("Pi-Hole logout returned non-success status", "status", statusCode)
}
return nil
}

func rawDnsRecordToRecord(rawDnsRecord string) (DomainName, IP, error) {
splitRawDnsRecord := strings.Split(rawDnsRecord, " ")
if len(splitRawDnsRecord) == 2 {
Expand All @@ -72,8 +92,7 @@ func dnsRecordToRaw(domain DomainName, ip IP) string {

func (p *Client) GetDnsRecords() (DnsRecords, error) {
if p.sid == "" {
log.Error("Missing Pi-Hole session ID")
os.Exit(1)
return nil, errMissingSessionId
}
headers["X-FTL-SID"] = p.sid
configResponseString, _, err := common.Get(&p.Client, p.baseURL+"/config", headers)
Expand Down Expand Up @@ -127,8 +146,7 @@ func (p *Client) AddDnsRecords(domains []string, ip string) (numOfAddedDnsRecord
}

if p.sid == "" {
log.Error("Missing Pi-Hole session ID")
os.Exit(1)
return 0, errMissingSessionId
}
headers["X-FTL-SID"] = p.sid
resp, statusCode, err := common.Patch(&p.Client, p.baseURL+"/config", headers, string(payloadString))
Expand All @@ -137,7 +155,9 @@ func (p *Client) AddDnsRecords(domains []string, ip string) (numOfAddedDnsRecord
}

if statusCode == 401 {
p.refreshAuth()
if err := p.refreshAuth(); err != nil {
return 0, errors.Join(errAuthRefreshFailed, err)
}
return p.AddDnsRecords(domains, ip)
}

Expand Down Expand Up @@ -183,8 +203,7 @@ func (p *Client) DeleteDnsRecords(domains []string) (numOfDeletedDnsRecords int,
}

if p.sid == "" {
log.Error("Missing Pi-Hole session ID")
os.Exit(1)
return 0, errMissingSessionId
}
headers["X-FTL-SID"] = p.sid
resp, statusCode, err := common.Patch(&p.Client, p.baseURL+"/config", headers, string(payloadString))
Expand All @@ -193,7 +212,9 @@ func (p *Client) DeleteDnsRecords(domains []string) (numOfDeletedDnsRecords int,
}

if statusCode == 401 {
p.refreshAuth()
if err := p.refreshAuth(); err != nil {
return 0, errors.Join(errAuthRefreshFailed, err)
}
return p.DeleteDnsRecords(domains)
}

Expand Down Expand Up @@ -223,8 +244,7 @@ func cNameRecordToRaw(domain DomainName, target Target) string {

func (p *Client) getCNameRecords() (CNameRecords, error) {
if p.sid == "" {
log.Error("Missing Pi-Hole session ID")
os.Exit(1)
return nil, errMissingSessionId
}
headers["X-FTL-SID"] = p.sid
configResponseString, _, err := common.Get(&p.Client, p.baseURL+"/config", headers)
Expand Down Expand Up @@ -278,8 +298,7 @@ func (p *Client) AddCNameRecords(domains []string, target string) (numOfAddedCNa
}

if p.sid == "" {
log.Error("Missing Pi-Hole session ID")
os.Exit(1)
return numOfAddedCNameRecords, errMissingSessionId
}
headers["X-FTL-SID"] = p.sid
resp, statusCode, err := common.Patch(&p.Client, p.baseURL+"/config", headers, string(payloadString))
Expand All @@ -288,7 +307,9 @@ func (p *Client) AddCNameRecords(domains []string, target string) (numOfAddedCNa
}

if statusCode == 401 {
p.refreshAuth()
if err := p.refreshAuth(); err != nil {
return numOfAddedCNameRecords, errors.Join(errAuthRefreshFailed, err)
}
return p.AddCNameRecords(domains, target)
}

Expand Down Expand Up @@ -334,8 +355,7 @@ func (p *Client) DeleteCNameRecords(domains []string) (numOfDeletedCNameRecords
}

if p.sid == "" {
log.Error("Missing Pi-Hole session ID")
os.Exit(1)
return numOfDeletedCNameRecords, errMissingSessionId
}
headers["X-FTL-SID"] = p.sid
resp, statusCode, err := common.Patch(&p.Client, p.baseURL+"/config", headers, string(payloadString))
Expand All @@ -344,7 +364,9 @@ func (p *Client) DeleteCNameRecords(domains []string) (numOfDeletedCNameRecords
}

if statusCode == 401 {
p.refreshAuth()
if err := p.refreshAuth(); err != nil {
return numOfDeletedCNameRecords, errors.Join(errAuthRefreshFailed, err)
}
return p.DeleteCNameRecords(domains)
}

Expand All @@ -357,7 +379,10 @@ func (p *Client) DeleteCNameRecords(domains []string) (numOfDeletedCNameRecords
return len(deletedDomains), nil
}

func (p *Client) refreshAuth() {
func (p *Client) refreshAuth() error {
log.Info("Refreshing Pi-Hole authentication")
p.Login(p.password)
if err := p.Logout(); err != nil {
log.Warn("Failed to logout old Pi-Hole session", "error", err)
}
return p.Login()
}
Loading
Loading