refactor: DNS

This commit is contained in:
世界
2024-12-02 23:17:01 +08:00
parent 8e2baf40f1
commit 7372d239a4
89 changed files with 4792 additions and 1733 deletions

View File

@@ -12,7 +12,6 @@ import (
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
@@ -34,7 +33,7 @@ type Outbound struct {
outbound.Adapter
logger logger.ContextLogger
dialer dialer.ParallelInterfaceDialer
domainStrategy dns.DomainStrategy
domainStrategy C.DomainStrategy
fallbackDelay time.Duration
overrideOption int
overrideDestination M.Socksaddr
@@ -50,7 +49,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeDirect, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions),
logger: logger,
domainStrategy: dns.DomainStrategy(options.DomainStrategy),
domainStrategy: C.DomainStrategy(options.DomainStrategy),
fallbackDelay: time.Duration(options.FallbackDelay),
dialer: outboundDialer,
// loopBack: newLoopBackDetector(router),
@@ -151,26 +150,26 @@ func (h *Outbound) DialParallel(ctx context.Context, network string, destination
case N.NetworkUDP:
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
}
var domainStrategy dns.DomainStrategy
if h.domainStrategy != dns.DomainStrategyAsIS {
var domainStrategy C.DomainStrategy
if h.domainStrategy != C.DomainStrategyAsIS {
domainStrategy = h.domainStrategy
} else {
//nolint:staticcheck
domainStrategy = dns.DomainStrategy(metadata.InboundOptions.DomainStrategy)
domainStrategy = C.DomainStrategy(metadata.InboundOptions.DomainStrategy)
}
switch domainStrategy {
case dns.DomainStrategyUseIPv4:
case C.DomainStrategyIPv4Only:
destinationAddresses = common.Filter(destinationAddresses, netip.Addr.Is4)
if len(destinationAddresses) == 0 {
return nil, E.New("no IPv4 address available for ", destination)
}
case dns.DomainStrategyUseIPv6:
case C.DomainStrategyIPv6Only:
destinationAddresses = common.Filter(destinationAddresses, netip.Addr.Is6)
if len(destinationAddresses) == 0 {
return nil, E.New("no IPv6 address available for ", destination)
}
}
return dialer.DialParallelNetwork(ctx, h.dialer, network, destination, destinationAddresses, domainStrategy == dns.DomainStrategyPreferIPv6, nil, nil, nil, h.fallbackDelay)
return dialer.DialParallelNetwork(ctx, h.dialer, network, destination, destinationAddresses, domainStrategy == C.DomainStrategyPreferIPv6, nil, nil, nil, h.fallbackDelay)
}
func (h *Outbound) DialParallelNetwork(ctx context.Context, network string, destination M.Socksaddr, destinationAddresses []netip.Addr, networkStrategy *C.NetworkStrategy, networkType []C.InterfaceType, fallbackNetworkType []C.InterfaceType, fallbackDelay time.Duration) (net.Conn, error) {
@@ -191,26 +190,26 @@ func (h *Outbound) DialParallelNetwork(ctx context.Context, network string, dest
case N.NetworkUDP:
h.logger.InfoContext(ctx, "outbound packet connection to ", destination)
}
var domainStrategy dns.DomainStrategy
if h.domainStrategy != dns.DomainStrategyAsIS {
var domainStrategy C.DomainStrategy
if h.domainStrategy != C.DomainStrategyAsIS {
domainStrategy = h.domainStrategy
} else {
//nolint:staticcheck
domainStrategy = dns.DomainStrategy(metadata.InboundOptions.DomainStrategy)
domainStrategy = C.DomainStrategy(metadata.InboundOptions.DomainStrategy)
}
switch domainStrategy {
case dns.DomainStrategyUseIPv4:
case C.DomainStrategyIPv4Only:
destinationAddresses = common.Filter(destinationAddresses, netip.Addr.Is4)
if len(destinationAddresses) == 0 {
return nil, E.New("no IPv4 address available for ", destination)
}
case dns.DomainStrategyUseIPv6:
case C.DomainStrategyIPv6Only:
destinationAddresses = common.Filter(destinationAddresses, netip.Addr.Is6)
if len(destinationAddresses) == 0 {
return nil, E.New("no IPv6 address available for ", destination)
}
}
return dialer.DialParallelNetwork(ctx, h.dialer, network, destination, destinationAddresses, domainStrategy == dns.DomainStrategyPreferIPv6, networkStrategy, networkType, fallbackNetworkType, fallbackDelay)
return dialer.DialParallelNetwork(ctx, h.dialer, network, destination, destinationAddresses, domainStrategy == C.DomainStrategyPreferIPv6, networkStrategy, networkType, fallbackNetworkType, fallbackDelay)
}
func (h *Outbound) ListenSerialNetworkPacket(ctx context.Context, destination M.Socksaddr, destinationAddresses []netip.Addr, networkStrategy *C.NetworkStrategy, networkType []C.InterfaceType, fallbackNetworkType []C.InterfaceType, fallbackDelay time.Duration) (net.PacketConn, netip.Addr, error) {

View File

@@ -7,7 +7,7 @@ import (
"github.com/sagernet/sing-box/adapter"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing-box/dns"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/buf"
"github.com/sagernet/sing/common/bufio"
@@ -19,7 +19,7 @@ import (
mDNS "github.com/miekg/dns"
)
func HandleStreamDNSRequest(ctx context.Context, router adapter.Router, conn net.Conn, metadata adapter.InboundContext) error {
func HandleStreamDNSRequest(ctx context.Context, router adapter.DNSRouter, conn net.Conn, metadata adapter.InboundContext) error {
var queryLength uint16
err := binary.Read(conn, binary.BigEndian, &queryLength)
if err != nil {
@@ -41,7 +41,7 @@ func HandleStreamDNSRequest(ctx context.Context, router adapter.Router, conn net
}
metadataInQuery := metadata
go func() error {
response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message, adapter.DNSQueryOptions{})
if err != nil {
conn.Close()
return err
@@ -61,7 +61,7 @@ func HandleStreamDNSRequest(ctx context.Context, router adapter.Router, conn net
return nil
}
func NewDNSPacketConnection(ctx context.Context, router adapter.Router, conn N.PacketConn, cachedPackets []*N.PacketBuffer, metadata adapter.InboundContext) error {
func NewDNSPacketConnection(ctx context.Context, router adapter.DNSRouter, conn N.PacketConn, cachedPackets []*N.PacketBuffer, metadata adapter.InboundContext) error {
metadata.Destination = M.Socksaddr{}
var reader N.PacketReader = conn
var counters []N.CountFunc
@@ -123,7 +123,7 @@ func NewDNSPacketConnection(ctx context.Context, router adapter.Router, conn N.P
}
metadataInQuery := metadata
go func() error {
response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message, adapter.DNSQueryOptions{})
if err != nil {
cancel(err)
return err
@@ -148,7 +148,7 @@ func NewDNSPacketConnection(ctx context.Context, router adapter.Router, conn N.P
return group.Run(fastClose)
}
func newDNSPacketConnection(ctx context.Context, router adapter.Router, conn N.PacketConn, readWaiter N.PacketReadWaiter, readCounters []N.CountFunc, cached []*N.PacketBuffer, metadata adapter.InboundContext) error {
func newDNSPacketConnection(ctx context.Context, router adapter.DNSRouter, conn N.PacketConn, readWaiter N.PacketReadWaiter, readCounters []N.CountFunc, cached []*N.PacketBuffer, metadata adapter.InboundContext) error {
fastClose, cancel := common.ContextWithCancelCause(ctx)
timeout := canceler.New(fastClose, cancel, C.DNSTimeout)
var group task.Group
@@ -193,7 +193,7 @@ func newDNSPacketConnection(ctx context.Context, router adapter.Router, conn N.P
}
metadataInQuery := metadata
go func() error {
response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message)
response, err := router.Exchange(adapter.WithContext(ctx, &metadataInQuery), &message, adapter.DNSQueryOptions{})
if err != nil {
cancel(err)
return err

View File

@@ -14,6 +14,7 @@ import (
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
)
func RegisterOutbound(registry *outbound.Registry) {
@@ -22,14 +23,14 @@ func RegisterOutbound(registry *outbound.Registry) {
type Outbound struct {
outbound.Adapter
router adapter.Router
router adapter.DNSRouter
logger logger.ContextLogger
}
func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.StubOptions) (adapter.Outbound, error) {
return &Outbound{
Adapter: outbound.NewAdapter(C.TypeDNS, tag, []string{N.NetworkTCP, N.NetworkUDP}, nil),
router: router,
router: service.FromContext[adapter.DNSRouter](ctx),
logger: logger,
}, nil
}

View File

@@ -17,6 +17,7 @@ import (
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/uot"
"github.com/sagernet/sing/protocol/socks"
"github.com/sagernet/sing/service"
)
func RegisterOutbound(registry *outbound.Registry) {
@@ -27,7 +28,7 @@ var _ adapter.Outbound = (*Outbound)(nil)
type Outbound struct {
outbound.Adapter
router adapter.Router
dnsRouter adapter.DNSRouter
logger logger.ContextLogger
client *socks.Client
resolve bool
@@ -50,11 +51,11 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
return nil, err
}
outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSOCKS, tag, options.Network.Build(), options.DialerOptions),
router: router,
logger: logger,
client: socks.NewClient(outboundDialer, options.ServerOptions.Build(), version, options.Username, options.Password),
resolve: version == socks.Version4,
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeSOCKS, tag, options.Network.Build(), options.DialerOptions),
dnsRouter: service.FromContext[adapter.DNSRouter](ctx),
logger: logger,
client: socks.NewClient(outboundDialer, options.ServerOptions.Build(), version, options.Username, options.Password),
resolve: version == socks.Version4,
}
uotOptions := common.PtrValueOrDefault(options.UDPOverTCP)
if uotOptions.Enabled {
@@ -83,7 +84,7 @@ func (h *Outbound) DialContext(ctx context.Context, network string, destination
return nil, E.Extend(N.ErrUnknownNetwork, network)
}
if h.resolve && destination.IsFqdn() {
destinationAddresses, err := h.router.LookupDefault(ctx, destination.Fqdn)
destinationAddresses, err := h.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{})
if err != nil {
return nil, err
}
@@ -101,7 +102,7 @@ func (h *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (n
return h.uotClient.ListenPacket(ctx, destination)
}
if h.resolve && destination.IsFqdn() {
destinationAddresses, err := h.router.LookupDefault(ctx, destination.Fqdn)
destinationAddresses, err := h.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{})
if err != nil {
return nil, err
}

View File

@@ -13,13 +13,13 @@ import (
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/wireguard"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
)
func RegisterEndpoint(registry *endpoint.Registry) {
@@ -35,6 +35,7 @@ type Endpoint struct {
endpoint.Adapter
ctx context.Context
router adapter.Router
dnsRouter adapter.DNSRouter
logger logger.ContextLogger
localAddresses []netip.Prefix
endpoint *wireguard.Endpoint
@@ -45,6 +46,7 @@ func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextL
Adapter: endpoint.NewAdapterWithDialerOptions(C.TypeWireGuard, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions),
ctx: ctx,
router: router,
dnsRouter: service.FromContext[adapter.DNSRouter](ctx),
logger: logger,
localAddresses: options.Address,
}
@@ -79,7 +81,9 @@ func NewEndpoint(ctx context.Context, router adapter.Router, logger log.ContextL
PrivateKey: options.PrivateKey,
ListenPort: options.ListenPort,
ResolvePeer: func(domain string) (netip.Addr, error) {
endpointAddresses, lookupErr := router.Lookup(ctx, domain, dns.DomainStrategy(options.DomainStrategy))
endpointAddresses, lookupErr := ep.dnsRouter.Lookup(ctx, domain, adapter.DNSQueryOptions{
Strategy: C.DomainStrategy(options.DomainStrategy),
})
if lookupErr != nil {
return netip.Addr{}, lookupErr
}
@@ -185,7 +189,7 @@ func (w *Endpoint) DialContext(ctx context.Context, network string, destination
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
}
if destination.IsFqdn() {
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
destinationAddresses, err := w.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{})
if err != nil {
return nil, err
}
@@ -199,7 +203,7 @@ func (w *Endpoint) DialContext(ctx context.Context, network string, destination
func (w *Endpoint) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
w.logger.InfoContext(ctx, "outbound packet connection to ", destination)
if destination.IsFqdn() {
destinationAddresses, err := w.router.LookupDefault(ctx, destination.Fqdn)
destinationAddresses, err := w.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{})
if err != nil {
return nil, err
}

View File

@@ -13,12 +13,12 @@ import (
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-box/transport/wireguard"
"github.com/sagernet/sing-dns"
"github.com/sagernet/sing/common"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/logger"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/service"
)
func RegisterOutbound(registry *outbound.Registry) {
@@ -33,7 +33,7 @@ var (
type Outbound struct {
outbound.Adapter
ctx context.Context
router adapter.Router
dnsRouter adapter.DNSRouter
logger logger.ContextLogger
localAddresses []netip.Prefix
endpoint *wireguard.Endpoint
@@ -47,7 +47,7 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
outbound := &Outbound{
Adapter: outbound.NewAdapterWithDialerOptions(C.TypeWireGuard, tag, []string{N.NetworkTCP, N.NetworkUDP}, options.DialerOptions),
ctx: ctx,
router: router,
dnsRouter: service.FromContext[adapter.DNSRouter](ctx),
logger: logger,
localAddresses: options.LocalAddress,
}
@@ -94,7 +94,9 @@ func NewOutbound(ctx context.Context, router adapter.Router, logger log.ContextL
Address: options.LocalAddress,
PrivateKey: options.PrivateKey,
ResolvePeer: func(domain string) (netip.Addr, error) {
endpointAddresses, lookupErr := router.Lookup(ctx, domain, dns.DomainStrategy(options.DomainStrategy))
endpointAddresses, lookupErr := outbound.dnsRouter.Lookup(ctx, domain, adapter.DNSQueryOptions{
Strategy: C.DomainStrategy(options.DomainStrategy),
})
if lookupErr != nil {
return netip.Addr{}, lookupErr
}
@@ -137,7 +139,7 @@ func (o *Outbound) DialContext(ctx context.Context, network string, destination
o.logger.InfoContext(ctx, "outbound packet connection to ", destination)
}
if destination.IsFqdn() {
destinationAddresses, err := o.router.LookupDefault(ctx, destination.Fqdn)
destinationAddresses, err := o.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{})
if err != nil {
return nil, err
}
@@ -151,7 +153,7 @@ func (o *Outbound) DialContext(ctx context.Context, network string, destination
func (o *Outbound) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
o.logger.InfoContext(ctx, "outbound packet connection to ", destination)
if destination.IsFqdn() {
destinationAddresses, err := o.router.LookupDefault(ctx, destination.Fqdn)
destinationAddresses, err := o.dnsRouter.Lookup(ctx, destination.Fqdn, adapter.DNSQueryOptions{})
if err != nil {
return nil, err
}