Fix DNS exchange failure and recursion deadlock in connector

Co-authored-by: everyx <lunt.luo@gmail.com>
This commit is contained in:
世界
2026-03-06 14:53:03 +08:00
parent 84019b06d9
commit 27c5b0b1af
5 changed files with 373 additions and 18 deletions

View File

@@ -4,6 +4,9 @@ import (
"context"
"net"
"sync"
"time"
E "github.com/sagernet/sing/common/exceptions"
)
type ConnectorCallbacks[T any] struct {
@@ -16,10 +19,11 @@ type Connector[T any] struct {
dial func(ctx context.Context) (T, error)
callbacks ConnectorCallbacks[T]
access sync.Mutex
connection T
hasConnection bool
connecting chan struct{}
access sync.Mutex
connection T
hasConnection bool
connectionCancel context.CancelFunc
connecting chan struct{}
closeCtx context.Context
closed bool
@@ -47,6 +51,10 @@ func NewSingleflightConnector(closeCtx context.Context, dial func(context.Contex
})
}
type contextKeyConnecting struct{}
var errRecursiveConnectorDial = E.New("recursive connector dial")
func (c *Connector[T]) Get(ctx context.Context) (T, error) {
var zero T
for {
@@ -64,6 +72,14 @@ func (c *Connector[T]) Get(ctx context.Context) (T, error) {
}
c.hasConnection = false
if c.connectionCancel != nil {
c.connectionCancel()
c.connectionCancel = nil
}
if isRecursiveConnectorDial(ctx, c) {
c.access.Unlock()
return zero, errRecursiveConnectorDial
}
if c.connecting != nil {
connecting := c.connecting
@@ -79,10 +95,16 @@ func (c *Connector[T]) Get(ctx context.Context) (T, error) {
}
}
if err := ctx.Err(); err != nil {
c.access.Unlock()
return zero, err
}
c.connecting = make(chan struct{})
c.access.Unlock()
connection, err := c.dialWithCancellation(ctx)
dialContext := context.WithValue(ctx, contextKeyConnecting{}, c)
connection, cancel, err := c.dialWithCancellation(dialContext)
c.access.Lock()
close(c.connecting)
@@ -94,13 +116,21 @@ func (c *Connector[T]) Get(ctx context.Context) (T, error) {
}
if c.closed {
cancel()
c.callbacks.Close(connection)
c.access.Unlock()
return zero, ErrTransportClosed
}
if err = ctx.Err(); err != nil {
cancel()
c.callbacks.Close(connection)
c.access.Unlock()
return zero, err
}
c.connection = connection
c.hasConnection = true
c.connectionCancel = cancel
result := c.connection
c.access.Unlock()
@@ -108,19 +138,63 @@ func (c *Connector[T]) Get(ctx context.Context) (T, error) {
}
}
func (c *Connector[T]) dialWithCancellation(ctx context.Context) (T, error) {
dialCtx, cancel := context.WithCancel(ctx)
defer cancel()
func isRecursiveConnectorDial[T any](ctx context.Context, connector *Connector[T]) bool {
dialConnector, loaded := ctx.Value(contextKeyConnecting{}).(*Connector[T])
return loaded && dialConnector == connector
}
go func() {
select {
case <-c.closeCtx.Done():
func (c *Connector[T]) dialWithCancellation(ctx context.Context) (T, context.CancelFunc, error) {
var zero T
if err := ctx.Err(); err != nil {
return zero, nil, err
}
connCtx, cancel := context.WithCancel(c.closeCtx)
var (
stateAccess sync.Mutex
dialComplete bool
)
stopCancel := context.AfterFunc(ctx, func() {
stateAccess.Lock()
if !dialComplete {
cancel()
case <-dialCtx.Done():
}
}()
stateAccess.Unlock()
})
select {
case <-ctx.Done():
stateAccess.Lock()
dialComplete = true
stateAccess.Unlock()
stopCancel()
cancel()
return zero, nil, ctx.Err()
default:
}
return c.dial(dialCtx)
connection, err := c.dial(valueContext{connCtx, ctx})
stateAccess.Lock()
dialComplete = true
stateAccess.Unlock()
stopCancel()
if err != nil {
cancel()
return zero, nil, err
}
return connection, cancel, nil
}
type valueContext struct {
context.Context
parent context.Context
}
func (v valueContext) Value(key any) any {
return v.parent.Value(key)
}
func (v valueContext) Deadline() (time.Time, bool) {
return v.parent.Deadline()
}
func (c *Connector[T]) Close() error {
@@ -132,6 +206,10 @@ func (c *Connector[T]) Close() error {
}
c.closed = true
if c.connectionCancel != nil {
c.connectionCancel()
c.connectionCancel = nil
}
if c.hasConnection {
c.callbacks.Close(c.connection)
c.hasConnection = false
@@ -144,6 +222,10 @@ func (c *Connector[T]) Reset() {
c.access.Lock()
defer c.access.Unlock()
if c.connectionCancel != nil {
c.connectionCancel()
c.connectionCancel = nil
}
if c.hasConnection {
c.callbacks.Reset(c.connection)
c.hasConnection = false

View File

@@ -0,0 +1,263 @@
package transport
import (
"context"
"sync/atomic"
"testing"
"time"
"github.com/stretchr/testify/require"
)
type testConnectorConnection struct{}
func TestConnectorRecursiveGetFailsFast(t *testing.T) {
t.Parallel()
var (
dialCount atomic.Int32
closeCount atomic.Int32
connector *Connector[*testConnectorConnection]
)
dial := func(ctx context.Context) (*testConnectorConnection, error) {
dialCount.Add(1)
_, err := connector.Get(ctx)
if err != nil {
return nil, err
}
return &testConnectorConnection{}, nil
}
connector = NewConnector(context.Background(), dial, ConnectorCallbacks[*testConnectorConnection]{
IsClosed: func(connection *testConnectorConnection) bool {
return false
},
Close: func(connection *testConnectorConnection) {
closeCount.Add(1)
},
Reset: func(connection *testConnectorConnection) {
closeCount.Add(1)
},
})
_, err := connector.Get(context.Background())
require.ErrorIs(t, err, errRecursiveConnectorDial)
require.EqualValues(t, 1, dialCount.Load())
require.EqualValues(t, 0, closeCount.Load())
}
func TestConnectorRecursiveGetAcrossConnectorsAllowed(t *testing.T) {
t.Parallel()
var (
outerDialCount atomic.Int32
innerDialCount atomic.Int32
outerConnector *Connector[*testConnectorConnection]
innerConnector *Connector[*testConnectorConnection]
)
innerConnector = NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
innerDialCount.Add(1)
return &testConnectorConnection{}, nil
}, ConnectorCallbacks[*testConnectorConnection]{
IsClosed: func(connection *testConnectorConnection) bool {
return false
},
Close: func(connection *testConnectorConnection) {},
Reset: func(connection *testConnectorConnection) {},
})
outerConnector = NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
outerDialCount.Add(1)
_, err := innerConnector.Get(ctx)
if err != nil {
return nil, err
}
return &testConnectorConnection{}, nil
}, ConnectorCallbacks[*testConnectorConnection]{
IsClosed: func(connection *testConnectorConnection) bool {
return false
},
Close: func(connection *testConnectorConnection) {},
Reset: func(connection *testConnectorConnection) {},
})
_, err := outerConnector.Get(context.Background())
require.NoError(t, err)
require.EqualValues(t, 1, outerDialCount.Load())
require.EqualValues(t, 1, innerDialCount.Load())
}
func TestConnectorDialContextPreservesValueAndDeadline(t *testing.T) {
t.Parallel()
type contextKey struct{}
var (
dialValue any
dialDeadline time.Time
dialHasDeadline bool
)
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
dialValue = ctx.Value(contextKey{})
dialDeadline, dialHasDeadline = ctx.Deadline()
return &testConnectorConnection{}, nil
}, ConnectorCallbacks[*testConnectorConnection]{
IsClosed: func(connection *testConnectorConnection) bool {
return false
},
Close: func(connection *testConnectorConnection) {},
Reset: func(connection *testConnectorConnection) {},
})
deadline := time.Now().Add(time.Minute)
requestContext, cancel := context.WithDeadline(context.WithValue(context.Background(), contextKey{}, "test-value"), deadline)
defer cancel()
_, err := connector.Get(requestContext)
require.NoError(t, err)
require.Equal(t, "test-value", dialValue)
require.True(t, dialHasDeadline)
require.WithinDuration(t, deadline, dialDeadline, time.Second)
}
func TestConnectorDialSkipsCanceledRequest(t *testing.T) {
t.Parallel()
var dialCount atomic.Int32
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
dialCount.Add(1)
return &testConnectorConnection{}, nil
}, ConnectorCallbacks[*testConnectorConnection]{
IsClosed: func(connection *testConnectorConnection) bool {
return false
},
Close: func(connection *testConnectorConnection) {},
Reset: func(connection *testConnectorConnection) {},
})
requestContext, cancel := context.WithCancel(context.Background())
cancel()
_, err := connector.Get(requestContext)
require.ErrorIs(t, err, context.Canceled)
require.EqualValues(t, 0, dialCount.Load())
}
func TestConnectorCanceledRequestDoesNotCacheConnection(t *testing.T) {
t.Parallel()
var (
dialCount atomic.Int32
closeCount atomic.Int32
)
dialStarted := make(chan struct{}, 1)
releaseDial := make(chan struct{})
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
dialCount.Add(1)
select {
case dialStarted <- struct{}{}:
default:
}
<-releaseDial
return &testConnectorConnection{}, nil
}, ConnectorCallbacks[*testConnectorConnection]{
IsClosed: func(connection *testConnectorConnection) bool {
return false
},
Close: func(connection *testConnectorConnection) {
closeCount.Add(1)
},
Reset: func(connection *testConnectorConnection) {},
})
requestContext, cancel := context.WithCancel(context.Background())
result := make(chan error, 1)
go func() {
_, err := connector.Get(requestContext)
result <- err
}()
<-dialStarted
cancel()
close(releaseDial)
err := <-result
require.ErrorIs(t, err, context.Canceled)
require.EqualValues(t, 1, dialCount.Load())
require.EqualValues(t, 1, closeCount.Load())
_, err = connector.Get(context.Background())
require.NoError(t, err)
require.EqualValues(t, 2, dialCount.Load())
}
func TestConnectorDialContextNotCanceledByRequestContextAfterDial(t *testing.T) {
t.Parallel()
var dialContext context.Context
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
dialContext = ctx
return &testConnectorConnection{}, nil
}, ConnectorCallbacks[*testConnectorConnection]{
IsClosed: func(connection *testConnectorConnection) bool {
return false
},
Close: func(connection *testConnectorConnection) {},
Reset: func(connection *testConnectorConnection) {},
})
requestContext, cancel := context.WithCancel(context.Background())
_, err := connector.Get(requestContext)
require.NoError(t, err)
require.NotNil(t, dialContext)
cancel()
select {
case <-dialContext.Done():
t.Fatal("dial context canceled by request context after successful dial")
case <-time.After(100 * time.Millisecond):
}
err = connector.Close()
require.NoError(t, err)
}
func TestConnectorDialContextCanceledOnClose(t *testing.T) {
t.Parallel()
var dialContext context.Context
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
dialContext = ctx
return &testConnectorConnection{}, nil
}, ConnectorCallbacks[*testConnectorConnection]{
IsClosed: func(connection *testConnectorConnection) bool {
return false
},
Close: func(connection *testConnectorConnection) {},
Reset: func(connection *testConnectorConnection) {},
})
_, err := connector.Get(context.Background())
require.NoError(t, err)
require.NotNil(t, dialContext)
select {
case <-dialContext.Done():
t.Fatal("dial context canceled before connector close")
default:
}
err = connector.Close()
require.NoError(t, err)
select {
case <-dialContext.Done():
case <-time.After(time.Second):
t.Fatal("dial context not canceled after connector close")
}
}

