Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions internal/rss/rss.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (

"github.com/PuerkitoBio/goquery"
"github.com/mmcdole/gofeed"

"github.com/Hyaxia/blogwatcher/internal/safeclient"
)

type FeedArticle struct {
Expand All @@ -27,8 +29,7 @@ func (e FeedParseError) Error() string {
}

func ParseFeed(feedURL string, timeout time.Duration) ([]FeedArticle, error) {
client := &http.Client{Timeout: timeout}
response, err := client.Get(feedURL)
response, err := safeclient.SafeGet(feedURL, timeout)
if err != nil {
return nil, FeedParseError{Message: fmt.Sprintf("failed to fetch feed: %v", err)}
}
Expand Down Expand Up @@ -61,8 +62,7 @@ func ParseFeed(feedURL string, timeout time.Duration) ([]FeedArticle, error) {
}

func DiscoverFeedURL(blogURL string, timeout time.Duration) (string, error) {
client := &http.Client{Timeout: timeout}
response, err := client.Get(blogURL)
response, err := safeclient.SafeGet(blogURL, timeout)
if err != nil {
return "", nil
}
Expand Down Expand Up @@ -130,8 +130,7 @@ func DiscoverFeedURL(blogURL string, timeout time.Duration) (string, error) {
}

func isValidFeed(feedURL string, timeout time.Duration) (bool, error) {
client := &http.Client{Timeout: timeout}
response, err := client.Get(feedURL)
response, err := safeclient.SafeGet(feedURL, timeout)
if err != nil {
return false, err
}
Expand Down
8 changes: 8 additions & 0 deletions internal/rss/rss_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,18 @@ package rss
import (
"net/http"
"net/http/httptest"
"os"
"testing"
"time"

"github.com/Hyaxia/blogwatcher/internal/safeclient"
)

func TestMain(m *testing.M) {
safeclient.SetTestAllowPrivate(true)
os.Exit(m.Run())
}

const sampleFeed = `<?xml version="1.0" encoding="UTF-8" ?>
<rss version="2.0">
<channel>
Expand Down
147 changes: 147 additions & 0 deletions internal/safeclient/safeclient.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package safeclient

import (
"context"
"errors"
"fmt"
"net"
"net/http"
"net/url"
"time"
)

// denyList contains CIDR ranges that must never be contacted.
var denyList []*net.IPNet

// testAllowPrivate is a global flag that, when true, disables the IP deny-list
// for all SafeGet calls. It is set via SetTestAllowPrivate and intended for
// test infrastructure only (e.g. tests using httptest.NewServer on 127.0.0.1).
var testAllowPrivate bool

// SetTestAllowPrivate globally disables (or re-enables) the IP deny-list.
// This is intended for use in TestMain of packages that test against
// httptest.NewServer, which binds to loopback.
func SetTestAllowPrivate(allow bool) {
testAllowPrivate = allow
}

func init() {
for _, cidr := range []string{
// IPv4
"127.0.0.0/8", // loopback
"10.0.0.0/8", // RFC 1918
"172.16.0.0/12", // RFC 1918
"192.168.0.0/16", // RFC 1918
"169.254.0.0/16", // link-local / cloud metadata
"0.0.0.0/8", // "this" network

// IPv6
"::1/128", // loopback
"fc00::/7", // unique local
"fe80::/10", // link-local
} {
_, network, _ := net.ParseCIDR(cidr)
denyList = append(denyList, network)
}
}

// SSRFError is returned when a URL resolves to a blocked IP range.
type SSRFError struct {
URL string
Message string
}

func (e *SSRFError) Error() string {
return fmt.Sprintf("ssrf blocked: %s — %s", e.URL, e.Message)
}

// IsSSRFError reports whether err (or any error in its chain) is an SSRFError.
func IsSSRFError(err error) bool {
var target *SSRFError
return errors.As(err, &target)
}

// Option configures SafeGet behaviour.
type Option func(*options)

type options struct {
allowPrivate bool
}

// AllowPrivate disables the IP deny-list check.
// This is intended for use in tests only.
func AllowPrivate() Option {
return func(o *options) { o.allowPrivate = true }
}

// SafeGet performs an HTTP GET after validating the URL against SSRF rules.
//
// It enforces:
// - http or https scheme only
// - hostname must not resolve to a private/reserved IP range
func SafeGet(rawURL string, timeout time.Duration, opts ...Option) (*http.Response, error) {
cfg := &options{}
for _, o := range opts {
o(cfg)
}

parsed, err := url.Parse(rawURL)
if err != nil {
return nil, &SSRFError{URL: rawURL, Message: "malformed URL"}
}
if parsed.Scheme != "http" && parsed.Scheme != "https" {
return nil, &SSRFError{URL: rawURL, Message: fmt.Sprintf("unsupported scheme %q", parsed.Scheme)}
}

hostname := parsed.Hostname()
if hostname == "" {
return nil, &SSRFError{URL: rawURL, Message: "empty hostname"}
}

if !cfg.allowPrivate && !testAllowPrivate {
if err := checkHost(hostname, timeout); err != nil {
return nil, err
}
}

client := &http.Client{Timeout: timeout}
return client.Get(rawURL)
}

// checkHost resolves hostname and rejects it if every IP falls in the deny-list.
func checkHost(hostname string, timeout time.Duration) error {
// Fast path: if hostname is already an IP literal, check it directly.
if ip := net.ParseIP(hostname); ip != nil {
if isBlocked(ip) {
return &SSRFError{URL: hostname, Message: fmt.Sprintf("IP %s is in a blocked range", ip)}
}
return nil
}

ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()

addrs, err := net.DefaultResolver.LookupIPAddr(ctx, hostname)
if err != nil {
return &SSRFError{URL: hostname, Message: fmt.Sprintf("DNS resolution failed: %v", err)}
}
if len(addrs) == 0 {
return &SSRFError{URL: hostname, Message: "DNS returned no addresses"}
}

for _, addr := range addrs {
if !isBlocked(addr.IP) {
return nil // at least one public IP — allow
}
}
return &SSRFError{URL: hostname, Message: "all resolved IPs are in blocked ranges"}
}

func isBlocked(ip net.IP) bool {
for _, network := range denyList {
if network.Contains(ip) {
return true
}
}
return false
}
112 changes: 112 additions & 0 deletions internal/safeclient/safeclient_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package safeclient

import (
"net/http"
"net/http/httptest"
"testing"
"time"
)

func TestSafeGet_PublicURL(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer server.Close()

// httptest.NewServer binds to 127.0.0.1, so we must AllowPrivate for the test to reach it.
resp, err := SafeGet(server.URL, 2*time.Second, AllowPrivate())
if err != nil {
t.Fatalf("expected no error, got %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.StatusCode)
}
}

func TestSafeGet_BlocksLoopback(t *testing.T) {
_, err := SafeGet("http://127.0.0.1/secret", 2*time.Second)
if err == nil {
t.Fatal("expected error for loopback address")
}
if !IsSSRFError(err) {
t.Fatalf("expected SSRFError, got %T: %v", err, err)
}
}

func TestSafeGet_BlocksIPv6Loopback(t *testing.T) {
_, err := SafeGet("http://[::1]/secret", 2*time.Second)
if err == nil {
t.Fatal("expected error for IPv6 loopback")
}
if !IsSSRFError(err) {
t.Fatalf("expected SSRFError, got %T: %v", err, err)
}
}

func TestSafeGet_BlocksMetadataEndpoint(t *testing.T) {
_, err := SafeGet("http://169.254.169.254/latest/meta-data/", 2*time.Second)
if err == nil {
t.Fatal("expected error for metadata endpoint")
}
if !IsSSRFError(err) {
t.Fatalf("expected SSRFError, got %T: %v", err, err)
}
}

func TestSafeGet_BlocksRFC1918(t *testing.T) {
for _, addr := range []string{
"http://10.0.0.1/",
"http://172.16.0.1/",
"http://192.168.1.1/",
} {
_, err := SafeGet(addr, 2*time.Second)
if err == nil {
t.Fatalf("expected error for private address %s", addr)
}
if !IsSSRFError(err) {
t.Fatalf("expected SSRFError for %s, got %T: %v", addr, err, err)
}
}
}

func TestSafeGet_BlocksNonHTTPScheme(t *testing.T) {
for _, u := range []string{
"ftp://example.com/file",
"file:///etc/passwd",
"gopher://example.com",
} {
_, err := SafeGet(u, 2*time.Second)
if err == nil {
t.Fatalf("expected error for scheme in %s", u)
}
if !IsSSRFError(err) {
t.Fatalf("expected SSRFError for %s, got %T: %v", u, err, err)
}
}
}

func TestSafeGet_AllowPrivateBypassesDenyList(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer server.Close()

resp, err := SafeGet(server.URL, 2*time.Second, AllowPrivate())
if err != nil {
t.Fatalf("AllowPrivate should bypass deny list, got error: %v", err)
}
resp.Body.Close()
}

func TestIsSSRFError(t *testing.T) {
err := &SSRFError{URL: "http://127.0.0.1", Message: "blocked"}
if !IsSSRFError(err) {
t.Fatal("expected IsSSRFError to return true")
}

if IsSSRFError(nil) {
t.Fatal("expected IsSSRFError(nil) to return false")
}
}
7 changes: 7 additions & 0 deletions internal/scanner/scanner_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@ package scanner
import (
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"

"github.com/Hyaxia/blogwatcher/internal/model"
"github.com/Hyaxia/blogwatcher/internal/safeclient"
"github.com/Hyaxia/blogwatcher/internal/storage"
)

func TestMain(m *testing.M) {
safeclient.SetTestAllowPrivate(true)
os.Exit(m.Run())
}

const sampleFeed = `<?xml version="1.0" encoding="UTF-8" ?>
<rss version="2.0">
<channel>
Expand Down
6 changes: 3 additions & 3 deletions internal/scraper/scraper.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ package scraper
import (
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"time"

"github.com/PuerkitoBio/goquery"

"github.com/Hyaxia/blogwatcher/internal/safeclient"
)

type ScrapedArticle struct {
Expand All @@ -26,8 +27,7 @@ func (e ScrapeError) Error() string {
}

func ScrapeBlog(blogURL string, selector string, timeout time.Duration) ([]ScrapedArticle, error) {
client := &http.Client{Timeout: timeout}
response, err := client.Get(blogURL)
response, err := safeclient.SafeGet(blogURL, timeout)
if err != nil {
return nil, ScrapeError{Message: fmt.Sprintf("failed to fetch page: %v", err)}
}
Expand Down
8 changes: 8 additions & 0 deletions internal/scraper/scraper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,18 @@ package scraper
import (
"net/http"
"net/http/httptest"
"os"
"testing"
"time"

"github.com/Hyaxia/blogwatcher/internal/safeclient"
)

func TestMain(m *testing.M) {
safeclient.SetTestAllowPrivate(true)
os.Exit(m.Run())
}

func TestScrapeBlog(t *testing.T) {
html := `<!DOCTYPE html>
<html>
Expand Down