diff --git a/adapter/dns.go b/adapter/dns.go index 8f065e2e8..017feb596 100644 --- a/adapter/dns.go +++ b/adapter/dns.go @@ -25,8 +25,8 @@ type DNSRouter interface { type DNSClient interface { Start() - Exchange(ctx context.Context, transport DNSTransport, message *dns.Msg, options DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) (*dns.Msg, error) - Lookup(ctx context.Context, transport DNSTransport, domain string, options DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) + Exchange(ctx context.Context, transport DNSTransport, message *dns.Msg, options DNSQueryOptions, responseChecker func(response *dns.Msg) bool) (*dns.Msg, error) + Lookup(ctx context.Context, transport DNSTransport, domain string, options DNSQueryOptions, responseChecker func(response *dns.Msg) bool) ([]netip.Addr, error) ClearCache() } diff --git a/adapter/inbound.go b/adapter/inbound.go index c8f9224ee..048699f6d 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -4,11 +4,13 @@ import ( "context" "net" "net/netip" + "strings" "time" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" + "github.com/sagernet/sing/common" M "github.com/sagernet/sing/common/metadata" "github.com/miekg/dns" @@ -126,17 +128,27 @@ func (c *InboundContext) DestinationAddressesForMatch() []netip.Addr { return c.DestinationAddresses } +func (c *InboundContext) DNSResponseAddressesForMatch() []netip.Addr { + return DNSResponseAddresses(c.DNSResponse) +} + func DNSResponseAddresses(response *dns.Msg) []netip.Addr { if response == nil || response.Rcode != dns.RcodeSuccess { return nil } - var addresses []netip.Addr + addresses := make([]netip.Addr, 0, len(response.Answer)) for _, rawRecord := range response.Answer { switch record := rawRecord.(type) { case *dns.A: addresses = append(addresses, M.AddrFromIP(record.A)) case *dns.AAAA: addresses = append(addresses, M.AddrFromIP(record.AAAA)) + case *dns.HTTPS: + for _, value := range record.SVCB.Value { + if value.Key() == dns.SVCB_IPV4HINT || value.Key() == dns.SVCB_IPV6HINT { + addresses = append(addresses, common.Map(strings.Split(value.String(), ","), M.ParseAddr)...) + } + } } } return addresses diff --git a/adapter/rule.go b/adapter/rule.go index f8ee797d4..31ed9b424 100644 --- a/adapter/rule.go +++ b/adapter/rule.go @@ -2,6 +2,8 @@ package adapter import ( C "github.com/sagernet/sing-box/constant" + + "github.com/miekg/dns" ) type HeadlessRule interface { @@ -19,7 +21,7 @@ type Rule interface { type DNSRule interface { Rule WithAddressLimit() bool - MatchAddressLimit(metadata *InboundContext) bool + MatchAddressLimit(metadata *InboundContext, response *dns.Msg) bool } type RuleAction interface { diff --git a/dns/client.go b/dns/client.go index 70b53c951..e86a90977 100644 --- a/dns/client.go +++ b/dns/client.go @@ -5,7 +5,6 @@ import ( "errors" "net" "net/netip" - "strings" "time" "github.com/sagernet/sing-box/adapter" @@ -14,7 +13,6 @@ import ( "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" - M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/common/task" "github.com/sagernet/sing/contrab/freelru" "github.com/sagernet/sing/contrab/maphash" @@ -109,7 +107,7 @@ func extractNegativeTTL(response *dns.Msg) (uint32, bool) { return 0, false } -func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, message *dns.Msg, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) (*dns.Msg, error) { +func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, message *dns.Msg, options adapter.DNSQueryOptions, responseChecker func(response *dns.Msg) bool) (*dns.Msg, error) { if len(message.Question) == 0 { if c.logger != nil { c.logger.WarnContext(ctx, "bad question size: ", len(message.Question)) @@ -239,13 +237,10 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m disableCache = disableCache || (response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError) if responseChecker != nil { var rejected bool - // TODO: add accept_any rule and support to check response instead of addresses if response.Rcode != dns.RcodeSuccess && response.Rcode != dns.RcodeNameError { rejected = true - } else if len(response.Answer) == 0 { - rejected = !responseChecker(nil) } else { - rejected = !responseChecker(MessageToAddresses(response)) + rejected = !responseChecker(response) } if rejected { if !disableCache && c.rdrc != nil { @@ -315,7 +310,7 @@ func (c *Client) Exchange(ctx context.Context, transport adapter.DNSTransport, m return response, nil } -func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) { +func (c *Client) Lookup(ctx context.Context, transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions, responseChecker func(response *dns.Msg) bool) ([]netip.Addr, error) { domain = FqdnToDomain(domain) dnsName := dns.Fqdn(domain) var strategy C.DomainStrategy @@ -400,7 +395,7 @@ func (c *Client) storeCache(transport adapter.DNSTransport, question dns.Questio } } -func (c *Client) lookupToExchange(ctx context.Context, transport adapter.DNSTransport, name string, qType uint16, options adapter.DNSQueryOptions, responseChecker func(responseAddrs []netip.Addr) bool) ([]netip.Addr, error) { +func (c *Client) lookupToExchange(ctx context.Context, transport adapter.DNSTransport, name string, qType uint16, options adapter.DNSQueryOptions, responseChecker func(response *dns.Msg) bool) ([]netip.Addr, error) { question := dns.Question{ Name: name, Qtype: qType, @@ -515,25 +510,7 @@ func (c *Client) loadResponse(question dns.Question, transport adapter.DNSTransp } func MessageToAddresses(response *dns.Msg) []netip.Addr { - if response == nil || response.Rcode != dns.RcodeSuccess { - return nil - } - addresses := make([]netip.Addr, 0, len(response.Answer)) - for _, rawAnswer := range response.Answer { - switch answer := rawAnswer.(type) { - case *dns.A: - addresses = append(addresses, M.AddrFromIP(answer.A)) - case *dns.AAAA: - addresses = append(addresses, M.AddrFromIP(answer.AAAA)) - case *dns.HTTPS: - for _, value := range answer.SVCB.Value { - if value.Key() == dns.SVCB_IPV4HINT || value.Key() == dns.SVCB_IPV6HINT { - addresses = append(addresses, common.Map(strings.Split(value.String(), ","), M.ParseAddr)...) - } - } - } - } - return addresses + return adapter.DNSResponseAddresses(response) } func wrapError(err error) error { diff --git a/dns/router.go b/dns/router.go index f07d897c2..96abc1238 100644 --- a/dns/router.go +++ b/dns/router.go @@ -145,6 +145,7 @@ func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int, continue } metadata.ResetRuleCache() + metadata.DestinationAddressMatchFromResponse = false if currentRule.Match(metadata) { displayRuleIndex := currentRuleIndex if displayRuleIndex != -1 { @@ -285,6 +286,7 @@ func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, optio for currentRuleIndex, currentRule := range r.rules { metadata.ResetRuleCache() metadata.DNSResponse = savedResponse + metadata.DestinationAddressMatchFromResponse = false if !currentRule.Match(metadata) { continue } @@ -481,6 +483,8 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte ctx, metadata = adapter.ExtendContext(ctx) metadata.Destination = M.Socksaddr{} metadata.QueryType = message.Question[0].Qtype + metadata.DNSResponse = nil + metadata.DestinationAddressMatchFromResponse = false switch metadata.QueryType { case mDNS.TypeA: metadata.IPVersion = 4 @@ -596,6 +600,8 @@ func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQ ctx, metadata := adapter.ExtendContext(ctx) metadata.Destination = M.Socksaddr{} metadata.Domain = FqdnToDomain(domain) + metadata.DNSResponse = nil + metadata.DestinationAddressMatchFromResponse = false if options.Transport != nil { transport := options.Transport r.applyTransportDefaults(transport, &options) @@ -666,15 +672,15 @@ func isAddressQuery(message *mDNS.Msg) bool { return false } -func addressLimitResponseCheck(rule adapter.DNSRule, metadata *adapter.InboundContext) func(responseAddrs []netip.Addr) bool { +func addressLimitResponseCheck(rule adapter.DNSRule, metadata *adapter.InboundContext) func(response *mDNS.Msg) bool { if rule == nil || !rule.WithAddressLimit() { return nil } responseMetadata := *metadata - return func(responseAddrs []netip.Addr) bool { + return func(response *mDNS.Msg) bool { checkMetadata := responseMetadata - checkMetadata.DestinationAddresses = responseAddrs - return rule.MatchAddressLimit(&checkMetadata) + checkMetadata.DNSResponse = response + return rule.MatchAddressLimit(&checkMetadata, response) } } diff --git a/dns/router_test.go b/dns/router_test.go index 1fbb4c977..b83114bfd 100644 --- a/dns/router_test.go +++ b/dns/router_test.go @@ -3,6 +3,7 @@ package dns import ( "context" "errors" + "net" "net/netip" "testing" @@ -76,14 +77,14 @@ func (m *fakeDeprecatedManager) ReportDeprecated(feature deprecated.Note) { func (c *fakeDNSClient) Start() {} -func (c *fakeDNSClient) Exchange(ctx context.Context, transport adapter.DNSTransport, message *mDNS.Msg, _ adapter.DNSQueryOptions, _ func([]netip.Addr) bool) (*mDNS.Msg, error) { +func (c *fakeDNSClient) Exchange(ctx context.Context, transport adapter.DNSTransport, message *mDNS.Msg, _ adapter.DNSQueryOptions, _ func(*mDNS.Msg) bool) (*mDNS.Msg, error) { if c.beforeExchange != nil { c.beforeExchange(ctx, transport, message) } return c.exchange(transport, message) } -func (c *fakeDNSClient) Lookup(context.Context, adapter.DNSTransport, string, adapter.DNSQueryOptions, func([]netip.Addr) bool) ([]netip.Addr, error) { +func (c *fakeDNSClient) Lookup(context.Context, adapter.DNSTransport, string, adapter.DNSQueryOptions, func(*mDNS.Msg) bool) ([]netip.Addr, error) { return nil, errors.New("unused client lookup") } @@ -121,6 +122,49 @@ func mustRecord(t *testing.T, record string) option.DNSRecordOptions { return value } +func fixedHTTPSHintResponse(question mDNS.Question, addresses ...netip.Addr) *mDNS.Msg { + response := &mDNS.Msg{ + MsgHdr: mDNS.MsgHdr{ + Response: true, + Rcode: mDNS.RcodeSuccess, + }, + Question: []mDNS.Question{question}, + Answer: []mDNS.RR{ + &mDNS.HTTPS{ + SVCB: mDNS.SVCB{ + Hdr: mDNS.RR_Header{ + Name: question.Name, + Rrtype: mDNS.TypeHTTPS, + Class: mDNS.ClassINET, + Ttl: 60, + }, + Priority: 1, + Target: ".", + }, + }, + }, + } + https := response.Answer[0].(*mDNS.HTTPS) + var ( + hints4 []net.IP + hints6 []net.IP + ) + for _, address := range addresses { + if address.Is4() { + hints4 = append(hints4, net.IP(append([]byte(nil), address.AsSlice()...))) + } else { + hints6 = append(hints6, net.IP(append([]byte(nil), address.AsSlice()...))) + } + } + if len(hints4) > 0 { + https.SVCB.Value = append(https.SVCB.Value, &mDNS.SVCBIPv4Hint{Hint: hints4}) + } + if len(hints6) > 0 { + https.SVCB.Value = append(https.SVCB.Value, &mDNS.SVCBIPv6Hint{Hint: hints6}) + } + return response +} + func TestValidateNewDNSRules_RequireMatchResponseForDirectIPCIDR(t *testing.T) { t.Parallel() @@ -141,6 +185,17 @@ func TestValidateNewDNSRules_RequireMatchResponseForDirectIPCIDR(t *testing.T) { require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response") } +func TestDNSResponseAddressesMatchesMessageToAddressesForHTTPSHints(t *testing.T) { + t.Parallel() + + response := fixedHTTPSHintResponse(fixedQuestion("example.com", mDNS.TypeHTTPS), + netip.MustParseAddr("1.1.1.1"), + netip.MustParseAddr("2001:db8::1"), + ) + + require.Equal(t, MessageToAddresses(response), adapter.DNSResponseAddresses(response)) +} + func TestExchangeNewModeEvaluateMatchResponseRoute(t *testing.T) { t.Parallel() @@ -259,6 +314,65 @@ func TestExchangeNewModeEvaluateMatchResponseRouteIgnoresTTL(t *testing.T) { require.Equal(t, []netip.Addr{netip.MustParseAddr("8.8.8.8")}, MessageToAddresses(response)) } +func TestExchangeNewModeEvaluateMatchResponseRouteWithHTTPSHints(t *testing.T) { + t.Parallel() + + transportManager := &fakeDNSTransportManager{ + defaultTransport: &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}, + transports: map[string]adapter.DNSTransport{ + "upstream": &fakeDNSTransport{tag: "upstream", transportType: C.DNSTypeUDP}, + "selected": &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP}, + "default": &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}, + }, + } + client := &fakeDNSClient{ + exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) { + switch transport.Tag() { + case "upstream": + return fixedHTTPSHintResponse(message.Question[0], netip.MustParseAddr("1.1.1.1")), nil + case "selected": + return fixedHTTPSHintResponse(message.Question[0], netip.MustParseAddr("8.8.8.8")), nil + default: + return nil, errors.New("unexpected transport") + } + }, + } + rules := []option.DNSRule{ + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + Domain: badoption.Listable[string]{"example.com"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeEvaluate, + RouteOptions: option.DNSRouteActionOptions{Server: "upstream"}, + }, + }, + }, + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + MatchResponse: true, + IPCIDR: badoption.Listable[string]{"1.1.1.0/24"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "selected"}, + }, + }, + }, + } + router := newTestRouter(t, rules, transportManager, client) + + response, err := router.Exchange(context.Background(), &mDNS.Msg{ + Question: []mDNS.Question{fixedQuestion("example.com", mDNS.TypeHTTPS)}, + }, adapter.DNSQueryOptions{}) + require.NoError(t, err) + require.Equal(t, []netip.Addr{netip.MustParseAddr("8.8.8.8")}, MessageToAddresses(response)) +} + func TestExchangeNewModeEvaluateDoesNotLeakAddressesToNextQuery(t *testing.T) { t.Parallel() diff --git a/route/rule/rule_dns.go b/route/rule/rule_dns.go index f5914f89e..545d88f60 100644 --- a/route/rule/rule_dns.go +++ b/route/rule/rule_dns.go @@ -11,6 +11,8 @@ import ( "github.com/sagernet/sing/common" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/service" + + "github.com/miekg/dns" ) func NewDNSRule(ctx context.Context, logger log.ContextLogger, options option.DNSRule, checkServer bool, legacyAddressFilter bool) (adapter.DNSRule, error) { @@ -367,8 +369,11 @@ func (r *DefaultDNSRule) matchStatesForMatch(metadata *adapter.InboundContext) r return r.abstractDefaultRule.matchStates(&matchMetadata) } -func (r *DefaultDNSRule) MatchAddressLimit(metadata *adapter.InboundContext) bool { - return !r.matchStates(metadata).isEmpty() +func (r *DefaultDNSRule) MatchAddressLimit(metadata *adapter.InboundContext, response *dns.Msg) bool { + matchMetadata := *metadata + matchMetadata.DNSResponse = response + matchMetadata.DestinationAddressMatchFromResponse = true + return !r.abstractDefaultRule.matchStates(&matchMetadata).isEmpty() } var _ adapter.DNSRule = (*LogicalDNSRule)(nil) @@ -477,6 +482,9 @@ func (r *LogicalDNSRule) Match(metadata *adapter.InboundContext) bool { return !r.matchStatesForMatch(metadata).isEmpty() } -func (r *LogicalDNSRule) MatchAddressLimit(metadata *adapter.InboundContext) bool { - return !r.matchStates(metadata).isEmpty() +func (r *LogicalDNSRule) MatchAddressLimit(metadata *adapter.InboundContext, response *dns.Msg) bool { + matchMetadata := *metadata + matchMetadata.DNSResponse = response + matchMetadata.DestinationAddressMatchFromResponse = true + return !r.abstractLogicalRule.matchStates(&matchMetadata).isEmpty() } diff --git a/route/rule/rule_item_cidr.go b/route/rule/rule_item_cidr.go index 321b6e5d2..61612f88f 100644 --- a/route/rule/rule_item_cidr.go +++ b/route/rule/rule_item_cidr.go @@ -77,7 +77,7 @@ func (r *IPCIDRItem) Match(metadata *adapter.InboundContext) bool { return r.ipSet.Contains(metadata.Source.Addr) } if metadata.DestinationAddressMatchFromResponse { - for _, address := range metadata.DestinationAddressesForMatch() { + for _, address := range metadata.DNSResponseAddressesForMatch() { if r.ipSet.Contains(address) { return true } @@ -87,7 +87,7 @@ func (r *IPCIDRItem) Match(metadata *adapter.InboundContext) bool { if metadata.Destination.IsIP() { return r.ipSet.Contains(metadata.Destination.Addr) } - addresses := metadata.DestinationAddressesForMatch() + addresses := metadata.DestinationAddresses if len(addresses) > 0 { for _, address := range addresses { if r.ipSet.Contains(address) { diff --git a/route/rule/rule_item_ip_accept_any.go b/route/rule/rule_item_ip_accept_any.go index 1ca712573..fceebc186 100644 --- a/route/rule/rule_item_ip_accept_any.go +++ b/route/rule/rule_item_ip_accept_any.go @@ -13,6 +13,9 @@ func NewIPAcceptAnyItem() *IPAcceptAnyItem { } func (r *IPAcceptAnyItem) Match(metadata *adapter.InboundContext) bool { + if metadata.DestinationAddressMatchFromResponse { + return len(metadata.DNSResponseAddressesForMatch()) > 0 + } return len(metadata.DestinationAddresses) > 0 } diff --git a/route/rule/rule_item_ip_is_private.go b/route/rule/rule_item_ip_is_private.go index b71255bae..c96887739 100644 --- a/route/rule/rule_item_ip_is_private.go +++ b/route/rule/rule_item_ip_is_private.go @@ -20,7 +20,7 @@ func (r *IPIsPrivateItem) Match(metadata *adapter.InboundContext) bool { return !N.IsPublicAddr(metadata.Source.Addr) } if metadata.DestinationAddressMatchFromResponse { - for _, destinationAddress := range metadata.DestinationAddressesForMatch() { + for _, destinationAddress := range metadata.DNSResponseAddressesForMatch() { if !N.IsPublicAddr(destinationAddress) { return true } @@ -30,7 +30,7 @@ func (r *IPIsPrivateItem) Match(metadata *adapter.InboundContext) bool { if metadata.Destination.Addr.IsValid() { return !N.IsPublicAddr(metadata.Destination.Addr) } - for _, destinationAddress := range metadata.DestinationAddressesForMatch() { + for _, destinationAddress := range metadata.DestinationAddresses { if !N.IsPublicAddr(destinationAddress) { return true } diff --git a/route/rule/rule_set_semantics_test.go b/route/rule/rule_set_semantics_test.go index 17a7da9d5..553c2bf3e 100644 --- a/route/rule/rule_set_semantics_test.go +++ b/route/rule/rule_set_semantics_test.go @@ -583,7 +583,7 @@ func TestDNSRuleSetSemantics(t *testing.T) { addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{ruleSet}}) addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"}) }) - require.True(t, rule.MatchAddressLimit(&metadata)) + require.True(t, rule.MatchAddressLimit(&metadata, dnsResponseForTest(netip.MustParseAddr("203.0.113.1")))) }) t.Run("dns keeps ruleset or semantics", func(t *testing.T) { t.Parallel() @@ -598,7 +598,7 @@ func TestDNSRuleSetSemantics(t *testing.T) { addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{emptyStateSet, destinationStateSet}}) addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"}) }) - require.True(t, rule.MatchAddressLimit(&metadata)) + require.True(t, rule.MatchAddressLimit(&metadata, dnsResponseForTest(netip.MustParseAddr("203.0.113.1")))) }) t.Run("ruleset ip cidr flags stay scoped", func(t *testing.T) { t.Parallel() @@ -612,7 +612,7 @@ func TestDNSRuleSetSemantics(t *testing.T) { ipCidrAcceptEmpty: true, }) }) - require.True(t, rule.MatchAddressLimit(&metadata)) + require.True(t, rule.MatchAddressLimit(&metadata, dnsResponseForTest())) require.False(t, metadata.IPCIDRMatchSource) require.False(t, metadata.IPCIDRAcceptEmpty) }) @@ -639,6 +639,62 @@ func TestDNSMatchResponseRuleSetDestinationCIDRUsesDNSResponse(t *testing.T) { require.False(t, rule.Match(&unmatchedMetadata)) } +func TestDNSAddressLimitIgnoresDestinationAddresses(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + build func(*testing.T, *abstractDefaultRule) + matchedResponse *mDNS.Msg + unmatchedResponse *mDNS.Msg + }{ + { + name: "ip_cidr", + build: func(t *testing.T, rule *abstractDefaultRule) { + t.Helper() + addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"}) + }, + matchedResponse: dnsResponseForTest(netip.MustParseAddr("203.0.113.1")), + unmatchedResponse: dnsResponseForTest(netip.MustParseAddr("8.8.8.8")), + }, + { + name: "ip_is_private", + build: func(t *testing.T, rule *abstractDefaultRule) { + t.Helper() + addDestinationIPIsPrivateItem(rule) + }, + matchedResponse: dnsResponseForTest(netip.MustParseAddr("10.0.0.1")), + unmatchedResponse: dnsResponseForTest(netip.MustParseAddr("8.8.8.8")), + }, + { + name: "ip_accept_any", + build: func(t *testing.T, rule *abstractDefaultRule) { + t.Helper() + addDestinationIPAcceptAnyItem(rule) + }, + matchedResponse: dnsResponseForTest(netip.MustParseAddr("203.0.113.1")), + unmatchedResponse: dnsResponseForTest(), + }, + } + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + rule := dnsRuleForTest(func(rule *abstractDefaultRule) { + testCase.build(t, rule) + }) + + mismatchMetadata := testMetadata("lookup.example") + mismatchMetadata.DestinationAddresses = []netip.Addr{netip.MustParseAddr("203.0.113.1")} + require.False(t, rule.MatchAddressLimit(&mismatchMetadata, testCase.unmatchedResponse)) + + matchMetadata := testMetadata("lookup.example") + matchMetadata.DestinationAddresses = []netip.Addr{netip.MustParseAddr("8.8.8.8")} + require.True(t, rule.MatchAddressLimit(&matchMetadata, testCase.matchedResponse)) + }) + } +} + func TestDNSInvertAddressLimitPreLookupRegression(t *testing.T) { t.Parallel() testCases := []struct { @@ -688,11 +744,11 @@ func TestDNSInvertAddressLimitPreLookupRegression(t *testing.T) { matchedMetadata := testMetadata("lookup.example") matchedMetadata.DestinationAddresses = testCase.matchedAddrs - require.False(t, rule.MatchAddressLimit(&matchedMetadata)) + require.False(t, rule.MatchAddressLimit(&matchedMetadata, dnsResponseForTest(testCase.matchedAddrs...))) unmatchedMetadata := testMetadata("lookup.example") unmatchedMetadata.DestinationAddresses = testCase.unmatchedAddrs - require.True(t, rule.MatchAddressLimit(&unmatchedMetadata)) + require.True(t, rule.MatchAddressLimit(&unmatchedMetadata, dnsResponseForTest(testCase.unmatchedAddrs...))) }) } t.Run("mixed resolved and deferred fields keep old pre lookup false", func(t *testing.T) {