diff --git a/route/rule/rule_dns_legacy.go b/route/rule/rule_dns_legacy.go index 23816d7e6..8cdad83f3 100644 --- a/route/rule/rule_dns_legacy.go +++ b/route/rule/rule_dns_legacy.go @@ -24,9 +24,28 @@ type legacyResponseLiteral struct { ipSet *netipx.IPSet } -type legacyResponseTerm []legacyResponseLiteral +type legacyResponseFormulaKind uint8 -type legacyResponseFormula []legacyResponseTerm +const ( + legacyFormulaFalse legacyResponseFormulaKind = iota + legacyFormulaTrue + legacyFormulaLiteral + legacyFormulaAnd + legacyFormulaOr +) + +type legacyResponseFormula struct { + kind legacyResponseFormulaKind + literal legacyResponseLiteral + children []legacyResponseFormula +} + +type legacyResponseConstraint struct { + requireEmpty bool + requireNonEmpty bool + requiredSets []*netipx.IPSet + forbiddenSet *netipx.IPSet +} type legacyRuleMatchStateSet [16]legacyResponseFormula @@ -59,83 +78,59 @@ var ( ) func legacyFalseFormula() legacyResponseFormula { - return nil + return legacyResponseFormula{} } func legacyTrueFormula() legacyResponseFormula { - return legacyResponseFormula{legacyResponseTerm{}} + return legacyResponseFormula{kind: legacyFormulaTrue} } func legacyLiteralFormula(literal legacyResponseLiteral) legacyResponseFormula { - return legacyResponseFormula{legacyResponseTerm{literal}} + return legacyResponseFormula{ + kind: legacyFormulaLiteral, + literal: literal, + } } func (f legacyResponseFormula) isFalse() bool { - return len(f) == 0 + return f.kind == legacyFormulaFalse } func (f legacyResponseFormula) isTrue() bool { - return len(f) == 1 && len(f[0]) == 0 + return f.kind == legacyFormulaTrue } func (f legacyResponseFormula) or(other legacyResponseFormula) legacyResponseFormula { - if f.isFalse() { - return other - } - if other.isFalse() { - return f - } - result := make(legacyResponseFormula, 0, len(f)+len(other)) - result = append(result, f...) - result = append(result, other...) - return result + return legacyOrFormulas(f, other) } func (f legacyResponseFormula) and(other legacyResponseFormula) legacyResponseFormula { - if f.isFalse() || other.isFalse() { - return legacyFalseFormula() - } - if f.isTrue() { - return other - } - if other.isTrue() { - return f - } - var result legacyResponseFormula - for _, leftTerm := range f { - for _, rightTerm := range other { - combined, valid := legacyCombineResponseTerms(leftTerm, rightTerm) - if valid { - result = append(result, combined) - } - } - } - return result + return legacyAndFormulas(f, other) } func (f legacyResponseFormula) not() legacyResponseFormula { - if f.isFalse() { + switch f.kind { + case legacyFormulaFalse: return legacyTrueFormula() - } - result := legacyTrueFormula() - for _, term := range f { - result = result.and(legacyNegateResponseTerm(term)) - if result.isFalse() { - return result - } - } - return result -} - -func legacyNegateResponseTerm(term legacyResponseTerm) legacyResponseFormula { - if len(term) == 0 { + case legacyFormulaTrue: return legacyFalseFormula() + case legacyFormulaLiteral: + return legacyLiteralFormula(legacyNegateResponseLiteral(f.literal)) + case legacyFormulaAnd: + negated := make([]legacyResponseFormula, 0, len(f.children)) + for _, child := range f.children { + negated = append(negated, child.not()) + } + return legacyOrFormulas(negated...) + case legacyFormulaOr: + negated := make([]legacyResponseFormula, 0, len(f.children)) + for _, child := range f.children { + negated = append(negated, child.not()) + } + return legacyAndFormulas(negated...) + default: + panic("unknown legacy response formula kind") } - result := make(legacyResponseFormula, 0, len(term)) - for _, literal := range term { - result = append(result, legacyResponseTerm{legacyNegateResponseLiteral(literal)}) - } - return result } func legacyNegateResponseLiteral(literal legacyResponseLiteral) legacyResponseLiteral { @@ -153,62 +148,171 @@ func legacyNegateResponseLiteral(literal legacyResponseLiteral) legacyResponseLi } } -func legacyCombineResponseTerms(left legacyResponseTerm, right legacyResponseTerm) (legacyResponseTerm, bool) { - combined := make(legacyResponseTerm, 0, len(left)+len(right)) - combined = append(combined, left...) - combined = append(combined, right...) - if !legacyResponseTermSatisfiable(combined) { - return nil, false +func legacyOrFormulas(formulas ...legacyResponseFormula) legacyResponseFormula { + children := make([]legacyResponseFormula, 0, len(formulas)) + for _, formula := range formulas { + if formula.isFalse() { + continue + } + if formula.isTrue() { + return legacyTrueFormula() + } + if formula.kind == legacyFormulaOr { + children = append(children, formula.children...) + continue + } + children = append(children, formula) } - return combined, true -} - -func legacyResponseTermSatisfiable(term legacyResponseTerm) bool { - var ( - requireEmpty bool - requireNonEmpty bool - requiredSets []*netipx.IPSet - forbiddenBuild netipx.IPSetBuilder - hasForbidden bool - ) - for _, literal := range term { - switch literal.kind { - case legacyLiteralRequireEmpty: - requireEmpty = true - case legacyLiteralRequireNonEmpty: - requireNonEmpty = true - case legacyLiteralRequireSet: - requiredSets = append(requiredSets, literal.ipSet) - case legacyLiteralForbidSet: - if literal.ipSet != nil { - forbiddenBuild.AddSet(literal.ipSet) - hasForbidden = true - } - default: - panic("unknown legacy response literal kind") + switch len(children) { + case 0: + return legacyFalseFormula() + case 1: + return children[0] + default: + return legacyResponseFormula{ + kind: legacyFormulaOr, + children: children, } } - if requireEmpty && (requireNonEmpty || len(requiredSets) > 0) { - return false +} + +func legacyAndFormulas(formulas ...legacyResponseFormula) legacyResponseFormula { + children := make([]legacyResponseFormula, 0, len(formulas)) + for _, formula := range formulas { + if formula.isFalse() { + return legacyFalseFormula() + } + if formula.isTrue() { + continue + } + if formula.kind == legacyFormulaAnd { + children = append(children, formula.children...) + continue + } + children = append(children, formula) } - if requireEmpty { + switch len(children) { + case 0: + return legacyTrueFormula() + case 1: + return children[0] + } + result := legacyResponseFormula{ + kind: legacyFormulaAnd, + children: children, + } + if !result.satisfiable() { + return legacyFalseFormula() + } + return result +} + +func (f legacyResponseFormula) satisfiable() bool { + return legacyResponseFormulasSatisfiable(legacyResponseConstraint{}, []legacyResponseFormula{f}) +} + +func legacyResponseFormulasSatisfiable(constraint legacyResponseConstraint, formulas []legacyResponseFormula) bool { + stack := append(make([]legacyResponseFormula, 0, len(formulas)), formulas...) + var disjunctions []legacyResponseFormula + for len(stack) > 0 { + formula := stack[len(stack)-1] + stack = stack[:len(stack)-1] + switch formula.kind { + case legacyFormulaFalse: + return false + case legacyFormulaTrue: + continue + case legacyFormulaLiteral: + var ok bool + constraint, ok = constraint.withLiteral(formula.literal) + if !ok { + return false + } + case legacyFormulaAnd: + stack = append(stack, formula.children...) + case legacyFormulaOr: + if len(formula.children) == 0 { + return false + } + disjunctions = append(disjunctions, formula) + default: + panic("unknown legacy response formula kind") + } + } + if len(disjunctions) == 0 { return true } - var forbidden *netipx.IPSet - if hasForbidden { - forbidden = common.Must1(forbiddenBuild.IPSet()) + bestIndex := 0 + for i := 1; i < len(disjunctions); i++ { + if len(disjunctions[i].children) < len(disjunctions[bestIndex].children) { + bestIndex = i + } } - for _, required := range requiredSets { - if !legacyIPSetHasAllowedIP(required, forbidden) { + selected := disjunctions[bestIndex] + remaining := make([]legacyResponseFormula, 0, len(disjunctions)-1) + remaining = append(remaining, disjunctions[:bestIndex]...) + remaining = append(remaining, disjunctions[bestIndex+1:]...) + for _, child := range selected.children { + nextFormulas := make([]legacyResponseFormula, 0, len(remaining)+1) + nextFormulas = append(nextFormulas, remaining...) + nextFormulas = append(nextFormulas, child) + if legacyResponseFormulasSatisfiable(constraint, nextFormulas) { + return true + } + } + return false +} + +func (c legacyResponseConstraint) withLiteral(literal legacyResponseLiteral) (legacyResponseConstraint, bool) { + switch literal.kind { + case legacyLiteralRequireEmpty: + c.requireEmpty = true + case legacyLiteralRequireNonEmpty: + c.requireNonEmpty = true + case legacyLiteralRequireSet: + requiredSets := make([]*netipx.IPSet, len(c.requiredSets)+1) + copy(requiredSets, c.requiredSets) + requiredSets[len(c.requiredSets)] = literal.ipSet + c.requiredSets = requiredSets + case legacyLiteralForbidSet: + c.forbiddenSet = legacyUnionIPSets(c.forbiddenSet, literal.ipSet) + default: + panic("unknown legacy response literal kind") + } + return c, c.satisfiable() +} + +func (c legacyResponseConstraint) satisfiable() bool { + if c.requireEmpty && (c.requireNonEmpty || len(c.requiredSets) > 0) { + return false + } + if c.requireEmpty { + return true + } + for _, required := range c.requiredSets { + if !legacyIPSetHasAllowedIP(required, c.forbiddenSet) { return false } } - if requireNonEmpty && len(requiredSets) == 0 { - return legacyIPSetHasAllowedIP(legacyAllIPSet, forbidden) + if c.requireNonEmpty && len(c.requiredSets) == 0 { + return legacyIPSetHasAllowedIP(legacyAllIPSet, c.forbiddenSet) } return true } +func legacyUnionIPSets(left *netipx.IPSet, right *netipx.IPSet) *netipx.IPSet { + if left == nil { + return right + } + if right == nil { + return left + } + var builder netipx.IPSetBuilder + builder.AddSet(left) + builder.AddSet(right) + return common.Must1(builder.IPSet()) +} + func legacyIPSetHasAllowedIP(required *netipx.IPSet, forbidden *netipx.IPSet) bool { if required == nil { required = legacyAllIPSet diff --git a/route/rule/rule_set_semantics_test.go b/route/rule/rule_set_semantics_test.go index 66e2646fc..03fb64ef3 100644 --- a/route/rule/rule_set_semantics_test.go +++ b/route/rule/rule_set_semantics_test.go @@ -969,6 +969,94 @@ func TestDNSLegacyInvertRuleSetAddressLimitPreLookupRegression(t *testing.T) { require.True(t, rule.MatchAddressLimit(&unmatchedMetadata, dnsResponseForTest(netip.MustParseAddr("8.8.8.8")))) } +func TestDNSLegacyInvertNegationStressRegression(t *testing.T) { + t.Parallel() + + const branchCount = 20 + unmatchedResponse := dnsResponseForTest(netip.MustParseAddr("203.0.113.250")) + + t.Run("logical wrapper", func(t *testing.T) { + t.Parallel() + + branches := make([]adapter.HeadlessRule, 0, branchCount) + var matchedAddrs []netip.Addr + for i := 0; i < branchCount; i++ { + firstCIDR, secondCIDR, branchAddrs := legacyNegationBranchCIDRs(i) + if matchedAddrs == nil { + matchedAddrs = branchAddrs + } + branches = append(branches, &LogicalDNSRule{ + abstractLogicalRule: abstractLogicalRule{ + mode: C.LogicalTypeAnd, + rules: []adapter.HeadlessRule{ + dnsRuleForTest(func(rule *abstractDefaultRule) { + addDestinationIPCIDRItem(t, rule, []string{firstCIDR}) + }), + dnsRuleForTest(func(rule *abstractDefaultRule) { + addDestinationIPCIDRItem(t, rule, []string{secondCIDR}) + }), + }, + }, + }) + } + + rule := &LogicalDNSRule{ + abstractLogicalRule: abstractLogicalRule{ + rules: branches, + mode: C.LogicalTypeOr, + invert: true, + }, + } + + preLookupMetadata := testMetadata("lookup.example") + require.True(t, rule.LegacyPreMatch(&preLookupMetadata)) + + matchedMetadata := testMetadata("lookup.example") + require.False(t, rule.MatchAddressLimit(&matchedMetadata, dnsResponseForTest(matchedAddrs...))) + + unmatchedMetadata := testMetadata("lookup.example") + require.True(t, rule.MatchAddressLimit(&unmatchedMetadata, unmatchedResponse)) + }) + + t.Run("ruleset wrapper", func(t *testing.T) { + t.Parallel() + + branches := make([]adapter.HeadlessRule, 0, branchCount) + var matchedAddrs []netip.Addr + for i := 0; i < branchCount; i++ { + firstCIDR, secondCIDR, branchAddrs := legacyNegationBranchCIDRs(i) + if matchedAddrs == nil { + matchedAddrs = branchAddrs + } + branches = append(branches, headlessLogicalRule( + C.LogicalTypeAnd, + false, + headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addDestinationIPCIDRItem(t, rule, []string{firstCIDR}) + }), + headlessDefaultRule(t, func(rule *abstractDefaultRule) { + addDestinationIPCIDRItem(t, rule, []string{secondCIDR}) + }), + )) + } + + ruleSet := newLocalRuleSetForTest("dns-legacy-negation-stress", branches...) + rule := dnsRuleForTest(func(rule *abstractDefaultRule) { + rule.invert = true + addRuleSetItem(rule, &RuleSetItem{setList: []adapter.RuleSet{ruleSet}}) + }) + + preLookupMetadata := testMetadata("lookup.example") + require.True(t, rule.LegacyPreMatch(&preLookupMetadata)) + + matchedMetadata := testMetadata("lookup.example") + require.False(t, rule.MatchAddressLimit(&matchedMetadata, dnsResponseForTest(matchedAddrs...))) + + unmatchedMetadata := testMetadata("lookup.example") + require.True(t, rule.MatchAddressLimit(&unmatchedMetadata, unmatchedResponse)) + }) +} + func TestDNSInvertAddressLimitPreLookupRegression(t *testing.T) { t.Parallel() testCases := []struct { @@ -1149,6 +1237,12 @@ func dnsResponseForTest(addresses ...netip.Addr) *mDNS.Msg { return response } +func legacyNegationBranchCIDRs(index int) (string, string, []netip.Addr) { + first := netip.AddrFrom4([4]byte{198, 18, 0, byte(index*2 + 1)}) + second := netip.AddrFrom4([4]byte{198, 18, 0, byte(index*2 + 2)}) + return first.String() + "/32", second.String() + "/32", []netip.Addr{first, second} +} + func addRuleSetItem(rule *abstractDefaultRule, item *RuleSetItem) { rule.ruleSetItem = item rule.allItems = append(rule.allItems, item)