Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions cmd/root/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"log/slog"
"net"
"os"
"time"

Expand Down Expand Up @@ -92,6 +93,7 @@ func (f *apiFlags) runAPICommand(cmd *cobra.Command, args []string) (commandErr
defer lnCleanup()

out.Println("Listening on", ln.Addr().String())
warnIfNotLoopback(out, ln.Addr())

slog.Debug("Starting server", "agents", agentsPath, "addr", ln.Addr().String())

Expand Down Expand Up @@ -123,3 +125,22 @@ func (f *apiFlags) runAPICommand(cmd *cobra.Command, args []string) (commandErr

return s.Serve(ctx, ln)
}

// warnIfNotLoopback prints a security warning when the API server is bound to
// an address other than loopback. The default --listen value is 127.0.0.1, so
// reaching this code path means the operator was explicit about exposing the
// API; we just remind them that the API has no authentication.
func warnIfNotLoopback(out *cli.Printer, addr net.Addr) {
tcpAddr, ok := addr.(*net.TCPAddr)
if !ok {
// Unix sockets and named pipes rely on filesystem permissions.
return
}
if tcpAddr.IP.IsLoopback() {
return
}
out.Println("WARNING: API server is listening on a non-loopback address.")
out.Println(" The API has no authentication; anyone able to reach")
out.Println(" this address can run agents and access all sessions.")
slog.Warn("API server bound to non-loopback address", "addr", tcpAddr.String())
}
114 changes: 113 additions & 1 deletion pkg/config/sources.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@ import (
"fmt"
"io"
"log/slog"
"net"
"net/http"
"net/url"
"os"
"path/filepath"
"slices"
"strings"
"syscall"
"time"

"github.com/docker/docker-agent/pkg/content"
"github.com/docker/docker-agent/pkg/environment"
Expand Down Expand Up @@ -192,6 +195,11 @@ func hasLocalArtifact(store *content.Store, storeKey string) bool {
type urlSource struct {
url string
envProvider environment.Provider
// unsafe disables the HTTPS-only and SSRF dial-time checks. It is set
// only by the test-only constructor newURLSourceForTest (defined in
// sources_test.go), which exists because tests use httptest.NewServer
// (plain HTTP, 127.0.0.1).
unsafe bool
}

// NewURLSource creates a new URL source. If envProvider is non-nil, it will be used
Expand All @@ -217,6 +225,12 @@ func getURLCacheDir() string {
}

func (a urlSource) Read(ctx context.Context) ([]byte, error) {
if !a.unsafe {
if err := validateAgentURL(a.url); err != nil {
return nil, err
}
}

cacheDir := getURLCacheDir()
urlHash := hashURL(a.url)
cachePath := filepath.Join(cacheDir, urlHash)
Expand All @@ -241,7 +255,12 @@ func (a urlSource) Read(ctx context.Context) ([]byte, error) {
// Add GitHub token authorization for GitHub URLs
a.addGitHubAuth(ctx, req)

resp, err := httpclient.NewHTTPClient(ctx).Do(req)
client := httpclient.NewHTTPClient(ctx)
if !a.unsafe {
client = ssrfSafeHTTPClient()
}

resp, err := client.Do(req)
if err != nil {
// Network error - try to use cached version
if cachedData, cacheErr := os.ReadFile(cachePath); cacheErr == nil {
Expand Down Expand Up @@ -344,3 +363,96 @@ func hashURL(rawURL string) string {
func IsURLReference(input string) bool {
return strings.HasPrefix(input, "http://") || strings.HasPrefix(input, "https://")
}

// validateAgentURL enforces that an agent URL uses HTTPS. SSRF protection
// (rejecting connections to loopback / private / link-local addresses) is
// done at dial time by ssrfSafeHTTPClient so that DNS rebinding cannot be
// used to bypass it.
func validateAgentURL(rawURL string) error {
u, err := url.Parse(rawURL)
if err != nil {
return fmt.Errorf("invalid URL %q: %w", rawURL, err)
}
if u.Scheme != "https" {
return fmt.Errorf("refusing to load agent from %q: only https:// URLs are allowed (got scheme %q)", rawURL, u.Scheme)
}
if u.Host == "" {
return fmt.Errorf("invalid URL %q: missing host", rawURL)
}
return nil
}

// ssrfSafeHTTPClient returns an http.Client whose dialer rejects connections
// to non-public IP ranges (loopback, private, link-local, multicast,
// unspecified). The check happens after DNS resolution and before the TCP
// handshake, so DNS rebinding to a private IP is also blocked.
//
// Redirects are re-validated through CheckRedirect so that an https://
// origin cannot transparently downgrade to http:// or to a different scheme.
// SSRF protection on the redirect target is provided by the same dialer.
func ssrfSafeHTTPClient() *http.Client {
dialer := &net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
Control: ssrfDialControl,
}
return &http.Client{
Timeout: 60 * time.Second,
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[HIGH] SSRF bypass: http.ProxyFromEnvironment defeats dial-level protection

ssrfSafeHTTPClient sets Proxy: http.ProxyFromEnvironment, which means if HTTP_PROXY or HTTPS_PROXY is configured in the environment, all requests are routed through that proxy. The ssrfDialControl hook only fires for TCP dials made by this process — once a legitimate external proxy is connected, the proxy makes its own outbound connection to the target, entirely bypassing the dial-control SSRF protection.

Attack scenario: an operator's environment has HTTPS_PROXY=https://corp-proxy.example.com; an attacker-controlled agent YAML specifies url: https://169.254.169.254/latest/meta-data/.... The dial to corp-proxy.example.com passes (public IP), and the proxy faithfully fetches the metadata endpoint.

Fix: disable proxy support for this SSRF-sensitive client by setting Proxy: nil (or simply omitting Proxy — the zero value for the field disables proxying):

Transport: &http.Transport{
    // Proxy: nil  ← no proxy; ssrfDialControl only works for direct dials
    DialContext: dialer.DialContext,
    ...
}

DialContext: dialer.DialContext,
ForceAttemptHTTP2: true,
MaxIdleConns: 10,
IdleConnTimeout: 30 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ResponseHeaderTimeout: 30 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
},
CheckRedirect: ssrfCheckRedirect,
}
}

// ssrfCheckRedirect is the http.Client CheckRedirect hook used by
// ssrfSafeHTTPClient. It rejects redirects to non-https URLs (defeating
// TLS downgrade) and bounds the redirect chain. SSRF on the redirect
// target itself is enforced by the dialer's Control hook.
func ssrfCheckRedirect(req *http.Request, via []*http.Request) error {
if len(via) >= 10 {
return errors.New("stopped after 10 redirects")
}
if req.URL.Scheme != "https" {
return fmt.Errorf("refusing redirect to non-https URL %q", req.URL.Redacted())
}
return nil
}

// ssrfDialControl is invoked by net.Dialer after DNS resolution but before the
// TCP handshake. It rejects addresses that are not safe to fetch from over
// the public internet.
func ssrfDialControl(_, address string, _ syscall.RawConn) error {
host, _, err := net.SplitHostPort(address)
if err != nil {
return fmt.Errorf("parsing dial address %q: %w", address, err)
}
ip := net.ParseIP(host)
if ip == nil {
return fmt.Errorf("refusing to dial %q: not a valid IP", host)
}
if !isPublicIP(ip) {
return fmt.Errorf("refusing to dial non-public address %s", ip)
}
return nil
}

// isPublicIP reports whether ip is a routable public address. It rejects
// loopback (127/8, ::1), RFC1918 private ranges, link-local (incl. the
// 169.254.169.254 cloud metadata endpoint), multicast and the unspecified
// address (0.0.0.0, ::).
func isPublicIP(ip net.IP) bool {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[MEDIUM] isPublicIP does not block 100.64.0.0/10 (RFC 6598 Shared Address Space / CGNAT)

Go's net.IP.IsPrivate() (Go 1.17+) covers only RFC 1918 (10/8, 172.16/12, 192.168/16) and RFC 4193 (fc00::/7). It does not include the RFC 6598 Carrier-Grade NAT range 100.64.0.0/10. In cloud and ISP environments, addresses in this range are often used for internal-only services (e.g., hypervisor APIs, internal load-balancer health endpoints) that are unreachable from the public internet but reachable from within the hosting environment.

An agent URL pointing to 100.64.x.x would pass isPublicIP, allowing the SSRF protection to be bypassed in those environments.

Fix: add an explicit check for the CGNAT range:

var cgnatRange = func() *net.IPNet {
    _, n, _ := net.ParseCIDR("100.64.0.0/10")
    return n
}()

func isPublicIP(ip net.IP) bool {
    return !ip.IsLoopback() &&
        !ip.IsPrivate() &&
        !cgnatRange.Contains(ip) &&  // RFC 6598
        !ip.IsLinkLocalUnicast() &&
        !ip.IsLinkLocalMulticast() &&
        !ip.IsMulticast() &&
        !ip.IsUnspecified()
}

return !ip.IsLoopback() &&
!ip.IsPrivate() &&
!ip.IsLinkLocalUnicast() &&
!ip.IsLinkLocalMulticast() &&
!ip.IsMulticast() &&
!ip.IsUnspecified()
}
Loading
Loading