mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-13 20:28:32 +10:00
Align cloudflare runtime behavior with cloudflared
This commit is contained in:
@@ -26,3 +26,13 @@ func TestValidateRegistrationResultRejectsNonRemoteManaged(t *testing.T) {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeProtocolAcceptsAuto(t *testing.T) {
|
||||
protocol, err := normalizeProtocol("auto")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if protocol != "" {
|
||||
t.Fatalf("expected auto protocol to normalize to empty string, got %q", protocol)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -265,7 +265,7 @@ func (i *Inbound) handleHTTPService(ctx context.Context, stream io.ReadWriteClos
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
i.handleStreamService(ctx, stream, respWriter, request, metadata, service.Destination)
|
||||
i.handleStreamService(ctx, stream, respWriter, request, metadata, service)
|
||||
case ResolvedServiceUnix, ResolvedServiceUnixTLS, ResolvedServiceHelloWorld:
|
||||
if request.Type == ConnectionTypeHTTP {
|
||||
i.handleDirectHTTPStream(ctx, stream, respWriter, request, metadata, service)
|
||||
@@ -279,7 +279,7 @@ func (i *Inbound) handleHTTPService(ctx context.Context, stream io.ReadWriteClos
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
i.handleBastionStream(ctx, stream, respWriter, request, metadata)
|
||||
i.handleBastionStream(ctx, stream, respWriter, request, metadata, service)
|
||||
case ResolvedServiceSocksProxy:
|
||||
if request.Type != ConnectionTypeWebsocket {
|
||||
err := E.New("socks-proxy service requires websocket request type")
|
||||
|
||||
@@ -97,3 +97,30 @@ func TestGetRegionalServiceName(t *testing.T) {
|
||||
t.Fatalf("expected regional service us-%s, got %s", edgeSRVService, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInitialEdgeAddrIndex(t *testing.T) {
|
||||
if got := initialEdgeAddrIndex(0, 4); got != 0 {
|
||||
t.Fatalf("expected conn 0 to get index 0, got %d", got)
|
||||
}
|
||||
if got := initialEdgeAddrIndex(3, 4); got != 3 {
|
||||
t.Fatalf("expected conn 3 to get index 3, got %d", got)
|
||||
}
|
||||
if got := initialEdgeAddrIndex(5, 4); got != 1 {
|
||||
t.Fatalf("expected conn 5 to wrap to index 1, got %d", got)
|
||||
}
|
||||
if got := initialEdgeAddrIndex(2, 1); got != 0 {
|
||||
t.Fatalf("expected single-address pool to always return 0, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRotateEdgeAddrIndex(t *testing.T) {
|
||||
if got := rotateEdgeAddrIndex(0, 4); got != 1 {
|
||||
t.Fatalf("expected index 0 to rotate to 1, got %d", got)
|
||||
}
|
||||
if got := rotateEdgeAddrIndex(3, 4); got != 0 {
|
||||
t.Fatalf("expected last index to wrap to 0, got %d", got)
|
||||
}
|
||||
if got := rotateEdgeAddrIndex(0, 1); got != 0 {
|
||||
t.Fatalf("expected single-address pool to stay at 0, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,9 +86,9 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
|
||||
haConnections = 4
|
||||
}
|
||||
|
||||
protocol := options.Protocol
|
||||
if protocol != "" && protocol != "quic" && protocol != "http2" {
|
||||
return nil, E.New("unsupported protocol: ", protocol, ", expected quic or http2")
|
||||
protocol, err := normalizeProtocol(options.Protocol)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
edgeIPVersion := options.EdgeIPVersion
|
||||
@@ -283,6 +283,7 @@ const (
|
||||
func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr, features []string) {
|
||||
defer i.done.Done()
|
||||
|
||||
edgeIndex := initialEdgeAddrIndex(connIndex, len(edgeAddrs))
|
||||
retries := 0
|
||||
for {
|
||||
select {
|
||||
@@ -291,7 +292,7 @@ func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr, fe
|
||||
default:
|
||||
}
|
||||
|
||||
edgeAddr := edgeAddrs[rand.Intn(len(edgeAddrs))]
|
||||
edgeAddr := edgeAddrs[edgeIndex]
|
||||
err := i.serveConnection(connIndex, edgeAddr, features, uint8(retries))
|
||||
if err == nil || i.ctx.Err() != nil {
|
||||
return
|
||||
@@ -303,6 +304,7 @@ func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr, fe
|
||||
}
|
||||
|
||||
retries++
|
||||
edgeIndex = rotateEdgeAddrIndex(edgeIndex, len(edgeAddrs))
|
||||
backoff := backoffDuration(retries)
|
||||
var retryableErr *RetryableError
|
||||
if errors.As(err, &retryableErr) && retryableErr.Delay > 0 {
|
||||
@@ -410,6 +412,20 @@ func backoffDuration(retries int) time.Duration {
|
||||
return backoff/2 + jitter
|
||||
}
|
||||
|
||||
func initialEdgeAddrIndex(connIndex uint8, size int) int {
|
||||
if size <= 1 {
|
||||
return 0
|
||||
}
|
||||
return int(connIndex) % size
|
||||
}
|
||||
|
||||
func rotateEdgeAddrIndex(current int, size int) int {
|
||||
if size <= 1 {
|
||||
return 0
|
||||
}
|
||||
return (current + 1) % size
|
||||
}
|
||||
|
||||
func flattenRegions(regions [][]*EdgeAddr) []*EdgeAddr {
|
||||
var result []*EdgeAddr
|
||||
for _, region := range regions {
|
||||
@@ -430,3 +446,13 @@ func parseToken(token string) (Credentials, error) {
|
||||
}
|
||||
return tunnelToken.ToCredentials(), nil
|
||||
}
|
||||
|
||||
func normalizeProtocol(protocol string) (string, error) {
|
||||
if protocol == "auto" {
|
||||
return "", nil
|
||||
}
|
||||
if protocol != "" && protocol != "quic" && protocol != "http2" {
|
||||
return "", E.New("unsupported protocol: ", protocol, ", expected auto, quic or http2")
|
||||
}
|
||||
return protocol, nil
|
||||
}
|
||||
|
||||
@@ -32,20 +32,20 @@ const (
|
||||
socksReplyCommandNotSupported = 7
|
||||
)
|
||||
|
||||
func (i *Inbound) handleBastionStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext) {
|
||||
func (i *Inbound) handleBastionStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
|
||||
destination, err := resolveBastionDestination(request)
|
||||
if err != nil {
|
||||
respWriter.WriteResponse(err, nil)
|
||||
return
|
||||
}
|
||||
i.handleRouterBackedStream(ctx, stream, respWriter, request, M.ParseSocksaddr(destination))
|
||||
i.handleRouterBackedStream(ctx, stream, respWriter, request, M.ParseSocksaddr(destination), service.OriginRequest.ProxyType)
|
||||
}
|
||||
|
||||
func (i *Inbound) handleStreamService(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, destination M.Socksaddr) {
|
||||
i.handleRouterBackedStream(ctx, stream, respWriter, request, destination)
|
||||
func (i *Inbound) handleStreamService(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
|
||||
i.handleRouterBackedStream(ctx, stream, respWriter, request, service.Destination, service.OriginRequest.ProxyType)
|
||||
}
|
||||
|
||||
func (i *Inbound) handleRouterBackedStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, destination M.Socksaddr) {
|
||||
func (i *Inbound) handleRouterBackedStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, destination M.Socksaddr, proxyType string) {
|
||||
targetConn, cleanup, err := i.dialRouterTCP(ctx, destination)
|
||||
if err != nil {
|
||||
respWriter.WriteResponse(err, nil)
|
||||
@@ -61,6 +61,12 @@ func (i *Inbound) handleRouterBackedStream(ctx context.Context, stream io.ReadWr
|
||||
|
||||
wsConn := v2raywebsocket.NewConn(newStreamConn(stream), nil, ws.StateServerSide)
|
||||
defer wsConn.Close()
|
||||
if isSocksProxyType(proxyType) {
|
||||
if err := serveFixedSocksStream(ctx, wsConn, targetConn); err != nil && !E.IsClosedOrCanceled(err) {
|
||||
i.logger.DebugContext(ctx, "socks-over-websocket stream closed: ", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
_ = bufio.CopyConn(ctx, wsConn, targetConn)
|
||||
}
|
||||
|
||||
@@ -101,6 +107,67 @@ func websocketResponseHeaders(request *ConnectRequest) http.Header {
|
||||
return header
|
||||
}
|
||||
|
||||
func isSocksProxyType(proxyType string) bool {
|
||||
lower := strings.ToLower(strings.TrimSpace(proxyType))
|
||||
return lower == "socks" || lower == "socks5"
|
||||
}
|
||||
|
||||
func serveFixedSocksStream(ctx context.Context, conn net.Conn, targetConn net.Conn) error {
|
||||
version := make([]byte, 1)
|
||||
if _, err := io.ReadFull(conn, version); err != nil {
|
||||
return err
|
||||
}
|
||||
if version[0] != 5 {
|
||||
return E.New("unsupported SOCKS version: ", version[0])
|
||||
}
|
||||
|
||||
methodCount := make([]byte, 1)
|
||||
if _, err := io.ReadFull(conn, methodCount); err != nil {
|
||||
return err
|
||||
}
|
||||
methods := make([]byte, int(methodCount[0]))
|
||||
if _, err := io.ReadFull(conn, methods); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var supportsNoAuth bool
|
||||
for _, method := range methods {
|
||||
if method == 0 {
|
||||
supportsNoAuth = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !supportsNoAuth {
|
||||
_, err := conn.Write([]byte{5, 255})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return E.New("unknown authentication type")
|
||||
}
|
||||
if _, err := conn.Write([]byte{5, 0}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
requestHeader := make([]byte, 4)
|
||||
if _, err := io.ReadFull(conn, requestHeader); err != nil {
|
||||
return err
|
||||
}
|
||||
if requestHeader[0] != 5 {
|
||||
return E.New("unsupported SOCKS request version: ", requestHeader[0])
|
||||
}
|
||||
if requestHeader[1] != 1 {
|
||||
_ = writeSocksReply(conn, socksReplyCommandNotSupported)
|
||||
return E.New("unsupported SOCKS command: ", requestHeader[1])
|
||||
}
|
||||
if _, err := readSocksDestination(conn, requestHeader[3]); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := writeSocksReply(conn, socksReplySuccess); err != nil {
|
||||
return err
|
||||
}
|
||||
return bufio.CopyConn(ctx, conn, targetConn)
|
||||
}
|
||||
|
||||
func requestHeaderValue(request *ConnectRequest, headerName string) string {
|
||||
for _, entry := range request.Metadata {
|
||||
if !strings.HasPrefix(entry.Key, metadataHTTPHeader+":") {
|
||||
|
||||
@@ -201,7 +201,7 @@ func TestHandleBastionStream(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
inboundInstance.handleBastionStream(context.Background(), serverSide, respWriter, request, adapter.InboundContext{})
|
||||
inboundInstance.handleBastionStream(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, ResolvedService{})
|
||||
}()
|
||||
|
||||
select {
|
||||
@@ -438,7 +438,10 @@ func TestHandleStreamService(t *testing.T) {
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
inboundInstance.handleStreamService(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, M.ParseSocksaddr(listener.Addr().String()))
|
||||
inboundInstance.handleStreamService(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, ResolvedService{
|
||||
Kind: ResolvedServiceStream,
|
||||
Destination: M.ParseSocksaddr(listener.Addr().String()),
|
||||
})
|
||||
}()
|
||||
|
||||
select {
|
||||
@@ -473,3 +476,67 @@ func TestHandleStreamService(t *testing.T) {
|
||||
t.Fatal("stream service did not exit")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleStreamServiceProxyTypeSocks(t *testing.T) {
|
||||
listener := startEchoListener(t)
|
||||
defer listener.Close()
|
||||
|
||||
serverSide, clientSide := net.Pipe()
|
||||
defer clientSide.Close()
|
||||
|
||||
inboundInstance := newSpecialServiceInbound(t)
|
||||
request := &ConnectRequest{
|
||||
Type: ConnectionTypeWebsocket,
|
||||
Metadata: []Metadata{
|
||||
{Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="},
|
||||
},
|
||||
}
|
||||
respWriter := &fakeConnectResponseWriter{done: make(chan struct{})}
|
||||
|
||||
done := make(chan struct{})
|
||||
go func() {
|
||||
defer close(done)
|
||||
inboundInstance.handleStreamService(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, ResolvedService{
|
||||
Kind: ResolvedServiceStream,
|
||||
Destination: M.ParseSocksaddr(listener.Addr().String()),
|
||||
OriginRequest: OriginRequestConfig{
|
||||
ProxyType: "socks",
|
||||
},
|
||||
})
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-respWriter.done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("timed out waiting for stream service connect response")
|
||||
}
|
||||
if respWriter.err != nil {
|
||||
t.Fatal(respWriter.err)
|
||||
}
|
||||
if respWriter.status != http.StatusSwitchingProtocols {
|
||||
t.Fatalf("expected 101 response, got %d", respWriter.status)
|
||||
}
|
||||
|
||||
writeSocksAuth(t, clientSide)
|
||||
data := writeSocksConnectIPv4(t, clientSide, listener.Addr().String())
|
||||
if len(data) != 10 || data[1] != socksReplySuccess {
|
||||
t.Fatalf("unexpected socks connect response: %v", data)
|
||||
}
|
||||
|
||||
if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, []byte("hello")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data, _, err := wsutil.ReadServerData(clientSide)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(data) != "hello" {
|
||||
t.Fatalf("expected echoed payload, got %q", string(data))
|
||||
}
|
||||
_ = clientSide.Close()
|
||||
select {
|
||||
case <-done:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("socks stream service did not exit")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user