Simplify DNS router internals

- Replace dnsRuleModeRequirements 4-tuple return with dnsRuleModeFlags struct
- Eliminate redundant hasDNSRuleActionStrategy tree walk by reusing mode flags from buildRules
- Remove single-field lookupWithRulesResponse wrapper
- Accept fields directly in resolveDNSRoute instead of *RuleActionDNSRoute
- Extract rulesAndMode() helper to deduplicate snapshot unpacking
- Trim verbose RuleSetMetadata comment
This commit is contained in:
世界
2026-04-01 17:20:29 +08:00
parent 99b363c878
commit 91f942c8bc
2 changed files with 90 additions and 125 deletions

View File

@@ -66,9 +66,7 @@ type RuleSet interface {
type RuleSetUpdateCallback func(it RuleSet)
// Rule-set metadata only exposes headless-rule capabilities that outer routers
// need before evaluating nested matches. Headless rules do not support
// ip_version, so there is intentionally no ContainsIPVersionRule flag here.
// ip_version is not a headless-rule item, so ContainsIPVersionRule is intentionally absent.
type RuleSetMetadata struct {
ContainsProcessRule bool
ContainsWIFIRule bool

View File

@@ -60,6 +60,13 @@ func (s *rulesSnapshot) retain() {
s.references.Add(1)
}
func (s *rulesSnapshot) rulesAndMode() ([]adapter.DNSRule, bool) {
if s == nil {
return nil, false
}
return s.rules, s.legacyDNSMode
}
func (s *rulesSnapshot) release() {
if s == nil {
return
@@ -129,7 +136,7 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOp
func (r *Router) Initialize(rules []option.DNSRule) error {
r.rawRules = append(r.rawRules[:0], rules...)
newRules, _, err := r.buildRules(false)
newRules, _, _, err := r.buildRules(false)
if err != nil {
return err
}
@@ -193,7 +200,7 @@ func (r *Router) rebuildRules(startRules bool) error {
if r.isClosing() {
return nil
}
newRules, legacyDNSMode, err := r.buildRules(startRules)
newRules, legacyDNSMode, modeFlags, err := r.buildRules(startRules)
if err != nil {
if r.isClosing() {
return nil
@@ -207,7 +214,7 @@ func (r *Router) rebuildRules(startRules bool) error {
shouldReportRuleStrategyDeprecated := startRules &&
legacyDNSMode &&
!r.ruleStrategyDeprecatedReported &&
hasDNSRuleActionStrategy(r.rawRules)
modeFlags.neededFromStrategy
newSnapshot := newRulesSnapshot(newRules, legacyDNSMode)
r.stateAccess.Lock()
if r.closing {
@@ -247,22 +254,22 @@ func (r *Router) acquireRulesSnapshot() *rulesSnapshot {
return snapshot
}
func (r *Router) buildRules(startRules bool) ([]adapter.DNSRule, bool, error) {
func (r *Router) buildRules(startRules bool) ([]adapter.DNSRule, bool, dnsRuleModeFlags, error) {
for i, ruleOptions := range r.rawRules {
err := R.ValidateNoNestedDNSRuleActions(ruleOptions)
if err != nil {
return nil, false, E.Cause(err, "parse dns rule[", i, "]")
return nil, false, dnsRuleModeFlags{}, E.Cause(err, "parse dns rule[", i, "]")
}
}
router := service.FromContext[adapter.Router](r.ctx)
legacyDNSMode, err := resolveLegacyDNSMode(router, r.rawRules)
legacyDNSMode, modeFlags, err := resolveLegacyDNSMode(router, r.rawRules)
if err != nil {
return nil, false, err
return nil, false, dnsRuleModeFlags{}, err
}
if !legacyDNSMode {
err = validateLegacyDNSModeDisabledRules(r.rawRules)
if err != nil {
return nil, false, err
return nil, false, dnsRuleModeFlags{}, err
}
}
newRules := make([]adapter.DNSRule, 0, len(r.rawRules))
@@ -271,7 +278,7 @@ func (r *Router) buildRules(startRules bool) ([]adapter.DNSRule, bool, error) {
dnsRule, err = R.NewDNSRule(r.ctx, r.logger, ruleOptions, true, legacyDNSMode)
if err != nil {
closeRules(newRules)
return nil, false, E.Cause(err, "parse dns rule[", i, "]")
return nil, false, dnsRuleModeFlags{}, E.Cause(err, "parse dns rule[", i, "]")
}
newRules = append(newRules, dnsRule)
}
@@ -280,11 +287,11 @@ func (r *Router) buildRules(startRules bool) ([]adapter.DNSRule, bool, error) {
err = rule.Start()
if err != nil {
closeRules(newRules)
return nil, false, E.Cause(err, "initialize DNS rule[", i, "]")
return nil, false, dnsRuleModeFlags{}, E.Cause(err, "initialize DNS rule[", i, "]")
}
}
}
return newRules, legacyDNSMode, nil
return newRules, legacyDNSMode, modeFlags, nil
}
func closeRules(rules []adapter.DNSRule) {
@@ -433,8 +440,8 @@ const (
dnsRouteStatusResolved
)
func (r *Router) resolveDNSRoute(action *R.RuleActionDNSRoute, allowFakeIP bool, options *adapter.DNSQueryOptions) (adapter.DNSTransport, dnsRouteStatus) {
transport, loaded := r.transport.Transport(action.Server)
func (r *Router) resolveDNSRoute(server string, routeOptions R.RuleActionDNSRouteOptions, allowFakeIP bool, options *adapter.DNSQueryOptions) (adapter.DNSTransport, dnsRouteStatus) {
transport, loaded := r.transport.Transport(server)
if !loaded {
return nil, dnsRouteStatusMissing
}
@@ -442,7 +449,7 @@ func (r *Router) resolveDNSRoute(action *R.RuleActionDNSRoute, allowFakeIP bool,
if isFakeIP && !allowFakeIP {
return transport, dnsRouteStatusSkipped
}
r.applyDNSRouteOptions(options, action.RuleActionDNSRouteOptions)
r.applyDNSRouteOptions(options, routeOptions)
if isFakeIP {
options.DisableCache = true
}
@@ -484,10 +491,7 @@ func (r *Router) exchangeWithRules(ctx context.Context, rules []adapter.DNSRule,
r.applyDNSRouteOptions(&effectiveOptions, *action)
case *R.RuleActionEvaluate:
queryOptions := effectiveOptions
transport, status := r.resolveDNSRoute(&R.RuleActionDNSRoute{
Server: action.Server,
RuleActionDNSRouteOptions: action.RuleActionDNSRouteOptions,
}, allowFakeIP, &queryOptions)
transport, status := r.resolveDNSRoute(action.Server, action.RuleActionDNSRouteOptions, allowFakeIP, &queryOptions)
switch status {
case dnsRouteStatusMissing:
r.logger.ErrorContext(ctx, "transport not found: ", action.Server)
@@ -512,7 +516,7 @@ func (r *Router) exchangeWithRules(ctx context.Context, rules []adapter.DNSRule,
savedResponse = response
case *R.RuleActionDNSRoute:
queryOptions := effectiveOptions
transport, status := r.resolveDNSRoute(action, allowFakeIP, &queryOptions)
transport, status := r.resolveDNSRoute(action.Server, action.RuleActionDNSRouteOptions, allowFakeIP, &queryOptions)
switch status {
case dnsRouteStatusMissing:
r.logger.ErrorContext(ctx, "transport not found: ", action.Server)
@@ -569,10 +573,6 @@ func (r *Router) exchangeWithRules(ctx context.Context, rules []adapter.DNSRule,
}
}
type lookupWithRulesResponse struct {
addresses []netip.Addr
}
func (r *Router) resolveLookupStrategy(options adapter.DNSQueryOptions) C.DomainStrategy {
if options.LookupStrategy != C.DomainStrategyAsIS {
return options.LookupStrategy
@@ -618,16 +618,14 @@ func (r *Router) lookupWithRules(ctx context.Context, rules []adapter.DNSRule, d
lookupOptions.Strategy = strategy
}
if strategy == C.DomainStrategyIPv4Only {
response, err := r.lookupWithRulesType(ctx, rules, domain, mDNS.TypeA, lookupOptions)
return response.addresses, err
return r.lookupWithRulesType(ctx, rules, domain, mDNS.TypeA, lookupOptions)
}
if strategy == C.DomainStrategyIPv6Only {
response, err := r.lookupWithRulesType(ctx, rules, domain, mDNS.TypeAAAA, lookupOptions)
return response.addresses, err
return r.lookupWithRulesType(ctx, rules, domain, mDNS.TypeAAAA, lookupOptions)
}
var (
response4 lookupWithRulesResponse
response6 lookupWithRulesResponse
response4 []netip.Addr
response6 []netip.Addr
)
var group task.Group
group.Append("exchange4", func(ctx context.Context) error {
@@ -641,13 +639,13 @@ func (r *Router) lookupWithRules(ctx context.Context, rules []adapter.DNSRule, d
return err
})
err := group.Run(ctx)
if len(response4.addresses) == 0 && len(response6.addresses) == 0 {
if len(response4) == 0 && len(response6) == 0 {
return nil, err
}
return sortAddresses(response4.addresses, response6.addresses, strategy), nil
return sortAddresses(response4, response6, strategy), nil
}
func (r *Router) lookupWithRulesType(ctx context.Context, rules []adapter.DNSRule, domain string, qType uint16, options adapter.DNSQueryOptions) (lookupWithRulesResponse, error) {
func (r *Router) lookupWithRulesType(ctx context.Context, rules []adapter.DNSRule, domain string, qType uint16, options adapter.DNSQueryOptions) ([]netip.Addr, error) {
request := &mDNS.Msg{
MsgHdr: mDNS.MsgHdr{
RecursionDesired: true,
@@ -659,18 +657,16 @@ func (r *Router) lookupWithRulesType(ctx context.Context, rules []adapter.DNSRul
}},
}
exchangeResult := r.exchangeWithRules(withLookupQueryMetadata(ctx, qType), rules, request, options, false)
result := lookupWithRulesResponse{}
if exchangeResult.rejectAction != nil {
return result, exchangeResult.rejectAction.Error(ctx)
return nil, exchangeResult.rejectAction.Error(ctx)
}
if exchangeResult.err != nil {
return result, exchangeResult.err
return nil, exchangeResult.err
}
if exchangeResult.response.Rcode != mDNS.RcodeSuccess {
return result, RcodeError(exchangeResult.response.Rcode)
return nil, RcodeError(exchangeResult.response.Rcode)
}
result.addresses = filterAddressesByQueryType(MessageToAddresses(exchangeResult.response), qType)
return result, nil
return filterAddressesByQueryType(MessageToAddresses(exchangeResult.response), qType), nil
}
func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapter.DNSQueryOptions) (*mDNS.Msg, error) {
@@ -688,14 +684,7 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte
}
snapshot := r.acquireRulesSnapshot()
defer snapshot.release()
var (
rules []adapter.DNSRule
legacyDNSMode bool
)
if snapshot != nil {
rules = snapshot.rules
legacyDNSMode = snapshot.legacyDNSMode
}
rules, legacyDNSMode := snapshot.rulesAndMode()
r.logger.DebugContext(ctx, "exchange ", FormatQuestion(message.Question[0].String()))
var (
response *mDNS.Msg
@@ -803,14 +792,7 @@ done:
func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) {
snapshot := r.acquireRulesSnapshot()
defer snapshot.release()
var (
rules []adapter.DNSRule
legacyDNSMode bool
)
if snapshot != nil {
rules = snapshot.rules
legacyDNSMode = snapshot.legacyDNSMode
}
rules, legacyDNSMode := snapshot.rulesAndMode()
var (
responseAddrs []netip.Addr
err error
@@ -964,84 +946,92 @@ func defaultRuleDisablesLegacyDNSMode(rule option.DefaultDNSRule) bool {
len(rule.QueryType) > 0
}
func resolveLegacyDNSMode(router adapter.Router, rules []option.DNSRule) (bool, error) {
legacyDNSModeDisabled, needsLegacyDNSMode, needsLegacyDNSModeFromStrategy, err := dnsRuleModeRequirements(router, rules)
type dnsRuleModeFlags struct {
disabled bool
needed bool
neededFromStrategy bool
}
func (f *dnsRuleModeFlags) merge(other dnsRuleModeFlags) {
f.disabled = f.disabled || other.disabled
f.needed = f.needed || other.needed
f.neededFromStrategy = f.neededFromStrategy || other.neededFromStrategy
}
func resolveLegacyDNSMode(router adapter.Router, rules []option.DNSRule) (bool, dnsRuleModeFlags, error) {
flags, err := dnsRuleModeRequirements(router, rules)
if err != nil {
return false, err
return false, flags, err
}
if legacyDNSModeDisabled && needsLegacyDNSModeFromStrategy {
return false, E.New("DNS rule action strategy is only supported in legacyDNSMode")
if flags.disabled && flags.neededFromStrategy {
return false, flags, E.New("DNS rule action strategy is only supported in legacyDNSMode")
}
if legacyDNSModeDisabled {
return false, nil
if flags.disabled {
return false, flags, nil
}
return needsLegacyDNSMode, nil
return flags.needed, flags, nil
}
func dnsRuleModeRequirements(router adapter.Router, rules []option.DNSRule) (bool, bool, bool, error) {
var legacyDNSModeDisabled bool
var needsLegacyDNSMode bool
var needsLegacyDNSModeFromStrategy bool
func dnsRuleModeRequirements(router adapter.Router, rules []option.DNSRule) (dnsRuleModeFlags, error) {
var flags dnsRuleModeFlags
for i, rule := range rules {
ruleLegacyDNSModeDisabled, ruleNeedsLegacyDNSMode, ruleNeedsLegacyDNSModeFromStrategy, err := dnsRuleModeRequirementsInRule(router, rule)
ruleFlags, err := dnsRuleModeRequirementsInRule(router, rule)
if err != nil {
return false, false, false, E.Cause(err, "dns rule[", i, "]")
return dnsRuleModeFlags{}, E.Cause(err, "dns rule[", i, "]")
}
legacyDNSModeDisabled = legacyDNSModeDisabled || ruleLegacyDNSModeDisabled
needsLegacyDNSMode = needsLegacyDNSMode || ruleNeedsLegacyDNSMode
needsLegacyDNSModeFromStrategy = needsLegacyDNSModeFromStrategy || ruleNeedsLegacyDNSModeFromStrategy
flags.merge(ruleFlags)
}
return legacyDNSModeDisabled, needsLegacyDNSMode, needsLegacyDNSModeFromStrategy, nil
return flags, nil
}
func dnsRuleModeRequirementsInRule(router adapter.Router, rule option.DNSRule) (bool, bool, bool, error) {
func dnsRuleModeRequirementsInRule(router adapter.Router, rule option.DNSRule) (dnsRuleModeFlags, error) {
switch rule.Type {
case "", C.RuleTypeDefault:
return dnsRuleModeRequirementsInDefaultRule(router, rule.DefaultOptions)
case C.RuleTypeLogical:
legacyDNSModeDisabled := dnsRuleActionType(rule) == C.RuleActionTypeEvaluate
needsLegacyDNSModeFromStrategy := dnsRuleActionHasStrategy(rule.LogicalOptions.DNSRuleAction)
needsLegacyDNSMode := needsLegacyDNSModeFromStrategy
for i, subRule := range rule.LogicalOptions.Rules {
subLegacyDNSModeDisabled, subNeedsLegacyDNSMode, subNeedsLegacyDNSModeFromStrategy, err := dnsRuleModeRequirementsInRule(router, subRule)
if err != nil {
return false, false, false, E.Cause(err, "sub rule[", i, "]")
}
legacyDNSModeDisabled = legacyDNSModeDisabled || subLegacyDNSModeDisabled
needsLegacyDNSMode = needsLegacyDNSMode || subNeedsLegacyDNSMode
needsLegacyDNSModeFromStrategy = needsLegacyDNSModeFromStrategy || subNeedsLegacyDNSModeFromStrategy
flags := dnsRuleModeFlags{
disabled: dnsRuleActionType(rule) == C.RuleActionTypeEvaluate,
neededFromStrategy: dnsRuleActionHasStrategy(rule.LogicalOptions.DNSRuleAction),
}
return legacyDNSModeDisabled, needsLegacyDNSMode, needsLegacyDNSModeFromStrategy, nil
flags.needed = flags.neededFromStrategy
for i, subRule := range rule.LogicalOptions.Rules {
subFlags, err := dnsRuleModeRequirementsInRule(router, subRule)
if err != nil {
return dnsRuleModeFlags{}, E.Cause(err, "sub rule[", i, "]")
}
flags.merge(subFlags)
}
return flags, nil
default:
return false, false, false, nil
return dnsRuleModeFlags{}, nil
}
}
func dnsRuleModeRequirementsInDefaultRule(router adapter.Router, rule option.DefaultDNSRule) (bool, bool, bool, error) {
legacyDNSModeDisabled := defaultRuleDisablesLegacyDNSMode(rule)
needsLegacyDNSModeFromStrategy := dnsRuleActionHasStrategy(rule.DNSRuleAction)
needsLegacyDNSMode := defaultRuleNeedsLegacyDNSModeFromAddressFilter(rule) || needsLegacyDNSModeFromStrategy
func dnsRuleModeRequirementsInDefaultRule(router adapter.Router, rule option.DefaultDNSRule) (dnsRuleModeFlags, error) {
flags := dnsRuleModeFlags{
disabled: defaultRuleDisablesLegacyDNSMode(rule),
neededFromStrategy: dnsRuleActionHasStrategy(rule.DNSRuleAction),
}
flags.needed = defaultRuleNeedsLegacyDNSModeFromAddressFilter(rule) || flags.neededFromStrategy
if len(rule.RuleSet) == 0 {
return legacyDNSModeDisabled, needsLegacyDNSMode, needsLegacyDNSModeFromStrategy, nil
return flags, nil
}
if router == nil {
return false, false, false, E.New("router service not found")
return dnsRuleModeFlags{}, E.New("router service not found")
}
for _, tag := range rule.RuleSet {
ruleSet, loaded := router.RuleSet(tag)
if !loaded {
return false, false, false, E.New("rule-set not found: ", tag)
return dnsRuleModeFlags{}, E.New("rule-set not found: ", tag)
}
metadata := ruleSet.Metadata()
// Rule sets are built from headless rules, so query_type is the only
// per-query DNS predicate they can contribute here. ip_version is not a
// headless-rule item and is therefore intentionally absent from metadata.
legacyDNSModeDisabled = legacyDNSModeDisabled || metadata.ContainsDNSQueryTypeRule
// ip_version is not a headless-rule item, so ContainsIPVersionRule is intentionally absent.
flags.disabled = flags.disabled || metadata.ContainsDNSQueryTypeRule
if !rule.RuleSetIPCIDRMatchSource && metadata.ContainsIPCIDRRule {
needsLegacyDNSMode = true
flags.needed = true
}
}
return legacyDNSModeDisabled, needsLegacyDNSMode, needsLegacyDNSModeFromStrategy, nil
return flags, nil
}
func referencedDNSRuleSetTags(rules []option.DNSRule) []string {
@@ -1126,29 +1116,6 @@ func validateLegacyDNSModeDisabledDefaultRule(rule option.DefaultDNSRule) (bool,
return rule.MatchResponse, nil
}
func hasDNSRuleActionStrategy(rules []option.DNSRule) bool {
for _, rule := range rules {
if dnsRuleHasActionStrategy(rule) {
return true
}
}
return false
}
func dnsRuleHasActionStrategy(rule option.DNSRule) bool {
switch rule.Type {
case "", C.RuleTypeDefault:
return dnsRuleActionHasStrategy(rule.DefaultOptions.DNSRuleAction)
case C.RuleTypeLogical:
if dnsRuleActionHasStrategy(rule.LogicalOptions.DNSRuleAction) {
return true
}
return hasDNSRuleActionStrategy(rule.LogicalOptions.Rules)
default:
return false
}
}
func dnsRuleActionHasStrategy(action option.DNSRuleAction) bool {
switch action.Action {
case "", C.RuleActionTypeRoute, C.RuleActionTypeEvaluate: