dns: serialize rebuilds and keep last good rules on failure

This commit is contained in:
世界
2026-03-31 13:15:25 +08:00
parent b44cf24745
commit bd222fe9df
2 changed files with 574 additions and 90 deletions

View File

@@ -50,9 +50,9 @@ type Router struct {
platformInterface adapter.PlatformInterface
legacyDNSMode bool
rulesAccess sync.RWMutex
rebuildAccess sync.Mutex
closing bool
ruleSetCallbacks []dnsRuleSetCallback
runtimeRuleError error
addressFilterDeprecatedReported bool
ruleStrategyDeprecatedReported bool
}
@@ -116,11 +116,19 @@ func (r *Router) Start(stage adapter.StartStage) error {
return err
}
monitor.Start("register DNS rule-set callbacks")
err = r.registerRuleSetCallbacks()
needsRulesRefresh, err := r.registerRuleSetCallbacks()
monitor.Finish()
if err != nil {
return err
}
if needsRulesRefresh {
monitor.Start("refresh DNS rules after callback registration")
err = r.rebuildRules(true)
monitor.Finish()
if err != nil {
r.logger.Error(E.Cause(err, "refresh DNS rules after callback registration"))
}
}
}
return nil
}
@@ -133,7 +141,6 @@ func (r *Router) Close() error {
r.ruleSetCallbacks = nil
runtimeRules := r.rules
r.rules = nil
r.runtimeRuleError = nil
for _, callback := range callbacks {
callback.ruleSet.UnregisterCallback(callback.element)
}
@@ -150,6 +157,8 @@ func (r *Router) Close() error {
}
func (r *Router) rebuildRules(startRules bool) error {
r.rebuildAccess.Lock()
defer r.rebuildAccess.Unlock()
if r.isClosing() {
return nil
}
@@ -177,7 +186,6 @@ func (r *Router) rebuildRules(startRules bool) error {
oldRules := r.rules
r.rules = newRules
r.legacyDNSMode = legacyDNSMode
r.runtimeRuleError = nil
if shouldReportAddressFilterDeprecated {
r.addressFilterDeprecatedReported = true
}
@@ -246,20 +254,20 @@ func closeRules(rules []adapter.DNSRule) {
}
}
func (r *Router) registerRuleSetCallbacks() error {
func (r *Router) registerRuleSetCallbacks() (bool, error) {
tags := referencedDNSRuleSetTags(r.rawRules)
if len(tags) == 0 {
return nil
return false, nil
}
r.rulesAccess.RLock()
if len(r.ruleSetCallbacks) > 0 {
r.rulesAccess.RUnlock()
return nil
return true, nil
}
r.rulesAccess.RUnlock()
router := service.FromContext[adapter.Router](r.ctx)
if router == nil {
return E.New("router service not found")
return false, E.New("router service not found")
}
callbacks := make([]dnsRuleSetCallback, 0, len(tags))
for _, tag := range tags {
@@ -268,14 +276,11 @@ func (r *Router) registerRuleSetCallbacks() error {
for _, callback := range callbacks {
callback.ruleSet.UnregisterCallback(callback.element)
}
return E.New("rule-set not found: ", tag)
return false, E.New("rule-set not found: ", tag)
}
element := ruleSet.RegisterCallback(func(adapter.RuleSet) {
err := r.rebuildRules(true)
if err != nil {
r.rulesAccess.Lock()
r.runtimeRuleError = err
r.rulesAccess.Unlock()
r.logger.Error(E.Cause(err, "rebuild DNS rules after rule-set update"))
}
})
@@ -293,7 +298,7 @@ func (r *Router) registerRuleSetCallbacks() error {
for _, callback := range callbacks {
callback.ruleSet.UnregisterCallback(callback.element)
}
return nil
return true, nil
}
func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int, isAddressQuery bool, options *adapter.DNSQueryOptions) (adapter.DNSTransport, adapter.DNSRule, int) {
@@ -653,9 +658,6 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte
}
r.rulesAccess.RLock()
defer r.rulesAccess.RUnlock()
if r.runtimeRuleError != nil {
return nil, r.runtimeRuleError
}
r.logger.DebugContext(ctx, "exchange ", FormatQuestion(message.Question[0].String()))
var (
response *mDNS.Msg
@@ -760,9 +762,6 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte
func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) {
r.rulesAccess.RLock()
defer r.rulesAccess.RUnlock()
if r.runtimeRuleError != nil {
return nil, r.runtimeRuleError
}
var (
responseAddrs []netip.Addr
err error

View File

@@ -2,9 +2,10 @@ package dns
import (
"context"
"errors"
"io"
"net"
"net/netip"
"strings"
"sync"
"testing"
"time"
@@ -16,6 +17,8 @@ import (
"github.com/sagernet/sing-box/option"
rulepkg "github.com/sagernet/sing-box/route/rule"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json/badoption"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/x/list"
@@ -38,7 +41,7 @@ func (t *fakeDNSTransport) Tag() string { return t.tag }
func (t *fakeDNSTransport) Dependencies() []string { return nil }
func (t *fakeDNSTransport) Reset() {}
func (t *fakeDNSTransport) Exchange(context.Context, *mDNS.Msg) (*mDNS.Msg, error) {
return nil, errors.New("unused transport exchange")
return nil, E.New("unused transport exchange")
}
type fakeDNSTransportManager struct {
@@ -66,7 +69,7 @@ func (m *fakeDNSTransportManager) FakeIP() adapter.FakeIPTransport {
}
func (m *fakeDNSTransportManager) Remove(string) error { return nil }
func (m *fakeDNSTransportManager) Create(context.Context, log.ContextLogger, string, string, any) error {
return errors.New("unsupported")
return E.New("unsupported")
}
type fakeDNSClient struct {
@@ -80,6 +83,7 @@ type fakeDeprecatedManager struct {
}
type fakeRouter struct {
access sync.RWMutex
ruleSets map[string]adapter.RuleSet
}
@@ -104,9 +108,19 @@ func (r *fakeRouter) RoutePacketConnectionEx(context.Context, N.PacketConn, adap
}
func (r *fakeRouter) RuleSet(tag string) (adapter.RuleSet, bool) {
r.access.RLock()
defer r.access.RUnlock()
ruleSet, loaded := r.ruleSets[tag]
return ruleSet, loaded
}
func (r *fakeRouter) setRuleSet(tag string, ruleSet adapter.RuleSet) {
r.access.Lock()
defer r.access.Unlock()
if r.ruleSets == nil {
r.ruleSets = make(map[string]adapter.RuleSet)
}
r.ruleSets[tag] = ruleSet
}
func (r *fakeRouter) Rules() []adapter.Rule { return nil }
func (r *fakeRouter) NeedFindProcess() bool { return false }
func (r *fakeRouter) NeedFindNeighbor() bool { return false }
@@ -115,10 +129,14 @@ func (r *fakeRouter) AppendTracker(adapter.ConnectionTracker) {}
func (r *fakeRouter) ResetNetwork() {}
type fakeRuleSet struct {
access sync.Mutex
metadata adapter.RuleSetMetadata
callbacks list.List[adapter.RuleSetUpdateCallback]
refs int
access sync.Mutex
metadata adapter.RuleSetMetadata
metadataRead func(adapter.RuleSetMetadata) adapter.RuleSetMetadata
match func(*adapter.InboundContext) bool
callbacks list.List[adapter.RuleSetUpdateCallback]
refs int
afterIncrementReference func()
beforeDecrementReference func()
}
func (s *fakeRuleSet) Name() string { return "fake-rule-set" }
@@ -126,17 +144,32 @@ func (s *fakeRuleSet) StartContext(context.Context, *adapter.HTTPStartContext) e
func (s *fakeRuleSet) PostStart() error { return nil }
func (s *fakeRuleSet) Metadata() adapter.RuleSetMetadata {
s.access.Lock()
defer s.access.Unlock()
return s.metadata
metadata := s.metadata
metadataRead := s.metadataRead
s.access.Unlock()
if metadataRead != nil {
return metadataRead(metadata)
}
return metadata
}
func (s *fakeRuleSet) ExtractIPSet() []*netipx.IPSet { return nil }
func (s *fakeRuleSet) IncRef() {
s.access.Lock()
defer s.access.Unlock()
s.refs++
afterIncrementReference := s.afterIncrementReference
s.access.Unlock()
if afterIncrementReference != nil {
afterIncrementReference()
}
}
func (s *fakeRuleSet) DecRef() {
s.access.Lock()
beforeDecrementReference := s.beforeDecrementReference
s.access.Unlock()
if beforeDecrementReference != nil {
beforeDecrementReference()
}
s.access.Lock()
defer s.access.Unlock()
s.refs--
@@ -156,9 +189,17 @@ func (s *fakeRuleSet) UnregisterCallback(element *list.Element[adapter.RuleSetUp
defer s.access.Unlock()
s.callbacks.Remove(element)
}
func (s *fakeRuleSet) Close() error { return nil }
func (s *fakeRuleSet) Match(*adapter.InboundContext) bool { return true }
func (s *fakeRuleSet) String() string { return "fake-rule-set" }
func (s *fakeRuleSet) Close() error { return nil }
func (s *fakeRuleSet) Match(metadata *adapter.InboundContext) bool {
s.access.Lock()
match := s.match
s.access.Unlock()
if match != nil {
return match(metadata)
}
return true
}
func (s *fakeRuleSet) String() string { return "fake-rule-set" }
func (s *fakeRuleSet) updateMetadata(metadata adapter.RuleSetMetadata) {
s.access.Lock()
s.metadata = metadata
@@ -196,7 +237,7 @@ func (c *fakeDNSClient) Exchange(ctx context.Context, transport adapter.DNSTrans
func (c *fakeDNSClient) Lookup(_ context.Context, transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions, responseChecker func(*mDNS.Msg) bool) ([]netip.Addr, error) {
if c.lookup == nil {
return nil, errors.New("unused client lookup")
return nil, E.New("unused client lookup")
}
addresses, response, err := c.lookup(transport, domain, options)
if err != nil {
@@ -221,10 +262,14 @@ func newTestRouter(t *testing.T, rules []option.DNSRule, transportManager *fakeD
}
func newTestRouterWithContext(t *testing.T, ctx context.Context, rules []option.DNSRule, transportManager *fakeDNSTransportManager, client *fakeDNSClient) *Router {
return newTestRouterWithContextAndLogger(t, ctx, rules, transportManager, client, log.NewNOPFactory().NewLogger("dns"))
}
func newTestRouterWithContextAndLogger(t *testing.T, ctx context.Context, rules []option.DNSRule, transportManager *fakeDNSTransportManager, client *fakeDNSClient, dnsLogger log.ContextLogger) *Router {
t.Helper()
router := &Router{
ctx: ctx,
logger: log.NewNOPFactory().NewLogger("dns"),
logger: dnsLogger,
transport: transportManager,
client: client,
rawRules: make([]option.DNSRule, 0, len(rules)),
@@ -240,6 +285,26 @@ func newTestRouterWithContext(t *testing.T, ctx context.Context, rules []option.
return router
}
func waitForLogMessageContaining(t *testing.T, entries <-chan log.Entry, done <-chan struct{}, substring string) log.Entry {
t.Helper()
timeout := time.After(time.Second)
for {
select {
case entry, ok := <-entries:
if !ok {
t.Fatal("log subscription closed")
}
if strings.Contains(entry.Message, substring) {
return entry
}
case <-done:
t.Fatal("log subscription closed")
case <-timeout:
t.Fatalf("timed out waiting for log message containing %q", substring)
}
}
}
func fixedQuestion(name string, qType uint16) mDNS.Question {
return mDNS.Question{
Name: mDNS.Fqdn(name),
@@ -541,10 +606,356 @@ func TestRuleSetUpdateReleasesOldRuleSetRefs(t *testing.T) {
require.Zero(t, fakeSet.refCount())
}
func TestRuleSetUpdateSetsRuntimeErrorWhenRebuildFails(t *testing.T) {
func TestRuleSetUpdateKeepsLastSuccessfullyCompiledRuleGraphWhenRebuildFails(t *testing.T) {
t.Parallel()
fakeSet := &fakeRuleSet{}
callbackRuleSet := &fakeRuleSet{
match: func(*adapter.InboundContext) bool {
return false
},
}
routerService := &fakeRouter{
ruleSets: map[string]adapter.RuleSet{
"dynamic-set": callbackRuleSet,
},
}
ctx := service.ContextWith[adapter.Router](context.Background(), routerService)
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
preservedTransport := &fakeDNSTransport{tag: "preserved", transportType: C.DNSTypeUDP}
wouldBeNewTransport := &fakeDNSTransport{tag: "would-be-new", transportType: C.DNSTypeUDP}
loggerFactory := log.NewDefaultFactory(
context.Background(),
log.Formatter{
BaseTime: time.Now(),
DisableColors: true,
DisableTimestamp: true,
},
io.Discard,
"",
nil,
true,
)
loggerFactory.SetLevel(log.LevelError)
logEntries, logDone, err := loggerFactory.Subscribe()
require.NoError(t, err)
t.Cleanup(func() {
loggerFactory.UnSubscribe(logEntries)
closeErr := loggerFactory.Close()
require.NoError(t, closeErr)
})
var lastUsedTransport common.TypedValue[string]
router := newTestRouterWithContextAndLogger(t, ctx, []option.DNSRule{
{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
RuleSet: badoption.Listable[string]{"dynamic-set"},
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.DNSRouteActionOptions{Server: "would-be-new"},
},
},
},
{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
Domain: badoption.Listable[string]{"example.com"},
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.DNSRouteActionOptions{Server: "preserved"},
},
},
},
{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
IPIsPrivate: true,
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.DNSRouteActionOptions{Server: "preserved"},
},
},
},
}, &fakeDNSTransportManager{
defaultTransport: defaultTransport,
transports: map[string]adapter.DNSTransport{
"default": defaultTransport,
"preserved": preservedTransport,
"would-be-new": wouldBeNewTransport,
},
}, &fakeDNSClient{
lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) {
lastUsedTransport.Store(transport.Tag())
response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60)
return MessageToAddresses(response), response, nil
},
}, loggerFactory.NewLogger("dns"))
t.Cleanup(func() {
closeErr := router.Close()
require.NoError(t, closeErr)
})
require.True(t, router.legacyDNSMode)
require.Equal(t, 1, callbackRuleSet.refCount())
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
require.NoError(t, err)
require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses)
require.Equal(t, "preserved", lastUsedTransport.Load())
rebuildTargetRuleSet := &fakeRuleSet{
metadata: adapter.RuleSetMetadata{
ContainsDNSQueryTypeRule: true,
},
match: func(*adapter.InboundContext) bool {
return true
},
}
routerService.setRuleSet("dynamic-set", rebuildTargetRuleSet)
callbackRuleSet.updateMetadata(adapter.RuleSetMetadata{
ContainsDNSQueryTypeRule: true,
})
rebuildErrorEntry := waitForLogMessageContaining(t, logEntries, logDone, "rebuild DNS rules after rule-set update")
require.Contains(t, rebuildErrorEntry.Message, "ip_cidr and ip_is_private require match_response")
require.True(t, router.legacyDNSMode)
require.Equal(t, 1, callbackRuleSet.refCount())
require.Zero(t, rebuildTargetRuleSet.refCount())
lastUsedTransport.Store("")
addresses, err = router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
require.NoError(t, err)
require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses)
require.Equal(t, "preserved", lastUsedTransport.Load())
require.NotEqual(t, "would-be-new", lastUsedTransport.Load())
}
func TestRuleSetUpdateSerializesConcurrentRebuilds(t *testing.T) {
t.Parallel()
callbackRuleSet := &fakeRuleSet{
match: func(*adapter.InboundContext) bool {
return false
},
}
routerService := &fakeRouter{
ruleSets: map[string]adapter.RuleSet{
"dynamic-set": callbackRuleSet,
},
}
ctx := service.ContextWith[adapter.Router](context.Background(), routerService)
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
firstTransport := &fakeDNSTransport{tag: "first", transportType: C.DNSTypeUDP}
secondTransport := &fakeDNSTransport{tag: "second", transportType: C.DNSTypeUDP}
var lastUsedTransport common.TypedValue[string]
router := newTestRouterWithContext(t, ctx, []option.DNSRule{
{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
RuleSet: badoption.Listable[string]{"dynamic-set"},
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.DNSRouteActionOptions{Server: "first"},
},
},
},
{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
Domain: badoption.Listable[string]{"example.com"},
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.DNSRouteActionOptions{Server: "second"},
},
},
},
}, &fakeDNSTransportManager{
defaultTransport: defaultTransport,
transports: map[string]adapter.DNSTransport{
"default": defaultTransport,
"first": firstTransport,
"second": secondTransport,
},
}, &fakeDNSClient{
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
lastUsedTransport.Store(transport.Tag())
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60), nil
},
})
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
require.NoError(t, err)
require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses)
require.Equal(t, "second", lastUsedTransport.Load())
callbacks := callbackRuleSet.snapshotCallbacks()
require.Len(t, callbacks, 1)
firstMetadataEntered := make(chan struct{})
releaseFirstMetadata := make(chan struct{})
firstRuleSetStarted := make(chan struct{})
releaseFirstRuleSetStart := make(chan struct{})
secondMetadataEntered := make(chan struct{})
releaseSecondMetadata := make(chan struct{})
var metadataAccess sync.Mutex
var metadataCallCount int
var concurrentMetadataCalls int
var maximumConcurrentMetadataCalls int
recordMetadataEntry := func() func() {
metadataAccess.Lock()
metadataCallCount++
concurrentMetadataCalls++
if concurrentMetadataCalls > maximumConcurrentMetadataCalls {
maximumConcurrentMetadataCalls = concurrentMetadataCalls
}
metadataAccess.Unlock()
return func() {
metadataAccess.Lock()
concurrentMetadataCalls--
metadataAccess.Unlock()
}
}
firstBuildRuleSet := &fakeRuleSet{
match: func(*adapter.InboundContext) bool {
return true
},
metadataRead: func(metadata adapter.RuleSetMetadata) adapter.RuleSetMetadata {
metadataDone := recordMetadataEntry()
close(firstMetadataEntered)
<-releaseFirstMetadata
metadataDone()
return metadata
},
afterIncrementReference: func() {
close(firstRuleSetStarted)
<-releaseFirstRuleSetStart
},
}
secondBuildRuleSet := &fakeRuleSet{
match: func(*adapter.InboundContext) bool {
return false
},
metadataRead: func(metadata adapter.RuleSetMetadata) adapter.RuleSetMetadata {
metadataDone := recordMetadataEntry()
close(secondMetadataEntered)
<-releaseSecondMetadata
metadataDone()
return metadata
},
}
routerService.setRuleSet("dynamic-set", firstBuildRuleSet)
firstCallbackFinished := make(chan struct{})
go func() {
callbacks[0](callbackRuleSet)
close(firstCallbackFinished)
}()
select {
case <-firstMetadataEntered:
case <-time.After(time.Second):
t.Fatal("first rebuild did not reach rule-set metadata")
}
close(releaseFirstMetadata)
select {
case <-firstRuleSetStarted:
case <-time.After(time.Second):
t.Fatal("first rebuild did not reach rule-set start")
}
routerService.setRuleSet("dynamic-set", secondBuildRuleSet)
secondCallbackStarted := make(chan struct{})
secondCallbackFinished := make(chan struct{})
go func() {
close(secondCallbackStarted)
callbacks[0](callbackRuleSet)
close(secondCallbackFinished)
}()
select {
case <-secondCallbackStarted:
case <-time.After(time.Second):
t.Fatal("second rebuild did not start")
}
select {
case <-secondMetadataEntered:
t.Fatal("second rebuild entered rule-set metadata before the first rebuild completed")
default:
}
close(releaseFirstRuleSetStart)
select {
case <-firstCallbackFinished:
case <-time.After(time.Second):
t.Fatal("first rebuild callback did not finish")
}
select {
case <-secondMetadataEntered:
case <-time.After(time.Second):
t.Fatal("second rebuild did not enter rule-set metadata after the first rebuild finished")
}
addresses, err = router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
require.NoError(t, err)
require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses)
require.Equal(t, "first", lastUsedTransport.Load())
close(releaseSecondMetadata)
select {
case <-secondCallbackFinished:
case <-time.After(time.Second):
t.Fatal("second rebuild callback did not finish")
}
metadataAccess.Lock()
require.Equal(t, 2, metadataCallCount)
require.Equal(t, 1, maximumConcurrentMetadataCalls)
metadataAccess.Unlock()
require.Zero(t, callbackRuleSet.refCount())
require.Zero(t, firstBuildRuleSet.refCount())
require.Equal(t, 1, secondBuildRuleSet.refCount())
lastUsedTransport.Store("")
addresses, err = router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
require.NoError(t, err)
require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses)
require.Equal(t, "second", lastUsedTransport.Load())
err = router.Close()
require.NoError(t, err)
require.Zero(t, callbackRuleSet.refCount())
require.Zero(t, firstBuildRuleSet.refCount())
require.Zero(t, secondBuildRuleSet.refCount())
}
func TestCloseDuringRebuildDiscardsResult(t *testing.T) {
t.Parallel()
fakeSet := &fakeRuleSet{
metadata: adapter.RuleSetMetadata{
ContainsIPCIDRRule: true,
},
}
ctx := service.ContextWith[adapter.Router](context.Background(), &fakeRouter{
ruleSets: map[string]adapter.RuleSet{
"dynamic-set": fakeSet,
@@ -560,42 +971,73 @@ func TestRuleSetUpdateSetsRuntimeErrorWhenRebuildFails(t *testing.T) {
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.DNSRouteActionOptions{Server: "default"},
},
},
},
{
Type: C.RuleTypeDefault,
DefaultOptions: option.DefaultDNSRule{
RawDefaultDNSRule: option.RawDefaultDNSRule{
IPIsPrivate: true,
},
DNSRuleAction: option.DNSRuleAction{
Action: C.RuleActionTypeRoute,
RouteOptions: option.DNSRouteActionOptions{Server: "default"},
RouteOptions: option.DNSRouteActionOptions{Server: "installed"},
},
},
},
}, &fakeDNSTransportManager{
defaultTransport: defaultTransport,
transports: map[string]adapter.DNSTransport{
"default": defaultTransport,
"default": defaultTransport,
"discarded": &fakeDNSTransport{tag: "discarded", transportType: C.DNSTypeUDP},
"installed": &fakeDNSTransport{tag: "installed", transportType: C.DNSTypeUDP},
},
}, &fakeDNSClient{
lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) {
response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60)
return MessageToAddresses(response), response, nil
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
switch transport.Tag() {
case "discarded", "installed", "default":
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60), nil
default:
return nil, E.New("unexpected transport: ", transport.Tag())
}
},
})
require.True(t, router.legacyDNSMode)
require.Equal(t, 1, fakeSet.refCount())
fakeSet.updateMetadata(adapter.RuleSetMetadata{
ContainsDNSQueryTypeRule: true,
})
callbacks := fakeSet.snapshotCallbacks()
require.Len(t, callbacks, 1)
_, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response")
firstMetadataEntered := make(chan struct{})
releaseFirstMetadata := make(chan struct{})
callbackFinished := make(chan struct{})
fakeSet.metadataRead = func(metadata adapter.RuleSetMetadata) adapter.RuleSetMetadata {
router.rawRules[0].DefaultOptions.RouteOptions.Server = "discarded"
close(firstMetadataEntered)
<-releaseFirstMetadata
return adapter.RuleSetMetadata{}
}
go func() {
callbacks[0](fakeSet)
close(callbackFinished)
}()
select {
case <-firstMetadataEntered:
case <-time.After(time.Second):
t.Fatal("rebuild did not reach rule-set metadata")
}
err := router.Close()
require.NoError(t, err)
close(releaseFirstMetadata)
select {
case <-callbackFinished:
case <-time.After(time.Second):
t.Fatal("rebuild callback did not finish after close")
}
fakeSet.metadataRead = nil
router.rulesAccess.RLock()
require.True(t, router.closing)
require.Nil(t, router.rules)
require.Empty(t, router.ruleSetCallbacks)
router.rulesAccess.RUnlock()
require.True(t, router.legacyDNSMode)
require.Zero(t, fakeSet.refCount())
}
func TestCloseIgnoresSnapshottedRuleSetCallback(t *testing.T) {
@@ -661,7 +1103,6 @@ func TestCloseIgnoresSnapshottedRuleSetCallback(t *testing.T) {
require.True(t, router.closing)
require.Nil(t, router.rules)
require.Empty(t, router.ruleSetCallbacks)
require.NoError(t, router.runtimeRuleError)
}
func TestLookupLegacyDNSModeDefersDirectDestinationIPMatch(t *testing.T) {
@@ -680,7 +1121,7 @@ func TestLookupLegacyDNSModeDefersDirectDestinationIPMatch(t *testing.T) {
case "default":
t.Fatal("default transport should not be used when legacy rule matches after response")
}
return nil, nil, errors.New("unexpected transport")
return nil, nil, E.New("unexpected transport")
},
}
router := newTestRouter(t, []option.DNSRule{{
@@ -716,12 +1157,23 @@ func TestLookupLegacyDNSModeFallsBackAfterRejectedAddressLimitResponse(t *testin
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
privateTransport := &fakeDNSTransport{tag: "private", transportType: C.DNSTypeUDP}
var lookups []string
var lookupAccess sync.Mutex
var lookupTags []string
recordLookup := func(tag string) {
lookupAccess.Lock()
lookupTags = append(lookupTags, tag)
lookupAccess.Unlock()
}
currentLookupTags := func() []string {
lookupAccess.Lock()
defer lookupAccess.Unlock()
return append([]string(nil), lookupTags...)
}
client := &fakeDNSClient{
lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) {
require.Equal(t, "example.com", domain)
require.Equal(t, C.DomainStrategyIPv4Only, options.LookupStrategy)
lookups = append(lookups, transport.Tag())
recordLookup(transport.Tag())
switch transport.Tag() {
case "private":
response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60)
@@ -730,7 +1182,7 @@ func TestLookupLegacyDNSModeFallsBackAfterRejectedAddressLimitResponse(t *testin
response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("9.9.9.9")}, 60)
return MessageToAddresses(response), response, nil
}
return nil, nil, errors.New("unexpected transport")
return nil, nil, E.New("unexpected transport")
},
}
router := newTestRouter(t, []option.DNSRule{{
@@ -757,7 +1209,7 @@ func TestLookupLegacyDNSModeFallsBackAfterRejectedAddressLimitResponse(t *testin
})
require.NoError(t, err)
require.Equal(t, []netip.Addr{netip.MustParseAddr("9.9.9.9")}, addresses)
require.Equal(t, []string{"private", "default"}, lookups)
require.Equal(t, []string{"private", "default"}, currentLookupTags())
}
func TestLookupLegacyDNSModeRuleSetAcceptEmptyDoesNotTreatMismatchAsEmpty(t *testing.T) {
@@ -785,7 +1237,18 @@ func TestLookupLegacyDNSModeRuleSetAcceptEmptyDoesNotTreatMismatchAsEmpty(t *tes
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
privateTransport := &fakeDNSTransport{tag: "private", transportType: C.DNSTypeUDP}
var lookups []string
var lookupAccess sync.Mutex
var lookupTags []string
recordLookup := func(tag string) {
lookupAccess.Lock()
lookupTags = append(lookupTags, tag)
lookupAccess.Unlock()
}
currentLookupTags := func() []string {
lookupAccess.Lock()
defer lookupAccess.Unlock()
return append([]string(nil), lookupTags...)
}
router := newTestRouterWithContext(t, ctx, []option.DNSRule{
{
Type: C.RuleTypeDefault,
@@ -819,7 +1282,7 @@ func TestLookupLegacyDNSModeRuleSetAcceptEmptyDoesNotTreatMismatchAsEmpty(t *tes
lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) {
require.Equal(t, "example.com", domain)
require.Equal(t, C.DomainStrategyIPv4Only, options.LookupStrategy)
lookups = append(lookups, transport.Tag())
recordLookup(transport.Tag())
switch transport.Tag() {
case "private":
response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60)
@@ -828,7 +1291,7 @@ func TestLookupLegacyDNSModeRuleSetAcceptEmptyDoesNotTreatMismatchAsEmpty(t *tes
response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("9.9.9.9")}, 60)
return MessageToAddresses(response), response, nil
}
return nil, nil, errors.New("unexpected transport")
return nil, nil, E.New("unexpected transport")
},
})
@@ -839,7 +1302,7 @@ func TestLookupLegacyDNSModeRuleSetAcceptEmptyDoesNotTreatMismatchAsEmpty(t *tes
})
require.NoError(t, err)
require.Equal(t, []netip.Addr{netip.MustParseAddr("9.9.9.9")}, addresses)
require.Equal(t, []string{"private", "default"}, lookups)
require.Equal(t, []string{"private", "default"}, currentLookupTags())
}
func TestDNSResponseAddressesMatchesMessageToAddressesForHTTPSHints(t *testing.T) {
@@ -872,7 +1335,7 @@ func TestExchangeLegacyDNSModeDisabledEvaluateMatchResponseRoute(t *testing.T) {
case "selected":
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60), nil
default:
return nil, errors.New("unexpected transport")
return nil, E.New("unexpected transport")
}
},
}
@@ -931,7 +1394,7 @@ func TestExchangeLegacyDNSModeDisabledEvaluateMatchResponseRouteIgnoresTTL(t *te
case "selected":
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60), nil
default:
return nil, errors.New("unexpected transport")
return nil, E.New("unexpected transport")
}
},
}
@@ -990,7 +1453,7 @@ func TestExchangeLegacyDNSModeDisabledEvaluateMatchResponseRouteWithHTTPSHints(t
case "selected":
return fixedHTTPSHintResponse(message.Question[0], netip.MustParseAddr("8.8.8.8")), nil
default:
return nil, errors.New("unexpected transport")
return nil, E.New("unexpected transport")
}
},
}
@@ -1049,7 +1512,7 @@ func TestExchangeLegacyDNSModeDisabledEvaluateMatchResponseRouteWithMappedHTTPSI
case "selected":
return fixedHTTPSHintResponse(message.Question[0], netip.MustParseAddr("8.8.8.8")), nil
default:
return nil, errors.New("unexpected transport")
return nil, E.New("unexpected transport")
}
},
}
@@ -1119,7 +1582,7 @@ func TestExchangeLegacyDNSModeDisabledEvaluateDoesNotLeakAddressesToNextQuery(t
case "selected":
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60), nil
default:
return nil, errors.New("unexpected transport")
return nil, E.New("unexpected transport")
}
},
}
@@ -1181,7 +1644,7 @@ func TestExchangeLegacyDNSModeDisabledEvaluateRouteResolutionFailureClearsRespon
case "default":
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("4.4.4.4")}, 60), nil
default:
return nil, errors.New("unexpected transport")
return nil, E.New("unexpected transport")
}
},
}
@@ -1268,13 +1731,13 @@ func TestExchangeLegacyDNSModeDisabledEvaluateExchangeFailureUsesMatchResponseBo
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
switch transport.Tag() {
case "upstream":
return nil, errors.New("upstream exchange failed")
return nil, E.New("upstream exchange failed")
case "selected":
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60), nil
case "default":
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("4.4.4.4")}, 60), nil
default:
return nil, errors.New("unexpected transport")
return nil, E.New("unexpected transport")
}
},
}
@@ -1332,9 +1795,9 @@ func TestLookupLegacyDNSModeDisabledAllowsPartialSuccess(t *testing.T) {
case mDNS.TypeA:
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("1.1.1.1")}, 60), nil
case mDNS.TypeAAAA:
return nil, errors.New("ipv6 failed")
return nil, E.New("ipv6 failed")
default:
return nil, errors.New("unexpected qtype")
return nil, E.New("unexpected qtype")
}
},
})
@@ -1451,7 +1914,7 @@ func TestLookupLegacyDNSModeDisabledEvaluateSkipFakeIPPreservesResponse(t *testi
}
return FixedResponse(0, message.Question[0], nil, 60), nil
default:
return nil, errors.New("unexpected transport")
return nil, E.New("unexpected transport")
}
},
})
@@ -1494,7 +1957,7 @@ func TestLookupLegacyDNSModeDisabledUsesQueryTypeRule(t *testing.T) {
case "only-a":
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("9.9.9.9")}, 60), nil
default:
return nil, errors.New("unexpected transport")
return nil, E.New("unexpected transport")
}
},
})
@@ -1560,7 +2023,7 @@ func TestLookupLegacyDNSModeDisabledUsesRuleSetQueryTypeRule(t *testing.T) {
}
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2001:db8::9")}, 60), nil
default:
return nil, errors.New("unexpected transport")
return nil, E.New("unexpected transport")
}
},
})
@@ -1609,7 +2072,7 @@ func TestLookupLegacyDNSModeDisabledUsesIPVersionRule(t *testing.T) {
}
return FixedResponse(0, message.Question[0], nil, 60), nil
default:
return nil, errors.New("unexpected transport")
return nil, E.New("unexpected transport")
}
},
})
@@ -1829,7 +2292,18 @@ func TestLookupLegacyDNSModeDisabledUsesInputStrategy(t *testing.T) {
t.Parallel()
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
var qTypes []uint16
var queryTypeAccess sync.Mutex
var queryTypes []uint16
recordQueryType := func(queryType uint16) {
queryTypeAccess.Lock()
queryTypes = append(queryTypes, queryType)
queryTypeAccess.Unlock()
}
currentQueryTypes := func() []uint16 {
queryTypeAccess.Lock()
defer queryTypeAccess.Unlock()
return append([]uint16(nil), queryTypes...)
}
router := newTestRouter(t, nil, &fakeDNSTransportManager{
defaultTransport: defaultTransport,
transports: map[string]adapter.DNSTransport{
@@ -1837,7 +2311,7 @@ func TestLookupLegacyDNSModeDisabledUsesInputStrategy(t *testing.T) {
},
}, &fakeDNSClient{
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
qTypes = append(qTypes, message.Question[0].Qtype)
recordQueryType(message.Question[0].Qtype)
if message.Question[0].Qtype == mDNS.TypeA {
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2.2.2.2")}, 60), nil
}
@@ -1850,7 +2324,7 @@ func TestLookupLegacyDNSModeDisabledUsesInputStrategy(t *testing.T) {
Strategy: C.DomainStrategyIPv4Only,
})
require.NoError(t, err)
require.Equal(t, []uint16{mDNS.TypeA}, qTypes)
require.Equal(t, []uint16{mDNS.TypeA}, currentQueryTypes())
require.Equal(t, []netip.Addr{netip.MustParseAddr("2.2.2.2")}, addresses)
}
@@ -1858,7 +2332,18 @@ func TestLookupLegacyDNSModeDisabledUsesDefaultStrategy(t *testing.T) {
t.Parallel()
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
var qTypes []uint16
var queryTypeAccess sync.Mutex
var queryTypes []uint16
recordQueryType := func(queryType uint16) {
queryTypeAccess.Lock()
queryTypes = append(queryTypes, queryType)
queryTypeAccess.Unlock()
}
currentQueryTypes := func() []uint16 {
queryTypeAccess.Lock()
defer queryTypeAccess.Unlock()
return append([]uint16(nil), queryTypes...)
}
router := newTestRouter(t, nil, &fakeDNSTransportManager{
defaultTransport: defaultTransport,
transports: map[string]adapter.DNSTransport{
@@ -1866,7 +2351,7 @@ func TestLookupLegacyDNSModeDisabledUsesDefaultStrategy(t *testing.T) {
},
}, &fakeDNSClient{
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
qTypes = append(qTypes, message.Question[0].Qtype)
recordQueryType(message.Question[0].Qtype)
if message.Question[0].Qtype == mDNS.TypeA {
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2.2.2.2")}, 60), nil
}
@@ -1878,7 +2363,7 @@ func TestLookupLegacyDNSModeDisabledUsesDefaultStrategy(t *testing.T) {
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
require.NoError(t, err)
require.Equal(t, []uint16{mDNS.TypeA}, qTypes)
require.Equal(t, []uint16{mDNS.TypeA}, currentQueryTypes())
require.Equal(t, []netip.Addr{netip.MustParseAddr("2.2.2.2")}, addresses)
}
@@ -1903,7 +2388,7 @@ func TestExchangeLegacyDNSModeDisabledLogicalMatchResponseIPCIDRFallsThrough(t *
case "default":
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("4.4.4.4")}, 60), nil
default:
return nil, errors.New("unexpected transport")
return nil, E.New("unexpected transport")
}
},
}