mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-11 17:47:20 +10:00
Fix DNS transports
This commit is contained in:
@@ -68,6 +68,7 @@ type DNSTransport interface {
|
|||||||
Type() string
|
Type() string
|
||||||
Tag() string
|
Tag() string
|
||||||
Dependencies() []string
|
Dependencies() []string
|
||||||
|
Reset()
|
||||||
Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error)
|
Exchange(ctx context.Context, message *dns.Msg) (*dns.Msg, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -444,6 +444,6 @@ func (r *Router) LookupReverseMapping(ip netip.Addr) (string, bool) {
|
|||||||
func (r *Router) ResetNetwork() {
|
func (r *Router) ResetNetwork() {
|
||||||
r.ClearCache()
|
r.ClearCache()
|
||||||
for _, transport := range r.transport.Transports() {
|
for _, transport := range r.transport.Transports() {
|
||||||
transport.Close()
|
transport.Reset()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
145
dns/transport/base.go
Normal file
145
dns/transport/base.go
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
package transport
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
C "github.com/sagernet/sing-box/constant"
|
||||||
|
"github.com/sagernet/sing-box/dns"
|
||||||
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TransportState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
StateNew TransportState = iota
|
||||||
|
StateStarted
|
||||||
|
StateClosing
|
||||||
|
StateClosed
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrTransportClosed = os.ErrClosed
|
||||||
|
ErrConnectionReset = E.New("connection reset")
|
||||||
|
)
|
||||||
|
|
||||||
|
type BaseTransport struct {
|
||||||
|
dns.TransportAdapter
|
||||||
|
Logger logger.ContextLogger
|
||||||
|
|
||||||
|
mutex sync.Mutex
|
||||||
|
state TransportState
|
||||||
|
inFlight int32
|
||||||
|
queriesComplete chan struct{}
|
||||||
|
closeCtx context.Context
|
||||||
|
closeCancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewBaseTransport(adapter dns.TransportAdapter, logger logger.ContextLogger) *BaseTransport {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
return &BaseTransport{
|
||||||
|
TransportAdapter: adapter,
|
||||||
|
Logger: logger,
|
||||||
|
state: StateNew,
|
||||||
|
closeCtx: ctx,
|
||||||
|
closeCancel: cancel,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *BaseTransport) State() TransportState {
|
||||||
|
t.mutex.Lock()
|
||||||
|
defer t.mutex.Unlock()
|
||||||
|
return t.state
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *BaseTransport) SetStarted() error {
|
||||||
|
t.mutex.Lock()
|
||||||
|
defer t.mutex.Unlock()
|
||||||
|
switch t.state {
|
||||||
|
case StateNew:
|
||||||
|
t.state = StateStarted
|
||||||
|
return nil
|
||||||
|
case StateStarted:
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return ErrTransportClosed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *BaseTransport) BeginQuery() bool {
|
||||||
|
t.mutex.Lock()
|
||||||
|
defer t.mutex.Unlock()
|
||||||
|
if t.state != StateStarted {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
t.inFlight++
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *BaseTransport) EndQuery() {
|
||||||
|
t.mutex.Lock()
|
||||||
|
if t.inFlight > 0 {
|
||||||
|
t.inFlight--
|
||||||
|
}
|
||||||
|
if t.inFlight == 0 && t.queriesComplete != nil {
|
||||||
|
close(t.queriesComplete)
|
||||||
|
t.queriesComplete = nil
|
||||||
|
}
|
||||||
|
t.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *BaseTransport) CloseContext() context.Context {
|
||||||
|
return t.closeCtx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *BaseTransport) Shutdown(ctx context.Context) error {
|
||||||
|
t.mutex.Lock()
|
||||||
|
|
||||||
|
if t.state >= StateClosing {
|
||||||
|
t.mutex.Unlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if t.state == StateNew {
|
||||||
|
t.state = StateClosed
|
||||||
|
t.mutex.Unlock()
|
||||||
|
t.closeCancel()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
t.state = StateClosing
|
||||||
|
|
||||||
|
if t.inFlight == 0 {
|
||||||
|
t.state = StateClosed
|
||||||
|
t.mutex.Unlock()
|
||||||
|
t.closeCancel()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
t.queriesComplete = make(chan struct{})
|
||||||
|
queriesComplete := t.queriesComplete
|
||||||
|
t.mutex.Unlock()
|
||||||
|
|
||||||
|
t.closeCancel()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-queriesComplete:
|
||||||
|
t.mutex.Lock()
|
||||||
|
t.state = StateClosed
|
||||||
|
t.mutex.Unlock()
|
||||||
|
return nil
|
||||||
|
case <-ctx.Done():
|
||||||
|
t.mutex.Lock()
|
||||||
|
t.state = StateClosed
|
||||||
|
t.mutex.Unlock()
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *BaseTransport) Close() error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), C.TCPTimeout)
|
||||||
|
defer cancel()
|
||||||
|
return t.Shutdown(ctx)
|
||||||
|
}
|
||||||
205
dns/transport/connector.go
Normal file
205
dns/transport/connector.go
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
package transport
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConnectorCallbacks[T any] struct {
|
||||||
|
IsClosed func(connection T) bool
|
||||||
|
Close func(connection T)
|
||||||
|
Reset func(connection T)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Connector[T any] struct {
|
||||||
|
dial func(ctx context.Context) (T, error)
|
||||||
|
callbacks ConnectorCallbacks[T]
|
||||||
|
|
||||||
|
access sync.Mutex
|
||||||
|
connection T
|
||||||
|
hasConnection bool
|
||||||
|
connecting chan struct{}
|
||||||
|
|
||||||
|
closeCtx context.Context
|
||||||
|
closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnector[T any](closeCtx context.Context, dial func(context.Context) (T, error), callbacks ConnectorCallbacks[T]) *Connector[T] {
|
||||||
|
return &Connector[T]{
|
||||||
|
dial: dial,
|
||||||
|
callbacks: callbacks,
|
||||||
|
closeCtx: closeCtx,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSingleflightConnector(closeCtx context.Context, dial func(context.Context) (*Connection, error)) *Connector[*Connection] {
|
||||||
|
return NewConnector(closeCtx, dial, ConnectorCallbacks[*Connection]{
|
||||||
|
IsClosed: func(connection *Connection) bool {
|
||||||
|
return connection.IsClosed()
|
||||||
|
},
|
||||||
|
Close: func(connection *Connection) {
|
||||||
|
connection.CloseWithError(ErrTransportClosed)
|
||||||
|
},
|
||||||
|
Reset: func(connection *Connection) {
|
||||||
|
connection.CloseWithError(ErrConnectionReset)
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connector[T]) Get(ctx context.Context) (T, error) {
|
||||||
|
var zero T
|
||||||
|
for {
|
||||||
|
c.access.Lock()
|
||||||
|
|
||||||
|
if c.closed {
|
||||||
|
c.access.Unlock()
|
||||||
|
return zero, ErrTransportClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.hasConnection && !c.callbacks.IsClosed(c.connection) {
|
||||||
|
connection := c.connection
|
||||||
|
c.access.Unlock()
|
||||||
|
return connection, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.hasConnection = false
|
||||||
|
|
||||||
|
if c.connecting != nil {
|
||||||
|
connecting := c.connecting
|
||||||
|
c.access.Unlock()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-connecting:
|
||||||
|
continue
|
||||||
|
case <-ctx.Done():
|
||||||
|
return zero, ctx.Err()
|
||||||
|
case <-c.closeCtx.Done():
|
||||||
|
return zero, ErrTransportClosed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.connecting = make(chan struct{})
|
||||||
|
c.access.Unlock()
|
||||||
|
|
||||||
|
connection, err := c.dialWithCancellation(ctx)
|
||||||
|
|
||||||
|
c.access.Lock()
|
||||||
|
close(c.connecting)
|
||||||
|
c.connecting = nil
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
c.access.Unlock()
|
||||||
|
return zero, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.closed {
|
||||||
|
c.callbacks.Close(connection)
|
||||||
|
c.access.Unlock()
|
||||||
|
return zero, ErrTransportClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
c.connection = connection
|
||||||
|
c.hasConnection = true
|
||||||
|
result := c.connection
|
||||||
|
c.access.Unlock()
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connector[T]) dialWithCancellation(ctx context.Context) (T, error) {
|
||||||
|
dialCtx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-c.closeCtx.Done():
|
||||||
|
cancel()
|
||||||
|
case <-dialCtx.Done():
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
return c.dial(dialCtx)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connector[T]) Close() error {
|
||||||
|
c.access.Lock()
|
||||||
|
defer c.access.Unlock()
|
||||||
|
|
||||||
|
if c.closed {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
c.closed = true
|
||||||
|
|
||||||
|
if c.hasConnection {
|
||||||
|
c.callbacks.Close(c.connection)
|
||||||
|
c.hasConnection = false
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connector[T]) Reset() {
|
||||||
|
c.access.Lock()
|
||||||
|
defer c.access.Unlock()
|
||||||
|
|
||||||
|
if c.hasConnection {
|
||||||
|
c.callbacks.Reset(c.connection)
|
||||||
|
c.hasConnection = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Connection struct {
|
||||||
|
net.Conn
|
||||||
|
|
||||||
|
closeOnce sync.Once
|
||||||
|
done chan struct{}
|
||||||
|
closeError error
|
||||||
|
}
|
||||||
|
|
||||||
|
func WrapConnection(conn net.Conn) *Connection {
|
||||||
|
return &Connection{
|
||||||
|
Conn: conn,
|
||||||
|
done: make(chan struct{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) Done() <-chan struct{} {
|
||||||
|
return c.done
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) IsClosed() bool {
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) CloseError() error {
|
||||||
|
select {
|
||||||
|
case <-c.done:
|
||||||
|
if c.closeError != nil {
|
||||||
|
return c.closeError
|
||||||
|
}
|
||||||
|
return ErrTransportClosed
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) Close() error {
|
||||||
|
return c.CloseWithError(ErrTransportClosed)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Connection) CloseWithError(err error) error {
|
||||||
|
var returnError error
|
||||||
|
c.closeOnce.Do(func() {
|
||||||
|
c.closeError = err
|
||||||
|
returnError = c.Conn.Close()
|
||||||
|
close(c.done)
|
||||||
|
})
|
||||||
|
return returnError
|
||||||
|
}
|
||||||
@@ -108,6 +108,13 @@ func (t *Transport) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Transport) Reset() {
|
||||||
|
t.transportLock.Lock()
|
||||||
|
t.updatedAt = time.Time{}
|
||||||
|
t.servers = nil
|
||||||
|
t.transportLock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||||
servers, err := t.fetch()
|
servers, err := t.fetch()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -82,8 +82,12 @@ func (s *MemoryStorage) FakeIPLoadDomain(domain string, isIPv6 bool) (netip.Addr
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *MemoryStorage) FakeIPReset() error {
|
func (s *MemoryStorage) FakeIPReset() error {
|
||||||
|
s.addressAccess.Lock()
|
||||||
|
s.domainAccess.Lock()
|
||||||
s.addressCache = make(map[netip.Addr]string)
|
s.addressCache = make(map[netip.Addr]string)
|
||||||
s.domainCache4 = make(map[string]netip.Addr)
|
s.domainCache4 = make(map[string]netip.Addr)
|
||||||
s.domainCache6 = make(map[string]netip.Addr)
|
s.domainCache6 = make(map[string]netip.Addr)
|
||||||
|
s.domainAccess.Unlock()
|
||||||
|
s.addressAccess.Unlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package fakeip
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
@@ -13,13 +14,15 @@ import (
|
|||||||
var _ adapter.FakeIPStore = (*Store)(nil)
|
var _ adapter.FakeIPStore = (*Store)(nil)
|
||||||
|
|
||||||
type Store struct {
|
type Store struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
logger logger.Logger
|
logger logger.Logger
|
||||||
inet4Range netip.Prefix
|
inet4Range netip.Prefix
|
||||||
inet6Range netip.Prefix
|
inet6Range netip.Prefix
|
||||||
storage adapter.FakeIPStorage
|
storage adapter.FakeIPStorage
|
||||||
inet4Current netip.Addr
|
|
||||||
inet6Current netip.Addr
|
addressAccess sync.Mutex
|
||||||
|
inet4Current netip.Addr
|
||||||
|
inet6Current netip.Addr
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewStore(ctx context.Context, logger logger.Logger, inet4Range netip.Prefix, inet6Range netip.Prefix) *Store {
|
func NewStore(ctx context.Context, logger logger.Logger, inet4Range netip.Prefix, inet6Range netip.Prefix) *Store {
|
||||||
@@ -65,18 +68,30 @@ func (s *Store) Close() error {
|
|||||||
if s.storage == nil {
|
if s.storage == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return s.storage.FakeIPSaveMetadata(&adapter.FakeIPMetadata{
|
s.addressAccess.Lock()
|
||||||
|
metadata := &adapter.FakeIPMetadata{
|
||||||
Inet4Range: s.inet4Range,
|
Inet4Range: s.inet4Range,
|
||||||
Inet6Range: s.inet6Range,
|
Inet6Range: s.inet6Range,
|
||||||
Inet4Current: s.inet4Current,
|
Inet4Current: s.inet4Current,
|
||||||
Inet6Current: s.inet6Current,
|
Inet6Current: s.inet6Current,
|
||||||
})
|
}
|
||||||
|
s.addressAccess.Unlock()
|
||||||
|
return s.storage.FakeIPSaveMetadata(metadata)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Store) Create(domain string, isIPv6 bool) (netip.Addr, error) {
|
func (s *Store) Create(domain string, isIPv6 bool) (netip.Addr, error) {
|
||||||
if address, loaded := s.storage.FakeIPLoadDomain(domain, isIPv6); loaded {
|
if address, loaded := s.storage.FakeIPLoadDomain(domain, isIPv6); loaded {
|
||||||
return address, nil
|
return address, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.addressAccess.Lock()
|
||||||
|
defer s.addressAccess.Unlock()
|
||||||
|
|
||||||
|
// Double-check after acquiring lock
|
||||||
|
if address, loaded := s.storage.FakeIPLoadDomain(domain, isIPv6); loaded {
|
||||||
|
return address, nil
|
||||||
|
}
|
||||||
|
|
||||||
var address netip.Addr
|
var address netip.Addr
|
||||||
if !isIPv6 {
|
if !isIPv6 {
|
||||||
if !s.inet4Current.IsValid() {
|
if !s.inet4Current.IsValid() {
|
||||||
@@ -99,7 +114,10 @@ func (s *Store) Create(domain string, isIPv6 bool) (netip.Addr, error) {
|
|||||||
s.inet6Current = nextAddress
|
s.inet6Current = nextAddress
|
||||||
address = nextAddress
|
address = nextAddress
|
||||||
}
|
}
|
||||||
s.storage.FakeIPStoreAsync(address, domain, s.logger)
|
err := s.storage.FakeIPStore(address, domain)
|
||||||
|
if err != nil {
|
||||||
|
s.logger.Warn("save FakeIP cache: ", err)
|
||||||
|
}
|
||||||
s.storage.FakeIPSaveMetadataAsync(&adapter.FakeIPMetadata{
|
s.storage.FakeIPSaveMetadataAsync(&adapter.FakeIPMetadata{
|
||||||
Inet4Range: s.inet4Range,
|
Inet4Range: s.inet4Range,
|
||||||
Inet6Range: s.inet6Range,
|
Inet6Range: s.inet6Range,
|
||||||
|
|||||||
@@ -59,6 +59,9 @@ func (t *Transport) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Transport) Reset() {
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||||
question := message.Question[0]
|
question := message.Question[0]
|
||||||
domain := mDNS.CanonicalName(question.Name)
|
domain := mDNS.CanonicalName(question.Name)
|
||||||
|
|||||||
@@ -145,6 +145,13 @@ func (t *HTTPSTransport) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *HTTPSTransport) Reset() {
|
||||||
|
t.transportAccess.Lock()
|
||||||
|
defer t.transportAccess.Unlock()
|
||||||
|
t.transport.CloseIdleConnections()
|
||||||
|
t.transport = t.transport.Clone()
|
||||||
|
}
|
||||||
|
|
||||||
func (t *HTTPSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
func (t *HTTPSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||||
startAt := time.Now()
|
startAt := time.Now()
|
||||||
response, err := t.exchange(ctx, message)
|
response, err := t.exchange(ctx, message)
|
||||||
@@ -182,7 +189,10 @@ func (t *HTTPSTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS
|
|||||||
request.Header = t.headers.Clone()
|
request.Header = t.headers.Clone()
|
||||||
request.Header.Set("Content-Type", MimeType)
|
request.Header.Set("Content-Type", MimeType)
|
||||||
request.Header.Set("Accept", MimeType)
|
request.Header.Set("Accept", MimeType)
|
||||||
response, err := t.transport.RoundTrip(request)
|
t.transportAccess.Lock()
|
||||||
|
currentTransport := t.transport
|
||||||
|
t.transportAccess.Unlock()
|
||||||
|
response, err := currentTransport.RoundTrip(request)
|
||||||
requestBuffer.Release()
|
requestBuffer.Release()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -194,12 +204,12 @@ func (t *HTTPSTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS
|
|||||||
var responseMessage mDNS.Msg
|
var responseMessage mDNS.Msg
|
||||||
if response.ContentLength > 0 {
|
if response.ContentLength > 0 {
|
||||||
responseBuffer := buf.NewSize(int(response.ContentLength))
|
responseBuffer := buf.NewSize(int(response.ContentLength))
|
||||||
|
defer responseBuffer.Release()
|
||||||
_, err = responseBuffer.ReadFullFrom(response.Body, int(response.ContentLength))
|
_, err = responseBuffer.ReadFullFrom(response.Body, int(response.ContentLength))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = responseMessage.Unpack(responseBuffer.Bytes())
|
err = responseMessage.Unpack(responseBuffer.Bytes())
|
||||||
responseBuffer.Release()
|
|
||||||
} else {
|
} else {
|
||||||
rawMessage, err = io.ReadAll(response.Body)
|
rawMessage, err = io.ReadAll(response.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -76,6 +76,9 @@ func (t *Transport) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Transport) Reset() {
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||||
if t.resolved != nil {
|
if t.resolved != nil {
|
||||||
resolverObject := t.resolved.Object()
|
resolverObject := t.resolved.Object()
|
||||||
|
|||||||
@@ -92,6 +92,12 @@ func (t *Transport) Close() error {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Transport) Reset() {
|
||||||
|
if t.dhcpTransport != nil {
|
||||||
|
t.dhcpTransport.Reset()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||||
question := message.Question[0]
|
question := message.Question[0]
|
||||||
if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA {
|
if question.Qtype == mDNS.TypeA || question.Qtype == mDNS.TypeAAAA {
|
||||||
|
|||||||
@@ -8,10 +8,12 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/sagernet/quic-go"
|
"github.com/sagernet/quic-go"
|
||||||
"github.com/sagernet/quic-go/http3"
|
"github.com/sagernet/quic-go/http3"
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/common/dialer"
|
||||||
"github.com/sagernet/sing-box/common/tls"
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
"github.com/sagernet/sing-box/dns"
|
"github.com/sagernet/sing-box/dns"
|
||||||
@@ -23,6 +25,7 @@ import (
|
|||||||
"github.com/sagernet/sing/common/bufio"
|
"github.com/sagernet/sing/common/bufio"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
"github.com/sagernet/sing/common/logger"
|
"github.com/sagernet/sing/common/logger"
|
||||||
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
sHTTP "github.com/sagernet/sing/protocol/http"
|
sHTTP "github.com/sagernet/sing/protocol/http"
|
||||||
|
|
||||||
@@ -37,11 +40,14 @@ func RegisterHTTP3Transport(registry *dns.TransportRegistry) {
|
|||||||
|
|
||||||
type HTTP3Transport struct {
|
type HTTP3Transport struct {
|
||||||
dns.TransportAdapter
|
dns.TransportAdapter
|
||||||
logger logger.ContextLogger
|
logger logger.ContextLogger
|
||||||
dialer N.Dialer
|
dialer N.Dialer
|
||||||
destination *url.URL
|
destination *url.URL
|
||||||
headers http.Header
|
headers http.Header
|
||||||
transport *http3.Transport
|
serverAddr M.Socksaddr
|
||||||
|
tlsConfig *tls.STDConfig
|
||||||
|
transportAccess sync.Mutex
|
||||||
|
transport *http3.Transport
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewHTTP3(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteHTTPSDNSServerOptions) (adapter.DNSTransport, error) {
|
func NewHTTP3(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteHTTPSDNSServerOptions) (adapter.DNSTransport, error) {
|
||||||
@@ -95,33 +101,57 @@ func NewHTTP3(ctx context.Context, logger log.ContextLogger, tag string, options
|
|||||||
if !serverAddr.IsValid() {
|
if !serverAddr.IsValid() {
|
||||||
return nil, E.New("invalid server address: ", serverAddr)
|
return nil, E.New("invalid server address: ", serverAddr)
|
||||||
}
|
}
|
||||||
return &HTTP3Transport{
|
t := &HTTP3Transport{
|
||||||
TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeHTTP3, tag, options.RemoteDNSServerOptions),
|
TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeHTTP3, tag, options.RemoteDNSServerOptions),
|
||||||
logger: logger,
|
logger: logger,
|
||||||
dialer: transportDialer,
|
dialer: transportDialer,
|
||||||
destination: &destinationURL,
|
destination: &destinationURL,
|
||||||
headers: headers,
|
headers: headers,
|
||||||
transport: &http3.Transport{
|
serverAddr: serverAddr,
|
||||||
Dial: func(ctx context.Context, addr string, tlsCfg *tls.STDConfig, cfg *quic.Config) (*quic.Conn, error) {
|
tlsConfig: stdConfig,
|
||||||
conn, dialErr := transportDialer.DialContext(ctx, N.NetworkUDP, serverAddr)
|
}
|
||||||
if dialErr != nil {
|
t.transport = t.newTransport()
|
||||||
return nil, dialErr
|
return t, nil
|
||||||
}
|
}
|
||||||
return quic.DialEarly(ctx, bufio.NewUnbindPacketConn(conn), conn.RemoteAddr(), tlsCfg, cfg)
|
|
||||||
},
|
func (t *HTTP3Transport) newTransport() *http3.Transport {
|
||||||
TLSClientConfig: stdConfig,
|
return &http3.Transport{
|
||||||
|
Dial: func(ctx context.Context, addr string, tlsCfg *tls.STDConfig, cfg *quic.Config) (*quic.Conn, error) {
|
||||||
|
conn, dialErr := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr)
|
||||||
|
if dialErr != nil {
|
||||||
|
return nil, dialErr
|
||||||
|
}
|
||||||
|
quicConn, dialErr := quic.DialEarly(ctx, bufio.NewUnbindPacketConn(conn), conn.RemoteAddr(), tlsCfg, cfg)
|
||||||
|
if dialErr != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, dialErr
|
||||||
|
}
|
||||||
|
return quicConn, nil
|
||||||
},
|
},
|
||||||
}, nil
|
TLSClientConfig: t.tlsConfig,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *HTTP3Transport) Start(stage adapter.StartStage) error {
|
func (t *HTTP3Transport) Start(stage adapter.StartStage) error {
|
||||||
return nil
|
if stage != adapter.StartStateStart {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return dialer.InitializeDetour(t.dialer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *HTTP3Transport) Close() error {
|
func (t *HTTP3Transport) Close() error {
|
||||||
|
t.transportAccess.Lock()
|
||||||
|
defer t.transportAccess.Unlock()
|
||||||
return t.transport.Close()
|
return t.transport.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *HTTP3Transport) Reset() {
|
||||||
|
t.transportAccess.Lock()
|
||||||
|
defer t.transportAccess.Unlock()
|
||||||
|
t.transport.Close()
|
||||||
|
t.transport = t.newTransport()
|
||||||
|
}
|
||||||
|
|
||||||
func (t *HTTP3Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
func (t *HTTP3Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||||
exMessage := *message
|
exMessage := *message
|
||||||
exMessage.Id = 0
|
exMessage.Id = 0
|
||||||
@@ -140,7 +170,10 @@ func (t *HTTP3Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS
|
|||||||
request.Header = t.headers.Clone()
|
request.Header = t.headers.Clone()
|
||||||
request.Header.Set("Content-Type", transport.MimeType)
|
request.Header.Set("Content-Type", transport.MimeType)
|
||||||
request.Header.Set("Accept", transport.MimeType)
|
request.Header.Set("Accept", transport.MimeType)
|
||||||
response, err := t.transport.RoundTrip(request)
|
t.transportAccess.Lock()
|
||||||
|
currentTransport := t.transport
|
||||||
|
t.transportAccess.Unlock()
|
||||||
|
response, err := currentTransport.RoundTrip(request)
|
||||||
requestBuffer.Release()
|
requestBuffer.Release()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -152,12 +185,12 @@ func (t *HTTP3Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS
|
|||||||
var responseMessage mDNS.Msg
|
var responseMessage mDNS.Msg
|
||||||
if response.ContentLength > 0 {
|
if response.ContentLength > 0 {
|
||||||
responseBuffer := buf.NewSize(int(response.ContentLength))
|
responseBuffer := buf.NewSize(int(response.ContentLength))
|
||||||
|
defer responseBuffer.Release()
|
||||||
_, err = responseBuffer.ReadFullFrom(response.Body, int(response.ContentLength))
|
_, err = responseBuffer.ReadFullFrom(response.Body, int(response.ContentLength))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
err = responseMessage.Unpack(responseBuffer.Bytes())
|
err = responseMessage.Unpack(responseBuffer.Bytes())
|
||||||
responseBuffer.Release()
|
|
||||||
} else {
|
} else {
|
||||||
rawMessage, err = io.ReadAll(response.Body)
|
rawMessage, err = io.ReadAll(response.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -3,10 +3,11 @@ package quic
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"sync"
|
"os"
|
||||||
|
|
||||||
"github.com/sagernet/quic-go"
|
"github.com/sagernet/quic-go"
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
|
"github.com/sagernet/sing-box/common/dialer"
|
||||||
"github.com/sagernet/sing-box/common/tls"
|
"github.com/sagernet/sing-box/common/tls"
|
||||||
C "github.com/sagernet/sing-box/constant"
|
C "github.com/sagernet/sing-box/constant"
|
||||||
"github.com/sagernet/sing-box/dns"
|
"github.com/sagernet/sing-box/dns"
|
||||||
@@ -17,7 +18,6 @@ import (
|
|||||||
"github.com/sagernet/sing/common"
|
"github.com/sagernet/sing/common"
|
||||||
"github.com/sagernet/sing/common/bufio"
|
"github.com/sagernet/sing/common/bufio"
|
||||||
E "github.com/sagernet/sing/common/exceptions"
|
E "github.com/sagernet/sing/common/exceptions"
|
||||||
"github.com/sagernet/sing/common/logger"
|
|
||||||
M "github.com/sagernet/sing/common/metadata"
|
M "github.com/sagernet/sing/common/metadata"
|
||||||
N "github.com/sagernet/sing/common/network"
|
N "github.com/sagernet/sing/common/network"
|
||||||
|
|
||||||
@@ -31,14 +31,14 @@ func RegisterTransport(registry *dns.TransportRegistry) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Transport struct {
|
type Transport struct {
|
||||||
dns.TransportAdapter
|
*transport.BaseTransport
|
||||||
|
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
logger logger.ContextLogger
|
|
||||||
dialer N.Dialer
|
dialer N.Dialer
|
||||||
serverAddr M.Socksaddr
|
serverAddr M.Socksaddr
|
||||||
tlsConfig tls.Config
|
tlsConfig tls.Config
|
||||||
access sync.Mutex
|
|
||||||
connection *quic.Conn
|
connector *transport.Connector[*quic.Conn]
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewQUIC(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteTLSDNSServerOptions) (adapter.DNSTransport, error) {
|
func NewQUIC(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteTLSDNSServerOptions) (adapter.DNSTransport, error) {
|
||||||
@@ -62,38 +62,84 @@ func NewQUIC(ctx context.Context, logger log.ContextLogger, tag string, options
|
|||||||
if !serverAddr.IsValid() {
|
if !serverAddr.IsValid() {
|
||||||
return nil, E.New("invalid server address: ", serverAddr)
|
return nil, E.New("invalid server address: ", serverAddr)
|
||||||
}
|
}
|
||||||
return &Transport{
|
|
||||||
TransportAdapter: dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeQUIC, tag, options.RemoteDNSServerOptions),
|
t := &Transport{
|
||||||
ctx: ctx,
|
BaseTransport: transport.NewBaseTransport(
|
||||||
logger: logger,
|
dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeQUIC, tag, options.RemoteDNSServerOptions),
|
||||||
dialer: transportDialer,
|
logger,
|
||||||
serverAddr: serverAddr,
|
),
|
||||||
tlsConfig: tlsConfig,
|
ctx: ctx,
|
||||||
}, nil
|
dialer: transportDialer,
|
||||||
|
serverAddr: serverAddr,
|
||||||
|
tlsConfig: tlsConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.connector = transport.NewConnector(t.CloseContext(), t.dial, transport.ConnectorCallbacks[*quic.Conn]{
|
||||||
|
IsClosed: func(connection *quic.Conn) bool {
|
||||||
|
return common.Done(connection.Context())
|
||||||
|
},
|
||||||
|
Close: func(connection *quic.Conn) {
|
||||||
|
connection.CloseWithError(0, "")
|
||||||
|
},
|
||||||
|
Reset: func(connection *quic.Conn) {
|
||||||
|
connection.CloseWithError(0, "")
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
return t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *Transport) dial(ctx context.Context) (*quic.Conn, error) {
|
||||||
|
conn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "dial UDP connection")
|
||||||
|
}
|
||||||
|
earlyConnection, err := sQUIC.DialEarly(
|
||||||
|
ctx,
|
||||||
|
bufio.NewUnbindPacketConn(conn),
|
||||||
|
t.serverAddr.UDPAddr(),
|
||||||
|
t.tlsConfig,
|
||||||
|
nil,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, E.Cause(err, "establish QUIC connection")
|
||||||
|
}
|
||||||
|
return earlyConnection, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Transport) Start(stage adapter.StartStage) error {
|
func (t *Transport) Start(stage adapter.StartStage) error {
|
||||||
return nil
|
if stage != adapter.StartStateStart {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
err := t.SetStarted()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return dialer.InitializeDetour(t.dialer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Transport) Close() error {
|
func (t *Transport) Close() error {
|
||||||
t.access.Lock()
|
return E.Errors(t.BaseTransport.Close(), t.connector.Close())
|
||||||
defer t.access.Unlock()
|
}
|
||||||
connection := t.connection
|
|
||||||
if connection != nil {
|
func (t *Transport) Reset() {
|
||||||
connection.CloseWithError(0, "")
|
t.connector.Reset()
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||||
|
if !t.BeginQuery() {
|
||||||
|
return nil, transport.ErrTransportClosed
|
||||||
|
}
|
||||||
|
defer t.EndQuery()
|
||||||
|
|
||||||
var (
|
var (
|
||||||
conn *quic.Conn
|
conn *quic.Conn
|
||||||
err error
|
err error
|
||||||
response *mDNS.Msg
|
response *mDNS.Msg
|
||||||
)
|
)
|
||||||
for i := 0; i < 2; i++ {
|
for i := 0; i < 2; i++ {
|
||||||
conn, err = t.openConnection()
|
conn, err = t.connector.Get(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -103,58 +149,38 @@ func (t *Transport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg,
|
|||||||
} else if !isQUICRetryError(err) {
|
} else if !isQUICRetryError(err) {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else {
|
} else {
|
||||||
conn.CloseWithError(quic.ApplicationErrorCode(0), "")
|
t.connector.Reset()
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *Transport) openConnection() (*quic.Conn, error) {
|
|
||||||
connection := t.connection
|
|
||||||
if connection != nil && !common.Done(connection.Context()) {
|
|
||||||
return connection, nil
|
|
||||||
}
|
|
||||||
t.access.Lock()
|
|
||||||
defer t.access.Unlock()
|
|
||||||
connection = t.connection
|
|
||||||
if connection != nil && !common.Done(connection.Context()) {
|
|
||||||
return connection, nil
|
|
||||||
}
|
|
||||||
conn, err := t.dialer.DialContext(t.ctx, N.NetworkUDP, t.serverAddr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
earlyConnection, err := sQUIC.DialEarly(
|
|
||||||
t.ctx,
|
|
||||||
bufio.NewUnbindPacketConn(conn),
|
|
||||||
t.serverAddr.UDPAddr(),
|
|
||||||
t.tlsConfig,
|
|
||||||
nil,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
t.connection = earlyConnection
|
|
||||||
return earlyConnection, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Transport) exchange(ctx context.Context, message *mDNS.Msg, conn *quic.Conn) (*mDNS.Msg, error) {
|
func (t *Transport) exchange(ctx context.Context, message *mDNS.Msg, conn *quic.Conn) (*mDNS.Msg, error) {
|
||||||
stream, err := conn.OpenStreamSync(ctx)
|
stream, err := conn.OpenStreamSync(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, E.Cause(err, "open stream")
|
||||||
}
|
}
|
||||||
|
defer stream.CancelRead(0)
|
||||||
err = transport.WriteMessage(stream, 0, message)
|
err = transport.WriteMessage(stream, 0, message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
stream.Close()
|
stream.Close()
|
||||||
return nil, err
|
return nil, E.Cause(err, "write request")
|
||||||
}
|
}
|
||||||
stream.Close()
|
stream.Close()
|
||||||
return transport.ReadMessage(stream)
|
response, err := transport.ReadMessage(stream)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "read response")
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://github.com/AdguardTeam/dnsproxy/blob/fd1868577652c639cce3da00e12ca548f421baf1/upstream/upstream_quic.go#L394
|
// https://github.com/AdguardTeam/dnsproxy/blob/fd1868577652c639cce3da00e12ca548f421baf1/upstream/upstream_quic.go#L394
|
||||||
func isQUICRetryError(err error) (ok bool) {
|
func isQUICRetryError(err error) (ok bool) {
|
||||||
|
if errors.Is(err, os.ErrClosed) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
var qAppErr *quic.ApplicationError
|
var qAppErr *quic.ApplicationError
|
||||||
if errors.As(err, &qAppErr) && qAppErr.ErrorCode == 0 {
|
if errors.As(err, &qAppErr) && qAppErr.ErrorCode == 0 {
|
||||||
return true
|
return true
|
||||||
|
|||||||
@@ -62,17 +62,24 @@ func (t *TCPTransport) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *TCPTransport) Reset() {
|
||||||
|
}
|
||||||
|
|
||||||
func (t *TCPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
func (t *TCPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||||
conn, err := t.dialer.DialContext(ctx, N.NetworkTCP, t.serverAddr)
|
conn, err := t.dialer.DialContext(ctx, N.NetworkTCP, t.serverAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, E.Cause(err, "dial TCP connection")
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
err = WriteMessage(conn, 0, message)
|
err = WriteMessage(conn, 0, message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, E.Cause(err, "write request")
|
||||||
}
|
}
|
||||||
return ReadMessage(conn)
|
response, err := ReadMessage(conn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "read response")
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func ReadMessage(reader io.Reader) (*mDNS.Msg, error) {
|
func ReadMessage(reader io.Reader) (*mDNS.Msg, error) {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package transport
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
"github.com/sagernet/sing-box/common/dialer"
|
"github.com/sagernet/sing-box/common/dialer"
|
||||||
@@ -28,8 +29,8 @@ func RegisterTLS(registry *dns.TransportRegistry) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type TLSTransport struct {
|
type TLSTransport struct {
|
||||||
dns.TransportAdapter
|
*BaseTransport
|
||||||
logger logger.ContextLogger
|
|
||||||
dialer tls.Dialer
|
dialer tls.Dialer
|
||||||
serverAddr M.Socksaddr
|
serverAddr M.Socksaddr
|
||||||
tlsConfig tls.Config
|
tlsConfig tls.Config
|
||||||
@@ -65,11 +66,10 @@ func NewTLS(ctx context.Context, logger log.ContextLogger, tag string, options o
|
|||||||
|
|
||||||
func NewTLSRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer N.Dialer, serverAddr M.Socksaddr, tlsConfig tls.Config) *TLSTransport {
|
func NewTLSRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer N.Dialer, serverAddr M.Socksaddr, tlsConfig tls.Config) *TLSTransport {
|
||||||
return &TLSTransport{
|
return &TLSTransport{
|
||||||
TransportAdapter: adapter,
|
BaseTransport: NewBaseTransport(adapter, logger),
|
||||||
logger: logger,
|
dialer: tls.NewDialer(dialer, tlsConfig),
|
||||||
dialer: tls.NewDialer(dialer, tlsConfig),
|
serverAddr: serverAddr,
|
||||||
serverAddr: serverAddr,
|
tlsConfig: tlsConfig,
|
||||||
tlsConfig: tlsConfig,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,37 +77,59 @@ func (t *TLSTransport) Start(stage adapter.StartStage) error {
|
|||||||
if stage != adapter.StartStateStart {
|
if stage != adapter.StartStateStart {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
err := t.SetStarted()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return dialer.InitializeDetour(t.dialer)
|
return dialer.InitializeDetour(t.dialer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TLSTransport) Close() error {
|
func (t *TLSTransport) Close() error {
|
||||||
|
t.access.Lock()
|
||||||
|
for connection := t.connections.Front(); connection != nil; connection = connection.Next() {
|
||||||
|
connection.Value.Close()
|
||||||
|
}
|
||||||
|
t.connections.Init()
|
||||||
|
t.access.Unlock()
|
||||||
|
return t.BaseTransport.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TLSTransport) Reset() {
|
||||||
t.access.Lock()
|
t.access.Lock()
|
||||||
defer t.access.Unlock()
|
defer t.access.Unlock()
|
||||||
for connection := t.connections.Front(); connection != nil; connection = connection.Next() {
|
for connection := t.connections.Front(); connection != nil; connection = connection.Next() {
|
||||||
connection.Value.Close()
|
connection.Value.Close()
|
||||||
}
|
}
|
||||||
t.connections.Init()
|
t.connections.Init()
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
func (t *TLSTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||||
|
if !t.BeginQuery() {
|
||||||
|
return nil, ErrTransportClosed
|
||||||
|
}
|
||||||
|
defer t.EndQuery()
|
||||||
|
|
||||||
t.access.Lock()
|
t.access.Lock()
|
||||||
conn := t.connections.PopFront()
|
conn := t.connections.PopFront()
|
||||||
t.access.Unlock()
|
t.access.Unlock()
|
||||||
if conn != nil {
|
if conn != nil {
|
||||||
response, err := t.exchange(message, conn)
|
response, err := t.exchange(ctx, message, conn)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
t.Logger.DebugContext(ctx, "discarded pooled connection: ", err)
|
||||||
}
|
}
|
||||||
tlsConn, err := t.dialer.DialTLSContext(ctx, t.serverAddr)
|
tlsConn, err := t.dialer.DialTLSContext(ctx, t.serverAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, E.Cause(err, "dial TLS connection")
|
||||||
}
|
}
|
||||||
return t.exchange(message, &tlsDNSConn{Conn: tlsConn})
|
return t.exchange(ctx, message, &tlsDNSConn{Conn: tlsConn})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TLSTransport) exchange(message *mDNS.Msg, conn *tlsDNSConn) (*mDNS.Msg, error) {
|
func (t *TLSTransport) exchange(ctx context.Context, message *mDNS.Msg, conn *tlsDNSConn) (*mDNS.Msg, error) {
|
||||||
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
|
conn.SetDeadline(deadline)
|
||||||
|
}
|
||||||
conn.queryId++
|
conn.queryId++
|
||||||
err := WriteMessage(conn, conn.queryId, message)
|
err := WriteMessage(conn, conn.queryId, message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -120,6 +142,12 @@ func (t *TLSTransport) exchange(message *mDNS.Msg, conn *tlsDNSConn) (*mDNS.Msg,
|
|||||||
return nil, E.Cause(err, "read response")
|
return nil, E.Cause(err, "read response")
|
||||||
}
|
}
|
||||||
t.access.Lock()
|
t.access.Lock()
|
||||||
|
if t.State() >= StateClosing {
|
||||||
|
t.access.Unlock()
|
||||||
|
conn.Close()
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
conn.SetDeadline(time.Time{})
|
||||||
t.connections.PushBack(conn)
|
t.connections.PushBack(conn)
|
||||||
t.access.Unlock()
|
t.access.Unlock()
|
||||||
return response, nil
|
return response, nil
|
||||||
|
|||||||
@@ -2,9 +2,8 @@ package transport
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/sagernet/sing-box/adapter"
|
"github.com/sagernet/sing-box/adapter"
|
||||||
"github.com/sagernet/sing-box/common/dialer"
|
"github.com/sagernet/sing-box/common/dialer"
|
||||||
@@ -28,15 +27,23 @@ func RegisterUDP(registry *dns.TransportRegistry) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type UDPTransport struct {
|
type UDPTransport struct {
|
||||||
dns.TransportAdapter
|
*BaseTransport
|
||||||
logger logger.ContextLogger
|
|
||||||
dialer N.Dialer
|
dialer N.Dialer
|
||||||
serverAddr M.Socksaddr
|
serverAddr M.Socksaddr
|
||||||
udpSize int
|
udpSize atomic.Int32
|
||||||
tcpTransport *TCPTransport
|
|
||||||
access sync.Mutex
|
connector *Connector[*Connection]
|
||||||
conn *dnsConnection
|
|
||||||
done chan struct{}
|
callbackAccess sync.RWMutex
|
||||||
|
queryId uint16
|
||||||
|
callbacks map[uint16]*udpCallback
|
||||||
|
}
|
||||||
|
|
||||||
|
type udpCallback struct {
|
||||||
|
access sync.Mutex
|
||||||
|
response *mDNS.Msg
|
||||||
|
done chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUDP(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteDNSServerOptions) (adapter.DNSTransport, error) {
|
func NewUDP(ctx context.Context, logger log.ContextLogger, tag string, options option.RemoteDNSServerOptions) (adapter.DNSTransport, error) {
|
||||||
@@ -54,180 +61,198 @@ func NewUDP(ctx context.Context, logger log.ContextLogger, tag string, options o
|
|||||||
return NewUDPRaw(logger, dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeUDP, tag, options), transportDialer, serverAddr), nil
|
return NewUDPRaw(logger, dns.NewTransportAdapterWithRemoteOptions(C.DNSTypeUDP, tag, options), transportDialer, serverAddr), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUDPRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialer N.Dialer, serverAddr M.Socksaddr) *UDPTransport {
|
func NewUDPRaw(logger logger.ContextLogger, adapter dns.TransportAdapter, dialerInstance N.Dialer, serverAddr M.Socksaddr) *UDPTransport {
|
||||||
return &UDPTransport{
|
t := &UDPTransport{
|
||||||
TransportAdapter: adapter,
|
BaseTransport: NewBaseTransport(adapter, logger),
|
||||||
logger: logger,
|
dialer: dialerInstance,
|
||||||
dialer: dialer,
|
serverAddr: serverAddr,
|
||||||
serverAddr: serverAddr,
|
callbacks: make(map[uint16]*udpCallback),
|
||||||
udpSize: 2048,
|
|
||||||
tcpTransport: &TCPTransport{
|
|
||||||
dialer: dialer,
|
|
||||||
serverAddr: serverAddr,
|
|
||||||
},
|
|
||||||
done: make(chan struct{}),
|
|
||||||
}
|
}
|
||||||
|
t.udpSize.Store(2048)
|
||||||
|
t.connector = NewSingleflightConnector(t.CloseContext(), t.dial)
|
||||||
|
return t
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *UDPTransport) dial(ctx context.Context) (*Connection, error) {
|
||||||
|
rawConn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "dial UDP connection")
|
||||||
|
}
|
||||||
|
conn := WrapConnection(rawConn)
|
||||||
|
go t.recvLoop(conn)
|
||||||
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *UDPTransport) Start(stage adapter.StartStage) error {
|
func (t *UDPTransport) Start(stage adapter.StartStage) error {
|
||||||
if stage != adapter.StartStateStart {
|
if stage != adapter.StartStateStart {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
err := t.SetStarted()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
return dialer.InitializeDetour(t.dialer)
|
return dialer.InitializeDetour(t.dialer)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *UDPTransport) Close() error {
|
func (t *UDPTransport) Close() error {
|
||||||
t.access.Lock()
|
return E.Errors(t.BaseTransport.Close(), t.connector.Close())
|
||||||
defer t.access.Unlock()
|
}
|
||||||
close(t.done)
|
|
||||||
t.done = make(chan struct{})
|
func (t *UDPTransport) Reset() {
|
||||||
return nil
|
t.connector.Reset()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *UDPTransport) nextAvailableQueryId() (uint16, error) {
|
||||||
|
start := t.queryId
|
||||||
|
for {
|
||||||
|
t.queryId++
|
||||||
|
if _, exists := t.callbacks[t.queryId]; !exists {
|
||||||
|
return t.queryId, nil
|
||||||
|
}
|
||||||
|
if t.queryId == start {
|
||||||
|
return 0, E.New("no available query ID")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *UDPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
func (t *UDPTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||||
|
if !t.BeginQuery() {
|
||||||
|
return nil, ErrTransportClosed
|
||||||
|
}
|
||||||
|
defer t.EndQuery()
|
||||||
|
|
||||||
response, err := t.exchange(ctx, message)
|
response, err := t.exchange(ctx, message)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if response.Truncated {
|
if response.Truncated {
|
||||||
t.logger.InfoContext(ctx, "response truncated, retrying with TCP")
|
t.Logger.InfoContext(ctx, "response truncated, retrying with TCP")
|
||||||
return t.tcpTransport.Exchange(ctx, message)
|
return t.exchangeTCP(ctx, message)
|
||||||
|
}
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *UDPTransport) exchangeTCP(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||||
|
conn, err := t.dialer.DialContext(ctx, N.NetworkTCP, t.serverAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "dial TCP connection")
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
err = WriteMessage(conn, message.Id, message)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "write request")
|
||||||
|
}
|
||||||
|
response, err := ReadMessage(conn)
|
||||||
|
if err != nil {
|
||||||
|
return nil, E.Cause(err, "read response")
|
||||||
}
|
}
|
||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *UDPTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
func (t *UDPTransport) exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||||
t.access.Lock()
|
|
||||||
if edns0Opt := message.IsEdns0(); edns0Opt != nil {
|
if edns0Opt := message.IsEdns0(); edns0Opt != nil {
|
||||||
if udpSize := int(edns0Opt.UDPSize()); udpSize > t.udpSize {
|
udpSize := int32(edns0Opt.UDPSize())
|
||||||
t.udpSize = udpSize
|
for {
|
||||||
close(t.done)
|
current := t.udpSize.Load()
|
||||||
t.done = make(chan struct{})
|
if udpSize <= current {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if t.udpSize.CompareAndSwap(current, udpSize) {
|
||||||
|
t.connector.Reset()
|
||||||
|
break
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
t.access.Unlock()
|
|
||||||
conn, err := t.open(ctx)
|
conn, err := t.connector.Get(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
buffer := buf.NewSize(1 + message.Len())
|
|
||||||
defer buffer.Release()
|
callback := &udpCallback{
|
||||||
exMessage := *message
|
|
||||||
exMessage.Compress = true
|
|
||||||
messageId := message.Id
|
|
||||||
callback := &dnsCallback{
|
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
}
|
}
|
||||||
conn.access.Lock()
|
|
||||||
conn.queryId++
|
t.callbackAccess.Lock()
|
||||||
exMessage.Id = conn.queryId
|
queryId, err := t.nextAvailableQueryId()
|
||||||
conn.callbacks[exMessage.Id] = callback
|
if err != nil {
|
||||||
conn.access.Unlock()
|
t.callbackAccess.Unlock()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
t.callbacks[queryId] = callback
|
||||||
|
t.callbackAccess.Unlock()
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
conn.access.Lock()
|
t.callbackAccess.Lock()
|
||||||
delete(conn.callbacks, exMessage.Id)
|
delete(t.callbacks, queryId)
|
||||||
conn.access.Unlock()
|
t.callbackAccess.Unlock()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
buffer := buf.NewSize(1 + message.Len())
|
||||||
|
defer buffer.Release()
|
||||||
|
|
||||||
|
exMessage := *message
|
||||||
|
exMessage.Compress = true
|
||||||
|
originalId := message.Id
|
||||||
|
exMessage.Id = queryId
|
||||||
|
|
||||||
rawMessage, err := exMessage.PackBuffer(buffer.FreeBytes())
|
rawMessage, err := exMessage.PackBuffer(buffer.FreeBytes())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = conn.Write(rawMessage)
|
_, err = conn.Write(rawMessage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close(err)
|
conn.CloseWithError(err)
|
||||||
return nil, err
|
return nil, E.Cause(err, "write request")
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-callback.done:
|
case <-callback.done:
|
||||||
callback.message.Id = messageId
|
callback.response.Id = originalId
|
||||||
return callback.message, nil
|
return callback.response, nil
|
||||||
case <-conn.done:
|
case <-conn.Done():
|
||||||
return nil, conn.err
|
return nil, conn.CloseError()
|
||||||
case <-t.done:
|
case <-t.CloseContext().Done():
|
||||||
return nil, os.ErrClosed
|
return nil, ErrTransportClosed
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
conn.Close(ctx.Err())
|
|
||||||
return nil, ctx.Err()
|
return nil, ctx.Err()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *UDPTransport) open(ctx context.Context) (*dnsConnection, error) {
|
func (t *UDPTransport) recvLoop(conn *Connection) {
|
||||||
t.access.Lock()
|
|
||||||
defer t.access.Unlock()
|
|
||||||
if t.conn != nil {
|
|
||||||
select {
|
|
||||||
case <-t.conn.done:
|
|
||||||
default:
|
|
||||||
return t.conn, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
conn, err := t.dialer.DialContext(ctx, N.NetworkUDP, t.serverAddr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
dnsConn := &dnsConnection{
|
|
||||||
Conn: conn,
|
|
||||||
done: make(chan struct{}),
|
|
||||||
callbacks: make(map[uint16]*dnsCallback),
|
|
||||||
}
|
|
||||||
go t.recvLoop(dnsConn)
|
|
||||||
t.conn = dnsConn
|
|
||||||
return dnsConn, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *UDPTransport) recvLoop(conn *dnsConnection) {
|
|
||||||
for {
|
for {
|
||||||
buffer := buf.NewSize(t.udpSize)
|
buffer := buf.NewSize(int(t.udpSize.Load()))
|
||||||
_, err := buffer.ReadOnceFrom(conn)
|
_, err := buffer.ReadOnceFrom(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
buffer.Release()
|
buffer.Release()
|
||||||
conn.Close(err)
|
conn.CloseWithError(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var message mDNS.Msg
|
var message mDNS.Msg
|
||||||
err = message.Unpack(buffer.Bytes())
|
err = message.Unpack(buffer.Bytes())
|
||||||
buffer.Release()
|
buffer.Release()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
conn.Close(err)
|
t.Logger.Debug("discarded malformed UDP response: ", err)
|
||||||
return
|
continue
|
||||||
}
|
}
|
||||||
conn.access.RLock()
|
|
||||||
callback, loaded := conn.callbacks[message.Id]
|
t.callbackAccess.RLock()
|
||||||
conn.access.RUnlock()
|
callback, loaded := t.callbacks[message.Id]
|
||||||
|
t.callbackAccess.RUnlock()
|
||||||
|
|
||||||
if !loaded {
|
if !loaded {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
callback.access.Lock()
|
callback.access.Lock()
|
||||||
select {
|
select {
|
||||||
case <-callback.done:
|
case <-callback.done:
|
||||||
default:
|
default:
|
||||||
callback.message = &message
|
callback.response = &message
|
||||||
close(callback.done)
|
close(callback.done)
|
||||||
}
|
}
|
||||||
callback.access.Unlock()
|
callback.access.Unlock()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type dnsConnection struct {
|
|
||||||
net.Conn
|
|
||||||
access sync.RWMutex
|
|
||||||
done chan struct{}
|
|
||||||
closeOnce sync.Once
|
|
||||||
err error
|
|
||||||
queryId uint16
|
|
||||||
callbacks map[uint16]*dnsCallback
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *dnsConnection) Close(err error) {
|
|
||||||
c.closeOnce.Do(func() {
|
|
||||||
c.err = err
|
|
||||||
close(c.done)
|
|
||||||
})
|
|
||||||
c.Conn.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
type dnsCallback struct {
|
|
||||||
access sync.Mutex
|
|
||||||
message *mDNS.Msg
|
|
||||||
done chan struct{}
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -46,6 +46,9 @@ func (p *platformTransport) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (p *platformTransport) Reset() {
|
||||||
|
}
|
||||||
|
|
||||||
func (p *platformTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
func (p *platformTransport) Exchange(ctx context.Context, message *mDNS.Msg) (*mDNS.Msg, error) {
|
||||||
response := &ExchangeContext{
|
response := &ExchangeContext{
|
||||||
context: ctx,
|
context: ctx,
|
||||||
|
|||||||
@@ -110,6 +110,16 @@ func (t *Transport) Close() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *Transport) Reset() {
|
||||||
|
t.linkAccess.RLock()
|
||||||
|
defer t.linkAccess.RUnlock()
|
||||||
|
for _, servers := range t.linkServers {
|
||||||
|
for _, server := range servers.Servers {
|
||||||
|
server.Reset()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (t *Transport) updateTransports(link *TransportLink) error {
|
func (t *Transport) updateTransports(link *TransportLink) error {
|
||||||
t.linkAccess.Lock()
|
t.linkAccess.Lock()
|
||||||
defer t.linkAccess.Unlock()
|
defer t.linkAccess.Unlock()
|
||||||
|
|||||||
Reference in New Issue
Block a user