dns: preserve legacy address-filter pre-match semantics

Legacy DNS address-filter mode still accepts destination-side IP
predicates with a deprecation warning, but the recent evaluate/
match_response refactor started evaluating those predicates during
pre-response Match(). That broke rules whose transport selection must
be deferred until MatchAddressLimit() can inspect the upstream reply.

Restore the old defer behavior by reintroducing an internal
IgnoreDestinationIPCIDRMatch flag on InboundContext and using it only
for legacy pre-response DNS matching. Default and logical DNS rules now
carry the legacy mode bit, set the ignore flag on metadata copies while
performing pre-response Match(), and explicitly clear it again for
match_response and MatchAddressLimit() so response-phase matching still
checks the returned addresses.

Add regression coverage for direct legacy destination-IP rules,
rule_set-backed CIDR rules, logical wrappers, and the legacy Lookup
router path, including fallback after a rejected response. This keeps
legacy configs working without changing new-mode evaluate semantics.

Tests: go test ./route/rule ./dns
Tests: make
This commit is contained in:
世界
2026-03-25 20:05:50 +08:00
parent 354e2bbff7
commit 6c5f351dcf
5 changed files with 244 additions and 12 deletions

View File

@@ -99,11 +99,12 @@ type InboundContext struct {
IPCIDRMatchSource bool
IPCIDRAcceptEmpty bool
SourceAddressMatch bool
SourcePortMatch bool
DestinationAddressMatch bool
DestinationPortMatch bool
DidMatch bool
SourceAddressMatch bool
SourcePortMatch bool
DestinationAddressMatch bool
DestinationPortMatch bool
DidMatch bool
IgnoreDestinationIPCIDRMatch bool
}
func (c *InboundContext) ResetRuleCache() {

View File

@@ -65,6 +65,7 @@ func (m *fakeDNSTransportManager) Create(context.Context, log.ContextLogger, str
type fakeDNSClient struct {
beforeExchange func(ctx context.Context, transport adapter.DNSTransport, message *mDNS.Msg)
exchange func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error)
lookup func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error)
}
type fakeDeprecatedManager struct {
@@ -84,8 +85,24 @@ func (c *fakeDNSClient) Exchange(ctx context.Context, transport adapter.DNSTrans
return c.exchange(transport, message)
}
func (c *fakeDNSClient) Lookup(context.Context, adapter.DNSTransport, string, adapter.DNSQueryOptions, func(*mDNS.Msg) bool) ([]netip.Addr, error) {
return nil, errors.New("unused client lookup")
func (c *fakeDNSClient) Lookup(_ context.Context, transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions, responseChecker func(*mDNS.Msg) bool) ([]netip.Addr, error) {
if c.lookup == nil {
return nil, errors.New("unused client lookup")
}
addresses, response, err := c.lookup(transport, domain, options)
if err != nil {
return nil, err
}
if response == nil {
response = FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), addresses, 60)
}
if responseChecker != nil && !responseChecker(response) {
return nil, ErrResponseRejected
}
if addresses != nil {
return addresses, nil
}
return MessageToAddresses(response), nil
}
func (c *fakeDNSClient) ClearCache() {}
@@ -185,6 +202,102 @@ func TestValidateNewDNSRules_RequireMatchResponseForDirectIPCIDR(t *testing.T) {
require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response")
}
func TestLookupLegacyModeDefersDirectDestinationIPMatch(t *testing.T) {
t.Parallel()
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
privateTransport := &fakeDNSTransport{tag: "private", transportType: C.DNSTypeUDP}
client := &fakeDNSClient{
lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) {
require.Equal(t, "example.com", domain)
require.Equal(t, C.DomainStrategyIPv4Only, options.LookupStrategy)
switch transport.Tag() {
case "private":
response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60)
return MessageToAddresses(response), response, nil
case "default":
t.Fatal("default transport should not be used when legacy rule matches after response")
}
return nil, nil, errors.New("unexpected transport")
},
}
router := newTestRouter(t, []option.DNSRule{{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
IPIsPrivate: true,
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.DNSRouteActionOptions{Server: "private"},
},
},
}}, &fakeDNSTransportManager{
defaultTransport: defaultTransport,
transports: map[string]adapter.DNSTransport{
"default": defaultTransport,
"private": privateTransport,
},
}, client)
require.True(t, router.legacyAddressFilterMode)
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{
LookupStrategy: C.DomainStrategyIPv4Only,
})
require.NoError(t, err)
require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses)
}
func TestLookupLegacyModeFallsBackAfterRejectedAddressLimitResponse(t *testing.T) {
t.Parallel()
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
privateTransport := &fakeDNSTransport{tag: "private", transportType: C.DNSTypeUDP}
var lookups []string
client := &fakeDNSClient{
lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) {
require.Equal(t, "example.com", domain)
require.Equal(t, C.DomainStrategyIPv4Only, options.LookupStrategy)
lookups = append(lookups, transport.Tag())
switch transport.Tag() {
case "private":
response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60)
return MessageToAddresses(response), response, nil
case "default":
response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("9.9.9.9")}, 60)
return MessageToAddresses(response), response, nil
}
return nil, nil, errors.New("unexpected transport")
},
}
router := newTestRouter(t, []option.DNSRule{{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
IPIsPrivate: true,
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.DNSRouteActionOptions{Server: "private"},
},
},
}}, &fakeDNSTransportManager{
defaultTransport: defaultTransport,
transports: map[string]adapter.DNSTransport{
"default": defaultTransport,
"private": privateTransport,
},
}, client)
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{
LookupStrategy: C.DomainStrategyIPv4Only,
})
require.NoError(t, err)
require.Equal(t, []netip.Addr{netip.MustParseAddr("9.9.9.9")}, addresses)
require.Equal(t, []string{"private", "default"}, lookups)
}
func TestDNSResponseAddressesMatchesMessageToAddressesForHTTPSHints(t *testing.T) {
t.Parallel()

View File

@@ -60,7 +60,7 @@ func (r *abstractDefaultRule) destinationIPCIDRMatchesSource(metadata *adapter.I
}
func (r *abstractDefaultRule) destinationIPCIDRMatchesDestination(metadata *adapter.InboundContext) bool {
return !metadata.IPCIDRMatchSource && len(r.destinationIPCIDRItems) > 0
return !metadata.IgnoreDestinationIPCIDRMatch && !metadata.IPCIDRMatchSource && len(r.destinationIPCIDRItems) > 0
}
func (r *abstractDefaultRule) requiresSourceAddressMatch(metadata *adapter.InboundContext) bool {

View File

@@ -54,7 +54,8 @@ var _ adapter.DNSRule = (*DefaultDNSRule)(nil)
type DefaultDNSRule struct {
abstractDefaultRule
matchResponse bool
matchResponse bool
legacyAddressFilter bool
}
func (r *DefaultDNSRule) matchStates(metadata *adapter.InboundContext) ruleMatchStateSet {
@@ -67,7 +68,8 @@ func NewDefaultDNSRule(ctx context.Context, logger log.ContextLogger, options op
invert: options.Invert,
action: NewDNSRuleAction(logger, options.DNSRuleAction),
},
matchResponse: options.MatchResponse,
matchResponse: options.MatchResponse,
legacyAddressFilter: legacyAddressFilter,
}
if len(options.Inbound) > 0 {
item := NewInboundRule(options.Inbound)
@@ -361,16 +363,21 @@ func (r *DefaultDNSRule) matchStatesForMatch(metadata *adapter.InboundContext) r
return 0
}
matchMetadata := *metadata
matchMetadata.IgnoreDestinationIPCIDRMatch = false
matchMetadata.DestinationAddressMatchFromResponse = true
return r.abstractDefaultRule.matchStates(&matchMetadata)
}
matchMetadata := *metadata
if r.legacyAddressFilter {
matchMetadata.IgnoreDestinationIPCIDRMatch = true
}
return r.abstractDefaultRule.matchStates(&matchMetadata)
}
func (r *DefaultDNSRule) MatchAddressLimit(metadata *adapter.InboundContext, response *dns.Msg) bool {
matchMetadata := *metadata
matchMetadata.DNSResponse = response
matchMetadata.IgnoreDestinationIPCIDRMatch = false
matchMetadata.DestinationAddressMatchFromResponse = true
return !r.abstractDefaultRule.matchStates(&matchMetadata).isEmpty()
}
@@ -379,6 +386,7 @@ var _ adapter.DNSRule = (*LogicalDNSRule)(nil)
type LogicalDNSRule struct {
abstractLogicalRule
legacyAddressFilter bool
}
func (r *LogicalDNSRule) matchStates(metadata *adapter.InboundContext) ruleMatchStateSet {
@@ -397,11 +405,15 @@ func matchDNSHeadlessRuleStatesForMatch(rule adapter.HeadlessRule, metadata *ada
}
func (r *LogicalDNSRule) matchStatesForMatch(metadata *adapter.InboundContext) ruleMatchStateSet {
matchMetadata := *metadata
if r.legacyAddressFilter {
matchMetadata.IgnoreDestinationIPCIDRMatch = true
}
var stateSet ruleMatchStateSet
if r.mode == C.LogicalTypeAnd {
stateSet = emptyRuleMatchState()
for _, rule := range r.rules {
nestedMetadata := *metadata
nestedMetadata := matchMetadata
nestedMetadata.ResetRuleCache()
nestedStateSet := matchDNSHeadlessRuleStatesForMatch(rule, &nestedMetadata)
if nestedStateSet.isEmpty() {
@@ -414,7 +426,7 @@ func (r *LogicalDNSRule) matchStatesForMatch(metadata *adapter.InboundContext) r
}
} else {
for _, rule := range r.rules {
nestedMetadata := *metadata
nestedMetadata := matchMetadata
nestedMetadata.ResetRuleCache()
stateSet = stateSet.merge(matchDNSHeadlessRuleStatesForMatch(rule, &nestedMetadata))
}
@@ -438,6 +450,7 @@ func NewLogicalDNSRule(ctx context.Context, logger log.ContextLogger, options op
invert: options.Invert,
action: NewDNSRuleAction(logger, options.DNSRuleAction),
},
legacyAddressFilter: legacyAddressFilter,
}
switch options.Mode {
case C.LogicalTypeAnd:
@@ -484,6 +497,7 @@ func (r *LogicalDNSRule) Match(metadata *adapter.InboundContext) bool {
func (r *LogicalDNSRule) MatchAddressLimit(metadata *adapter.InboundContext, response *dns.Msg) bool {
matchMetadata := *metadata
matchMetadata.DNSResponse = response
matchMetadata.IgnoreDestinationIPCIDRMatch = false
matchMetadata.DestinationAddressMatchFromResponse = true
return !r.abstractLogicalRule.matchStates(&matchMetadata).isEmpty()
}

View File

@@ -742,6 +742,110 @@ func TestDNSAddressLimitIgnoresDestinationAddresses(t *testing.T) {
}
}
func TestDNSLegacyAddressLimitPreLookupDefersDirectRules(t *testing.T) {
t.Parallel()
testCases := []struct {
name string
build func(*testing.T, *abstractDefaultRule)
matchedResponse *mDNS.Msg
unmatchedResponse *mDNS.Msg
}{
{
name: "ip_cidr",
build: func(t *testing.T, rule *abstractDefaultRule) {
t.Helper()
addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"})
},
matchedResponse: dnsResponseForTest(netip.MustParseAddr("203.0.113.1")),
unmatchedResponse: dnsResponseForTest(netip.MustParseAddr("8.8.8.8")),
},
{
name: "ip_is_private",
build: func(t *testing.T, rule *abstractDefaultRule) {
t.Helper()
addDestinationIPIsPrivateItem(rule)
},
matchedResponse: dnsResponseForTest(netip.MustParseAddr("10.0.0.1")),
unmatchedResponse: dnsResponseForTest(netip.MustParseAddr("8.8.8.8")),
},
{
name: "ip_accept_any",
build: func(t *testing.T, rule *abstractDefaultRule) {
t.Helper()
addDestinationIPAcceptAnyItem(rule)
},
matchedResponse: dnsResponseForTest(netip.MustParseAddr("203.0.113.1")),
unmatchedResponse: dnsResponseForTest(),
},
}
for _, testCase := range testCases {
testCase := testCase
t.Run(testCase.name, func(t *testing.T) {
t.Parallel()
rule := dnsRuleForTest(func(rule *abstractDefaultRule) {
testCase.build(t, rule)
})
rule.legacyAddressFilter = true
preLookupMetadata := testMetadata("lookup.example")
require.True(t, rule.Match(&preLookupMetadata))
matchedMetadata := testMetadata("lookup.example")
require.True(t, rule.MatchAddressLimit(&matchedMetadata, testCase.matchedResponse))
unmatchedMetadata := testMetadata("lookup.example")
require.False(t, rule.MatchAddressLimit(&unmatchedMetadata, testCase.unmatchedResponse))
})
}
}
func TestDNSLegacyAddressLimitPreLookupDefersRuleSetDestinationCIDR(t *testing.T) {
t.Parallel()
ruleSet := newLocalRuleSetForTest("dns-legacy-ipcidr", headlessDefaultRule(t, func(rule *abstractDefaultRule) {
addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"})
}))
rule := dnsRuleForTest(func(rule *abstractDefaultRule) {
addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{ruleSet}})
})
rule.legacyAddressFilter = true
preLookupMetadata := testMetadata("lookup.example")
require.True(t, rule.Match(&preLookupMetadata))
matchedMetadata := testMetadata("lookup.example")
require.True(t, rule.MatchAddressLimit(&matchedMetadata, dnsResponseForTest(netip.MustParseAddr("203.0.113.1"))))
unmatchedMetadata := testMetadata("lookup.example")
require.False(t, rule.MatchAddressLimit(&unmatchedMetadata, dnsResponseForTest(netip.MustParseAddr("8.8.8.8"))))
}
func TestDNSLegacyLogicalAddressLimitPreLookupDefersNestedRules(t *testing.T) {
t.Parallel()
nestedRule := dnsRuleForTest(func(rule *abstractDefaultRule) {
addDestinationIPIsPrivateItem(rule)
})
logicalRule := &LogicalDNSRule{
abstractLogicalRule: abstractLogicalRule{
rules: []adapter.HeadlessRule{nestedRule},
mode: C.LogicalTypeAnd,
},
legacyAddressFilter: true,
}
preLookupMetadata := testMetadata("lookup.example")
require.True(t, logicalRule.Match(&preLookupMetadata))
matchedMetadata := testMetadata("lookup.example")
require.True(t, logicalRule.MatchAddressLimit(&matchedMetadata, dnsResponseForTest(netip.MustParseAddr("10.0.0.1"))))
unmatchedMetadata := testMetadata("lookup.example")
require.False(t, logicalRule.MatchAddressLimit(&unmatchedMetadata, dnsResponseForTest(netip.MustParseAddr("8.8.8.8"))))
}
func TestDNSInvertAddressLimitPreLookupRegression(t *testing.T) {
t.Parallel()
testCases := []struct {