mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-13 20:28:32 +10:00
dns: restore init validation and fix rule-set query type
This commit is contained in:
@@ -91,6 +91,11 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOp
|
||||
|
||||
func (r *Router) Initialize(rules []option.DNSRule) error {
|
||||
r.rawRules = append(r.rawRules[:0], rules...)
|
||||
newRules, _, err := r.buildRules(false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
closeRules(newRules)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -142,43 +147,19 @@ func (r *Router) Close() error {
|
||||
}
|
||||
|
||||
func (r *Router) rebuildRules(startRules bool) error {
|
||||
router := service.FromContext[adapter.Router](r.ctx)
|
||||
legacyAddressFilterMode, err := resolveLegacyAddressFilterMode(router, r.rawRules)
|
||||
newRules, legacyAddressFilterMode, err := r.buildRules(startRules)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !legacyAddressFilterMode {
|
||||
err = validateNonLegacyAddressFilterRules(r.rawRules)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
newRules := make([]adapter.DNSRule, 0, len(r.rawRules))
|
||||
for i, ruleOptions := range r.rawRules {
|
||||
dnsRule, err := R.NewDNSRule(r.ctx, r.logger, ruleOptions, true, legacyAddressFilterMode)
|
||||
if err != nil {
|
||||
closeRules(newRules)
|
||||
return E.Cause(err, "parse dns rule[", i, "]")
|
||||
}
|
||||
newRules = append(newRules, dnsRule)
|
||||
}
|
||||
if startRules {
|
||||
for i, rule := range newRules {
|
||||
err := rule.Start()
|
||||
if err != nil {
|
||||
closeRules(newRules)
|
||||
return E.Cause(err, "initialize DNS rule[", i, "]")
|
||||
}
|
||||
}
|
||||
}
|
||||
shouldReportDeprecated := startRules &&
|
||||
legacyAddressFilterMode &&
|
||||
!r.deprecatedReported &&
|
||||
common.Any(newRules, func(rule adapter.DNSRule) bool { return rule.WithAddressLimit() })
|
||||
r.rulesAccess.Lock()
|
||||
oldRules := r.rules
|
||||
r.rules = newRules
|
||||
r.legacyAddressFilterMode = legacyAddressFilterMode
|
||||
r.runtimeRuleError = nil
|
||||
shouldReportDeprecated := legacyAddressFilterMode &&
|
||||
!r.deprecatedReported &&
|
||||
common.Any(newRules, func(rule adapter.DNSRule) bool { return rule.WithAddressLimit() })
|
||||
if shouldReportDeprecated {
|
||||
r.deprecatedReported = true
|
||||
}
|
||||
@@ -190,6 +171,39 @@ func (r *Router) rebuildRules(startRules bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Router) buildRules(startRules bool) ([]adapter.DNSRule, bool, error) {
|
||||
router := service.FromContext[adapter.Router](r.ctx)
|
||||
legacyAddressFilterMode, err := resolveLegacyAddressFilterMode(router, r.rawRules)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
if !legacyAddressFilterMode {
|
||||
err = validateNonLegacyAddressFilterRules(r.rawRules)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
}
|
||||
newRules := make([]adapter.DNSRule, 0, len(r.rawRules))
|
||||
for i, ruleOptions := range r.rawRules {
|
||||
dnsRule, err := R.NewDNSRule(r.ctx, r.logger, ruleOptions, true, legacyAddressFilterMode)
|
||||
if err != nil {
|
||||
closeRules(newRules)
|
||||
return nil, false, E.Cause(err, "parse dns rule[", i, "]")
|
||||
}
|
||||
newRules = append(newRules, dnsRule)
|
||||
}
|
||||
if startRules {
|
||||
for i, rule := range newRules {
|
||||
err := rule.Start()
|
||||
if err != nil {
|
||||
closeRules(newRules)
|
||||
return nil, false, E.Cause(err, "initialize DNS rule[", i, "]")
|
||||
}
|
||||
}
|
||||
}
|
||||
return newRules, legacyAddressFilterMode, nil
|
||||
}
|
||||
|
||||
func closeRules(rules []adapter.DNSRule) {
|
||||
for _, rule := range rules {
|
||||
_ = rule.Close()
|
||||
|
||||
@@ -278,7 +278,34 @@ func TestValidateNewDNSRules_RequireMatchResponseForDirectIPCIDR(t *testing.T) {
|
||||
require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response")
|
||||
}
|
||||
|
||||
func TestStartNewModeRejectsDirectLegacyRuleWhenRuleSetForcesNew(t *testing.T) {
|
||||
func TestInitializeRejectsInvalidDNSRuleParseError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
router := &Router{
|
||||
ctx: context.Background(),
|
||||
logger: log.NewNOPFactory().NewLogger("dns"),
|
||||
transport: &fakeDNSTransportManager{},
|
||||
client: &fakeDNSClient{},
|
||||
rawRules: make([]option.DNSRule, 0, 1),
|
||||
rules: make([]adapter.DNSRule, 0, 1),
|
||||
defaultDomainStrategy: C.DomainStrategyAsIS,
|
||||
}
|
||||
err := router.Initialize([]option.DNSRule{{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
DomainRegex: badoption.Listable[string]{"("},
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeRoute,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "default"},
|
||||
},
|
||||
},
|
||||
}})
|
||||
require.ErrorContains(t, err, "domain_regex")
|
||||
}
|
||||
|
||||
func TestInitializeRejectsDirectLegacyRuleWhenRuleSetForcesNew(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
@@ -336,9 +363,6 @@ func TestStartNewModeRejectsDirectLegacyRuleWhenRuleSetForcesNew(t *testing.T) {
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = router.Start(adapter.StartStateStart)
|
||||
require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response")
|
||||
}
|
||||
|
||||
@@ -1076,6 +1100,75 @@ func TestLookupNewModeUsesQueryTypeRule(t *testing.T) {
|
||||
require.Equal(t, []netip.Addr{netip.MustParseAddr("9.9.9.9")}, addresses)
|
||||
}
|
||||
|
||||
func TestLookupNewModeUsesRuleSetQueryTypeRule(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
ruleSet, err := rulepkg.NewRuleSet(ctx, log.NewNOPFactory().NewLogger("router"), option.RuleSet{
|
||||
Type: C.RuleSetTypeInline,
|
||||
Tag: "query-set",
|
||||
InlineOptions: option.PlainRuleSet{
|
||||
Rules: []option.HeadlessRule{{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultHeadlessRule{
|
||||
QueryType: badoption.Listable[option.DNSQueryType]{option.DNSQueryType(mDNS.TypeA)},
|
||||
},
|
||||
}},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
ctx = service.ContextWith[adapter.Router](ctx, &fakeRouter{
|
||||
ruleSets: map[string]adapter.RuleSet{
|
||||
"query-set": ruleSet,
|
||||
},
|
||||
})
|
||||
|
||||
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
|
||||
router := newTestRouterWithContext(t, ctx, []option.DNSRule{{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
RuleSet: badoption.Listable[string]{"query-set"},
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeRoute,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "only-a"},
|
||||
},
|
||||
},
|
||||
}}, &fakeDNSTransportManager{
|
||||
defaultTransport: defaultTransport,
|
||||
transports: map[string]adapter.DNSTransport{
|
||||
"default": defaultTransport,
|
||||
"only-a": &fakeDNSTransport{tag: "only-a", transportType: C.DNSTypeUDP},
|
||||
},
|
||||
}, &fakeDNSClient{
|
||||
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||
switch transport.Tag() {
|
||||
case "default":
|
||||
if message.Question[0].Qtype == mDNS.TypeA {
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("3.3.3.3")}, 60), nil
|
||||
}
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2001:db8::4")}, 60), nil
|
||||
case "only-a":
|
||||
if message.Question[0].Qtype == mDNS.TypeA {
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("9.9.9.9")}, 60), nil
|
||||
}
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2001:db8::9")}, 60), nil
|
||||
default:
|
||||
return nil, errors.New("unexpected transport")
|
||||
}
|
||||
},
|
||||
})
|
||||
require.False(t, router.legacyAddressFilterMode)
|
||||
|
||||
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []netip.Addr{
|
||||
netip.MustParseAddr("9.9.9.9"),
|
||||
netip.MustParseAddr("2001:db8::4"),
|
||||
}, addresses)
|
||||
}
|
||||
|
||||
func TestLookupNewModeUsesIPVersionRule(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user