diff --git a/dns/router.go b/dns/router.go index 8e2424b9f..ae986d712 100644 --- a/dns/router.go +++ b/dns/router.go @@ -6,6 +6,7 @@ import ( "net/netip" "strings" "sync" + "sync/atomic" "time" "github.com/sagernet/sing-box/adapter" @@ -37,6 +38,42 @@ type dnsRuleSetCallback struct { element *list.Element[adapter.RuleSetUpdateCallback] } +type rulesSnapshot struct { + rules []adapter.DNSRule + legacyDNSMode bool + references atomic.Int64 +} + +func newRulesSnapshot(rules []adapter.DNSRule, legacyDNSMode bool) *rulesSnapshot { + snapshot := &rulesSnapshot{ + rules: rules, + legacyDNSMode: legacyDNSMode, + } + snapshot.references.Store(1) + return snapshot +} + +func (s *rulesSnapshot) retain() { + if s == nil { + return + } + s.references.Add(1) +} + +func (s *rulesSnapshot) release() { + if s == nil { + return + } + references := s.references.Add(-1) + switch { + case references > 0: + case references == 0: + closeRules(s.rules) + default: + panic("dns: negative rules snapshot references") + } +} + type Router struct { ctx context.Context logger logger.ContextLogger @@ -44,13 +81,12 @@ type Router struct { outbound adapter.OutboundManager client adapter.DNSClient rawRules []option.DNSRule - rules []adapter.DNSRule + currentRules atomic.Pointer[rulesSnapshot] defaultDomainStrategy C.DomainStrategy dnsReverseMapping freelru.Cache[netip.Addr, string] platformInterface adapter.PlatformInterface - legacyDNSMode bool - rulesAccess sync.RWMutex rebuildAccess sync.Mutex + stateAccess sync.Mutex closing bool ruleSetCallbacks []dnsRuleSetCallback addressFilterDeprecatedReported bool @@ -64,9 +100,9 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOp transport: service.FromContext[adapter.DNSTransportManager](ctx), outbound: service.FromContext[adapter.OutboundManager](ctx), rawRules: make([]option.DNSRule, 0, len(options.Rules)), - rules: make([]adapter.DNSRule, 0, len(options.Rules)), defaultDomainStrategy: C.DomainStrategy(options.Strategy), } + router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, len(options.Rules)), false)) router.client = NewClient(ClientOptions{ DisableCache: options.DNSClientOptions.DisableCache, DisableExpire: options.DNSClientOptions.DisableExpire, @@ -134,26 +170,21 @@ func (r *Router) Start(stage adapter.StartStage) error { } func (r *Router) Close() error { - monitor := taskmonitor.New(r.logger, C.StopTimeout) - r.rulesAccess.Lock() + r.stateAccess.Lock() + if r.closing { + r.stateAccess.Unlock() + return nil + } r.closing = true callbacks := r.ruleSetCallbacks r.ruleSetCallbacks = nil - runtimeRules := r.rules - r.rules = nil + oldSnapshot := r.currentRules.Swap(nil) for _, callback := range callbacks { callback.ruleSet.UnregisterCallback(callback.element) } - r.rulesAccess.Unlock() - var err error - for i, rule := range runtimeRules { - monitor.Start("close dns rule[", i, "]") - err = E.Append(err, rule.Close(), func(err error) error { - return E.Cause(err, "close dns rule[", i, "]") - }) - monitor.Finish() - } - return err + r.stateAccess.Unlock() + oldSnapshot.release() + return nil } func (r *Router) rebuildRules(startRules bool) error { @@ -177,23 +208,22 @@ func (r *Router) rebuildRules(startRules bool) error { legacyDNSMode && !r.ruleStrategyDeprecatedReported && hasDNSRuleActionStrategy(r.rawRules) - r.rulesAccess.Lock() + newSnapshot := newRulesSnapshot(newRules, legacyDNSMode) + r.stateAccess.Lock() if r.closing { - r.rulesAccess.Unlock() - closeRules(newRules) + r.stateAccess.Unlock() + newSnapshot.release() return nil } - oldRules := r.rules - r.rules = newRules - r.legacyDNSMode = legacyDNSMode if shouldReportAddressFilterDeprecated { r.addressFilterDeprecatedReported = true } if shouldReportRuleStrategyDeprecated { r.ruleStrategyDeprecatedReported = true } - r.rulesAccess.Unlock() - closeRules(oldRules) + oldSnapshot := r.currentRules.Swap(newSnapshot) + r.stateAccess.Unlock() + oldSnapshot.release() if shouldReportAddressFilterDeprecated { deprecated.Report(r.ctx, deprecated.OptionLegacyDNSAddressFilter) } @@ -204,11 +234,19 @@ func (r *Router) rebuildRules(startRules bool) error { } func (r *Router) isClosing() bool { - r.rulesAccess.RLock() - defer r.rulesAccess.RUnlock() + r.stateAccess.Lock() + defer r.stateAccess.Unlock() return r.closing } +func (r *Router) acquireRulesSnapshot() *rulesSnapshot { + r.stateAccess.Lock() + defer r.stateAccess.Unlock() + snapshot := r.currentRules.Load() + snapshot.retain() + return snapshot +} + func (r *Router) buildRules(startRules bool) ([]adapter.DNSRule, bool, error) { for i, ruleOptions := range r.rawRules { err := R.ValidateNoNestedDNSRuleActions(ruleOptions) @@ -259,12 +297,12 @@ func (r *Router) registerRuleSetCallbacks() (bool, error) { if len(tags) == 0 { return false, nil } - r.rulesAccess.RLock() + r.stateAccess.Lock() if len(r.ruleSetCallbacks) > 0 { - r.rulesAccess.RUnlock() + r.stateAccess.Unlock() return true, nil } - r.rulesAccess.RUnlock() + r.stateAccess.Unlock() router := service.FromContext[adapter.Router](r.ctx) if router == nil { return false, E.New("router service not found") @@ -289,19 +327,19 @@ func (r *Router) registerRuleSetCallbacks() (bool, error) { element: element, }) } - r.rulesAccess.Lock() + r.stateAccess.Lock() if len(r.ruleSetCallbacks) == 0 { r.ruleSetCallbacks = callbacks callbacks = nil } - r.rulesAccess.Unlock() + r.stateAccess.Unlock() for _, callback := range callbacks { callback.ruleSet.UnregisterCallback(callback.element) } return true, nil } -func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int, isAddressQuery bool, options *adapter.DNSQueryOptions) (adapter.DNSTransport, adapter.DNSRule, int) { +func (r *Router) matchDNS(ctx context.Context, rules []adapter.DNSRule, allowFakeIP bool, ruleIndex int, isAddressQuery bool, options *adapter.DNSQueryOptions) (adapter.DNSTransport, adapter.DNSRule, int) { metadata := adapter.ContextFrom(ctx) if metadata == nil { panic("no context") @@ -310,8 +348,8 @@ func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int, if ruleIndex != -1 { currentRuleIndex = ruleIndex + 1 } - for ; currentRuleIndex < len(r.rules); currentRuleIndex++ { - currentRule := r.rules[currentRuleIndex] + for ; currentRuleIndex < len(rules); currentRuleIndex++ { + currentRule := rules[currentRuleIndex] if currentRule.WithAddressLimit() && !isAddressQuery { continue } @@ -422,14 +460,14 @@ type exchangeWithRulesResult struct { err error } -func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, options adapter.DNSQueryOptions, allowFakeIP bool) exchangeWithRulesResult { +func (r *Router) exchangeWithRules(ctx context.Context, rules []adapter.DNSRule, message *mDNS.Msg, options adapter.DNSQueryOptions, allowFakeIP bool) exchangeWithRulesResult { metadata := adapter.ContextFrom(ctx) if metadata == nil { panic("no context") } effectiveOptions := options var savedResponse *mDNS.Msg - for currentRuleIndex, currentRule := range r.rules { + for currentRuleIndex, currentRule := range rules { metadata.ResetRuleCache() metadata.DNSResponse = savedResponse metadata.DestinationAddressMatchFromResponse = false @@ -578,18 +616,18 @@ func filterAddressesByQueryType(addresses []netip.Addr, qType uint16) []netip.Ad } } -func (r *Router) lookupWithRules(ctx context.Context, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) { +func (r *Router) lookupWithRules(ctx context.Context, rules []adapter.DNSRule, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) { strategy := r.resolveLookupStrategy(options) lookupOptions := options if strategy != C.DomainStrategyAsIS { lookupOptions.Strategy = strategy } if strategy == C.DomainStrategyIPv4Only { - response, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeA, lookupOptions) + response, err := r.lookupWithRulesType(ctx, rules, domain, mDNS.TypeA, lookupOptions) return response.addresses, err } if strategy == C.DomainStrategyIPv6Only { - response, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeAAAA, lookupOptions) + response, err := r.lookupWithRulesType(ctx, rules, domain, mDNS.TypeAAAA, lookupOptions) return response.addresses, err } var ( @@ -598,12 +636,12 @@ func (r *Router) lookupWithRules(ctx context.Context, domain string, options ada ) var group task.Group group.Append("exchange4", func(ctx context.Context) error { - result, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeA, lookupOptions) + result, err := r.lookupWithRulesType(ctx, rules, domain, mDNS.TypeA, lookupOptions) response4 = result return err }) group.Append("exchange6", func(ctx context.Context) error { - result, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeAAAA, lookupOptions) + result, err := r.lookupWithRulesType(ctx, rules, domain, mDNS.TypeAAAA, lookupOptions) response6 = result return err }) @@ -614,7 +652,7 @@ func (r *Router) lookupWithRules(ctx context.Context, domain string, options ada return sortAddresses(response4.addresses, response6.addresses, strategy), nil } -func (r *Router) lookupWithRulesType(ctx context.Context, domain string, qType uint16, options adapter.DNSQueryOptions) (lookupWithRulesResponse, error) { +func (r *Router) lookupWithRulesType(ctx context.Context, rules []adapter.DNSRule, domain string, qType uint16, options adapter.DNSQueryOptions) (lookupWithRulesResponse, error) { request := &mDNS.Msg{ MsgHdr: mDNS.MsgHdr{ RecursionDesired: true, @@ -625,7 +663,7 @@ func (r *Router) lookupWithRulesType(ctx context.Context, domain string, qType u Qclass: mDNS.ClassINET, }}, } - exchangeResult := r.exchangeWithRules(withLookupQueryMetadata(ctx, qType), request, options, false) + exchangeResult := r.exchangeWithRules(withLookupQueryMetadata(ctx, qType), rules, request, options, false) result := lookupWithRulesResponse{} if exchangeResult.rejectAction != nil { return result, exchangeResult.rejectAction.Error(ctx) @@ -656,8 +694,16 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte } return &responseMessage, nil } - r.rulesAccess.RLock() - defer r.rulesAccess.RUnlock() + snapshot := r.acquireRulesSnapshot() + defer snapshot.release() + var ( + rules []adapter.DNSRule + legacyDNSMode bool + ) + if snapshot != nil { + rules = snapshot.rules + legacyDNSMode = snapshot.legacyDNSMode + } r.logger.DebugContext(ctx, "exchange ", FormatQuestion(message.Question[0].String())) var ( response *mDNS.Msg @@ -683,8 +729,8 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte options.Strategy = r.defaultDomainStrategy } response, err = r.client.Exchange(ctx, transport, message, options, nil) - } else if !r.legacyDNSMode { - exchangeResult := r.exchangeWithRules(ctx, message, options, true) + } else if !legacyDNSMode { + exchangeResult := r.exchangeWithRules(ctx, rules, message, options, true) response, transport, err = exchangeResult.response, exchangeResult.transport, exchangeResult.err } else { var ( @@ -695,7 +741,7 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte for { dnsCtx := adapter.OverrideContext(ctx) dnsOptions := options - transport, rule, ruleIndex = r.matchDNS(ctx, true, ruleIndex, isAddressQuery(message), &dnsOptions) + transport, rule, ruleIndex = r.matchDNS(ctx, rules, true, ruleIndex, isAddressQuery(message), &dnsOptions) if rule != nil { switch action := rule.Action().(type) { case *R.RuleActionReject: @@ -760,8 +806,16 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte } func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) { - r.rulesAccess.RLock() - defer r.rulesAccess.RUnlock() + snapshot := r.acquireRulesSnapshot() + defer snapshot.release() + var ( + rules []adapter.DNSRule + legacyDNSMode bool + ) + if snapshot != nil { + rules = snapshot.rules + legacyDNSMode = snapshot.legacyDNSMode + } var ( responseAddrs []netip.Addr err error @@ -797,8 +851,8 @@ func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQ options.Strategy = r.defaultDomainStrategy } responseAddrs, err = r.client.Lookup(ctx, transport, domain, options, nil) - } else if !r.legacyDNSMode { - responseAddrs, err = r.lookupWithRules(ctx, domain, options) + } else if !legacyDNSMode { + responseAddrs, err = r.lookupWithRules(ctx, rules, domain, options) } else { var ( transport adapter.DNSTransport @@ -809,7 +863,7 @@ func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQ for { dnsCtx := adapter.OverrideContext(ctx) dnsOptions := options - transport, rule, ruleIndex = r.matchDNS(ctx, false, ruleIndex, true, &dnsOptions) + transport, rule, ruleIndex = r.matchDNS(ctx, rules, false, ruleIndex, true, &dnsOptions) if rule != nil { switch action := rule.Action().(type) { case *R.RuleActionReject: diff --git a/dns/router_test.go b/dns/router_test.go index 37ebcbc02..81ab358b2 100644 --- a/dns/router_test.go +++ b/dns/router_test.go @@ -113,6 +113,7 @@ func (r *fakeRouter) RuleSet(tag string) (adapter.RuleSet, bool) { ruleSet, loaded := r.ruleSets[tag] return ruleSet, loaded } + func (r *fakeRouter) setRuleSet(tag string, ruleSet adapter.RuleSet) { r.access.Lock() defer r.access.Unlock() @@ -135,7 +136,7 @@ type fakeRuleSet struct { match func(*adapter.InboundContext) bool callbacks list.List[adapter.RuleSetUpdateCallback] refs int - afterIncrementReference func() + afterIncrementReference func() beforeDecrementReference func() } @@ -273,9 +274,9 @@ func newTestRouterWithContextAndLogger(t *testing.T, ctx context.Context, rules transport: transportManager, client: client, rawRules: make([]option.DNSRule, 0, len(rules)), - rules: make([]adapter.DNSRule, 0, len(rules)), defaultDomainStrategy: C.DomainStrategyAsIS, } + router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, len(rules)), false)) if rules != nil { err := router.Initialize(rules) require.NoError(t, err) @@ -427,9 +428,9 @@ func TestInitializeRejectsInvalidDNSRuleParseError(t *testing.T) { transport: &fakeDNSTransportManager{}, client: &fakeDNSClient{}, rawRules: make([]option.DNSRule, 0, 1), - rules: make([]adapter.DNSRule, 0, 1), defaultDomainStrategy: C.DomainStrategyAsIS, } + router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 1), false)) err := router.Initialize([]option.DNSRule{{ Type: C.RuleTypeDefault, DefaultOptions: option.DefaultDNSRule{ @@ -474,9 +475,9 @@ func TestInitializeRejectsDirectLegacyRuleWhenRuleSetForcesNew(t *testing.T) { transport: &fakeDNSTransportManager{}, client: &fakeDNSClient{}, rawRules: make([]option.DNSRule, 0, 2), - rules: make([]adapter.DNSRule, 0, 2), defaultDomainStrategy: C.DomainStrategyAsIS, } + router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 2), false)) err = router.Initialize([]option.DNSRule{ { Type: C.RuleTypeDefault, @@ -557,7 +558,7 @@ func TestLookupLegacyDNSModeDefersRuleSetDestinationIPMatch(t *testing.T) { }, }) - require.True(t, router.legacyDNSMode) + require.True(t, router.currentRules.Load().legacyDNSMode) addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{ LookupStrategy: C.DomainStrategyIPv4Only, @@ -700,7 +701,7 @@ func TestRuleSetUpdateKeepsLastSuccessfullyCompiledRuleGraphWhenRebuildFails(t * require.NoError(t, closeErr) }) - require.True(t, router.legacyDNSMode) + require.True(t, router.currentRules.Load().legacyDNSMode) require.Equal(t, 1, callbackRuleSet.refCount()) addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) @@ -723,7 +724,7 @@ func TestRuleSetUpdateKeepsLastSuccessfullyCompiledRuleGraphWhenRebuildFails(t * }) rebuildErrorEntry := waitForLogMessageContaining(t, logEntries, logDone, "rebuild DNS rules after rule-set update") require.Contains(t, rebuildErrorEntry.Message, "ip_cidr and ip_is_private require match_response") - require.True(t, router.legacyDNSMode) + require.True(t, router.currentRules.Load().legacyDNSMode) require.Equal(t, 1, callbackRuleSet.refCount()) require.Zero(t, rebuildTargetRuleSet.refCount()) @@ -992,7 +993,7 @@ func TestCloseDuringRebuildDiscardsResult(t *testing.T) { } }, }) - require.True(t, router.legacyDNSMode) + require.True(t, router.currentRules.Load().legacyDNSMode) require.Equal(t, 1, fakeSet.refCount()) callbacks := fakeSet.snapshotCallbacks() @@ -1031,12 +1032,11 @@ func TestCloseDuringRebuildDiscardsResult(t *testing.T) { fakeSet.metadataRead = nil - router.rulesAccess.RLock() + router.stateAccess.Lock() require.True(t, router.closing) - require.Nil(t, router.rules) require.Empty(t, router.ruleSetCallbacks) - router.rulesAccess.RUnlock() - require.True(t, router.legacyDNSMode) + router.stateAccess.Unlock() + require.Nil(t, router.currentRules.Load()) require.Zero(t, fakeSet.refCount()) } @@ -1098,11 +1098,218 @@ func TestCloseIgnoresSnapshottedRuleSetCallback(t *testing.T) { } callbacks[0](fakeSet) - router.rulesAccess.RLock() - defer router.rulesAccess.RUnlock() + router.stateAccess.Lock() require.True(t, router.closing) - require.Nil(t, router.rules) require.Empty(t, router.ruleSetCallbacks) + router.stateAccess.Unlock() + require.Nil(t, router.currentRules.Load()) +} + +func TestRuleSetUpdateDoesNotBlockOnInFlightLookup(t *testing.T) { + t.Parallel() + + fakeSet := &fakeRuleSet{ + metadata: adapter.RuleSetMetadata{ + ContainsIPCIDRRule: true, + }, + } + ctx := service.ContextWith[adapter.Router](context.Background(), &fakeRouter{ + ruleSets: map[string]adapter.RuleSet{ + "dynamic-set": fakeSet, + }, + }) + defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} + selectedTransport := &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP} + lookupStarted := make(chan struct{}) + releaseLookup := make(chan struct{}) + router := newTestRouterWithContext(t, ctx, []option.DNSRule{{ + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + RuleSet: badoption.Listable[string]{"dynamic-set"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "selected"}, + }, + }, + }}, &fakeDNSTransportManager{ + defaultTransport: defaultTransport, + transports: map[string]adapter.DNSTransport{ + "default": defaultTransport, + "selected": selectedTransport, + }, + }, &fakeDNSClient{ + lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) { + require.Equal(t, "selected", transport.Tag()) + require.Equal(t, "example.com", domain) + require.Equal(t, C.DomainStrategyIPv4Only, options.LookupStrategy) + close(lookupStarted) + <-releaseLookup + response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60) + return MessageToAddresses(response), response, nil + }, + }) + t.Cleanup(func() { + closeErr := router.Close() + require.NoError(t, closeErr) + }) + + require.True(t, router.currentRules.Load().legacyDNSMode) + require.Equal(t, 1, fakeSet.refCount()) + + var ( + addresses []netip.Addr + err error + ) + lookupDone := make(chan struct{}) + go func() { + addresses, err = router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{ + LookupStrategy: C.DomainStrategyIPv4Only, + }) + close(lookupDone) + }() + + select { + case <-lookupStarted: + case <-time.After(time.Second): + t.Fatal("lookup did not reach DNS client") + } + + rebuildDone := make(chan struct{}) + go func() { + fakeSet.updateMetadata(adapter.RuleSetMetadata{ + ContainsIPCIDRRule: true, + }) + close(rebuildDone) + }() + + select { + case <-rebuildDone: + case <-time.After(time.Second): + t.Fatal("rebuild blocked on in-flight lookup") + } + + require.Equal(t, 2, fakeSet.refCount()) + + select { + case <-lookupDone: + t.Fatal("lookup finished before release") + default: + } + + close(releaseLookup) + + select { + case <-lookupDone: + case <-time.After(time.Second): + t.Fatal("lookup did not finish after release") + } + + require.NoError(t, err) + require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses) + require.Eventually(t, func() bool { + return fakeSet.refCount() == 1 + }, time.Second, 10*time.Millisecond) +} + +func TestCloseReleasesSnapshottedRulesAfterInFlightLookup(t *testing.T) { + t.Parallel() + + fakeSet := &fakeRuleSet{ + metadata: adapter.RuleSetMetadata{ + ContainsIPCIDRRule: true, + }, + } + ctx := service.ContextWith[adapter.Router](context.Background(), &fakeRouter{ + ruleSets: map[string]adapter.RuleSet{ + "dynamic-set": fakeSet, + }, + }) + defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP} + selectedTransport := &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP} + lookupStarted := make(chan struct{}) + releaseLookup := make(chan struct{}) + router := newTestRouterWithContext(t, ctx, []option.DNSRule{{ + Type: C.RuleTypeDefault, + DefaultOptions: option.DefaultDNSRule{ + RawDefaultDNSRule: option.RawDefaultDNSRule{ + RuleSet: badoption.Listable[string]{"dynamic-set"}, + }, + DNSRuleAction: option.DNSRuleAction{ + Action: C.RuleActionTypeRoute, + RouteOptions: option.DNSRouteActionOptions{Server: "selected"}, + }, + }, + }}, &fakeDNSTransportManager{ + defaultTransport: defaultTransport, + transports: map[string]adapter.DNSTransport{ + "default": defaultTransport, + "selected": selectedTransport, + }, + }, &fakeDNSClient{ + lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) { + require.Equal(t, "selected", transport.Tag()) + require.Equal(t, "example.com", domain) + require.Equal(t, C.DomainStrategyIPv4Only, options.LookupStrategy) + close(lookupStarted) + <-releaseLookup + response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60) + return MessageToAddresses(response), response, nil + }, + }) + + require.True(t, router.currentRules.Load().legacyDNSMode) + require.Equal(t, 1, fakeSet.refCount()) + + var ( + addresses []netip.Addr + lookupErr error + closeErr error + ) + lookupDone := make(chan struct{}) + go func() { + addresses, lookupErr = router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{ + LookupStrategy: C.DomainStrategyIPv4Only, + }) + close(lookupDone) + }() + + select { + case <-lookupStarted: + case <-time.After(time.Second): + t.Fatal("lookup did not reach DNS client") + } + + closeDone := make(chan struct{}) + go func() { + closeErr = router.Close() + close(closeDone) + }() + + require.Eventually(t, func() bool { + return router.currentRules.Load() == nil && fakeSet.refCount() == 1 + }, time.Second, 10*time.Millisecond) + + close(releaseLookup) + + select { + case <-lookupDone: + case <-time.After(time.Second): + t.Fatal("lookup did not finish after release") + } + select { + case <-closeDone: + case <-time.After(time.Second): + t.Fatal("close did not finish") + } + + require.NoError(t, lookupErr) + require.NoError(t, closeErr) + require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses) + require.Eventually(t, func() bool { + return fakeSet.refCount() == 0 + }, time.Second, 10*time.Millisecond) } func TestLookupLegacyDNSModeDefersDirectDestinationIPMatch(t *testing.T) { @@ -1143,7 +1350,7 @@ func TestLookupLegacyDNSModeDefersDirectDestinationIPMatch(t *testing.T) { }, }, client) - require.True(t, router.legacyDNSMode) + require.True(t, router.currentRules.Load().legacyDNSMode) addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{ LookupStrategy: C.DomainStrategyIPv4Only, @@ -1295,7 +1502,7 @@ func TestLookupLegacyDNSModeRuleSetAcceptEmptyDoesNotTreatMismatchAsEmpty(t *tes }, }) - require.True(t, router.legacyDNSMode) + require.True(t, router.currentRules.Load().legacyDNSMode) addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{ LookupStrategy: C.DomainStrategyIPv4Only, @@ -1801,7 +2008,7 @@ func TestLookupLegacyDNSModeDisabledAllowsPartialSuccess(t *testing.T) { } }, }) - router.legacyDNSMode = false + router.currentRules.Load().legacyDNSMode = false addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) require.NoError(t, err) @@ -1838,7 +2045,7 @@ func TestLookupLegacyDNSModeDisabledSkipsFakeIPRule(t *testing.T) { return FixedResponse(0, message.Question[0], nil, 60), nil }, }) - router.legacyDNSMode = false + router.currentRules.Load().legacyDNSMode = false addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) require.NoError(t, err) @@ -1918,7 +2125,7 @@ func TestLookupLegacyDNSModeDisabledEvaluateSkipFakeIPPreservesResponse(t *testi } }, }) - router.legacyDNSMode = false + router.currentRules.Load().legacyDNSMode = false addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) require.NoError(t, err) @@ -1961,7 +2168,7 @@ func TestLookupLegacyDNSModeDisabledUsesQueryTypeRule(t *testing.T) { } }, }) - require.False(t, router.legacyDNSMode) + require.False(t, router.currentRules.Load().legacyDNSMode) addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) require.NoError(t, err) @@ -2027,7 +2234,7 @@ func TestLookupLegacyDNSModeDisabledUsesRuleSetQueryTypeRule(t *testing.T) { } }, }) - require.False(t, router.legacyDNSMode) + require.False(t, router.currentRules.Load().legacyDNSMode) addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) require.NoError(t, err) @@ -2076,7 +2283,7 @@ func TestLookupLegacyDNSModeDisabledUsesIPVersionRule(t *testing.T) { } }, }) - require.False(t, router.legacyDNSMode) + require.False(t, router.currentRules.Load().legacyDNSMode) addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) require.NoError(t, err) @@ -2092,9 +2299,9 @@ func TestInitializeRejectsDNSRuleStrategyWhenLegacyDNSModeIsDisabledByEvaluate(t transport: &fakeDNSTransportManager{}, client: &fakeDNSClient{}, rawRules: make([]option.DNSRule, 0, 1), - rules: make([]adapter.DNSRule, 0, 1), defaultDomainStrategy: C.DomainStrategyAsIS, } + router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 1), false)) err := router.Initialize([]option.DNSRule{{ Type: C.RuleTypeDefault, DefaultOptions: option.DefaultDNSRule{ @@ -2122,9 +2329,9 @@ func TestInitializeRejectsDNSRuleStrategyWhenLegacyDNSModeIsDisabledByMatchRespo transport: &fakeDNSTransportManager{}, client: &fakeDNSClient{}, rawRules: make([]option.DNSRule, 0, 1), - rules: make([]adapter.DNSRule, 0, 1), defaultDomainStrategy: C.DomainStrategyAsIS, } + router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 1), false)) err := router.Initialize([]option.DNSRule{{ Type: C.RuleTypeDefault, DefaultOptions: option.DefaultDNSRule{ @@ -2175,7 +2382,7 @@ func TestLookupLegacyDNSModeUsesRouteStrategy(t *testing.T) { }, }) - require.True(t, router.legacyDNSMode) + require.True(t, router.currentRules.Load().legacyDNSMode) addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) require.NoError(t, err) @@ -2207,7 +2414,7 @@ func TestLookupLegacyDNSModeDisabledReturnsRejectedErrorForRejectAction(t *testi "default": defaultTransport, }, }, &fakeDNSClient{}) - require.False(t, router.legacyDNSMode) + require.False(t, router.currentRules.Load().legacyDNSMode) addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) require.Nil(t, addresses) @@ -2240,7 +2447,7 @@ func TestExchangeLegacyDNSModeDisabledReturnsRefusedResponseForRejectAction(t *t "default": defaultTransport, }, }, &fakeDNSClient{}) - require.False(t, router.legacyDNSMode) + require.False(t, router.currentRules.Load().legacyDNSMode) response, err := router.Exchange(context.Background(), &mDNS.Msg{ Question: []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)}, @@ -2278,7 +2485,7 @@ func TestLookupLegacyDNSModeDisabledFiltersPerQueryTypeAddressesBeforeMerging(t "default": defaultTransport, }, }, &fakeDNSClient{}) - require.False(t, router.legacyDNSMode) + require.False(t, router.currentRules.Load().legacyDNSMode) addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) require.NoError(t, err) @@ -2318,7 +2525,7 @@ func TestLookupLegacyDNSModeDisabledUsesInputStrategy(t *testing.T) { return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2001:db8::2")}, 60), nil }, }) - router.legacyDNSMode = false + router.currentRules.Load().legacyDNSMode = false addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{ Strategy: C.DomainStrategyIPv4Only, @@ -2359,7 +2566,7 @@ func TestLookupLegacyDNSModeDisabledUsesDefaultStrategy(t *testing.T) { }, }) router.defaultDomainStrategy = C.DomainStrategyIPv4Only - router.legacyDNSMode = false + router.currentRules.Load().legacyDNSMode = false addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{}) require.NoError(t, err) @@ -2445,9 +2652,9 @@ func TestLegacyDNSModeReportsLegacyAddressFilterDeprecation(t *testing.T) { ctx: ctx, logger: log.NewNOPFactory().NewLogger("dns"), client: &fakeDNSClient{}, - rules: make([]adapter.DNSRule, 0, 1), defaultDomainStrategy: C.DomainStrategyAsIS, } + router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 1), false)) err := router.Initialize([]option.DNSRule{{ Type: C.RuleTypeDefault, DefaultOptions: option.DefaultDNSRule{ @@ -2477,9 +2684,9 @@ func TestLegacyDNSModeReportsDNSRuleStrategyDeprecation(t *testing.T) { ctx: ctx, logger: log.NewNOPFactory().NewLogger("dns"), client: &fakeDNSClient{}, - rules: make([]adapter.DNSRule, 0, 1), defaultDomainStrategy: C.DomainStrategyAsIS, } + router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 1), false)) err := router.Initialize([]option.DNSRule{{ Type: C.RuleTypeDefault, DefaultOptions: option.DefaultDNSRule{