Fix DNS evaluate routing regressions

This commit is contained in:
世界
2026-03-24 17:36:02 +08:00
parent 33e4fcc400
commit 27b60052fe
3 changed files with 258 additions and 51 deletions

View File

@@ -275,7 +275,7 @@ func (r *Router) logRuleMatch(ctx context.Context, ruleIndex int, currentRule ad
}
}
func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, options adapter.DNSQueryOptions, allowFakeIP bool) (*mDNS.Msg, adapter.DNSTransport, error) {
func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, options adapter.DNSQueryOptions, allowFakeIP bool) (*mDNS.Msg, adapter.DNSTransport, adapter.DNSQueryOptions, error) {
metadata := adapter.ContextFrom(ctx)
if metadata == nil {
panic("no context")
@@ -328,7 +328,7 @@ func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, optio
queryOptions.Strategy = r.defaultDomainStrategy
}
response, err := r.client.Exchange(adapter.OverrideContext(ctx), transport, message, queryOptions, nil)
return response, transport, err
return response, transport, queryOptions, err
case *R.RuleActionReject:
switch action.Method {
case C.RuleActionRejectMethodDefault:
@@ -339,12 +339,12 @@ func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, optio
Response: true,
},
Question: []mDNS.Question{message.Question[0]},
}, nil, nil
}, nil, effectiveOptions, nil
case C.RuleActionRejectMethodDrop:
return nil, nil, tun.ErrDrop
return nil, nil, effectiveOptions, tun.ErrDrop
}
case *R.RuleActionPredefined:
return action.Response(message), nil, nil
return action.Response(message), nil, effectiveOptions, nil
}
}
queryOptions := effectiveOptions
@@ -354,55 +354,83 @@ func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, optio
queryOptions.Strategy = r.defaultDomainStrategy
}
response, err := r.client.Exchange(adapter.OverrideContext(ctx), transport, message, queryOptions, nil)
return response, transport, err
return response, transport, queryOptions, err
}
type lookupWithRulesResponse struct {
addresses []netip.Addr
strategy C.DomainStrategy
}
func (r *Router) resolveLookupStrategy(options adapter.DNSQueryOptions, strategies ...C.DomainStrategy) C.DomainStrategy {
if options.LookupStrategy != C.DomainStrategyAsIS {
return options.LookupStrategy
}
for _, strategy := range strategies {
if strategy != C.DomainStrategyAsIS {
return strategy
}
}
if options.Strategy != C.DomainStrategyAsIS {
return options.Strategy
}
return r.defaultDomainStrategy
}
func lookupStrategyAllowsQueryType(strategy C.DomainStrategy, qType uint16) bool {
switch strategy {
case C.DomainStrategyIPv4Only:
return qType == mDNS.TypeA
case C.DomainStrategyIPv6Only:
return qType == mDNS.TypeAAAA
default:
return true
}
}
func (r *Router) lookupWithRules(ctx context.Context, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) {
var strategy C.DomainStrategy
if options.LookupStrategy != C.DomainStrategyAsIS {
strategy = options.LookupStrategy
} else {
strategy = options.Strategy
}
lookupOptions := options
if options.LookupStrategy != C.DomainStrategyAsIS {
lookupOptions.Strategy = strategy
lookupOptions.Strategy = options.LookupStrategy
}
if strategy == C.DomainStrategyIPv4Only {
return r.lookupWithRulesType(ctx, domain, mDNS.TypeA, lookupOptions)
if options.LookupStrategy == C.DomainStrategyIPv4Only {
response, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeA, lookupOptions)
return response.addresses, err
}
if strategy == C.DomainStrategyIPv6Only {
return r.lookupWithRulesType(ctx, domain, mDNS.TypeAAAA, lookupOptions)
if options.LookupStrategy == C.DomainStrategyIPv6Only {
response, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeAAAA, lookupOptions)
return response.addresses, err
}
var (
response4 []netip.Addr
response6 []netip.Addr
response4 lookupWithRulesResponse
response6 lookupWithRulesResponse
)
var group task.Group
group.Append("exchange4", func(ctx context.Context) error {
response, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeA, lookupOptions)
if err != nil {
return err
}
response4 = response
return nil
result, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeA, lookupOptions)
response4 = result
return err
})
group.Append("exchange6", func(ctx context.Context) error {
response, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeAAAA, lookupOptions)
if err != nil {
return err
}
response6 = response
return nil
result, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeAAAA, lookupOptions)
response6 = result
return err
})
err := group.Run(ctx)
if len(response4) == 0 && len(response6) == 0 {
strategy := r.resolveLookupStrategy(options, response4.strategy, response6.strategy)
if !lookupStrategyAllowsQueryType(strategy, mDNS.TypeA) {
response4.addresses = nil
}
if !lookupStrategyAllowsQueryType(strategy, mDNS.TypeAAAA) {
response6.addresses = nil
}
if len(response4.addresses) == 0 && len(response6.addresses) == 0 {
return nil, err
}
return sortAddresses(response4, response6, strategy), nil
return sortAddresses(response4.addresses, response6.addresses, strategy), nil
}
func (r *Router) lookupWithRulesType(ctx context.Context, domain string, qType uint16, options adapter.DNSQueryOptions) ([]netip.Addr, error) {
func (r *Router) lookupWithRulesType(ctx context.Context, domain string, qType uint16, options adapter.DNSQueryOptions) (lookupWithRulesResponse, error) {
request := &mDNS.Msg{
MsgHdr: mDNS.MsgHdr{
RecursionDesired: true,
@@ -413,14 +441,21 @@ func (r *Router) lookupWithRulesType(ctx context.Context, domain string, qType u
Qclass: mDNS.ClassINET,
}},
}
response, _, err := r.exchangeWithRules(adapter.OverrideContext(ctx), request, options, false)
response, _, queryOptions, err := r.exchangeWithRules(adapter.OverrideContext(ctx), request, options, false)
result := lookupWithRulesResponse{
strategy: r.resolveLookupStrategy(options, queryOptions.Strategy),
}
if err != nil {
return nil, err
return result, err
}
if response.Rcode != mDNS.RcodeSuccess {
return nil, RcodeError(response.Rcode)
return result, RcodeError(response.Rcode)
}
return MessageToAddresses(response), nil
if !lookupStrategyAllowsQueryType(result.strategy, qType) {
return result, nil
}
result.addresses = MessageToAddresses(response)
return result, nil
}
func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapter.DNSQueryOptions) (*mDNS.Msg, error) {
@@ -461,7 +496,7 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte
}
response, err = r.client.Exchange(ctx, transport, message, options, nil)
} else if !r.legacyAddressFilterMode {
response, transport, err = r.exchangeWithRules(ctx, message, options, true)
response, transport, _, err = r.exchangeWithRules(ctx, message, options, true)
} else {
var (
rule adapter.DNSRule

View File

@@ -305,6 +305,134 @@ func TestLookupNewModeDoesNotUseQueryTypeRule(t *testing.T) {
require.Equal(t, []netip.Addr{netip.MustParseAddr("3.3.3.3")}, addresses)
}
func TestLookupNewModeAppliesRouteStrategyAfterEvaluate(t *testing.T) {
t.Parallel()
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
router := newTestRouter(t, []option.DNSRule{
{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
Domain: badoption.Listable[string]{"example.com"},
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeEvaluate,
RouteOptions: option.DNSRouteActionOptions{Server: "default"},
},
},
},
{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
MatchResponse: true,
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.DNSRouteActionOptions{
Server: "selected",
Strategy: option.DomainStrategy(C.DomainStrategyIPv4Only),
},
},
},
},
}, &fakeDNSTransportManager{
defaultTransport: defaultTransport,
transports: map[string]adapter.DNSTransport{
"default": defaultTransport,
"selected": &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP},
},
}, &fakeDNSClient{
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
if transport.Tag() == "default" {
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("1.1.1.1")}, 60), nil
}
switch message.Question[0].Qtype {
case mDNS.TypeA:
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2.2.2.2")}, 60), nil
case mDNS.TypeAAAA:
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2001:db8::1")}, 60), nil
default:
return nil, errors.New("unexpected qtype")
}
},
})
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
require.NoError(t, err)
require.Equal(t, []netip.Addr{netip.MustParseAddr("2.2.2.2")}, addresses)
}
func TestExchangeNewModeLogicalMatchResponseIPCIDRFallsThrough(t *testing.T) {
t.Parallel()
transportManager := &fakeDNSTransportManager{
defaultTransport: &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP},
transports: map[string]adapter.DNSTransport{
"upstream": &fakeDNSTransport{tag: "upstream", transportType: C.DNSTypeUDP},
"selected": &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP},
"default": &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP},
},
}
client := &fakeDNSClient{
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
switch transport.Tag() {
case "upstream":
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("9.9.9.9")}, 60), nil
case "selected":
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60), nil
case "default":
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("4.4.4.4")}, 60), nil
default:
return nil, errors.New("unexpected transport")
}
},
}
rules := []option.DNSRule{
{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
Domain: badoption.Listable[string]{"example.com"},
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeEvaluate,
RouteOptions: option.DNSRouteActionOptions{Server: "upstream"},
},
},
},
{
Type: C.RuleTypeLogical,
LogicalOptions: option.LogicalDNSRule{
RawLogicalDNSRule: option.RawLogicalDNSRule{
Mode: C.LogicalTypeOr,
Rules: []option.DNSRule{{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
MatchResponse: true,
IPCIDR: badoption.Listable[string]{"1.1.1.0/24"},
},
},
}},
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.DNSRouteActionOptions{Server: "selected"},
},
},
},
}
router := newTestRouter(t, rules, transportManager, client)
response, err := router.Exchange(context.Background(), &mDNS.Msg{
Question: []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)},
}, adapter.DNSQueryOptions{})
require.NoError(t, err)
require.Equal(t, []netip.Addr{netip.MustParseAddr("4.4.4.4")}, MessageToAddresses(response))
}
func TestOldModeReportsLegacyAddressFilterDeprecation(t *testing.T) {
t.Parallel()

View File

@@ -350,17 +350,19 @@ func (r *DefaultDNSRule) WithAddressLimit() bool {
}
func (r *DefaultDNSRule) Match(metadata *adapter.InboundContext) bool {
return !r.matchStatesForMatch(metadata).isEmpty()
}
func (r *DefaultDNSRule) matchStatesForMatch(metadata *adapter.InboundContext) ruleMatchStateSet {
if r.matchResponse {
if metadata.DNSResponse == nil {
return false
return 0
}
return r.abstractDefaultRule.Match(metadata)
return r.abstractDefaultRule.matchStates(metadata)
}
metadata.IgnoreDestinationIPCIDRMatch = true
defer func() {
metadata.IgnoreDestinationIPCIDRMatch = false
}()
return !r.matchStates(metadata).isEmpty()
matchMetadata := *metadata
matchMetadata.IgnoreDestinationIPCIDRMatch = true
return r.abstractDefaultRule.matchStates(&matchMetadata)
}
func (r *DefaultDNSRule) MatchAddressLimit(metadata *adapter.InboundContext) bool {
@@ -377,6 +379,52 @@ func (r *LogicalDNSRule) matchStates(metadata *adapter.InboundContext) ruleMatch
return r.abstractLogicalRule.matchStates(metadata)
}
func matchDNSHeadlessRuleStatesForMatch(rule adapter.HeadlessRule, metadata *adapter.InboundContext) ruleMatchStateSet {
switch rule := rule.(type) {
case *DefaultDNSRule:
return rule.matchStatesForMatch(metadata)
case *LogicalDNSRule:
return rule.matchStatesForMatch(metadata)
default:
return matchHeadlessRuleStates(rule, metadata)
}
}
func (r *LogicalDNSRule) matchStatesForMatch(metadata *adapter.InboundContext) ruleMatchStateSet {
var stateSet ruleMatchStateSet
if r.mode == C.LogicalTypeAnd {
stateSet = emptyRuleMatchState()
for _, rule := range r.rules {
nestedMetadata := *metadata
nestedMetadata.ResetRuleCache()
nestedStateSet := matchDNSHeadlessRuleStatesForMatch(rule, &nestedMetadata)
if nestedStateSet.isEmpty() {
if r.invert {
return emptyRuleMatchState()
}
return 0
}
stateSet = stateSet.combine(nestedStateSet)
}
} else {
for _, rule := range r.rules {
nestedMetadata := *metadata
nestedMetadata.ResetRuleCache()
stateSet = stateSet.merge(matchDNSHeadlessRuleStatesForMatch(rule, &nestedMetadata))
}
if stateSet.isEmpty() {
if r.invert {
return emptyRuleMatchState()
}
return 0
}
}
if r.invert {
return 0
}
return stateSet
}
func NewLogicalDNSRule(ctx context.Context, logger log.ContextLogger, options option.LogicalDNSRule, legacyAddressFilter bool) (*LogicalDNSRule, error) {
r := &LogicalDNSRule{
abstractLogicalRule: abstractLogicalRule{
@@ -424,11 +472,7 @@ func (r *LogicalDNSRule) WithAddressLimit() bool {
}
func (r *LogicalDNSRule) Match(metadata *adapter.InboundContext) bool {
metadata.IgnoreDestinationIPCIDRMatch = true
defer func() {
metadata.IgnoreDestinationIPCIDRMatch = false
}()
return !r.matchStates(metadata).isEmpty()
return !r.matchStatesForMatch(metadata).isEmpty()
}
func (r *LogicalDNSRule) MatchAddressLimit(metadata *adapter.InboundContext) bool {