mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-11 17:47:20 +10:00
408 lines
11 KiB
Go
408 lines
11 KiB
Go
package transport
|
|
|
|
import (
|
|
"context"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
type testConnectorConnection struct{}
|
|
|
|
func TestConnectorRecursiveGetFailsFast(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var (
|
|
dialCount atomic.Int32
|
|
closeCount atomic.Int32
|
|
connector *Connector[*testConnectorConnection]
|
|
)
|
|
|
|
dial := func(ctx context.Context) (*testConnectorConnection, error) {
|
|
dialCount.Add(1)
|
|
_, err := connector.Get(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &testConnectorConnection{}, nil
|
|
}
|
|
|
|
connector = NewConnector(context.Background(), dial, ConnectorCallbacks[*testConnectorConnection]{
|
|
IsClosed: func(connection *testConnectorConnection) bool {
|
|
return false
|
|
},
|
|
Close: func(connection *testConnectorConnection) {
|
|
closeCount.Add(1)
|
|
},
|
|
Reset: func(connection *testConnectorConnection) {
|
|
closeCount.Add(1)
|
|
},
|
|
})
|
|
|
|
_, err := connector.Get(context.Background())
|
|
require.ErrorIs(t, err, errRecursiveConnectorDial)
|
|
require.EqualValues(t, 1, dialCount.Load())
|
|
require.EqualValues(t, 0, closeCount.Load())
|
|
}
|
|
|
|
func TestConnectorRecursiveGetAcrossConnectorsAllowed(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var (
|
|
outerDialCount atomic.Int32
|
|
innerDialCount atomic.Int32
|
|
outerConnector *Connector[*testConnectorConnection]
|
|
innerConnector *Connector[*testConnectorConnection]
|
|
)
|
|
|
|
innerConnector = NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
|
innerDialCount.Add(1)
|
|
return &testConnectorConnection{}, nil
|
|
}, ConnectorCallbacks[*testConnectorConnection]{
|
|
IsClosed: func(connection *testConnectorConnection) bool {
|
|
return false
|
|
},
|
|
Close: func(connection *testConnectorConnection) {},
|
|
Reset: func(connection *testConnectorConnection) {},
|
|
})
|
|
|
|
outerConnector = NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
|
outerDialCount.Add(1)
|
|
_, err := innerConnector.Get(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &testConnectorConnection{}, nil
|
|
}, ConnectorCallbacks[*testConnectorConnection]{
|
|
IsClosed: func(connection *testConnectorConnection) bool {
|
|
return false
|
|
},
|
|
Close: func(connection *testConnectorConnection) {},
|
|
Reset: func(connection *testConnectorConnection) {},
|
|
})
|
|
|
|
_, err := outerConnector.Get(context.Background())
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 1, outerDialCount.Load())
|
|
require.EqualValues(t, 1, innerDialCount.Load())
|
|
}
|
|
|
|
func TestConnectorDialContextPreservesValueAndDeadline(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
type contextKey struct{}
|
|
|
|
var (
|
|
dialValue any
|
|
dialDeadline time.Time
|
|
dialHasDeadline bool
|
|
)
|
|
|
|
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
|
dialValue = ctx.Value(contextKey{})
|
|
dialDeadline, dialHasDeadline = ctx.Deadline()
|
|
return &testConnectorConnection{}, nil
|
|
}, ConnectorCallbacks[*testConnectorConnection]{
|
|
IsClosed: func(connection *testConnectorConnection) bool {
|
|
return false
|
|
},
|
|
Close: func(connection *testConnectorConnection) {},
|
|
Reset: func(connection *testConnectorConnection) {},
|
|
})
|
|
|
|
deadline := time.Now().Add(time.Minute)
|
|
requestContext, cancel := context.WithDeadline(context.WithValue(context.Background(), contextKey{}, "test-value"), deadline)
|
|
defer cancel()
|
|
|
|
_, err := connector.Get(requestContext)
|
|
require.NoError(t, err)
|
|
require.Equal(t, "test-value", dialValue)
|
|
require.True(t, dialHasDeadline)
|
|
require.WithinDuration(t, deadline, dialDeadline, time.Second)
|
|
}
|
|
|
|
func TestConnectorDialSkipsCanceledRequest(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var dialCount atomic.Int32
|
|
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
|
dialCount.Add(1)
|
|
return &testConnectorConnection{}, nil
|
|
}, ConnectorCallbacks[*testConnectorConnection]{
|
|
IsClosed: func(connection *testConnectorConnection) bool {
|
|
return false
|
|
},
|
|
Close: func(connection *testConnectorConnection) {},
|
|
Reset: func(connection *testConnectorConnection) {},
|
|
})
|
|
|
|
requestContext, cancel := context.WithCancel(context.Background())
|
|
cancel()
|
|
|
|
_, err := connector.Get(requestContext)
|
|
require.ErrorIs(t, err, context.Canceled)
|
|
require.EqualValues(t, 0, dialCount.Load())
|
|
}
|
|
|
|
func TestConnectorCanceledRequestDoesNotCacheConnection(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var (
|
|
dialCount atomic.Int32
|
|
closeCount atomic.Int32
|
|
)
|
|
dialStarted := make(chan struct{}, 1)
|
|
releaseDial := make(chan struct{})
|
|
|
|
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
|
dialCount.Add(1)
|
|
select {
|
|
case dialStarted <- struct{}{}:
|
|
default:
|
|
}
|
|
<-releaseDial
|
|
return &testConnectorConnection{}, nil
|
|
}, ConnectorCallbacks[*testConnectorConnection]{
|
|
IsClosed: func(connection *testConnectorConnection) bool {
|
|
return false
|
|
},
|
|
Close: func(connection *testConnectorConnection) {
|
|
closeCount.Add(1)
|
|
},
|
|
Reset: func(connection *testConnectorConnection) {},
|
|
})
|
|
|
|
requestContext, cancel := context.WithCancel(context.Background())
|
|
result := make(chan error, 1)
|
|
go func() {
|
|
_, err := connector.Get(requestContext)
|
|
result <- err
|
|
}()
|
|
|
|
<-dialStarted
|
|
cancel()
|
|
close(releaseDial)
|
|
|
|
err := <-result
|
|
require.ErrorIs(t, err, context.Canceled)
|
|
require.EqualValues(t, 1, dialCount.Load())
|
|
require.Eventually(t, func() bool {
|
|
return closeCount.Load() == 1
|
|
}, time.Second, 10*time.Millisecond)
|
|
|
|
_, err = connector.Get(context.Background())
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 2, dialCount.Load())
|
|
}
|
|
|
|
func TestConnectorCanceledRequestReturnsBeforeIgnoredDialCompletes(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var (
|
|
dialCount atomic.Int32
|
|
closeCount atomic.Int32
|
|
)
|
|
dialStarted := make(chan struct{}, 1)
|
|
releaseDial := make(chan struct{})
|
|
|
|
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
|
dialCount.Add(1)
|
|
select {
|
|
case dialStarted <- struct{}{}:
|
|
default:
|
|
}
|
|
<-releaseDial
|
|
return &testConnectorConnection{}, nil
|
|
}, ConnectorCallbacks[*testConnectorConnection]{
|
|
IsClosed: func(connection *testConnectorConnection) bool {
|
|
return false
|
|
},
|
|
Close: func(connection *testConnectorConnection) {
|
|
closeCount.Add(1)
|
|
},
|
|
Reset: func(connection *testConnectorConnection) {},
|
|
})
|
|
|
|
requestContext, cancel := context.WithCancel(context.Background())
|
|
result := make(chan error, 1)
|
|
go func() {
|
|
_, err := connector.Get(requestContext)
|
|
result <- err
|
|
}()
|
|
|
|
<-dialStarted
|
|
cancel()
|
|
|
|
select {
|
|
case err := <-result:
|
|
require.ErrorIs(t, err, context.Canceled)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("Get did not return after request cancel")
|
|
}
|
|
|
|
require.EqualValues(t, 1, dialCount.Load())
|
|
require.EqualValues(t, 0, closeCount.Load())
|
|
|
|
close(releaseDial)
|
|
|
|
require.Eventually(t, func() bool {
|
|
return closeCount.Load() == 1
|
|
}, time.Second, 10*time.Millisecond)
|
|
|
|
_, err := connector.Get(context.Background())
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 2, dialCount.Load())
|
|
}
|
|
|
|
func TestConnectorWaiterDoesNotStartNewDialBeforeCanceledDialCompletes(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var (
|
|
dialCount atomic.Int32
|
|
closeCount atomic.Int32
|
|
)
|
|
firstDialStarted := make(chan struct{}, 1)
|
|
secondDialStarted := make(chan struct{}, 1)
|
|
releaseFirstDial := make(chan struct{})
|
|
|
|
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
|
attempt := dialCount.Add(1)
|
|
switch attempt {
|
|
case 1:
|
|
select {
|
|
case firstDialStarted <- struct{}{}:
|
|
default:
|
|
}
|
|
<-releaseFirstDial
|
|
case 2:
|
|
select {
|
|
case secondDialStarted <- struct{}{}:
|
|
default:
|
|
}
|
|
}
|
|
return &testConnectorConnection{}, nil
|
|
}, ConnectorCallbacks[*testConnectorConnection]{
|
|
IsClosed: func(connection *testConnectorConnection) bool {
|
|
return false
|
|
},
|
|
Close: func(connection *testConnectorConnection) {
|
|
closeCount.Add(1)
|
|
},
|
|
Reset: func(connection *testConnectorConnection) {},
|
|
})
|
|
|
|
requestContext, cancel := context.WithCancel(context.Background())
|
|
firstResult := make(chan error, 1)
|
|
go func() {
|
|
_, err := connector.Get(requestContext)
|
|
firstResult <- err
|
|
}()
|
|
|
|
<-firstDialStarted
|
|
cancel()
|
|
|
|
secondResult := make(chan error, 1)
|
|
go func() {
|
|
_, err := connector.Get(context.Background())
|
|
secondResult <- err
|
|
}()
|
|
|
|
select {
|
|
case <-secondDialStarted:
|
|
t.Fatal("second dial started before first dial completed")
|
|
case <-time.After(100 * time.Millisecond):
|
|
}
|
|
|
|
select {
|
|
case err := <-firstResult:
|
|
require.ErrorIs(t, err, context.Canceled)
|
|
case <-time.After(time.Second):
|
|
t.Fatal("first Get did not return after request cancel")
|
|
}
|
|
|
|
close(releaseFirstDial)
|
|
|
|
require.Eventually(t, func() bool {
|
|
return closeCount.Load() == 1
|
|
}, time.Second, 10*time.Millisecond)
|
|
|
|
select {
|
|
case <-secondDialStarted:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("second dial did not start after first dial completed")
|
|
}
|
|
|
|
err := <-secondResult
|
|
require.NoError(t, err)
|
|
require.EqualValues(t, 2, dialCount.Load())
|
|
}
|
|
|
|
func TestConnectorDialContextNotCanceledByRequestContextAfterDial(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var dialContext context.Context
|
|
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
|
dialContext = ctx
|
|
return &testConnectorConnection{}, nil
|
|
}, ConnectorCallbacks[*testConnectorConnection]{
|
|
IsClosed: func(connection *testConnectorConnection) bool {
|
|
return false
|
|
},
|
|
Close: func(connection *testConnectorConnection) {},
|
|
Reset: func(connection *testConnectorConnection) {},
|
|
})
|
|
|
|
requestContext, cancel := context.WithCancel(context.Background())
|
|
_, err := connector.Get(requestContext)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, dialContext)
|
|
|
|
cancel()
|
|
|
|
select {
|
|
case <-dialContext.Done():
|
|
t.Fatal("dial context canceled by request context after successful dial")
|
|
case <-time.After(100 * time.Millisecond):
|
|
}
|
|
|
|
err = connector.Close()
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestConnectorDialContextCanceledOnClose(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var dialContext context.Context
|
|
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
|
dialContext = ctx
|
|
return &testConnectorConnection{}, nil
|
|
}, ConnectorCallbacks[*testConnectorConnection]{
|
|
IsClosed: func(connection *testConnectorConnection) bool {
|
|
return false
|
|
},
|
|
Close: func(connection *testConnectorConnection) {},
|
|
Reset: func(connection *testConnectorConnection) {},
|
|
})
|
|
|
|
_, err := connector.Get(context.Background())
|
|
require.NoError(t, err)
|
|
require.NotNil(t, dialContext)
|
|
|
|
select {
|
|
case <-dialContext.Done():
|
|
t.Fatal("dial context canceled before connector close")
|
|
default:
|
|
}
|
|
|
|
err = connector.Close()
|
|
require.NoError(t, err)
|
|
|
|
select {
|
|
case <-dialContext.Done():
|
|
case <-time.After(time.Second):
|
|
t.Fatal("dial context not canceled after connector close")
|
|
}
|
|
}
|