Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions internal/model/upstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,36 @@ func (up *Upstream) InitConnectionPool(bootstrap func(host string) (net.IP, erro
}
}

// isPrivateIP checks if an IP address is in the private IP range
// Private IP ranges: 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16
func isPrivateIP(ip net.IP) bool {
if ip == nil {
return false
}
// Convert to 4-byte representation for IPv4
ip4 := ip.To4()
if ip4 != nil {
// 10.0.0.0/8
if ip4[0] == 10 {
return true
}
// 172.16.0.0/12
if ip4[0] == 172 && ip4[1] >= 16 && ip4[1] <= 31 {
return true
}
// 192.168.0.0/16
if ip4[0] == 192 && ip4[1] == 168 {
return true
}
return false
}
// IPv6 private ranges (fc00::/7 - Unique Local Addresses)
if len(ip) == 16 && (ip[0]&0xfe) == 0xfc {
return true
}
return false
}

func (up *Upstream) IsValidMsg(r *dns.Msg) bool {
domain := GetDomainNameFromDnsMsg(r)
inBlacklist := utils.HasMatchedRule(up.config.BlacklistSplited, domain)
Expand All @@ -192,6 +222,16 @@ func (up *Upstream) IsValidMsg(r *dns.Msg) bool {
}
ip = typeAAAA.AAAA
}

// Private IPs should always be considered valid regardless of primary/non-primary status
// IPv4 private ranges: 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16
// IPv6 private ranges: fc00::/7 (Unique Local Addresses)
// They are used for internal networks and are not subject to geographical restrictions
if isPrivateIP(ip) {
up.logger.Printf("checkPrimary result %s: %s@%s -> private IP, skipping primary check", up.Address, domain, ip)
continue
}

isPrimary, err := up.ipRanger.Contains(ip)
if err != nil {
up.logger.Printf("ipRanger query ip %s failed: %s", ip, err)
Expand Down
162 changes: 162 additions & 0 deletions internal/model/upstream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@ package model

import (
"index/suffixarray"
"net"
"strings"
"testing"

"github.com/miekg/dns"
"github.com/naiba/nbdns/pkg/logger"
"github.com/naiba/nbdns/pkg/utils"
"github.com/yl2chen/cidranger"
)

var primaryLocations = []string{"中国", "省", "市", "自治区"}
Expand Down Expand Up @@ -118,3 +122,161 @@ func checkPrimaryStringsContains(str string) bool {
}
return false
}

// TestIsPrivateIP tests the isPrivateIP function
func TestIsPrivateIP(t *testing.T) {
tests := []struct {
name string
ip string
expected bool
}{
// Private IPv4 ranges
{"10.0.0.0", "10.0.0.0", true},
{"10.255.255.255", "10.255.255.255", true},
{"10.1.2.3", "10.1.2.3", true},
{"172.16.0.0", "172.16.0.0", true},
{"172.31.255.255", "172.31.255.255", true},
{"172.20.1.1", "172.20.1.1", true},
{"192.168.0.0", "192.168.0.0", true},
{"192.168.255.255", "192.168.255.255", true},
{"192.168.1.1", "192.168.1.1", true},

// Public IPv4 addresses (not private)
{"8.8.8.8", "8.8.8.8", false},
{"1.1.1.1", "1.1.1.1", false},
{"172.15.0.1", "172.15.0.1", false}, // Just before 172.16.0.0/12
{"172.32.0.1", "172.32.0.1", false}, // Just after 172.31.255.255
{"192.167.1.1", "192.167.1.1", false}, // Not 192.168
{"192.169.1.1", "192.169.1.1", false}, // Not 192.168
{"11.0.0.1", "11.0.0.1", false}, // Not 10.x.x.x

// IPv6 private (Unique Local Addresses fc00::/7)
{"fc00::1", "fc00::1", true},
{"fd00::1", "fd00::1", true},
{"fdff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", "fdff:ffff:ffff:ffff:ffff:ffff:ffff:ffff", true},

// IPv6 public (not private)
{"2001:4860:4860::8888", "2001:4860:4860::8888", false},
{"fe80::1", "fe80::1", false}, // Link-local, not ULA

// Edge cases
{"nil IP", "", false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var ip net.IP
if tt.ip != "" {
ip = net.ParseIP(tt.ip)
}
result := isPrivateIP(ip)
if result != tt.expected {
t.Errorf("isPrivateIP(%s) = %v, want %v", tt.ip, result, tt.expected)
}
})
}
}

// TestIsValidMsgWithPrivateIP tests that private IPs are not dropped
func TestIsValidMsgWithPrivateIP(t *testing.T) {
// Create a simple IP ranger with a test IP range (e.g., 1.0.0.0/8)
ipRanger := cidranger.NewPCTrieRanger()
_, network, _ := net.ParseCIDR("1.0.0.0/8")
ipRanger.Insert(cidranger.NewBasicRangerEntry(*network))

log := logger.New(false)

tests := []struct {
name string
isPrimary bool
ip string
shouldPass bool
reason string
}{
{
name: "Primary DNS with private IP 10.x",
isPrimary: true,
ip: "10.0.0.1",
shouldPass: true,
reason: "Private IPs should always be valid",
},
{
name: "Primary DNS with private IP 172.16.x",
isPrimary: true,
ip: "172.16.0.1",
shouldPass: true,
reason: "Private IPs should always be valid",
},
{
name: "Primary DNS with private IP 192.168.x",
isPrimary: true,
ip: "192.168.1.1",
shouldPass: true,
reason: "Private IPs should always be valid",
},
{
name: "Primary DNS with public non-primary IP",
isPrimary: true,
ip: "8.8.8.8",
shouldPass: false,
reason: "Public non-primary IPs should be rejected by primary DNS",
},
{
name: "Primary DNS with public primary IP",
isPrimary: true,
ip: "1.0.0.1",
shouldPass: true,
reason: "Primary IPs should be valid for primary DNS",
},
{
name: "Non-primary DNS with private IP",
isPrimary: false,
ip: "10.0.0.1",
shouldPass: true,
reason: "Private IPs should always be valid",
},
{
name: "Non-primary DNS with public IP",
isPrimary: false,
ip: "8.8.8.8",
shouldPass: true,
reason: "Non-primary DNS accepts any IP",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
config := &Config{
BlacklistSplited: [][]string{},
}

upstream := &Upstream{
IsPrimary: tt.isPrimary,
Address: "test://example.com:53",
config: config,
ipRanger: ipRanger,
logger: log,
}

// Create a DNS message with an A record
msg := &dns.Msg{
Answer: []dns.RR{
&dns.A{
Hdr: dns.RR_Header{
Name: "example.com.",
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
},
A: net.ParseIP(tt.ip).To4(),
},
},
}

result := upstream.IsValidMsg(msg)
if result != tt.shouldPass {
t.Errorf("%s: IsValidMsg() = %v, want %v. Reason: %s", tt.name, result, tt.shouldPass, tt.reason)
}
})
}
}