mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-12 01:57:18 +10:00
Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c0d45aebfa | ||
|
|
4e0a953b98 | ||
|
|
27c5b0b1af | ||
|
|
84019b06d9 |
Submodule clients/android updated: 172199dfc3...7777469b5d
Submodule clients/apple updated: 16800708dd...c19945f65b
@@ -4,6 +4,9 @@ import (
|
||||
"context"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
E "github.com/sagernet/sing/common/exceptions"
|
||||
)
|
||||
|
||||
type ConnectorCallbacks[T any] struct {
|
||||
@@ -16,10 +19,11 @@ type Connector[T any] struct {
|
||||
dial func(ctx context.Context) (T, error)
|
||||
callbacks ConnectorCallbacks[T]
|
||||
|
||||
access sync.Mutex
|
||||
connection T
|
||||
hasConnection bool
|
||||
connecting chan struct{}
|
||||
access sync.Mutex
|
||||
connection T
|
||||
hasConnection bool
|
||||
connectionCancel context.CancelFunc
|
||||
connecting chan struct{}
|
||||
|
||||
closeCtx context.Context
|
||||
closed bool
|
||||
@@ -47,6 +51,10 @@ func NewSingleflightConnector(closeCtx context.Context, dial func(context.Contex
|
||||
})
|
||||
}
|
||||
|
||||
type contextKeyConnecting struct{}
|
||||
|
||||
var errRecursiveConnectorDial = E.New("recursive connector dial")
|
||||
|
||||
func (c *Connector[T]) Get(ctx context.Context) (T, error) {
|
||||
var zero T
|
||||
for {
|
||||
@@ -64,6 +72,14 @@ func (c *Connector[T]) Get(ctx context.Context) (T, error) {
|
||||
}
|
||||
|
||||
c.hasConnection = false
|
||||
if c.connectionCancel != nil {
|
||||
c.connectionCancel()
|
||||
c.connectionCancel = nil
|
||||
}
|
||||
if isRecursiveConnectorDial(ctx, c) {
|
||||
c.access.Unlock()
|
||||
return zero, errRecursiveConnectorDial
|
||||
}
|
||||
|
||||
if c.connecting != nil {
|
||||
connecting := c.connecting
|
||||
@@ -79,10 +95,16 @@ func (c *Connector[T]) Get(ctx context.Context) (T, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if err := ctx.Err(); err != nil {
|
||||
c.access.Unlock()
|
||||
return zero, err
|
||||
}
|
||||
|
||||
c.connecting = make(chan struct{})
|
||||
c.access.Unlock()
|
||||
|
||||
connection, err := c.dialWithCancellation(ctx)
|
||||
dialContext := context.WithValue(ctx, contextKeyConnecting{}, c)
|
||||
connection, cancel, err := c.dialWithCancellation(dialContext)
|
||||
|
||||
c.access.Lock()
|
||||
close(c.connecting)
|
||||
@@ -94,13 +116,21 @@ func (c *Connector[T]) Get(ctx context.Context) (T, error) {
|
||||
}
|
||||
|
||||
if c.closed {
|
||||
cancel()
|
||||
c.callbacks.Close(connection)
|
||||
c.access.Unlock()
|
||||
return zero, ErrTransportClosed
|
||||
}
|
||||
if err = ctx.Err(); err != nil {
|
||||
cancel()
|
||||
c.callbacks.Close(connection)
|
||||
c.access.Unlock()
|
||||
return zero, err
|
||||
}
|
||||
|
||||
c.connection = connection
|
||||
c.hasConnection = true
|
||||
c.connectionCancel = cancel
|
||||
result := c.connection
|
||||
c.access.Unlock()
|
||||
|
||||
@@ -108,19 +138,63 @@ func (c *Connector[T]) Get(ctx context.Context) (T, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Connector[T]) dialWithCancellation(ctx context.Context) (T, error) {
|
||||
dialCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
func isRecursiveConnectorDial[T any](ctx context.Context, connector *Connector[T]) bool {
|
||||
dialConnector, loaded := ctx.Value(contextKeyConnecting{}).(*Connector[T])
|
||||
return loaded && dialConnector == connector
|
||||
}
|
||||
|
||||
go func() {
|
||||
select {
|
||||
case <-c.closeCtx.Done():
|
||||
func (c *Connector[T]) dialWithCancellation(ctx context.Context) (T, context.CancelFunc, error) {
|
||||
var zero T
|
||||
if err := ctx.Err(); err != nil {
|
||||
return zero, nil, err
|
||||
}
|
||||
connCtx, cancel := context.WithCancel(c.closeCtx)
|
||||
|
||||
var (
|
||||
stateAccess sync.Mutex
|
||||
dialComplete bool
|
||||
)
|
||||
stopCancel := context.AfterFunc(ctx, func() {
|
||||
stateAccess.Lock()
|
||||
if !dialComplete {
|
||||
cancel()
|
||||
case <-dialCtx.Done():
|
||||
}
|
||||
}()
|
||||
stateAccess.Unlock()
|
||||
})
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
stateAccess.Lock()
|
||||
dialComplete = true
|
||||
stateAccess.Unlock()
|
||||
stopCancel()
|
||||
cancel()
|
||||
return zero, nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
return c.dial(dialCtx)
|
||||
connection, err := c.dial(valueContext{connCtx, ctx})
|
||||
stateAccess.Lock()
|
||||
dialComplete = true
|
||||
stateAccess.Unlock()
|
||||
stopCancel()
|
||||
if err != nil {
|
||||
cancel()
|
||||
return zero, nil, err
|
||||
}
|
||||
return connection, cancel, nil
|
||||
}
|
||||
|
||||
type valueContext struct {
|
||||
context.Context
|
||||
parent context.Context
|
||||
}
|
||||
|
||||
func (v valueContext) Value(key any) any {
|
||||
return v.parent.Value(key)
|
||||
}
|
||||
|
||||
func (v valueContext) Deadline() (time.Time, bool) {
|
||||
return v.parent.Deadline()
|
||||
}
|
||||
|
||||
func (c *Connector[T]) Close() error {
|
||||
@@ -132,6 +206,10 @@ func (c *Connector[T]) Close() error {
|
||||
}
|
||||
c.closed = true
|
||||
|
||||
if c.connectionCancel != nil {
|
||||
c.connectionCancel()
|
||||
c.connectionCancel = nil
|
||||
}
|
||||
if c.hasConnection {
|
||||
c.callbacks.Close(c.connection)
|
||||
c.hasConnection = false
|
||||
@@ -144,6 +222,10 @@ func (c *Connector[T]) Reset() {
|
||||
c.access.Lock()
|
||||
defer c.access.Unlock()
|
||||
|
||||
if c.connectionCancel != nil {
|
||||
c.connectionCancel()
|
||||
c.connectionCancel = nil
|
||||
}
|
||||
if c.hasConnection {
|
||||
c.callbacks.Reset(c.connection)
|
||||
c.hasConnection = false
|
||||
|
||||
263
dns/transport/connector_test.go
Normal file
263
dns/transport/connector_test.go
Normal file
@@ -0,0 +1,263 @@
|
||||
package transport
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type testConnectorConnection struct{}
|
||||
|
||||
func TestConnectorRecursiveGetFailsFast(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
dialCount atomic.Int32
|
||||
closeCount atomic.Int32
|
||||
connector *Connector[*testConnectorConnection]
|
||||
)
|
||||
|
||||
dial := func(ctx context.Context) (*testConnectorConnection, error) {
|
||||
dialCount.Add(1)
|
||||
_, err := connector.Get(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &testConnectorConnection{}, nil
|
||||
}
|
||||
|
||||
connector = NewConnector(context.Background(), dial, ConnectorCallbacks[*testConnectorConnection]{
|
||||
IsClosed: func(connection *testConnectorConnection) bool {
|
||||
return false
|
||||
},
|
||||
Close: func(connection *testConnectorConnection) {
|
||||
closeCount.Add(1)
|
||||
},
|
||||
Reset: func(connection *testConnectorConnection) {
|
||||
closeCount.Add(1)
|
||||
},
|
||||
})
|
||||
|
||||
_, err := connector.Get(context.Background())
|
||||
require.ErrorIs(t, err, errRecursiveConnectorDial)
|
||||
require.EqualValues(t, 1, dialCount.Load())
|
||||
require.EqualValues(t, 0, closeCount.Load())
|
||||
}
|
||||
|
||||
func TestConnectorRecursiveGetAcrossConnectorsAllowed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
outerDialCount atomic.Int32
|
||||
innerDialCount atomic.Int32
|
||||
outerConnector *Connector[*testConnectorConnection]
|
||||
innerConnector *Connector[*testConnectorConnection]
|
||||
)
|
||||
|
||||
innerConnector = NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
||||
innerDialCount.Add(1)
|
||||
return &testConnectorConnection{}, nil
|
||||
}, ConnectorCallbacks[*testConnectorConnection]{
|
||||
IsClosed: func(connection *testConnectorConnection) bool {
|
||||
return false
|
||||
},
|
||||
Close: func(connection *testConnectorConnection) {},
|
||||
Reset: func(connection *testConnectorConnection) {},
|
||||
})
|
||||
|
||||
outerConnector = NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
||||
outerDialCount.Add(1)
|
||||
_, err := innerConnector.Get(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &testConnectorConnection{}, nil
|
||||
}, ConnectorCallbacks[*testConnectorConnection]{
|
||||
IsClosed: func(connection *testConnectorConnection) bool {
|
||||
return false
|
||||
},
|
||||
Close: func(connection *testConnectorConnection) {},
|
||||
Reset: func(connection *testConnectorConnection) {},
|
||||
})
|
||||
|
||||
_, err := outerConnector.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, outerDialCount.Load())
|
||||
require.EqualValues(t, 1, innerDialCount.Load())
|
||||
}
|
||||
|
||||
func TestConnectorDialContextPreservesValueAndDeadline(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type contextKey struct{}
|
||||
|
||||
var (
|
||||
dialValue any
|
||||
dialDeadline time.Time
|
||||
dialHasDeadline bool
|
||||
)
|
||||
|
||||
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
||||
dialValue = ctx.Value(contextKey{})
|
||||
dialDeadline, dialHasDeadline = ctx.Deadline()
|
||||
return &testConnectorConnection{}, nil
|
||||
}, ConnectorCallbacks[*testConnectorConnection]{
|
||||
IsClosed: func(connection *testConnectorConnection) bool {
|
||||
return false
|
||||
},
|
||||
Close: func(connection *testConnectorConnection) {},
|
||||
Reset: func(connection *testConnectorConnection) {},
|
||||
})
|
||||
|
||||
deadline := time.Now().Add(time.Minute)
|
||||
requestContext, cancel := context.WithDeadline(context.WithValue(context.Background(), contextKey{}, "test-value"), deadline)
|
||||
defer cancel()
|
||||
|
||||
_, err := connector.Get(requestContext)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test-value", dialValue)
|
||||
require.True(t, dialHasDeadline)
|
||||
require.WithinDuration(t, deadline, dialDeadline, time.Second)
|
||||
}
|
||||
|
||||
func TestConnectorDialSkipsCanceledRequest(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var dialCount atomic.Int32
|
||||
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
||||
dialCount.Add(1)
|
||||
return &testConnectorConnection{}, nil
|
||||
}, ConnectorCallbacks[*testConnectorConnection]{
|
||||
IsClosed: func(connection *testConnectorConnection) bool {
|
||||
return false
|
||||
},
|
||||
Close: func(connection *testConnectorConnection) {},
|
||||
Reset: func(connection *testConnectorConnection) {},
|
||||
})
|
||||
|
||||
requestContext, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
_, err := connector.Get(requestContext)
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
require.EqualValues(t, 0, dialCount.Load())
|
||||
}
|
||||
|
||||
func TestConnectorCanceledRequestDoesNotCacheConnection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var (
|
||||
dialCount atomic.Int32
|
||||
closeCount atomic.Int32
|
||||
)
|
||||
dialStarted := make(chan struct{}, 1)
|
||||
releaseDial := make(chan struct{})
|
||||
|
||||
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
||||
dialCount.Add(1)
|
||||
select {
|
||||
case dialStarted <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
<-releaseDial
|
||||
return &testConnectorConnection{}, nil
|
||||
}, ConnectorCallbacks[*testConnectorConnection]{
|
||||
IsClosed: func(connection *testConnectorConnection) bool {
|
||||
return false
|
||||
},
|
||||
Close: func(connection *testConnectorConnection) {
|
||||
closeCount.Add(1)
|
||||
},
|
||||
Reset: func(connection *testConnectorConnection) {},
|
||||
})
|
||||
|
||||
requestContext, cancel := context.WithCancel(context.Background())
|
||||
result := make(chan error, 1)
|
||||
go func() {
|
||||
_, err := connector.Get(requestContext)
|
||||
result <- err
|
||||
}()
|
||||
|
||||
<-dialStarted
|
||||
cancel()
|
||||
close(releaseDial)
|
||||
|
||||
err := <-result
|
||||
require.ErrorIs(t, err, context.Canceled)
|
||||
require.EqualValues(t, 1, dialCount.Load())
|
||||
require.EqualValues(t, 1, closeCount.Load())
|
||||
|
||||
_, err = connector.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 2, dialCount.Load())
|
||||
}
|
||||
|
||||
func TestConnectorDialContextNotCanceledByRequestContextAfterDial(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var dialContext context.Context
|
||||
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
||||
dialContext = ctx
|
||||
return &testConnectorConnection{}, nil
|
||||
}, ConnectorCallbacks[*testConnectorConnection]{
|
||||
IsClosed: func(connection *testConnectorConnection) bool {
|
||||
return false
|
||||
},
|
||||
Close: func(connection *testConnectorConnection) {},
|
||||
Reset: func(connection *testConnectorConnection) {},
|
||||
})
|
||||
|
||||
requestContext, cancel := context.WithCancel(context.Background())
|
||||
_, err := connector.Get(requestContext)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, dialContext)
|
||||
|
||||
cancel()
|
||||
|
||||
select {
|
||||
case <-dialContext.Done():
|
||||
t.Fatal("dial context canceled by request context after successful dial")
|
||||
case <-time.After(100 * time.Millisecond):
|
||||
}
|
||||
|
||||
err = connector.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestConnectorDialContextCanceledOnClose(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var dialContext context.Context
|
||||
connector := NewConnector(context.Background(), func(ctx context.Context) (*testConnectorConnection, error) {
|
||||
dialContext = ctx
|
||||
return &testConnectorConnection{}, nil
|
||||
}, ConnectorCallbacks[*testConnectorConnection]{
|
||||
IsClosed: func(connection *testConnectorConnection) bool {
|
||||
return false
|
||||
},
|
||||
Close: func(connection *testConnectorConnection) {},
|
||||
Reset: func(connection *testConnectorConnection) {},
|
||||
})
|
||||
|
||||
_, err := connector.Get(context.Background())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, dialContext)
|
||||
|
||||
select {
|
||||
case <-dialContext.Done():
|
||||
t.Fatal("dial context canceled before connector close")
|
||||
default:
|
||||
}
|
||||
|
||||
err = connector.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
select {
|
||||
case <-dialContext.Done():
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("dial context not canceled after connector close")
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,10 @@
|
||||
icon: material/alert-decagram
|
||||
---
|
||||
|
||||
#### 1.13.2
|
||||
|
||||
* Fixes and improvements
|
||||
|
||||
#### 1.13.1
|
||||
|
||||
* Fixes and improvements
|
||||
|
||||
2
go.mod
2
go.mod
@@ -33,7 +33,7 @@ require (
|
||||
github.com/sagernet/gomobile v0.1.12
|
||||
github.com/sagernet/gvisor v0.0.0-20250811.0-sing-box-mod.1
|
||||
github.com/sagernet/quic-go v0.59.0-sing-box-mod.4
|
||||
github.com/sagernet/sing v0.8.1
|
||||
github.com/sagernet/sing v0.8.2
|
||||
github.com/sagernet/sing-mux v0.3.4
|
||||
github.com/sagernet/sing-quic v0.6.0
|
||||
github.com/sagernet/sing-shadowsocks v0.2.8
|
||||
|
||||
4
go.sum
4
go.sum
@@ -236,8 +236,8 @@ github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNen
|
||||
github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8=
|
||||
github.com/sagernet/quic-go v0.59.0-sing-box-mod.4 h1:6qvrUW79S+CrPwWz6cMePXohgjHoKxLo3c+MDhNwc3o=
|
||||
github.com/sagernet/quic-go v0.59.0-sing-box-mod.4/go.mod h1:OqILvS182CyOol5zNNo6bguvOGgXzV459+chpRaUC+4=
|
||||
github.com/sagernet/sing v0.8.1 h1:Li+zg4xdiMsvdX4j50TPqmSG8LF/TB9US2qlAN40izU=
|
||||
github.com/sagernet/sing v0.8.1/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
|
||||
github.com/sagernet/sing v0.8.2 h1:kX1IH9SWJv4S0T9M8O+HNahWgbOuY1VauxbF7NU5lOg=
|
||||
github.com/sagernet/sing v0.8.2/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
|
||||
github.com/sagernet/sing-mux v0.3.4 h1:ZQplKl8MNXutjzbMVtWvWG31fohhgOfCuUZR4dVQ8+s=
|
||||
github.com/sagernet/sing-mux v0.3.4/go.mod h1:QvlKMyNBNrQoyX4x+gq028uPbLM2XeRpWtDsWBJbFSk=
|
||||
github.com/sagernet/sing-quic v0.6.0 h1:dhrFnP45wgVKEOT1EvtsToxdzRnHIDIAgj6WHV9pLyM=
|
||||
|
||||
@@ -106,7 +106,7 @@ func (c *Client) DialContext(ctx context.Context) (net.Conn, error) {
|
||||
cancel(err)
|
||||
return nil, err
|
||||
}
|
||||
return NewGRPCConn(stream), nil
|
||||
return NewGRPCConn(stream, cancel), nil
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
package v2raygrpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/sagernet/sing/common/baderror"
|
||||
@@ -14,16 +16,19 @@ var _ net.Conn = (*GRPCConn)(nil)
|
||||
|
||||
type GRPCConn struct {
|
||||
GunService
|
||||
cache []byte
|
||||
cache []byte
|
||||
cancel context.CancelCauseFunc
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func NewGRPCConn(service GunService) *GRPCConn {
|
||||
func NewGRPCConn(service GunService, cancel context.CancelCauseFunc) *GRPCConn {
|
||||
//nolint:staticcheck
|
||||
if client, isClient := service.(GunService_TunClient); isClient {
|
||||
service = &clientConnWrapper{client}
|
||||
}
|
||||
return &GRPCConn{
|
||||
GunService: service,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,6 +59,11 @@ func (c *GRPCConn) Write(b []byte) (n int, err error) {
|
||||
}
|
||||
|
||||
func (c *GRPCConn) Close() error {
|
||||
c.closeOnce.Do(func() {
|
||||
if c.cancel != nil {
|
||||
c.cancel(nil)
|
||||
}
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -52,7 +52,7 @@ func NewServer(ctx context.Context, logger logger.ContextLogger, options option.
|
||||
}
|
||||
|
||||
func (s *Server) Tun(server GunService_TunServer) error {
|
||||
conn := NewGRPCConn(server)
|
||||
conn := NewGRPCConn(server, nil)
|
||||
var source M.Socksaddr
|
||||
if remotePeer, loaded := peer.FromContext(server.Context()); loaded {
|
||||
source = M.SocksaddrFromNet(remotePeer.Addr)
|
||||
|
||||
@@ -136,10 +136,12 @@ func (s *Server) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
|
||||
s.handler.NewConnectionEx(DupContext(request.Context()), conn, source, M.Socksaddr{}, nil)
|
||||
} else {
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
flusher := writer.(http.Flusher)
|
||||
flusher.Flush()
|
||||
done := make(chan struct{})
|
||||
conn := NewHTTP2Wrapper(&ServerHTTPConn{
|
||||
NewHTTPConn(request.Body, writer),
|
||||
writer.(http.Flusher),
|
||||
flusher,
|
||||
})
|
||||
s.handler.NewConnectionEx(request.Context(), conn, source, M.Socksaddr{}, N.OnceClose(func(it error) {
|
||||
close(done)
|
||||
|
||||
Reference in New Issue
Block a user