From bd222fe9dfdc6ac8e5cc4780e9493b5c87900a54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Tue, 31 Mar 2026 13:15:25 +0800 Subject: [PATCH] dns: serialize rebuilds and keep last good rules on failure --- dns/router.go | 37 ++- dns/router_test.go | 627 ++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 574 insertions(+), 90 deletions(-) diff --git a/dns/router.go b/dns/router.go index b30f5a0c8..8e2424b9f 100644 --- a/dns/router.go +++ b/dns/router.go @@ -50,9 +50,9 @@ type Router struct { platformInterface adapter.PlatformInterface legacyDNSMode bool rulesAccess sync.RWMutex + rebuildAccess sync.Mutex closing bool ruleSetCallbacks []dnsRuleSetCallback - runtimeRuleError error addressFilterDeprecatedReported bool ruleStrategyDeprecatedReported bool } @@ -116,11 +116,19 @@ func (r *Router) Start(stage adapter.StartStage) error { return err } monitor.Start("register DNS rule-set callbacks") - err = r.registerRuleSetCallbacks() + needsRulesRefresh, err := r.registerRuleSetCallbacks() monitor.Finish() if err != nil { return err } + if needsRulesRefresh { + monitor.Start("refresh DNS rules after callback registration") + err = r.rebuildRules(true) + monitor.Finish() + if err != nil { + r.logger.Error(E.Cause(err, "refresh DNS rules after callback registration")) + } + } } return nil } @@ -133,7 +141,6 @@ func (r *Router) Close() error { r.ruleSetCallbacks = nil runtimeRules := r.rules r.rules = nil - r.runtimeRuleError = nil for _, callback := range callbacks { callback.ruleSet.UnregisterCallback(callback.element) } @@ -150,6 +157,8 @@ func (r *Router) Close() error { } func (r *Router) rebuildRules(startRules bool) error { + r.rebuildAccess.Lock() + defer r.rebuildAccess.Unlock() if r.isClosing() { return nil } @@ -177,7 +186,6 @@ func (r *Router) rebuildRules(startRules bool) error { oldRules := r.rules r.rules = newRules r.legacyDNSMode = legacyDNSMode - r.runtimeRuleError = nil if shouldReportAddressFilterDeprecated { r.addressFilterDeprecatedReported = true } @@ -246,20 +254,20 @@ func closeRules(rules []adapter.DNSRule) { } } -func (r *Router) registerRuleSetCallbacks() error { +func (r *Router) registerRuleSetCallbacks() (bool, error) { tags := referencedDNSRuleSetTags(r.rawRules) if len(tags) == 0 { - return nil + return false, nil } r.rulesAccess.RLock() if len(r.ruleSetCallbacks) > 0 { r.rulesAccess.RUnlock() - return nil + return true, nil } r.rulesAccess.RUnlock() router := service.FromContext[adapter.Router](r.ctx) if router == nil { - return E.New("router service not found") + return false, E.New("router service not found") } callbacks := make([]dnsRuleSetCallback, 0, len(tags)) for _, tag := range tags { @@ -268,14 +276,11 @@ func (r *Router) registerRuleSetCallbacks() error { for _, callback := range callbacks { callback.ruleSet.UnregisterCallback(callback.element) } - return E.New("rule-set not found: ", tag) + return false, E.New("rule-set not found: ", tag) } element := ruleSet.RegisterCallback(func(adapter.RuleSet) { err := r.rebuildRules(true) if err != nil { - r.rulesAccess.Lock() - r.runtimeRuleError = err - r.rulesAccess.Unlock() r.logger.Error(E.Cause(err, "rebuild DNS rules after rule-set update")) } }) @@ -293,7 +298,7 @@ func (r *Router) registerRuleSetCallbacks() error { for _, callback := range callbacks { callback.ruleSet.UnregisterCallback(callback.element) } - return nil + return true, nil } func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int, isAddressQuery bool, options *adapter.DNSQueryOptions) (adapter.DNSTransport, adapter.DNSRule, int) { @@ -653,9 +658,6 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte } r.rulesAccess.RLock() defer r.rulesAccess.RUnlock() - if r.runtimeRuleError != nil { - return nil, r.runtimeRuleError - } r.logger.DebugContext(ctx, "exchange ", FormatQuestion(message.Question[0].String())) var ( response *mDNS.Msg @@ -760,9 +762,6 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) { r.rulesAccess.RLock() defer r.rulesAccess.RUnlock() - if r.runtimeRuleError != nil { - return nil, r.runtimeRuleError - } var ( responseAddrs []netip.Addr err error diff --git a/dns/router_test.go b/dns/router_test.go index 46ddcd028..37ebcbc02 100644 --- a/dns/router_test.go +++ b/dns/router_test.go @@ -2,9 +2,10 @@ package dns import ( "context" - "errors" + "io" "net" "net/netip" + "strings" "sync" "testing" "time" @@ -16,6 +17,8 @@ import ( "github.com/sagernet/sing-box/option" rulepkg "github.com/sagernet/sing-box/route/rule" "github.com/sagernet/sing-tun" + "github.com/sagernet/sing/common" + E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/json/badoption" N "github.com/sagernet/sing/common/network" "github.com/sagernet/sing/common/x/list" @@ -38,7 +41,7 @@ func (t *fakeDNSTransport) Tag() string { return t.tag } func (t *fakeDNSTransport) Dependencies() []string { return nil } func (t *fakeDNSTransport) Reset() {} func (t *fakeDNSTransport) Exchange(context.Context, *mDNS.Msg) (*mDNS.Msg, error) { - return nil, errors.New("unused transport exchange") + return nil, E.New("unused transport exchange") } type fakeDNSTransportManager struct { @@ -66,7 +69,7 @@ func (m *fakeDNSTransportManager) FakeIP() adapter.FakeIPTransport { } func (m *fakeDNSTransportManager) Remove(string) error { return nil } func (m *fakeDNSTransportManager) Create(context.Context, log.ContextLogger, string, string, any) error { - return errors.New("unsupported") + return E.New("unsupported") } type fakeDNSClient struct { @@ -80,6 +83,7 @@ type fakeDeprecatedManager struct { } type fakeRouter struct { + access sync.RWMutex ruleSets map[string]adapter.RuleSet } @@ -104,9 +108,19 @@ func (r *fakeRouter) RoutePacketConnectionEx(context.Context, N.PacketConn, adap } func (r *fakeRouter) RuleSet(tag string) (adapter.RuleSet, bool) { + r.access.RLock() + defer r.access.RUnlock() ruleSet, loaded := r.ruleSets[tag] return ruleSet, loaded } +func (r *fakeRouter) setRuleSet(tag string, ruleSet adapter.RuleSet) { + r.access.Lock() + defer r.access.Unlock() + if r.ruleSets == nil { + r.ruleSets = make(map[string]adapter.RuleSet) + } + r.ruleSets[tag] = ruleSet +} func (r *fakeRouter) Rules() []adapter.Rule { return nil } func (r *fakeRouter) NeedFindProcess() bool { return false } func (r *fakeRouter) NeedFindNeighbor() bool { return false } @@ -115,10 +129,14 @@ func (r *fakeRouter) AppendTracker(adapter.ConnectionTracker) {} func (r *fakeRouter) ResetNetwork() {} type fakeRuleSet struct { - access sync.Mutex - metadata adapter.RuleSetMetadata - callbacks list.List[adapter.RuleSetUpdateCallback] - refs int + access sync.Mutex + metadata adapter.RuleSetMetadata + metadataRead func(adapter.RuleSetMetadata) adapter.RuleSetMetadata + match func(*adapter.InboundContext) bool + callbacks list.List[adapter.RuleSetUpdateCallback] + refs int + afterIncrementReference func() + beforeDecrementReference func() } func (s *fakeRuleSet) Name() string { return "fake-rule-set" } @@ -126,17 +144,32 @@ func (s *fakeRuleSet) StartContext(context.Context, *adapter.HTTPStartContext) e func (s *fakeRuleSet) PostStart() error { return nil } func (s *fakeRuleSet) Metadata() adapter.RuleSetMetadata { s.access.Lock() - defer s.access.Unlock() - return s.metadata + metadata := s.metadata + metadataRead := s.metadataRead + s.access.Unlock() + if metadataRead != nil { + return metadataRead(metadata) + } + return metadata } func (s *fakeRuleSet) ExtractIPSet() []*netipx.IPSet { return nil } func (s *fakeRuleSet) IncRef() { s.access.Lock() - defer s.access.Unlock() s.refs++ + afterIncrementReference := s.afterIncrementReference + s.access.Unlock() + if afterIncrementReference != nil { + afterIncrementReference() + } } func (s *fakeRuleSet) DecRef() { + s.access.Lock() + beforeDecrementReference := s.beforeDecrementReference + s.access.Unlock() + if beforeDecrementReference != nil { + beforeDecrementReference() + } s.access.Lock() defer s.access.Unlock() s.refs-- @@ -156,9 +189,17 @@ func (s *fakeRuleSet) UnregisterCallback(element *list.Element[adapter.RuleSetUp defer s.access.Unlock() s.callbacks.Remove(element) } -func (s *fakeRuleSet) Close() error { return nil } -func (s *fakeRuleSet) Match(*adapter.InboundContext) bool { return true } -func (s *fakeRuleSet) String() string { return "fake-rule-set" } +func (s *fakeRuleSet) Close() error { return nil } +func (s *fakeRuleSet) Match(metadata *adapter.InboundContext) bool { + s.access.Lock() + match := s.match + s.access.Unlock() + if match != nil { + return match(metadata) + } + return true +} +func (s *fakeRuleSet) String() string { return "fake-rule-set" } func (s *fakeRuleSet) updateMetadata(metadata adapter.RuleSetMetadata) { s.access.Lock() s.metadata = metadata @@ -196,7 +237,7 @@ func (c *fakeDNSClient) Exchange(ctx context.Context, transport adapter.DNSTrans func (c *fakeDNSClient) Lookup(_ context.Context, transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions, responseChecker func(*mDNS.Msg) bool) ([]netip.Addr, error) { if c.lookup == nil { - return nil, errors.New("unused client lookup") + return nil, E.New("unused client lookup") } addresses, response, err := c.lookup(transport, domain, options) if err != nil { @@ -221,10 +262,14 @@ func newTestRouter(t *testing.T, rules []option.DNSRule, transportManager *fakeD } func newTestRouterWithContext(t *testing.T, ctx context.Context, rules []option.DNSRule, transportManager *fakeDNSTransportManager, client *fakeDNSClient) *Router { + return newTestRouterWithContextAndLogger(t, ctx, rules, transportManager, client, log.NewNOPFactory().NewLogger("dns")) +} + +func newTestRouterWithContextAndLogger(t *testing.T, ctx context.Context, rules []option.DNSRule, transportManager *fakeDNSTransportManager, client *fakeDNSClient, dnsLogger log.ContextLogger) *Router { t.Helper() router := &Router{ ctx: ctx, - logger: log.NewNOPFactory().NewLogger("dns"), + logger: dnsLogger, transport: transportManager, client: client, rawRules: make([]option.DNSRule, 0, len(rules)), @@ -240,6 +285,26 @@ func newTestRouterWithContext(t *testing.T, ctx context.Context, rules []option. return router } +func waitForLogMessageContaining(t *testing.T, entries <-chan log.Entry, done <-chan struct{}, substring string) log.Entry { + t.Helper() + timeout := time.After(time.Second) + for { + select { + case entry, ok := <-entries: + if !ok { + t.Fatal("log subscription closed") + } + if strings.Contains(entry.Message, substring) { + return entry + } + case <-done: + t.Fatal("log subscription closed") + case <-timeout: + t.Fatalf("timed out waiting for log message containing %q", substring) + } + } +} + func fixedQuestion(name string, qType uint16) mDNS.Question { return mDNS.Question{ Name: mDNS.Fqdn(name), @@ -541,10 +606,356 @@ func TestRuleSetUpdateReleasesOldRuleSetRefs(t *testing.T) { require.Zero(t, fakeSet.refCount()) } -func TestRuleSetUpdateSetsRuntimeErrorWhenRebuildFails(t *testing.T) { +func TestRuleSetUpdateKeepsLastSuccessfullyCompiledRuleGraphWhenRebuildFails(t *testing.T) { t.Parallel() - fakeSet := &fakeRuleSet{} + callbackRuleSet := &fakeRuleSet{ + match: func(*adapter.InboundContext) bool { + return false + }, + } + routerService := &fakeRouter{ + ruleSets: map[string]adapter.RuleSet{ + "dynamic-set": callbackRuleSet, + }, + } + ctx := service.ContextWith[adapter.Router](context.Background(), routerService) + defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} + preservedTransport := &fakeDNSTransport{tag: "preserved", transportType: C.DNSTypeUDP} + wouldBeNewTransport := &fakeDNSTransport{tag: "would-be-new", transportType: C.DNSTypeUDP} + loggerFactory := log.NewDefaultFactory( + context.Background(), + log.Formatter{ + BaseTime: time.Now(), + DisableColors: true, + DisableTimestamp: true, + }, + io.Discard, + "", + nil, + true, + ) + loggerFactory.SetLevel(log.LevelError) + logEntries, logDone, err := loggerFactory.Subscribe() + require.NoError(t, err) + t.Cleanup(func() { + loggerFactory.UnSubscribe(logEntries) + closeErr := loggerFactory.Close() + require.NoError(t, closeErr) + }) + var lastUsedTransport common.TypedValue[string] + router := newTestRouterWithContextAndLogger(t, ctx, []option.DNSRule{ + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + RuleSet: badoption.Listable[string]{"dynamic-set"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "would-be-new"}, + }, + }, + }, + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + Domain: badoption.Listable[string]{"example.com"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "preserved"}, + }, + }, + }, + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + IPIsPrivate: true, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "preserved"}, + }, + }, + }, + }, &fakeDNSTransportManager{ + defaultTransport: defaultTransport, + transports: map[string]adapter.DNSTransport{ + "default": defaultTransport, + "preserved": preservedTransport, + "would-be-new": wouldBeNewTransport, + }, + }, &fakeDNSClient{ + lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) { + lastUsedTransport.Store(transport.Tag()) + response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60) + return MessageToAddresses(response), response, nil + }, + }, loggerFactory.NewLogger("dns")) + t.Cleanup(func() { + closeErr := router.Close() + require.NoError(t, closeErr) + }) + + require.True(t, router.legacyDNSMode) + require.Equal(t, 1, callbackRuleSet.refCount()) + + addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) + require.NoError(t, err) + require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses) + require.Equal(t, "preserved", lastUsedTransport.Load()) + + rebuildTargetRuleSet := &fakeRuleSet{ + metadata: adapter.RuleSetMetadata{ + ContainsDNSQueryTypeRule: true, + }, + match: func(*adapter.InboundContext) bool { + return true + }, + } + routerService.setRuleSet("dynamic-set", rebuildTargetRuleSet) + + callbackRuleSet.updateMetadata(adapter.RuleSetMetadata{ + ContainsDNSQueryTypeRule: true, + }) + rebuildErrorEntry := waitForLogMessageContaining(t, logEntries, logDone, "rebuild DNS rules after rule-set update") + require.Contains(t, rebuildErrorEntry.Message, "ip_cidr and ip_is_private require match_response") + require.True(t, router.legacyDNSMode) + require.Equal(t, 1, callbackRuleSet.refCount()) + require.Zero(t, rebuildTargetRuleSet.refCount()) + + lastUsedTransport.Store("") + addresses, err = router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) + require.NoError(t, err) + require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses) + require.Equal(t, "preserved", lastUsedTransport.Load()) + require.NotEqual(t, "would-be-new", lastUsedTransport.Load()) +} + +func TestRuleSetUpdateSerializesConcurrentRebuilds(t *testing.T) { + t.Parallel() + + callbackRuleSet := &fakeRuleSet{ + match: func(*adapter.InboundContext) bool { + return false + }, + } + routerService := &fakeRouter{ + ruleSets: map[string]adapter.RuleSet{ + "dynamic-set": callbackRuleSet, + }, + } + ctx := service.ContextWith[adapter.Router](context.Background(), routerService) + defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} + firstTransport := &fakeDNSTransport{tag: "first", transportType: C.DNSTypeUDP} + secondTransport := &fakeDNSTransport{tag: "second", transportType: C.DNSTypeUDP} + var lastUsedTransport common.TypedValue[string] + router := newTestRouterWithContext(t, ctx, []option.DNSRule{ + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + RuleSet: badoption.Listable[string]{"dynamic-set"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "first"}, + }, + }, + }, + { + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + Domain: badoption.Listable[string]{"example.com"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "second"}, + }, + }, + }, + }, &fakeDNSTransportManager{ + defaultTransport: defaultTransport, + transports: map[string]adapter.DNSTransport{ + "default": defaultTransport, + "first": firstTransport, + "second": secondTransport, + }, + }, &fakeDNSClient{ + exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) { + lastUsedTransport.Store(transport.Tag()) + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60), nil + }, + }) + + addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) + require.NoError(t, err) + require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses) + require.Equal(t, "second", lastUsedTransport.Load()) + + callbacks := callbackRuleSet.snapshotCallbacks() + require.Len(t, callbacks, 1) + + firstMetadataEntered := make(chan struct{}) + releaseFirstMetadata := make(chan struct{}) + firstRuleSetStarted := make(chan struct{}) + releaseFirstRuleSetStart := make(chan struct{}) + secondMetadataEntered := make(chan struct{}) + releaseSecondMetadata := make(chan struct{}) + + var metadataAccess sync.Mutex + var metadataCallCount int + var concurrentMetadataCalls int + var maximumConcurrentMetadataCalls int + + recordMetadataEntry := func() func() { + metadataAccess.Lock() + metadataCallCount++ + concurrentMetadataCalls++ + if concurrentMetadataCalls > maximumConcurrentMetadataCalls { + maximumConcurrentMetadataCalls = concurrentMetadataCalls + } + metadataAccess.Unlock() + return func() { + metadataAccess.Lock() + concurrentMetadataCalls-- + metadataAccess.Unlock() + } + } + + firstBuildRuleSet := &fakeRuleSet{ + match: func(*adapter.InboundContext) bool { + return true + }, + metadataRead: func(metadata adapter.RuleSetMetadata) adapter.RuleSetMetadata { + metadataDone := recordMetadataEntry() + close(firstMetadataEntered) + <-releaseFirstMetadata + metadataDone() + return metadata + }, + afterIncrementReference: func() { + close(firstRuleSetStarted) + <-releaseFirstRuleSetStart + }, + } + secondBuildRuleSet := &fakeRuleSet{ + match: func(*adapter.InboundContext) bool { + return false + }, + metadataRead: func(metadata adapter.RuleSetMetadata) adapter.RuleSetMetadata { + metadataDone := recordMetadataEntry() + close(secondMetadataEntered) + <-releaseSecondMetadata + metadataDone() + return metadata + }, + } + + routerService.setRuleSet("dynamic-set", firstBuildRuleSet) + + firstCallbackFinished := make(chan struct{}) + go func() { + callbacks[0](callbackRuleSet) + close(firstCallbackFinished) + }() + + select { + case <-firstMetadataEntered: + case <-time.After(time.Second): + t.Fatal("first rebuild did not reach rule-set metadata") + } + + close(releaseFirstMetadata) + + select { + case <-firstRuleSetStarted: + case <-time.After(time.Second): + t.Fatal("first rebuild did not reach rule-set start") + } + + routerService.setRuleSet("dynamic-set", secondBuildRuleSet) + + secondCallbackStarted := make(chan struct{}) + secondCallbackFinished := make(chan struct{}) + go func() { + close(secondCallbackStarted) + callbacks[0](callbackRuleSet) + close(secondCallbackFinished) + }() + + select { + case <-secondCallbackStarted: + case <-time.After(time.Second): + t.Fatal("second rebuild did not start") + } + + select { + case <-secondMetadataEntered: + t.Fatal("second rebuild entered rule-set metadata before the first rebuild completed") + default: + } + + close(releaseFirstRuleSetStart) + + select { + case <-firstCallbackFinished: + case <-time.After(time.Second): + t.Fatal("first rebuild callback did not finish") + } + + select { + case <-secondMetadataEntered: + case <-time.After(time.Second): + t.Fatal("second rebuild did not enter rule-set metadata after the first rebuild finished") + } + + addresses, err = router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) + require.NoError(t, err) + require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses) + require.Equal(t, "first", lastUsedTransport.Load()) + + close(releaseSecondMetadata) + + select { + case <-secondCallbackFinished: + case <-time.After(time.Second): + t.Fatal("second rebuild callback did not finish") + } + + metadataAccess.Lock() + require.Equal(t, 2, metadataCallCount) + require.Equal(t, 1, maximumConcurrentMetadataCalls) + metadataAccess.Unlock() + require.Zero(t, callbackRuleSet.refCount()) + require.Zero(t, firstBuildRuleSet.refCount()) + require.Equal(t, 1, secondBuildRuleSet.refCount()) + + lastUsedTransport.Store("") + addresses, err = router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) + require.NoError(t, err) + require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses) + require.Equal(t, "second", lastUsedTransport.Load()) + + err = router.Close() + require.NoError(t, err) + require.Zero(t, callbackRuleSet.refCount()) + require.Zero(t, firstBuildRuleSet.refCount()) + require.Zero(t, secondBuildRuleSet.refCount()) +} + +func TestCloseDuringRebuildDiscardsResult(t *testing.T) { + t.Parallel() + + fakeSet := &fakeRuleSet{ + metadata: adapter.RuleSetMetadata{ + ContainsIPCIDRRule: true, + }, + } ctx := service.ContextWith[adapter.Router](context.Background(), &fakeRouter{ ruleSets: map[string]adapter.RuleSet{ "dynamic-set": fakeSet, @@ -560,42 +971,73 @@ func TestRuleSetUpdateSetsRuntimeErrorWhenRebuildFails(t *testing.T) { }, DNSRuleAction: option.DNSRuleAction{ Action: C.RuleActionTypeRoute, - RouteOptions: option.DNSRouteActionOptions{Server: "default"}, - }, - }, - }, - { - Type: C.RuleTypeDefault, - DefaultOptions: option.DefaultDNSRule{ - RawDefaultDNSRule: option.RawDefaultDNSRule{ - IPIsPrivate: true, - }, - DNSRuleAction: option.DNSRuleAction{ - Action: C.RuleActionTypeRoute, - RouteOptions: option.DNSRouteActionOptions{Server: "default"}, + RouteOptions: option.DNSRouteActionOptions{Server: "installed"}, }, }, }, }, &fakeDNSTransportManager{ defaultTransport: defaultTransport, transports: map[string]adapter.DNSTransport{ - "default": defaultTransport, + "default": defaultTransport, + "discarded": &fakeDNSTransport{tag: "discarded", transportType: C.DNSTypeUDP}, + "installed": &fakeDNSTransport{tag: "installed", transportType: C.DNSTypeUDP}, }, }, &fakeDNSClient{ - lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) { - response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60) - return MessageToAddresses(response), response, nil + exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) { + switch transport.Tag() { + case "discarded", "installed", "default": + return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60), nil + default: + return nil, E.New("unexpected transport: ", transport.Tag()) + } }, }) - require.True(t, router.legacyDNSMode) + require.Equal(t, 1, fakeSet.refCount()) - fakeSet.updateMetadata(adapter.RuleSetMetadata{ - ContainsDNSQueryTypeRule: true, - }) + callbacks := fakeSet.snapshotCallbacks() + require.Len(t, callbacks, 1) - _, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) - require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response") + firstMetadataEntered := make(chan struct{}) + releaseFirstMetadata := make(chan struct{}) + callbackFinished := make(chan struct{}) + fakeSet.metadataRead = func(metadata adapter.RuleSetMetadata) adapter.RuleSetMetadata { + router.rawRules[0].DefaultOptions.RouteOptions.Server = "discarded" + close(firstMetadataEntered) + <-releaseFirstMetadata + return adapter.RuleSetMetadata{} + } + + go func() { + callbacks[0](fakeSet) + close(callbackFinished) + }() + + select { + case <-firstMetadataEntered: + case <-time.After(time.Second): + t.Fatal("rebuild did not reach rule-set metadata") + } + + err := router.Close() + require.NoError(t, err) + close(releaseFirstMetadata) + + select { + case <-callbackFinished: + case <-time.After(time.Second): + t.Fatal("rebuild callback did not finish after close") + } + + fakeSet.metadataRead = nil + + router.rulesAccess.RLock() + require.True(t, router.closing) + require.Nil(t, router.rules) + require.Empty(t, router.ruleSetCallbacks) + router.rulesAccess.RUnlock() + require.True(t, router.legacyDNSMode) + require.Zero(t, fakeSet.refCount()) } func TestCloseIgnoresSnapshottedRuleSetCallback(t *testing.T) { @@ -661,7 +1103,6 @@ func TestCloseIgnoresSnapshottedRuleSetCallback(t *testing.T) { require.True(t, router.closing) require.Nil(t, router.rules) require.Empty(t, router.ruleSetCallbacks) - require.NoError(t, router.runtimeRuleError) } func TestLookupLegacyDNSModeDefersDirectDestinationIPMatch(t *testing.T) { @@ -680,7 +1121,7 @@ func TestLookupLegacyDNSModeDefersDirectDestinationIPMatch(t *testing.T) { case "default": t.Fatal("default transport should not be used when legacy rule matches after response") } - return nil, nil, errors.New("unexpected transport") + return nil, nil, E.New("unexpected transport") }, } router := newTestRouter(t, []option.DNSRule{{ @@ -716,12 +1157,23 @@ func TestLookupLegacyDNSModeFallsBackAfterRejectedAddressLimitResponse(t *testin defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} privateTransport := &fakeDNSTransport{tag: "private", transportType: C.DNSTypeUDP} - var lookups []string + var lookupAccess sync.Mutex + var lookupTags []string + recordLookup := func(tag string) { + lookupAccess.Lock() + lookupTags = append(lookupTags, tag) + lookupAccess.Unlock() + } + currentLookupTags := func() []string { + lookupAccess.Lock() + defer lookupAccess.Unlock() + return append([]string(nil), lookupTags...) + } client := &fakeDNSClient{ lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) { require.Equal(t, "example.com", domain) require.Equal(t, C.DomainStrategyIPv4Only, options.LookupStrategy) - lookups = append(lookups, transport.Tag()) + recordLookup(transport.Tag()) switch transport.Tag() { case "private": response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60) @@ -730,7 +1182,7 @@ func TestLookupLegacyDNSModeFallsBackAfterRejectedAddressLimitResponse(t *testin response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("9.9.9.9")}, 60) return MessageToAddresses(response), response, nil } - return nil, nil, errors.New("unexpected transport") + return nil, nil, E.New("unexpected transport") }, } router := newTestRouter(t, []option.DNSRule{{ @@ -757,7 +1209,7 @@ func TestLookupLegacyDNSModeFallsBackAfterRejectedAddressLimitResponse(t *testin }) require.NoError(t, err) require.Equal(t, []netip.Addr{netip.MustParseAddr("9.9.9.9")}, addresses) - require.Equal(t, []string{"private", "default"}, lookups) + require.Equal(t, []string{"private", "default"}, currentLookupTags()) } func TestLookupLegacyDNSModeRuleSetAcceptEmptyDoesNotTreatMismatchAsEmpty(t *testing.T) { @@ -785,7 +1237,18 @@ func TestLookupLegacyDNSModeRuleSetAcceptEmptyDoesNotTreatMismatchAsEmpty(t *tes defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} privateTransport := &fakeDNSTransport{tag: "private", transportType: C.DNSTypeUDP} - var lookups []string + var lookupAccess sync.Mutex + var lookupTags []string + recordLookup := func(tag string) { + lookupAccess.Lock() + lookupTags = append(lookupTags, tag) + lookupAccess.Unlock() + } + currentLookupTags := func() []string { + lookupAccess.Lock() + defer lookupAccess.Unlock() + return append([]string(nil), lookupTags...) + } router := newTestRouterWithContext(t, ctx, []option.DNSRule{ { Type: C.RuleTypeDefault, @@ -819,7 +1282,7 @@ func TestLookupLegacyDNSModeRuleSetAcceptEmptyDoesNotTreatMismatchAsEmpty(t *tes lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) { require.Equal(t, "example.com", domain) require.Equal(t, C.DomainStrategyIPv4Only, options.LookupStrategy) - lookups = append(lookups, transport.Tag()) + recordLookup(transport.Tag()) switch transport.Tag() { case "private": response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60) @@ -828,7 +1291,7 @@ func TestLookupLegacyDNSModeRuleSetAcceptEmptyDoesNotTreatMismatchAsEmpty(t *tes response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("9.9.9.9")}, 60) return MessageToAddresses(response), response, nil } - return nil, nil, errors.New("unexpected transport") + return nil, nil, E.New("unexpected transport") }, }) @@ -839,7 +1302,7 @@ func TestLookupLegacyDNSModeRuleSetAcceptEmptyDoesNotTreatMismatchAsEmpty(t *tes }) require.NoError(t, err) require.Equal(t, []netip.Addr{netip.MustParseAddr("9.9.9.9")}, addresses) - require.Equal(t, []string{"private", "default"}, lookups) + require.Equal(t, []string{"private", "default"}, currentLookupTags()) } func TestDNSResponseAddressesMatchesMessageToAddressesForHTTPSHints(t *testing.T) { @@ -872,7 +1335,7 @@ func TestExchangeLegacyDNSModeDisabledEvaluateMatchResponseRoute(t *testing.T) { case "selected": return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60), nil default: - return nil, errors.New("unexpected transport") + return nil, E.New("unexpected transport") } }, } @@ -931,7 +1394,7 @@ func TestExchangeLegacyDNSModeDisabledEvaluateMatchResponseRouteIgnoresTTL(t *te case "selected": return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60), nil default: - return nil, errors.New("unexpected transport") + return nil, E.New("unexpected transport") } }, } @@ -990,7 +1453,7 @@ func TestExchangeLegacyDNSModeDisabledEvaluateMatchResponseRouteWithHTTPSHints(t case "selected": return fixedHTTPSHintResponse(message.Question[0], netip.MustParseAddr("8.8.8.8")), nil default: - return nil, errors.New("unexpected transport") + return nil, E.New("unexpected transport") } }, } @@ -1049,7 +1512,7 @@ func TestExchangeLegacyDNSModeDisabledEvaluateMatchResponseRouteWithMappedHTTPSI case "selected": return fixedHTTPSHintResponse(message.Question[0], netip.MustParseAddr("8.8.8.8")), nil default: - return nil, errors.New("unexpected transport") + return nil, E.New("unexpected transport") } }, } @@ -1119,7 +1582,7 @@ func TestExchangeLegacyDNSModeDisabledEvaluateDoesNotLeakAddressesToNextQuery(t case "selected": return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60), nil default: - return nil, errors.New("unexpected transport") + return nil, E.New("unexpected transport") } }, } @@ -1181,7 +1644,7 @@ func TestExchangeLegacyDNSModeDisabledEvaluateRouteResolutionFailureClearsRespon case "default": return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("4.4.4.4")}, 60), nil default: - return nil, errors.New("unexpected transport") + return nil, E.New("unexpected transport") } }, } @@ -1268,13 +1731,13 @@ func TestExchangeLegacyDNSModeDisabledEvaluateExchangeFailureUsesMatchResponseBo exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) { switch transport.Tag() { case "upstream": - return nil, errors.New("upstream exchange failed") + return nil, E.New("upstream exchange failed") case "selected": return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60), nil case "default": return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("4.4.4.4")}, 60), nil default: - return nil, errors.New("unexpected transport") + return nil, E.New("unexpected transport") } }, } @@ -1332,9 +1795,9 @@ func TestLookupLegacyDNSModeDisabledAllowsPartialSuccess(t *testing.T) { case mDNS.TypeA: return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("1.1.1.1")}, 60), nil case mDNS.TypeAAAA: - return nil, errors.New("ipv6 failed") + return nil, E.New("ipv6 failed") default: - return nil, errors.New("unexpected qtype") + return nil, E.New("unexpected qtype") } }, }) @@ -1451,7 +1914,7 @@ func TestLookupLegacyDNSModeDisabledEvaluateSkipFakeIPPreservesResponse(t *testi } return FixedResponse(0, message.Question[0], nil, 60), nil default: - return nil, errors.New("unexpected transport") + return nil, E.New("unexpected transport") } }, }) @@ -1494,7 +1957,7 @@ func TestLookupLegacyDNSModeDisabledUsesQueryTypeRule(t *testing.T) { case "only-a": return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("9.9.9.9")}, 60), nil default: - return nil, errors.New("unexpected transport") + return nil, E.New("unexpected transport") } }, }) @@ -1560,7 +2023,7 @@ func TestLookupLegacyDNSModeDisabledUsesRuleSetQueryTypeRule(t *testing.T) { } return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2001:db8::9")}, 60), nil default: - return nil, errors.New("unexpected transport") + return nil, E.New("unexpected transport") } }, }) @@ -1609,7 +2072,7 @@ func TestLookupLegacyDNSModeDisabledUsesIPVersionRule(t *testing.T) { } return FixedResponse(0, message.Question[0], nil, 60), nil default: - return nil, errors.New("unexpected transport") + return nil, E.New("unexpected transport") } }, }) @@ -1829,7 +2292,18 @@ func TestLookupLegacyDNSModeDisabledUsesInputStrategy(t *testing.T) { t.Parallel() defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} - var qTypes []uint16 + var queryTypeAccess sync.Mutex + var queryTypes []uint16 + recordQueryType := func(queryType uint16) { + queryTypeAccess.Lock() + queryTypes = append(queryTypes, queryType) + queryTypeAccess.Unlock() + } + currentQueryTypes := func() []uint16 { + queryTypeAccess.Lock() + defer queryTypeAccess.Unlock() + return append([]uint16(nil), queryTypes...) + } router := newTestRouter(t, nil, &fakeDNSTransportManager{ defaultTransport: defaultTransport, transports: map[string]adapter.DNSTransport{ @@ -1837,7 +2311,7 @@ func TestLookupLegacyDNSModeDisabledUsesInputStrategy(t *testing.T) { }, }, &fakeDNSClient{ exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) { - qTypes = append(qTypes, message.Question[0].Qtype) + recordQueryType(message.Question[0].Qtype) if message.Question[0].Qtype == mDNS.TypeA { return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2.2.2.2")}, 60), nil } @@ -1850,7 +2324,7 @@ func TestLookupLegacyDNSModeDisabledUsesInputStrategy(t *testing.T) { Strategy: C.DomainStrategyIPv4Only, }) require.NoError(t, err) - require.Equal(t, []uint16{mDNS.TypeA}, qTypes) + require.Equal(t, []uint16{mDNS.TypeA}, currentQueryTypes()) require.Equal(t, []netip.Addr{netip.MustParseAddr("2.2.2.2")}, addresses) } @@ -1858,7 +2332,18 @@ func TestLookupLegacyDNSModeDisabledUsesDefaultStrategy(t *testing.T) { t.Parallel() defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} - var qTypes []uint16 + var queryTypeAccess sync.Mutex + var queryTypes []uint16 + recordQueryType := func(queryType uint16) { + queryTypeAccess.Lock() + queryTypes = append(queryTypes, queryType) + queryTypeAccess.Unlock() + } + currentQueryTypes := func() []uint16 { + queryTypeAccess.Lock() + defer queryTypeAccess.Unlock() + return append([]uint16(nil), queryTypes...) + } router := newTestRouter(t, nil, &fakeDNSTransportManager{ defaultTransport: defaultTransport, transports: map[string]adapter.DNSTransport{ @@ -1866,7 +2351,7 @@ func TestLookupLegacyDNSModeDisabledUsesDefaultStrategy(t *testing.T) { }, }, &fakeDNSClient{ exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) { - qTypes = append(qTypes, message.Question[0].Qtype) + recordQueryType(message.Question[0].Qtype) if message.Question[0].Qtype == mDNS.TypeA { return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2.2.2.2")}, 60), nil } @@ -1878,7 +2363,7 @@ func TestLookupLegacyDNSModeDisabledUsesDefaultStrategy(t *testing.T) { addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) require.NoError(t, err) - require.Equal(t, []uint16{mDNS.TypeA}, qTypes) + require.Equal(t, []uint16{mDNS.TypeA}, currentQueryTypes()) require.Equal(t, []netip.Addr{netip.MustParseAddr("2.2.2.2")}, addresses) } @@ -1903,7 +2388,7 @@ func TestExchangeLegacyDNSModeDisabledLogicalMatchResponseIPCIDRFallsThrough(t * case "default": return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("4.4.4.4")}, 60), nil default: - return nil, errors.New("unexpected transport") + return nil, E.New("unexpected transport") } }, }