mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-14 04:38:28 +10:00
dns: use refcounted snapshot to narrow rule lock scope
Exchange and Lookup held rulesAccess.RLock across all DNS network I/O, blocking rebuildRules from swapping in new rules until every in-flight query finished. Replace the RWMutex with an atomic pointer to a refcounted rulesSnapshot so queries only hold a snapshot reference during execution, allowing concurrent rule rebuilds.
This commit is contained in:
164
dns/router.go
164
dns/router.go
@@ -6,6 +6,7 @@ import (
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
@@ -37,6 +38,42 @@ type dnsRuleSetCallback struct {
|
||||
element *list.Element[adapter.RuleSetUpdateCallback]
|
||||
}
|
||||
|
||||
type rulesSnapshot struct {
|
||||
rules []adapter.DNSRule
|
||||
legacyDNSMode bool
|
||||
references atomic.Int64
|
||||
}
|
||||
|
||||
func newRulesSnapshot(rules []adapter.DNSRule, legacyDNSMode bool) *rulesSnapshot {
|
||||
snapshot := &rulesSnapshot{
|
||||
rules: rules,
|
||||
legacyDNSMode: legacyDNSMode,
|
||||
}
|
||||
snapshot.references.Store(1)
|
||||
return snapshot
|
||||
}
|
||||
|
||||
func (s *rulesSnapshot) retain() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.references.Add(1)
|
||||
}
|
||||
|
||||
func (s *rulesSnapshot) release() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
references := s.references.Add(-1)
|
||||
switch {
|
||||
case references > 0:
|
||||
case references == 0:
|
||||
closeRules(s.rules)
|
||||
default:
|
||||
panic("dns: negative rules snapshot references")
|
||||
}
|
||||
}
|
||||
|
||||
type Router struct {
|
||||
ctx context.Context
|
||||
logger logger.ContextLogger
|
||||
@@ -44,13 +81,12 @@ type Router struct {
|
||||
outbound adapter.OutboundManager
|
||||
client adapter.DNSClient
|
||||
rawRules []option.DNSRule
|
||||
rules []adapter.DNSRule
|
||||
currentRules atomic.Pointer[rulesSnapshot]
|
||||
defaultDomainStrategy C.DomainStrategy
|
||||
dnsReverseMapping freelru.Cache[netip.Addr, string]
|
||||
platformInterface adapter.PlatformInterface
|
||||
legacyDNSMode bool
|
||||
rulesAccess sync.RWMutex
|
||||
rebuildAccess sync.Mutex
|
||||
stateAccess sync.Mutex
|
||||
closing bool
|
||||
ruleSetCallbacks []dnsRuleSetCallback
|
||||
addressFilterDeprecatedReported bool
|
||||
@@ -64,9 +100,9 @@ func NewRouter(ctx context.Context, logFactory log.Factory, options option.DNSOp
|
||||
transport: service.FromContext[adapter.DNSTransportManager](ctx),
|
||||
outbound: service.FromContext[adapter.OutboundManager](ctx),
|
||||
rawRules: make([]option.DNSRule, 0, len(options.Rules)),
|
||||
rules: make([]adapter.DNSRule, 0, len(options.Rules)),
|
||||
defaultDomainStrategy: C.DomainStrategy(options.Strategy),
|
||||
}
|
||||
router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, len(options.Rules)), false))
|
||||
router.client = NewClient(ClientOptions{
|
||||
DisableCache: options.DNSClientOptions.DisableCache,
|
||||
DisableExpire: options.DNSClientOptions.DisableExpire,
|
||||
@@ -134,26 +170,21 @@ func (r *Router) Start(stage adapter.StartStage) error {
|
||||
}
|
||||
|
||||
func (r *Router) Close() error {
|
||||
monitor := taskmonitor.New(r.logger, C.StopTimeout)
|
||||
r.rulesAccess.Lock()
|
||||
r.stateAccess.Lock()
|
||||
if r.closing {
|
||||
r.stateAccess.Unlock()
|
||||
return nil
|
||||
}
|
||||
r.closing = true
|
||||
callbacks := r.ruleSetCallbacks
|
||||
r.ruleSetCallbacks = nil
|
||||
runtimeRules := r.rules
|
||||
r.rules = nil
|
||||
oldSnapshot := r.currentRules.Swap(nil)
|
||||
for _, callback := range callbacks {
|
||||
callback.ruleSet.UnregisterCallback(callback.element)
|
||||
}
|
||||
r.rulesAccess.Unlock()
|
||||
var err error
|
||||
for i, rule := range runtimeRules {
|
||||
monitor.Start("close dns rule[", i, "]")
|
||||
err = E.Append(err, rule.Close(), func(err error) error {
|
||||
return E.Cause(err, "close dns rule[", i, "]")
|
||||
})
|
||||
monitor.Finish()
|
||||
}
|
||||
return err
|
||||
r.stateAccess.Unlock()
|
||||
oldSnapshot.release()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *Router) rebuildRules(startRules bool) error {
|
||||
@@ -177,23 +208,22 @@ func (r *Router) rebuildRules(startRules bool) error {
|
||||
legacyDNSMode &&
|
||||
!r.ruleStrategyDeprecatedReported &&
|
||||
hasDNSRuleActionStrategy(r.rawRules)
|
||||
r.rulesAccess.Lock()
|
||||
newSnapshot := newRulesSnapshot(newRules, legacyDNSMode)
|
||||
r.stateAccess.Lock()
|
||||
if r.closing {
|
||||
r.rulesAccess.Unlock()
|
||||
closeRules(newRules)
|
||||
r.stateAccess.Unlock()
|
||||
newSnapshot.release()
|
||||
return nil
|
||||
}
|
||||
oldRules := r.rules
|
||||
r.rules = newRules
|
||||
r.legacyDNSMode = legacyDNSMode
|
||||
if shouldReportAddressFilterDeprecated {
|
||||
r.addressFilterDeprecatedReported = true
|
||||
}
|
||||
if shouldReportRuleStrategyDeprecated {
|
||||
r.ruleStrategyDeprecatedReported = true
|
||||
}
|
||||
r.rulesAccess.Unlock()
|
||||
closeRules(oldRules)
|
||||
oldSnapshot := r.currentRules.Swap(newSnapshot)
|
||||
r.stateAccess.Unlock()
|
||||
oldSnapshot.release()
|
||||
if shouldReportAddressFilterDeprecated {
|
||||
deprecated.Report(r.ctx, deprecated.OptionLegacyDNSAddressFilter)
|
||||
}
|
||||
@@ -204,11 +234,19 @@ func (r *Router) rebuildRules(startRules bool) error {
|
||||
}
|
||||
|
||||
func (r *Router) isClosing() bool {
|
||||
r.rulesAccess.RLock()
|
||||
defer r.rulesAccess.RUnlock()
|
||||
r.stateAccess.Lock()
|
||||
defer r.stateAccess.Unlock()
|
||||
return r.closing
|
||||
}
|
||||
|
||||
func (r *Router) acquireRulesSnapshot() *rulesSnapshot {
|
||||
r.stateAccess.Lock()
|
||||
defer r.stateAccess.Unlock()
|
||||
snapshot := r.currentRules.Load()
|
||||
snapshot.retain()
|
||||
return snapshot
|
||||
}
|
||||
|
||||
func (r *Router) buildRules(startRules bool) ([]adapter.DNSRule, bool, error) {
|
||||
for i, ruleOptions := range r.rawRules {
|
||||
err := R.ValidateNoNestedDNSRuleActions(ruleOptions)
|
||||
@@ -259,12 +297,12 @@ func (r *Router) registerRuleSetCallbacks() (bool, error) {
|
||||
if len(tags) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
r.rulesAccess.RLock()
|
||||
r.stateAccess.Lock()
|
||||
if len(r.ruleSetCallbacks) > 0 {
|
||||
r.rulesAccess.RUnlock()
|
||||
r.stateAccess.Unlock()
|
||||
return true, nil
|
||||
}
|
||||
r.rulesAccess.RUnlock()
|
||||
r.stateAccess.Unlock()
|
||||
router := service.FromContext[adapter.Router](r.ctx)
|
||||
if router == nil {
|
||||
return false, E.New("router service not found")
|
||||
@@ -289,19 +327,19 @@ func (r *Router) registerRuleSetCallbacks() (bool, error) {
|
||||
element: element,
|
||||
})
|
||||
}
|
||||
r.rulesAccess.Lock()
|
||||
r.stateAccess.Lock()
|
||||
if len(r.ruleSetCallbacks) == 0 {
|
||||
r.ruleSetCallbacks = callbacks
|
||||
callbacks = nil
|
||||
}
|
||||
r.rulesAccess.Unlock()
|
||||
r.stateAccess.Unlock()
|
||||
for _, callback := range callbacks {
|
||||
callback.ruleSet.UnregisterCallback(callback.element)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int, isAddressQuery bool, options *adapter.DNSQueryOptions) (adapter.DNSTransport, adapter.DNSRule, int) {
|
||||
func (r *Router) matchDNS(ctx context.Context, rules []adapter.DNSRule, allowFakeIP bool, ruleIndex int, isAddressQuery bool, options *adapter.DNSQueryOptions) (adapter.DNSTransport, adapter.DNSRule, int) {
|
||||
metadata := adapter.ContextFrom(ctx)
|
||||
if metadata == nil {
|
||||
panic("no context")
|
||||
@@ -310,8 +348,8 @@ func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int,
|
||||
if ruleIndex != -1 {
|
||||
currentRuleIndex = ruleIndex + 1
|
||||
}
|
||||
for ; currentRuleIndex < len(r.rules); currentRuleIndex++ {
|
||||
currentRule := r.rules[currentRuleIndex]
|
||||
for ; currentRuleIndex < len(rules); currentRuleIndex++ {
|
||||
currentRule := rules[currentRuleIndex]
|
||||
if currentRule.WithAddressLimit() && !isAddressQuery {
|
||||
continue
|
||||
}
|
||||
@@ -422,14 +460,14 @@ type exchangeWithRulesResult struct {
|
||||
err error
|
||||
}
|
||||
|
||||
func (r *Router) exchangeWithRules(ctx context.Context, message *mDNS.Msg, options adapter.DNSQueryOptions, allowFakeIP bool) exchangeWithRulesResult {
|
||||
func (r *Router) exchangeWithRules(ctx context.Context, rules []adapter.DNSRule, message *mDNS.Msg, options adapter.DNSQueryOptions, allowFakeIP bool) exchangeWithRulesResult {
|
||||
metadata := adapter.ContextFrom(ctx)
|
||||
if metadata == nil {
|
||||
panic("no context")
|
||||
}
|
||||
effectiveOptions := options
|
||||
var savedResponse *mDNS.Msg
|
||||
for currentRuleIndex, currentRule := range r.rules {
|
||||
for currentRuleIndex, currentRule := range rules {
|
||||
metadata.ResetRuleCache()
|
||||
metadata.DNSResponse = savedResponse
|
||||
metadata.DestinationAddressMatchFromResponse = false
|
||||
@@ -578,18 +616,18 @@ func filterAddressesByQueryType(addresses []netip.Addr, qType uint16) []netip.Ad
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Router) lookupWithRules(ctx context.Context, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) {
|
||||
func (r *Router) lookupWithRules(ctx context.Context, rules []adapter.DNSRule, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) {
|
||||
strategy := r.resolveLookupStrategy(options)
|
||||
lookupOptions := options
|
||||
if strategy != C.DomainStrategyAsIS {
|
||||
lookupOptions.Strategy = strategy
|
||||
}
|
||||
if strategy == C.DomainStrategyIPv4Only {
|
||||
response, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeA, lookupOptions)
|
||||
response, err := r.lookupWithRulesType(ctx, rules, domain, mDNS.TypeA, lookupOptions)
|
||||
return response.addresses, err
|
||||
}
|
||||
if strategy == C.DomainStrategyIPv6Only {
|
||||
response, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeAAAA, lookupOptions)
|
||||
response, err := r.lookupWithRulesType(ctx, rules, domain, mDNS.TypeAAAA, lookupOptions)
|
||||
return response.addresses, err
|
||||
}
|
||||
var (
|
||||
@@ -598,12 +636,12 @@ func (r *Router) lookupWithRules(ctx context.Context, domain string, options ada
|
||||
)
|
||||
var group task.Group
|
||||
group.Append("exchange4", func(ctx context.Context) error {
|
||||
result, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeA, lookupOptions)
|
||||
result, err := r.lookupWithRulesType(ctx, rules, domain, mDNS.TypeA, lookupOptions)
|
||||
response4 = result
|
||||
return err
|
||||
})
|
||||
group.Append("exchange6", func(ctx context.Context) error {
|
||||
result, err := r.lookupWithRulesType(ctx, domain, mDNS.TypeAAAA, lookupOptions)
|
||||
result, err := r.lookupWithRulesType(ctx, rules, domain, mDNS.TypeAAAA, lookupOptions)
|
||||
response6 = result
|
||||
return err
|
||||
})
|
||||
@@ -614,7 +652,7 @@ func (r *Router) lookupWithRules(ctx context.Context, domain string, options ada
|
||||
return sortAddresses(response4.addresses, response6.addresses, strategy), nil
|
||||
}
|
||||
|
||||
func (r *Router) lookupWithRulesType(ctx context.Context, domain string, qType uint16, options adapter.DNSQueryOptions) (lookupWithRulesResponse, error) {
|
||||
func (r *Router) lookupWithRulesType(ctx context.Context, rules []adapter.DNSRule, domain string, qType uint16, options adapter.DNSQueryOptions) (lookupWithRulesResponse, error) {
|
||||
request := &mDNS.Msg{
|
||||
MsgHdr: mDNS.MsgHdr{
|
||||
RecursionDesired: true,
|
||||
@@ -625,7 +663,7 @@ func (r *Router) lookupWithRulesType(ctx context.Context, domain string, qType u
|
||||
Qclass: mDNS.ClassINET,
|
||||
}},
|
||||
}
|
||||
exchangeResult := r.exchangeWithRules(withLookupQueryMetadata(ctx, qType), request, options, false)
|
||||
exchangeResult := r.exchangeWithRules(withLookupQueryMetadata(ctx, qType), rules, request, options, false)
|
||||
result := lookupWithRulesResponse{}
|
||||
if exchangeResult.rejectAction != nil {
|
||||
return result, exchangeResult.rejectAction.Error(ctx)
|
||||
@@ -656,8 +694,16 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte
|
||||
}
|
||||
return &responseMessage, nil
|
||||
}
|
||||
r.rulesAccess.RLock()
|
||||
defer r.rulesAccess.RUnlock()
|
||||
snapshot := r.acquireRulesSnapshot()
|
||||
defer snapshot.release()
|
||||
var (
|
||||
rules []adapter.DNSRule
|
||||
legacyDNSMode bool
|
||||
)
|
||||
if snapshot != nil {
|
||||
rules = snapshot.rules
|
||||
legacyDNSMode = snapshot.legacyDNSMode
|
||||
}
|
||||
r.logger.DebugContext(ctx, "exchange ", FormatQuestion(message.Question[0].String()))
|
||||
var (
|
||||
response *mDNS.Msg
|
||||
@@ -683,8 +729,8 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte
|
||||
options.Strategy = r.defaultDomainStrategy
|
||||
}
|
||||
response, err = r.client.Exchange(ctx, transport, message, options, nil)
|
||||
} else if !r.legacyDNSMode {
|
||||
exchangeResult := r.exchangeWithRules(ctx, message, options, true)
|
||||
} else if !legacyDNSMode {
|
||||
exchangeResult := r.exchangeWithRules(ctx, rules, message, options, true)
|
||||
response, transport, err = exchangeResult.response, exchangeResult.transport, exchangeResult.err
|
||||
} else {
|
||||
var (
|
||||
@@ -695,7 +741,7 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte
|
||||
for {
|
||||
dnsCtx := adapter.OverrideContext(ctx)
|
||||
dnsOptions := options
|
||||
transport, rule, ruleIndex = r.matchDNS(ctx, true, ruleIndex, isAddressQuery(message), &dnsOptions)
|
||||
transport, rule, ruleIndex = r.matchDNS(ctx, rules, true, ruleIndex, isAddressQuery(message), &dnsOptions)
|
||||
if rule != nil {
|
||||
switch action := rule.Action().(type) {
|
||||
case *R.RuleActionReject:
|
||||
@@ -760,8 +806,16 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte
|
||||
}
|
||||
|
||||
func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) {
|
||||
r.rulesAccess.RLock()
|
||||
defer r.rulesAccess.RUnlock()
|
||||
snapshot := r.acquireRulesSnapshot()
|
||||
defer snapshot.release()
|
||||
var (
|
||||
rules []adapter.DNSRule
|
||||
legacyDNSMode bool
|
||||
)
|
||||
if snapshot != nil {
|
||||
rules = snapshot.rules
|
||||
legacyDNSMode = snapshot.legacyDNSMode
|
||||
}
|
||||
var (
|
||||
responseAddrs []netip.Addr
|
||||
err error
|
||||
@@ -797,8 +851,8 @@ func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQ
|
||||
options.Strategy = r.defaultDomainStrategy
|
||||
}
|
||||
responseAddrs, err = r.client.Lookup(ctx, transport, domain, options, nil)
|
||||
} else if !r.legacyDNSMode {
|
||||
responseAddrs, err = r.lookupWithRules(ctx, domain, options)
|
||||
} else if !legacyDNSMode {
|
||||
responseAddrs, err = r.lookupWithRules(ctx, rules, domain, options)
|
||||
} else {
|
||||
var (
|
||||
transport adapter.DNSTransport
|
||||
@@ -809,7 +863,7 @@ func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQ
|
||||
for {
|
||||
dnsCtx := adapter.OverrideContext(ctx)
|
||||
dnsOptions := options
|
||||
transport, rule, ruleIndex = r.matchDNS(ctx, false, ruleIndex, true, &dnsOptions)
|
||||
transport, rule, ruleIndex = r.matchDNS(ctx, rules, false, ruleIndex, true, &dnsOptions)
|
||||
if rule != nil {
|
||||
switch action := rule.Action().(type) {
|
||||
case *R.RuleActionReject:
|
||||
|
||||
@@ -113,6 +113,7 @@ func (r *fakeRouter) RuleSet(tag string) (adapter.RuleSet, bool) {
|
||||
ruleSet, loaded := r.ruleSets[tag]
|
||||
return ruleSet, loaded
|
||||
}
|
||||
|
||||
func (r *fakeRouter) setRuleSet(tag string, ruleSet adapter.RuleSet) {
|
||||
r.access.Lock()
|
||||
defer r.access.Unlock()
|
||||
@@ -135,7 +136,7 @@ type fakeRuleSet struct {
|
||||
match func(*adapter.InboundContext) bool
|
||||
callbacks list.List[adapter.RuleSetUpdateCallback]
|
||||
refs int
|
||||
afterIncrementReference func()
|
||||
afterIncrementReference func()
|
||||
beforeDecrementReference func()
|
||||
}
|
||||
|
||||
@@ -273,9 +274,9 @@ func newTestRouterWithContextAndLogger(t *testing.T, ctx context.Context, rules
|
||||
transport: transportManager,
|
||||
client: client,
|
||||
rawRules: make([]option.DNSRule, 0, len(rules)),
|
||||
rules: make([]adapter.DNSRule, 0, len(rules)),
|
||||
defaultDomainStrategy: C.DomainStrategyAsIS,
|
||||
}
|
||||
router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, len(rules)), false))
|
||||
if rules != nil {
|
||||
err := router.Initialize(rules)
|
||||
require.NoError(t, err)
|
||||
@@ -427,9 +428,9 @@ func TestInitializeRejectsInvalidDNSRuleParseError(t *testing.T) {
|
||||
transport: &fakeDNSTransportManager{},
|
||||
client: &fakeDNSClient{},
|
||||
rawRules: make([]option.DNSRule, 0, 1),
|
||||
rules: make([]adapter.DNSRule, 0, 1),
|
||||
defaultDomainStrategy: C.DomainStrategyAsIS,
|
||||
}
|
||||
router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 1), false))
|
||||
err := router.Initialize([]option.DNSRule{{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
@@ -474,9 +475,9 @@ func TestInitializeRejectsDirectLegacyRuleWhenRuleSetForcesNew(t *testing.T) {
|
||||
transport: &fakeDNSTransportManager{},
|
||||
client: &fakeDNSClient{},
|
||||
rawRules: make([]option.DNSRule, 0, 2),
|
||||
rules: make([]adapter.DNSRule, 0, 2),
|
||||
defaultDomainStrategy: C.DomainStrategyAsIS,
|
||||
}
|
||||
router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 2), false))
|
||||
err = router.Initialize([]option.DNSRule{
|
||||
{
|
||||
Type: C.RuleTypeDefault,
|
||||
@@ -557,7 +558,7 @@ func TestLookupLegacyDNSModeDefersRuleSetDestinationIPMatch(t *testing.T) {
|
||||
},
|
||||
})
|
||||
|
||||
require.True(t, router.legacyDNSMode)
|
||||
require.True(t, router.currentRules.Load().legacyDNSMode)
|
||||
|
||||
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{
|
||||
LookupStrategy: C.DomainStrategyIPv4Only,
|
||||
@@ -700,7 +701,7 @@ func TestRuleSetUpdateKeepsLastSuccessfullyCompiledRuleGraphWhenRebuildFails(t *
|
||||
require.NoError(t, closeErr)
|
||||
})
|
||||
|
||||
require.True(t, router.legacyDNSMode)
|
||||
require.True(t, router.currentRules.Load().legacyDNSMode)
|
||||
require.Equal(t, 1, callbackRuleSet.refCount())
|
||||
|
||||
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
|
||||
@@ -723,7 +724,7 @@ func TestRuleSetUpdateKeepsLastSuccessfullyCompiledRuleGraphWhenRebuildFails(t *
|
||||
})
|
||||
rebuildErrorEntry := waitForLogMessageContaining(t, logEntries, logDone, "rebuild DNS rules after rule-set update")
|
||||
require.Contains(t, rebuildErrorEntry.Message, "ip_cidr and ip_is_private require match_response")
|
||||
require.True(t, router.legacyDNSMode)
|
||||
require.True(t, router.currentRules.Load().legacyDNSMode)
|
||||
require.Equal(t, 1, callbackRuleSet.refCount())
|
||||
require.Zero(t, rebuildTargetRuleSet.refCount())
|
||||
|
||||
@@ -992,7 +993,7 @@ func TestCloseDuringRebuildDiscardsResult(t *testing.T) {
|
||||
}
|
||||
},
|
||||
})
|
||||
require.True(t, router.legacyDNSMode)
|
||||
require.True(t, router.currentRules.Load().legacyDNSMode)
|
||||
require.Equal(t, 1, fakeSet.refCount())
|
||||
|
||||
callbacks := fakeSet.snapshotCallbacks()
|
||||
@@ -1031,12 +1032,11 @@ func TestCloseDuringRebuildDiscardsResult(t *testing.T) {
|
||||
|
||||
fakeSet.metadataRead = nil
|
||||
|
||||
router.rulesAccess.RLock()
|
||||
router.stateAccess.Lock()
|
||||
require.True(t, router.closing)
|
||||
require.Nil(t, router.rules)
|
||||
require.Empty(t, router.ruleSetCallbacks)
|
||||
router.rulesAccess.RUnlock()
|
||||
require.True(t, router.legacyDNSMode)
|
||||
router.stateAccess.Unlock()
|
||||
require.Nil(t, router.currentRules.Load())
|
||||
require.Zero(t, fakeSet.refCount())
|
||||
}
|
||||
|
||||
@@ -1098,11 +1098,218 @@ func TestCloseIgnoresSnapshottedRuleSetCallback(t *testing.T) {
|
||||
}
|
||||
callbacks[0](fakeSet)
|
||||
|
||||
router.rulesAccess.RLock()
|
||||
defer router.rulesAccess.RUnlock()
|
||||
router.stateAccess.Lock()
|
||||
require.True(t, router.closing)
|
||||
require.Nil(t, router.rules)
|
||||
require.Empty(t, router.ruleSetCallbacks)
|
||||
router.stateAccess.Unlock()
|
||||
require.Nil(t, router.currentRules.Load())
|
||||
}
|
||||
|
||||
func TestRuleSetUpdateDoesNotBlockOnInFlightLookup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fakeSet := &fakeRuleSet{
|
||||
metadata: adapter.RuleSetMetadata{
|
||||
ContainsIPCIDRRule: true,
|
||||
},
|
||||
}
|
||||
ctx := service.ContextWith[adapter.Router](context.Background(), &fakeRouter{
|
||||
ruleSets: map[string]adapter.RuleSet{
|
||||
"dynamic-set": fakeSet,
|
||||
},
|
||||
})
|
||||
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
|
||||
selectedTransport := &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP}
|
||||
lookupStarted := make(chan struct{})
|
||||
releaseLookup := make(chan struct{})
|
||||
router := newTestRouterWithContext(t, ctx, []option.DNSRule{{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
RuleSet: badoption.Listable[string]{"dynamic-set"},
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeRoute,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "selected"},
|
||||
},
|
||||
},
|
||||
}}, &fakeDNSTransportManager{
|
||||
defaultTransport: defaultTransport,
|
||||
transports: map[string]adapter.DNSTransport{
|
||||
"default": defaultTransport,
|
||||
"selected": selectedTransport,
|
||||
},
|
||||
}, &fakeDNSClient{
|
||||
lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) {
|
||||
require.Equal(t, "selected", transport.Tag())
|
||||
require.Equal(t, "example.com", domain)
|
||||
require.Equal(t, C.DomainStrategyIPv4Only, options.LookupStrategy)
|
||||
close(lookupStarted)
|
||||
<-releaseLookup
|
||||
response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60)
|
||||
return MessageToAddresses(response), response, nil
|
||||
},
|
||||
})
|
||||
t.Cleanup(func() {
|
||||
closeErr := router.Close()
|
||||
require.NoError(t, closeErr)
|
||||
})
|
||||
|
||||
require.True(t, router.currentRules.Load().legacyDNSMode)
|
||||
require.Equal(t, 1, fakeSet.refCount())
|
||||
|
||||
var (
|
||||
addresses []netip.Addr
|
||||
err error
|
||||
)
|
||||
lookupDone := make(chan struct{})
|
||||
go func() {
|
||||
addresses, err = router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{
|
||||
LookupStrategy: C.DomainStrategyIPv4Only,
|
||||
})
|
||||
close(lookupDone)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-lookupStarted:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("lookup did not reach DNS client")
|
||||
}
|
||||
|
||||
rebuildDone := make(chan struct{})
|
||||
go func() {
|
||||
fakeSet.updateMetadata(adapter.RuleSetMetadata{
|
||||
ContainsIPCIDRRule: true,
|
||||
})
|
||||
close(rebuildDone)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-rebuildDone:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("rebuild blocked on in-flight lookup")
|
||||
}
|
||||
|
||||
require.Equal(t, 2, fakeSet.refCount())
|
||||
|
||||
select {
|
||||
case <-lookupDone:
|
||||
t.Fatal("lookup finished before release")
|
||||
default:
|
||||
}
|
||||
|
||||
close(releaseLookup)
|
||||
|
||||
select {
|
||||
case <-lookupDone:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("lookup did not finish after release")
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses)
|
||||
require.Eventually(t, func() bool {
|
||||
return fakeSet.refCount() == 1
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestCloseReleasesSnapshottedRulesAfterInFlightLookup(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fakeSet := &fakeRuleSet{
|
||||
metadata: adapter.RuleSetMetadata{
|
||||
ContainsIPCIDRRule: true,
|
||||
},
|
||||
}
|
||||
ctx := service.ContextWith[adapter.Router](context.Background(), &fakeRouter{
|
||||
ruleSets: map[string]adapter.RuleSet{
|
||||
"dynamic-set": fakeSet,
|
||||
},
|
||||
})
|
||||
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
|
||||
selectedTransport := &fakeDNSTransport{tag: "selected", transportType: C.DNSTypeUDP}
|
||||
lookupStarted := make(chan struct{})
|
||||
releaseLookup := make(chan struct{})
|
||||
router := newTestRouterWithContext(t, ctx, []option.DNSRule{{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
RuleSet: badoption.Listable[string]{"dynamic-set"},
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeRoute,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "selected"},
|
||||
},
|
||||
},
|
||||
}}, &fakeDNSTransportManager{
|
||||
defaultTransport: defaultTransport,
|
||||
transports: map[string]adapter.DNSTransport{
|
||||
"default": defaultTransport,
|
||||
"selected": selectedTransport,
|
||||
},
|
||||
}, &fakeDNSClient{
|
||||
lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) {
|
||||
require.Equal(t, "selected", transport.Tag())
|
||||
require.Equal(t, "example.com", domain)
|
||||
require.Equal(t, C.DomainStrategyIPv4Only, options.LookupStrategy)
|
||||
close(lookupStarted)
|
||||
<-releaseLookup
|
||||
response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60)
|
||||
return MessageToAddresses(response), response, nil
|
||||
},
|
||||
})
|
||||
|
||||
require.True(t, router.currentRules.Load().legacyDNSMode)
|
||||
require.Equal(t, 1, fakeSet.refCount())
|
||||
|
||||
var (
|
||||
addresses []netip.Addr
|
||||
lookupErr error
|
||||
closeErr error
|
||||
)
|
||||
lookupDone := make(chan struct{})
|
||||
go func() {
|
||||
addresses, lookupErr = router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{
|
||||
LookupStrategy: C.DomainStrategyIPv4Only,
|
||||
})
|
||||
close(lookupDone)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-lookupStarted:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("lookup did not reach DNS client")
|
||||
}
|
||||
|
||||
closeDone := make(chan struct{})
|
||||
go func() {
|
||||
closeErr = router.Close()
|
||||
close(closeDone)
|
||||
}()
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return router.currentRules.Load() == nil && fakeSet.refCount() == 1
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
close(releaseLookup)
|
||||
|
||||
select {
|
||||
case <-lookupDone:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("lookup did not finish after release")
|
||||
}
|
||||
select {
|
||||
case <-closeDone:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("close did not finish")
|
||||
}
|
||||
|
||||
require.NoError(t, lookupErr)
|
||||
require.NoError(t, closeErr)
|
||||
require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses)
|
||||
require.Eventually(t, func() bool {
|
||||
return fakeSet.refCount() == 0
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
}
|
||||
|
||||
func TestLookupLegacyDNSModeDefersDirectDestinationIPMatch(t *testing.T) {
|
||||
@@ -1143,7 +1350,7 @@ func TestLookupLegacyDNSModeDefersDirectDestinationIPMatch(t *testing.T) {
|
||||
},
|
||||
}, client)
|
||||
|
||||
require.True(t, router.legacyDNSMode)
|
||||
require.True(t, router.currentRules.Load().legacyDNSMode)
|
||||
|
||||
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{
|
||||
LookupStrategy: C.DomainStrategyIPv4Only,
|
||||
@@ -1295,7 +1502,7 @@ func TestLookupLegacyDNSModeRuleSetAcceptEmptyDoesNotTreatMismatchAsEmpty(t *tes
|
||||
},
|
||||
})
|
||||
|
||||
require.True(t, router.legacyDNSMode)
|
||||
require.True(t, router.currentRules.Load().legacyDNSMode)
|
||||
|
||||
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{
|
||||
LookupStrategy: C.DomainStrategyIPv4Only,
|
||||
@@ -1801,7 +2008,7 @@ func TestLookupLegacyDNSModeDisabledAllowsPartialSuccess(t *testing.T) {
|
||||
}
|
||||
},
|
||||
})
|
||||
router.legacyDNSMode = false
|
||||
router.currentRules.Load().legacyDNSMode = false
|
||||
|
||||
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
|
||||
require.NoError(t, err)
|
||||
@@ -1838,7 +2045,7 @@ func TestLookupLegacyDNSModeDisabledSkipsFakeIPRule(t *testing.T) {
|
||||
return FixedResponse(0, message.Question[0], nil, 60), nil
|
||||
},
|
||||
})
|
||||
router.legacyDNSMode = false
|
||||
router.currentRules.Load().legacyDNSMode = false
|
||||
|
||||
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
|
||||
require.NoError(t, err)
|
||||
@@ -1918,7 +2125,7 @@ func TestLookupLegacyDNSModeDisabledEvaluateSkipFakeIPPreservesResponse(t *testi
|
||||
}
|
||||
},
|
||||
})
|
||||
router.legacyDNSMode = false
|
||||
router.currentRules.Load().legacyDNSMode = false
|
||||
|
||||
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
|
||||
require.NoError(t, err)
|
||||
@@ -1961,7 +2168,7 @@ func TestLookupLegacyDNSModeDisabledUsesQueryTypeRule(t *testing.T) {
|
||||
}
|
||||
},
|
||||
})
|
||||
require.False(t, router.legacyDNSMode)
|
||||
require.False(t, router.currentRules.Load().legacyDNSMode)
|
||||
|
||||
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
|
||||
require.NoError(t, err)
|
||||
@@ -2027,7 +2234,7 @@ func TestLookupLegacyDNSModeDisabledUsesRuleSetQueryTypeRule(t *testing.T) {
|
||||
}
|
||||
},
|
||||
})
|
||||
require.False(t, router.legacyDNSMode)
|
||||
require.False(t, router.currentRules.Load().legacyDNSMode)
|
||||
|
||||
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
|
||||
require.NoError(t, err)
|
||||
@@ -2076,7 +2283,7 @@ func TestLookupLegacyDNSModeDisabledUsesIPVersionRule(t *testing.T) {
|
||||
}
|
||||
},
|
||||
})
|
||||
require.False(t, router.legacyDNSMode)
|
||||
require.False(t, router.currentRules.Load().legacyDNSMode)
|
||||
|
||||
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
|
||||
require.NoError(t, err)
|
||||
@@ -2092,9 +2299,9 @@ func TestInitializeRejectsDNSRuleStrategyWhenLegacyDNSModeIsDisabledByEvaluate(t
|
||||
transport: &fakeDNSTransportManager{},
|
||||
client: &fakeDNSClient{},
|
||||
rawRules: make([]option.DNSRule, 0, 1),
|
||||
rules: make([]adapter.DNSRule, 0, 1),
|
||||
defaultDomainStrategy: C.DomainStrategyAsIS,
|
||||
}
|
||||
router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 1), false))
|
||||
err := router.Initialize([]option.DNSRule{{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
@@ -2122,9 +2329,9 @@ func TestInitializeRejectsDNSRuleStrategyWhenLegacyDNSModeIsDisabledByMatchRespo
|
||||
transport: &fakeDNSTransportManager{},
|
||||
client: &fakeDNSClient{},
|
||||
rawRules: make([]option.DNSRule, 0, 1),
|
||||
rules: make([]adapter.DNSRule, 0, 1),
|
||||
defaultDomainStrategy: C.DomainStrategyAsIS,
|
||||
}
|
||||
router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 1), false))
|
||||
err := router.Initialize([]option.DNSRule{{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
@@ -2175,7 +2382,7 @@ func TestLookupLegacyDNSModeUsesRouteStrategy(t *testing.T) {
|
||||
},
|
||||
})
|
||||
|
||||
require.True(t, router.legacyDNSMode)
|
||||
require.True(t, router.currentRules.Load().legacyDNSMode)
|
||||
|
||||
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
|
||||
require.NoError(t, err)
|
||||
@@ -2207,7 +2414,7 @@ func TestLookupLegacyDNSModeDisabledReturnsRejectedErrorForRejectAction(t *testi
|
||||
"default": defaultTransport,
|
||||
},
|
||||
}, &fakeDNSClient{})
|
||||
require.False(t, router.legacyDNSMode)
|
||||
require.False(t, router.currentRules.Load().legacyDNSMode)
|
||||
|
||||
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
|
||||
require.Nil(t, addresses)
|
||||
@@ -2240,7 +2447,7 @@ func TestExchangeLegacyDNSModeDisabledReturnsRefusedResponseForRejectAction(t *t
|
||||
"default": defaultTransport,
|
||||
},
|
||||
}, &fakeDNSClient{})
|
||||
require.False(t, router.legacyDNSMode)
|
||||
require.False(t, router.currentRules.Load().legacyDNSMode)
|
||||
|
||||
response, err := router.Exchange(context.Background(), &mDNS.Msg{
|
||||
Question: []mDNS.Question{fixedQuestion("example.com", mDNS.TypeA)},
|
||||
@@ -2278,7 +2485,7 @@ func TestLookupLegacyDNSModeDisabledFiltersPerQueryTypeAddressesBeforeMerging(t
|
||||
"default": defaultTransport,
|
||||
},
|
||||
}, &fakeDNSClient{})
|
||||
require.False(t, router.legacyDNSMode)
|
||||
require.False(t, router.currentRules.Load().legacyDNSMode)
|
||||
|
||||
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
|
||||
require.NoError(t, err)
|
||||
@@ -2318,7 +2525,7 @@ func TestLookupLegacyDNSModeDisabledUsesInputStrategy(t *testing.T) {
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2001:db8::2")}, 60), nil
|
||||
},
|
||||
})
|
||||
router.legacyDNSMode = false
|
||||
router.currentRules.Load().legacyDNSMode = false
|
||||
|
||||
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{
|
||||
Strategy: C.DomainStrategyIPv4Only,
|
||||
@@ -2359,7 +2566,7 @@ func TestLookupLegacyDNSModeDisabledUsesDefaultStrategy(t *testing.T) {
|
||||
},
|
||||
})
|
||||
router.defaultDomainStrategy = C.DomainStrategyIPv4Only
|
||||
router.legacyDNSMode = false
|
||||
router.currentRules.Load().legacyDNSMode = false
|
||||
|
||||
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
|
||||
require.NoError(t, err)
|
||||
@@ -2445,9 +2652,9 @@ func TestLegacyDNSModeReportsLegacyAddressFilterDeprecation(t *testing.T) {
|
||||
ctx: ctx,
|
||||
logger: log.NewNOPFactory().NewLogger("dns"),
|
||||
client: &fakeDNSClient{},
|
||||
rules: make([]adapter.DNSRule, 0, 1),
|
||||
defaultDomainStrategy: C.DomainStrategyAsIS,
|
||||
}
|
||||
router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 1), false))
|
||||
err := router.Initialize([]option.DNSRule{{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
@@ -2477,9 +2684,9 @@ func TestLegacyDNSModeReportsDNSRuleStrategyDeprecation(t *testing.T) {
|
||||
ctx: ctx,
|
||||
logger: log.NewNOPFactory().NewLogger("dns"),
|
||||
client: &fakeDNSClient{},
|
||||
rules: make([]adapter.DNSRule, 0, 1),
|
||||
defaultDomainStrategy: C.DomainStrategyAsIS,
|
||||
}
|
||||
router.currentRules.Store(newRulesSnapshot(make([]adapter.DNSRule, 0, 1), false))
|
||||
err := router.Initialize([]option.DNSRule{{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
|
||||
Reference in New Issue
Block a user