mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-13 20:28:32 +10:00
Fix DNS evaluate routing regressions
This commit is contained in:
113
dns/router.go
113
dns/router.go
@@ -275,7 +275,7 @@ func (r *Router) logRuleMatch(ctx context.Context, ruleIndex int, currentRule ad
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, options adapter.DNSQueryOptions, allowFakeIP bool) (*mDNS.Msg, adapter.DNSTransport, error) {
|
||||
func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, options adapter.DNSQueryOptions, allowFakeIP bool) (*mDNS.Msg, adapter.DNSTransport, adapter.DNSQueryOptions, error) {
|
||||
metadata := adapter.ContextFrom(ctx)
|
||||
if metadata == nil {
|
||||
panic("no context")
|
||||
@@ -328,7 +328,7 @@ func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, optio
|
||||
queryOptions.Strategy = r.defaultDomainStrategy
|
||||
}
|
||||
response, err := r.client.Exchange(adapter.OverrideContext(ctx), transport, message, queryOptions, nil)
|
||||
return response, transport, err
|
||||
return response, transport, queryOptions, err
|
||||
case *R.RuleActionReject:
|
||||
switch action.Method {
|
||||
case C.RuleActionRejectMethodDefault:
|
||||
@@ -339,12 +339,12 @@ func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, optio
|
||||
Response: true,
|
||||
},
|
||||
Question: []mDNS.Question{message.Question[0]},
|
||||
}, nil, nil
|
||||
}, nil, effectiveOptions, nil
|
||||
case C.RuleActionRejectMethodDrop:
|
||||
return nil, nil, tun.ErrDrop
|
||||
return nil, nil, effectiveOptions, tun.ErrDrop
|
||||
}
|
||||
case *R.RuleActionPredefined:
|
||||
return action.Response(message), nil, nil
|
||||
return action.Response(message), nil, effectiveOptions, nil
|
||||
}
|
||||
}
|
||||
queryOptions := effectiveOptions
|
||||
@@ -354,55 +354,83 @@ func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, optio
|
||||
queryOptions.Strategy = r.defaultDomainStrategy
|
||||
}
|
||||
response, err := r.client.Exchange(adapter.OverrideContext(ctx), transport, message, queryOptions, nil)
|
||||
return response, transport, err
|
||||
return response, transport, queryOptions, err
|
||||
}
|
||||
|
||||
type lookupWithRulesResponse struct {
|
||||
addresses []netip.Addr
|
||||
strategy C.DomainStrategy
|
||||
}
|
||||
|
||||
func (r *Router) resolveLookupStrategy(options adapter.DNSQueryOptions, strategies ...C.DomainStrategy) C.DomainStrategy {
|
||||
if options.LookupStrategy != C.DomainStrategyAsIS {
|
||||
return options.LookupStrategy
|
||||
}
|
||||
for _, strategy := range strategies {
|
||||
if strategy != C.DomainStrategyAsIS {
|
||||
return strategy
|
||||
}
|
||||
}
|
||||
if options.Strategy != C.DomainStrategyAsIS {
|
||||
return options.Strategy
|
||||
}
|
||||
return r.defaultDomainStrategy
|
||||
}
|
||||
|
||||
func lookupStrategyAllowsQueryType(strategy C.DomainStrategy, qType uint16) bool {
|
||||
switch strategy {
|
||||
case C.DomainStrategyIPv4Only:
|
||||
return qType == mDNS.TypeA
|
||||
case C.DomainStrategyIPv6Only:
|
||||
return qType == mDNS.TypeAAAA
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Router) lookupWithRules(ctx context.Context, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) {
|
||||
var strategy C.DomainStrategy
|
||||
if options.LookupStrategy != C.DomainStrategyAsIS {
|
||||
strategy = options.LookupStrategy
|
||||
} else {
|
||||
strategy = options.Strategy
|
||||
}
|
||||
lookupOptions := options
|
||||
if options.LookupStrategy != C.DomainStrategyAsIS {
|
||||
lookupOptions.Strategy = strategy
|
||||
lookupOptions.Strategy = options.LookupStrategy
|
||||
}
|
||||
if strategy == C.DomainStrategyIPv4Only {
|
||||
return r.lookupWithRulesType(ctx, domain, mDNS.TypeA, lookupOptions)
|
||||
if options.LookupStrategy == C.DomainStrategyIPv4Only {
|
||||
response, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeA, lookupOptions)
|
||||
return response.addresses, err
|
||||
}
|
||||
if strategy == C.DomainStrategyIPv6Only {
|
||||
return r.lookupWithRulesType(ctx, domain, mDNS.TypeAAAA, lookupOptions)
|
||||
if options.LookupStrategy == C.DomainStrategyIPv6Only {
|
||||
response, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeAAAA, lookupOptions)
|
||||
return response.addresses, err
|
||||
}
|
||||
var (
|
||||
response4 []netip.Addr
|
||||
response6 []netip.Addr
|
||||
response4 lookupWithRulesResponse
|
||||
response6 lookupWithRulesResponse
|
||||
)
|
||||
var group task.Group
|
||||
group.Append("exchange4", func(ctx context.Context) error {
|
||||
response, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeA, lookupOptions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
response4 = response
|
||||
return nil
|
||||
result, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeA, lookupOptions)
|
||||
response4 = result
|
||||
return err
|
||||
})
|
||||
group.Append("exchange6", func(ctx context.Context) error {
|
||||
response, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeAAAA, lookupOptions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
response6 = response
|
||||
return nil
|
||||
result, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeAAAA, lookupOptions)
|
||||
response6 = result
|
||||
return err
|
||||
})
|
||||
err := group.Run(ctx)
|
||||
if len(response4) == 0 && len(response6) == 0 {
|
||||
strategy := r.resolveLookupStrategy(options, response4.strategy, response6.strategy)
|
||||
if !lookupStrategyAllowsQueryType(strategy, mDNS.TypeA) {
|
||||
response4.addresses = nil
|
||||
}
|
||||
if !lookupStrategyAllowsQueryType(strategy, mDNS.TypeAAAA) {
|
||||
response6.addresses = nil
|
||||
}
|
||||
if len(response4.addresses) == 0 && len(response6.addresses) == 0 {
|
||||
return nil, err
|
||||
}
|
||||
return sortAddresses(response4, response6, strategy), nil
|
||||
return sortAddresses(response4.addresses, response6.addresses, strategy), nil
|
||||
}
|
||||
|
||||
func (r *Router) lookupWithRulesType(ctx context.Context, domain string, qType uint16, options adapter.DNSQueryOptions) ([]netip.Addr, error) {
|
||||
func (r *Router) lookupWithRulesType(ctx context.Context, domain string, qType uint16, options adapter.DNSQueryOptions) (lookupWithRulesResponse, error) {
|
||||
request := &mDNS.Msg{
|
||||
MsgHdr: mDNS.MsgHdr{
|
||||
RecursionDesired: true,
|
||||
@@ -413,14 +441,21 @@ func (r *Router) lookupWithRulesType(ctx context.Context, domain string, qType u
|
||||
Qclass: mDNS.ClassINET,
|
||||
}},
|
||||
}
|
||||
response, _, err := r.exchangeWithRules(adapter.OverrideContext(ctx), request, options, false)
|
||||
response, _, queryOptions, err := r.exchangeWithRules(adapter.OverrideContext(ctx), request, options, false)
|
||||
result := lookupWithRulesResponse{
|
||||
strategy: r.resolveLookupStrategy(options, queryOptions.Strategy),
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return result, err
|
||||
}
|
||||
if response.Rcode != mDNS.RcodeSuccess {
|
||||
return nil, RcodeError(response.Rcode)
|
||||
return result, RcodeError(response.Rcode)
|
||||
}
|
||||
return MessageToAddresses(response), nil
|
||||
if !lookupStrategyAllowsQueryType(result.strategy, qType) {
|
||||
return result, nil
|
||||
}
|
||||
result.addresses = MessageToAddresses(response)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapter.DNSQueryOptions) (*mDNS.Msg, error) {
|
||||
@@ -461,7 +496,7 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte
|
||||
}
|
||||
response, err = r.client.Exchange(ctx, transport, message, options, nil)
|
||||
} else if !r.legacyAddressFilterMode {
|
||||
response, transport, err = r.exchangeWithRules(ctx, message, options, true)
|
||||
response, transport, _, err = r.exchangeWithRules(ctx, message, options, true)
|
||||
} else {
|
||||
var (
|
||||
rule adapter.DNSRule
|
||||
|
||||
@@ -305,6 +305,134 @@ func TestLookupNewModeDoesNotUseQueryTypeRule(t *testing.T) {
|
||||
require.Equal(t, []netip.Addr{netip.MustParseAddr("3.3.3.3")}, addresses)
|
||||
}
|
||||
|
||||
func TestLookupNewModeAppliesRouteStrategyAfterEvaluate(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
|
||||
router := newTestRouter(t, []option.DNSRule{
|
||||
{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
Domain: badoption.Listable[string]{"example.com"},
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeEvaluate,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "default"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
MatchResponse: true,
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeRoute,
|
||||
RouteOptions: option.DNSRouteActionOptions{
|
||||
Server: "selected",
|
||||
Strategy: option.DomainStrategy(C.DomainStrategyIPv4Only),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, &fakeDNSTransportManager{
|
||||
defaultTransport: defaultTransport,
|
||||
transports: map[string]adapter.DNSTransport{
|
||||
"default": defaultTransport,
|
||||
"selected": &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP},
|
||||
},
|
||||
}, &fakeDNSClient{
|
||||
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||
if transport.Tag() == "default" {
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("1.1.1.1")}, 60), nil
|
||||
}
|
||||
switch message.Question[0].Qtype {
|
||||
case mDNS.TypeA:
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2.2.2.2")}, 60), nil
|
||||
case mDNS.TypeAAAA:
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2001:db8::1")}, 60), nil
|
||||
default:
|
||||
return nil, errors.New("unexpected qtype")
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []netip.Addr{netip.MustParseAddr("2.2.2.2")}, addresses)
|
||||
}
|
||||
|
||||
func TestExchangeNewModeLogicalMatchResponseIPCIDRFallsThrough(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
transportManager := &fakeDNSTransportManager{
|
||||
defaultTransport: &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP},
|
||||
transports: map[string]adapter.DNSTransport{
|
||||
"upstream": &fakeDNSTransport{tag: "upstream", transportType: C.DNSTypeUDP},
|
||||
"selected": &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP},
|
||||
"default": &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP},
|
||||
},
|
||||
}
|
||||
client := &fakeDNSClient{
|
||||
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||
switch transport.Tag() {
|
||||
case "upstream":
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("9.9.9.9")}, 60), nil
|
||||
case "selected":
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60), nil
|
||||
case "default":
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("4.4.4.4")}, 60), nil
|
||||
default:
|
||||
return nil, errors.New("unexpected transport")
|
||||
}
|
||||
},
|
||||
}
|
||||
rules := []option.DNSRule{
|
||||
{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
Domain: badoption.Listable[string]{"example.com"},
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeEvaluate,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "upstream"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: C.RuleTypeLogical,
|
||||
LogicalOptions: option.LogicalDNSRule{
|
||||
RawLogicalDNSRule: option.RawLogicalDNSRule{
|
||||
Mode: C.LogicalTypeOr,
|
||||
Rules: []option.DNSRule{{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
MatchResponse: true,
|
||||
IPCIDR: badoption.Listable[string]{"1.1.1.0/24"},
|
||||
},
|
||||
},
|
||||
}},
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeRoute,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "selected"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
router := newTestRouter(t, rules, transportManager, client)
|
||||
|
||||
response, err := router.Exchange(context.Background(), &mDNS.Msg{
|
||||
Question: []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)},
|
||||
}, adapter.DNSQueryOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []netip.Addr{netip.MustParseAddr("4.4.4.4")}, MessageToAddresses(response))
|
||||
}
|
||||
|
||||
func TestOldModeReportsLegacyAddressFilterDeprecation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -350,17 +350,19 @@ func (r *DefaultDNSRule) WithAddressLimit() bool {
|
||||
}
|
||||
|
||||
func (r *DefaultDNSRule) Match(metadata *adapter.InboundContext) bool {
|
||||
return !r.matchStatesForMatch(metadata).isEmpty()
|
||||
}
|
||||
|
||||
func (r *DefaultDNSRule) matchStatesForMatch(metadata *adapter.InboundContext) ruleMatchStateSet {
|
||||
if r.matchResponse {
|
||||
if metadata.DNSResponse == nil {
|
||||
return false
|
||||
return 0
|
||||
}
|
||||
return r.abstractDefaultRule.Match(metadata)
|
||||
return r.abstractDefaultRule.matchStates(metadata)
|
||||
}
|
||||
metadata.IgnoreDestinationIPCIDRMatch = true
|
||||
defer func() {
|
||||
metadata.IgnoreDestinationIPCIDRMatch = false
|
||||
}()
|
||||
return !r.matchStates(metadata).isEmpty()
|
||||
matchMetadata := *metadata
|
||||
matchMetadata.IgnoreDestinationIPCIDRMatch = true
|
||||
return r.abstractDefaultRule.matchStates(&matchMetadata)
|
||||
}
|
||||
|
||||
func (r *DefaultDNSRule) MatchAddressLimit(metadata *adapter.InboundContext) bool {
|
||||
@@ -377,6 +379,52 @@ func (r *LogicalDNSRule) matchStates(metadata *adapter.InboundContext) ruleMatch
|
||||
return r.abstractLogicalRule.matchStates(metadata)
|
||||
}
|
||||
|
||||
func matchDNSHeadlessRuleStatesForMatch(rule adapter.HeadlessRule, metadata *adapter.InboundContext) ruleMatchStateSet {
|
||||
switch rule := rule.(type) {
|
||||
case *DefaultDNSRule:
|
||||
return rule.matchStatesForMatch(metadata)
|
||||
case *LogicalDNSRule:
|
||||
return rule.matchStatesForMatch(metadata)
|
||||
default:
|
||||
return matchHeadlessRuleStates(rule, metadata)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *LogicalDNSRule) matchStatesForMatch(metadata *adapter.InboundContext) ruleMatchStateSet {
|
||||
var stateSet ruleMatchStateSet
|
||||
if r.mode == C.LogicalTypeAnd {
|
||||
stateSet = emptyRuleMatchState()
|
||||
for _, rule := range r.rules {
|
||||
nestedMetadata := *metadata
|
||||
nestedMetadata.ResetRuleCache()
|
||||
nestedStateSet := matchDNSHeadlessRuleStatesForMatch(rule, &nestedMetadata)
|
||||
if nestedStateSet.isEmpty() {
|
||||
if r.invert {
|
||||
return emptyRuleMatchState()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
stateSet = stateSet.combine(nestedStateSet)
|
||||
}
|
||||
} else {
|
||||
for _, rule := range r.rules {
|
||||
nestedMetadata := *metadata
|
||||
nestedMetadata.ResetRuleCache()
|
||||
stateSet = stateSet.merge(matchDNSHeadlessRuleStatesForMatch(rule, &nestedMetadata))
|
||||
}
|
||||
if stateSet.isEmpty() {
|
||||
if r.invert {
|
||||
return emptyRuleMatchState()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
}
|
||||
if r.invert {
|
||||
return 0
|
||||
}
|
||||
return stateSet
|
||||
}
|
||||
|
||||
func NewLogicalDNSRule(ctx context.Context, logger log.ContextLogger, options option.LogicalDNSRule, legacyAddressFilter bool) (*LogicalDNSRule, error) {
|
||||
r := &LogicalDNSRule{
|
||||
abstractLogicalRule: abstractLogicalRule{
|
||||
@@ -424,11 +472,7 @@ func (r *LogicalDNSRule) WithAddressLimit() bool {
|
||||
}
|
||||
|
||||
func (r *LogicalDNSRule) Match(metadata *adapter.InboundContext) bool {
|
||||
metadata.IgnoreDestinationIPCIDRMatch = true
|
||||
defer func() {
|
||||
metadata.IgnoreDestinationIPCIDRMatch = false
|
||||
}()
|
||||
return !r.matchStates(metadata).isEmpty()
|
||||
return !r.matchStatesForMatch(metadata).isEmpty()
|
||||
}
|
||||
|
||||
func (r *LogicalDNSRule) MatchAddressLimit(metadata *adapter.InboundContext) bool {
|
||||
|
||||
Reference in New Issue
Block a user