mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-14 04:38:28 +10:00
dns: serialize rebuilds and keep last good rules on failure
This commit is contained in:
@@ -50,9 +50,9 @@ type Router struct {
|
||||
platformInterface adapter.PlatformInterface
|
||||
legacyDNSMode bool
|
||||
rulesAccess sync.RWMutex
|
||||
rebuildAccess sync.Mutex
|
||||
closing bool
|
||||
ruleSetCallbacks []dnsRuleSetCallback
|
||||
runtimeRuleError error
|
||||
addressFilterDeprecatedReported bool
|
||||
ruleStrategyDeprecatedReported bool
|
||||
}
|
||||
@@ -116,11 +116,19 @@ func (r *Router) Start(stage adapter.StartStage) error {
|
||||
return err
|
||||
}
|
||||
monitor.Start("register DNS rule-set callbacks")
|
||||
err = r.registerRuleSetCallbacks()
|
||||
needsRulesRefresh, err := r.registerRuleSetCallbacks()
|
||||
monitor.Finish()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if needsRulesRefresh {
|
||||
monitor.Start("refresh DNS rules after callback registration")
|
||||
err = r.rebuildRules(true)
|
||||
monitor.Finish()
|
||||
if err != nil {
|
||||
r.logger.Error(E.Cause(err, "refresh DNS rules after callback registration"))
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -133,7 +141,6 @@ func (r *Router) Close() error {
|
||||
r.ruleSetCallbacks = nil
|
||||
runtimeRules := r.rules
|
||||
r.rules = nil
|
||||
r.runtimeRuleError = nil
|
||||
for _, callback := range callbacks {
|
||||
callback.ruleSet.UnregisterCallback(callback.element)
|
||||
}
|
||||
@@ -150,6 +157,8 @@ func (r *Router) Close() error {
|
||||
}
|
||||
|
||||
func (r *Router) rebuildRules(startRules bool) error {
|
||||
r.rebuildAccess.Lock()
|
||||
defer r.rebuildAccess.Unlock()
|
||||
if r.isClosing() {
|
||||
return nil
|
||||
}
|
||||
@@ -177,7 +186,6 @@ func (r *Router) rebuildRules(startRules bool) error {
|
||||
oldRules := r.rules
|
||||
r.rules = newRules
|
||||
r.legacyDNSMode = legacyDNSMode
|
||||
r.runtimeRuleError = nil
|
||||
if shouldReportAddressFilterDeprecated {
|
||||
r.addressFilterDeprecatedReported = true
|
||||
}
|
||||
@@ -246,20 +254,20 @@ func closeRules(rules []adapter.DNSRule) {
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Router) registerRuleSetCallbacks() error {
|
||||
func (r *Router) registerRuleSetCallbacks() (bool, error) {
|
||||
tags := referencedDNSRuleSetTags(r.rawRules)
|
||||
if len(tags) == 0 {
|
||||
return nil
|
||||
return false, nil
|
||||
}
|
||||
r.rulesAccess.RLock()
|
||||
if len(r.ruleSetCallbacks) > 0 {
|
||||
r.rulesAccess.RUnlock()
|
||||
return nil
|
||||
return true, nil
|
||||
}
|
||||
r.rulesAccess.RUnlock()
|
||||
router := service.FromContext[adapter.Router](r.ctx)
|
||||
if router == nil {
|
||||
return E.New("router service not found")
|
||||
return false, E.New("router service not found")
|
||||
}
|
||||
callbacks := make([]dnsRuleSetCallback, 0, len(tags))
|
||||
for _, tag := range tags {
|
||||
@@ -268,14 +276,11 @@ func (r *Router) registerRuleSetCallbacks() error {
|
||||
for _, callback := range callbacks {
|
||||
callback.ruleSet.UnregisterCallback(callback.element)
|
||||
}
|
||||
return E.New("rule-set not found: ", tag)
|
||||
return false, E.New("rule-set not found: ", tag)
|
||||
}
|
||||
element := ruleSet.RegisterCallback(func(adapter.RuleSet) {
|
||||
err := r.rebuildRules(true)
|
||||
if err != nil {
|
||||
r.rulesAccess.Lock()
|
||||
r.runtimeRuleError = err
|
||||
r.rulesAccess.Unlock()
|
||||
r.logger.Error(E.Cause(err, "rebuild DNS rules after rule-set update"))
|
||||
}
|
||||
})
|
||||
@@ -293,7 +298,7 @@ func (r *Router) registerRuleSetCallbacks() error {
|
||||
for _, callback := range callbacks {
|
||||
callback.ruleSet.UnregisterCallback(callback.element)
|
||||
}
|
||||
return nil
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *Router) matchDNS(ctx context.Context, allowFakeIP bool, ruleIndex int, isAddressQuery bool, options *adapter.DNSQueryOptions) (adapter.DNSTransport, adapter.DNSRule, int) {
|
||||
@@ -653,9 +658,6 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte
|
||||
}
|
||||
r.rulesAccess.RLock()
|
||||
defer r.rulesAccess.RUnlock()
|
||||
if r.runtimeRuleError != nil {
|
||||
return nil, r.runtimeRuleError
|
||||
}
|
||||
r.logger.DebugContext(ctx, "exchange ", FormatQuestion(message.Question[0].String()))
|
||||
var (
|
||||
response *mDNS.Msg
|
||||
@@ -760,9 +762,6 @@ func (r *Router) Exchange(ctx context.Context, message *mDNS.Msg, options adapte
|
||||
func (r *Router) Lookup(ctx context.Context, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, error) {
|
||||
r.rulesAccess.RLock()
|
||||
defer r.rulesAccess.RUnlock()
|
||||
if r.runtimeRuleError != nil {
|
||||
return nil, r.runtimeRuleError
|
||||
}
|
||||
var (
|
||||
responseAddrs []netip.Addr
|
||||
err error
|
||||
|
||||
@@ -2,9 +2,10 @@ package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -16,6 +17,8 @@ import (
|
||||
"github.com/sagernet/sing-box/option"
|
||||
rulepkg "github.com/sagernet/sing-box/route/rule"
|
||||
"github.com/sagernet/sing-tun"
|
||||
"github.com/sagernet/sing/common"
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
"github.com/sagernet/sing/common/json/badoption"
|
||||
N "github.com/sagernet/sing/common/network"
|
||||
"github.com/sagernet/sing/common/x/list"
|
||||
@@ -38,7 +41,7 @@ func (t *fakeDNSTransport) Tag() string { return t.tag }
|
||||
func (t *fakeDNSTransport) Dependencies() []string { return nil }
|
||||
func (t *fakeDNSTransport) Reset() {}
|
||||
func (t *fakeDNSTransport) Exchange(context.Context, *mDNS.Msg) (*mDNS.Msg, error) {
|
||||
return nil, errors.New("unused transport exchange")
|
||||
return nil, E.New("unused transport exchange")
|
||||
}
|
||||
|
||||
type fakeDNSTransportManager struct {
|
||||
@@ -66,7 +69,7 @@ func (m *fakeDNSTransportManager) FakeIP() adapter.FakeIPTransport {
|
||||
}
|
||||
func (m *fakeDNSTransportManager) Remove(string) error { return nil }
|
||||
func (m *fakeDNSTransportManager) Create(context.Context, log.ContextLogger, string, string, any) error {
|
||||
return errors.New("unsupported")
|
||||
return E.New("unsupported")
|
||||
}
|
||||
|
||||
type fakeDNSClient struct {
|
||||
@@ -80,6 +83,7 @@ type fakeDeprecatedManager struct {
|
||||
}
|
||||
|
||||
type fakeRouter struct {
|
||||
access sync.RWMutex
|
||||
ruleSets map[string]adapter.RuleSet
|
||||
}
|
||||
|
||||
@@ -104,9 +108,19 @@ func (r *fakeRouter) RoutePacketConnectionEx(context.Context, N.PacketConn, adap
|
||||
}
|
||||
|
||||
func (r *fakeRouter) RuleSet(tag string) (adapter.RuleSet, bool) {
|
||||
r.access.RLock()
|
||||
defer r.access.RUnlock()
|
||||
ruleSet, loaded := r.ruleSets[tag]
|
||||
return ruleSet, loaded
|
||||
}
|
||||
func (r *fakeRouter) setRuleSet(tag string, ruleSet adapter.RuleSet) {
|
||||
r.access.Lock()
|
||||
defer r.access.Unlock()
|
||||
if r.ruleSets == nil {
|
||||
r.ruleSets = make(map[string]adapter.RuleSet)
|
||||
}
|
||||
r.ruleSets[tag] = ruleSet
|
||||
}
|
||||
func (r *fakeRouter) Rules() []adapter.Rule { return nil }
|
||||
func (r *fakeRouter) NeedFindProcess() bool { return false }
|
||||
func (r *fakeRouter) NeedFindNeighbor() bool { return false }
|
||||
@@ -115,10 +129,14 @@ func (r *fakeRouter) AppendTracker(adapter.ConnectionTracker) {}
|
||||
func (r *fakeRouter) ResetNetwork() {}
|
||||
|
||||
type fakeRuleSet struct {
|
||||
access sync.Mutex
|
||||
metadata adapter.RuleSetMetadata
|
||||
callbacks list.List[adapter.RuleSetUpdateCallback]
|
||||
refs int
|
||||
access sync.Mutex
|
||||
metadata adapter.RuleSetMetadata
|
||||
metadataRead func(adapter.RuleSetMetadata) adapter.RuleSetMetadata
|
||||
match func(*adapter.InboundContext) bool
|
||||
callbacks list.List[adapter.RuleSetUpdateCallback]
|
||||
refs int
|
||||
afterIncrementReference func()
|
||||
beforeDecrementReference func()
|
||||
}
|
||||
|
||||
func (s *fakeRuleSet) Name() string { return "fake-rule-set" }
|
||||
@@ -126,17 +144,32 @@ func (s *fakeRuleSet) StartContext(context.Context, *adapter.HTTPStartContext) e
|
||||
func (s *fakeRuleSet) PostStart() error { return nil }
|
||||
func (s *fakeRuleSet) Metadata() adapter.RuleSetMetadata {
|
||||
s.access.Lock()
|
||||
defer s.access.Unlock()
|
||||
return s.metadata
|
||||
metadata := s.metadata
|
||||
metadataRead := s.metadataRead
|
||||
s.access.Unlock()
|
||||
if metadataRead != nil {
|
||||
return metadataRead(metadata)
|
||||
}
|
||||
return metadata
|
||||
}
|
||||
func (s *fakeRuleSet) ExtractIPSet() []*netipx.IPSet { return nil }
|
||||
func (s *fakeRuleSet) IncRef() {
|
||||
s.access.Lock()
|
||||
defer s.access.Unlock()
|
||||
s.refs++
|
||||
afterIncrementReference := s.afterIncrementReference
|
||||
s.access.Unlock()
|
||||
if afterIncrementReference != nil {
|
||||
afterIncrementReference()
|
||||
}
|
||||
}
|
||||
|
||||
func (s *fakeRuleSet) DecRef() {
|
||||
s.access.Lock()
|
||||
beforeDecrementReference := s.beforeDecrementReference
|
||||
s.access.Unlock()
|
||||
if beforeDecrementReference != nil {
|
||||
beforeDecrementReference()
|
||||
}
|
||||
s.access.Lock()
|
||||
defer s.access.Unlock()
|
||||
s.refs--
|
||||
@@ -156,9 +189,17 @@ func (s *fakeRuleSet) UnregisterCallback(element *list.Element[adapter.RuleSetUp
|
||||
defer s.access.Unlock()
|
||||
s.callbacks.Remove(element)
|
||||
}
|
||||
func (s *fakeRuleSet) Close() error { return nil }
|
||||
func (s *fakeRuleSet) Match(*adapter.InboundContext) bool { return true }
|
||||
func (s *fakeRuleSet) String() string { return "fake-rule-set" }
|
||||
func (s *fakeRuleSet) Close() error { return nil }
|
||||
func (s *fakeRuleSet) Match(metadata *adapter.InboundContext) bool {
|
||||
s.access.Lock()
|
||||
match := s.match
|
||||
s.access.Unlock()
|
||||
if match != nil {
|
||||
return match(metadata)
|
||||
}
|
||||
return true
|
||||
}
|
||||
func (s *fakeRuleSet) String() string { return "fake-rule-set" }
|
||||
func (s *fakeRuleSet) updateMetadata(metadata adapter.RuleSetMetadata) {
|
||||
s.access.Lock()
|
||||
s.metadata = metadata
|
||||
@@ -196,7 +237,7 @@ func (c *fakeDNSClient) Exchange(ctx context.Context, transport adapter.DNSTrans
|
||||
|
||||
func (c *fakeDNSClient) Lookup(_ context.Context, transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions, responseChecker func(*mDNS.Msg) bool) ([]netip.Addr, error) {
|
||||
if c.lookup == nil {
|
||||
return nil, errors.New("unused client lookup")
|
||||
return nil, E.New("unused client lookup")
|
||||
}
|
||||
addresses, response, err := c.lookup(transport, domain, options)
|
||||
if err != nil {
|
||||
@@ -221,10 +262,14 @@ func newTestRouter(t *testing.T, rules []option.DNSRule, transportManager *fakeD
|
||||
}
|
||||
|
||||
func newTestRouterWithContext(t *testing.T, ctx context.Context, rules []option.DNSRule, transportManager *fakeDNSTransportManager, client *fakeDNSClient) *Router {
|
||||
return newTestRouterWithContextAndLogger(t, ctx, rules, transportManager, client, log.NewNOPFactory().NewLogger("dns"))
|
||||
}
|
||||
|
||||
func newTestRouterWithContextAndLogger(t *testing.T, ctx context.Context, rules []option.DNSRule, transportManager *fakeDNSTransportManager, client *fakeDNSClient, dnsLogger log.ContextLogger) *Router {
|
||||
t.Helper()
|
||||
router := &Router{
|
||||
ctx: ctx,
|
||||
logger: log.NewNOPFactory().NewLogger("dns"),
|
||||
logger: dnsLogger,
|
||||
transport: transportManager,
|
||||
client: client,
|
||||
rawRules: make([]option.DNSRule, 0, len(rules)),
|
||||
@@ -240,6 +285,26 @@ func newTestRouterWithContext(t *testing.T, ctx context.Context, rules []option.
|
||||
return router
|
||||
}
|
||||
|
||||
func waitForLogMessageContaining(t *testing.T, entries <-chan log.Entry, done <-chan struct{}, substring string) log.Entry {
|
||||
t.Helper()
|
||||
timeout := time.After(time.Second)
|
||||
for {
|
||||
select {
|
||||
case entry, ok := <-entries:
|
||||
if !ok {
|
||||
t.Fatal("log subscription closed")
|
||||
}
|
||||
if strings.Contains(entry.Message, substring) {
|
||||
return entry
|
||||
}
|
||||
case <-done:
|
||||
t.Fatal("log subscription closed")
|
||||
case <-timeout:
|
||||
t.Fatalf("timed out waiting for log message containing %q", substring)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func fixedQuestion(name string, qType uint16) mDNS.Question {
|
||||
return mDNS.Question{
|
||||
Name: mDNS.Fqdn(name),
|
||||
@@ -541,10 +606,356 @@ func TestRuleSetUpdateReleasesOldRuleSetRefs(t *testing.T) {
|
||||
require.Zero(t, fakeSet.refCount())
|
||||
}
|
||||
|
||||
func TestRuleSetUpdateSetsRuntimeErrorWhenRebuildFails(t *testing.T) {
|
||||
func TestRuleSetUpdateKeepsLastSuccessfullyCompiledRuleGraphWhenRebuildFails(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fakeSet := &fakeRuleSet{}
|
||||
callbackRuleSet := &fakeRuleSet{
|
||||
match: func(*adapter.InboundContext) bool {
|
||||
return false
|
||||
},
|
||||
}
|
||||
routerService := &fakeRouter{
|
||||
ruleSets: map[string]adapter.RuleSet{
|
||||
"dynamic-set": callbackRuleSet,
|
||||
},
|
||||
}
|
||||
ctx := service.ContextWith[adapter.Router](context.Background(), routerService)
|
||||
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
|
||||
preservedTransport := &fakeDNSTransport{tag: "preserved", transportType: C.DNSTypeUDP}
|
||||
wouldBeNewTransport := &fakeDNSTransport{tag: "would-be-new", transportType: C.DNSTypeUDP}
|
||||
loggerFactory := log.NewDefaultFactory(
|
||||
context.Background(),
|
||||
log.Formatter{
|
||||
BaseTime: time.Now(),
|
||||
DisableColors: true,
|
||||
DisableTimestamp: true,
|
||||
},
|
||||
io.Discard,
|
||||
"",
|
||||
nil,
|
||||
true,
|
||||
)
|
||||
loggerFactory.SetLevel(log.LevelError)
|
||||
logEntries, logDone, err := loggerFactory.Subscribe()
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
loggerFactory.UnSubscribe(logEntries)
|
||||
closeErr := loggerFactory.Close()
|
||||
require.NoError(t, closeErr)
|
||||
})
|
||||
var lastUsedTransport common.TypedValue[string]
|
||||
router := newTestRouterWithContextAndLogger(t, ctx, []option.DNSRule{
|
||||
{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
RuleSet: badoption.Listable[string]{"dynamic-set"},
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeRoute,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "would-be-new"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
Domain: badoption.Listable[string]{"example.com"},
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeRoute,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "preserved"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
IPIsPrivate: true,
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeRoute,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "preserved"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, &fakeDNSTransportManager{
|
||||
defaultTransport: defaultTransport,
|
||||
transports: map[string]adapter.DNSTransport{
|
||||
"default": defaultTransport,
|
||||
"preserved": preservedTransport,
|
||||
"would-be-new": wouldBeNewTransport,
|
||||
},
|
||||
}, &fakeDNSClient{
|
||||
lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) {
|
||||
lastUsedTransport.Store(transport.Tag())
|
||||
response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60)
|
||||
return MessageToAddresses(response), response, nil
|
||||
},
|
||||
}, loggerFactory.NewLogger("dns"))
|
||||
t.Cleanup(func() {
|
||||
closeErr := router.Close()
|
||||
require.NoError(t, closeErr)
|
||||
})
|
||||
|
||||
require.True(t, router.legacyDNSMode)
|
||||
require.Equal(t, 1, callbackRuleSet.refCount())
|
||||
|
||||
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses)
|
||||
require.Equal(t, "preserved", lastUsedTransport.Load())
|
||||
|
||||
rebuildTargetRuleSet := &fakeRuleSet{
|
||||
metadata: adapter.RuleSetMetadata{
|
||||
ContainsDNSQueryTypeRule: true,
|
||||
},
|
||||
match: func(*adapter.InboundContext) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
routerService.setRuleSet("dynamic-set", rebuildTargetRuleSet)
|
||||
|
||||
callbackRuleSet.updateMetadata(adapter.RuleSetMetadata{
|
||||
ContainsDNSQueryTypeRule: true,
|
||||
})
|
||||
rebuildErrorEntry := waitForLogMessageContaining(t, logEntries, logDone, "rebuild DNS rules after rule-set update")
|
||||
require.Contains(t, rebuildErrorEntry.Message, "ip_cidr and ip_is_private require match_response")
|
||||
require.True(t, router.legacyDNSMode)
|
||||
require.Equal(t, 1, callbackRuleSet.refCount())
|
||||
require.Zero(t, rebuildTargetRuleSet.refCount())
|
||||
|
||||
lastUsedTransport.Store("")
|
||||
addresses, err = router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses)
|
||||
require.Equal(t, "preserved", lastUsedTransport.Load())
|
||||
require.NotEqual(t, "would-be-new", lastUsedTransport.Load())
|
||||
}
|
||||
|
||||
func TestRuleSetUpdateSerializesConcurrentRebuilds(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
callbackRuleSet := &fakeRuleSet{
|
||||
match: func(*adapter.InboundContext) bool {
|
||||
return false
|
||||
},
|
||||
}
|
||||
routerService := &fakeRouter{
|
||||
ruleSets: map[string]adapter.RuleSet{
|
||||
"dynamic-set": callbackRuleSet,
|
||||
},
|
||||
}
|
||||
ctx := service.ContextWith[adapter.Router](context.Background(), routerService)
|
||||
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
|
||||
firstTransport := &fakeDNSTransport{tag: "first", transportType: C.DNSTypeUDP}
|
||||
secondTransport := &fakeDNSTransport{tag: "second", transportType: C.DNSTypeUDP}
|
||||
var lastUsedTransport common.TypedValue[string]
|
||||
router := newTestRouterWithContext(t, ctx, []option.DNSRule{
|
||||
{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
RuleSet: badoption.Listable[string]{"dynamic-set"},
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeRoute,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "first"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
Domain: badoption.Listable[string]{"example.com"},
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeRoute,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "second"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, &fakeDNSTransportManager{
|
||||
defaultTransport: defaultTransport,
|
||||
transports: map[string]adapter.DNSTransport{
|
||||
"default": defaultTransport,
|
||||
"first": firstTransport,
|
||||
"second": secondTransport,
|
||||
},
|
||||
}, &fakeDNSClient{
|
||||
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||
lastUsedTransport.Store(transport.Tag())
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60), nil
|
||||
},
|
||||
})
|
||||
|
||||
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses)
|
||||
require.Equal(t, "second", lastUsedTransport.Load())
|
||||
|
||||
callbacks := callbackRuleSet.snapshotCallbacks()
|
||||
require.Len(t, callbacks, 1)
|
||||
|
||||
firstMetadataEntered := make(chan struct{})
|
||||
releaseFirstMetadata := make(chan struct{})
|
||||
firstRuleSetStarted := make(chan struct{})
|
||||
releaseFirstRuleSetStart := make(chan struct{})
|
||||
secondMetadataEntered := make(chan struct{})
|
||||
releaseSecondMetadata := make(chan struct{})
|
||||
|
||||
var metadataAccess sync.Mutex
|
||||
var metadataCallCount int
|
||||
var concurrentMetadataCalls int
|
||||
var maximumConcurrentMetadataCalls int
|
||||
|
||||
recordMetadataEntry := func() func() {
|
||||
metadataAccess.Lock()
|
||||
metadataCallCount++
|
||||
concurrentMetadataCalls++
|
||||
if concurrentMetadataCalls > maximumConcurrentMetadataCalls {
|
||||
maximumConcurrentMetadataCalls = concurrentMetadataCalls
|
||||
}
|
||||
metadataAccess.Unlock()
|
||||
return func() {
|
||||
metadataAccess.Lock()
|
||||
concurrentMetadataCalls--
|
||||
metadataAccess.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
firstBuildRuleSet := &fakeRuleSet{
|
||||
match: func(*adapter.InboundContext) bool {
|
||||
return true
|
||||
},
|
||||
metadataRead: func(metadata adapter.RuleSetMetadata) adapter.RuleSetMetadata {
|
||||
metadataDone := recordMetadataEntry()
|
||||
close(firstMetadataEntered)
|
||||
<-releaseFirstMetadata
|
||||
metadataDone()
|
||||
return metadata
|
||||
},
|
||||
afterIncrementReference: func() {
|
||||
close(firstRuleSetStarted)
|
||||
<-releaseFirstRuleSetStart
|
||||
},
|
||||
}
|
||||
secondBuildRuleSet := &fakeRuleSet{
|
||||
match: func(*adapter.InboundContext) bool {
|
||||
return false
|
||||
},
|
||||
metadataRead: func(metadata adapter.RuleSetMetadata) adapter.RuleSetMetadata {
|
||||
metadataDone := recordMetadataEntry()
|
||||
close(secondMetadataEntered)
|
||||
<-releaseSecondMetadata
|
||||
metadataDone()
|
||||
return metadata
|
||||
},
|
||||
}
|
||||
|
||||
routerService.setRuleSet("dynamic-set", firstBuildRuleSet)
|
||||
|
||||
firstCallbackFinished := make(chan struct{})
|
||||
go func() {
|
||||
callbacks[0](callbackRuleSet)
|
||||
close(firstCallbackFinished)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-firstMetadataEntered:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("first rebuild did not reach rule-set metadata")
|
||||
}
|
||||
|
||||
close(releaseFirstMetadata)
|
||||
|
||||
select {
|
||||
case <-firstRuleSetStarted:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("first rebuild did not reach rule-set start")
|
||||
}
|
||||
|
||||
routerService.setRuleSet("dynamic-set", secondBuildRuleSet)
|
||||
|
||||
secondCallbackStarted := make(chan struct{})
|
||||
secondCallbackFinished := make(chan struct{})
|
||||
go func() {
|
||||
close(secondCallbackStarted)
|
||||
callbacks[0](callbackRuleSet)
|
||||
close(secondCallbackFinished)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-secondCallbackStarted:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("second rebuild did not start")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-secondMetadataEntered:
|
||||
t.Fatal("second rebuild entered rule-set metadata before the first rebuild completed")
|
||||
default:
|
||||
}
|
||||
|
||||
close(releaseFirstRuleSetStart)
|
||||
|
||||
select {
|
||||
case <-firstCallbackFinished:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("first rebuild callback did not finish")
|
||||
}
|
||||
|
||||
select {
|
||||
case <-secondMetadataEntered:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("second rebuild did not enter rule-set metadata after the first rebuild finished")
|
||||
}
|
||||
|
||||
addresses, err = router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses)
|
||||
require.Equal(t, "first", lastUsedTransport.Load())
|
||||
|
||||
close(releaseSecondMetadata)
|
||||
|
||||
select {
|
||||
case <-secondCallbackFinished:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("second rebuild callback did not finish")
|
||||
}
|
||||
|
||||
metadataAccess.Lock()
|
||||
require.Equal(t, 2, metadataCallCount)
|
||||
require.Equal(t, 1, maximumConcurrentMetadataCalls)
|
||||
metadataAccess.Unlock()
|
||||
require.Zero(t, callbackRuleSet.refCount())
|
||||
require.Zero(t, firstBuildRuleSet.refCount())
|
||||
require.Equal(t, 1, secondBuildRuleSet.refCount())
|
||||
|
||||
lastUsedTransport.Store("")
|
||||
addresses, err = router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []netip.Addr{netip.MustParseAddr("10.0.0.1")}, addresses)
|
||||
require.Equal(t, "second", lastUsedTransport.Load())
|
||||
|
||||
err = router.Close()
|
||||
require.NoError(t, err)
|
||||
require.Zero(t, callbackRuleSet.refCount())
|
||||
require.Zero(t, firstBuildRuleSet.refCount())
|
||||
require.Zero(t, secondBuildRuleSet.refCount())
|
||||
}
|
||||
|
||||
func TestCloseDuringRebuildDiscardsResult(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fakeSet := &fakeRuleSet{
|
||||
metadata: adapter.RuleSetMetadata{
|
||||
ContainsIPCIDRRule: true,
|
||||
},
|
||||
}
|
||||
ctx := service.ContextWith[adapter.Router](context.Background(), &fakeRouter{
|
||||
ruleSets: map[string]adapter.RuleSet{
|
||||
"dynamic-set": fakeSet,
|
||||
@@ -560,42 +971,73 @@ func TestRuleSetUpdateSetsRuntimeErrorWhenRebuildFails(t *testing.T) {
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeRoute,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "default"},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Type: C.RuleTypeDefault,
|
||||
DefaultOptions: option.DefaultDNSRule{
|
||||
RawDefaultDNSRule: option.RawDefaultDNSRule{
|
||||
IPIsPrivate: true,
|
||||
},
|
||||
DNSRuleAction: option.DNSRuleAction{
|
||||
Action: C.RuleActionTypeRoute,
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "default"},
|
||||
RouteOptions: option.DNSRouteActionOptions{Server: "installed"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}, &fakeDNSTransportManager{
|
||||
defaultTransport: defaultTransport,
|
||||
transports: map[string]adapter.DNSTransport{
|
||||
"default": defaultTransport,
|
||||
"default": defaultTransport,
|
||||
"discarded": &fakeDNSTransport{tag: "discarded", transportType: C.DNSTypeUDP},
|
||||
"installed": &fakeDNSTransport{tag: "installed", transportType: C.DNSTypeUDP},
|
||||
},
|
||||
}, &fakeDNSClient{
|
||||
lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) {
|
||||
response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60)
|
||||
return MessageToAddresses(response), response, nil
|
||||
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||
switch transport.Tag() {
|
||||
case "discarded", "installed", "default":
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("10.0.0.1")}, 60), nil
|
||||
default:
|
||||
return nil, E.New("unexpected transport: ", transport.Tag())
|
||||
}
|
||||
},
|
||||
})
|
||||
|
||||
require.True(t, router.legacyDNSMode)
|
||||
require.Equal(t, 1, fakeSet.refCount())
|
||||
|
||||
fakeSet.updateMetadata(adapter.RuleSetMetadata{
|
||||
ContainsDNSQueryTypeRule: true,
|
||||
})
|
||||
callbacks := fakeSet.snapshotCallbacks()
|
||||
require.Len(t, callbacks, 1)
|
||||
|
||||
_, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
|
||||
require.ErrorContains(t, err, "ip_cidr and ip_is_private require match_response")
|
||||
firstMetadataEntered := make(chan struct{})
|
||||
releaseFirstMetadata := make(chan struct{})
|
||||
callbackFinished := make(chan struct{})
|
||||
fakeSet.metadataRead = func(metadata adapter.RuleSetMetadata) adapter.RuleSetMetadata {
|
||||
router.rawRules[0].DefaultOptions.RouteOptions.Server = "discarded"
|
||||
close(firstMetadataEntered)
|
||||
<-releaseFirstMetadata
|
||||
return adapter.RuleSetMetadata{}
|
||||
}
|
||||
|
||||
go func() {
|
||||
callbacks[0](fakeSet)
|
||||
close(callbackFinished)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-firstMetadataEntered:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("rebuild did not reach rule-set metadata")
|
||||
}
|
||||
|
||||
err := router.Close()
|
||||
require.NoError(t, err)
|
||||
close(releaseFirstMetadata)
|
||||
|
||||
select {
|
||||
case <-callbackFinished:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("rebuild callback did not finish after close")
|
||||
}
|
||||
|
||||
fakeSet.metadataRead = nil
|
||||
|
||||
router.rulesAccess.RLock()
|
||||
require.True(t, router.closing)
|
||||
require.Nil(t, router.rules)
|
||||
require.Empty(t, router.ruleSetCallbacks)
|
||||
router.rulesAccess.RUnlock()
|
||||
require.True(t, router.legacyDNSMode)
|
||||
require.Zero(t, fakeSet.refCount())
|
||||
}
|
||||
|
||||
func TestCloseIgnoresSnapshottedRuleSetCallback(t *testing.T) {
|
||||
@@ -661,7 +1103,6 @@ func TestCloseIgnoresSnapshottedRuleSetCallback(t *testing.T) {
|
||||
require.True(t, router.closing)
|
||||
require.Nil(t, router.rules)
|
||||
require.Empty(t, router.ruleSetCallbacks)
|
||||
require.NoError(t, router.runtimeRuleError)
|
||||
}
|
||||
|
||||
func TestLookupLegacyDNSModeDefersDirectDestinationIPMatch(t *testing.T) {
|
||||
@@ -680,7 +1121,7 @@ func TestLookupLegacyDNSModeDefersDirectDestinationIPMatch(t *testing.T) {
|
||||
case "default":
|
||||
t.Fatal("default transport should not be used when legacy rule matches after response")
|
||||
}
|
||||
return nil, nil, errors.New("unexpected transport")
|
||||
return nil, nil, E.New("unexpected transport")
|
||||
},
|
||||
}
|
||||
router := newTestRouter(t, []option.DNSRule{{
|
||||
@@ -716,12 +1157,23 @@ func TestLookupLegacyDNSModeFallsBackAfterRejectedAddressLimitResponse(t *testin
|
||||
|
||||
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
|
||||
privateTransport := &fakeDNSTransport{tag: "private", transportType: C.DNSTypeUDP}
|
||||
var lookups []string
|
||||
var lookupAccess sync.Mutex
|
||||
var lookupTags []string
|
||||
recordLookup := func(tag string) {
|
||||
lookupAccess.Lock()
|
||||
lookupTags = append(lookupTags, tag)
|
||||
lookupAccess.Unlock()
|
||||
}
|
||||
currentLookupTags := func() []string {
|
||||
lookupAccess.Lock()
|
||||
defer lookupAccess.Unlock()
|
||||
return append([]string(nil), lookupTags...)
|
||||
}
|
||||
client := &fakeDNSClient{
|
||||
lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) {
|
||||
require.Equal(t, "example.com", domain)
|
||||
require.Equal(t, C.DomainStrategyIPv4Only, options.LookupStrategy)
|
||||
lookups = append(lookups, transport.Tag())
|
||||
recordLookup(transport.Tag())
|
||||
switch transport.Tag() {
|
||||
case "private":
|
||||
response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60)
|
||||
@@ -730,7 +1182,7 @@ func TestLookupLegacyDNSModeFallsBackAfterRejectedAddressLimitResponse(t *testin
|
||||
response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("9.9.9.9")}, 60)
|
||||
return MessageToAddresses(response), response, nil
|
||||
}
|
||||
return nil, nil, errors.New("unexpected transport")
|
||||
return nil, nil, E.New("unexpected transport")
|
||||
},
|
||||
}
|
||||
router := newTestRouter(t, []option.DNSRule{{
|
||||
@@ -757,7 +1209,7 @@ func TestLookupLegacyDNSModeFallsBackAfterRejectedAddressLimitResponse(t *testin
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []netip.Addr{netip.MustParseAddr("9.9.9.9")}, addresses)
|
||||
require.Equal(t, []string{"private", "default"}, lookups)
|
||||
require.Equal(t, []string{"private", "default"}, currentLookupTags())
|
||||
}
|
||||
|
||||
func TestLookupLegacyDNSModeRuleSetAcceptEmptyDoesNotTreatMismatchAsEmpty(t *testing.T) {
|
||||
@@ -785,7 +1237,18 @@ func TestLookupLegacyDNSModeRuleSetAcceptEmptyDoesNotTreatMismatchAsEmpty(t *tes
|
||||
|
||||
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
|
||||
privateTransport := &fakeDNSTransport{tag: "private", transportType: C.DNSTypeUDP}
|
||||
var lookups []string
|
||||
var lookupAccess sync.Mutex
|
||||
var lookupTags []string
|
||||
recordLookup := func(tag string) {
|
||||
lookupAccess.Lock()
|
||||
lookupTags = append(lookupTags, tag)
|
||||
lookupAccess.Unlock()
|
||||
}
|
||||
currentLookupTags := func() []string {
|
||||
lookupAccess.Lock()
|
||||
defer lookupAccess.Unlock()
|
||||
return append([]string(nil), lookupTags...)
|
||||
}
|
||||
router := newTestRouterWithContext(t, ctx, []option.DNSRule{
|
||||
{
|
||||
Type: C.RuleTypeDefault,
|
||||
@@ -819,7 +1282,7 @@ func TestLookupLegacyDNSModeRuleSetAcceptEmptyDoesNotTreatMismatchAsEmpty(t *tes
|
||||
lookup: func(transport adapter.DNSTransport, domain string, options adapter.DNSQueryOptions) ([]netip.Addr, *mDNS.Msg, error) {
|
||||
require.Equal(t, "example.com", domain)
|
||||
require.Equal(t, C.DomainStrategyIPv4Only, options.LookupStrategy)
|
||||
lookups = append(lookups, transport.Tag())
|
||||
recordLookup(transport.Tag())
|
||||
switch transport.Tag() {
|
||||
case "private":
|
||||
response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60)
|
||||
@@ -828,7 +1291,7 @@ func TestLookupLegacyDNSModeRuleSetAcceptEmptyDoesNotTreatMismatchAsEmpty(t *tes
|
||||
response := FixedResponse(0, fixedQuestion(domain, mDNS.TypeA), []netip.Addr{netip.MustParseAddr("9.9.9.9")}, 60)
|
||||
return MessageToAddresses(response), response, nil
|
||||
}
|
||||
return nil, nil, errors.New("unexpected transport")
|
||||
return nil, nil, E.New("unexpected transport")
|
||||
},
|
||||
})
|
||||
|
||||
@@ -839,7 +1302,7 @@ func TestLookupLegacyDNSModeRuleSetAcceptEmptyDoesNotTreatMismatchAsEmpty(t *tes
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []netip.Addr{netip.MustParseAddr("9.9.9.9")}, addresses)
|
||||
require.Equal(t, []string{"private", "default"}, lookups)
|
||||
require.Equal(t, []string{"private", "default"}, currentLookupTags())
|
||||
}
|
||||
|
||||
func TestDNSResponseAddressesMatchesMessageToAddressesForHTTPSHints(t *testing.T) {
|
||||
@@ -872,7 +1335,7 @@ func TestExchangeLegacyDNSModeDisabledEvaluateMatchResponseRoute(t *testing.T) {
|
||||
case "selected":
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60), nil
|
||||
default:
|
||||
return nil, errors.New("unexpected transport")
|
||||
return nil, E.New("unexpected transport")
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -931,7 +1394,7 @@ func TestExchangeLegacyDNSModeDisabledEvaluateMatchResponseRouteIgnoresTTL(t *te
|
||||
case "selected":
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60), nil
|
||||
default:
|
||||
return nil, errors.New("unexpected transport")
|
||||
return nil, E.New("unexpected transport")
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -990,7 +1453,7 @@ func TestExchangeLegacyDNSModeDisabledEvaluateMatchResponseRouteWithHTTPSHints(t
|
||||
case "selected":
|
||||
return fixedHTTPSHintResponse(message.Question[0], netip.MustParseAddr("8.8.8.8")), nil
|
||||
default:
|
||||
return nil, errors.New("unexpected transport")
|
||||
return nil, E.New("unexpected transport")
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -1049,7 +1512,7 @@ func TestExchangeLegacyDNSModeDisabledEvaluateMatchResponseRouteWithMappedHTTPSI
|
||||
case "selected":
|
||||
return fixedHTTPSHintResponse(message.Question[0], netip.MustParseAddr("8.8.8.8")), nil
|
||||
default:
|
||||
return nil, errors.New("unexpected transport")
|
||||
return nil, E.New("unexpected transport")
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -1119,7 +1582,7 @@ func TestExchangeLegacyDNSModeDisabledEvaluateDoesNotLeakAddressesToNextQuery(t
|
||||
case "selected":
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60), nil
|
||||
default:
|
||||
return nil, errors.New("unexpected transport")
|
||||
return nil, E.New("unexpected transport")
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -1181,7 +1644,7 @@ func TestExchangeLegacyDNSModeDisabledEvaluateRouteResolutionFailureClearsRespon
|
||||
case "default":
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("4.4.4.4")}, 60), nil
|
||||
default:
|
||||
return nil, errors.New("unexpected transport")
|
||||
return nil, E.New("unexpected transport")
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -1268,13 +1731,13 @@ func TestExchangeLegacyDNSModeDisabledEvaluateExchangeFailureUsesMatchResponseBo
|
||||
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||
switch transport.Tag() {
|
||||
case "upstream":
|
||||
return nil, errors.New("upstream exchange failed")
|
||||
return nil, E.New("upstream exchange failed")
|
||||
case "selected":
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("8.8.8.8")}, 60), nil
|
||||
case "default":
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("4.4.4.4")}, 60), nil
|
||||
default:
|
||||
return nil, errors.New("unexpected transport")
|
||||
return nil, E.New("unexpected transport")
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -1332,9 +1795,9 @@ func TestLookupLegacyDNSModeDisabledAllowsPartialSuccess(t *testing.T) {
|
||||
case mDNS.TypeA:
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("1.1.1.1")}, 60), nil
|
||||
case mDNS.TypeAAAA:
|
||||
return nil, errors.New("ipv6 failed")
|
||||
return nil, E.New("ipv6 failed")
|
||||
default:
|
||||
return nil, errors.New("unexpected qtype")
|
||||
return nil, E.New("unexpected qtype")
|
||||
}
|
||||
},
|
||||
})
|
||||
@@ -1451,7 +1914,7 @@ func TestLookupLegacyDNSModeDisabledEvaluateSkipFakeIPPreservesResponse(t *testi
|
||||
}
|
||||
return FixedResponse(0, message.Question[0], nil, 60), nil
|
||||
default:
|
||||
return nil, errors.New("unexpected transport")
|
||||
return nil, E.New("unexpected transport")
|
||||
}
|
||||
},
|
||||
})
|
||||
@@ -1494,7 +1957,7 @@ func TestLookupLegacyDNSModeDisabledUsesQueryTypeRule(t *testing.T) {
|
||||
case "only-a":
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("9.9.9.9")}, 60), nil
|
||||
default:
|
||||
return nil, errors.New("unexpected transport")
|
||||
return nil, E.New("unexpected transport")
|
||||
}
|
||||
},
|
||||
})
|
||||
@@ -1560,7 +2023,7 @@ func TestLookupLegacyDNSModeDisabledUsesRuleSetQueryTypeRule(t *testing.T) {
|
||||
}
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2001:db8::9")}, 60), nil
|
||||
default:
|
||||
return nil, errors.New("unexpected transport")
|
||||
return nil, E.New("unexpected transport")
|
||||
}
|
||||
},
|
||||
})
|
||||
@@ -1609,7 +2072,7 @@ func TestLookupLegacyDNSModeDisabledUsesIPVersionRule(t *testing.T) {
|
||||
}
|
||||
return FixedResponse(0, message.Question[0], nil, 60), nil
|
||||
default:
|
||||
return nil, errors.New("unexpected transport")
|
||||
return nil, E.New("unexpected transport")
|
||||
}
|
||||
},
|
||||
})
|
||||
@@ -1829,7 +2292,18 @@ func TestLookupLegacyDNSModeDisabledUsesInputStrategy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
|
||||
var qTypes []uint16
|
||||
var queryTypeAccess sync.Mutex
|
||||
var queryTypes []uint16
|
||||
recordQueryType := func(queryType uint16) {
|
||||
queryTypeAccess.Lock()
|
||||
queryTypes = append(queryTypes, queryType)
|
||||
queryTypeAccess.Unlock()
|
||||
}
|
||||
currentQueryTypes := func() []uint16 {
|
||||
queryTypeAccess.Lock()
|
||||
defer queryTypeAccess.Unlock()
|
||||
return append([]uint16(nil), queryTypes...)
|
||||
}
|
||||
router := newTestRouter(t, nil, &fakeDNSTransportManager{
|
||||
defaultTransport: defaultTransport,
|
||||
transports: map[string]adapter.DNSTransport{
|
||||
@@ -1837,7 +2311,7 @@ func TestLookupLegacyDNSModeDisabledUsesInputStrategy(t *testing.T) {
|
||||
},
|
||||
}, &fakeDNSClient{
|
||||
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||
qTypes = append(qTypes, message.Question[0].Qtype)
|
||||
recordQueryType(message.Question[0].Qtype)
|
||||
if message.Question[0].Qtype == mDNS.TypeA {
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2.2.2.2")}, 60), nil
|
||||
}
|
||||
@@ -1850,7 +2324,7 @@ func TestLookupLegacyDNSModeDisabledUsesInputStrategy(t *testing.T) {
|
||||
Strategy: C.DomainStrategyIPv4Only,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []uint16{mDNS.TypeA}, qTypes)
|
||||
require.Equal(t, []uint16{mDNS.TypeA}, currentQueryTypes())
|
||||
require.Equal(t, []netip.Addr{netip.MustParseAddr("2.2.2.2")}, addresses)
|
||||
}
|
||||
|
||||
@@ -1858,7 +2332,18 @@ func TestLookupLegacyDNSModeDisabledUsesDefaultStrategy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
defaultTransport := &fakeDNSTransport{tag: "default", transportType: C.DNSTypeUDP}
|
||||
var qTypes []uint16
|
||||
var queryTypeAccess sync.Mutex
|
||||
var queryTypes []uint16
|
||||
recordQueryType := func(queryType uint16) {
|
||||
queryTypeAccess.Lock()
|
||||
queryTypes = append(queryTypes, queryType)
|
||||
queryTypeAccess.Unlock()
|
||||
}
|
||||
currentQueryTypes := func() []uint16 {
|
||||
queryTypeAccess.Lock()
|
||||
defer queryTypeAccess.Unlock()
|
||||
return append([]uint16(nil), queryTypes...)
|
||||
}
|
||||
router := newTestRouter(t, nil, &fakeDNSTransportManager{
|
||||
defaultTransport: defaultTransport,
|
||||
transports: map[string]adapter.DNSTransport{
|
||||
@@ -1866,7 +2351,7 @@ func TestLookupLegacyDNSModeDisabledUsesDefaultStrategy(t *testing.T) {
|
||||
},
|
||||
}, &fakeDNSClient{
|
||||
exchange: func(transport adapter.DNSTransport, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||
qTypes = append(qTypes, message.Question[0].Qtype)
|
||||
recordQueryType(message.Question[0].Qtype)
|
||||
if message.Question[0].Qtype == mDNS.TypeA {
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("2.2.2.2")}, 60), nil
|
||||
}
|
||||
@@ -1878,7 +2363,7 @@ func TestLookupLegacyDNSModeDisabledUsesDefaultStrategy(t *testing.T) {
|
||||
|
||||
addresses, err := router.Lookup(context.Background(), "example.com", adapter.DNSQueryOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []uint16{mDNS.TypeA}, qTypes)
|
||||
require.Equal(t, []uint16{mDNS.TypeA}, currentQueryTypes())
|
||||
require.Equal(t, []netip.Addr{netip.MustParseAddr("2.2.2.2")}, addresses)
|
||||
}
|
||||
|
||||
@@ -1903,7 +2388,7 @@ func TestExchangeLegacyDNSModeDisabledLogicalMatchResponseIPCIDRFallsThrough(t *
|
||||
case "default":
|
||||
return FixedResponse(0, message.Question[0], []netip.Addr{netip.MustParseAddr("4.4.4.4")}, 60), nil
|
||||
default:
|
||||
return nil, errors.New("unexpected transport")
|
||||
return nil, E.New("unexpected transport")
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user