From db7655e7d31c010d7c18d41036439e52ab4d4c4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 26 Mar 2026 12:27:22 +0800 Subject: [PATCH] dns: restore init validation and fix rule-set query type --- dns/router.go | 72 +++++++++++++++++++------------- dns/router_test.go | 101 +++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 140 insertions(+), 33 deletions(-) diff --git a/dns/router.go b/dns/router.go index f5b80a120..fe8783766 100644 --- a/dns/router.go +++ b/dns/router.go @@ -91,6 +91,11 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOp func (r *Router) Initialize(rules []option.DNSRule) error { r.rawRules = append(r.rawRules[:0], rules...) + newRules, _, err := r.buildRules(false) + if err != nil { + return err + } + closeRules(newRules) return nil } @@ -142,43 +147,19 @@ func (r *Router) Close() error { } func (r *Router) rebuildRules(startRules bool) error { - router := service.FromContext[adapter.Router](r.ctx) - legacyAddressFilterMode, err := resolveLegacyAddressFilterMode(router, r.rawRules) + newRules, legacyAddressFilterMode, err := r.buildRules(startRules) if err != nil { return err } - if !legacyAddressFilterMode { - err = validateNonLegacyAddressFilterRules(r.rawRules) - if err != nil { - return err - } - } - newRules := make([]adapter.DNSRule, 0, len(r.rawRules)) - for i, ruleOptions := range r.rawRules { - dnsRule, err := R.NewDNSRule(r.ctx, r.logger, ruleOptions, true, legacyAddressFilterMode) - if err != nil { - closeRules(newRules) - return E.Cause(err, "parse dns rule[", i, "]") - } - newRules = append(newRules, dnsRule) - } - if startRules { - for i, rule := range newRules { - err := rule.Start() - if err != nil { - closeRules(newRules) - return E.Cause(err, "initialize DNS rule[", i, "]") - } - } - } + shouldReportDeprecated := startRules && + legacyAddressFilterMode && + !r.deprecatedReported && + common.Any(newRules, func(rule adapter.DNSRule) bool { return rule.WithAddressLimit() }) r.rulesAccess.Lock() oldRules := r.rules r.rules = newRules r.legacyAddressFilterMode = legacyAddressFilterMode r.runtimeRuleError = nil - shouldReportDeprecated := legacyAddressFilterMode && - !r.deprecatedReported && - common.Any(newRules, func(rule adapter.DNSRule) bool { return rule.WithAddressLimit() }) if shouldReportDeprecated { r.deprecatedReported = true } @@ -190,6 +171,39 @@ func (r *Router) rebuildRules(startRules bool) error { return nil } +func (r *Router) buildRules(startRules bool) ([]adapter.DNSRule, bool, error) { + router := service.FromContext[adapter.Router](r.ctx) + legacyAddressFilterMode, err := resolveLegacyAddressFilterMode(router, r.rawRules) + if err != nil { + return nil, false, err + } + if !legacyAddressFilterMode { + err = validateNonLegacyAddressFilterRules(r.rawRules) + if err != nil { + return nil, false, err + } + } + newRules := make([]adapter.DNSRule, 0, len(r.rawRules)) + for i, ruleOptions := range r.rawRules { + dnsRule, err := R.NewDNSRule(r.ctx, r.logger, ruleOptions, true, legacyAddressFilterMode) + if err != nil { + closeRules(newRules) + return nil, false, E.Cause(err, "parse dns rule[", i, "]") + } + newRules = append(newRules, dnsRule) + } + if startRules { + for i, rule := range newRules { + err := rule.Start() + if err != nil { + closeRules(newRules) + return nil, false, E.Cause(err, "initialize DNS rule[", i, "]") + } + } + } + return newRules, legacyAddressFilterMode, nil +} + func closeRules(rules []adapter.DNSRule) { for _, rule := range rules { _ = rule.Close() diff --git a/dns/router_test.go b/dns/router_test.go index 2c967c3fa..1f66f1757 100644 --- a/dns/router_test.go +++ b/dns/router_test.go @@ -278,7 +278,34 @@ func TestValidateNewDNSRules_RequireMatchResponseForDirectIPCIDR(t *testing.T) { require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response") } -func TestStartNewModeRejectsDirectLegacyRuleWhenRuleSetForcesNew(t *testing.T) { +func TestInitializeRejectsInvalidDNSRuleParseError(t *testing.T) { + t.Parallel() + + router := &Router{ + ctx: context.Background(), + logger: log.NewNOPFactory().NewLogger("dns"), + transport: &fakeDNSTransportManager{}, + client: &fakeDNSClient{}, + rawRules: make([]option.DNSRule, 0, 1), + rules: make([]adapter.DNSRule, 0, 1), + defaultDomainStrategy: C.DomainStrategyAsIS, + } + err := router.Initialize([]option.DNSRule{{ + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + DomainRegex: badoption.Listable[string]{"("}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "default"}, + }, + }, + }}) + require.ErrorContains(t, err, "domain_regex") +} + +func TestInitializeRejectsDirectLegacyRuleWhenRuleSetForcesNew(t *testing.T) { t.Parallel() ctx := context.Background() @@ -336,9 +363,6 @@ func TestStartNewModeRejectsDirectLegacyRuleWhenRuleSetForcesNew(t *testing.T) { }, }, }) - require.NoError(t, err) - - err = router.Start(adapter.StartStateStart) require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response") } @@ -1076,6 +1100,75 @@ func TestLookupNewModeUsesQueryTypeRule(t *testing.T) { require.Equal(t, []netip.Addr{netip.MustParseAddr("9.9.9.9")}, addresses) } +func TestLookupNewModeUsesRuleSetQueryTypeRule(t *testing.T) { + t.Parallel() + + ctx := context.Background() + ruleSet, err := rulepkg.NewRuleSet(ctx, log.NewNOPFactory().NewLogger("router"), option.RuleSet{ + Type: C.RuleSetTypeInline, + Tag: "query-set", + InlineOptions: option.PlainRuleSet{ + Rules: []option.HeadlessRule{{ + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultHeadlessRule{ + QueryType: badoption.Listable[option.DNSQueryType]{option.DNSQueryType(mDNS.TypeA)}, + }, + }}, + }, + }) + require.NoError(t, err) + ctx = service.ContextWith[adapter.Router](ctx, &fakeRouter{ + ruleSets: map[string]adapter.RuleSet{ + "query-set": ruleSet, + }, + }) + + defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} + router := newTestRouterWithContext(t, ctx, []option.DNSRule{{ + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + RuleSet: badoption.Listable[string]{"query-set"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "only-a"}, + }, + }, + }}, &fakeDNSTransportManager{ + defaultTransport: defaultTransport, + transports: map[string]adapter.DNSTransport{ + "default": defaultTransport, + "only-a": &fakeDNSTransport{tag: "only-a", transportType: C.DNSTypeUDP}, + }, + }, &fakeDNSClient{ + exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) { + switch transport.Tag() { + case "default": + if message.Question[0].Qtype == mDNS.TypeA { + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("3.3.3.3")}, 60), nil + } + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2001:db8::4")}, 60), nil + case "only-a": + if message.Question[0].Qtype == mDNS.TypeA { + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("9.9.9.9")}, 60), nil + } + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2001:db8::9")}, 60), nil + default: + return nil, errors.New("unexpected transport") + } + }, + }) + require.False(t, router.legacyAddressFilterMode) + + addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) + require.NoError(t, err) + require.Equal(t, []netip.Addr{ + netip.MustParseAddr("9.9.9.9"), + netip.MustParseAddr("2001:db8::4"), + }, addresses) +} + func TestLookupNewModeUsesIPVersionRule(t *testing.T) { t.Parallel()