From 40b9c64a0d790bc4e1c5e4c6e70a39b4d023e6b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Thu, 26 Mar 2026 11:30:17 +0800 Subject: [PATCH] dns: make rule path selection rule-set aware --- adapter/router.go | 7 +- box.go | 2 +- dns/router.go | 307 +++++++++++++++++++++++++++++----- dns/router_test.go | 259 +++++++++++++++++++++++++++- route/rule/rule_dns_legacy.go | 36 ++-- route/rule/rule_set.go | 4 + route/rule/rule_set_local.go | 1 + route/rule/rule_set_remote.go | 1 + 8 files changed, 556 insertions(+), 61 deletions(-) diff --git a/adapter/router.go b/adapter/router.go index 82e6881a6..a8f66ba67 100644 --- a/adapter/router.go +++ b/adapter/router.go @@ -67,9 +67,10 @@ type RuleSet interface { type RuleSetUpdateCallback func(it RuleSet) type RuleSetMetadata struct { - ContainsProcessRule bool - ContainsWIFIRule bool - ContainsIPCIDRRule bool + ContainsProcessRule bool + ContainsWIFIRule bool + ContainsIPCIDRRule bool + ContainsDNSQueryTypeRule bool } type HTTPStartContext struct { ctx context.Context diff --git a/box.go b/box.go index a765e21d8..82403a29c 100644 --- a/box.go +++ b/box.go @@ -486,7 +486,7 @@ func (s *Box) preStart() error { if err != nil { return err } - err = adapter.Start(s.logger, adapter.StartStateStart, s.outbound, s.dnsTransport, s.dnsRouter, s.network, s.connection, s.router) + err = adapter.Start(s.logger, adapter.StartStateStart, s.outbound, s.dnsTransport, s.network, s.connection, s.router, s.dnsRouter) if err != nil { return err } diff --git a/dns/router.go b/dns/router.go index 3518a8460..f5b80a120 100644 --- a/dns/router.go +++ b/dns/router.go @@ -5,6 +5,7 @@ import ( "errors" "net/netip" "strings" + "sync" "time" "github.com/sagernet/sing-box/adapter" @@ -21,6 +22,7 @@ import ( "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" "github.com/sagernet/sing/common/task" + "github.com/sagernet/sing/common/x/list" "github.com/sagernet/sing/contrab/freelru" "github.com/sagernet/sing/contrab/maphash" "github.com/sagernet/sing/service" @@ -30,17 +32,27 @@ import ( var _ adapter.DNSRouter = (*Router)(nil) +type dnsRuleSetCallback struct { + ruleSet adapter.RuleSet + element *list.Element[adapter.RuleSetUpdateCallback] +} + type Router struct { ctx context.Context logger logger.ContextLogger transport adapter.DNSTransportManager outbound adapter.OutboundManager client adapter.DNSClient + rawRules []option.DNSRule rules []adapter.DNSRule defaultDomainStrategy C.DomainStrategy dnsReverseMapping freelru.Cache[netip.Addr, string] platformInterface adapter.PlatformInterface legacyAddressFilterMode bool + rulesAccess sync.RWMutex + ruleSetCallbacks []dnsRuleSetCallback + runtimeRuleError error + deprecatedReported bool } func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOptions) *Router { @@ -49,6 +61,7 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOp logger: logFactory.NewLogger("dns"), transport: service.FromContext[adapter.DNSTransportManager](ctx), outbound: service.FromContext[adapter.OutboundManager](ctx), + rawRules: make([]option.DNSRule, 0, len(options.Rules)), rules: make([]adapter.DNSRule, 0, len(options.Rules)), defaultDomainStrategy: C.DomainStrategy(options.Strategy), } @@ -77,20 +90,7 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOp } func (r *Router) Initialize(rules []option.DNSRule) error { - r.legacyAddressFilterMode = hasLegacyAddressFilterItems(rules) - if !r.legacyAddressFilterMode { - err := validateNonLegacyAddressFilterRules(rules) - if err != nil { - return err - } - } - for i, ruleOptions := range rules { - dnsRule, err := R.NewDNSRule(r.ctx, r.logger, ruleOptions, true, r.legacyAddressFilterMode) - if err != nil { - return E.Cause(err, "parse dns rule[", i, "]") - } - r.rules = append(r.rules, dnsRule) - } + r.rawRules = append(r.rawRules[:0], rules...) return nil } @@ -102,16 +102,17 @@ func (r *Router) Start(stage adapter.StartStage) error { r.client.Start() monitor.Finish() - for i, rule := range r.rules { - monitor.Start("initialize DNS rule[", i, "]") - err := rule.Start() - monitor.Finish() - if err != nil { - return E.Cause(err, "initialize DNS rule[", i, "]") - } + monitor.Start("initialize DNS rules") + err := r.rebuildRules(true) + monitor.Finish() + if err != nil { + return err } - if r.legacyAddressFilterMode && common.Any(r.rules, func(rule adapter.DNSRule) bool { return rule.WithAddressLimit() }) { - deprecated.Report(r.ctx, deprecated.OptionLegacyDNSAddressFilter) + monitor.Start("register DNS rule-set callbacks") + err = r.registerRuleSetCallbacks() + monitor.Finish() + if err != nil { + return err } } return nil @@ -119,8 +120,18 @@ func (r *Router) Start(stage adapter.StartStage) error { func (r *Router) Close() error { monitor := taskmonitor.New(r.logger, C.StopTimeout) + r.rulesAccess.Lock() + callbacks := r.ruleSetCallbacks + r.ruleSetCallbacks = nil + runtimeRules := r.rules + r.rules = nil + r.runtimeRuleError = nil + r.rulesAccess.Unlock() + for _, callback := range callbacks { + callback.ruleSet.UnregisterCallback(callback.element) + } var err error - for i, rule := range r.rules { + for i, rule := range runtimeRules { monitor.Start("close dns rule[", i, "]") err = E.Append(err, rule.Close(), func(err error) error { return E.Cause(err, "close dns rule[", i, "]") @@ -130,6 +141,111 @@ func (r *Router) Close() error { return err } +func (r *Router) rebuildRules(startRules bool) error { + router := service.FromContext[adapter.Router](r.ctx) + legacyAddressFilterMode, err := resolveLegacyAddressFilterMode(router, r.rawRules) + if err != nil { + return err + } + if !legacyAddressFilterMode { + err = validateNonLegacyAddressFilterRules(r.rawRules) + if err != nil { + return err + } + } + newRules := make([]adapter.DNSRule, 0, len(r.rawRules)) + for i, ruleOptions := range r.rawRules { + dnsRule, err := R.NewDNSRule(r.ctx, r.logger, ruleOptions, true, legacyAddressFilterMode) + if err != nil { + closeRules(newRules) + return E.Cause(err, "parse dns rule[", i, "]") + } + newRules = append(newRules, dnsRule) + } + if startRules { + for i, rule := range newRules { + err := rule.Start() + if err != nil { + closeRules(newRules) + return E.Cause(err, "initialize DNS rule[", i, "]") + } + } + } + r.rulesAccess.Lock() + oldRules := r.rules + r.rules = newRules + r.legacyAddressFilterMode = legacyAddressFilterMode + r.runtimeRuleError = nil + shouldReportDeprecated := legacyAddressFilterMode && + !r.deprecatedReported && + common.Any(newRules, func(rule adapter.DNSRule) bool { return rule.WithAddressLimit() }) + if shouldReportDeprecated { + r.deprecatedReported = true + } + r.rulesAccess.Unlock() + closeRules(oldRules) + if shouldReportDeprecated { + deprecated.Report(r.ctx, deprecated.OptionLegacyDNSAddressFilter) + } + return nil +} + +func closeRules(rules []adapter.DNSRule) { + for _, rule := range rules { + _ = rule.Close() + } +} + +func (r *Router) registerRuleSetCallbacks() error { + tags := referencedDNSRuleSetTags(r.rawRules) + if len(tags) == 0 { + return nil + } + r.rulesAccess.RLock() + if len(r.ruleSetCallbacks) > 0 { + r.rulesAccess.RUnlock() + return nil + } + r.rulesAccess.RUnlock() + router := service.FromContext[adapter.Router](r.ctx) + if router == nil { + return E.New("router service not found") + } + callbacks := make([]dnsRuleSetCallback, 0, len(tags)) + for _, tag := range tags { + ruleSet, loaded := router.RuleSet(tag) + if !loaded { + for _, callback := range callbacks { + callback.ruleSet.UnregisterCallback(callback.element) + } + return E.New("rule-set not found: ", tag) + } + element := ruleSet.RegisterCallback(func(adapter.RuleSet) { + err := r.rebuildRules(true) + if err != nil { + r.rulesAccess.Lock() + r.runtimeRuleError = err + r.rulesAccess.Unlock() + r.logger.Error(E.Cause(err, "rebuild DNS rules after rule-set update")) + } + }) + callbacks = append(callbacks, dnsRuleSetCallback{ + ruleSet: ruleSet, + element: element, + }) + } + r.rulesAccess.Lock() + if len(r.ruleSetCallbacks) == 0 { + r.ruleSetCallbacks = callbacks + callbacks = nil + } + r.rulesAccess.Unlock() + for _, callback := range callbacks { + callback.ruleSet.UnregisterCallback(callback.element) + } + return nil +} + func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int, isAddressQuery bool, options *adapter.DNSQueryOptions) (adapter.DNSTransport, adapter.DNSRule, int) { metadata := adapter.ContextFrom(ctx) if metadata == nil { @@ -538,6 +654,11 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte } return &responseMessage, nil } + r.rulesAccess.RLock() + defer r.rulesAccess.RUnlock() + if r.runtimeRuleError != nil { + return nil, r.runtimeRuleError + } r.logger.DebugContext(ctx, "exchange ", FormatQuestion(message.Question[0].String())) var ( response *mDNS.Msg @@ -639,6 +760,11 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte } func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) { + r.rulesAccess.RLock() + defer r.rulesAccess.RUnlock() + if r.runtimeRuleError != nil { + return nil, r.runtimeRuleError + } var ( responseAddrs []netip.Addr err error @@ -769,28 +895,124 @@ func (r *Router) ResetNetwork() { } } -func hasLegacyAddressFilterItems(rules []option.DNSRule) bool { - return common.Any(rules, hasLegacyAddressFilterItemsInRule) -} - -func hasLegacyAddressFilterItemsInRule(rule option.DNSRule) bool { - switch rule.Type { - case "", C.RuleTypeDefault: - return hasLegacyAddressFilterItemsInDefaultRule(rule.DefaultOptions) - case C.RuleTypeLogical: - return common.Any(rule.LogicalOptions.Rules, hasLegacyAddressFilterItemsInRule) - default: - return false - } -} - -func hasLegacyAddressFilterItemsInDefaultRule(rule option.DefaultDNSRule) bool { +func hasDirectLegacyAddressFilterItemsInDefaultRule(rule option.DefaultDNSRule) bool { if rule.IPAcceptAny || rule.RuleSetIPCIDRAcceptEmpty { return true } return !rule.MatchResponse && (len(rule.IPCIDR) > 0 || rule.IPIsPrivate) } +func hasResponseMatchFields(rule option.DefaultDNSRule) bool { + return rule.ResponseRcode != nil || + len(rule.ResponseAnswer) > 0 || + len(rule.ResponseNs) > 0 || + len(rule.ResponseExtra) > 0 +} + +func defaultRuleForcesNewDNSPath(rule option.DefaultDNSRule) bool { + return rule.MatchResponse || + hasResponseMatchFields(rule) || + rule.Action == C.RuleActionTypeEvaluate || + rule.IPVersion > 0 || + len(rule.QueryType) > 0 +} + +func resolveLegacyAddressFilterMode(router adapter.Router, rules []option.DNSRule) (bool, error) { + forceNew, needsLegacy, err := dnsRuleModeRequirements(router, rules) + if err != nil { + return false, err + } + if forceNew { + return false, nil + } + return needsLegacy, nil +} + +func dnsRuleModeRequirements(router adapter.Router, rules []option.DNSRule) (bool, bool, error) { + var forceNew bool + var needsLegacy bool + for i, rule := range rules { + ruleForceNew, ruleNeedsLegacy, err := dnsRuleModeRequirementsInRule(router, rule) + if err != nil { + return false, false, E.Cause(err, "dns rule[", i, "]") + } + forceNew = forceNew || ruleForceNew + needsLegacy = needsLegacy || ruleNeedsLegacy + } + return forceNew, needsLegacy, nil +} + +func dnsRuleModeRequirementsInRule(router adapter.Router, rule option.DNSRule) (bool, bool, error) { + switch rule.Type { + case "", C.RuleTypeDefault: + return dnsRuleModeRequirementsInDefaultRule(router, rule.DefaultOptions) + case C.RuleTypeLogical: + forceNew := dnsRuleActionType(rule) == C.RuleActionTypeEvaluate + var needsLegacy bool + for i, subRule := range rule.LogicalOptions.Rules { + subForceNew, subNeedsLegacy, err := dnsRuleModeRequirementsInRule(router, subRule) + if err != nil { + return false, false, E.Cause(err, "sub rule[", i, "]") + } + forceNew = forceNew || subForceNew + needsLegacy = needsLegacy || subNeedsLegacy + } + return forceNew, needsLegacy, nil + default: + return false, false, nil + } +} + +func dnsRuleModeRequirementsInDefaultRule(router adapter.Router, rule option.DefaultDNSRule) (bool, bool, error) { + forceNew := defaultRuleForcesNewDNSPath(rule) + needsLegacy := hasDirectLegacyAddressFilterItemsInDefaultRule(rule) + if len(rule.RuleSet) == 0 { + return forceNew, needsLegacy, nil + } + if router == nil { + return false, false, E.New("router service not found") + } + for _, tag := range rule.RuleSet { + ruleSet, loaded := router.RuleSet(tag) + if !loaded { + return false, false, E.New("rule-set not found: ", tag) + } + metadata := ruleSet.Metadata() + forceNew = forceNew || metadata.ContainsDNSQueryTypeRule + if !rule.RuleSetIPCIDRMatchSource && metadata.ContainsIPCIDRRule { + needsLegacy = true + } + } + return forceNew, needsLegacy, nil +} + +func referencedDNSRuleSetTags(rules []option.DNSRule) []string { + tagMap := make(map[string]bool) + var walkRule func(rule option.DNSRule) + walkRule = func(rule option.DNSRule) { + switch rule.Type { + case "", C.RuleTypeDefault: + for _, tag := range rule.DefaultOptions.RuleSet { + tagMap[tag] = true + } + case C.RuleTypeLogical: + for _, subRule := range rule.LogicalOptions.Rules { + walkRule(subRule) + } + } + } + for _, rule := range rules { + walkRule(rule) + } + tags := make([]string, 0, len(tagMap)) + for tag := range tagMap { + if tag != "" { + tags = append(tags, tag) + } + } + return tags +} + func validateNonLegacyAddressFilterRules(rules []option.DNSRule) error { var seenEvaluate bool for i, rule := range rules { @@ -832,10 +1054,7 @@ func validateNonLegacyAddressFilterRuleTree(rule option.DNSRule) (bool, error) { } func validateNonLegacyAddressFilterDefaultRule(rule option.DefaultDNSRule) (bool, error) { - hasResponseRecords := rule.ResponseRcode != nil || - len(rule.ResponseAnswer) > 0 || - len(rule.ResponseNs) > 0 || - len(rule.ResponseExtra) > 0 + hasResponseRecords := hasResponseMatchFields(rule) if hasResponseRecords && !rule.MatchResponse { return false, E.New("response_* items require match_response") } diff --git a/dns/router_test.go b/dns/router_test.go index 7012789fc..2c967c3fa 100644 --- a/dns/router_test.go +++ b/dns/router_test.go @@ -6,17 +6,23 @@ import ( "net" "net/netip" "testing" + "time" "github.com/sagernet/sing-box/adapter" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/experimental/deprecated" "github.com/sagernet/sing-box/log" "github.com/sagernet/sing-box/option" + rulepkg "github.com/sagernet/sing-box/route/rule" + "github.com/sagernet/sing-tun" "github.com/sagernet/sing/common/json/badoption" + N "github.com/sagernet/sing/common/network" + "github.com/sagernet/sing/common/x/list" "github.com/sagernet/sing/service" mDNS "github.com/miekg/dns" "github.com/stretchr/testify/require" + "go4.org/netipx" ) type fakeDNSTransport struct { @@ -72,6 +78,69 @@ type fakeDeprecatedManager struct { features []deprecated.Note } +type fakeRouter struct { + ruleSets map[string]adapter.RuleSet +} + +func (r *fakeRouter) Start(adapter.StartStage) error { return nil } +func (r *fakeRouter) Close() error { return nil } +func (r *fakeRouter) PreMatch(metadata adapter.InboundContext, _ tun.DirectRouteContext, _ time.Duration, _ bool) (tun.DirectRouteDestination, error) { + return nil, nil +} + +func (r *fakeRouter) RouteConnection(context.Context, net.Conn, adapter.InboundContext) error { + return nil +} + +func (r *fakeRouter) RoutePacketConnection(context.Context, N.PacketConn, adapter.InboundContext) error { + return nil +} + +func (r *fakeRouter) RouteConnectionEx(context.Context, net.Conn, adapter.InboundContext, N.CloseHandlerFunc) { +} + +func (r *fakeRouter) RoutePacketConnectionEx(context.Context, N.PacketConn, adapter.InboundContext, N.CloseHandlerFunc) { +} + +func (r *fakeRouter) RuleSet(tag string) (adapter.RuleSet, bool) { + ruleSet, loaded := r.ruleSets[tag] + return ruleSet, loaded +} +func (r *fakeRouter) Rules() []adapter.Rule { return nil } +func (r *fakeRouter) NeedFindProcess() bool { return false } +func (r *fakeRouter) NeedFindNeighbor() bool { return false } +func (r *fakeRouter) NeighborResolver() adapter.NeighborResolver { return nil } +func (r *fakeRouter) AppendTracker(adapter.ConnectionTracker) {} +func (r *fakeRouter) ResetNetwork() {} + +type fakeRuleSet struct { + metadata adapter.RuleSetMetadata + callbacks []adapter.RuleSetUpdateCallback +} + +func (s *fakeRuleSet) Name() string { return "fake-rule-set" } +func (s *fakeRuleSet) StartContext(context.Context, *adapter.HTTPStartContext) error { return nil } +func (s *fakeRuleSet) PostStart() error { return nil } +func (s *fakeRuleSet) Metadata() adapter.RuleSetMetadata { return s.metadata } +func (s *fakeRuleSet) ExtractIPSet() []*netipx.IPSet { return nil } +func (s *fakeRuleSet) IncRef() {} +func (s *fakeRuleSet) DecRef() {} +func (s *fakeRuleSet) Cleanup() {} +func (s *fakeRuleSet) RegisterCallback(callback adapter.RuleSetUpdateCallback) *list.Element[adapter.RuleSetUpdateCallback] { + s.callbacks = append(s.callbacks, callback) + return nil +} +func (s *fakeRuleSet) UnregisterCallback(*list.Element[adapter.RuleSetUpdateCallback]) {} +func (s *fakeRuleSet) Close() error { return nil } +func (s *fakeRuleSet) Match(*adapter.InboundContext) bool { return true } +func (s *fakeRuleSet) String() string { return "fake-rule-set" } +func (s *fakeRuleSet) updateMetadata(metadata adapter.RuleSetMetadata) { + s.metadata = metadata + for _, callback := range s.callbacks { + callback(s) + } +} + func (m *fakeDeprecatedManager) ReportDeprecated(feature deprecated.Note) { m.features = append(m.features, feature) } @@ -108,18 +177,25 @@ func (c *fakeDNSClient) Lookup(_ context.Context, transport adapter.DNSTransport func (c *fakeDNSClient) ClearCache() {} func newTestRouter(t *testing.T, rules []option.DNSRule, transportManager *fakeDNSTransportManager, client *fakeDNSClient) *Router { + return newTestRouterWithContext(t, context.Background(), rules, transportManager, client) +} + +func newTestRouterWithContext(t *testing.T, ctx context.Context, rules []option.DNSRule, transportManager *fakeDNSTransportManager, client *fakeDNSClient) *Router { t.Helper() router := &Router{ - ctx: context.Background(), + ctx: ctx, logger: log.NewNOPFactory().NewLogger("dns"), transport: transportManager, client: client, + rawRules: make([]option.DNSRule, 0, len(rules)), rules: make([]adapter.DNSRule, 0, len(rules)), defaultDomainStrategy: C.DomainStrategyAsIS, } if rules != nil { err := router.Initialize(rules) require.NoError(t, err) + err = router.Start(adapter.StartStateStart) + require.NoError(t, err) } return router } @@ -202,6 +278,187 @@ func TestValidateNewDNSRules_RequireMatchResponseForDirectIPCIDR(t *testing.T) { require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response") } +func TestStartNewModeRejectsDirectLegacyRuleWhenRuleSetForcesNew(t *testing.T) { + t.Parallel() + + ctx := context.Background() + ruleSet, err := rulepkg.NewRuleSet(ctx, log.NewNOPFactory().NewLogger("router"), option.RuleSet{ + Type: C.RuleSetTypeInline, + Tag: "query-set", + InlineOptions: option.PlainRuleSet{ + Rules: []option.HeadlessRule{{ + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultHeadlessRule{ + QueryType: badoption.Listable[option.DNSQueryType]{option.DNSQueryType(mDNS.TypeA)}, + }, + }}, + }, + }) + require.NoError(t, err) + ctx = service.ContextWith[adapter.Router](ctx, &fakeRouter{ + ruleSets: map[string]adapter.RuleSet{ + "query-set": ruleSet, + }, + }) + + router := &Router{ + ctx: ctx, + logger: log.NewNOPFactory().NewLogger("dns"), + transport: &fakeDNSTransportManager{}, + client: &fakeDNSClient{}, + rawRules: make([]option.DNSRule, 0, 2), + rules: make([]adapter.DNSRule, 0, 2), + defaultDomainStrategy: C.DomainStrategyAsIS, + } + err = router.Initialize([]option.DNSRule{ + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + RuleSet: badoption.Listable[string]{"query-set"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "default"}, + }, + }, + }, + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + IPIsPrivate: true, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "private"}, + }, + }, + }, + }) + require.NoError(t, err) + + err = router.Start(adapter.StartStateStart) + require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response") +} + +func TestLookupLegacyModeDefersRuleSetDestinationIPMatch(t *testing.T) { + t.Parallel() + + ctx := context.Background() + ruleSet, err := rulepkg.NewRuleSet(ctx, log.NewNOPFactory().NewLogger("router"), option.RuleSet{ + Type: C.RuleSetTypeInline, + Tag: "legacy-ipcidr-set", + InlineOptions: option.PlainRuleSet{ + Rules: []option.HeadlessRule{{ + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultHeadlessRule{ + IPCIDR: badoption.Listable[string]{"10.0.0.0/8"}, + }, + }}, + }, + }) + require.NoError(t, err) + ctx = service.ContextWith[adapter.Router](ctx, &fakeRouter{ + ruleSets: map[string]adapter.RuleSet{ + "legacy-ipcidr-set": ruleSet, + }, + }) + + defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} + privateTransport := &fakeDNSTransport{tag: "private", transportType: C.DNSTypeUDP} + router := newTestRouterWithContext(t, ctx, []option.DNSRule{{ + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + RuleSet: badoption.Listable[string]{"legacy-ipcidr-set"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "private"}, + }, + }, + }}, &fakeDNSTransportManager{ + defaultTransport: defaultTransport, + transports: map[string]adapter.DNSTransport{ + "default": defaultTransport, + "private": privateTransport, + }, + }, &fakeDNSClient{ + lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) { + require.Equal(t, "example.com", domain) + require.Equal(t, "private", transport.Tag()) + response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60) + return MessageToAddresses(response), response, nil + }, + }) + + require.True(t, router.legacyAddressFilterMode) + + addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{ + LookupStrategy: C.DomainStrategyIPv4Only, + }) + require.NoError(t, err) + require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses) +} + +func TestRuleSetUpdateSetsRuntimeErrorWhenRebuildFails(t *testing.T) { + t.Parallel() + + fakeSet := &fakeRuleSet{} + ctx := service.ContextWith[adapter.Router](context.Background(), &fakeRouter{ + ruleSets: map[string]adapter.RuleSet{ + "dynamic-set": fakeSet, + }, + }) + defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} + router := newTestRouterWithContext(t, ctx, []option.DNSRule{ + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + RuleSet: badoption.Listable[string]{"dynamic-set"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "default"}, + }, + }, + }, + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + IPIsPrivate: true, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "default"}, + }, + }, + }, + }, &fakeDNSTransportManager{ + defaultTransport: defaultTransport, + transports: map[string]adapter.DNSTransport{ + "default": defaultTransport, + }, + }, &fakeDNSClient{ + lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) { + response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60) + return MessageToAddresses(response), response, nil + }, + }) + + require.True(t, router.legacyAddressFilterMode) + + fakeSet.updateMetadata(adapter.RuleSetMetadata{ + ContainsDNSQueryTypeRule: true, + }) + + _, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) + require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response") +} + func TestLookupLegacyModeDefersDirectDestinationIPMatch(t *testing.T) { t.Parallel() diff --git a/route/rule/rule_dns_legacy.go b/route/rule/rule_dns_legacy.go index 8cdad83f3..25088cacc 100644 --- a/route/rule/rule_dns_legacy.go +++ b/route/rule/rule_dns_legacy.go @@ -47,7 +47,12 @@ type legacyResponseConstraint struct { forbiddenSet *netipx.IPSet } -type legacyRuleMatchStateSet [16]legacyResponseFormula +const ( + legacyRuleMatchDeferredDestinationAddress ruleMatchState = 1 << 4 + legacyRuleMatchStateCount = 32 +) + +type legacyRuleMatchStateSet [legacyRuleMatchStateCount]legacyResponseFormula var ( legacyAllIPSet = func() *netipx.IPSet { @@ -350,7 +355,7 @@ func (s legacyRuleMatchStateSet) isEmpty() bool { func (s legacyRuleMatchStateSet) merge(other legacyRuleMatchStateSet) legacyRuleMatchStateSet { var merged legacyRuleMatchStateSet - for state := ruleMatchState(0); state < 16; state++ { + for state := ruleMatchState(0); state < legacyRuleMatchStateCount; state++ { merged[state] = s[state].or(other[state]) } return merged @@ -361,11 +366,11 @@ func (s legacyRuleMatchStateSet) combine(other legacyRuleMatchStateSet) legacyRu return legacyRuleMatchStateSet{} } var combined legacyRuleMatchStateSet - for left := ruleMatchState(0); left < 16; left++ { + for left := ruleMatchState(0); left < legacyRuleMatchStateCount; left++ { if s[left].isFalse() { continue } - for right := ruleMatchState(0); right < 16; right++ { + for right := ruleMatchState(0); right < legacyRuleMatchStateCount; right++ { if other[right].isFalse() { continue } @@ -380,7 +385,7 @@ func (s legacyRuleMatchStateSet) withBase(base ruleMatchState) legacyRuleMatchSt return legacyRuleMatchStateSet{} } var withBase legacyRuleMatchStateSet - for state := ruleMatchState(0); state < 16; state++ { + for state := ruleMatchState(0); state < legacyRuleMatchStateCount; state++ { if s[state].isFalse() { continue } @@ -391,7 +396,7 @@ func (s legacyRuleMatchStateSet) withBase(base ruleMatchState) legacyRuleMatchSt func (s legacyRuleMatchStateSet) filter(allowed func(ruleMatchState) bool) legacyRuleMatchStateSet { var filtered legacyRuleMatchStateSet - for state := ruleMatchState(0); state < 16; state++ { + for state := ruleMatchState(0); state < legacyRuleMatchStateCount; state++ { if s[state].isFalse() { continue } @@ -404,7 +409,7 @@ func (s legacyRuleMatchStateSet) filter(allowed func(ruleMatchState) bool) legac func (s legacyRuleMatchStateSet) addBit(bit ruleMatchState) legacyRuleMatchStateSet { var withBit legacyRuleMatchStateSet - for state := ruleMatchState(0); state < 16; state++ { + for state := ruleMatchState(0); state < legacyRuleMatchStateCount; state++ { if s[state].isFalse() { continue } @@ -422,7 +427,7 @@ func (s legacyRuleMatchStateSet) branchOnBit(bit ruleMatchState, condition legac } var branched legacyRuleMatchStateSet conditionFalse := condition.not() - for state := ruleMatchState(0); state < 16; state++ { + for state := ruleMatchState(0); state < legacyRuleMatchStateCount; state++ { if s[state].isFalse() { continue } @@ -444,7 +449,7 @@ func (s legacyRuleMatchStateSet) andFormula(formula legacyResponseFormula) legac return s } var result legacyRuleMatchStateSet - for state := ruleMatchState(0); state < 16; state++ { + for state := ruleMatchState(0); state < legacyRuleMatchStateCount; state++ { if s[state].isFalse() { continue } @@ -588,7 +593,7 @@ func (r *abstractDefaultRule) legacyMatchStatesWithBase(metadata *adapter.Inboun } if r.legacyDestinationIPCIDRMatchesDestination(metadata) { metadata.DidMatch = true - stateSet = stateSet.branchOnBit(ruleMatchDestinationAddress, legacyDestinationIPFormula(r.destinationIPCIDRItems, metadata)) + stateSet = stateSet.branchOnBit(legacyRuleMatchDeferredDestinationAddress, legacyDestinationIPFormula(r.destinationIPCIDRItems, metadata)) } if len(r.destinationPortItems) > 0 { metadata.DidMatch = true @@ -608,7 +613,7 @@ func (r *abstractDefaultRule) legacyMatchStatesWithBase(metadata *adapter.Inboun if r.ruleSetItem != nil { metadata.DidMatch = true var merged legacyRuleMatchStateSet - for state := ruleMatchState(0); state < 16; state++ { + for state := ruleMatchState(0); state < legacyRuleMatchStateCount; state++ { if stateSet[state].isFalse() { continue } @@ -627,6 +632,9 @@ func (r *abstractDefaultRule) legacyMatchStatesWithBase(metadata *adapter.Inboun if r.legacyRequiresDestinationAddressMatch(metadata) && !state.has(ruleMatchDestinationAddress) { return false } + if r.legacyRequiresDeferredDestinationAddressMatch(metadata) && !state.has(legacyRuleMatchDeferredDestinationAddress) { + return false + } if len(r.destinationPortItems) > 0 && !state.has(ruleMatchDestinationPort) { return false } @@ -647,7 +655,11 @@ func (r *abstractDefaultRule) legacyDestinationIPCIDRMatchesDestination(metadata } func (r *abstractDefaultRule) legacyRequiresDestinationAddressMatch(metadata *adapter.InboundContext) bool { - return len(r.destinationAddressItems) > 0 || r.legacyDestinationIPCIDRMatchesDestination(metadata) + return len(r.destinationAddressItems) > 0 +} + +func (r *abstractDefaultRule) legacyRequiresDeferredDestinationAddressMatch(metadata *adapter.InboundContext) bool { + return r.legacyDestinationIPCIDRMatchesDestination(metadata) } func (r *abstractLogicalRule) legacyMatchStates(metadata *adapter.InboundContext) legacyRuleMatchStateSet { diff --git a/route/rule/rule_set.go b/route/rule/rule_set.go index 39068dbf3..9bffa8fcb 100644 --- a/route/rule/rule_set.go +++ b/route/rule/rule_set.go @@ -69,3 +69,7 @@ func isWIFIHeadlessRule(rule option.DefaultHeadlessRule) bool { func isIPCIDRHeadlessRule(rule option.DefaultHeadlessRule) bool { return len(rule.IPCIDR) > 0 || rule.IPSet != nil } + +func isDNSQueryTypeHeadlessRule(rule option.DefaultHeadlessRule) bool { + return len(rule.QueryType) > 0 +} diff --git a/route/rule/rule_set_local.go b/route/rule/rule_set_local.go index ed873d706..51e8f2723 100644 --- a/route/rule/rule_set_local.go +++ b/route/rule/rule_set_local.go @@ -141,6 +141,7 @@ func (s *LocalRuleSet) reloadRules(headlessRules []option.HeadlessRule) error { metadata.ContainsProcessRule = HasHeadlessRule(headlessRules, isProcessHeadlessRule) metadata.ContainsWIFIRule = HasHeadlessRule(headlessRules, isWIFIHeadlessRule) metadata.ContainsIPCIDRRule = HasHeadlessRule(headlessRules, isIPCIDRHeadlessRule) + metadata.ContainsDNSQueryTypeRule = HasHeadlessRule(headlessRules, isDNSQueryTypeHeadlessRule) s.access.Lock() s.rules = rules s.metadata = metadata diff --git a/route/rule/rule_set_remote.go b/route/rule/rule_set_remote.go index bda6e23f1..4d2691450 100644 --- a/route/rule/rule_set_remote.go +++ b/route/rule/rule_set_remote.go @@ -193,6 +193,7 @@ func (s *RemoteRuleSet) loadBytes(content []byte) error { s.metadata.ContainsProcessRule = HasHeadlessRule(plainRuleSet.Rules, isProcessHeadlessRule) s.metadata.ContainsWIFIRule = HasHeadlessRule(plainRuleSet.Rules, isWIFIHeadlessRule) s.metadata.ContainsIPCIDRRule = HasHeadlessRule(plainRuleSet.Rules, isIPCIDRHeadlessRule) + s.metadata.ContainsDNSQueryTypeRule = HasHeadlessRule(plainRuleSet.Rules, isDNSQueryTypeHeadlessRule) s.rules = rules callbacks := s.callbacks.Array() s.access.Unlock()