diff --git a/adapter/router.go b/adapter/router.go index b8564eb0a..550aa6629 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -66,9 +66,7 @@ type RuleSet interface { type RuleSetUpdateCallback func(it RuleSet) -// Rule-set metadata only exposes headless-rule capabilities that outer routers -// need before evaluating nested matches. Headless rules do not support -// ip_version, so there is intentionally no ContainsIPVersionRule flag here. +// ip_version is not a headless-rule item, so ContainsIPVersionRule is intentionally absent. type RuleSetMetadata struct { ContainsProcessRule bool ContainsWIFIRule bool diff --git a/dns/router.go b/dns/router.go index 52a87ae6d..59d7163bd 100644 --- a/dns/router.go +++ b/dns/router.go @@ -60,6 +60,13 @@ func (s *rulesSnapshot) retain() { s.references.Add(1) } +func (s *rulesSnapshot) rulesAndMode() ([]adapter.DNSRule, bool) { + if s == nil { + return nil, false + } + return s.rules, s.legacyDNSMode +} + func (s *rulesSnapshot) release() { if s == nil { return @@ -129,7 +136,7 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOp func (r *Router) Initialize(rules []option.DNSRule) error { r.rawRules = append(r.rawRules[:0], rules...) - newRules, _, err := r.buildRules(false) + newRules, _, _, err := r.buildRules(false) if err != nil { return err } @@ -193,7 +200,7 @@ func (r *Router) rebuildRules(startRules bool) error { if r.isClosing() { return nil } - newRules, legacyDNSMode, err := r.buildRules(startRules) + newRules, legacyDNSMode, modeFlags, err := r.buildRules(startRules) if err != nil { if r.isClosing() { return nil @@ -207,7 +214,7 @@ func (r *Router) rebuildRules(startRules bool) error { shouldReportRuleStrategyDeprecated := startRules && legacyDNSMode && !r.ruleStrategyDeprecatedReported && - hasDNSRuleActionStrategy(r.rawRules) + modeFlags.neededFromStrategy newSnapshot := newRulesSnapshot(newRules, legacyDNSMode) r.stateAccess.Lock() if r.closing { @@ -247,22 +254,22 @@ func (r *Router) acquireRulesSnapshot() *rulesSnapshot { return snapshot } -func (r *Router) buildRules(startRules bool) ([]adapter.DNSRule, bool, error) { +func (r *Router) buildRules(startRules bool) ([]adapter.DNSRule, bool, dnsRuleModeFlags, error) { for i, ruleOptions := range r.rawRules { err := R.ValidateNoNestedDNSRuleActions(ruleOptions) if err != nil { - return nil, false, E.Cause(err, "parse dns rule[", i, "]") + return nil, false, dnsRuleModeFlags{}, E.Cause(err, "parse dns rule[", i, "]") } } router := service.FromContext[adapter.Router](r.ctx) - legacyDNSMode, err := resolveLegacyDNSMode(router, r.rawRules) + legacyDNSMode, modeFlags, err := resolveLegacyDNSMode(router, r.rawRules) if err != nil { - return nil, false, err + return nil, false, dnsRuleModeFlags{}, err } if !legacyDNSMode { err = validateLegacyDNSModeDisabledRules(r.rawRules) if err != nil { - return nil, false, err + return nil, false, dnsRuleModeFlags{}, err } } newRules := make([]adapter.DNSRule, 0, len(r.rawRules)) @@ -271,7 +278,7 @@ func (r *Router) buildRules(startRules bool) ([]adapter.DNSRule, bool, error) { dnsRule, err = R.NewDNSRule(r.ctx, r.logger, ruleOptions, true, legacyDNSMode) if err != nil { closeRules(newRules) - return nil, false, E.Cause(err, "parse dns rule[", i, "]") + return nil, false, dnsRuleModeFlags{}, E.Cause(err, "parse dns rule[", i, "]") } newRules = append(newRules, dnsRule) } @@ -280,11 +287,11 @@ func (r *Router) buildRules(startRules bool) ([]adapter.DNSRule, bool, error) { err = rule.Start() if err != nil { closeRules(newRules) - return nil, false, E.Cause(err, "initialize DNS rule[", i, "]") + return nil, false, dnsRuleModeFlags{}, E.Cause(err, "initialize DNS rule[", i, "]") } } } - return newRules, legacyDNSMode, nil + return newRules, legacyDNSMode, modeFlags, nil } func closeRules(rules []adapter.DNSRule) { @@ -433,8 +440,8 @@ const ( dnsRouteStatusResolved ) -func (r *Router) resolveDNSRoute(action *R.RuleActionDNSRoute, allowFakeIP bool, options *adapter.DNSQueryOptions) (adapter.DNSTransport, dnsRouteStatus) { - transport, loaded := r.transport.Transport(action.Server) +func (r *Router) resolveDNSRoute(server string, routeOptions R.RuleActionDNSRouteOptions, allowFakeIP bool, options *adapter.DNSQueryOptions) (adapter.DNSTransport, dnsRouteStatus) { + transport, loaded := r.transport.Transport(server) if !loaded { return nil, dnsRouteStatusMissing } @@ -442,7 +449,7 @@ func (r *Router) resolveDNSRoute(action *R.RuleActionDNSRoute, allowFakeIP bool, if isFakeIP && !allowFakeIP { return transport, dnsRouteStatusSkipped } - r.applyDNSRouteOptions(options, action.RuleActionDNSRouteOptions) + r.applyDNSRouteOptions(options, routeOptions) if isFakeIP { options.DisableCache = true } @@ -484,10 +491,7 @@ func (r *Router) exchangeWithRules(ctx context.Context, rules []adapter.DNSRule, r.applyDNSRouteOptions(&effectiveOptions, *action) case *R.RuleActionEvaluate: queryOptions := effectiveOptions - transport, status := r.resolveDNSRoute(&R.RuleActionDNSRoute{ - Server: action.Server, - RuleActionDNSRouteOptions: action.RuleActionDNSRouteOptions, - }, allowFakeIP, &queryOptions) + transport, status := r.resolveDNSRoute(action.Server, action.RuleActionDNSRouteOptions, allowFakeIP, &queryOptions) switch status { case dnsRouteStatusMissing: r.logger.ErrorContext(ctx, "transport not found: ", action.Server) @@ -512,7 +516,7 @@ func (r *Router) exchangeWithRules(ctx context.Context, rules []adapter.DNSRule, savedResponse = response case *R.RuleActionDNSRoute: queryOptions := effectiveOptions - transport, status := r.resolveDNSRoute(action, allowFakeIP, &queryOptions) + transport, status := r.resolveDNSRoute(action.Server, action.RuleActionDNSRouteOptions, allowFakeIP, &queryOptions) switch status { case dnsRouteStatusMissing: r.logger.ErrorContext(ctx, "transport not found: ", action.Server) @@ -569,10 +573,6 @@ func (r *Router) exchangeWithRules(ctx context.Context, rules []adapter.DNSRule, } } -type lookupWithRulesResponse struct { - addresses []netip.Addr -} - func (r *Router) resolveLookupStrategy(options adapter.DNSQueryOptions) C.DomainStrategy { if options.LookupStrategy != C.DomainStrategyAsIS { return options.LookupStrategy @@ -618,16 +618,14 @@ func (r *Router) lookupWithRules(ctx context.Context, rules []adapter.DNSRule, d lookupOptions.Strategy = strategy } if strategy == C.DomainStrategyIPv4Only { - response, err := r.lookupWithRulesType(ctx, rules, domain, mDNS.TypeA, lookupOptions) - return response.addresses, err + return r.lookupWithRulesType(ctx, rules, domain, mDNS.TypeA, lookupOptions) } if strategy == C.DomainStrategyIPv6Only { - response, err := r.lookupWithRulesType(ctx, rules, domain, mDNS.TypeAAAA, lookupOptions) - return response.addresses, err + return r.lookupWithRulesType(ctx, rules, domain, mDNS.TypeAAAA, lookupOptions) } var ( - response4 lookupWithRulesResponse - response6 lookupWithRulesResponse + response4 []netip.Addr + response6 []netip.Addr ) var group task.Group group.Append("exchange4", func(ctx context.Context) error { @@ -641,13 +639,13 @@ func (r *Router) lookupWithRules(ctx context.Context, rules []adapter.DNSRule, d return err }) err := group.Run(ctx) - if len(response4.addresses) == 0 && len(response6.addresses) == 0 { + if len(response4) == 0 && len(response6) == 0 { return nil, err } - return sortAddresses(response4.addresses, response6.addresses, strategy), nil + return sortAddresses(response4, response6, strategy), nil } -func (r *Router) lookupWithRulesType(ctx context.Context, rules []adapter.DNSRule, domain string, qType uint16, options adapter.DNSQueryOptions) (lookupWithRulesResponse, error) { +func (r *Router) lookupWithRulesType(ctx context.Context, rules []adapter.DNSRule, domain string, qType uint16, options adapter.DNSQueryOptions) ([]netip.Addr, error) { request := &mDNS.Msg{ MsgHdr: mDNS.MsgHdr{ RecursionDesired: true, @@ -659,18 +657,16 @@ func (r *Router) lookupWithRulesType(ctx context.Context, rules []adapter.DNSRul }}, } exchangeResult := r.exchangeWithRules(withLookupQueryMetadata(ctx, qType), rules, request, options, false) - result := lookupWithRulesResponse{} if exchangeResult.rejectAction != nil { - return result, exchangeResult.rejectAction.Error(ctx) + return nil, exchangeResult.rejectAction.Error(ctx) } if exchangeResult.err != nil { - return result, exchangeResult.err + return nil, exchangeResult.err } if exchangeResult.response.Rcode != mDNS.RcodeSuccess { - return result, RcodeError(exchangeResult.response.Rcode) + return nil, RcodeError(exchangeResult.response.Rcode) } - result.addresses = filterAddressesByQueryType(MessageToAddresses(exchangeResult.response), qType) - return result, nil + return filterAddressesByQueryType(MessageToAddresses(exchangeResult.response), qType), nil } func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapter.DNSQueryOptions) (*mDNS.Msg, error) { @@ -688,14 +684,7 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte } snapshot := r.acquireRulesSnapshot() defer snapshot.release() - var ( - rules []adapter.DNSRule - legacyDNSMode bool - ) - if snapshot != nil { - rules = snapshot.rules - legacyDNSMode = snapshot.legacyDNSMode - } + rules, legacyDNSMode := snapshot.rulesAndMode() r.logger.DebugContext(ctx, "exchange ", FormatQuestion(message.Question[0].String())) var ( response *mDNS.Msg @@ -803,14 +792,7 @@ done: func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) { snapshot := r.acquireRulesSnapshot() defer snapshot.release() - var ( - rules []adapter.DNSRule - legacyDNSMode bool - ) - if snapshot != nil { - rules = snapshot.rules - legacyDNSMode = snapshot.legacyDNSMode - } + rules, legacyDNSMode := snapshot.rulesAndMode() var ( responseAddrs []netip.Addr err error @@ -964,84 +946,92 @@ func defaultRuleDisablesLegacyDNSMode(rule option.DefaultDNSRule) bool { len(rule.QueryType) > 0 } -func resolveLegacyDNSMode(router adapter.Router, rules []option.DNSRule) (bool, error) { - legacyDNSModeDisabled, needsLegacyDNSMode, needsLegacyDNSModeFromStrategy, err := dnsRuleModeRequirements(router, rules) +type dnsRuleModeFlags struct { + disabled bool + needed bool + neededFromStrategy bool +} + +func (f *dnsRuleModeFlags) merge(other dnsRuleModeFlags) { + f.disabled = f.disabled || other.disabled + f.needed = f.needed || other.needed + f.neededFromStrategy = f.neededFromStrategy || other.neededFromStrategy +} + +func resolveLegacyDNSMode(router adapter.Router, rules []option.DNSRule) (bool, dnsRuleModeFlags, error) { + flags, err := dnsRuleModeRequirements(router, rules) if err != nil { - return false, err + return false, flags, err } - if legacyDNSModeDisabled && needsLegacyDNSModeFromStrategy { - return false, E.New("DNS rule action strategy is only supported in legacyDNSMode") + if flags.disabled && flags.neededFromStrategy { + return false, flags, E.New("DNS rule action strategy is only supported in legacyDNSMode") } - if legacyDNSModeDisabled { - return false, nil + if flags.disabled { + return false, flags, nil } - return needsLegacyDNSMode, nil + return flags.needed, flags, nil } -func dnsRuleModeRequirements(router adapter.Router, rules []option.DNSRule) (bool, bool, bool, error) { - var legacyDNSModeDisabled bool - var needsLegacyDNSMode bool - var needsLegacyDNSModeFromStrategy bool +func dnsRuleModeRequirements(router adapter.Router, rules []option.DNSRule) (dnsRuleModeFlags, error) { + var flags dnsRuleModeFlags for i, rule := range rules { - ruleLegacyDNSModeDisabled, ruleNeedsLegacyDNSMode, ruleNeedsLegacyDNSModeFromStrategy, err := dnsRuleModeRequirementsInRule(router, rule) + ruleFlags, err := dnsRuleModeRequirementsInRule(router, rule) if err != nil { - return false, false, false, E.Cause(err, "dns rule[", i, "]") + return dnsRuleModeFlags{}, E.Cause(err, "dns rule[", i, "]") } - legacyDNSModeDisabled = legacyDNSModeDisabled || ruleLegacyDNSModeDisabled - needsLegacyDNSMode = needsLegacyDNSMode || ruleNeedsLegacyDNSMode - needsLegacyDNSModeFromStrategy = needsLegacyDNSModeFromStrategy || ruleNeedsLegacyDNSModeFromStrategy + flags.merge(ruleFlags) } - return legacyDNSModeDisabled, needsLegacyDNSMode, needsLegacyDNSModeFromStrategy, nil + return flags, nil } -func dnsRuleModeRequirementsInRule(router adapter.Router, rule option.DNSRule) (bool, bool, bool, error) { +func dnsRuleModeRequirementsInRule(router adapter.Router, rule option.DNSRule) (dnsRuleModeFlags, error) { switch rule.Type { case "", C.RuleTypeDefault: return dnsRuleModeRequirementsInDefaultRule(router, rule.DefaultOptions) case C.RuleTypeLogical: - legacyDNSModeDisabled := dnsRuleActionType(rule) == C.RuleActionTypeEvaluate - needsLegacyDNSModeFromStrategy := dnsRuleActionHasStrategy(rule.LogicalOptions.DNSRuleAction) - needsLegacyDNSMode := needsLegacyDNSModeFromStrategy - for i, subRule := range rule.LogicalOptions.Rules { - subLegacyDNSModeDisabled, subNeedsLegacyDNSMode, subNeedsLegacyDNSModeFromStrategy, err := dnsRuleModeRequirementsInRule(router, subRule) - if err != nil { - return false, false, false, E.Cause(err, "sub rule[", i, "]") - } - legacyDNSModeDisabled = legacyDNSModeDisabled || subLegacyDNSModeDisabled - needsLegacyDNSMode = needsLegacyDNSMode || subNeedsLegacyDNSMode - needsLegacyDNSModeFromStrategy = needsLegacyDNSModeFromStrategy || subNeedsLegacyDNSModeFromStrategy + flags := dnsRuleModeFlags{ + disabled: dnsRuleActionType(rule) == C.RuleActionTypeEvaluate, + neededFromStrategy: dnsRuleActionHasStrategy(rule.LogicalOptions.DNSRuleAction), } - return legacyDNSModeDisabled, needsLegacyDNSMode, needsLegacyDNSModeFromStrategy, nil + flags.needed = flags.neededFromStrategy + for i, subRule := range rule.LogicalOptions.Rules { + subFlags, err := dnsRuleModeRequirementsInRule(router, subRule) + if err != nil { + return dnsRuleModeFlags{}, E.Cause(err, "sub rule[", i, "]") + } + flags.merge(subFlags) + } + return flags, nil default: - return false, false, false, nil + return dnsRuleModeFlags{}, nil } } -func dnsRuleModeRequirementsInDefaultRule(router adapter.Router, rule option.DefaultDNSRule) (bool, bool, bool, error) { - legacyDNSModeDisabled := defaultRuleDisablesLegacyDNSMode(rule) - needsLegacyDNSModeFromStrategy := dnsRuleActionHasStrategy(rule.DNSRuleAction) - needsLegacyDNSMode := defaultRuleNeedsLegacyDNSModeFromAddressFilter(rule) || needsLegacyDNSModeFromStrategy +func dnsRuleModeRequirementsInDefaultRule(router adapter.Router, rule option.DefaultDNSRule) (dnsRuleModeFlags, error) { + flags := dnsRuleModeFlags{ + disabled: defaultRuleDisablesLegacyDNSMode(rule), + neededFromStrategy: dnsRuleActionHasStrategy(rule.DNSRuleAction), + } + flags.needed = defaultRuleNeedsLegacyDNSModeFromAddressFilter(rule) || flags.neededFromStrategy if len(rule.RuleSet) == 0 { - return legacyDNSModeDisabled, needsLegacyDNSMode, needsLegacyDNSModeFromStrategy, nil + return flags, nil } if router == nil { - return false, false, false, E.New("router service not found") + return dnsRuleModeFlags{}, E.New("router service not found") } for _, tag := range rule.RuleSet { ruleSet, loaded := router.RuleSet(tag) if !loaded { - return false, false, false, E.New("rule-set not found: ", tag) + return dnsRuleModeFlags{}, E.New("rule-set not found: ", tag) } metadata := ruleSet.Metadata() - // Rule sets are built from headless rules, so query_type is the only - // per-query DNS predicate they can contribute here. ip_version is not a - // headless-rule item and is therefore intentionally absent from metadata. - legacyDNSModeDisabled = legacyDNSModeDisabled || metadata.ContainsDNSQueryTypeRule + // ip_version is not a headless-rule item, so ContainsIPVersionRule is intentionally absent. + flags.disabled = flags.disabled || metadata.ContainsDNSQueryTypeRule if !rule.RuleSetIPCIDRMatchSource && metadata.ContainsIPCIDRRule { - needsLegacyDNSMode = true + flags.needed = true } } - return legacyDNSModeDisabled, needsLegacyDNSMode, needsLegacyDNSModeFromStrategy, nil + return flags, nil } func referencedDNSRuleSetTags(rules []option.DNSRule) []string { @@ -1126,29 +1116,6 @@ func validateLegacyDNSModeDisabledDefaultRule(rule option.DefaultDNSRule) (bool, return rule.MatchResponse, nil } -func hasDNSRuleActionStrategy(rules []option.DNSRule) bool { - for _, rule := range rules { - if dnsRuleHasActionStrategy(rule) { - return true - } - } - return false -} - -func dnsRuleHasActionStrategy(rule option.DNSRule) bool { - switch rule.Type { - case "", C.RuleTypeDefault: - return dnsRuleActionHasStrategy(rule.DefaultOptions.DNSRuleAction) - case C.RuleTypeLogical: - if dnsRuleActionHasStrategy(rule.LogicalOptions.DNSRuleAction) { - return true - } - return hasDNSRuleActionStrategy(rule.LogicalOptions.Rules) - default: - return false - } -} - func dnsRuleActionHasStrategy(action option.DNSRuleAction) bool { switch action.Action { case "", C.RuleActionTypeRoute, C.RuleActionTypeEvaluate: