mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-11 17:47:20 +10:00
Fix connector canceled dial cleanup
This commit is contained in:
@@ -55,6 +55,12 @@ type contextKeyConnecting struct{}
|
||||
|
||||
var errRecursiveConnectorDial = E.New("recursive connector dial")
|
||||
|
||||
type connectorDialResult[T any] struct {
|
||||
connection T
|
||||
cancel context.CancelFunc
|
||||
err error
|
||||
}
|
||||
|
||||
func (c *Connector[T]) Get(ctx context.Context) (T, error) {
|
||||
var zero T
|
||||
for {
|
||||
@@ -100,41 +106,37 @@ func (c *Connector[T]) Get(ctx context.Context) (T, error) {
|
||||
return zero, err
|
||||
}
|
||||
|
||||
c.connecting = make(chan struct{})
|
||||
connecting := make(chan struct{})
|
||||
c.connecting = connecting
|
||||
dialContext := context.WithValue(ctx, contextKeyConnecting{}, c)
|
||||
dialResult := make(chan connectorDialResult[T], 1)
|
||||
c.access.Unlock()
|
||||
|
||||
dialContext := context.WithValue(ctx, contextKeyConnecting{}, c)
|
||||
connection, cancel, err := c.dialWithCancellation(dialContext)
|
||||
go func() {
|
||||
connection, cancel, err := c.dialWithCancellation(dialContext)
|
||||
dialResult <- connectorDialResult[T]{
|
||||
connection: connection,
|
||||
cancel: cancel,
|
||||
err: err,
|
||||
}
|
||||
}()
|
||||
|
||||
c.access.Lock()
|
||||
close(c.connecting)
|
||||
c.connecting = nil
|
||||
|
||||
if err != nil {
|
||||
c.access.Unlock()
|
||||
return zero, err
|
||||
}
|
||||
|
||||
if c.closed {
|
||||
cancel()
|
||||
c.callbacks.Close(connection)
|
||||
c.access.Unlock()
|
||||
select {
|
||||
case result := <-dialResult:
|
||||
return c.completeDial(ctx, connecting, result)
|
||||
case <-ctx.Done():
|
||||
go func() {
|
||||
result := <-dialResult
|
||||
_, _ = c.completeDial(ctx, connecting, result)
|
||||
}()
|
||||
return zero, ctx.Err()
|
||||
case <-c.closeCtx.Done():
|
||||
go func() {
|
||||
result := <-dialResult
|
||||
_, _ = c.completeDial(ctx, connecting, result)
|
||||
}()
|
||||
return zero, ErrTransportClosed
|
||||
}
|
||||
if err = ctx.Err(); err != nil {
|
||||
cancel()
|
||||
c.callbacks.Close(connection)
|
||||
c.access.Unlock()
|
||||
return zero, err
|
||||
}
|
||||
|
||||
c.connection = connection
|
||||
c.hasConnection = true
|
||||
c.connectionCancel = cancel
|
||||
result := c.connection
|
||||
c.access.Unlock()
|
||||
|
||||
return result, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -143,6 +145,38 @@ func isRecursiveConnectorDial[T any](ctx context.Context, connector *Connector[T
|
||||
return loaded && dialConnector == connector
|
||||
}
|
||||
|
||||
func (c *Connector[T]) completeDial(ctx context.Context, connecting chan struct{}, result connectorDialResult[T]) (T, error) {
|
||||
var zero T
|
||||
|
||||
c.access.Lock()
|
||||
defer c.access.Unlock()
|
||||
defer func() {
|
||||
if c.connecting == connecting {
|
||||
c.connecting = nil
|
||||
}
|
||||
close(connecting)
|
||||
}()
|
||||
|
||||
if result.err != nil {
|
||||
return zero, result.err
|
||||
}
|
||||
if c.closed || c.closeCtx.Err() != nil {
|
||||
result.cancel()
|
||||
c.callbacks.Close(result.connection)
|
||||
return zero, ErrTransportClosed
|
||||
}
|
||||
if err := ctx.Err(); err != nil {
|
||||
result.cancel()
|
||||
c.callbacks.Close(result.connection)
|
||||
return zero, err
|
||||
}
|
||||
|
||||
c.connection = result.connection
|
||||
c.hasConnection = true
|
||||
c.connectionCancel = result.cancel
|
||||
return c.connection, nil
|
||||
}
|
||||
|
||||
func (c *Connector[T]) dialWithCancellation(ctx context.Context) (T, context.CancelFunc, error) {
|
||||
var zero T
|
||||
if err := ctx.Err(); err != nil {
|
||||
|
||||
@@ -188,13 +188,157 @@ func TestConnectorCanceledRequestDoesNotCacheConnection(t *testing.T) {
|
||||
err := <-result
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
require.EqualValues(t, 1, dialCount.Load())
|
||||
require.EqualValues(t, 1, closeCount.Load())
|
||||
require.Eventually(t, func() bool {
|
||||
return closeCount.Load() == 1
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
_, err = connector.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 2, dialCount.Load())
|
||||
}
|
||||
|
||||
func TestConnectorCanceledRequestReturnsBeforeIgnoredDialCompletes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
dialCount atomic.Int32
|
||||
closeCount atomic.Int32
|
||||
)
|
||||
dialStarted := make(chan struct{}, 1)
|
||||
releaseDial := make(chan struct{})
|
||||
|
||||
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
||||
dialCount.Add(1)
|
||||
select {
|
||||
case dialStarted <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
<-releaseDial
|
||||
return &testConnectorConnection{}, nil
|
||||
}, ConnectorCallbacks[*testConnectorConnection]{
|
||||
IsClosed: func(connection *testConnectorConnection) bool {
|
||||
return false
|
||||
},
|
||||
Close: func(connection *testConnectorConnection) {
|
||||
closeCount.Add(1)
|
||||
},
|
||||
Reset: func(connection *testConnectorConnection) {},
|
||||
})
|
||||
|
||||
requestContext, cancel := context.WithCancel(context.Background())
|
||||
result := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := connector.Get(requestContext)
|
||||
result <- err
|
||||
}()
|
||||
|
||||
<-dialStarted
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case err := <-result:
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("Get did not return after request cancel")
|
||||
}
|
||||
|
||||
require.EqualValues(t, 1, dialCount.Load())
|
||||
require.EqualValues(t, 0, closeCount.Load())
|
||||
|
||||
close(releaseDial)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return closeCount.Load() == 1
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
_, err := connector.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 2, dialCount.Load())
|
||||
}
|
||||
|
||||
func TestConnectorWaiterDoesNotStartNewDialBeforeCanceledDialCompletes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
dialCount atomic.Int32
|
||||
closeCount atomic.Int32
|
||||
)
|
||||
firstDialStarted := make(chan struct{}, 1)
|
||||
secondDialStarted := make(chan struct{}, 1)
|
||||
releaseFirstDial := make(chan struct{})
|
||||
|
||||
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
||||
attempt := dialCount.Add(1)
|
||||
switch attempt {
|
||||
case 1:
|
||||
select {
|
||||
case firstDialStarted <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
<-releaseFirstDial
|
||||
case 2:
|
||||
select {
|
||||
case secondDialStarted <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
return &testConnectorConnection{}, nil
|
||||
}, ConnectorCallbacks[*testConnectorConnection]{
|
||||
IsClosed: func(connection *testConnectorConnection) bool {
|
||||
return false
|
||||
},
|
||||
Close: func(connection *testConnectorConnection) {
|
||||
closeCount.Add(1)
|
||||
},
|
||||
Reset: func(connection *testConnectorConnection) {},
|
||||
})
|
||||
|
||||
requestContext, cancel := context.WithCancel(context.Background())
|
||||
firstResult := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := connector.Get(requestContext)
|
||||
firstResult <- err
|
||||
}()
|
||||
|
||||
<-firstDialStarted
|
||||
cancel()
|
||||
|
||||
secondResult := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := connector.Get(context.Background())
|
||||
secondResult <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-secondDialStarted:
|
||||
t.Fatal("second dial started before first dial completed")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
|
||||
select {
|
||||
case err := <-firstResult:
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("first Get did not return after request cancel")
|
||||
}
|
||||
|
||||
close(releaseFirstDial)
|
||||
|
||||
require.Eventually(t, func() bool {
|
||||
return closeCount.Load() == 1
|
||||
}, time.Second, 10*time.Millisecond)
|
||||
|
||||
select {
|
||||
case <-secondDialStarted:
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("second dial did not start after first dial completed")
|
||||
}
|
||||
|
||||
err := <-secondResult
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 2, dialCount.Load())
|
||||
}
|
||||
|
||||
func TestConnectorDialContextNotCanceledByRequestContextAfterDial(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user