refactor: Outbound domain resolver

This commit is contained in:
世界
2025-01-12 12:45:27 +08:00
parent 364a055f77
commit 2cb5ff521a
25 changed files with 93 additions and 36 deletions

View File

@@ -8,6 +8,7 @@ import (
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/experimental/deprecated"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
@@ -15,7 +16,7 @@ import (
"github.com/sagernet/sing/service"
)
func New(ctx context.Context, options option.DialerOptions) (N.Dialer, error) {
func New(ctx context.Context, options option.DialerOptions, remoteIsDomain bool) (N.Dialer, error) {
if options.IsWireGuardListener {
return NewDefault(ctx, options)
}
@@ -35,13 +36,25 @@ func New(ctx context.Context, options option.DialerOptions) (N.Dialer, error) {
}
dialer = NewDetour(outboundManager, options.Detour)
}
if options.Detour == "" {
if remoteIsDomain && options.Detour == "" && options.DomainResolver == "" {
deprecated.Report(ctx, deprecated.OptionMissingDomainResolverInDialOptions)
}
if (options.Detour == "" && remoteIsDomain) || options.DomainResolver != "" {
router := service.FromContext[adapter.DNSRouter](ctx)
if router != nil {
var resolveTransport adapter.DNSTransport
if options.DomainResolver != "" {
transport, loaded := service.FromContext[adapter.DNSTransportManager](ctx).Transport(options.DomainResolver)
if !loaded {
return nil, E.New("DNS server not found: " + options.DomainResolver)
}
resolveTransport = transport
}
dialer = NewResolveDialer(
router,
dialer,
options.Detour == "" && !options.TCPFastOpen,
resolveTransport,
C.DomainStrategy(options.DomainStrategy),
time.Duration(options.FallbackDelay))
}
@@ -60,10 +73,19 @@ func NewDirect(ctx context.Context, options option.DialerOptions) (ParallelInter
if err != nil {
return nil, err
}
var resolveTransport adapter.DNSTransport
if options.DomainResolver != "" {
transport, loaded := service.FromContext[adapter.DNSTransportManager](ctx).Transport(options.DomainResolver)
if !loaded {
return nil, E.New("DNS server not found: " + options.DomainResolver)
}
resolveTransport = transport
}
return NewResolveParallelInterfaceDialer(
service.FromContext[adapter.DNSRouter](ctx),
dialer,
true,
resolveTransport,
C.DomainStrategy(options.DomainStrategy),
time.Duration(options.FallbackDelay),
), nil

View File

@@ -22,15 +22,17 @@ type resolveDialer struct {
dialer N.Dialer
parallel bool
router adapter.DNSRouter
transport adapter.DNSTransport
strategy C.DomainStrategy
fallbackDelay time.Duration
}
func NewResolveDialer(router adapter.DNSRouter, dialer N.Dialer, parallel bool, strategy C.DomainStrategy, fallbackDelay time.Duration) N.Dialer {
func NewResolveDialer(router adapter.DNSRouter, dialer N.Dialer, parallel bool, transport adapter.DNSTransport, strategy C.DomainStrategy, fallbackDelay time.Duration) N.Dialer {
return &resolveDialer{
dialer,
parallel,
router,
transport,
strategy,
fallbackDelay,
}
@@ -41,12 +43,13 @@ type resolveParallelNetworkDialer struct {
dialer ParallelInterfaceDialer
}
func NewResolveParallelInterfaceDialer(router adapter.DNSRouter, dialer ParallelInterfaceDialer, parallel bool, strategy C.DomainStrategy, fallbackDelay time.Duration) ParallelInterfaceDialer {
func NewResolveParallelInterfaceDialer(router adapter.DNSRouter, dialer ParallelInterfaceDialer, parallel bool, transport adapter.DNSTransport, strategy C.DomainStrategy, fallbackDelay time.Duration) ParallelInterfaceDialer {
return &resolveParallelNetworkDialer{
resolveDialer{
dialer,
parallel,
router,
transport,
strategy,
fallbackDelay,
},
@@ -59,7 +62,7 @@ func (d *resolveDialer) DialContext(ctx context.Context, network string, destina
return d.dialer.DialContext(ctx, network, destination)
}
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
addresses, err := d.router.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{Strategy: d.strategy})
addresses, err := d.router.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{Transport: d.transport, Strategy: d.strategy})
if err != nil {
return nil, err
}
@@ -75,7 +78,7 @@ func (d *resolveDialer) ListenPacket(ctx context.Context, destination M.Socksadd
return d.dialer.ListenPacket(ctx, destination)
}
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
addresses, err := d.router.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{Strategy: d.strategy})
addresses, err := d.router.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{Transport: d.transport, Strategy: d.strategy})
if err != nil {
return nil, err
}
@@ -91,9 +94,7 @@ func (d *resolveParallelNetworkDialer) DialParallelInterface(ctx context.Context
return d.dialer.DialContext(ctx, network, destination)
}
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
addresses, err := d.router.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{
Strategy: d.strategy,
})
addresses, err := d.router.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{Transport: d.transport, Strategy: d.strategy})
if err != nil {
return nil, err
}
@@ -112,7 +113,7 @@ func (d *resolveParallelNetworkDialer) ListenSerialInterfacePacket(ctx context.C
return d.dialer.ListenPacket(ctx, destination)
}
ctx = log.ContextWithOverrideLevel(ctx, log.LevelDebug)
addresses, err := d.router.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{Strategy: d.strategy})
addresses, err := d.router.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{Transport: d.transport, Strategy: d.strategy})
if err != nil {
return nil, err
}

View File

@@ -101,7 +101,7 @@ func NewRealityServer(ctx context.Context, logger log.Logger, options option.Inb
tlsConfig.ShortIds[shortID] = true
}
handshakeDialer, err := dialer.New(ctx, options.Reality.Handshake.DialerOptions)
handshakeDialer, err := dialer.New(ctx, options.Reality.Handshake.DialerOptions, options.Reality.Handshake.ServerIsDomain())
if err != nil {
return nil, err
}