diff --git a/protocol/cloudflare/config_decode_test.go b/protocol/cloudflare/config_decode_test.go index 4addd1f99..0c6f83473 100644 --- a/protocol/cloudflare/config_decode_test.go +++ b/protocol/cloudflare/config_decode_test.go @@ -26,3 +26,13 @@ func TestValidateRegistrationResultRejectsNonRemoteManaged(t *testing.T) { t.Fatalf("unexpected error: %v", err) } } + +func TestNormalizeProtocolAcceptsAuto(t *testing.T) { + protocol, err := normalizeProtocol("auto") + if err != nil { + t.Fatal(err) + } + if protocol != "" { + t.Fatalf("expected auto protocol to normalize to empty string, got %q", protocol) + } +} diff --git a/protocol/cloudflare/dispatch.go b/protocol/cloudflare/dispatch.go index 5cedfe2a2..10b430076 100644 --- a/protocol/cloudflare/dispatch.go +++ b/protocol/cloudflare/dispatch.go @@ -265,7 +265,7 @@ func (i *Inbound) handleHTTPService(ctx context.Context, stream io.ReadWriteClos respWriter.WriteResponse(err, nil) return } - i.handleStreamService(ctx, stream, respWriter, request, metadata, service.Destination) + i.handleStreamService(ctx, stream, respWriter, request, metadata, service) case ResolvedServiceUnix, ResolvedServiceUnixTLS, ResolvedServiceHelloWorld: if request.Type == ConnectionTypeHTTP { i.handleDirectHTTPStream(ctx, stream, respWriter, request, metadata, service) @@ -279,7 +279,7 @@ func (i *Inbound) handleHTTPService(ctx context.Context, stream io.ReadWriteClos respWriter.WriteResponse(err, nil) return } - i.handleBastionStream(ctx, stream, respWriter, request, metadata) + i.handleBastionStream(ctx, stream, respWriter, request, metadata, service) case ResolvedServiceSocksProxy: if request.Type != ConnectionTypeWebsocket { err := E.New("socks-proxy service requires websocket request type") diff --git a/protocol/cloudflare/edge_discovery_test.go b/protocol/cloudflare/edge_discovery_test.go index 28dda352e..f3e6e8df5 100644 --- a/protocol/cloudflare/edge_discovery_test.go +++ b/protocol/cloudflare/edge_discovery_test.go @@ -97,3 +97,30 @@ func TestGetRegionalServiceName(t *testing.T) { t.Fatalf("expected regional service us-%s, got %s", edgeSRVService, got) } } + +func TestInitialEdgeAddrIndex(t *testing.T) { + if got := initialEdgeAddrIndex(0, 4); got != 0 { + t.Fatalf("expected conn 0 to get index 0, got %d", got) + } + if got := initialEdgeAddrIndex(3, 4); got != 3 { + t.Fatalf("expected conn 3 to get index 3, got %d", got) + } + if got := initialEdgeAddrIndex(5, 4); got != 1 { + t.Fatalf("expected conn 5 to wrap to index 1, got %d", got) + } + if got := initialEdgeAddrIndex(2, 1); got != 0 { + t.Fatalf("expected single-address pool to always return 0, got %d", got) + } +} + +func TestRotateEdgeAddrIndex(t *testing.T) { + if got := rotateEdgeAddrIndex(0, 4); got != 1 { + t.Fatalf("expected index 0 to rotate to 1, got %d", got) + } + if got := rotateEdgeAddrIndex(3, 4); got != 0 { + t.Fatalf("expected last index to wrap to 0, got %d", got) + } + if got := rotateEdgeAddrIndex(0, 1); got != 0 { + t.Fatalf("expected single-address pool to stay at 0, got %d", got) + } +} diff --git a/protocol/cloudflare/inbound.go b/protocol/cloudflare/inbound.go index ad4645818..42e0b46a3 100644 --- a/protocol/cloudflare/inbound.go +++ b/protocol/cloudflare/inbound.go @@ -86,9 +86,9 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo haConnections = 4 } - protocol := options.Protocol - if protocol != "" && protocol != "quic" && protocol != "http2" { - return nil, E.New("unsupported protocol: ", protocol, ", expected quic or http2") + protocol, err := normalizeProtocol(options.Protocol) + if err != nil { + return nil, err } edgeIPVersion := options.EdgeIPVersion @@ -283,6 +283,7 @@ const ( func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr, features []string) { defer i.done.Done() + edgeIndex := initialEdgeAddrIndex(connIndex, len(edgeAddrs)) retries := 0 for { select { @@ -291,7 +292,7 @@ func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr, fe default: } - edgeAddr := edgeAddrs[rand.Intn(len(edgeAddrs))] + edgeAddr := edgeAddrs[edgeIndex] err := i.serveConnection(connIndex, edgeAddr, features, uint8(retries)) if err == nil || i.ctx.Err() != nil { return @@ -303,6 +304,7 @@ func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr, fe } retries++ + edgeIndex = rotateEdgeAddrIndex(edgeIndex, len(edgeAddrs)) backoff := backoffDuration(retries) var retryableErr *RetryableError if errors.As(err, &retryableErr) && retryableErr.Delay > 0 { @@ -410,6 +412,20 @@ func backoffDuration(retries int) time.Duration { return backoff/2 + jitter } +func initialEdgeAddrIndex(connIndex uint8, size int) int { + if size <= 1 { + return 0 + } + return int(connIndex) % size +} + +func rotateEdgeAddrIndex(current int, size int) int { + if size <= 1 { + return 0 + } + return (current + 1) % size +} + func flattenRegions(regions [][]*EdgeAddr) []*EdgeAddr { var result []*EdgeAddr for _, region := range regions { @@ -430,3 +446,13 @@ func parseToken(token string) (Credentials, error) { } return tunnelToken.ToCredentials(), nil } + +func normalizeProtocol(protocol string) (string, error) { + if protocol == "auto" { + return "", nil + } + if protocol != "" && protocol != "quic" && protocol != "http2" { + return "", E.New("unsupported protocol: ", protocol, ", expected auto, quic or http2") + } + return protocol, nil +} diff --git a/protocol/cloudflare/special_service.go b/protocol/cloudflare/special_service.go index c5b5e0f4d..6c6d142ae 100644 --- a/protocol/cloudflare/special_service.go +++ b/protocol/cloudflare/special_service.go @@ -32,20 +32,20 @@ const ( socksReplyCommandNotSupported = 7 ) -func (i *Inbound) handleBastionStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext) { +func (i *Inbound) handleBastionStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) { destination, err := resolveBastionDestination(request) if err != nil { respWriter.WriteResponse(err, nil) return } - i.handleRouterBackedStream(ctx, stream, respWriter, request, M.ParseSocksaddr(destination)) + i.handleRouterBackedStream(ctx, stream, respWriter, request, M.ParseSocksaddr(destination), service.OriginRequest.ProxyType) } -func (i *Inbound) handleStreamService(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, destination M.Socksaddr) { - i.handleRouterBackedStream(ctx, stream, respWriter, request, destination) +func (i *Inbound) handleStreamService(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) { + i.handleRouterBackedStream(ctx, stream, respWriter, request, service.Destination, service.OriginRequest.ProxyType) } -func (i *Inbound) handleRouterBackedStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, destination M.Socksaddr) { +func (i *Inbound) handleRouterBackedStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, destination M.Socksaddr, proxyType string) { targetConn, cleanup, err := i.dialRouterTCP(ctx, destination) if err != nil { respWriter.WriteResponse(err, nil) @@ -61,6 +61,12 @@ func (i *Inbound) handleRouterBackedStream(ctx context.Context, stream io.ReadWr wsConn := v2raywebsocket.NewConn(newStreamConn(stream), nil, ws.StateServerSide) defer wsConn.Close() + if isSocksProxyType(proxyType) { + if err := serveFixedSocksStream(ctx, wsConn, targetConn); err != nil && !E.IsClosedOrCanceled(err) { + i.logger.DebugContext(ctx, "socks-over-websocket stream closed: ", err) + } + return + } _ = bufio.CopyConn(ctx, wsConn, targetConn) } @@ -101,6 +107,67 @@ func websocketResponseHeaders(request *ConnectRequest) http.Header { return header } +func isSocksProxyType(proxyType string) bool { + lower := strings.ToLower(strings.TrimSpace(proxyType)) + return lower == "socks" || lower == "socks5" +} + +func serveFixedSocksStream(ctx context.Context, conn net.Conn, targetConn net.Conn) error { + version := make([]byte, 1) + if _, err := io.ReadFull(conn, version); err != nil { + return err + } + if version[0] != 5 { + return E.New("unsupported SOCKS version: ", version[0]) + } + + methodCount := make([]byte, 1) + if _, err := io.ReadFull(conn, methodCount); err != nil { + return err + } + methods := make([]byte, int(methodCount[0])) + if _, err := io.ReadFull(conn, methods); err != nil { + return err + } + + var supportsNoAuth bool + for _, method := range methods { + if method == 0 { + supportsNoAuth = true + break + } + } + if !supportsNoAuth { + _, err := conn.Write([]byte{5, 255}) + if err != nil { + return err + } + return E.New("unknown authentication type") + } + if _, err := conn.Write([]byte{5, 0}); err != nil { + return err + } + + requestHeader := make([]byte, 4) + if _, err := io.ReadFull(conn, requestHeader); err != nil { + return err + } + if requestHeader[0] != 5 { + return E.New("unsupported SOCKS request version: ", requestHeader[0]) + } + if requestHeader[1] != 1 { + _ = writeSocksReply(conn, socksReplyCommandNotSupported) + return E.New("unsupported SOCKS command: ", requestHeader[1]) + } + if _, err := readSocksDestination(conn, requestHeader[3]); err != nil { + return err + } + if err := writeSocksReply(conn, socksReplySuccess); err != nil { + return err + } + return bufio.CopyConn(ctx, conn, targetConn) +} + func requestHeaderValue(request *ConnectRequest, headerName string) string { for _, entry := range request.Metadata { if !strings.HasPrefix(entry.Key, metadataHTTPHeader+":") { diff --git a/protocol/cloudflare/special_service_test.go b/protocol/cloudflare/special_service_test.go index 2df966afa..8c39543c2 100644 --- a/protocol/cloudflare/special_service_test.go +++ b/protocol/cloudflare/special_service_test.go @@ -201,7 +201,7 @@ func TestHandleBastionStream(t *testing.T) { done := make(chan struct{}) go func() { defer close(done) - inboundInstance.handleBastionStream(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}) + inboundInstance.handleBastionStream(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, ResolvedService{}) }() select { @@ -438,7 +438,10 @@ func TestHandleStreamService(t *testing.T) { done := make(chan struct{}) go func() { defer close(done) - inboundInstance.handleStreamService(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, M.ParseSocksaddr(listener.Addr().String())) + inboundInstance.handleStreamService(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, ResolvedService{ + Kind: ResolvedServiceStream, + Destination: M.ParseSocksaddr(listener.Addr().String()), + }) }() select { @@ -473,3 +476,67 @@ func TestHandleStreamService(t *testing.T) { t.Fatal("stream service did not exit") } } + +func TestHandleStreamServiceProxyTypeSocks(t *testing.T) { + listener := startEchoListener(t) + defer listener.Close() + + serverSide, clientSide := net.Pipe() + defer clientSide.Close() + + inboundInstance := newSpecialServiceInbound(t) + request := &ConnectRequest{ + Type: ConnectionTypeWebsocket, + Metadata: []Metadata{ + {Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="}, + }, + } + respWriter := &fakeConnectResponseWriter{done: make(chan struct{})} + + done := make(chan struct{}) + go func() { + defer close(done) + inboundInstance.handleStreamService(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, ResolvedService{ + Kind: ResolvedServiceStream, + Destination: M.ParseSocksaddr(listener.Addr().String()), + OriginRequest: OriginRequestConfig{ + ProxyType: "socks", + }, + }) + }() + + select { + case <-respWriter.done: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for stream service connect response") + } + if respWriter.err != nil { + t.Fatal(respWriter.err) + } + if respWriter.status != http.StatusSwitchingProtocols { + t.Fatalf("expected 101 response, got %d", respWriter.status) + } + + writeSocksAuth(t, clientSide) + data := writeSocksConnectIPv4(t, clientSide, listener.Addr().String()) + if len(data) != 10 || data[1] != socksReplySuccess { + t.Fatalf("unexpected socks connect response: %v", data) + } + + if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, []byte("hello")); err != nil { + t.Fatal(err) + } + data, _, err := wsutil.ReadServerData(clientSide) + if err != nil { + t.Fatal(err) + } + if string(data) != "hello" { + t.Fatalf("expected echoed payload, got %q", string(data)) + } + _ = clientSide.Close() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatal("socks stream service did not exit") + } +}