dns: use refcounted snapshot to narrow rule lock scope

Exchange and Lookup held rulesAccess.RLock across all DNS network I/O,
blocking rebuildRules from swapping in new rules until every in-flight
query finished. Replace the RWMutex with an atomic pointer to a
refcounted rulesSnapshot so queries only hold a snapshot reference
during execution, allowing concurrent rule rebuilds.
This commit is contained in:
世界
2026-03-31 15:29:16 +08:00
parent bd222fe9df
commit 866731344f
2 changed files with 349 additions and 88 deletions

View File

@@ -6,6 +6,7 @@ import (
"net/netip"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/sagernet/sing-box/adapter"
@@ -37,6 +38,42 @@ type dnsRuleSetCallback struct {
element *list.Element[adapter.RuleSetUpdateCallback]
}
type rulesSnapshot struct {
rules []adapter.DNSRule
legacyDNSMode bool
references atomic.Int64
}
func newRulesSnapshot(rules []adapter.DNSRule, legacyDNSMode bool) *rulesSnapshot {
snapshot := &rulesSnapshot{
rules: rules,
legacyDNSMode: legacyDNSMode,
}
snapshot.references.Store(1)
return snapshot
}
func (s *rulesSnapshot) retain() {
if s == nil {
return
}
s.references.Add(1)
}
func (s *rulesSnapshot) release() {
if s == nil {
return
}
references := s.references.Add(-1)
switch {
case references > 0:
case references == 0:
closeRules(s.rules)
default:
panic("dns: negative rules snapshot references")
}
}
type Router struct {
ctx context.Context
logger logger.ContextLogger
@@ -44,13 +81,12 @@ type Router struct {
outbound adapter.OutboundManager
client adapter.DNSClient
rawRules []option.DNSRule
rules []adapter.DNSRule
currentRules atomic.Pointer[rulesSnapshot]
defaultDomainStrategy C.DomainStrategy
dnsReverseMapping freelru.Cache[netip.Addr, string]
platformInterface adapter.PlatformInterface
legacyDNSMode bool
rulesAccess sync.RWMutex
rebuildAccess sync.Mutex
stateAccess sync.Mutex
closing bool
ruleSetCallbacks []dnsRuleSetCallback
addressFilterDeprecatedReported bool
@@ -64,9 +100,9 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOp
transport: service.FromContext[adapter.DNSTransportManager](ctx),
outbound: service.FromContext[adapter.OutboundManager](ctx),
rawRules: make([]option.DNSRule, 0, len(options.Rules)),
rules: make([]adapter.DNSRule, 0, len(options.Rules)),
defaultDomainStrategy: C.DomainStrategy(options.Strategy),
}
router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, len(options.Rules)), false))
router.client = NewClient(ClientOptions{
DisableCache: options.DNSClientOptions.DisableCache,
DisableExpire: options.DNSClientOptions.DisableExpire,
@@ -134,26 +170,21 @@ func (r *Router) Start(stage adapter.StartStage) error {
}
func (r *Router) Close() error {
monitor := taskmonitor.New(r.logger, C.StopTimeout)
r.rulesAccess.Lock()
r.stateAccess.Lock()
if r.closing {
r.stateAccess.Unlock()
return nil
}
r.closing = true
callbacks := r.ruleSetCallbacks
r.ruleSetCallbacks = nil
runtimeRules := r.rules
r.rules = nil
oldSnapshot := r.currentRules.Swap(nil)
for _, callback := range callbacks {
callback.ruleSet.UnregisterCallback(callback.element)
}
r.rulesAccess.Unlock()
var err error
for i, rule := range runtimeRules {
monitor.Start("close dns rule[", i, "]")
err = E.Append(err, rule.Close(), func(err error) error {
return E.Cause(err, "close dns rule[", i, "]")
})
monitor.Finish()
}
return err
r.stateAccess.Unlock()
oldSnapshot.release()
return nil
}
func (r *Router) rebuildRules(startRules bool) error {
@@ -177,23 +208,22 @@ func (r *Router) rebuildRules(startRules bool) error {
legacyDNSMode &&
!r.ruleStrategyDeprecatedReported &&
hasDNSRuleActionStrategy(r.rawRules)
r.rulesAccess.Lock()
newSnapshot := newRulesSnapshot(newRules, legacyDNSMode)
r.stateAccess.Lock()
if r.closing {
r.rulesAccess.Unlock()
closeRules(newRules)
r.stateAccess.Unlock()
newSnapshot.release()
return nil
}
oldRules := r.rules
r.rules = newRules
r.legacyDNSMode = legacyDNSMode
if shouldReportAddressFilterDeprecated {
r.addressFilterDeprecatedReported = true
}
if shouldReportRuleStrategyDeprecated {
r.ruleStrategyDeprecatedReported = true
}
r.rulesAccess.Unlock()
closeRules(oldRules)
oldSnapshot := r.currentRules.Swap(newSnapshot)
r.stateAccess.Unlock()
oldSnapshot.release()
if shouldReportAddressFilterDeprecated {
deprecated.Report(r.ctx, deprecated.OptionLegacyDNSAddressFilter)
}
@@ -204,11 +234,19 @@ func (r *Router) rebuildRules(startRules bool) error {
}
func (r *Router) isClosing() bool {
r.rulesAccess.RLock()
defer r.rulesAccess.RUnlock()
r.stateAccess.Lock()
defer r.stateAccess.Unlock()
return r.closing
}
func (r *Router) acquireRulesSnapshot() *rulesSnapshot {
r.stateAccess.Lock()
defer r.stateAccess.Unlock()
snapshot := r.currentRules.Load()
snapshot.retain()
return snapshot
}
func (r *Router) buildRules(startRules bool) ([]adapter.DNSRule, bool, error) {
for i, ruleOptions := range r.rawRules {
err := R.ValidateNoNestedDNSRuleActions(ruleOptions)
@@ -259,12 +297,12 @@ func (r *Router) registerRuleSetCallbacks() (bool, error) {
if len(tags) == 0 {
return false, nil
}
r.rulesAccess.RLock()
r.stateAccess.Lock()
if len(r.ruleSetCallbacks) > 0 {
r.rulesAccess.RUnlock()
r.stateAccess.Unlock()
return true, nil
}
r.rulesAccess.RUnlock()
r.stateAccess.Unlock()
router := service.FromContext[adapter.Router](r.ctx)
if router == nil {
return false, E.New("router service not found")
@@ -289,19 +327,19 @@ func (r *Router) registerRuleSetCallbacks() (bool, error) {
element: element,
})
}
r.rulesAccess.Lock()
r.stateAccess.Lock()
if len(r.ruleSetCallbacks) == 0 {
r.ruleSetCallbacks = callbacks
callbacks = nil
}
r.rulesAccess.Unlock()
r.stateAccess.Unlock()
for _, callback := range callbacks {
callback.ruleSet.UnregisterCallback(callback.element)
}
return true, nil
}
func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int, isAddressQuery bool, options *adapter.DNSQueryOptions) (adapter.DNSTransport, adapter.DNSRule, int) {
func (r *Router) matchDNS(ctx context.Context, rules []adapter.DNSRule, allowFakeIP bool, ruleIndex int, isAddressQuery bool, options *adapter.DNSQueryOptions) (adapter.DNSTransport, adapter.DNSRule, int) {
metadata := adapter.ContextFrom(ctx)
if metadata == nil {
panic("no context")
@@ -310,8 +348,8 @@ func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int,
if ruleIndex != -1 {
currentRuleIndex = ruleIndex + 1
}
for ; currentRuleIndex < len(r.rules); currentRuleIndex++ {
currentRule := r.rules[currentRuleIndex]
for ; currentRuleIndex < len(rules); currentRuleIndex++ {
currentRule := rules[currentRuleIndex]
if currentRule.WithAddressLimit() && !isAddressQuery {
continue
}
@@ -422,14 +460,14 @@ type exchangeWithRulesResult struct {
err error
}
func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, options adapter.DNSQueryOptions, allowFakeIP bool) exchangeWithRulesResult {
func (r *Router) exchangeWithRules(ctx context.Context, rules []adapter.DNSRule, message *mDNS.Msg, options adapter.DNSQueryOptions, allowFakeIP bool) exchangeWithRulesResult {
metadata := adapter.ContextFrom(ctx)
if metadata == nil {
panic("no context")
}
effectiveOptions := options
var savedResponse *mDNS.Msg
for currentRuleIndex, currentRule := range r.rules {
for currentRuleIndex, currentRule := range rules {
metadata.ResetRuleCache()
metadata.DNSResponse = savedResponse
metadata.DestinationAddressMatchFromResponse = false
@@ -578,18 +616,18 @@ func filterAddressesByQueryType(addresses []netip.Addr, qType uint16) []netip.Ad
}
}
func (r *Router) lookupWithRules(ctx context.Context, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) {
func (r *Router) lookupWithRules(ctx context.Context, rules []adapter.DNSRule, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) {
strategy := r.resolveLookupStrategy(options)
lookupOptions := options
if strategy != C.DomainStrategyAsIS {
lookupOptions.Strategy = strategy
}
if strategy == C.DomainStrategyIPv4Only {
response, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeA, lookupOptions)
response, err := r.lookupWithRulesType(ctx, rules, domain, mDNS.TypeA, lookupOptions)
return response.addresses, err
}
if strategy == C.DomainStrategyIPv6Only {
response, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeAAAA, lookupOptions)
response, err := r.lookupWithRulesType(ctx, rules, domain, mDNS.TypeAAAA, lookupOptions)
return response.addresses, err
}
var (
@@ -598,12 +636,12 @@ func (r *Router) lookupWithRules(ctx context.Context, domain string, options ada
)
var group task.Group
group.Append("exchange4", func(ctx context.Context) error {
result, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeA, lookupOptions)
result, err := r.lookupWithRulesType(ctx, rules, domain, mDNS.TypeA, lookupOptions)
response4 = result
return err
})
group.Append("exchange6", func(ctx context.Context) error {
result, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeAAAA, lookupOptions)
result, err := r.lookupWithRulesType(ctx, rules, domain, mDNS.TypeAAAA, lookupOptions)
response6 = result
return err
})
@@ -614,7 +652,7 @@ func (r *Router) lookupWithRules(ctx context.Context, domain string, options ada
return sortAddresses(response4.addresses, response6.addresses, strategy), nil
}
func (r *Router) lookupWithRulesType(ctx context.Context, domain string, qType uint16, options adapter.DNSQueryOptions) (lookupWithRulesResponse, error) {
func (r *Router) lookupWithRulesType(ctx context.Context, rules []adapter.DNSRule, domain string, qType uint16, options adapter.DNSQueryOptions) (lookupWithRulesResponse, error) {
request := &mDNS.Msg{
MsgHdr: mDNS.MsgHdr{
RecursionDesired: true,
@@ -625,7 +663,7 @@ func (r *Router) lookupWithRulesType(ctx context.Context, domain string, qType u
Qclass: mDNS.ClassINET,
}},
}
exchangeResult := r.exchangeWithRules(withLookupQueryMetadata(ctx, qType), request, options, false)
exchangeResult := r.exchangeWithRules(withLookupQueryMetadata(ctx, qType), rules, request, options, false)
result := lookupWithRulesResponse{}
if exchangeResult.rejectAction != nil {
return result, exchangeResult.rejectAction.Error(ctx)
@@ -656,8 +694,16 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte
}
return &responseMessage, nil
}
r.rulesAccess.RLock()
defer r.rulesAccess.RUnlock()
snapshot := r.acquireRulesSnapshot()
defer snapshot.release()
var (
rules []adapter.DNSRule
legacyDNSMode bool
)
if snapshot != nil {
rules = snapshot.rules
legacyDNSMode = snapshot.legacyDNSMode
}
r.logger.DebugContext(ctx, "exchange ", FormatQuestion(message.Question[0].String()))
var (
response *mDNS.Msg
@@ -683,8 +729,8 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte
options.Strategy = r.defaultDomainStrategy
}
response, err = r.client.Exchange(ctx, transport, message, options, nil)
} else if !r.legacyDNSMode {
exchangeResult := r.exchangeWithRules(ctx, message, options, true)
} else if !legacyDNSMode {
exchangeResult := r.exchangeWithRules(ctx, rules, message, options, true)
response, transport, err = exchangeResult.response, exchangeResult.transport, exchangeResult.err
} else {
var (
@@ -695,7 +741,7 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte
for {
dnsCtx := adapter.OverrideContext(ctx)
dnsOptions := options
transport, rule, ruleIndex = r.matchDNS(ctx, true, ruleIndex, isAddressQuery(message), &dnsOptions)
transport, rule, ruleIndex = r.matchDNS(ctx, rules, true, ruleIndex, isAddressQuery(message), &dnsOptions)
if rule != nil {
switch action := rule.Action().(type) {
case *R.RuleActionReject:
@@ -760,8 +806,16 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte
}
func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) {
r.rulesAccess.RLock()
defer r.rulesAccess.RUnlock()
snapshot := r.acquireRulesSnapshot()
defer snapshot.release()
var (
rules []adapter.DNSRule
legacyDNSMode bool
)
if snapshot != nil {
rules = snapshot.rules
legacyDNSMode = snapshot.legacyDNSMode
}
var (
responseAddrs []netip.Addr
err error
@@ -797,8 +851,8 @@ func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQ
options.Strategy = r.defaultDomainStrategy
}
responseAddrs, err = r.client.Lookup(ctx, transport, domain, options, nil)
} else if !r.legacyDNSMode {
responseAddrs, err = r.lookupWithRules(ctx, domain, options)
} else if !legacyDNSMode {
responseAddrs, err = r.lookupWithRules(ctx, rules, domain, options)
} else {
var (
transport adapter.DNSTransport
@@ -809,7 +863,7 @@ func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQ
for {
dnsCtx := adapter.OverrideContext(ctx)
dnsOptions := options
transport, rule, ruleIndex = r.matchDNS(ctx, false, ruleIndex, true, &dnsOptions)
transport, rule, ruleIndex = r.matchDNS(ctx, rules, false, ruleIndex, true, &dnsOptions)
if rule != nil {
switch action := rule.Action().(type) {
case *R.RuleActionReject:

View File

@@ -113,6 +113,7 @@ func (r *fakeRouter) RuleSet(tag string) (adapter.RuleSet, bool) {
ruleSet, loaded := r.ruleSets[tag]
return ruleSet, loaded
}
func (r *fakeRouter) setRuleSet(tag string, ruleSet adapter.RuleSet) {
r.access.Lock()
defer r.access.Unlock()
@@ -135,7 +136,7 @@ type fakeRuleSet struct {
match func(*adapter.InboundContext) bool
callbacks list.List[adapter.RuleSetUpdateCallback]
refs int
afterIncrementReference func()
afterIncrementReference func()
beforeDecrementReference func()
}
@@ -273,9 +274,9 @@ func newTestRouterWithContextAndLogger(t *testing.T, ctx context.Context, rules
transport: transportManager,
client: client,
rawRules: make([]option.DNSRule, 0, len(rules)),
rules: make([]adapter.DNSRule, 0, len(rules)),
defaultDomainStrategy: C.DomainStrategyAsIS,
}
router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, len(rules)), false))
if rules != nil {
err := router.Initialize(rules)
require.NoError(t, err)
@@ -427,9 +428,9 @@ func TestInitializeRejectsInvalidDNSRuleParseError(t *testing.T) {
transport: &fakeDNSTransportManager{},
client: &fakeDNSClient{},
rawRules: make([]option.DNSRule, 0, 1),
rules: make([]adapter.DNSRule, 0, 1),
defaultDomainStrategy: C.DomainStrategyAsIS,
}
router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 1), false))
err := router.Initialize([]option.DNSRule{{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
@@ -474,9 +475,9 @@ func TestInitializeRejectsDirectLegacyRuleWhenRuleSetForcesNew(t *testing.T) {
transport: &fakeDNSTransportManager{},
client: &fakeDNSClient{},
rawRules: make([]option.DNSRule, 0, 2),
rules: make([]adapter.DNSRule, 0, 2),
defaultDomainStrategy: C.DomainStrategyAsIS,
}
router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 2), false))
err = router.Initialize([]option.DNSRule{
{
Type: C.RuleTypeDefault,
@@ -557,7 +558,7 @@ func TestLookupLegacyDNSModeDefersRuleSetDestinationIPMatch(t *testing.T) {
},
})
require.True(t, router.legacyDNSMode)
require.True(t, router.currentRules.Load().legacyDNSMode)
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{
LookupStrategy: C.DomainStrategyIPv4Only,
@@ -700,7 +701,7 @@ func TestRuleSetUpdateKeepsLastSuccessfullyCompiledRuleGraphWhenRebuildFails(t *
require.NoError(t, closeErr)
})
require.True(t, router.legacyDNSMode)
require.True(t, router.currentRules.Load().legacyDNSMode)
require.Equal(t, 1, callbackRuleSet.refCount())
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
@@ -723,7 +724,7 @@ func TestRuleSetUpdateKeepsLastSuccessfullyCompiledRuleGraphWhenRebuildFails(t *
})
rebuildErrorEntry := waitForLogMessageContaining(t, logEntries, logDone, "rebuild DNS rules after rule-set update")
require.Contains(t, rebuildErrorEntry.Message, "ip_cidr and ip_is_private require match_response")
require.True(t, router.legacyDNSMode)
require.True(t, router.currentRules.Load().legacyDNSMode)
require.Equal(t, 1, callbackRuleSet.refCount())
require.Zero(t, rebuildTargetRuleSet.refCount())
@@ -992,7 +993,7 @@ func TestCloseDuringRebuildDiscardsResult(t *testing.T) {
}
},
})
require.True(t, router.legacyDNSMode)
require.True(t, router.currentRules.Load().legacyDNSMode)
require.Equal(t, 1, fakeSet.refCount())
callbacks := fakeSet.snapshotCallbacks()
@@ -1031,12 +1032,11 @@ func TestCloseDuringRebuildDiscardsResult(t *testing.T) {
fakeSet.metadataRead = nil
router.rulesAccess.RLock()
router.stateAccess.Lock()
require.True(t, router.closing)
require.Nil(t, router.rules)
require.Empty(t, router.ruleSetCallbacks)
router.rulesAccess.RUnlock()
require.True(t, router.legacyDNSMode)
router.stateAccess.Unlock()
require.Nil(t, router.currentRules.Load())
require.Zero(t, fakeSet.refCount())
}
@@ -1098,11 +1098,218 @@ func TestCloseIgnoresSnapshottedRuleSetCallback(t *testing.T) {
}
callbacks[0](fakeSet)
router.rulesAccess.RLock()
defer router.rulesAccess.RUnlock()
router.stateAccess.Lock()
require.True(t, router.closing)
require.Nil(t, router.rules)
require.Empty(t, router.ruleSetCallbacks)
router.stateAccess.Unlock()
require.Nil(t, router.currentRules.Load())
}
func TestRuleSetUpdateDoesNotBlockOnInFlightLookup(t *testing.T) {
t.Parallel()
fakeSet := &fakeRuleSet{
metadata: adapter.RuleSetMetadata{
ContainsIPCIDRRule: true,
},
}
ctx := service.ContextWith[adapter.Router](context.Background(), &fakeRouter{
ruleSets: map[string]adapter.RuleSet{
"dynamic-set": fakeSet,
},
})
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
selectedTransport := &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP}
lookupStarted := make(chan struct{})
releaseLookup := make(chan struct{})
router := newTestRouterWithContext(t, ctx, []option.DNSRule{{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
RuleSet: badoption.Listable[string]{"dynamic-set"},
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.DNSRouteActionOptions{Server: "selected"},
},
},
}}, &fakeDNSTransportManager{
defaultTransport: defaultTransport,
transports: map[string]adapter.DNSTransport{
"default": defaultTransport,
"selected": selectedTransport,
},
}, &fakeDNSClient{
lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) {
require.Equal(t, "selected", transport.Tag())
require.Equal(t, "example.com", domain)
require.Equal(t, C.DomainStrategyIPv4Only, options.LookupStrategy)
close(lookupStarted)
<-releaseLookup
response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60)
return MessageToAddresses(response), response, nil
},
})
t.Cleanup(func() {
closeErr := router.Close()
require.NoError(t, closeErr)
})
require.True(t, router.currentRules.Load().legacyDNSMode)
require.Equal(t, 1, fakeSet.refCount())
var (
addresses []netip.Addr
err error
)
lookupDone := make(chan struct{})
go func() {
addresses, err = router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{
LookupStrategy: C.DomainStrategyIPv4Only,
})
close(lookupDone)
}()
select {
case <-lookupStarted:
case <-time.After(time.Second):
t.Fatal("lookup did not reach DNS client")
}
rebuildDone := make(chan struct{})
go func() {
fakeSet.updateMetadata(adapter.RuleSetMetadata{
ContainsIPCIDRRule: true,
})
close(rebuildDone)
}()
select {
case <-rebuildDone:
case <-time.After(time.Second):
t.Fatal("rebuild blocked on in-flight lookup")
}
require.Equal(t, 2, fakeSet.refCount())
select {
case <-lookupDone:
t.Fatal("lookup finished before release")
default:
}
close(releaseLookup)
select {
case <-lookupDone:
case <-time.After(time.Second):
t.Fatal("lookup did not finish after release")
}
require.NoError(t, err)
require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses)
require.Eventually(t, func() bool {
return fakeSet.refCount() == 1
}, time.Second, 10*time.Millisecond)
}
func TestCloseReleasesSnapshottedRulesAfterInFlightLookup(t *testing.T) {
t.Parallel()
fakeSet := &fakeRuleSet{
metadata: adapter.RuleSetMetadata{
ContainsIPCIDRRule: true,
},
}
ctx := service.ContextWith[adapter.Router](context.Background(), &fakeRouter{
ruleSets: map[string]adapter.RuleSet{
"dynamic-set": fakeSet,
},
})
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
selectedTransport := &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP}
lookupStarted := make(chan struct{})
releaseLookup := make(chan struct{})
router := newTestRouterWithContext(t, ctx, []option.DNSRule{{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
RuleSet: badoption.Listable[string]{"dynamic-set"},
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.DNSRouteActionOptions{Server: "selected"},
},
},
}}, &fakeDNSTransportManager{
defaultTransport: defaultTransport,
transports: map[string]adapter.DNSTransport{
"default": defaultTransport,
"selected": selectedTransport,
},
}, &fakeDNSClient{
lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) {
require.Equal(t, "selected", transport.Tag())
require.Equal(t, "example.com", domain)
require.Equal(t, C.DomainStrategyIPv4Only, options.LookupStrategy)
close(lookupStarted)
<-releaseLookup
response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60)
return MessageToAddresses(response), response, nil
},
})
require.True(t, router.currentRules.Load().legacyDNSMode)
require.Equal(t, 1, fakeSet.refCount())
var (
addresses []netip.Addr
lookupErr error
closeErr error
)
lookupDone := make(chan struct{})
go func() {
addresses, lookupErr = router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{
LookupStrategy: C.DomainStrategyIPv4Only,
})
close(lookupDone)
}()
select {
case <-lookupStarted:
case <-time.After(time.Second):
t.Fatal("lookup did not reach DNS client")
}
closeDone := make(chan struct{})
go func() {
closeErr = router.Close()
close(closeDone)
}()
require.Eventually(t, func() bool {
return router.currentRules.Load() == nil && fakeSet.refCount() == 1
}, time.Second, 10*time.Millisecond)
close(releaseLookup)
select {
case <-lookupDone:
case <-time.After(time.Second):
t.Fatal("lookup did not finish after release")
}
select {
case <-closeDone:
case <-time.After(time.Second):
t.Fatal("close did not finish")
}
require.NoError(t, lookupErr)
require.NoError(t, closeErr)
require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses)
require.Eventually(t, func() bool {
return fakeSet.refCount() == 0
}, time.Second, 10*time.Millisecond)
}
func TestLookupLegacyDNSModeDefersDirectDestinationIPMatch(t *testing.T) {
@@ -1143,7 +1350,7 @@ func TestLookupLegacyDNSModeDefersDirectDestinationIPMatch(t *testing.T) {
},
}, client)
require.True(t, router.legacyDNSMode)
require.True(t, router.currentRules.Load().legacyDNSMode)
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{
LookupStrategy: C.DomainStrategyIPv4Only,
@@ -1295,7 +1502,7 @@ func TestLookupLegacyDNSModeRuleSetAcceptEmptyDoesNotTreatMismatchAsEmpty(t *tes
},
})
require.True(t, router.legacyDNSMode)
require.True(t, router.currentRules.Load().legacyDNSMode)
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{
LookupStrategy: C.DomainStrategyIPv4Only,
@@ -1801,7 +2008,7 @@ func TestLookupLegacyDNSModeDisabledAllowsPartialSuccess(t *testing.T) {
}
},
})
router.legacyDNSMode = false
router.currentRules.Load().legacyDNSMode = false
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
require.NoError(t, err)
@@ -1838,7 +2045,7 @@ func TestLookupLegacyDNSModeDisabledSkipsFakeIPRule(t *testing.T) {
return FixedResponse(0, message.Question[0], nil, 60), nil
},
})
router.legacyDNSMode = false
router.currentRules.Load().legacyDNSMode = false
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
require.NoError(t, err)
@@ -1918,7 +2125,7 @@ func TestLookupLegacyDNSModeDisabledEvaluateSkipFakeIPPreservesResponse(t *testi
}
},
})
router.legacyDNSMode = false
router.currentRules.Load().legacyDNSMode = false
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
require.NoError(t, err)
@@ -1961,7 +2168,7 @@ func TestLookupLegacyDNSModeDisabledUsesQueryTypeRule(t *testing.T) {
}
},
})
require.False(t, router.legacyDNSMode)
require.False(t, router.currentRules.Load().legacyDNSMode)
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
require.NoError(t, err)
@@ -2027,7 +2234,7 @@ func TestLookupLegacyDNSModeDisabledUsesRuleSetQueryTypeRule(t *testing.T) {
}
},
})
require.False(t, router.legacyDNSMode)
require.False(t, router.currentRules.Load().legacyDNSMode)
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
require.NoError(t, err)
@@ -2076,7 +2283,7 @@ func TestLookupLegacyDNSModeDisabledUsesIPVersionRule(t *testing.T) {
}
},
})
require.False(t, router.legacyDNSMode)
require.False(t, router.currentRules.Load().legacyDNSMode)
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
require.NoError(t, err)
@@ -2092,9 +2299,9 @@ func TestInitializeRejectsDNSRuleStrategyWhenLegacyDNSModeIsDisabledByEvaluate(t
transport: &fakeDNSTransportManager{},
client: &fakeDNSClient{},
rawRules: make([]option.DNSRule, 0, 1),
rules: make([]adapter.DNSRule, 0, 1),
defaultDomainStrategy: C.DomainStrategyAsIS,
}
router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 1), false))
err := router.Initialize([]option.DNSRule{{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
@@ -2122,9 +2329,9 @@ func TestInitializeRejectsDNSRuleStrategyWhenLegacyDNSModeIsDisabledByMatchRespo
transport: &fakeDNSTransportManager{},
client: &fakeDNSClient{},
rawRules: make([]option.DNSRule, 0, 1),
rules: make([]adapter.DNSRule, 0, 1),
defaultDomainStrategy: C.DomainStrategyAsIS,
}
router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 1), false))
err := router.Initialize([]option.DNSRule{{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
@@ -2175,7 +2382,7 @@ func TestLookupLegacyDNSModeUsesRouteStrategy(t *testing.T) {
},
})
require.True(t, router.legacyDNSMode)
require.True(t, router.currentRules.Load().legacyDNSMode)
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
require.NoError(t, err)
@@ -2207,7 +2414,7 @@ func TestLookupLegacyDNSModeDisabledReturnsRejectedErrorForRejectAction(t *testi
"default": defaultTransport,
},
}, &fakeDNSClient{})
require.False(t, router.legacyDNSMode)
require.False(t, router.currentRules.Load().legacyDNSMode)
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
require.Nil(t, addresses)
@@ -2240,7 +2447,7 @@ func TestExchangeLegacyDNSModeDisabledReturnsRefusedResponseForRejectAction(t *t
"default": defaultTransport,
},
}, &fakeDNSClient{})
require.False(t, router.legacyDNSMode)
require.False(t, router.currentRules.Load().legacyDNSMode)
response, err := router.Exchange(context.Background(), &mDNS.Msg{
Question: []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)},
@@ -2278,7 +2485,7 @@ func TestLookupLegacyDNSModeDisabledFiltersPerQueryTypeAddressesBeforeMerging(t
"default": defaultTransport,
},
}, &fakeDNSClient{})
require.False(t, router.legacyDNSMode)
require.False(t, router.currentRules.Load().legacyDNSMode)
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
require.NoError(t, err)
@@ -2318,7 +2525,7 @@ func TestLookupLegacyDNSModeDisabledUsesInputStrategy(t *testing.T) {
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2001:db8::2")}, 60), nil
},
})
router.legacyDNSMode = false
router.currentRules.Load().legacyDNSMode = false
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{
Strategy: C.DomainStrategyIPv4Only,
@@ -2359,7 +2566,7 @@ func TestLookupLegacyDNSModeDisabledUsesDefaultStrategy(t *testing.T) {
},
})
router.defaultDomainStrategy = C.DomainStrategyIPv4Only
router.legacyDNSMode = false
router.currentRules.Load().legacyDNSMode = false
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
require.NoError(t, err)
@@ -2445,9 +2652,9 @@ func TestLegacyDNSModeReportsLegacyAddressFilterDeprecation(t *testing.T) {
ctx: ctx,
logger: log.NewNOPFactory().NewLogger("dns"),
client: &fakeDNSClient{},
rules: make([]adapter.DNSRule, 0, 1),
defaultDomainStrategy: C.DomainStrategyAsIS,
}
router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 1), false))
err := router.Initialize([]option.DNSRule{{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
@@ -2477,9 +2684,9 @@ func TestLegacyDNSModeReportsDNSRuleStrategyDeprecation(t *testing.T) {
ctx: ctx,
logger: log.NewNOPFactory().NewLogger("dns"),
client: &fakeDNSClient{},
rules: make([]adapter.DNSRule, 0, 1),
defaultDomainStrategy: C.DomainStrategyAsIS,
}
router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 1), false))
err := router.Initialize([]option.DNSRule{{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{