Files
sing-box/protocol/cloudflare/rpc_stream_test.go
2026-03-31 15:32:57 +08:00

252 lines
6.6 KiB
Go

//go:build with_cloudflared
package cloudflare
import (
"context"
"io"
"net"
"testing"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc"
N "github.com/sagernet/sing/common/network"
"github.com/google/uuid"
)
type blockingRPCStream struct {
closed chan struct{}
}
func newBlockingRPCStream() *blockingRPCStream {
return &blockingRPCStream{closed: make(chan struct{})}
}
func (s *blockingRPCStream) Read(_ []byte) (int, error) {
<-s.closed
return 0, io.EOF
}
func (s *blockingRPCStream) Write(p []byte) (int, error) {
return len(p), nil
}
func (s *blockingRPCStream) Close() error {
select {
case <-s.closed:
default:
close(s.closed)
}
return nil
}
type blockingPacketDialRouter struct {
testRouter
entered chan struct{}
release chan struct{}
}
func (r *blockingPacketDialRouter) DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error) {
select {
case <-r.entered:
default:
close(r.entered)
}
select {
case <-r.release:
return newBlockingPacketConn(), nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
func newRPCInbound(t *testing.T, router adapter.Router) *Inbound {
t.Helper()
inboundInstance := newLimitedInbound(t, 0)
inboundInstance.router = router
return inboundInstance
}
func newRPCClientPair(t *testing.T, ctx context.Context) (tunnelrpc.CloudflaredServer, io.Closer, io.Closer, net.Conn, net.Conn) {
t.Helper()
serverSide, clientSide := net.Pipe()
transport := safeTransport(clientSide)
clientConn := newRPCClientConn(transport, ctx)
client := tunnelrpc.CloudflaredServer{Client: clientConn.Bootstrap(ctx)}
return client, clientConn, transport, serverSide, clientSide
}
func TestServeRPCStreamRespectsContextDeadline(t *testing.T) {
inboundInstance := newLimitedInbound(t, 0)
stream := newBlockingRPCStream()
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
done := make(chan struct{})
go func() {
ServeRPCStream(ctx, stream, inboundInstance, NewDatagramV2Muxer(inboundInstance, &captureDatagramSender{}, inboundInstance.logger), inboundInstance.logger)
close(done)
}()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("expected ServeRPCStream to exit after context deadline")
}
}
func TestServeV3RPCStreamRespectsContextDeadline(t *testing.T) {
inboundInstance := newLimitedInbound(t, 0)
stream := newBlockingRPCStream()
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
done := make(chan struct{})
go func() {
ServeV3RPCStream(ctx, stream, inboundInstance, inboundInstance.logger)
close(done)
}()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("expected ServeV3RPCStream to exit after context deadline")
}
}
func TestV2RPCAckAllowsConcurrentDispatch(t *testing.T) {
router := &blockingPacketDialRouter{
entered: make(chan struct{}),
release: make(chan struct{}),
}
inboundInstance := newRPCInbound(t, router)
muxer := NewDatagramV2Muxer(inboundInstance, &captureDatagramSender{}, inboundInstance.logger)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
client, clientConn, transport, serverSide, clientSide := newRPCClientPair(t, ctx)
defer clientConn.Close()
defer transport.Close()
defer clientSide.Close()
done := make(chan struct{})
go func() {
ServeRPCStream(ctx, serverSide, inboundInstance, muxer, inboundInstance.logger)
close(done)
}()
registerPromise := client.RegisterUdpSession(ctx, func(p tunnelrpc.SessionManager_registerUdpSession_Params) error {
sessionID := uuid.New()
if err := p.SetSessionId(sessionID[:]); err != nil {
return err
}
if err := p.SetDstIp([]byte{127, 0, 0, 1}); err != nil {
return err
}
p.SetDstPort(53)
p.SetCloseAfterIdleHint(int64(time.Second))
return p.SetTraceContext("")
})
select {
case <-router.entered:
case <-time.After(time.Second):
t.Fatal("expected register RPC to enter the blocking dial")
}
updateCtx, updateCancel := context.WithTimeout(ctx, 500*time.Millisecond)
defer updateCancel()
updatePromise := client.UpdateConfiguration(updateCtx, func(p tunnelrpc.ConfigurationManager_updateConfiguration_Params) error {
p.SetVersion(1)
return p.SetConfig([]byte(`{"ingress":[{"service":"http_status:503"}]}`))
})
if _, err := updatePromise.Result().Struct(); err != nil {
t.Fatalf("expected concurrent update RPC to succeed, got %v", err)
}
close(router.release)
if _, err := registerPromise.Result().Struct(); err != nil {
t.Fatalf("expected register RPC to complete, got %v", err)
}
cancel()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("expected ServeRPCStream to exit")
}
}
func TestV3RPCAckAllowsConcurrentDispatch(t *testing.T) {
inboundInstance := newLimitedInbound(t, 0)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
client, clientConn, transport, serverSide, clientSide := newRPCClientPair(t, ctx)
defer clientConn.Close()
defer transport.Close()
defer clientSide.Close()
done := make(chan struct{})
go func() {
ServeV3RPCStream(ctx, serverSide, inboundInstance, inboundInstance.logger)
close(done)
}()
inboundInstance.configManager.access.Lock()
updatePromise := client.UpdateConfiguration(ctx, func(p tunnelrpc.ConfigurationManager_updateConfiguration_Params) error {
p.SetVersion(1)
return p.SetConfig([]byte(`{"ingress":[{"service":"http_status:503"}]}`))
})
time.Sleep(20 * time.Millisecond)
registerCtx, registerCancel := context.WithTimeout(ctx, 500*time.Millisecond)
defer registerCancel()
registerPromise := client.RegisterUdpSession(registerCtx, func(p tunnelrpc.SessionManager_registerUdpSession_Params) error {
sessionID := uuid.New()
if err := p.SetSessionId(sessionID[:]); err != nil {
return err
}
if err := p.SetDstIp([]byte{127, 0, 0, 1}); err != nil {
return err
}
p.SetDstPort(53)
p.SetCloseAfterIdleHint(int64(time.Second))
return p.SetTraceContext("")
})
registerResult, err := registerPromise.Result().Struct()
if err != nil {
t.Fatalf("expected concurrent v3 register RPC to succeed, got %v", err)
}
resultErr, err := registerResult.Err()
if err != nil {
t.Fatal(err)
}
if resultErr != errUnsupportedDatagramV3UDPRegistration.Error() {
t.Fatalf("unexpected registration error %q", resultErr)
}
inboundInstance.configManager.access.Unlock()
if _, err := updatePromise.Result().Struct(); err != nil {
t.Fatalf("expected update RPC to complete, got %v", err)
}
cancel()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("expected ServeV3RPCStream to exit")
}
}