mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-13 20:28:32 +10:00
dns: use response-only address matching
This commit is contained in:
@@ -25,8 +25,8 @@ type DNSRouter interface {
|
||||
|
||||
type DNSClient interface {
|
||||
Start()
|
||||
Exchange(ctx context.Context, transport DNSTransport, message *dns.Msg, options DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) (*dns.Msg, error)
|
||||
Lookup(ctx context.Context, transport DNSTransport, domain string, options DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error)
|
||||
Exchange(ctx context.Context, transport DNSTransport, message *dns.Msg, options DNSQueryOptions, responseChecker func(response *dns.Msg) bool) (*dns.Msg, error)
|
||||
Lookup(ctx context.Context, transport DNSTransport, domain string, options DNSQueryOptions, responseChecker func(response *dns.Msg) bool) ([]netip.Addr, error)
|
||||
ClearCache()
|
||||
}
|
||||
|
||||
|
||||
@@ -4,11 +4,13 @@ import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
"github.com/sagernet/sing-box/log"
|
||||
"github.com/sagernet/sing-box/option"
|
||||
"github.com/sagernet/sing/common"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
@@ -126,17 +128,27 @@ func (c *InboundContext) DestinationAddressesForMatch() []netip.Addr {
|
||||
return c.DestinationAddresses
|
||||
}
|
||||
|
||||
func (c *InboundContext) DNSResponseAddressesForMatch() []netip.Addr {
|
||||
return DNSResponseAddresses(c.DNSResponse)
|
||||
}
|
||||
|
||||
func DNSResponseAddresses(response *dns.Msg) []netip.Addr {
|
||||
if response == nil || response.Rcode != dns.RcodeSuccess {
|
||||
return nil
|
||||
}
|
||||
var addresses []netip.Addr
|
||||
addresses := make([]netip.Addr, 0, len(response.Answer))
|
||||
for _, rawRecord := range response.Answer {
|
||||
switch record := rawRecord.(type) {
|
||||
case *dns.A:
|
||||
addresses = append(addresses, M.AddrFromIP(record.A))
|
||||
case *dns.AAAA:
|
||||
addresses = append(addresses, M.AddrFromIP(record.AAAA))
|
||||
case *dns.HTTPS:
|
||||
for _, value := range record.SVCB.Value {
|
||||
if value.Key() == dns.SVCB_IPV4HINT || value.Key() == dns.SVCB_IPV6HINT {
|
||||
addresses = append(addresses, common.Map(strings.Split(value.String(), ","), M.ParseAddr)...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return addresses
|
||||
|
||||
@@ -2,6 +2,8 @@ package adapter
|
||||
|
||||
import (
|
||||
C "github.com/sagernet/sing-box/constant"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
type HeadlessRule interface {
|
||||
@@ -19,7 +21,7 @@ type Rule interface {
|
||||
type DNSRule interface {
|
||||
Rule
|
||||
WithAddressLimit() bool
|
||||
MatchAddressLimit(metadata *InboundContext) bool
|
||||
MatchAddressLimit(metadata *InboundContext, response *dns.Msg) bool
|
||||
}
|
||||
|
||||
type RuleAction interface {
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
@@ -14,7 +13,6 @@ import (
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/logger"
|
||||
M "github.com/sagernet/sing/common/metadata"
|
||||
"github.com/sagernet/sing/common/task"
|
||||
"github.com/sagernet/sing/contrab/freelru"
|
||||
"github.com/sagernet/sing/contrab/maphash"
|
||||
@@ -109,7 +107,7 @@ func extractNegativeTTL(response *dns.Msg) (uint32, bool) {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, message *dns.Msg, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) (*dns.Msg, error) {
|
||||
func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, message *dns.Msg, options adapter.DNSQueryOptions, responseChecker func(response *dns.Msg) bool) (*dns.Msg, error) {
|
||||
if len(message.Question) == 0 {
|
||||
if c.logger != nil {
|
||||
c.logger.WarnContext(ctx, "bad question size: ", len(message.Question))
|
||||
@@ -239,13 +237,10 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
|
||||
disableCache = disableCache || (response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError)
|
||||
if responseChecker != nil {
|
||||
var rejected bool
|
||||
// TODO: add accept_any rule and support to check response instead of addresses
|
||||
if response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError {
|
||||
rejected = true
|
||||
} else if len(response.Answer) == 0 {
|
||||
rejected = !responseChecker(nil)
|
||||
} else {
|
||||
rejected = !responseChecker(MessageToAddresses(response))
|
||||
rejected = !responseChecker(response)
|
||||
}
|
||||
if rejected {
|
||||
if !disableCache && c.rdrc != nil {
|
||||
@@ -315,7 +310,7 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) {
|
||||
func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions, responseChecker func(response *dns.Msg) bool) ([]netip.Addr, error) {
|
||||
domain = FqdnToDomain(domain)
|
||||
dnsName := dns.Fqdn(domain)
|
||||
var strategy C.DomainStrategy
|
||||
@@ -400,7 +395,7 @@ func (c *Client) storeCache(transport adapter.DNSTransport, question dns.Questio
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) lookupToExchange(ctx context.Context, transport adapter.DNSTransport, name string, qType uint16, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) {
|
||||
func (c *Client) lookupToExchange(ctx context.Context, transport adapter.DNSTransport, name string, qType uint16, options adapter.DNSQueryOptions, responseChecker func(response *dns.Msg) bool) ([]netip.Addr, error) {
|
||||
question := dns.Question{
|
||||
Name: name,
|
||||
Qtype: qType,
|
||||
@@ -515,25 +510,7 @@ func (c *Client) loadResponse(question dns.Question, transport adapter.DNSTransp
|
||||
}
|
||||
|
||||
func MessageToAddresses(response *dns.Msg) []netip.Addr {
|
||||
if response == nil || response.Rcode != dns.RcodeSuccess {
|
||||
return nil
|
||||
}
|
||||
addresses := make([]netip.Addr, 0, len(response.Answer))
|
||||
for _, rawAnswer := range response.Answer {
|
||||
switch answer := rawAnswer.(type) {
|
||||
case *dns.A:
|
||||
addresses = append(addresses, M.AddrFromIP(answer.A))
|
||||
case *dns.AAAA:
|
||||
addresses = append(addresses, M.AddrFromIP(answer.AAAA))
|
||||
case *dns.HTTPS:
|
||||
for _, value := range answer.SVCB.Value {
|
||||
if value.Key() == dns.SVCB_IPV4HINT || value.Key() == dns.SVCB_IPV6HINT {
|
||||
addresses = append(addresses, common.Map(strings.Split(value.String(), ","), M.ParseAddr)...)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return addresses
|
||||
return adapter.DNSResponseAddresses(response)
|
||||
}
|
||||
|
||||
func wrapError(err error) error {
|
||||
|
||||
@@ -145,6 +145,7 @@ func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int,
|
||||
continue
|
||||
}
|
||||
metadata.ResetRuleCache()
|
||||
metadata.DestinationAddressMatchFromResponse = false
|
||||
if currentRule.Match(metadata) {
|
||||
displayRuleIndex := currentRuleIndex
|
||||
if displayRuleIndex != -1 {
|
||||
@@ -285,6 +286,7 @@ func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, optio
|
||||
for currentRuleIndex, currentRule := range r.rules {
|
||||
metadata.ResetRuleCache()
|
||||
metadata.DNSResponse = savedResponse
|
||||
metadata.DestinationAddressMatchFromResponse = false
|
||||
if !currentRule.Match(metadata) {
|
||||
continue
|
||||
}
|
||||
@@ -481,6 +483,8 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte
|
||||
ctx, metadata = adapter.ExtendContext(ctx)
|
||||
metadata.Destination = M.Socksaddr{}
|
||||
metadata.QueryType = message.Question[0].Qtype
|
||||
metadata.DNSResponse = nil
|
||||
metadata.DestinationAddressMatchFromResponse = false
|
||||
switch metadata.QueryType {
|
||||
case mDNS.TypeA:
|
||||
metadata.IPVersion = 4
|
||||
@@ -596,6 +600,8 @@ func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQ
|
||||
ctx, metadata := adapter.ExtendContext(ctx)
|
||||
metadata.Destination = M.Socksaddr{}
|
||||
metadata.Domain = FqdnToDomain(domain)
|
||||
metadata.DNSResponse = nil
|
||||
metadata.DestinationAddressMatchFromResponse = false
|
||||
if options.Transport != nil {
|
||||
transport := options.Transport
|
||||
r.applyTransportDefaults(transport, &options)
|
||||
@@ -666,15 +672,15 @@ func isAddressQuery(message *mDNS.Msg) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func addressLimitResponseCheck(rule adapter.DNSRule, metadata *adapter.InboundContext) func(responseAddrs []netip.Addr) bool {
|
||||
func addressLimitResponseCheck(rule adapter.DNSRule, metadata *adapter.InboundContext) func(response *mDNS.Msg) bool {
|
||||
if rule == nil || !rule.WithAddressLimit() {
|
||||
return nil
|
||||
}
|
||||
responseMetadata := *metadata
|
||||
return func(responseAddrs []netip.Addr) bool {
|
||||
return func(response *mDNS.Msg) bool {
|
||||
checkMetadata := responseMetadata
|
||||
checkMetadata.DestinationAddresses = responseAddrs
|
||||
return rule.MatchAddressLimit(&checkMetadata)
|
||||
checkMetadata.DNSResponse = response
|
||||
return rule.MatchAddressLimit(&checkMetadata, response)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package dns
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
@@ -76,14 +77,14 @@ func (m *fakeDeprecatedManager) ReportDeprecated(feature deprecated.Note) {
|
||||
|
||||
func (c *fakeDNSClient) Start() {}
|
||||
|
||||
func (c *fakeDNSClient) Exchange(ctx context.Context, transport adapter.DNSTransport, message *mDNS.Msg, _ adapter.DNSQueryOptions, _ func([]netip.Addr) bool) (*mDNS.Msg, error) {
|
||||
func (c *fakeDNSClient) Exchange(ctx context.Context, transport adapter.DNSTransport, message *mDNS.Msg, _ adapter.DNSQueryOptions, _ func(*mDNS.Msg) bool) (*mDNS.Msg, error) {
|
||||
if c.beforeExchange != nil {
|
||||
c.beforeExchange(ctx, transport, message)
|
||||
}
|
||||
return c.exchange(transport, message)
|
||||
}
|
||||
|
||||
func (c *fakeDNSClient) Lookup(context.Context, adapter.DNSTransport, string, adapter.DNSQueryOptions, func([]netip.Addr) bool) ([]netip.Addr, error) {
|
||||
func (c *fakeDNSClient) Lookup(context.Context, adapter.DNSTransport, string, adapter.DNSQueryOptions, func(*mDNS.Msg) bool) ([]netip.Addr, error) {
|
||||
return nil, errors.New("unused client lookup")
|
||||
}
|
||||
|
||||
@@ -121,6 +122,49 @@ func mustRecord(t *testing.T, record string) option.DNSRecordOptions {
|
||||
return value
|
||||
}
|
||||
|
||||
func fixedHTTPSHintResponse(question mDNS.Question, addresses ...netip.Addr) *mDNS.Msg {
|
||||
response := &mDNS.Msg{
|
||||
MsgHdr: mDNS.MsgHdr{
|
||||
Response: true,
|
||||
Rcode: mDNS.RcodeSuccess,
|
||||
},
|
||||
Question: []mDNS.Question{question},
|
||||
Answer: []mDNS.RR{
|
||||
&mDNS.HTTPS{
|
||||
SVCB: mDNS.SVCB{
|
||||
Hdr: mDNS.RR_Header{
|
||||
Name: question.Name,
|
||||
Rrtype: mDNS.TypeHTTPS,
|
||||
Class: mDNS.ClassINET,
|
||||
Ttl: 60,
|
||||
},
|
||||
Priority: 1,
|
||||
Target: ".",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
https := response.Answer[0].(*mDNS.HTTPS)
|
||||
var (
|
||||
hints4 []net.IP
|
||||
hints6 []net.IP
|
||||
)
|
||||
for _, address := range addresses {
|
||||
if address.Is4() {
|
||||
hints4 = append(hints4, net.IP(append([]byte(nil), address.AsSlice()...)))
|
||||
} else {
|
||||
hints6 = append(hints6, net.IP(append([]byte(nil), address.AsSlice()...)))
|
||||
}
|
||||
}
|
||||
if len(hints4) > 0 {
|
||||
https.SVCB.Value = append(https.SVCB.Value, &mDNS.SVCBIPv4Hint{Hint: hints4})
|
||||
}
|
||||
if len(hints6) > 0 {
|
||||
https.SVCB.Value = append(https.SVCB.Value, &mDNS.SVCBIPv6Hint{Hint: hints6})
|
||||
}
|
||||
return response
|
||||
}
|
||||
|
||||
func TestValidateNewDNSRules_RequireMatchResponseForDirectIPCIDR(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -141,6 +185,17 @@ func TestValidateNewDNSRules_RequireMatchResponseForDirectIPCIDR(t *testing.T) {
|
||||
require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response")
|
||||
}
|
||||
|
||||
func TestDNSResponseAddressesMatchesMessageToAddressesForHTTPSHints(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
response := fixedHTTPSHintResponse(fixedQuestion("example.com", mDNS.TypeHTTPS),
|
||||
netip.MustParseAddr("1.1.1.1"),
|
||||
netip.MustParseAddr("2001:db8::1"),
|
||||
)
|
||||
|
||||
require.Equal(t, MessageToAddresses(response), adapter.DNSResponseAddresses(response))
|
||||
}
|
||||
|
||||
func TestExchangeNewModeEvaluateMatchResponseRoute(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -259,6 +314,65 @@ func TestExchangeNewModeEvaluateMatchResponseRouteIgnoresTTL(t *testing.T) {
|
||||
require.Equal(t, []netip.Addr{netip.MustParseAddr("8.8.8.8")}, MessageToAddresses(response))
|
||||
}
|
||||
|
||||
func TestExchangeNewModeEvaluateMatchResponseRouteWithHTTPSHints(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
transportManager := &fakeDNSTransportManager{
|
||||
defaultTransport: &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP},
|
||||
transports: map[string]adapter.DNSTransport{
|
||||
"upstream": &fakeDNSTransport{tag: "upstream", transportType: C.DNSTypeUDP},
|
||||
"selected": &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP},
|
||||
"default": &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP},
|
||||
},
|
||||
}
|
||||
client := &fakeDNSClient{
|
||||
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||
switch transport.Tag() {
|
||||
case "upstream":
|
||||
return fixedHTTPSHintResponse(message.Question[0], netip.MustParseAddr("1.1.1.1")), nil
|
||||
case "selected":
|
||||
return fixedHTTPSHintResponse(message.Question[0], netip.MustParseAddr("8.8.8.8")), nil
|
||||
default:
|
||||
return nil, errors.New("unexpected transport")
|
||||
}
|
||||
},
|
||||
}
|
||||
rules := []option.DNSRule{
|
||||
{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
Domain: badoption.Listable[string]{"example.com"},
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeEvaluate,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "upstream"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
MatchResponse: true,
|
||||
IPCIDR: badoption.Listable[string]{"1.1.1.0/24"},
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeRoute,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "selected"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
router := newTestRouter(t, rules, transportManager, client)
|
||||
|
||||
response, err := router.Exchange(context.Background(), &mDNS.Msg{
|
||||
Question: []mDNS.Question{fixedQuestion("example.com", mDNS.TypeHTTPS)},
|
||||
}, adapter.DNSQueryOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []netip.Addr{netip.MustParseAddr("8.8.8.8")}, MessageToAddresses(response))
|
||||
}
|
||||
|
||||
func TestExchangeNewModeEvaluateDoesNotLeakAddressesToNextQuery(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/service"
|
||||
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
func NewDNSRule(ctx context.Context, logger log.ContextLogger, options option.DNSRule, checkServer bool, legacyAddressFilter bool) (adapter.DNSRule, error) {
|
||||
@@ -367,8 +369,11 @@ func (r *DefaultDNSRule) matchStatesForMatch(metadata *adapter.InboundContext) r
|
||||
return r.abstractDefaultRule.matchStates(&matchMetadata)
|
||||
}
|
||||
|
||||
func (r *DefaultDNSRule) MatchAddressLimit(metadata *adapter.InboundContext) bool {
|
||||
return !r.matchStates(metadata).isEmpty()
|
||||
func (r *DefaultDNSRule) MatchAddressLimit(metadata *adapter.InboundContext, response *dns.Msg) bool {
|
||||
matchMetadata := *metadata
|
||||
matchMetadata.DNSResponse = response
|
||||
matchMetadata.DestinationAddressMatchFromResponse = true
|
||||
return !r.abstractDefaultRule.matchStates(&matchMetadata).isEmpty()
|
||||
}
|
||||
|
||||
var _ adapter.DNSRule = (*LogicalDNSRule)(nil)
|
||||
@@ -477,6 +482,9 @@ func (r *LogicalDNSRule) Match(metadata *adapter.InboundContext) bool {
|
||||
return !r.matchStatesForMatch(metadata).isEmpty()
|
||||
}
|
||||
|
||||
func (r *LogicalDNSRule) MatchAddressLimit(metadata *adapter.InboundContext) bool {
|
||||
return !r.matchStates(metadata).isEmpty()
|
||||
func (r *LogicalDNSRule) MatchAddressLimit(metadata *adapter.InboundContext, response *dns.Msg) bool {
|
||||
matchMetadata := *metadata
|
||||
matchMetadata.DNSResponse = response
|
||||
matchMetadata.DestinationAddressMatchFromResponse = true
|
||||
return !r.abstractLogicalRule.matchStates(&matchMetadata).isEmpty()
|
||||
}
|
||||
|
||||
@@ -77,7 +77,7 @@ func (r *IPCIDRItem) Match(metadata *adapter.InboundContext) bool {
|
||||
return r.ipSet.Contains(metadata.Source.Addr)
|
||||
}
|
||||
if metadata.DestinationAddressMatchFromResponse {
|
||||
for _, address := range metadata.DestinationAddressesForMatch() {
|
||||
for _, address := range metadata.DNSResponseAddressesForMatch() {
|
||||
if r.ipSet.Contains(address) {
|
||||
return true
|
||||
}
|
||||
@@ -87,7 +87,7 @@ func (r *IPCIDRItem) Match(metadata *adapter.InboundContext) bool {
|
||||
if metadata.Destination.IsIP() {
|
||||
return r.ipSet.Contains(metadata.Destination.Addr)
|
||||
}
|
||||
addresses := metadata.DestinationAddressesForMatch()
|
||||
addresses := metadata.DestinationAddresses
|
||||
if len(addresses) > 0 {
|
||||
for _, address := range addresses {
|
||||
if r.ipSet.Contains(address) {
|
||||
|
||||
@@ -13,6 +13,9 @@ func NewIPAcceptAnyItem() *IPAcceptAnyItem {
|
||||
}
|
||||
|
||||
func (r *IPAcceptAnyItem) Match(metadata *adapter.InboundContext) bool {
|
||||
if metadata.DestinationAddressMatchFromResponse {
|
||||
return len(metadata.DNSResponseAddressesForMatch()) > 0
|
||||
}
|
||||
return len(metadata.DestinationAddresses) > 0
|
||||
}
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ func (r *IPIsPrivateItem) Match(metadata *adapter.InboundContext) bool {
|
||||
return !N.IsPublicAddr(metadata.Source.Addr)
|
||||
}
|
||||
if metadata.DestinationAddressMatchFromResponse {
|
||||
for _, destinationAddress := range metadata.DestinationAddressesForMatch() {
|
||||
for _, destinationAddress := range metadata.DNSResponseAddressesForMatch() {
|
||||
if !N.IsPublicAddr(destinationAddress) {
|
||||
return true
|
||||
}
|
||||
@@ -30,7 +30,7 @@ func (r *IPIsPrivateItem) Match(metadata *adapter.InboundContext) bool {
|
||||
if metadata.Destination.Addr.IsValid() {
|
||||
return !N.IsPublicAddr(metadata.Destination.Addr)
|
||||
}
|
||||
for _, destinationAddress := range metadata.DestinationAddressesForMatch() {
|
||||
for _, destinationAddress := range metadata.DestinationAddresses {
|
||||
if !N.IsPublicAddr(destinationAddress) {
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -583,7 +583,7 @@ func TestDNSRuleSetSemantics(t *testing.T) {
|
||||
addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{ruleSet}})
|
||||
addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"})
|
||||
})
|
||||
require.True(t, rule.MatchAddressLimit(&metadata))
|
||||
require.True(t, rule.MatchAddressLimit(&metadata, dnsResponseForTest(netip.MustParseAddr("203.0.113.1"))))
|
||||
})
|
||||
t.Run("dns keeps ruleset or semantics", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -598,7 +598,7 @@ func TestDNSRuleSetSemantics(t *testing.T) {
|
||||
addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{emptyStateSet, destinationStateSet}})
|
||||
addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"})
|
||||
})
|
||||
require.True(t, rule.MatchAddressLimit(&metadata))
|
||||
require.True(t, rule.MatchAddressLimit(&metadata, dnsResponseForTest(netip.MustParseAddr("203.0.113.1"))))
|
||||
})
|
||||
t.Run("ruleset ip cidr flags stay scoped", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
@@ -612,7 +612,7 @@ func TestDNSRuleSetSemantics(t *testing.T) {
|
||||
ipCidrAcceptEmpty: true,
|
||||
})
|
||||
})
|
||||
require.True(t, rule.MatchAddressLimit(&metadata))
|
||||
require.True(t, rule.MatchAddressLimit(&metadata, dnsResponseForTest()))
|
||||
require.False(t, metadata.IPCIDRMatchSource)
|
||||
require.False(t, metadata.IPCIDRAcceptEmpty)
|
||||
})
|
||||
@@ -639,6 +639,62 @@ func TestDNSMatchResponseRuleSetDestinationCIDRUsesDNSResponse(t *testing.T) {
|
||||
require.False(t, rule.Match(&unmatchedMetadata))
|
||||
}
|
||||
|
||||
func TestDNSAddressLimitIgnoresDestinationAddresses(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
build func(*testing.T, *abstractDefaultRule)
|
||||
matchedResponse *mDNS.Msg
|
||||
unmatchedResponse *mDNS.Msg
|
||||
}{
|
||||
{
|
||||
name: "ip_cidr",
|
||||
build: func(t *testing.T, rule *abstractDefaultRule) {
|
||||
t.Helper()
|
||||
addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"})
|
||||
},
|
||||
matchedResponse: dnsResponseForTest(netip.MustParseAddr("203.0.113.1")),
|
||||
unmatchedResponse: dnsResponseForTest(netip.MustParseAddr("8.8.8.8")),
|
||||
},
|
||||
{
|
||||
name: "ip_is_private",
|
||||
build: func(t *testing.T, rule *abstractDefaultRule) {
|
||||
t.Helper()
|
||||
addDestinationIPIsPrivateItem(rule)
|
||||
},
|
||||
matchedResponse: dnsResponseForTest(netip.MustParseAddr("10.0.0.1")),
|
||||
unmatchedResponse: dnsResponseForTest(netip.MustParseAddr("8.8.8.8")),
|
||||
},
|
||||
{
|
||||
name: "ip_accept_any",
|
||||
build: func(t *testing.T, rule *abstractDefaultRule) {
|
||||
t.Helper()
|
||||
addDestinationIPAcceptAnyItem(rule)
|
||||
},
|
||||
matchedResponse: dnsResponseForTest(netip.MustParseAddr("203.0.113.1")),
|
||||
unmatchedResponse: dnsResponseForTest(),
|
||||
},
|
||||
}
|
||||
for _, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
rule := dnsRuleForTest(func(rule *abstractDefaultRule) {
|
||||
testCase.build(t, rule)
|
||||
})
|
||||
|
||||
mismatchMetadata := testMetadata("lookup.example")
|
||||
mismatchMetadata.DestinationAddresses = []netip.Addr{netip.MustParseAddr("203.0.113.1")}
|
||||
require.False(t, rule.MatchAddressLimit(&mismatchMetadata, testCase.unmatchedResponse))
|
||||
|
||||
matchMetadata := testMetadata("lookup.example")
|
||||
matchMetadata.DestinationAddresses = []netip.Addr{netip.MustParseAddr("8.8.8.8")}
|
||||
require.True(t, rule.MatchAddressLimit(&matchMetadata, testCase.matchedResponse))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDNSInvertAddressLimitPreLookupRegression(t *testing.T) {
|
||||
t.Parallel()
|
||||
testCases := []struct {
|
||||
@@ -688,11 +744,11 @@ func TestDNSInvertAddressLimitPreLookupRegression(t *testing.T) {
|
||||
|
||||
matchedMetadata := testMetadata("lookup.example")
|
||||
matchedMetadata.DestinationAddresses = testCase.matchedAddrs
|
||||
require.False(t, rule.MatchAddressLimit(&matchedMetadata))
|
||||
require.False(t, rule.MatchAddressLimit(&matchedMetadata, dnsResponseForTest(testCase.matchedAddrs...)))
|
||||
|
||||
unmatchedMetadata := testMetadata("lookup.example")
|
||||
unmatchedMetadata.DestinationAddresses = testCase.unmatchedAddrs
|
||||
require.True(t, rule.MatchAddressLimit(&unmatchedMetadata))
|
||||
require.True(t, rule.MatchAddressLimit(&unmatchedMetadata, dnsResponseForTest(testCase.unmatchedAddrs...)))
|
||||
})
|
||||
}
|
||||
t.Run("mixed resolved and deferred fields keep old pre lookup false", func(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user