dns: restore init validation and fix rule-set query type

This commit is contained in:
世界
2026-03-26 12:27:22 +08:00
parent 40b9c64a0d
commit e09a6d3206
2 changed files with 140 additions and 33 deletions

View File

@@ -91,6 +91,11 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOp
func (r *Router) Initialize(rules []option.DNSRule) error {
r.rawRules = append(r.rawRules[:0], rules...)
newRules, _, err := r.buildRules(false)
if err != nil {
return err
}
closeRules(newRules)
return nil
}
@@ -142,43 +147,19 @@ func (r *Router) Close() error {
}
func (r *Router) rebuildRules(startRules bool) error {
router := service.FromContext[adapter.Router](r.ctx)
legacyAddressFilterMode, err := resolveLegacyAddressFilterMode(router, r.rawRules)
newRules, legacyAddressFilterMode, err := r.buildRules(startRules)
if err != nil {
return err
}
if !legacyAddressFilterMode {
err = validateNonLegacyAddressFilterRules(r.rawRules)
if err != nil {
return err
}
}
newRules := make([]adapter.DNSRule, 0, len(r.rawRules))
for i, ruleOptions := range r.rawRules {
dnsRule, err := R.NewDNSRule(r.ctx, r.logger, ruleOptions, true, legacyAddressFilterMode)
if err != nil {
closeRules(newRules)
return E.Cause(err, "parse dns rule[", i, "]")
}
newRules = append(newRules, dnsRule)
}
if startRules {
for i, rule := range newRules {
err := rule.Start()
if err != nil {
closeRules(newRules)
return E.Cause(err, "initialize DNS rule[", i, "]")
}
}
}
shouldReportDeprecated := startRules &&
legacyAddressFilterMode &&
!r.deprecatedReported &&
common.Any(newRules, func(rule adapter.DNSRule) bool { return rule.WithAddressLimit() })
r.rulesAccess.Lock()
oldRules := r.rules
r.rules = newRules
r.legacyAddressFilterMode = legacyAddressFilterMode
r.runtimeRuleError = nil
shouldReportDeprecated := legacyAddressFilterMode &&
!r.deprecatedReported &&
common.Any(newRules, func(rule adapter.DNSRule) bool { return rule.WithAddressLimit() })
if shouldReportDeprecated {
r.deprecatedReported = true
}
@@ -190,6 +171,39 @@ func (r *Router) rebuildRules(startRules bool) error {
return nil
}
func (r *Router) buildRules(startRules bool) ([]adapter.DNSRule, bool, error) {
router := service.FromContext[adapter.Router](r.ctx)
legacyAddressFilterMode, err := resolveLegacyAddressFilterMode(router, r.rawRules)
if err != nil {
return nil, false, err
}
if !legacyAddressFilterMode {
err = validateNonLegacyAddressFilterRules(r.rawRules)
if err != nil {
return nil, false, err
}
}
newRules := make([]adapter.DNSRule, 0, len(r.rawRules))
for i, ruleOptions := range r.rawRules {
dnsRule, err := R.NewDNSRule(r.ctx, r.logger, ruleOptions, true, legacyAddressFilterMode)
if err != nil {
closeRules(newRules)
return nil, false, E.Cause(err, "parse dns rule[", i, "]")
}
newRules = append(newRules, dnsRule)
}
if startRules {
for i, rule := range newRules {
err := rule.Start()
if err != nil {
closeRules(newRules)
return nil, false, E.Cause(err, "initialize DNS rule[", i, "]")
}
}
}
return newRules, legacyAddressFilterMode, nil
}
func closeRules(rules []adapter.DNSRule) {
for _, rule := range rules {
_ = rule.Close()

View File

@@ -278,7 +278,34 @@ func TestValidateNewDNSRules_RequireMatchResponseForDirectIPCIDR(t *testing.T) {
require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response")
}
func TestStartNewModeRejectsDirectLegacyRuleWhenRuleSetForcesNew(t *testing.T) {
func TestInitializeRejectsInvalidDNSRuleParseError(t *testing.T) {
t.Parallel()
router := &Router{
ctx: context.Background(),
logger: log.NewNOPFactory().NewLogger("dns"),
transport: &fakeDNSTransportManager{},
client: &fakeDNSClient{},
rawRules: make([]option.DNSRule, 0, 1),
rules: make([]adapter.DNSRule, 0, 1),
defaultDomainStrategy: C.DomainStrategyAsIS,
}
err := router.Initialize([]option.DNSRule{{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
DomainRegex: badoption.Listable[string]{"("},
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.DNSRouteActionOptions{Server: "default"},
},
},
}})
require.ErrorContains(t, err, "domain_regex")
}
func TestInitializeRejectsDirectLegacyRuleWhenRuleSetForcesNew(t *testing.T) {
t.Parallel()
ctx := context.Background()
@@ -336,9 +363,6 @@ func TestStartNewModeRejectsDirectLegacyRuleWhenRuleSetForcesNew(t *testing.T) {
},
},
})
require.NoError(t, err)
err = router.Start(adapter.StartStateStart)
require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response")
}
@@ -1076,6 +1100,75 @@ func TestLookupNewModeUsesQueryTypeRule(t *testing.T) {
require.Equal(t, []netip.Addr{netip.MustParseAddr("9.9.9.9")}, addresses)
}
func TestLookupNewModeUsesRuleSetQueryTypeRule(t *testing.T) {
t.Parallel()
ctx := context.Background()
ruleSet, err := rulepkg.NewRuleSet(ctx, log.NewNOPFactory().NewLogger("router"), option.RuleSet{
Type: C.RuleSetTypeInline,
Tag: "query-set",
InlineOptions: option.PlainRuleSet{
Rules: []option.HeadlessRule{{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultHeadlessRule{
QueryType: badoption.Listable[option.DNSQueryType]{option.DNSQueryType(mDNS.TypeA)},
},
}},
},
})
require.NoError(t, err)
ctx = service.ContextWith[adapter.Router](ctx, &fakeRouter{
ruleSets: map[string]adapter.RuleSet{
"query-set": ruleSet,
},
})
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
router := newTestRouterWithContext(t, ctx, []option.DNSRule{{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
RuleSet: badoption.Listable[string]{"query-set"},
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.DNSRouteActionOptions{Server: "only-a"},
},
},
}}, &fakeDNSTransportManager{
defaultTransport: defaultTransport,
transports: map[string]adapter.DNSTransport{
"default": defaultTransport,
"only-a": &fakeDNSTransport{tag: "only-a", transportType: C.DNSTypeUDP},
},
}, &fakeDNSClient{
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
switch transport.Tag() {
case "default":
if message.Question[0].Qtype == mDNS.TypeA {
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("3.3.3.3")}, 60), nil
}
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2001:db8::4")}, 60), nil
case "only-a":
if message.Question[0].Qtype == mDNS.TypeA {
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("9.9.9.9")}, 60), nil
}
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2001:db8::9")}, 60), nil
default:
return nil, errors.New("unexpected transport")
}
},
})
require.False(t, router.legacyAddressFilterMode)
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
require.NoError(t, err)
require.Equal(t, []netip.Addr{
netip.MustParseAddr("9.9.9.9"),
netip.MustParseAddr("2001:db8::4"),
}, addresses)
}
func TestLookupNewModeUsesIPVersionRule(t *testing.T) {
t.Parallel()