Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions internal/check/probe_dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package check

import (
"context"
"errors"
"fmt"
"net"
"time"
Expand All @@ -19,12 +20,37 @@ type DNSProbe struct {
resolver DNSResolver
}

// NewDNSProbe creates a new DNS probe for the given resolver and domain.
func NewDNSProbe(dnsResolver, domain string) *DNSProbe {
// DefaultDNSPort is the default DNS resolver port.
const DefaultDNSPort = "53"

var (
// ErrDNSMissingDomain is returned when no domain is specified.
ErrDNSMissingDomain = errors.New("DNS probe missing domain")
// ErrDNSMissingResolver is returned when no resolver is specified.
ErrDNSMissingResolver = errors.New("DNS probe missing resolver")
)

// NewDNSProbe creates a new DNS probe for the given resolver host (host:port
// or host-only, port defaults to 53) and domain.
func NewDNSProbe(host, domain string) (*DNSProbe, error) {
if domain == "" {
return nil, ErrDNSMissingDomain
}

hostname, port, err := net.SplitHostPort(host)
if err != nil {
hostname = host
port = DefaultDNSPort
}

if hostname == "" {
return nil, ErrDNSMissingResolver
}

return &DNSProbe{
DNSResolver: dnsResolver,
DNSResolver: net.JoinHostPort(hostname, port),
Domain: domain,
}
}, nil
}

// Scheme returns the protocol scheme (dns).
Expand Down
59 changes: 59 additions & 0 deletions internal/check/probe_dns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ import (
"net"
"strings"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

type fakeResolver struct {
Expand All @@ -18,6 +21,62 @@ func (r *fakeResolver) LookupHost(_ context.Context, _ string) ([]string, error)
return r.result, r.err
}

const testDomain = "example.com"

func TestNewDNSProbe_Valid(t *testing.T) {
tests := []struct {
name string
host string
domain string
wantAddr string
wantDomain string
}{
{host: "1.1.1.1", domain: testDomain, wantAddr: "1.1.1.1:53", wantDomain: testDomain},
{
host: "1.1.1.1:5353",
domain: testDomain,
wantAddr: "1.1.1.1:5353",
wantDomain: testDomain,
},
{
host: "[::1]:5353",
domain: testDomain,
wantAddr: "[::1]:5353",
wantDomain: testDomain,
},
}

for _, tt := range tests {
t.Run(tt.host, func(t *testing.T) {
probe, err := NewDNSProbe(tt.host, tt.domain)
require.NoError(t, err)
assert.Equal(t, tt.wantAddr, probe.DNSResolver)
assert.Equal(t, tt.wantDomain, probe.Domain)
})
}
}

func TestNewDNSProbe_Error(t *testing.T) {
tests := []struct {
name string
host string
domain string
wantErr error
}{
{name: "missing domain", host: "1.1.1.1", domain: "", wantErr: ErrDNSMissingDomain},
{name: "missing resolver", host: "", domain: testDomain, wantErr: ErrDNSMissingResolver},
{name: "port only", host: ":53", domain: testDomain, wantErr: ErrDNSMissingResolver},
{name: "both missing", host: "", domain: "", wantErr: ErrDNSMissingDomain},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewDNSProbe(tt.host, tt.domain)
require.ErrorIs(t, err, tt.wantErr)
})
}
}

