From aa8dd6e44fa7dd21066b7975951a106474e4f1e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B8=96=E7=95=8C?= Date: Mon, 29 Dec 2025 20:44:30 +0800 Subject: [PATCH] Fix DNS transports --- adapter/dns.go | 1 + dns/router.go | 2 +- dns/transport/base.go | 145 ++++++++++++++++ dns/transport/connector.go | 205 ++++++++++++++++++++++ dns/transport/dhcp/dhcp.go | 7 + dns/transport/fakeip/memory.go | 4 + dns/transport/fakeip/store.go | 38 ++-- dns/transport/hosts/hosts.go | 3 + dns/transport/https.go | 14 +- dns/transport/local/local.go | 3 + dns/transport/local/local_darwin.go | 6 + dns/transport/quic/http3.go | 71 ++++++-- dns/transport/quic/quic.go | 138 +++++++++------ dns/transport/tcp.go | 13 +- dns/transport/tls.go | 52 ++++-- dns/transport/udp.go | 259 +++++++++++++++------------- experimental/libbox/dns.go | 3 + service/resolved/transport.go | 10 ++ 18 files changed, 754 insertions(+), 220 deletions(-) create mode 100644 dns/transport/base.go create mode 100644 dns/transport/connector.go diff --git a/adapter/dns.go b/adapter/dns.go index bf73f4e5a..8f065e2e8 100644 --- a/adapter/dns.go +++ b/adapter/dns.go @@ -68,6 +68,7 @@ type DNSTransport interface { Type() string Tag() string Dependencies() []string + Reset() Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error) } diff --git a/dns/router.go b/dns/router.go index 1038fdf08..e82cab290 100644 --- a/dns/router.go +++ b/dns/router.go @@ -444,6 +444,6 @@ func (r *Router) LookupReverseMapping(ip netip.Addr) (string, bool) { func (r *Router) ResetNetwork() { r.ClearCache() for _, transport := range r.transport.Transports() { - transport.Close() + transport.Reset() } } diff --git a/dns/transport/base.go b/dns/transport/base.go new file mode 100644 index 000000000..06e41fd02 --- /dev/null +++ b/dns/transport/base.go @@ -0,0 +1,145 @@ +package transport + +import ( + "context" + "os" + "sync" + + C "github.com/sagernet/sing-box/constant" + "github.com/sagernet/sing-box/dns" + E "github.com/sagernet/sing/common/exceptions" + "github.com/sagernet/sing/common/logger" +) + +type TransportState int + +const ( + StateNew TransportState = iota + StateStarted + StateClosing + StateClosed +) + +var ( + ErrTransportClosed = os.ErrClosed + ErrConnectionReset = E.New("connection reset") +) + +type BaseTransport struct { + dns.TransportAdapter + Logger logger.ContextLogger + + mutex sync.Mutex + state TransportState + inFlight int32 + queriesComplete chan struct{} + closeCtx context.Context + closeCancel context.CancelFunc +} + +func NewBaseTransport(adapter dns.TransportAdapter, logger logger.ContextLogger) *BaseTransport { + ctx, cancel := context.WithCancel(context.Background()) + return &BaseTransport{ + TransportAdapter: adapter, + Logger: logger, + state: StateNew, + closeCtx: ctx, + closeCancel: cancel, + } +} + +func (t *BaseTransport) State() TransportState { + t.mutex.Lock() + defer t.mutex.Unlock() + return t.state +} + +func (t *BaseTransport) SetStarted() error { + t.mutex.Lock() + defer t.mutex.Unlock() + switch t.state { + case StateNew: + t.state = StateStarted + return nil + case StateStarted: + return nil + default: + return ErrTransportClosed + } +} + +func (t *BaseTransport) BeginQuery() bool { + t.mutex.Lock() + defer t.mutex.Unlock() + if t.state != StateStarted { + return false + } + t.inFlight++ + return true +} + +func (t *BaseTransport) EndQuery() { + t.mutex.Lock() + if t.inFlight > 0 { + t.inFlight-- + } + if t.inFlight == 0 && t.queriesComplete != nil { + close(t.queriesComplete) + t.queriesComplete = nil + } + t.mutex.Unlock() +} + +func (t *BaseTransport) CloseContext() context.Context { + return t.closeCtx +} + +func (t *BaseTransport) Shutdown(ctx context.Context) error { + t.mutex.Lock() + + if t.state >= StateClosing { + t.mutex.Unlock() + return nil + } + + if t.state == StateNew { + t.state = StateClosed + t.mutex.Unlock() + t.closeCancel() + return nil + } + + t.state = StateClosing + + if t.inFlight == 0 { + t.state = StateClosed + t.mutex.Unlock() + t.closeCancel() + return nil + } + + t.queriesComplete = make(chan struct{}) + queriesComplete := t.queriesComplete + t.mutex.Unlock() + + t.closeCancel() + + select { + case <-queriesComplete: + t.mutex.Lock() + t.state = StateClosed + t.mutex.Unlock() + return nil + case <-ctx.Done(): + t.mutex.Lock() + t.state = StateClosed + t.mutex.Unlock() + return ctx.Err() + } +} + +func (t *BaseTransport) Close() error { + ctx, cancel := context.WithTimeout(context.Background(), C.TCPTimeout) + defer cancel() + return t.Shutdown(ctx) +} diff --git a/dns/transport/connector.go b/dns/transport/connector.go new file mode 100644 index 000000000..18fad0a5e --- /dev/null +++ b/dns/transport/connector.go @@ -0,0 +1,205 @@ +package transport + +import ( + "context" + "net" + "sync" +) + +type ConnectorCallbacks[T any] struct { + IsClosed func(connection T) bool + Close func(connection T) + Reset func(connection T) +} + +type Connector[T any] struct { + dial func(ctx context.Context) (T, error) + callbacks ConnectorCallbacks[T] + + access sync.Mutex + connection T + hasConnection bool + connecting chan struct{} + + closeCtx context.Context + closed bool +} + +func NewConnector[T any](closeCtx context.Context, dial func(context.Context) (T, error), callbacks ConnectorCallbacks[T]) *Connector[T] { + return &Connector[T]{ + dial: dial, + callbacks: callbacks, + closeCtx: closeCtx, + } +} + +func NewSingleflightConnector(closeCtx context.Context, dial func(context.Context) (*Connection, error)) *Connector[*Connection] { + return NewConnector(closeCtx, dial, ConnectorCallbacks[*Connection]{ + IsClosed: func(connection *Connection) bool { + return connection.IsClosed() + }, + Close: func(connection *Connection) { + connection.CloseWithError(ErrTransportClosed) + }, + Reset: func(connection *Connection) { + connection.CloseWithError(ErrConnectionReset) + }, + }) +} + +func (c *Connector[T]) Get(ctx context.Context) (T, error) { + var zero T + for { + c.access.Lock() + + if c.closed { + c.access.Unlock() + return zero, ErrTransportClosed + } + + if c.hasConnection && !c.callbacks.IsClosed(c.connection) { + connection := c.connection + c.access.Unlock() + return connection, nil + } + + c.hasConnection = false + + if c.connecting != nil { + connecting := c.connecting + c.access.Unlock() + + select { + case <-connecting: + continue + case <-ctx.Done(): + return zero, ctx.Err() + case <-c.closeCtx.Done(): + return zero, ErrTransportClosed + } + } + + c.connecting = make(chan struct{}) + c.access.Unlock() + + connection, err := c.dialWithCancellation(ctx) + + c.access.Lock() + close(c.connecting) + c.connecting = nil + + if err != nil { + c.access.Unlock() + return zero, err + } + + if c.closed { + c.callbacks.Close(connection) + c.access.Unlock() + return zero, ErrTransportClosed + } + + c.connection = connection + c.hasConnection = true + result := c.connection + c.access.Unlock() + + return result, nil + } +} + +func (c *Connector[T]) dialWithCancellation(ctx context.Context) (T, error) { + dialCtx, cancel := context.WithCancel(ctx) + defer cancel() + + go func() { + select { + case <-c.closeCtx.Done(): + cancel() + case <-dialCtx.Done(): + } + }() + + return c.dial(dialCtx) +} + +func (c *Connector[T]) Close() error { + c.access.Lock() + defer c.access.Unlock() + + if c.closed { + return nil + } + c.closed = true + + if c.hasConnection { + c.callbacks.Close(c.connection) + c.hasConnection = false + } + + return nil +} + +func (c *Connector[T]) Reset() { + c.access.Lock() + defer c.access.Unlock() + + if c.hasConnection { + c.callbacks.Reset(c.connection) + c.hasConnection = false + } +} + +type Connection struct { + net.Conn + + closeOnce sync.Once + done chan struct{} + closeError error +} + +func WrapConnection(conn net.Conn) *Connection { + return &Connection{ + Conn: conn, + done: make(chan struct{}), + } +} + +func (c *Connection) Done() <-chan struct{} { + return c.done +} + +func (c *Connection) IsClosed() bool { + select { + case <-c.done: + return true + default: + return false + } +} + +func (c *Connection) CloseError() error { + select { + case <-c.done: + if c.closeError != nil { + return c.closeError + } + return ErrTransportClosed + default: + return nil + } +} + +func (c *Connection) Close() error { + return c.CloseWithError(ErrTransportClosed) +} + +func (c *Connection) CloseWithError(err error) error { + var returnError error + c.closeOnce.Do(func() { + c.closeError = err + returnError = c.Conn.Close() + close(c.done) + }) + return returnError +} diff --git a/dns/transport/dhcp/dhcp.go b/dns/transport/dhcp/dhcp.go index 3f13d1d95..3f4eb7212 100644 --- a/dns/transport/dhcp/dhcp.go +++ b/dns/transport/dhcp/dhcp.go @@ -108,6 +108,13 @@ func (t *Transport) Close() error { return nil } +func (t *Transport) Reset() { + t.transportLock.Lock() + t.updatedAt = time.Time{} + t.servers = nil + t.transportLock.Unlock() +} + func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { servers, err := t.fetch() if err != nil { diff --git a/dns/transport/fakeip/memory.go b/dns/transport/fakeip/memory.go index 1640ab349..0cf8ecc7d 100644 --- a/dns/transport/fakeip/memory.go +++ b/dns/transport/fakeip/memory.go @@ -82,8 +82,12 @@ func (s *MemoryStorage) FakeIPLoadDomain(domain string, isIPv6 bool) (netip.Addr } func (s *MemoryStorage) FakeIPReset() error { + s.addressAccess.Lock() + s.domainAccess.Lock() s.addressCache = make(map[netip.Addr]string) s.domainCache4 = make(map[string]netip.Addr) s.domainCache6 = make(map[string]netip.Addr) + s.domainAccess.Unlock() + s.addressAccess.Unlock() return nil } diff --git a/dns/transport/fakeip/store.go b/dns/transport/fakeip/store.go index 83677b0d0..4c09ed7a9 100644 --- a/dns/transport/fakeip/store.go +++ b/dns/transport/fakeip/store.go @@ -3,6 +3,7 @@ package fakeip import ( "context" "net/netip" + "sync" "github.com/sagernet/sing-box/adapter" E "github.com/sagernet/sing/common/exceptions" @@ -13,13 +14,15 @@ import ( var _ adapter.FakeIPStore = (*Store)(nil) type Store struct { - ctx context.Context - logger logger.Logger - inet4Range netip.Prefix - inet6Range netip.Prefix - storage adapter.FakeIPStorage - inet4Current netip.Addr - inet6Current netip.Addr + ctx context.Context + logger logger.Logger + inet4Range netip.Prefix + inet6Range netip.Prefix + storage adapter.FakeIPStorage + + addressAccess sync.Mutex + inet4Current netip.Addr + inet6Current netip.Addr } func NewStore(ctx context.Context, logger logger.Logger, inet4Range netip.Prefix, inet6Range netip.Prefix) *Store { @@ -65,18 +68,30 @@ func (s *Store) Close() error { if s.storage == nil { return nil } - return s.storage.FakeIPSaveMetadata(&adapter.FakeIPMetadata{ + s.addressAccess.Lock() + metadata := &adapter.FakeIPMetadata{ Inet4Range: s.inet4Range, Inet6Range: s.inet6Range, Inet4Current: s.inet4Current, Inet6Current: s.inet6Current, - }) + } + s.addressAccess.Unlock() + return s.storage.FakeIPSaveMetadata(metadata) } func (s *Store) Create(domain string, isIPv6 bool) (netip.Addr, error) { if address, loaded := s.storage.FakeIPLoadDomain(domain, isIPv6); loaded { return address, nil } + + s.addressAccess.Lock() + defer s.addressAccess.Unlock() + + // Double-check after acquiring lock + if address, loaded := s.storage.FakeIPLoadDomain(domain, isIPv6); loaded { + return address, nil + } + var address netip.Addr if !isIPv6 { if !s.inet4Current.IsValid() { @@ -99,7 +114,10 @@ func (s *Store) Create(domain string, isIPv6 bool) (netip.Addr, error) { s.inet6Current = nextAddress address = nextAddress } - s.storage.FakeIPStoreAsync(address, domain, s.logger) + err := s.storage.FakeIPStore(address, domain) + if err != nil { + s.logger.Warn("save FakeIP cache: ", err) + } s.storage.FakeIPSaveMetadataAsync(&adapter.FakeIPMetadata{ Inet4Range: s.inet4Range, Inet6Range: s.inet6Range, diff --git a/dns/transport/hosts/hosts.go b/dns/transport/hosts/hosts.go index a5eecb402..f0e70a9a3 100644 --- a/dns/transport/hosts/hosts.go +++ b/dns/transport/hosts/hosts.go @@ -59,6 +59,9 @@ func (t *Transport) Close() error { return nil } +func (t *Transport) Reset() { +} + func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { question := message.Question[0] domain := mDNS.CanonicalName(question.Name) diff --git a/dns/transport/https.go b/dns/transport/https.go index 95fe7ed19..b508e6eae 100644 --- a/dns/transport/https.go +++ b/dns/transport/https.go @@ -145,6 +145,13 @@ func (t *HTTPSTransport) Close() error { return nil } +func (t *HTTPSTransport) Reset() { + t.transportAccess.Lock() + defer t.transportAccess.Unlock() + t.transport.CloseIdleConnections() + t.transport = t.transport.Clone() +} + func (t *HTTPSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { startAt := time.Now() response, err := t.exchange(ctx, message) @@ -182,7 +189,10 @@ func (t *HTTPSTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS request.Header = t.headers.Clone() request.Header.Set("Content-Type", MimeType) request.Header.Set("Accept", MimeType) - response, err := t.transport.RoundTrip(request) + t.transportAccess.Lock() + currentTransport := t.transport + t.transportAccess.Unlock() + response, err := currentTransport.RoundTrip(request) requestBuffer.Release() if err != nil { return nil, err @@ -194,12 +204,12 @@ func (t *HTTPSTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS var responseMessage mDNS.Msg if response.ContentLength > 0 { responseBuffer := buf.NewSize(int(response.ContentLength)) + defer responseBuffer.Release() _, err = responseBuffer.ReadFullFrom(response.Body, int(response.ContentLength)) if err != nil { return nil, err } err = responseMessage.Unpack(responseBuffer.Bytes()) - responseBuffer.Release() } else { rawMessage, err = io.ReadAll(response.Body) if err != nil { diff --git a/dns/transport/local/local.go b/dns/transport/local/local.go index 51b8c18c5..a42abc764 100644 --- a/dns/transport/local/local.go +++ b/dns/transport/local/local.go @@ -76,6 +76,9 @@ func (t *Transport) Close() error { return nil } +func (t *Transport) Reset() { +} + func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { if t.resolved != nil { resolverObject := t.resolved.Object() diff --git a/dns/transport/local/local_darwin.go b/dns/transport/local/local_darwin.go index ee759b914..5f1e60b15 100644 --- a/dns/transport/local/local_darwin.go +++ b/dns/transport/local/local_darwin.go @@ -92,6 +92,12 @@ func (t *Transport) Close() error { ) } +func (t *Transport) Reset() { + if t.dhcpTransport != nil { + t.dhcpTransport.Reset() + } +} + func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { question := message.Question[0] if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA { diff --git a/dns/transport/quic/http3.go b/dns/transport/quic/http3.go index 0459d685c..c3a5ca81c 100644 --- a/dns/transport/quic/http3.go +++ b/dns/transport/quic/http3.go @@ -8,10 +8,12 @@ import ( "net/http" "net/url" "strconv" + "sync" "github.com/sagernet/quic-go" "github.com/sagernet/quic-go/http3" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/tls" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/dns" @@ -23,6 +25,7 @@ import ( "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" "github.com/sagernet/sing/common/logger" + M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" sHTTP "github.com/sagernet/sing/protocol/http" @@ -37,11 +40,14 @@ func RegisterHTTP3Transport(registry *dns.TransportRegistry) { type HTTP3Transport struct { dns.TransportAdapter - logger logger.ContextLogger - dialer N.Dialer - destination *url.URL - headers http.Header - transport *http3.Transport + logger logger.ContextLogger + dialer N.Dialer + destination *url.URL + headers http.Header + serverAddr M.Socksaddr + tlsConfig *tls.STDConfig + transportAccess sync.Mutex + transport *http3.Transport } func NewHTTP3(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteHTTPSDNSServerOptions) (adapter.DNSTransport, error) { @@ -95,33 +101,57 @@ func NewHTTP3(ctx context.Context, logger log.ContextLogger, tag string, options if !serverAddr.IsValid() { return nil, E.New("invalid server address: ", serverAddr) } - return &HTTP3Transport{ + t := &HTTP3Transport{ TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeHTTP3, tag, options.RemoteDNSServerOptions), logger: logger, dialer: transportDialer, destination: &destinationURL, headers: headers, - transport: &http3.Transport{ - Dial: func(ctx context.Context, addr string, tlsCfg *tls.STDConfig, cfg *quic.Config) (*quic.Conn, error) { - conn, dialErr := transportDialer.DialContext(ctx, N.NetworkUDP, serverAddr) - if dialErr != nil { - return nil, dialErr - } - return quic.DialEarly(ctx, bufio.NewUnbindPacketConn(conn), conn.RemoteAddr(), tlsCfg, cfg) - }, - TLSClientConfig: stdConfig, + serverAddr: serverAddr, + tlsConfig: stdConfig, + } + t.transport = t.newTransport() + return t, nil +} + +func (t *HTTP3Transport) newTransport() *http3.Transport { + return &http3.Transport{ + Dial: func(ctx context.Context, addr string, tlsCfg *tls.STDConfig, cfg *quic.Config) (*quic.Conn, error) { + conn, dialErr := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr) + if dialErr != nil { + return nil, dialErr + } + quicConn, dialErr := quic.DialEarly(ctx, bufio.NewUnbindPacketConn(conn), conn.RemoteAddr(), tlsCfg, cfg) + if dialErr != nil { + conn.Close() + return nil, dialErr + } + return quicConn, nil }, - }, nil + TLSClientConfig: t.tlsConfig, + } } func (t *HTTP3Transport) Start(stage adapter.StartStage) error { - return nil + if stage != adapter.StartStateStart { + return nil + } + return dialer.InitializeDetour(t.dialer) } func (t *HTTP3Transport) Close() error { + t.transportAccess.Lock() + defer t.transportAccess.Unlock() return t.transport.Close() } +func (t *HTTP3Transport) Reset() { + t.transportAccess.Lock() + defer t.transportAccess.Unlock() + t.transport.Close() + t.transport = t.newTransport() +} + func (t *HTTP3Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { exMessage := *message exMessage.Id = 0 @@ -140,7 +170,10 @@ func (t *HTTP3Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS request.Header = t.headers.Clone() request.Header.Set("Content-Type", transport.MimeType) request.Header.Set("Accept", transport.MimeType) - response, err := t.transport.RoundTrip(request) + t.transportAccess.Lock() + currentTransport := t.transport + t.transportAccess.Unlock() + response, err := currentTransport.RoundTrip(request) requestBuffer.Release() if err != nil { return nil, err @@ -152,12 +185,12 @@ func (t *HTTP3Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS var responseMessage mDNS.Msg if response.ContentLength > 0 { responseBuffer := buf.NewSize(int(response.ContentLength)) + defer responseBuffer.Release() _, err = responseBuffer.ReadFullFrom(response.Body, int(response.ContentLength)) if err != nil { return nil, err } err = responseMessage.Unpack(responseBuffer.Bytes()) - responseBuffer.Release() } else { rawMessage, err = io.ReadAll(response.Body) if err != nil { diff --git a/dns/transport/quic/quic.go b/dns/transport/quic/quic.go index a54cddcb9..264610069 100644 --- a/dns/transport/quic/quic.go +++ b/dns/transport/quic/quic.go @@ -3,10 +3,11 @@ package quic import ( "context" "errors" - "sync" + "os" "github.com/sagernet/quic-go" "github.com/sagernet/sing-box/adapter" + "github.com/sagernet/sing-box/common/dialer" "github.com/sagernet/sing-box/common/tls" C "github.com/sagernet/sing-box/constant" "github.com/sagernet/sing-box/dns" @@ -17,7 +18,6 @@ import ( "github.com/sagernet/sing/common" "github.com/sagernet/sing/common/bufio" E "github.com/sagernet/sing/common/exceptions" - "github.com/sagernet/sing/common/logger" M "github.com/sagernet/sing/common/metadata" N "github.com/sagernet/sing/common/network" @@ -31,14 +31,14 @@ func RegisterTransport(registry *dns.TransportRegistry) { } type Transport struct { - dns.TransportAdapter + *transport.BaseTransport + ctx context.Context - logger logger.ContextLogger dialer N.Dialer serverAddr M.Socksaddr tlsConfig tls.Config - access sync.Mutex - connection *quic.Conn + + connector *transport.Connector[*quic.Conn] } func NewQUIC(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteTLSDNSServerOptions) (adapter.DNSTransport, error) { @@ -62,38 +62,84 @@ func NewQUIC(ctx context.Context, logger log.ContextLogger, tag string, options if !serverAddr.IsValid() { return nil, E.New("invalid server address: ", serverAddr) } - return &Transport{ - TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeQUIC, tag, options.RemoteDNSServerOptions), - ctx: ctx, - logger: logger, - dialer: transportDialer, - serverAddr: serverAddr, - tlsConfig: tlsConfig, - }, nil + + t := &Transport{ + BaseTransport: transport.NewBaseTransport( + dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeQUIC, tag, options.RemoteDNSServerOptions), + logger, + ), + ctx: ctx, + dialer: transportDialer, + serverAddr: serverAddr, + tlsConfig: tlsConfig, + } + + t.connector = transport.NewConnector(t.CloseContext(), t.dial, transport.ConnectorCallbacks[*quic.Conn]{ + IsClosed: func(connection *quic.Conn) bool { + return common.Done(connection.Context()) + }, + Close: func(connection *quic.Conn) { + connection.CloseWithError(0, "") + }, + Reset: func(connection *quic.Conn) { + connection.CloseWithError(0, "") + }, + }) + + return t, nil +} + +func (t *Transport) dial(ctx context.Context) (*quic.Conn, error) { + conn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr) + if err != nil { + return nil, E.Cause(err, "dial UDP connection") + } + earlyConnection, err := sQUIC.DialEarly( + ctx, + bufio.NewUnbindPacketConn(conn), + t.serverAddr.UDPAddr(), + t.tlsConfig, + nil, + ) + if err != nil { + conn.Close() + return nil, E.Cause(err, "establish QUIC connection") + } + return earlyConnection, nil } func (t *Transport) Start(stage adapter.StartStage) error { - return nil + if stage != adapter.StartStateStart { + return nil + } + err := t.SetStarted() + if err != nil { + return err + } + return dialer.InitializeDetour(t.dialer) } func (t *Transport) Close() error { - t.access.Lock() - defer t.access.Unlock() - connection := t.connection - if connection != nil { - connection.CloseWithError(0, "") - } - return nil + return E.Errors(t.BaseTransport.Close(), t.connector.Close()) +} + +func (t *Transport) Reset() { + t.connector.Reset() } func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { + if !t.BeginQuery() { + return nil, transport.ErrTransportClosed + } + defer t.EndQuery() + var ( conn *quic.Conn err error response *mDNS.Msg ) for i := 0; i < 2; i++ { - conn, err = t.openConnection() + conn, err = t.connector.Get(ctx) if err != nil { return nil, err } @@ -103,58 +149,38 @@ func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, } else if !isQUICRetryError(err) { return nil, err } else { - conn.CloseWithError(quic.ApplicationErrorCode(0), "") + t.connector.Reset() continue } } return nil, err } -func (t *Transport) openConnection() (*quic.Conn, error) { - connection := t.connection - if connection != nil && !common.Done(connection.Context()) { - return connection, nil - } - t.access.Lock() - defer t.access.Unlock() - connection = t.connection - if connection != nil && !common.Done(connection.Context()) { - return connection, nil - } - conn, err := t.dialer.DialContext(t.ctx, N.NetworkUDP, t.serverAddr) - if err != nil { - return nil, err - } - earlyConnection, err := sQUIC.DialEarly( - t.ctx, - bufio.NewUnbindPacketConn(conn), - t.serverAddr.UDPAddr(), - t.tlsConfig, - nil, - ) - if err != nil { - return nil, err - } - t.connection = earlyConnection - return earlyConnection, nil -} - func (t *Transport) exchange(ctx context.Context, message *mDNS.Msg, conn *quic.Conn) (*mDNS.Msg, error) { stream, err := conn.OpenStreamSync(ctx) if err != nil { - return nil, err + return nil, E.Cause(err, "open stream") } + defer stream.CancelRead(0) err = transport.WriteMessage(stream, 0, message) if err != nil { stream.Close() - return nil, err + return nil, E.Cause(err, "write request") } stream.Close() - return transport.ReadMessage(stream) + response, err := transport.ReadMessage(stream) + if err != nil { + return nil, E.Cause(err, "read response") + } + return response, nil } // https://github.com/AdguardTeam/dnsproxy/blob/fd1868577652c639cce3da00e12ca548f421baf1/upstream/upstream_quic.go#L394 func isQUICRetryError(err error) (ok bool) { + if errors.Is(err, os.ErrClosed) { + return true + } + var qAppErr *quic.ApplicationError if errors.As(err, &qAppErr) && qAppErr.ErrorCode == 0 { return true diff --git a/dns/transport/tcp.go b/dns/transport/tcp.go index 3039c5742..59333de8d 100644 --- a/dns/transport/tcp.go +++ b/dns/transport/tcp.go @@ -62,17 +62,24 @@ func (t *TCPTransport) Close() error { return nil } +func (t *TCPTransport) Reset() { +} + func (t *TCPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { conn, err := t.dialer.DialContext(ctx, N.NetworkTCP, t.serverAddr) if err != nil { - return nil, err + return nil, E.Cause(err, "dial TCP connection") } defer conn.Close() err = WriteMessage(conn, 0, message) if err != nil { - return nil, err + return nil, E.Cause(err, "write request") } - return ReadMessage(conn) + response, err := ReadMessage(conn) + if err != nil { + return nil, E.Cause(err, "read response") + } + return response, nil } func ReadMessage(reader io.Reader) (*mDNS.Msg, error) { diff --git a/dns/transport/tls.go b/dns/transport/tls.go index 932a72a8e..4d463296b 100644 --- a/dns/transport/tls.go +++ b/dns/transport/tls.go @@ -3,6 +3,7 @@ package transport import ( "context" "sync" + "time" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" @@ -28,8 +29,8 @@ func RegisterTLS(registry *dns.TransportRegistry) { } type TLSTransport struct { - dns.TransportAdapter - logger logger.ContextLogger + *BaseTransport + dialer tls.Dialer serverAddr M.Socksaddr tlsConfig tls.Config @@ -65,11 +66,10 @@ func NewTLS(ctx context.Context, logger log.ContextLogger, tag string, options o func NewTLSRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer N.Dialer, serverAddr M.Socksaddr, tlsConfig tls.Config) *TLSTransport { return &TLSTransport{ - TransportAdapter: adapter, - logger: logger, - dialer: tls.NewDialer(dialer, tlsConfig), - serverAddr: serverAddr, - tlsConfig: tlsConfig, + BaseTransport: NewBaseTransport(adapter, logger), + dialer: tls.NewDialer(dialer, tlsConfig), + serverAddr: serverAddr, + tlsConfig: tlsConfig, } } @@ -77,37 +77,59 @@ func (t *TLSTransport) Start(stage adapter.StartStage) error { if stage != adapter.StartStateStart { return nil } + err := t.SetStarted() + if err != nil { + return err + } return dialer.InitializeDetour(t.dialer) } func (t *TLSTransport) Close() error { + t.access.Lock() + for connection := t.connections.Front(); connection != nil; connection = connection.Next() { + connection.Value.Close() + } + t.connections.Init() + t.access.Unlock() + return t.BaseTransport.Close() +} + +func (t *TLSTransport) Reset() { t.access.Lock() defer t.access.Unlock() for connection := t.connections.Front(); connection != nil; connection = connection.Next() { connection.Value.Close() } t.connections.Init() - return nil } func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { + if !t.BeginQuery() { + return nil, ErrTransportClosed + } + defer t.EndQuery() + t.access.Lock() conn := t.connections.PopFront() t.access.Unlock() if conn != nil { - response, err := t.exchange(message, conn) + response, err := t.exchange(ctx, message, conn) if err == nil { return response, nil } + t.Logger.DebugContext(ctx, "discarded pooled connection: ", err) } tlsConn, err := t.dialer.DialTLSContext(ctx, t.serverAddr) if err != nil { - return nil, err + return nil, E.Cause(err, "dial TLS connection") } - return t.exchange(message, &tlsDNSConn{Conn: tlsConn}) + return t.exchange(ctx, message, &tlsDNSConn{Conn: tlsConn}) } -func (t *TLSTransport) exchange(message *mDNS.Msg, conn *tlsDNSConn) (*mDNS.Msg, error) { +func (t *TLSTransport) exchange(ctx context.Context, message *mDNS.Msg, conn *tlsDNSConn) (*mDNS.Msg, error) { + if deadline, ok := ctx.Deadline(); ok { + conn.SetDeadline(deadline) + } conn.queryId++ err := WriteMessage(conn, conn.queryId, message) if err != nil { @@ -120,6 +142,12 @@ func (t *TLSTransport) exchange(message *mDNS.Msg, conn *tlsDNSConn) (*mDNS.Msg, return nil, E.Cause(err, "read response") } t.access.Lock() + if t.State() >= StateClosing { + t.access.Unlock() + conn.Close() + return response, nil + } + conn.SetDeadline(time.Time{}) t.connections.PushBack(conn) t.access.Unlock() return response, nil diff --git a/dns/transport/udp.go b/dns/transport/udp.go index 48924c650..a72725458 100644 --- a/dns/transport/udp.go +++ b/dns/transport/udp.go @@ -2,9 +2,8 @@ package transport import ( "context" - "net" - "os" "sync" + "sync/atomic" "github.com/sagernet/sing-box/adapter" "github.com/sagernet/sing-box/common/dialer" @@ -28,15 +27,23 @@ func RegisterUDP(registry *dns.TransportRegistry) { } type UDPTransport struct { - dns.TransportAdapter - logger logger.ContextLogger - dialer N.Dialer - serverAddr M.Socksaddr - udpSize int - tcpTransport *TCPTransport - access sync.Mutex - conn *dnsConnection - done chan struct{} + *BaseTransport + + dialer N.Dialer + serverAddr M.Socksaddr + udpSize atomic.Int32 + + connector *Connector[*Connection] + + callbackAccess sync.RWMutex + queryId uint16 + callbacks map[uint16]*udpCallback +} + +type udpCallback struct { + access sync.Mutex + response *mDNS.Msg + done chan struct{} } func NewUDP(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteDNSServerOptions) (adapter.DNSTransport, error) { @@ -54,180 +61,198 @@ func NewUDP(ctx context.Context, logger log.ContextLogger, tag string, options o return NewUDPRaw(logger, dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeUDP, tag, options), transportDialer, serverAddr), nil } -func NewUDPRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer N.Dialer, serverAddr M.Socksaddr) *UDPTransport { - return &UDPTransport{ - TransportAdapter: adapter, - logger: logger, - dialer: dialer, - serverAddr: serverAddr, - udpSize: 2048, - tcpTransport: &TCPTransport{ - dialer: dialer, - serverAddr: serverAddr, - }, - done: make(chan struct{}), +func NewUDPRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialerInstance N.Dialer, serverAddr M.Socksaddr) *UDPTransport { + t := &UDPTransport{ + BaseTransport: NewBaseTransport(adapter, logger), + dialer: dialerInstance, + serverAddr: serverAddr, + callbacks: make(map[uint16]*udpCallback), } + t.udpSize.Store(2048) + t.connector = NewSingleflightConnector(t.CloseContext(), t.dial) + return t +} + +func (t *UDPTransport) dial(ctx context.Context) (*Connection, error) { + rawConn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr) + if err != nil { + return nil, E.Cause(err, "dial UDP connection") + } + conn := WrapConnection(rawConn) + go t.recvLoop(conn) + return conn, nil } func (t *UDPTransport) Start(stage adapter.StartStage) error { if stage != adapter.StartStateStart { return nil } + err := t.SetStarted() + if err != nil { + return err + } return dialer.InitializeDetour(t.dialer) } func (t *UDPTransport) Close() error { - t.access.Lock() - defer t.access.Unlock() - close(t.done) - t.done = make(chan struct{}) - return nil + return E.Errors(t.BaseTransport.Close(), t.connector.Close()) +} + +func (t *UDPTransport) Reset() { + t.connector.Reset() +} + +func (t *UDPTransport) nextAvailableQueryId() (uint16, error) { + start := t.queryId + for { + t.queryId++ + if _, exists := t.callbacks[t.queryId]; !exists { + return t.queryId, nil + } + if t.queryId == start { + return 0, E.New("no available query ID") + } + } } func (t *UDPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { + if !t.BeginQuery() { + return nil, ErrTransportClosed + } + defer t.EndQuery() + response, err := t.exchange(ctx, message) if err != nil { return nil, err } if response.Truncated { - t.logger.InfoContext(ctx, "response truncated, retrying with TCP") - return t.tcpTransport.Exchange(ctx, message) + t.Logger.InfoContext(ctx, "response truncated, retrying with TCP") + return t.exchangeTCP(ctx, message) + } + return response, nil +} + +func (t *UDPTransport) exchangeTCP(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { + conn, err := t.dialer.DialContext(ctx, N.NetworkTCP, t.serverAddr) + if err != nil { + return nil, E.Cause(err, "dial TCP connection") + } + defer conn.Close() + err = WriteMessage(conn, message.Id, message) + if err != nil { + return nil, E.Cause(err, "write request") + } + response, err := ReadMessage(conn) + if err != nil { + return nil, E.Cause(err, "read response") } return response, nil } func (t *UDPTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { - t.access.Lock() if edns0Opt := message.IsEdns0(); edns0Opt != nil { - if udpSize := int(edns0Opt.UDPSize()); udpSize > t.udpSize { - t.udpSize = udpSize - close(t.done) - t.done = make(chan struct{}) + udpSize := int32(edns0Opt.UDPSize()) + for { + current := t.udpSize.Load() + if udpSize <= current { + break + } + if t.udpSize.CompareAndSwap(current, udpSize) { + t.connector.Reset() + break + } } } - t.access.Unlock() - conn, err := t.open(ctx) + + conn, err := t.connector.Get(ctx) if err != nil { return nil, err } - buffer := buf.NewSize(1 + message.Len()) - defer buffer.Release() - exMessage := *message - exMessage.Compress = true - messageId := message.Id - callback := &dnsCallback{ + + callback := &udpCallback{ done: make(chan struct{}), } - conn.access.Lock() - conn.queryId++ - exMessage.Id = conn.queryId - conn.callbacks[exMessage.Id] = callback - conn.access.Unlock() + + t.callbackAccess.Lock() + queryId, err := t.nextAvailableQueryId() + if err != nil { + t.callbackAccess.Unlock() + return nil, err + } + t.callbacks[queryId] = callback + t.callbackAccess.Unlock() + defer func() { - conn.access.Lock() - delete(conn.callbacks, exMessage.Id) - conn.access.Unlock() + t.callbackAccess.Lock() + delete(t.callbacks, queryId) + t.callbackAccess.Unlock() }() + + buffer := buf.NewSize(1 + message.Len()) + defer buffer.Release() + + exMessage := *message + exMessage.Compress = true + originalId := message.Id + exMessage.Id = queryId + rawMessage, err := exMessage.PackBuffer(buffer.FreeBytes()) if err != nil { return nil, err } + _, err = conn.Write(rawMessage) if err != nil { - conn.Close(err) - return nil, err + conn.CloseWithError(err) + return nil, E.Cause(err, "write request") } + select { case <-callback.done: - callback.message.Id = messageId - return callback.message, nil - case <-conn.done: - return nil, conn.err - case <-t.done: - return nil, os.ErrClosed + callback.response.Id = originalId + return callback.response, nil + case <-conn.Done(): + return nil, conn.CloseError() + case <-t.CloseContext().Done(): + return nil, ErrTransportClosed case <-ctx.Done(): - conn.Close(ctx.Err()) return nil, ctx.Err() } } -func (t *UDPTransport) open(ctx context.Context) (*dnsConnection, error) { - t.access.Lock() - defer t.access.Unlock() - if t.conn != nil { - select { - case <-t.conn.done: - default: - return t.conn, nil - } - } - conn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr) - if err != nil { - return nil, err - } - dnsConn := &dnsConnection{ - Conn: conn, - done: make(chan struct{}), - callbacks: make(map[uint16]*dnsCallback), - } - go t.recvLoop(dnsConn) - t.conn = dnsConn - return dnsConn, nil -} - -func (t *UDPTransport) recvLoop(conn *dnsConnection) { +func (t *UDPTransport) recvLoop(conn *Connection) { for { - buffer := buf.NewSize(t.udpSize) + buffer := buf.NewSize(int(t.udpSize.Load())) _, err := buffer.ReadOnceFrom(conn) if err != nil { buffer.Release() - conn.Close(err) + conn.CloseWithError(err) return } + var message mDNS.Msg err = message.Unpack(buffer.Bytes()) buffer.Release() if err != nil { - conn.Close(err) - return + t.Logger.Debug("discarded malformed UDP response: ", err) + continue } - conn.access.RLock() - callback, loaded := conn.callbacks[message.Id] - conn.access.RUnlock() + + t.callbackAccess.RLock() + callback, loaded := t.callbacks[message.Id] + t.callbackAccess.RUnlock() + if !loaded { continue } + callback.access.Lock() select { case <-callback.done: default: - callback.message = &message + callback.response = &message close(callback.done) } callback.access.Unlock() } } - -type dnsConnection struct { - net.Conn - access sync.RWMutex - done chan struct{} - closeOnce sync.Once - err error - queryId uint16 - callbacks map[uint16]*dnsCallback -} - -func (c *dnsConnection) Close(err error) { - c.closeOnce.Do(func() { - c.err = err - close(c.done) - }) - c.Conn.Close() -} - -type dnsCallback struct { - access sync.Mutex - message *mDNS.Msg - done chan struct{} -} diff --git a/experimental/libbox/dns.go b/experimental/libbox/dns.go index d5c97b7ee..b7b3b0f67 100644 --- a/experimental/libbox/dns.go +++ b/experimental/libbox/dns.go @@ -46,6 +46,9 @@ func (p *platformTransport) Close() error { return nil } +func (p *platformTransport) Reset() { +} + func (p *platformTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) { response := &ExchangeContext{ context: ctx, diff --git a/service/resolved/transport.go b/service/resolved/transport.go index c54a63416..ac20663ae 100644 --- a/service/resolved/transport.go +++ b/service/resolved/transport.go @@ -110,6 +110,16 @@ func (t *Transport) Close() error { return nil } +func (t *Transport) Reset() { + t.linkAccess.RLock() + defer t.linkAccess.RUnlock() + for _, servers := range t.linkServers { + for _, server := range servers.Servers { + server.Reset() + } + } +} + func (t *Transport) updateTransports(link *TransportLink) error { t.linkAccess.Lock() defer t.linkAccess.Unlock()