From 6c5f351dcf7f20a340b4075a3717d1d3fac4a310 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 25 Mar 2026 20:05:50 +0800 Subject: [PATCH] dns: preserve legacy address-filter pre-match semantics Legacy DNS address-filter mode still accepts destination-side IP predicates with a deprecation warning, but the recent evaluate/ match_response refactor started evaluating those predicates during pre-response Match(). That broke rules whose transport selection must be deferred until MatchAddressLimit() can inspect the upstream reply. Restore the old defer behavior by reintroducing an internal IgnoreDestinationIPCIDRMatch flag on InboundContext and using it only for legacy pre-response DNS matching. Default and logical DNS rules now carry the legacy mode bit, set the ignore flag on metadata copies while performing pre-response Match(), and explicitly clear it again for match_response and MatchAddressLimit() so response-phase matching still checks the returned addresses. Add regression coverage for direct legacy destination-IP rules, rule_set-backed CIDR rules, logical wrappers, and the legacy Lookup router path, including fallback after a rejected response. This keeps legacy configs working without changing new-mode evaluate semantics. Tests: go test ./route/rule ./dns Tests: make --- adapter/inbound.go | 11 +-- dns/router_test.go | 117 +++++++++++++++++++++++++- route/rule/rule_abstract.go | 2 +- route/rule/rule_dns.go | 22 ++++- route/rule/rule_set_semantics_test.go | 104 +++++++++++++++++++++++ 5 files changed, 244 insertions(+), 12 deletions(-) diff --git a/adapter/inbound.go b/adapter/inbound.go index 5bc147436..048699f6d 100644 --- a/adapter/inbound.go +++ b/adapter/inbound.go @@ -99,11 +99,12 @@ type InboundContext struct { IPCIDRMatchSource bool IPCIDRAcceptEmpty bool - SourceAddressMatch bool - SourcePortMatch bool - DestinationAddressMatch bool - DestinationPortMatch bool - DidMatch bool + SourceAddressMatch bool + SourcePortMatch bool + DestinationAddressMatch bool + DestinationPortMatch bool + DidMatch bool + IgnoreDestinationIPCIDRMatch bool } func (c *InboundContext) ResetRuleCache() { diff --git a/dns/router_test.go b/dns/router_test.go index f5e06cba0..8176c96b4 100644 --- a/dns/router_test.go +++ b/dns/router_test.go @@ -65,6 +65,7 @@ func (m *fakeDNSTransportManager) Create(context.Context, log.ContextLogger, str type fakeDNSClient struct { beforeExchange func(ctx context.Context, transport adapter.DNSTransport, message *mDNS.Msg) exchange func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) + lookup func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) } type fakeDeprecatedManager struct { @@ -84,8 +85,24 @@ func (c *fakeDNSClient) Exchange(ctx context.Context, transport adapter.DNSTrans return c.exchange(transport, message) } -func (c *fakeDNSClient) Lookup(context.Context, adapter.DNSTransport, string, adapter.DNSQueryOptions, func(*mDNS.Msg) bool) ([]netip.Addr, error) { - return nil, errors.New("unused client lookup") +func (c *fakeDNSClient) Lookup(_ context.Context, transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions, responseChecker func(*mDNS.Msg) bool) ([]netip.Addr, error) { + if c.lookup == nil { + return nil, errors.New("unused client lookup") + } + addresses, response, err := c.lookup(transport, domain, options) + if err != nil { + return nil, err + } + if response == nil { + response = FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), addresses, 60) + } + if responseChecker != nil && !responseChecker(response) { + return nil, ErrResponseRejected + } + if addresses != nil { + return addresses, nil + } + return MessageToAddresses(response), nil } func (c *fakeDNSClient) ClearCache() {} @@ -185,6 +202,102 @@ func TestValidateNewDNSRules_RequireMatchResponseForDirectIPCIDR(t *testing.T) { require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response") } +func TestLookupLegacyModeDefersDirectDestinationIPMatch(t *testing.T) { + t.Parallel() + + defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} + privateTransport := &fakeDNSTransport{tag: "private", transportType: C.DNSTypeUDP} + client := &fakeDNSClient{ + lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) { + require.Equal(t, "example.com", domain) + require.Equal(t, C.DomainStrategyIPv4Only, options.LookupStrategy) + switch transport.Tag() { + case "private": + response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60) + return MessageToAddresses(response), response, nil + case "default": + t.Fatal("default transport should not be used when legacy rule matches after response") + } + return nil, nil, errors.New("unexpected transport") + }, + } + router := newTestRouter(t, []option.DNSRule{{ + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + IPIsPrivate: true, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "private"}, + }, + }, + }}, &fakeDNSTransportManager{ + defaultTransport: defaultTransport, + transports: map[string]adapter.DNSTransport{ + "default": defaultTransport, + "private": privateTransport, + }, + }, client) + + require.True(t, router.legacyAddressFilterMode) + + addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{ + LookupStrategy: C.DomainStrategyIPv4Only, + }) + require.NoError(t, err) + require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses) +} + +func TestLookupLegacyModeFallsBackAfterRejectedAddressLimitResponse(t *testing.T) { + t.Parallel() + + defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} + privateTransport := &fakeDNSTransport{tag: "private", transportType: C.DNSTypeUDP} + var lookups []string + client := &fakeDNSClient{ + lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) { + require.Equal(t, "example.com", domain) + require.Equal(t, C.DomainStrategyIPv4Only, options.LookupStrategy) + lookups = append(lookups, transport.Tag()) + switch transport.Tag() { + case "private": + response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60) + return MessageToAddresses(response), response, nil + case "default": + response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("9.9.9.9")}, 60) + return MessageToAddresses(response), response, nil + } + return nil, nil, errors.New("unexpected transport") + }, + } + router := newTestRouter(t, []option.DNSRule{{ + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + IPIsPrivate: true, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "private"}, + }, + }, + }}, &fakeDNSTransportManager{ + defaultTransport: defaultTransport, + transports: map[string]adapter.DNSTransport{ + "default": defaultTransport, + "private": privateTransport, + }, + }, client) + + addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{ + LookupStrategy: C.DomainStrategyIPv4Only, + }) + require.NoError(t, err) + require.Equal(t, []netip.Addr{netip.MustParseAddr("9.9.9.9")}, addresses) + require.Equal(t, []string{"private", "default"}, lookups) +} + func TestDNSResponseAddressesMatchesMessageToAddressesForHTTPSHints(t *testing.T) { t.Parallel() diff --git a/route/rule/rule_abstract.go b/route/rule/rule_abstract.go index 8ec57aac3..ca508330b 100644 --- a/route/rule/rule_abstract.go +++ b/route/rule/rule_abstract.go @@ -60,7 +60,7 @@ func (r *abstractDefaultRule) destinationIPCIDRMatchesSource(metadata *adapter.I } func (r *abstractDefaultRule) destinationIPCIDRMatchesDestination(metadata *adapter.InboundContext) bool { - return !metadata.IPCIDRMatchSource && len(r.destinationIPCIDRItems) > 0 + return !metadata.IgnoreDestinationIPCIDRMatch && !metadata.IPCIDRMatchSource && len(r.destinationIPCIDRItems) > 0 } func (r *abstractDefaultRule) requiresSourceAddressMatch(metadata *adapter.InboundContext) bool { diff --git a/route/rule/rule_dns.go b/route/rule/rule_dns.go index f535844a7..d643755bb 100644 --- a/route/rule/rule_dns.go +++ b/route/rule/rule_dns.go @@ -54,7 +54,8 @@ var _ adapter.DNSRule = (*DefaultDNSRule)(nil) type DefaultDNSRule struct { abstractDefaultRule - matchResponse bool + matchResponse bool + legacyAddressFilter bool } func (r *DefaultDNSRule) matchStates(metadata *adapter.InboundContext) ruleMatchStateSet { @@ -67,7 +68,8 @@ func NewDefaultDNSRule(ctx context.Context, logger log.ContextLogger, options op invert: options.Invert, action: NewDNSRuleAction(logger, options.DNSRuleAction), }, - matchResponse: options.MatchResponse, + matchResponse: options.MatchResponse, + legacyAddressFilter: legacyAddressFilter, } if len(options.Inbound) > 0 { item := NewInboundRule(options.Inbound) @@ -361,16 +363,21 @@ func (r *DefaultDNSRule) matchStatesForMatch(metadata *adapter.InboundContext) r return 0 } matchMetadata := *metadata + matchMetadata.IgnoreDestinationIPCIDRMatch = false matchMetadata.DestinationAddressMatchFromResponse = true return r.abstractDefaultRule.matchStates(&matchMetadata) } matchMetadata := *metadata + if r.legacyAddressFilter { + matchMetadata.IgnoreDestinationIPCIDRMatch = true + } return r.abstractDefaultRule.matchStates(&matchMetadata) } func (r *DefaultDNSRule) MatchAddressLimit(metadata *adapter.InboundContext, response *dns.Msg) bool { matchMetadata := *metadata matchMetadata.DNSResponse = response + matchMetadata.IgnoreDestinationIPCIDRMatch = false matchMetadata.DestinationAddressMatchFromResponse = true return !r.abstractDefaultRule.matchStates(&matchMetadata).isEmpty() } @@ -379,6 +386,7 @@ var _ adapter.DNSRule = (*LogicalDNSRule)(nil) type LogicalDNSRule struct { abstractLogicalRule + legacyAddressFilter bool } func (r *LogicalDNSRule) matchStates(metadata *adapter.InboundContext) ruleMatchStateSet { @@ -397,11 +405,15 @@ func matchDNSHeadlessRuleStatesForMatch(rule adapter.HeadlessRule, metadata *ada } func (r *LogicalDNSRule) matchStatesForMatch(metadata *adapter.InboundContext) ruleMatchStateSet { + matchMetadata := *metadata + if r.legacyAddressFilter { + matchMetadata.IgnoreDestinationIPCIDRMatch = true + } var stateSet ruleMatchStateSet if r.mode == C.LogicalTypeAnd { stateSet = emptyRuleMatchState() for _, rule := range r.rules { - nestedMetadata := *metadata + nestedMetadata := matchMetadata nestedMetadata.ResetRuleCache() nestedStateSet := matchDNSHeadlessRuleStatesForMatch(rule, &nestedMetadata) if nestedStateSet.isEmpty() { @@ -414,7 +426,7 @@ func (r *LogicalDNSRule) matchStatesForMatch(metadata *adapter.InboundContext) r } } else { for _, rule := range r.rules { - nestedMetadata := *metadata + nestedMetadata := matchMetadata nestedMetadata.ResetRuleCache() stateSet = stateSet.merge(matchDNSHeadlessRuleStatesForMatch(rule, &nestedMetadata)) } @@ -438,6 +450,7 @@ func NewLogicalDNSRule(ctx context.Context, logger log.ContextLogger, options op invert: options.Invert, action: NewDNSRuleAction(logger, options.DNSRuleAction), }, + legacyAddressFilter: legacyAddressFilter, } switch options.Mode { case C.LogicalTypeAnd: @@ -484,6 +497,7 @@ func (r *LogicalDNSRule) Match(metadata *adapter.InboundContext) bool { func (r *LogicalDNSRule) MatchAddressLimit(metadata *adapter.InboundContext, response *dns.Msg) bool { matchMetadata := *metadata matchMetadata.DNSResponse = response + matchMetadata.IgnoreDestinationIPCIDRMatch = false matchMetadata.DestinationAddressMatchFromResponse = true return !r.abstractLogicalRule.matchStates(&matchMetadata).isEmpty() } diff --git a/route/rule/rule_set_semantics_test.go b/route/rule/rule_set_semantics_test.go index f599adc3b..3a9d4ca9d 100644 --- a/route/rule/rule_set_semantics_test.go +++ b/route/rule/rule_set_semantics_test.go @@ -742,6 +742,110 @@ func TestDNSAddressLimitIgnoresDestinationAddresses(t *testing.T) { } } +func TestDNSLegacyAddressLimitPreLookupDefersDirectRules(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + build func(*testing.T, *abstractDefaultRule) + matchedResponse *mDNS.Msg + unmatchedResponse *mDNS.Msg + }{ + { + name: "ip_cidr", + build: func(t *testing.T, rule *abstractDefaultRule) { + t.Helper() + addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"}) + }, + matchedResponse: dnsResponseForTest(netip.MustParseAddr("203.0.113.1")), + unmatchedResponse: dnsResponseForTest(netip.MustParseAddr("8.8.8.8")), + }, + { + name: "ip_is_private", + build: func(t *testing.T, rule *abstractDefaultRule) { + t.Helper() + addDestinationIPIsPrivateItem(rule) + }, + matchedResponse: dnsResponseForTest(netip.MustParseAddr("10.0.0.1")), + unmatchedResponse: dnsResponseForTest(netip.MustParseAddr("8.8.8.8")), + }, + { + name: "ip_accept_any", + build: func(t *testing.T, rule *abstractDefaultRule) { + t.Helper() + addDestinationIPAcceptAnyItem(rule) + }, + matchedResponse: dnsResponseForTest(netip.MustParseAddr("203.0.113.1")), + unmatchedResponse: dnsResponseForTest(), + }, + } + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + rule := dnsRuleForTest(func(rule *abstractDefaultRule) { + testCase.build(t, rule) + }) + rule.legacyAddressFilter = true + + preLookupMetadata := testMetadata("lookup.example") + require.True(t, rule.Match(&preLookupMetadata)) + + matchedMetadata := testMetadata("lookup.example") + require.True(t, rule.MatchAddressLimit(&matchedMetadata, testCase.matchedResponse)) + + unmatchedMetadata := testMetadata("lookup.example") + require.False(t, rule.MatchAddressLimit(&unmatchedMetadata, testCase.unmatchedResponse)) + }) + } +} + +func TestDNSLegacyAddressLimitPreLookupDefersRuleSetDestinationCIDR(t *testing.T) { + t.Parallel() + + ruleSet := newLocalRuleSetForTest("dns-legacy-ipcidr", headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"}) + })) + rule := dnsRuleForTest(func(rule *abstractDefaultRule) { + addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{ruleSet}}) + }) + rule.legacyAddressFilter = true + + preLookupMetadata := testMetadata("lookup.example") + require.True(t, rule.Match(&preLookupMetadata)) + + matchedMetadata := testMetadata("lookup.example") + require.True(t, rule.MatchAddressLimit(&matchedMetadata, dnsResponseForTest(netip.MustParseAddr("203.0.113.1")))) + + unmatchedMetadata := testMetadata("lookup.example") + require.False(t, rule.MatchAddressLimit(&unmatchedMetadata, dnsResponseForTest(netip.MustParseAddr("8.8.8.8")))) +} + +func TestDNSLegacyLogicalAddressLimitPreLookupDefersNestedRules(t *testing.T) { + t.Parallel() + + nestedRule := dnsRuleForTest(func(rule *abstractDefaultRule) { + addDestinationIPIsPrivateItem(rule) + }) + logicalRule := &LogicalDNSRule{ + abstractLogicalRule: abstractLogicalRule{ + rules: []adapter.HeadlessRule{nestedRule}, + mode: C.LogicalTypeAnd, + }, + legacyAddressFilter: true, + } + + preLookupMetadata := testMetadata("lookup.example") + require.True(t, logicalRule.Match(&preLookupMetadata)) + + matchedMetadata := testMetadata("lookup.example") + require.True(t, logicalRule.MatchAddressLimit(&matchedMetadata, dnsResponseForTest(netip.MustParseAddr("10.0.0.1")))) + + unmatchedMetadata := testMetadata("lookup.example") + require.False(t, logicalRule.MatchAddressLimit(&unmatchedMetadata, dnsResponseForTest(netip.MustParseAddr("8.8.8.8")))) +} + func TestDNSInvertAddressLimitPreLookupRegression(t *testing.T) { t.Parallel() testCases := []struct {