diff --git a/internal/rss/rss.go b/internal/rss/rss.go index a162811..824313e 100644 --- a/internal/rss/rss.go +++ b/internal/rss/rss.go @@ -10,6 +10,8 @@ import ( "github.com/PuerkitoBio/goquery" "github.com/mmcdole/gofeed" + + "github.com/Hyaxia/blogwatcher/internal/safeclient" ) type FeedArticle struct { @@ -27,8 +29,7 @@ func (e FeedParseError) Error() string { } func ParseFeed(feedURL string, timeout time.Duration) ([]FeedArticle, error) { - client := &http.Client{Timeout: timeout} - response, err := client.Get(feedURL) + response, err := safeclient.SafeGet(feedURL, timeout) if err != nil { return nil, FeedParseError{Message: fmt.Sprintf("failed to fetch feed: %v", err)} } @@ -61,8 +62,7 @@ func ParseFeed(feedURL string, timeout time.Duration) ([]FeedArticle, error) { } func DiscoverFeedURL(blogURL string, timeout time.Duration) (string, error) { - client := &http.Client{Timeout: timeout} - response, err := client.Get(blogURL) + response, err := safeclient.SafeGet(blogURL, timeout) if err != nil { return "", nil } @@ -130,8 +130,7 @@ func DiscoverFeedURL(blogURL string, timeout time.Duration) (string, error) { } func isValidFeed(feedURL string, timeout time.Duration) (bool, error) { - client := &http.Client{Timeout: timeout} - response, err := client.Get(feedURL) + response, err := safeclient.SafeGet(feedURL, timeout) if err != nil { return false, err } diff --git a/internal/rss/rss_test.go b/internal/rss/rss_test.go index 38766f8..66c3e00 100644 --- a/internal/rss/rss_test.go +++ b/internal/rss/rss_test.go @@ -3,10 +3,18 @@ package rss import ( "net/http" "net/http/httptest" + "os" "testing" "time" + + "github.com/Hyaxia/blogwatcher/internal/safeclient" ) +func TestMain(m *testing.M) { + safeclient.SetTestAllowPrivate(true) + os.Exit(m.Run()) +} + const sampleFeed = ` diff --git a/internal/safeclient/safeclient.go b/internal/safeclient/safeclient.go new file mode 100644 index 0000000..be69fcc --- /dev/null +++ b/internal/safeclient/safeclient.go @@ -0,0 +1,147 @@ +package safeclient + +import ( + "context" + "errors" + "fmt" + "net" + "net/http" + "net/url" + "time" +) + +// denyList contains CIDR ranges that must never be contacted. +var denyList []*net.IPNet + +// testAllowPrivate is a global flag that, when true, disables the IP deny-list +// for all SafeGet calls. It is set via SetTestAllowPrivate and intended for +// test infrastructure only (e.g. tests using httptest.NewServer on 127.0.0.1). +var testAllowPrivate bool + +// SetTestAllowPrivate globally disables (or re-enables) the IP deny-list. +// This is intended for use in TestMain of packages that test against +// httptest.NewServer, which binds to loopback. +func SetTestAllowPrivate(allow bool) { + testAllowPrivate = allow +} + +func init() { + for _, cidr := range []string{ + // IPv4 + "127.0.0.0/8", // loopback + "10.0.0.0/8", // RFC 1918 + "172.16.0.0/12", // RFC 1918 + "192.168.0.0/16", // RFC 1918 + "169.254.0.0/16", // link-local / cloud metadata + "0.0.0.0/8", // "this" network + + // IPv6 + "::1/128", // loopback + "fc00::/7", // unique local + "fe80::/10", // link-local + } { + _, network, _ := net.ParseCIDR(cidr) + denyList = append(denyList, network) + } +} + +// SSRFError is returned when a URL resolves to a blocked IP range. +type SSRFError struct { + URL string + Message string +} + +func (e *SSRFError) Error() string { + return fmt.Sprintf("ssrf blocked: %s — %s", e.URL, e.Message) +} + +// IsSSRFError reports whether err (or any error in its chain) is an SSRFError. +func IsSSRFError(err error) bool { + var target *SSRFError + return errors.As(err, &target) +} + +// Option configures SafeGet behaviour. +type Option func(*options) + +type options struct { + allowPrivate bool +} + +// AllowPrivate disables the IP deny-list check. +// This is intended for use in tests only. +func AllowPrivate() Option { + return func(o *options) { o.allowPrivate = true } +} + +// SafeGet performs an HTTP GET after validating the URL against SSRF rules. +// +// It enforces: +// - http or https scheme only +// - hostname must not resolve to a private/reserved IP range +func SafeGet(rawURL string, timeout time.Duration, opts ...Option) (*http.Response, error) { + cfg := &options{} + for _, o := range opts { + o(cfg) + } + + parsed, err := url.Parse(rawURL) + if err != nil { + return nil, &SSRFError{URL: rawURL, Message: "malformed URL"} + } + if parsed.Scheme != "http" && parsed.Scheme != "https" { + return nil, &SSRFError{URL: rawURL, Message: fmt.Sprintf("unsupported scheme %q", parsed.Scheme)} + } + + hostname := parsed.Hostname() + if hostname == "" { + return nil, &SSRFError{URL: rawURL, Message: "empty hostname"} + } + + if !cfg.allowPrivate && !testAllowPrivate { + if err := checkHost(hostname, timeout); err != nil { + return nil, err + } + } + + client := &http.Client{Timeout: timeout} + return client.Get(rawURL) +} + +// checkHost resolves hostname and rejects it if every IP falls in the deny-list. +func checkHost(hostname string, timeout time.Duration) error { + // Fast path: if hostname is already an IP literal, check it directly. + if ip := net.ParseIP(hostname); ip != nil { + if isBlocked(ip) { + return &SSRFError{URL: hostname, Message: fmt.Sprintf("IP %s is in a blocked range", ip)} + } + return nil + } + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + addrs, err := net.DefaultResolver.LookupIPAddr(ctx, hostname) + if err != nil { + return &SSRFError{URL: hostname, Message: fmt.Sprintf("DNS resolution failed: %v", err)} + } + if len(addrs) == 0 { + return &SSRFError{URL: hostname, Message: "DNS returned no addresses"} + } + + for _, addr := range addrs { + if !isBlocked(addr.IP) { + return nil // at least one public IP — allow + } + } + return &SSRFError{URL: hostname, Message: "all resolved IPs are in blocked ranges"} +} + +func isBlocked(ip net.IP) bool { + for _, network := range denyList { + if network.Contains(ip) { + return true + } + } + return false +} diff --git a/internal/safeclient/safeclient_test.go b/internal/safeclient/safeclient_test.go new file mode 100644 index 0000000..346452f --- /dev/null +++ b/internal/safeclient/safeclient_test.go @@ -0,0 +1,112 @@ +package safeclient + +import ( + "net/http" + "net/http/httptest" + "testing" + "time" +) + +func TestSafeGet_PublicURL(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer server.Close() + + // httptest.NewServer binds to 127.0.0.1, so we must AllowPrivate for the test to reach it. + resp, err := SafeGet(server.URL, 2*time.Second, AllowPrivate()) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } +} + +func TestSafeGet_BlocksLoopback(t *testing.T) { + _, err := SafeGet("http://127.0.0.1/secret", 2*time.Second) + if err == nil { + t.Fatal("expected error for loopback address") + } + if !IsSSRFError(err) { + t.Fatalf("expected SSRFError, got %T: %v", err, err) + } +} + +func TestSafeGet_BlocksIPv6Loopback(t *testing.T) { + _, err := SafeGet("http://[::1]/secret", 2*time.Second) + if err == nil { + t.Fatal("expected error for IPv6 loopback") + } + if !IsSSRFError(err) { + t.Fatalf("expected SSRFError, got %T: %v", err, err) + } +} + +func TestSafeGet_BlocksMetadataEndpoint(t *testing.T) { + _, err := SafeGet("http://169.254.169.254/latest/meta-data/", 2*time.Second) + if err == nil { + t.Fatal("expected error for metadata endpoint") + } + if !IsSSRFError(err) { + t.Fatalf("expected SSRFError, got %T: %v", err, err) + } +} + +func TestSafeGet_BlocksRFC1918(t *testing.T) { + for _, addr := range []string{ + "http://10.0.0.1/", + "http://172.16.0.1/", + "http://192.168.1.1/", + } { + _, err := SafeGet(addr, 2*time.Second) + if err == nil { + t.Fatalf("expected error for private address %s", addr) + } + if !IsSSRFError(err) { + t.Fatalf("expected SSRFError for %s, got %T: %v", addr, err, err) + } + } +} + +func TestSafeGet_BlocksNonHTTPScheme(t *testing.T) { + for _, u := range []string{ + "ftp://example.com/file", + "file:///etc/passwd", + "gopher://example.com", + } { + _, err := SafeGet(u, 2*time.Second) + if err == nil { + t.Fatalf("expected error for scheme in %s", u) + } + if !IsSSRFError(err) { + t.Fatalf("expected SSRFError for %s, got %T: %v", u, err, err) + } + } +} + +func TestSafeGet_AllowPrivateBypassesDenyList(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + resp, err := SafeGet(server.URL, 2*time.Second, AllowPrivate()) + if err != nil { + t.Fatalf("AllowPrivate should bypass deny list, got error: %v", err) + } + resp.Body.Close() +} + +func TestIsSSRFError(t *testing.T) { + err := &SSRFError{URL: "http://127.0.0.1", Message: "blocked"} + if !IsSSRFError(err) { + t.Fatal("expected IsSSRFError to return true") + } + + if IsSSRFError(nil) { + t.Fatal("expected IsSSRFError(nil) to return false") + } +} diff --git a/internal/scanner/scanner_test.go b/internal/scanner/scanner_test.go index 376a9a2..e04e5b4 100644 --- a/internal/scanner/scanner_test.go +++ b/internal/scanner/scanner_test.go @@ -3,14 +3,21 @@ package scanner import ( "net/http" "net/http/httptest" + "os" "path/filepath" "testing" "time" "github.com/Hyaxia/blogwatcher/internal/model" + "github.com/Hyaxia/blogwatcher/internal/safeclient" "github.com/Hyaxia/blogwatcher/internal/storage" ) +func TestMain(m *testing.M) { + safeclient.SetTestAllowPrivate(true) + os.Exit(m.Run()) +} + const sampleFeed = ` diff --git a/internal/scraper/scraper.go b/internal/scraper/scraper.go index df25020..7f7d114 100644 --- a/internal/scraper/scraper.go +++ b/internal/scraper/scraper.go @@ -3,12 +3,13 @@ package scraper import ( "errors" "fmt" - "net/http" "net/url" "strings" "time" "github.com/PuerkitoBio/goquery" + + "github.com/Hyaxia/blogwatcher/internal/safeclient" ) type ScrapedArticle struct { @@ -26,8 +27,7 @@ func (e ScrapeError) Error() string { } func ScrapeBlog(blogURL string, selector string, timeout time.Duration) ([]ScrapedArticle, error) { - client := &http.Client{Timeout: timeout} - response, err := client.Get(blogURL) + response, err := safeclient.SafeGet(blogURL, timeout) if err != nil { return nil, ScrapeError{Message: fmt.Sprintf("failed to fetch page: %v", err)} } diff --git a/internal/scraper/scraper_test.go b/internal/scraper/scraper_test.go index 83c46cd..0f95bcd 100644 --- a/internal/scraper/scraper_test.go +++ b/internal/scraper/scraper_test.go @@ -3,10 +3,18 @@ package scraper import ( "net/http" "net/http/httptest" + "os" "testing" "time" + + "github.com/Hyaxia/blogwatcher/internal/safeclient" ) +func TestMain(m *testing.M) { + safeclient.SetTestAllowPrivate(true) + os.Exit(m.Run()) +} + func TestScrapeBlog(t *testing.T) { html := `