diff --git a/dns/router.go b/dns/router.go index 8e552bd52..19232f288 100644 --- a/dns/router.go +++ b/dns/router.go @@ -275,7 +275,7 @@ func (r *Router) logRuleMatch(ctx context.Context, ruleIndex int, currentRule ad } } -func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, options adapter.DNSQueryOptions, allowFakeIP bool) (*mDNS.Msg, adapter.DNSTransport, error) { +func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, options adapter.DNSQueryOptions, allowFakeIP bool) (*mDNS.Msg, adapter.DNSTransport, adapter.DNSQueryOptions, error) { metadata := adapter.ContextFrom(ctx) if metadata == nil { panic("no context") @@ -328,7 +328,7 @@ func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, optio queryOptions.Strategy = r.defaultDomainStrategy } response, err := r.client.Exchange(adapter.OverrideContext(ctx), transport, message, queryOptions, nil) - return response, transport, err + return response, transport, queryOptions, err case *R.RuleActionReject: switch action.Method { case C.RuleActionRejectMethodDefault: @@ -339,12 +339,12 @@ func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, optio Response: true, }, Question: []mDNS.Question{message.Question[0]}, - }, nil, nil + }, nil, effectiveOptions, nil case C.RuleActionRejectMethodDrop: - return nil, nil, tun.ErrDrop + return nil, nil, effectiveOptions, tun.ErrDrop } case *R.RuleActionPredefined: - return action.Response(message), nil, nil + return action.Response(message), nil, effectiveOptions, nil } } queryOptions := effectiveOptions @@ -354,55 +354,83 @@ func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, optio queryOptions.Strategy = r.defaultDomainStrategy } response, err := r.client.Exchange(adapter.OverrideContext(ctx), transport, message, queryOptions, nil) - return response, transport, err + return response, transport, queryOptions, err +} + +type lookupWithRulesResponse struct { + addresses []netip.Addr + strategy C.DomainStrategy +} + +func (r *Router) resolveLookupStrategy(options adapter.DNSQueryOptions, strategies ...C.DomainStrategy) C.DomainStrategy { + if options.LookupStrategy != C.DomainStrategyAsIS { + return options.LookupStrategy + } + for _, strategy := range strategies { + if strategy != C.DomainStrategyAsIS { + return strategy + } + } + if options.Strategy != C.DomainStrategyAsIS { + return options.Strategy + } + return r.defaultDomainStrategy +} + +func lookupStrategyAllowsQueryType(strategy C.DomainStrategy, qType uint16) bool { + switch strategy { + case C.DomainStrategyIPv4Only: + return qType == mDNS.TypeA + case C.DomainStrategyIPv6Only: + return qType == mDNS.TypeAAAA + default: + return true + } } func (r *Router) lookupWithRules(ctx context.Context, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) { - var strategy C.DomainStrategy - if options.LookupStrategy != C.DomainStrategyAsIS { - strategy = options.LookupStrategy - } else { - strategy = options.Strategy - } lookupOptions := options if options.LookupStrategy != C.DomainStrategyAsIS { - lookupOptions.Strategy = strategy + lookupOptions.Strategy = options.LookupStrategy } - if strategy == C.DomainStrategyIPv4Only { - return r.lookupWithRulesType(ctx, domain, mDNS.TypeA, lookupOptions) + if options.LookupStrategy == C.DomainStrategyIPv4Only { + response, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeA, lookupOptions) + return response.addresses, err } - if strategy == C.DomainStrategyIPv6Only { - return r.lookupWithRulesType(ctx, domain, mDNS.TypeAAAA, lookupOptions) + if options.LookupStrategy == C.DomainStrategyIPv6Only { + response, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeAAAA, lookupOptions) + return response.addresses, err } var ( - response4 []netip.Addr - response6 []netip.Addr + response4 lookupWithRulesResponse + response6 lookupWithRulesResponse ) var group task.Group group.Append("exchange4", func(ctx context.Context) error { - response, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeA, lookupOptions) - if err != nil { - return err - } - response4 = response - return nil + result, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeA, lookupOptions) + response4 = result + return err }) group.Append("exchange6", func(ctx context.Context) error { - response, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeAAAA, lookupOptions) - if err != nil { - return err - } - response6 = response - return nil + result, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeAAAA, lookupOptions) + response6 = result + return err }) err := group.Run(ctx) - if len(response4) == 0 && len(response6) == 0 { + strategy := r.resolveLookupStrategy(options, response4.strategy, response6.strategy) + if !lookupStrategyAllowsQueryType(strategy, mDNS.TypeA) { + response4.addresses = nil + } + if !lookupStrategyAllowsQueryType(strategy, mDNS.TypeAAAA) { + response6.addresses = nil + } + if len(response4.addresses) == 0 && len(response6.addresses) == 0 { return nil, err } - return sortAddresses(response4, response6, strategy), nil + return sortAddresses(response4.addresses, response6.addresses, strategy), nil } -func (r *Router) lookupWithRulesType(ctx context.Context, domain string, qType uint16, options adapter.DNSQueryOptions) ([]netip.Addr, error) { +func (r *Router) lookupWithRulesType(ctx context.Context, domain string, qType uint16, options adapter.DNSQueryOptions) (lookupWithRulesResponse, error) { request := &mDNS.Msg{ MsgHdr: mDNS.MsgHdr{ RecursionDesired: true, @@ -413,14 +441,21 @@ func (r *Router) lookupWithRulesType(ctx context.Context, domain string, qType u Qclass: mDNS.ClassINET, }}, } - response, _, err := r.exchangeWithRules(adapter.OverrideContext(ctx), request, options, false) + response, _, queryOptions, err := r.exchangeWithRules(adapter.OverrideContext(ctx), request, options, false) + result := lookupWithRulesResponse{ + strategy: r.resolveLookupStrategy(options, queryOptions.Strategy), + } if err != nil { - return nil, err + return result, err } if response.Rcode != mDNS.RcodeSuccess { - return nil, RcodeError(response.Rcode) + return result, RcodeError(response.Rcode) } - return MessageToAddresses(response), nil + if !lookupStrategyAllowsQueryType(result.strategy, qType) { + return result, nil + } + result.addresses = MessageToAddresses(response) + return result, nil } func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapter.DNSQueryOptions) (*mDNS.Msg, error) { @@ -461,7 +496,7 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte } response, err = r.client.Exchange(ctx, transport, message, options, nil) } else if !r.legacyAddressFilterMode { - response, transport, err = r.exchangeWithRules(ctx, message, options, true) + response, transport, _, err = r.exchangeWithRules(ctx, message, options, true) } else { var ( rule adapter.DNSRule diff --git a/dns/router_test.go b/dns/router_test.go index d7deb848f..f4bbd9a39 100644 --- a/dns/router_test.go +++ b/dns/router_test.go @@ -305,6 +305,134 @@ func TestLookupNewModeDoesNotUseQueryTypeRule(t *testing.T) { require.Equal(t, []netip.Addr{netip.MustParseAddr("3.3.3.3")}, addresses) } +func TestLookupNewModeAppliesRouteStrategyAfterEvaluate(t *testing.T) { + t.Parallel() + + defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} + router := newTestRouter(t, []option.DNSRule{ + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + Domain: badoption.Listable[string]{"example.com"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeEvaluate, + RouteOptions: option.DNSRouteActionOptions{Server: "default"}, + }, + }, + }, + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + MatchResponse: true, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{ + Server: "selected", + Strategy: option.DomainStrategy(C.DomainStrategyIPv4Only), + }, + }, + }, + }, + }, &fakeDNSTransportManager{ + defaultTransport: defaultTransport, + transports: map[string]adapter.DNSTransport{ + "default": defaultTransport, + "selected": &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP}, + }, + }, &fakeDNSClient{ + exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) { + if transport.Tag() == "default" { + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("1.1.1.1")}, 60), nil + } + switch message.Question[0].Qtype { + case mDNS.TypeA: + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2.2.2.2")}, 60), nil + case mDNS.TypeAAAA: + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2001:db8::1")}, 60), nil + default: + return nil, errors.New("unexpected qtype") + } + }, + }) + + addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) + require.NoError(t, err) + require.Equal(t, []netip.Addr{netip.MustParseAddr("2.2.2.2")}, addresses) +} + +func TestExchangeNewModeLogicalMatchResponseIPCIDRFallsThrough(t *testing.T) { + t.Parallel() + + transportManager := &fakeDNSTransportManager{ + defaultTransport: &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}, + transports: map[string]adapter.DNSTransport{ + "upstream": &fakeDNSTransport{tag: "upstream", transportType: C.DNSTypeUDP}, + "selected": &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP}, + "default": &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}, + }, + } + client := &fakeDNSClient{ + exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) { + switch transport.Tag() { + case "upstream": + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("9.9.9.9")}, 60), nil + case "selected": + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60), nil + case "default": + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("4.4.4.4")}, 60), nil + default: + return nil, errors.New("unexpected transport") + } + }, + } + rules := []option.DNSRule{ + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + Domain: badoption.Listable[string]{"example.com"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeEvaluate, + RouteOptions: option.DNSRouteActionOptions{Server: "upstream"}, + }, + }, + }, + { + Type: C.RuleTypeLogical, + LogicalOptions: option.LogicalDNSRule{ + RawLogicalDNSRule: option.RawLogicalDNSRule{ + Mode: C.LogicalTypeOr, + Rules: []option.DNSRule{{ + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + MatchResponse: true, + IPCIDR: badoption.Listable[string]{"1.1.1.0/24"}, + }, + }, + }}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "selected"}, + }, + }, + }, + } + router := newTestRouter(t, rules, transportManager, client) + + response, err := router.Exchange(context.Background(), &mDNS.Msg{ + Question: []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)}, + }, adapter.DNSQueryOptions{}) + require.NoError(t, err) + require.Equal(t, []netip.Addr{netip.MustParseAddr("4.4.4.4")}, MessageToAddresses(response)) +} + func TestOldModeReportsLegacyAddressFilterDeprecation(t *testing.T) { t.Parallel() diff --git a/route/rule/rule_dns.go b/route/rule/rule_dns.go index 1bb42cbaf..4d6636dc5 100644 --- a/route/rule/rule_dns.go +++ b/route/rule/rule_dns.go @@ -350,17 +350,19 @@ func (r *DefaultDNSRule) WithAddressLimit() bool { } func (r *DefaultDNSRule) Match(metadata *adapter.InboundContext) bool { + return !r.matchStatesForMatch(metadata).isEmpty() +} + +func (r *DefaultDNSRule) matchStatesForMatch(metadata *adapter.InboundContext) ruleMatchStateSet { if r.matchResponse { if metadata.DNSResponse == nil { - return false + return 0 } - return r.abstractDefaultRule.Match(metadata) + return r.abstractDefaultRule.matchStates(metadata) } - metadata.IgnoreDestinationIPCIDRMatch = true - defer func() { - metadata.IgnoreDestinationIPCIDRMatch = false - }() - return !r.matchStates(metadata).isEmpty() + matchMetadata := *metadata + matchMetadata.IgnoreDestinationIPCIDRMatch = true + return r.abstractDefaultRule.matchStates(&matchMetadata) } func (r *DefaultDNSRule) MatchAddressLimit(metadata *adapter.InboundContext) bool { @@ -377,6 +379,52 @@ func (r *LogicalDNSRule) matchStates(metadata *adapter.InboundContext) ruleMatch return r.abstractLogicalRule.matchStates(metadata) } +func matchDNSHeadlessRuleStatesForMatch(rule adapter.HeadlessRule, metadata *adapter.InboundContext) ruleMatchStateSet { + switch rule := rule.(type) { + case *DefaultDNSRule: + return rule.matchStatesForMatch(metadata) + case *LogicalDNSRule: + return rule.matchStatesForMatch(metadata) + default: + return matchHeadlessRuleStates(rule, metadata) + } +} + +func (r *LogicalDNSRule) matchStatesForMatch(metadata *adapter.InboundContext) ruleMatchStateSet { + var stateSet ruleMatchStateSet + if r.mode == C.LogicalTypeAnd { + stateSet = emptyRuleMatchState() + for _, rule := range r.rules { + nestedMetadata := *metadata + nestedMetadata.ResetRuleCache() + nestedStateSet := matchDNSHeadlessRuleStatesForMatch(rule, &nestedMetadata) + if nestedStateSet.isEmpty() { + if r.invert { + return emptyRuleMatchState() + } + return 0 + } + stateSet = stateSet.combine(nestedStateSet) + } + } else { + for _, rule := range r.rules { + nestedMetadata := *metadata + nestedMetadata.ResetRuleCache() + stateSet = stateSet.merge(matchDNSHeadlessRuleStatesForMatch(rule, &nestedMetadata)) + } + if stateSet.isEmpty() { + if r.invert { + return emptyRuleMatchState() + } + return 0 + } + } + if r.invert { + return 0 + } + return stateSet +} + func NewLogicalDNSRule(ctx context.Context, logger log.ContextLogger, options option.LogicalDNSRule, legacyAddressFilter bool) (*LogicalDNSRule, error) { r := &LogicalDNSRule{ abstractLogicalRule: abstractLogicalRule{ @@ -424,11 +472,7 @@ func (r *LogicalDNSRule) WithAddressLimit() bool { } func (r *LogicalDNSRule) Match(metadata *adapter.InboundContext) bool { - metadata.IgnoreDestinationIPCIDRMatch = true - defer func() { - metadata.IgnoreDestinationIPCIDRMatch = false - }() - return !r.matchStates(metadata).isEmpty() + return !r.matchStatesForMatch(metadata).isEmpty() } func (r *LogicalDNSRule) MatchAddressLimit(metadata *adapter.InboundContext) bool {