From a5c320114003812bdaeb5de20f6ec74711799298 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Wed, 25 Mar 2026 20:41:00 +0800 Subject: [PATCH] dns: isolate legacy pre-match semantics --- adapter/rule.go | 1 + dns/router.go | 2 +- route/rule/rule_dns.go | 33 +- route/rule/rule_dns_legacy.go | 638 ++++++++++++++++++++++++++ route/rule/rule_set_semantics_test.go | 135 +++++- 5 files changed, 785 insertions(+), 24 deletions(-) create mode 100644 route/rule/rule_dns_legacy.go diff --git a/adapter/rule.go b/adapter/rule.go index 31ed9b424..00470f60e 100644 --- a/adapter/rule.go +++ b/adapter/rule.go @@ -20,6 +20,7 @@ type Rule interface { type DNSRule interface { Rule + LegacyPreMatch(metadata *InboundContext) bool WithAddressLimit() bool MatchAddressLimit(metadata *InboundContext, response *dns.Msg) bool } diff --git a/dns/router.go b/dns/router.go index 5453812c8..5fbf50b75 100644 --- a/dns/router.go +++ b/dns/router.go @@ -146,7 +146,7 @@ func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int, } metadata.ResetRuleCache() metadata.DestinationAddressMatchFromResponse = false - if currentRule.Match(metadata) { + if currentRule.LegacyPreMatch(metadata) { displayRuleIndex := currentRuleIndex if displayRuleIndex != -1 { displayRuleIndex += displayRuleIndex + 1 diff --git a/route/rule/rule_dns.go b/route/rule/rule_dns.go index d643755bb..cceaec8e9 100644 --- a/route/rule/rule_dns.go +++ b/route/rule/rule_dns.go @@ -54,8 +54,7 @@ var _ adapter.DNSRule = (*DefaultDNSRule)(nil) type DefaultDNSRule struct { abstractDefaultRule - matchResponse bool - legacyAddressFilter bool + matchResponse bool } func (r *DefaultDNSRule) matchStates(metadata *adapter.InboundContext) ruleMatchStateSet { @@ -68,8 +67,7 @@ func NewDefaultDNSRule(ctx context.Context, logger log.ContextLogger, options op invert: options.Invert, action: NewDNSRuleAction(logger, options.DNSRuleAction), }, - matchResponse: options.MatchResponse, - legacyAddressFilter: legacyAddressFilter, + matchResponse: options.MatchResponse, } if len(options.Inbound) > 0 { item := NewInboundRule(options.Inbound) @@ -357,6 +355,13 @@ func (r *DefaultDNSRule) Match(metadata *adapter.InboundContext) bool { return !r.matchStatesForMatch(metadata).isEmpty() } +func (r *DefaultDNSRule) LegacyPreMatch(metadata *adapter.InboundContext) bool { + if r.matchResponse { + return !r.matchStatesForMatch(metadata).isEmpty() + } + return !r.abstractDefaultRule.legacyMatchStates(metadata).isEmpty() +} + func (r *DefaultDNSRule) matchStatesForMatch(metadata *adapter.InboundContext) ruleMatchStateSet { if r.matchResponse { if metadata.DNSResponse == nil { @@ -367,11 +372,7 @@ func (r *DefaultDNSRule) matchStatesForMatch(metadata *adapter.InboundContext) r matchMetadata.DestinationAddressMatchFromResponse = true return r.abstractDefaultRule.matchStates(&matchMetadata) } - matchMetadata := *metadata - if r.legacyAddressFilter { - matchMetadata.IgnoreDestinationIPCIDRMatch = true - } - return r.abstractDefaultRule.matchStates(&matchMetadata) + return r.abstractDefaultRule.matchStates(metadata) } func (r *DefaultDNSRule) MatchAddressLimit(metadata *adapter.InboundContext, response *dns.Msg) bool { @@ -386,7 +387,6 @@ var _ adapter.DNSRule = (*LogicalDNSRule)(nil) type LogicalDNSRule struct { abstractLogicalRule - legacyAddressFilter bool } func (r *LogicalDNSRule) matchStates(metadata *adapter.InboundContext) ruleMatchStateSet { @@ -405,15 +405,11 @@ func matchDNSHeadlessRuleStatesForMatch(rule adapter.HeadlessRule, metadata *ada } func (r *LogicalDNSRule) matchStatesForMatch(metadata *adapter.InboundContext) ruleMatchStateSet { - matchMetadata := *metadata - if r.legacyAddressFilter { - matchMetadata.IgnoreDestinationIPCIDRMatch = true - } var stateSet ruleMatchStateSet if r.mode == C.LogicalTypeAnd { stateSet = emptyRuleMatchState() for _, rule := range r.rules { - nestedMetadata := matchMetadata + nestedMetadata := *metadata nestedMetadata.ResetRuleCache() nestedStateSet := matchDNSHeadlessRuleStatesForMatch(rule, &nestedMetadata) if nestedStateSet.isEmpty() { @@ -426,7 +422,7 @@ func (r *LogicalDNSRule) matchStatesForMatch(metadata *adapter.InboundContext) r } } else { for _, rule := range r.rules { - nestedMetadata := matchMetadata + nestedMetadata := *metadata nestedMetadata.ResetRuleCache() stateSet = stateSet.merge(matchDNSHeadlessRuleStatesForMatch(rule, &nestedMetadata)) } @@ -450,7 +446,6 @@ func NewLogicalDNSRule(ctx context.Context, logger log.ContextLogger, options op invert: options.Invert, action: NewDNSRuleAction(logger, options.DNSRuleAction), }, - legacyAddressFilter: legacyAddressFilter, } switch options.Mode { case C.LogicalTypeAnd: @@ -494,6 +489,10 @@ func (r *LogicalDNSRule) Match(metadata *adapter.InboundContext) bool { return !r.matchStatesForMatch(metadata).isEmpty() } +func (r *LogicalDNSRule) LegacyPreMatch(metadata *adapter.InboundContext) bool { + return !r.abstractLogicalRule.legacyMatchStates(metadata).isEmpty() +} + func (r *LogicalDNSRule) MatchAddressLimit(metadata *adapter.InboundContext, response *dns.Msg) bool { matchMetadata := *metadata matchMetadata.DNSResponse = response diff --git a/route/rule/rule_dns_legacy.go b/route/rule/rule_dns_legacy.go new file mode 100644 index 000000000..23816d7e6 --- /dev/null +++ b/route/rule/rule_dns_legacy.go @@ -0,0 +1,638 @@ +package rule + +import ( + "net/netip" + + "github.com/sagernet/sing-box/adapter" + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing/common" + + "go4.org/netipx" +) + +type legacyResponseLiteralKind uint8 + +const ( + legacyLiteralRequireEmpty legacyResponseLiteralKind = iota + legacyLiteralRequireNonEmpty + legacyLiteralRequireSet + legacyLiteralForbidSet +) + +type legacyResponseLiteral struct { + kind legacyResponseLiteralKind + ipSet *netipx.IPSet +} + +type legacyResponseTerm []legacyResponseLiteral + +type legacyResponseFormula []legacyResponseTerm + +type legacyRuleMatchStateSet [16]legacyResponseFormula + +var ( + legacyAllIPSet = func() *netipx.IPSet { + var builder netipx.IPSetBuilder + builder.Complement() + return common.Must1(builder.IPSet()) + }() + legacyNonPublicIPSet = func() *netipx.IPSet { + var builder netipx.IPSetBuilder + for _, prefix := range []string{ + "0.0.0.0/32", + "10.0.0.0/8", + "127.0.0.0/8", + "169.254.0.0/16", + "172.16.0.0/12", + "192.168.0.0/16", + "224.0.0.0/4", + "::/128", + "::1/128", + "fc00::/7", + "fe80::/10", + "ff00::/8", + } { + builder.AddPrefix(netip.MustParsePrefix(prefix)) + } + return common.Must1(builder.IPSet()) + }() +) + +func legacyFalseFormula() legacyResponseFormula { + return nil +} + +func legacyTrueFormula() legacyResponseFormula { + return legacyResponseFormula{legacyResponseTerm{}} +} + +func legacyLiteralFormula(literal legacyResponseLiteral) legacyResponseFormula { + return legacyResponseFormula{legacyResponseTerm{literal}} +} + +func (f legacyResponseFormula) isFalse() bool { + return len(f) == 0 +} + +func (f legacyResponseFormula) isTrue() bool { + return len(f) == 1 && len(f[0]) == 0 +} + +func (f legacyResponseFormula) or(other legacyResponseFormula) legacyResponseFormula { + if f.isFalse() { + return other + } + if other.isFalse() { + return f + } + result := make(legacyResponseFormula, 0, len(f)+len(other)) + result = append(result, f...) + result = append(result, other...) + return result +} + +func (f legacyResponseFormula) and(other legacyResponseFormula) legacyResponseFormula { + if f.isFalse() || other.isFalse() { + return legacyFalseFormula() + } + if f.isTrue() { + return other + } + if other.isTrue() { + return f + } + var result legacyResponseFormula + for _, leftTerm := range f { + for _, rightTerm := range other { + combined, valid := legacyCombineResponseTerms(leftTerm, rightTerm) + if valid { + result = append(result, combined) + } + } + } + return result +} + +func (f legacyResponseFormula) not() legacyResponseFormula { + if f.isFalse() { + return legacyTrueFormula() + } + result := legacyTrueFormula() + for _, term := range f { + result = result.and(legacyNegateResponseTerm(term)) + if result.isFalse() { + return result + } + } + return result +} + +func legacyNegateResponseTerm(term legacyResponseTerm) legacyResponseFormula { + if len(term) == 0 { + return legacyFalseFormula() + } + result := make(legacyResponseFormula, 0, len(term)) + for _, literal := range term { + result = append(result, legacyResponseTerm{legacyNegateResponseLiteral(literal)}) + } + return result +} + +func legacyNegateResponseLiteral(literal legacyResponseLiteral) legacyResponseLiteral { + switch literal.kind { + case legacyLiteralRequireEmpty: + return legacyResponseLiteral{kind: legacyLiteralRequireNonEmpty} + case legacyLiteralRequireNonEmpty: + return legacyResponseLiteral{kind: legacyLiteralRequireEmpty} + case legacyLiteralRequireSet: + return legacyResponseLiteral{kind: legacyLiteralForbidSet, ipSet: literal.ipSet} + case legacyLiteralForbidSet: + return legacyResponseLiteral{kind: legacyLiteralRequireSet, ipSet: literal.ipSet} + default: + panic("unknown legacy response literal kind") + } +} + +func legacyCombineResponseTerms(left legacyResponseTerm, right legacyResponseTerm) (legacyResponseTerm, bool) { + combined := make(legacyResponseTerm, 0, len(left)+len(right)) + combined = append(combined, left...) + combined = append(combined, right...) + if !legacyResponseTermSatisfiable(combined) { + return nil, false + } + return combined, true +} + +func legacyResponseTermSatisfiable(term legacyResponseTerm) bool { + var ( + requireEmpty bool + requireNonEmpty bool + requiredSets []*netipx.IPSet + forbiddenBuild netipx.IPSetBuilder + hasForbidden bool + ) + for _, literal := range term { + switch literal.kind { + case legacyLiteralRequireEmpty: + requireEmpty = true + case legacyLiteralRequireNonEmpty: + requireNonEmpty = true + case legacyLiteralRequireSet: + requiredSets = append(requiredSets, literal.ipSet) + case legacyLiteralForbidSet: + if literal.ipSet != nil { + forbiddenBuild.AddSet(literal.ipSet) + hasForbidden = true + } + default: + panic("unknown legacy response literal kind") + } + } + if requireEmpty && (requireNonEmpty || len(requiredSets) > 0) { + return false + } + if requireEmpty { + return true + } + var forbidden *netipx.IPSet + if hasForbidden { + forbidden = common.Must1(forbiddenBuild.IPSet()) + } + for _, required := range requiredSets { + if !legacyIPSetHasAllowedIP(required, forbidden) { + return false + } + } + if requireNonEmpty && len(requiredSets) == 0 { + return legacyIPSetHasAllowedIP(legacyAllIPSet, forbidden) + } + return true +} + +func legacyIPSetHasAllowedIP(required *netipx.IPSet, forbidden *netipx.IPSet) bool { + if required == nil { + required = legacyAllIPSet + } + if forbidden == nil { + return len(required.Ranges()) > 0 + } + builder := netipx.IPSetBuilder{} + builder.AddSet(required) + builder.RemoveSet(forbidden) + remaining := common.Must1(builder.IPSet()) + return len(remaining.Ranges()) > 0 +} + +func legacySingleRuleMatchState(state ruleMatchState) legacyRuleMatchStateSet { + return legacySingleRuleMatchStateWithFormula(state, legacyTrueFormula()) +} + +func legacySingleRuleMatchStateWithFormula(state ruleMatchState, formula legacyResponseFormula) legacyRuleMatchStateSet { + var stateSet legacyRuleMatchStateSet + if !formula.isFalse() { + stateSet[state] = formula + } + return stateSet +} + +func (s legacyRuleMatchStateSet) isEmpty() bool { + for _, formula := range s { + if !formula.isFalse() { + return false + } + } + return true +} + +func (s legacyRuleMatchStateSet) merge(other legacyRuleMatchStateSet) legacyRuleMatchStateSet { + var merged legacyRuleMatchStateSet + for state := ruleMatchState(0); state < 16; state++ { + merged[state] = s[state].or(other[state]) + } + return merged +} + +func (s legacyRuleMatchStateSet) combine(other legacyRuleMatchStateSet) legacyRuleMatchStateSet { + if s.isEmpty() || other.isEmpty() { + return legacyRuleMatchStateSet{} + } + var combined legacyRuleMatchStateSet + for left := ruleMatchState(0); left < 16; left++ { + if s[left].isFalse() { + continue + } + for right := ruleMatchState(0); right < 16; right++ { + if other[right].isFalse() { + continue + } + combined[left|right] = combined[left|right].or(s[left].and(other[right])) + } + } + return combined +} + +func (s legacyRuleMatchStateSet) withBase(base ruleMatchState) legacyRuleMatchStateSet { + if s.isEmpty() { + return legacyRuleMatchStateSet{} + } + var withBase legacyRuleMatchStateSet + for state := ruleMatchState(0); state < 16; state++ { + if s[state].isFalse() { + continue + } + withBase[state|base] = withBase[state|base].or(s[state]) + } + return withBase +} + +func (s legacyRuleMatchStateSet) filter(allowed func(ruleMatchState) bool) legacyRuleMatchStateSet { + var filtered legacyRuleMatchStateSet + for state := ruleMatchState(0); state < 16; state++ { + if s[state].isFalse() { + continue + } + if allowed(state) { + filtered[state] = s[state] + } + } + return filtered +} + +func (s legacyRuleMatchStateSet) addBit(bit ruleMatchState) legacyRuleMatchStateSet { + var withBit legacyRuleMatchStateSet + for state := ruleMatchState(0); state < 16; state++ { + if s[state].isFalse() { + continue + } + withBit[state|bit] = withBit[state|bit].or(s[state]) + } + return withBit +} + +func (s legacyRuleMatchStateSet) branchOnBit(bit ruleMatchState, condition legacyResponseFormula) legacyRuleMatchStateSet { + if condition.isFalse() { + return s + } + if condition.isTrue() { + return s.addBit(bit) + } + var branched legacyRuleMatchStateSet + conditionFalse := condition.not() + for state := ruleMatchState(0); state < 16; state++ { + if s[state].isFalse() { + continue + } + if state.has(bit) { + branched[state] = branched[state].or(s[state]) + continue + } + branched[state] = branched[state].or(s[state].and(conditionFalse)) + branched[state|bit] = branched[state|bit].or(s[state].and(condition)) + } + return branched +} + +func (s legacyRuleMatchStateSet) andFormula(formula legacyResponseFormula) legacyRuleMatchStateSet { + if formula.isFalse() || s.isEmpty() { + return legacyRuleMatchStateSet{} + } + if formula.isTrue() { + return s + } + var result legacyRuleMatchStateSet + for state := ruleMatchState(0); state < 16; state++ { + if s[state].isFalse() { + continue + } + result[state] = s[state].and(formula) + } + return result +} + +func (s legacyRuleMatchStateSet) anyFormula() legacyResponseFormula { + var formula legacyResponseFormula + for _, stateFormula := range s { + formula = formula.or(stateFormula) + } + return formula +} + +type legacyRuleStateMatcher interface { + legacyMatchStates(metadata *adapter.InboundContext) legacyRuleMatchStateSet +} + +type legacyRuleStateMatcherWithBase interface { + legacyMatchStatesWithBase(metadata *adapter.InboundContext, base ruleMatchState) legacyRuleMatchStateSet +} + +func legacyMatchHeadlessRuleStates(rule adapter.HeadlessRule, metadata *adapter.InboundContext) legacyRuleMatchStateSet { + return legacyMatchHeadlessRuleStatesWithBase(rule, metadata, 0) +} + +func legacyMatchHeadlessRuleStatesWithBase(rule adapter.HeadlessRule, metadata *adapter.InboundContext, base ruleMatchState) legacyRuleMatchStateSet { + if matcher, loaded := rule.(legacyRuleStateMatcherWithBase); loaded { + return matcher.legacyMatchStatesWithBase(metadata, base) + } + if matcher, loaded := rule.(legacyRuleStateMatcher); loaded { + return matcher.legacyMatchStates(metadata).withBase(base) + } + if rule.Match(metadata) { + return legacySingleRuleMatchState(base) + } + return legacyRuleMatchStateSet{} +} + +func legacyMatchRuleItemStatesWithBase(item RuleItem, metadata *adapter.InboundContext, base ruleMatchState) legacyRuleMatchStateSet { + if matcher, loaded := item.(legacyRuleStateMatcherWithBase); loaded { + return matcher.legacyMatchStatesWithBase(metadata, base) + } + if matcher, loaded := item.(legacyRuleStateMatcher); loaded { + return matcher.legacyMatchStates(metadata).withBase(base) + } + if item.Match(metadata) { + return legacySingleRuleMatchState(base) + } + return legacyRuleMatchStateSet{} +} + +func (r *DefaultHeadlessRule) legacyMatchStates(metadata *adapter.InboundContext) legacyRuleMatchStateSet { + return r.abstractDefaultRule.legacyMatchStates(metadata) +} + +func (r *LogicalHeadlessRule) legacyMatchStates(metadata *adapter.InboundContext) legacyRuleMatchStateSet { + return r.abstractLogicalRule.legacyMatchStates(metadata) +} + +func (r *RuleSetItem) legacyMatchStates(metadata *adapter.InboundContext) legacyRuleMatchStateSet { + return r.legacyMatchStatesWithBase(metadata, 0) +} + +func (r *RuleSetItem) legacyMatchStatesWithBase(metadata *adapter.InboundContext, base ruleMatchState) legacyRuleMatchStateSet { + var stateSet legacyRuleMatchStateSet + for _, ruleSet := range r.setList { + nestedMetadata := *metadata + nestedMetadata.ResetRuleMatchCache() + nestedMetadata.IPCIDRMatchSource = r.ipCidrMatchSource + nestedMetadata.IPCIDRAcceptEmpty = r.ipCidrAcceptEmpty + stateSet = stateSet.merge(legacyMatchHeadlessRuleStatesWithBase(ruleSet, &nestedMetadata, base)) + } + return stateSet +} + +func (s *LocalRuleSet) legacyMatchStates(metadata *adapter.InboundContext) legacyRuleMatchStateSet { + return s.legacyMatchStatesWithBase(metadata, 0) +} + +func (s *LocalRuleSet) legacyMatchStatesWithBase(metadata *adapter.InboundContext, base ruleMatchState) legacyRuleMatchStateSet { + var stateSet legacyRuleMatchStateSet + for _, rule := range s.rules { + nestedMetadata := *metadata + nestedMetadata.ResetRuleMatchCache() + stateSet = stateSet.merge(legacyMatchHeadlessRuleStatesWithBase(rule, &nestedMetadata, base)) + } + return stateSet +} + +func (s *RemoteRuleSet) legacyMatchStates(metadata *adapter.InboundContext) legacyRuleMatchStateSet { + return s.legacyMatchStatesWithBase(metadata, 0) +} + +func (s *RemoteRuleSet) legacyMatchStatesWithBase(metadata *adapter.InboundContext, base ruleMatchState) legacyRuleMatchStateSet { + var stateSet legacyRuleMatchStateSet + for _, rule := range s.rules { + nestedMetadata := *metadata + nestedMetadata.ResetRuleMatchCache() + stateSet = stateSet.merge(legacyMatchHeadlessRuleStatesWithBase(rule, &nestedMetadata, base)) + } + return stateSet +} + +func (r *abstractDefaultRule) legacyMatchStates(metadata *adapter.InboundContext) legacyRuleMatchStateSet { + return r.legacyMatchStatesWithBase(metadata, 0) +} + +func (r *abstractDefaultRule) legacyMatchStatesWithBase(metadata *adapter.InboundContext, inheritedBase ruleMatchState) legacyRuleMatchStateSet { + if len(r.allItems) == 0 { + return legacySingleRuleMatchState(inheritedBase) + } + evaluationBase := inheritedBase + if r.invert { + evaluationBase = 0 + } + stateSet := legacySingleRuleMatchState(evaluationBase) + if len(r.sourceAddressItems) > 0 { + metadata.DidMatch = true + if matchAnyItem(r.sourceAddressItems, metadata) { + stateSet = stateSet.addBit(ruleMatchSourceAddress) + } + } + if r.destinationIPCIDRMatchesSource(metadata) { + metadata.DidMatch = true + stateSet = stateSet.branchOnBit(ruleMatchSourceAddress, legacyDestinationIPFormula(r.destinationIPCIDRItems, metadata)) + } + if len(r.sourcePortItems) > 0 { + metadata.DidMatch = true + if matchAnyItem(r.sourcePortItems, metadata) { + stateSet = stateSet.addBit(ruleMatchSourcePort) + } + } + if len(r.destinationAddressItems) > 0 { + metadata.DidMatch = true + if matchAnyItem(r.destinationAddressItems, metadata) { + stateSet = stateSet.addBit(ruleMatchDestinationAddress) + } + } + if r.legacyDestinationIPCIDRMatchesDestination(metadata) { + metadata.DidMatch = true + stateSet = stateSet.branchOnBit(ruleMatchDestinationAddress, legacyDestinationIPFormula(r.destinationIPCIDRItems, metadata)) + } + if len(r.destinationPortItems) > 0 { + metadata.DidMatch = true + if matchAnyItem(r.destinationPortItems, metadata) { + stateSet = stateSet.addBit(ruleMatchDestinationPort) + } + } + for _, item := range r.items { + metadata.DidMatch = true + if !item.Match(metadata) { + if r.invert { + return legacySingleRuleMatchState(inheritedBase) + } + return legacyRuleMatchStateSet{} + } + } + if r.ruleSetItem != nil { + metadata.DidMatch = true + var merged legacyRuleMatchStateSet + for state := ruleMatchState(0); state < 16; state++ { + if stateSet[state].isFalse() { + continue + } + nestedStateSet := legacyMatchRuleItemStatesWithBase(r.ruleSetItem, metadata, state) + merged = merged.merge(nestedStateSet.andFormula(stateSet[state])) + } + stateSet = merged + } + stateSet = stateSet.filter(func(state ruleMatchState) bool { + if r.legacyRequiresSourceAddressMatch(metadata) && !state.has(ruleMatchSourceAddress) { + return false + } + if len(r.sourcePortItems) > 0 && !state.has(ruleMatchSourcePort) { + return false + } + if r.legacyRequiresDestinationAddressMatch(metadata) && !state.has(ruleMatchDestinationAddress) { + return false + } + if len(r.destinationPortItems) > 0 && !state.has(ruleMatchDestinationPort) { + return false + } + return true + }) + if r.invert { + return legacySingleRuleMatchStateWithFormula(inheritedBase, stateSet.anyFormula().not()) + } + return stateSet +} + +func (r *abstractDefaultRule) legacyRequiresSourceAddressMatch(metadata *adapter.InboundContext) bool { + return len(r.sourceAddressItems) > 0 || r.destinationIPCIDRMatchesSource(metadata) +} + +func (r *abstractDefaultRule) legacyDestinationIPCIDRMatchesDestination(metadata *adapter.InboundContext) bool { + return !metadata.IPCIDRMatchSource && len(r.destinationIPCIDRItems) > 0 +} + +func (r *abstractDefaultRule) legacyRequiresDestinationAddressMatch(metadata *adapter.InboundContext) bool { + return len(r.destinationAddressItems) > 0 || r.legacyDestinationIPCIDRMatchesDestination(metadata) +} + +func (r *abstractLogicalRule) legacyMatchStates(metadata *adapter.InboundContext) legacyRuleMatchStateSet { + return r.legacyMatchStatesWithBase(metadata, 0) +} + +func (r *abstractLogicalRule) legacyMatchStatesWithBase(metadata *adapter.InboundContext, base ruleMatchState) legacyRuleMatchStateSet { + evaluationBase := base + if r.invert { + evaluationBase = 0 + } + var stateSet legacyRuleMatchStateSet + if r.mode == C.LogicalTypeAnd { + stateSet = legacySingleRuleMatchState(evaluationBase) + for _, rule := range r.rules { + nestedMetadata := *metadata + nestedMetadata.ResetRuleCache() + stateSet = stateSet.combine(legacyMatchHeadlessRuleStatesWithBase(rule, &nestedMetadata, evaluationBase)) + if stateSet.isEmpty() && !r.invert { + return legacyRuleMatchStateSet{} + } + } + } else { + for _, rule := range r.rules { + nestedMetadata := *metadata + nestedMetadata.ResetRuleCache() + stateSet = stateSet.merge(legacyMatchHeadlessRuleStatesWithBase(rule, &nestedMetadata, evaluationBase)) + } + } + if r.invert { + return legacySingleRuleMatchStateWithFormula(base, stateSet.anyFormula().not()) + } + return stateSet +} + +func legacyDestinationIPFormula(items []RuleItem, metadata *adapter.InboundContext) legacyResponseFormula { + if legacyDestinationIPResolved(metadata) { + if matchAnyItem(items, metadata) { + return legacyTrueFormula() + } + return legacyFalseFormula() + } + var formula legacyResponseFormula + for _, rawItem := range items { + switch item := rawItem.(type) { + case *IPCIDRItem: + if item.isSource || metadata.IPCIDRMatchSource { + if item.Match(metadata) { + return legacyTrueFormula() + } + continue + } + formula = formula.or(legacyLiteralFormula(legacyResponseLiteral{ + kind: legacyLiteralRequireSet, + ipSet: item.ipSet, + })) + if metadata.IPCIDRAcceptEmpty { + formula = formula.or(legacyLiteralFormula(legacyResponseLiteral{ + kind: legacyLiteralRequireEmpty, + })) + } + case *IPIsPrivateItem: + if item.isSource { + if item.Match(metadata) { + return legacyTrueFormula() + } + continue + } + formula = formula.or(legacyLiteralFormula(legacyResponseLiteral{ + kind: legacyLiteralRequireSet, + ipSet: legacyNonPublicIPSet, + })) + case *IPAcceptAnyItem: + formula = formula.or(legacyLiteralFormula(legacyResponseLiteral{ + kind: legacyLiteralRequireNonEmpty, + })) + default: + if rawItem.Match(metadata) { + return legacyTrueFormula() + } + } + } + return formula +} + +func legacyDestinationIPResolved(metadata *adapter.InboundContext) bool { + return metadata.IPCIDRMatchSource || + metadata.DestinationAddressMatchFromResponse || + metadata.DNSResponse != nil || + metadata.Destination.IsIP() || + len(metadata.DestinationAddresses) > 0 +} diff --git a/route/rule/rule_set_semantics_test.go b/route/rule/rule_set_semantics_test.go index 3a9d4ca9d..66e2646fc 100644 --- a/route/rule/rule_set_semantics_test.go +++ b/route/rule/rule_set_semantics_test.go @@ -787,10 +787,9 @@ func TestDNSLegacyAddressLimitPreLookupDefersDirectRules(t *testing.T) { rule := dnsRuleForTest(func(rule *abstractDefaultRule) { testCase.build(t, rule) }) - rule.legacyAddressFilter = true preLookupMetadata := testMetadata("lookup.example") - require.True(t, rule.Match(&preLookupMetadata)) + require.True(t, rule.LegacyPreMatch(&preLookupMetadata)) matchedMetadata := testMetadata("lookup.example") require.True(t, rule.MatchAddressLimit(&matchedMetadata, testCase.matchedResponse)) @@ -810,10 +809,9 @@ func TestDNSLegacyAddressLimitPreLookupDefersRuleSetDestinationCIDR(t *testing.T rule := dnsRuleForTest(func(rule *abstractDefaultRule) { addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{ruleSet}}) }) - rule.legacyAddressFilter = true preLookupMetadata := testMetadata("lookup.example") - require.True(t, rule.Match(&preLookupMetadata)) + require.True(t, rule.LegacyPreMatch(&preLookupMetadata)) matchedMetadata := testMetadata("lookup.example") require.True(t, rule.MatchAddressLimit(&matchedMetadata, dnsResponseForTest(netip.MustParseAddr("203.0.113.1")))) @@ -833,11 +831,10 @@ func TestDNSLegacyLogicalAddressLimitPreLookupDefersNestedRules(t *testing.T) { rules: []adapter.HeadlessRule{nestedRule}, mode: C.LogicalTypeAnd, }, - legacyAddressFilter: true, } preLookupMetadata := testMetadata("lookup.example") - require.True(t, logicalRule.Match(&preLookupMetadata)) + require.True(t, logicalRule.LegacyPreMatch(&preLookupMetadata)) matchedMetadata := testMetadata("lookup.example") require.True(t, logicalRule.MatchAddressLimit(&matchedMetadata, dnsResponseForTest(netip.MustParseAddr("10.0.0.1")))) @@ -846,6 +843,132 @@ func TestDNSLegacyLogicalAddressLimitPreLookupDefersNestedRules(t *testing.T) { require.False(t, logicalRule.MatchAddressLimit(&unmatchedMetadata, dnsResponseForTest(netip.MustParseAddr("8.8.8.8")))) } +func TestDNSLegacyInvertAddressLimitPreLookupRegression(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + build func(*testing.T, *abstractDefaultRule) + matchedAddrs []netip.Addr + unmatchedAddrs []netip.Addr + }{ + { + name: "ip_cidr", + build: func(t *testing.T, rule *abstractDefaultRule) { + t.Helper() + addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"}) + }, + matchedAddrs: []netip.Addr{netip.MustParseAddr("203.0.113.1")}, + unmatchedAddrs: []netip.Addr{netip.MustParseAddr("8.8.8.8")}, + }, + { + name: "ip_is_private", + build: func(t *testing.T, rule *abstractDefaultRule) { + t.Helper() + addDestinationIPIsPrivateItem(rule) + }, + matchedAddrs: []netip.Addr{netip.MustParseAddr("10.0.0.1")}, + unmatchedAddrs: []netip.Addr{netip.MustParseAddr("8.8.8.8")}, + }, + { + name: "ip_accept_any", + build: func(t *testing.T, rule *abstractDefaultRule) { + t.Helper() + addDestinationIPAcceptAnyItem(rule) + }, + matchedAddrs: []netip.Addr{netip.MustParseAddr("203.0.113.1")}, + }, + } + for _, testCase := range testCases { + testCase := testCase + t.Run(testCase.name, func(t *testing.T) { + t.Parallel() + + rule := dnsRuleForTest(func(rule *abstractDefaultRule) { + rule.invert = true + testCase.build(t, rule) + }) + + preLookupMetadata := testMetadata("lookup.example") + require.True(t, rule.LegacyPreMatch(&preLookupMetadata)) + + matchedMetadata := testMetadata("lookup.example") + require.False(t, rule.MatchAddressLimit(&matchedMetadata, dnsResponseForTest(testCase.matchedAddrs...))) + + unmatchedMetadata := testMetadata("lookup.example") + require.True(t, rule.MatchAddressLimit(&unmatchedMetadata, dnsResponseForTest(testCase.unmatchedAddrs...))) + }) + } +} + +func TestDNSLegacyInvertLogicalAddressLimitPreLookupRegression(t *testing.T) { + t.Parallel() + + t.Run("wrapper invert keeps nested deferred rule matchable", func(t *testing.T) { + t.Parallel() + + nestedRule := dnsRuleForTest(func(rule *abstractDefaultRule) { + addDestinationIPIsPrivateItem(rule) + }) + logicalRule := &LogicalDNSRule{ + abstractLogicalRule: abstractLogicalRule{ + rules: []adapter.HeadlessRule{nestedRule}, + mode: C.LogicalTypeAnd, + invert: true, + }, + } + + preLookupMetadata := testMetadata("lookup.example") + require.True(t, logicalRule.LegacyPreMatch(&preLookupMetadata)) + + matchedMetadata := testMetadata("lookup.example") + require.False(t, logicalRule.MatchAddressLimit(&matchedMetadata, dnsResponseForTest(netip.MustParseAddr("10.0.0.1")))) + + unmatchedMetadata := testMetadata("lookup.example") + require.True(t, logicalRule.MatchAddressLimit(&unmatchedMetadata, dnsResponseForTest(netip.MustParseAddr("8.8.8.8")))) + }) + + t.Run("inverted deferred child does not suppress branch", func(t *testing.T) { + t.Parallel() + + logicalRule := &LogicalDNSRule{ + abstractLogicalRule: abstractLogicalRule{ + rules: []adapter.HeadlessRule{ + dnsRuleForTest(func(rule *abstractDefaultRule) { + rule.invert = true + addDestinationIPIsPrivateItem(rule) + }), + }, + mode: C.LogicalTypeAnd, + }, + } + + preLookupMetadata := testMetadata("lookup.example") + require.True(t, logicalRule.LegacyPreMatch(&preLookupMetadata)) + }) +} + +func TestDNSLegacyInvertRuleSetAddressLimitPreLookupRegression(t *testing.T) { + t.Parallel() + + ruleSet := newLocalRuleSetForTest("dns-legacy-invert-ipcidr", headlessDefaultRule(t, func(rule *abstractDefaultRule) { + rule.invert = true + addDestinationIPCIDRItem(t, rule, []string{"203.0.113.0/24"}) + })) + rule := dnsRuleForTest(func(rule *abstractDefaultRule) { + addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{ruleSet}}) + }) + + preLookupMetadata := testMetadata("lookup.example") + require.True(t, rule.LegacyPreMatch(&preLookupMetadata)) + + matchedMetadata := testMetadata("lookup.example") + require.False(t, rule.MatchAddressLimit(&matchedMetadata, dnsResponseForTest(netip.MustParseAddr("203.0.113.1")))) + + unmatchedMetadata := testMetadata("lookup.example") + require.True(t, rule.MatchAddressLimit(&unmatchedMetadata, dnsResponseForTest(netip.MustParseAddr("8.8.8.8")))) +} + func TestDNSInvertAddressLimitPreLookupRegression(t *testing.T) { t.Parallel() testCases := []struct {