func TestDnsProbe(t *testing.T) {
t.Run("returns the first resolved IP address if the request is successful", func(t *testing.T) {
resolver := &fakeResolver{result: []string{"1.2.3.4"}}
Expand Down
6 changes: 3 additions & 3 deletions internal/check/probe_http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ func TestHttpProbe_Success(t *testing.T) {
}

report := probe.Execute(context.Background(), testTimeout)
require.NoError(t, report.Error())
assert.Equal(t, testOKStatus, report.Response())
require.NoError(t, report.error)
assert.Equal(t, testOKStatus, report.response)
}

func TestHttpProbe_UserAgentHeader(t *testing.T) {
Expand All @@ -70,7 +70,7 @@ func TestHttpProbe_UserAgentHeader(t *testing.T) {
}

report := probe.Execute(context.Background(), testTimeout)
require.NoError(t, report.Error())
require.NoError(t, report.error)
assert.Equal(t, "upd/dev", gotUA)
}

Expand Down
30 changes: 0 additions & 30 deletions internal/check/report.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,33 +41,3 @@ func (r *Report) LogAttrs() []any {

return attrs
}

// Protocol returns the protocol used for the probe (e.g., "http", "tcp", "dns").
func (r *Report) Protocol() string {
return r.protocol
}

// Response returns the response from the probe.
// For HTTP requests, this is the HTTP status code.
// For TCP probes, this is the local address (IP:port).
// For DNS probes, this is the resolved IP address and DNS resolver.
// Returns empty string if there was an error.
func (r *Report) Response() string {
return r.response
}

// Elapsed returns the time taken to complete the probe.
func (r *Report) Elapsed() time.Duration {
return r.elapsed
}

// Error returns any error that occurred during the probe.
// Returns nil if the probe was successful.
func (r *Report) Error() error {
return r.error
}

// IsError returns true if the probe encountered an error.
func (r *Report) IsError() bool {
return r.error != nil
}
16 changes: 7 additions & 9 deletions internal/check/report_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func TestProtocol(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
report := &Report{protocol: tt.protocol}
assert.Equal(t, tt.protocol, report.Protocol())
assert.Equal(t, tt.protocol, report.protocol)
})
}
}
Expand All @@ -90,16 +90,15 @@ func TestResponse(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
report := &Report{response: tt.response}
assert.Equal(t, tt.response, report.Response())
assert.Equal(t, tt.response, report.response)
})
}
}

func TestResponse_WithErrors(t *testing.T) {
err := errors.New("network error")
report := &Report{error: err}
resp := report.Response()
assert.Empty(t, resp, "Response should be empty when there's an error")
assert.Empty(t, report.response, "Response should be empty when there's an error")
}

func TestElapsed(t *testing.T) {
Expand All @@ -115,7 +114,7 @@ func TestElapsed(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
report := &Report{elapsed: tt.elapsed}
assert.Equal(t, tt.elapsed, report.Elapsed())
assert.Equal(t, tt.elapsed, report.elapsed)
})
}
}
Expand All @@ -133,9 +132,8 @@ func TestError(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
report := &Report{error: tt.err}
err := report.Error()
assert.Equal(t, tt.err, err)
assert.Equal(t, tt.wantErr, err != nil)
assert.Equal(t, tt.err, report.error)
assert.Equal(t, tt.wantErr, report.error != nil)
})
}
}
Expand All @@ -154,7 +152,7 @@ func TestIsError(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
report := &Report{error: tt.err}
assert.Equal(t, tt.isError, report.IsError())
assert.Equal(t, tt.isError, report.error != nil)
})
}
}
12 changes: 6 additions & 6 deletions internal/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ const (
AppName = "upd"
// AppShort is the application short description.
AppShort = "Tool to monitor if the network connection is up."
// ExitCodeError is the exit code for errors.
ExitCodeError = 1
// ErrChanSize is the buffer size for error channels.
ErrChanSize = 1
// SighupChanSize is the buffer size for SIGHUP channels.
Expand All @@ -35,8 +33,6 @@ const (
ConfigConfig string = "config"
// ConfigDebug is the debug flag name.
ConfigDebug string = "debug"
// ConfigDump is the dump flag name.
ConfigDump string = "dump"
)

// SetupLoop initializes the loop with configuration from the given file.
Expand Down Expand Up @@ -97,7 +93,7 @@ func Run(appCtx context.Context, cmd *cli.Command) error {

select {
case <-rootCtx.Done():
logger.L.Info("shutting down", "component", "app")
logger.L.Info("shutting down", logger.LogComponent, logger.LogComponentApp)
cancelCurrentWorker()
<-done

Expand All @@ -112,7 +108,11 @@ func Run(appCtx context.Context, cmd *cli.Command) error {

<-done
case <-sighupCh:
logger.L.Info("SIGHUP received: reloading configuration", "component", "app")
logger.L.Info(
"SIGHUP received: reloading configuration",
logger.LogComponent,
logger.LogComponentApp,
)
cancelCurrentWorker()
<-done
}
Expand Down
Loading
Loading