diff --git a/internal/model/upstream.go b/internal/model/upstream.go index ba71865..b8d6c0b 100644 --- a/internal/model/upstream.go +++ b/internal/model/upstream.go @@ -177,6 +177,36 @@ func (up *Upstream) InitConnectionPool(bootstrap func(host string) (net.IP, erro } } +// isPrivateIP checks if an IP address is in the private IP range +// Private IP ranges: 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16 +func isPrivateIP(ip net.IP) bool { + if ip == nil { + return false + } + // Convert to 4-byte representation for IPv4 + ip4 := ip.To4() + if ip4 != nil { + // 10.0.0.0/8 + if ip4[0] == 10 { + return true + } + // 172.16.0.0/12 + if ip4[0] == 172 && ip4[1] >= 16 && ip4[1] <= 31 { + return true + } + // 192.168.0.0/16 + if ip4[0] == 192 && ip4[1] == 168 { + return true + } + return false + } + // IPv6 private ranges (fc00::/7 - Unique Local Addresses) + if len(ip) == 16 && (ip[0]&0xfe) == 0xfc { + return true + } + return false +} + func (up *Upstream) IsValidMsg(r *dns.Msg) bool { domain := GetDomainNameFromDnsMsg(r) inBlacklist := utils.HasMatchedRule(up.config.BlacklistSplited, domain) @@ -192,6 +222,16 @@ func (up *Upstream) IsValidMsg(r *dns.Msg) bool { } ip = typeAAAA.AAAA } + + // Private IPs should always be considered valid regardless of primary/non-primary status + // IPv4 private ranges: 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16 + // IPv6 private ranges: fc00::/7 (Unique Local Addresses) + // They are used for internal networks and are not subject to geographical restrictions + if isPrivateIP(ip) { + up.logger.Printf("checkPrimary result %s: %s@%s -> private IP, skipping primary check", up.Address, domain, ip) + continue + } + isPrimary, err := up.ipRanger.Contains(ip) if err != nil { up.logger.Printf("ipRanger query ip %s failed: %s", ip, err) diff --git a/internal/model/upstream_test.go b/internal/model/upstream_test.go index f9ecad6..616e912 100644 --- a/internal/model/upstream_test.go +++ b/internal/model/upstream_test.go @@ -2,10 +2,14 @@ package model import ( "index/suffixarray" + "net" "strings" "testing" + "github.com/miekg/dns" + "github.com/naiba/nbdns/pkg/logger" "github.com/naiba/nbdns/pkg/utils" + "github.com/yl2chen/cidranger" ) var primaryLocations = []string{"中国", "省", "市", "自治区"} @@ -118,3 +122,161 @@ func checkPrimaryStringsContains(str string) bool { } return false } + +// TestIsPrivateIP tests the isPrivateIP function +func TestIsPrivateIP(t *testing.T) { + tests := []struct { + name string + ip string + expected bool + }{ + // Private IPv4 ranges + {"10.0.0.0", "10.0.0.0", true}, + {"10.255.255.255", "10.255.255.255", true}, + {"10.1.2.3", "10.1.2.3", true}, + {"172.16.0.0", "172.16.0.0", true}, + {"172.31.255.255", "172.31.255.255", true}, + {"172.20.1.1", "172.20.1.1", true}, + {"192.168.0.0", "192.168.0.0", true}, + {"192.168.255.255", "192.168.255.255", true}, + {"192.168.1.1", "192.168.1.1", true}, + + // Public IPv4 addresses (not private) + {"8.8.8.8", "8.8.8.8", false}, + {"1.1.1.1", "1.1.1.1", false}, + {"172.15.0.1", "172.15.0.1", false}, // Just before 172.16.0.0/12 + {"172.32.0.1", "172.32.0.1", false}, // Just after 172.31.255.255 + {"192.167.1.1", "192.167.1.1", false}, // Not 192.168 + {"192.169.1.1", "192.169.1.1", false}, // Not 192.168 + {"11.0.0.1", "11.0.0.1", false}, // Not 10.x.x.x + + // IPv6 private (Unique Local Addresses fc00::/7) + {"fc00::1", "fc00::1", true}, + {"fd00::1", "fd00::1", true}, + {"fdff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", "fdff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", true}, + + // IPv6 public (not private) + {"2001:4860:4860::8888", "2001:4860:4860::8888", false}, + {"fe80::1", "fe80::1", false}, // Link-local, not ULA + + // Edge cases + {"nil IP", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var ip net.IP + if tt.ip != "" { + ip = net.ParseIP(tt.ip) + } + result := isPrivateIP(ip) + if result != tt.expected { + t.Errorf("isPrivateIP(%s) = %v, want %v", tt.ip, result, tt.expected) + } + }) + } +} + +// TestIsValidMsgWithPrivateIP tests that private IPs are not dropped +func TestIsValidMsgWithPrivateIP(t *testing.T) { + // Create a simple IP ranger with a test IP range (e.g., 1.0.0.0/8) + ipRanger := cidranger.NewPCTrieRanger() + _, network, _ := net.ParseCIDR("1.0.0.0/8") + ipRanger.Insert(cidranger.NewBasicRangerEntry(*network)) + + log := logger.New(false) + + tests := []struct { + name string + isPrimary bool + ip string + shouldPass bool + reason string + }{ + { + name: "Primary DNS with private IP 10.x", + isPrimary: true, + ip: "10.0.0.1", + shouldPass: true, + reason: "Private IPs should always be valid", + }, + { + name: "Primary DNS with private IP 172.16.x", + isPrimary: true, + ip: "172.16.0.1", + shouldPass: true, + reason: "Private IPs should always be valid", + }, + { + name: "Primary DNS with private IP 192.168.x", + isPrimary: true, + ip: "192.168.1.1", + shouldPass: true, + reason: "Private IPs should always be valid", + }, + { + name: "Primary DNS with public non-primary IP", + isPrimary: true, + ip: "8.8.8.8", + shouldPass: false, + reason: "Public non-primary IPs should be rejected by primary DNS", + }, + { + name: "Primary DNS with public primary IP", + isPrimary: true, + ip: "1.0.0.1", + shouldPass: true, + reason: "Primary IPs should be valid for primary DNS", + }, + { + name: "Non-primary DNS with private IP", + isPrimary: false, + ip: "10.0.0.1", + shouldPass: true, + reason: "Private IPs should always be valid", + }, + { + name: "Non-primary DNS with public IP", + isPrimary: false, + ip: "8.8.8.8", + shouldPass: true, + reason: "Non-primary DNS accepts any IP", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + config := &Config{ + BlacklistSplited: [][]string{}, + } + + upstream := &Upstream{ + IsPrimary: tt.isPrimary, + Address: "test://example.com:53", + config: config, + ipRanger: ipRanger, + logger: log, + } + + // Create a DNS message with an A record + msg := &dns.Msg{ + Answer: []dns.RR{ + &dns.A{ + Hdr: dns.RR_Header{ + Name: "example.com.", + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 300, + }, + A: net.ParseIP(tt.ip).To4(), + }, + }, + } + + result := upstream.IsValidMsg(msg) + if result != tt.shouldPass { + t.Errorf("%s: IsValidMsg() = %v, want %v. Reason: %s", tt.name, result, tt.shouldPass, tt.reason) + } + }) + } +}