View File

@@ -106,7 +106,7 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
cancel(err)
return nil, err
}
return NewGRPCConn(stream), nil
return NewGRPCConn(stream, cancel), nil
}
func (c *Client) Close() error {

View File

@@ -1,8 +1,10 @@
package v2raygrpc
import (
"context"
"net"
"os"
"sync"
"time"
"github.com/sagernet/sing/common/baderror"
@@ -14,16 +16,19 @@ var _ net.Conn = (*GRPCConn)(nil)
type GRPCConn struct {
GunService
cache []byte
cache []byte
cancel context.CancelCauseFunc
closeOnce sync.Once
}
func NewGRPCConn(service GunService) *GRPCConn {
func NewGRPCConn(service GunService, cancel context.CancelCauseFunc) *GRPCConn {
//nolint:staticcheck
if client, isClient := service.(GunService_TunClient); isClient {
service = &clientConnWrapper{client}
}
return &GRPCConn{
GunService: service,
cancel: cancel,
}
}
@@ -54,6 +59,11 @@ func (c *GRPCConn) Write(b []byte) (n int, err error) {
}
func (c *GRPCConn) Close() error {
c.closeOnce.Do(func() {
if c.cancel != nil {
c.cancel(nil)
}
})
return nil
}

View File

@@ -52,7 +52,7 @@ func NewServer(ctx context.Context, logger logger.ContextLogger, options option.
}
func (s *Server) Tun(server GunService_TunServer) error {
conn := NewGRPCConn(server)
conn := NewGRPCConn(server, nil)
var source M.Socksaddr
if remotePeer, loaded := peer.FromContext(server.Context()); loaded {
source = M.SocksaddrFromNet(remotePeer.Addr)