Align cloudflare runtime behavior with cloudflared

This commit is contained in:
世界
2026-03-24 21:03:47 +08:00
parent 6e35f4da89
commit 1320b737b9
6 changed files with 210 additions and 13 deletions

View File

@@ -26,3 +26,13 @@ func TestValidateRegistrationResultRejectsNonRemoteManaged(t *testing.T) {
t.Fatalf("unexpected error: %v", err)
}
}
func TestNormalizeProtocolAcceptsAuto(t *testing.T) {
protocol, err := normalizeProtocol("auto")
if err != nil {
t.Fatal(err)
}
if protocol != "" {
t.Fatalf("expected auto protocol to normalize to empty string, got %q", protocol)
}
}

View File

@@ -265,7 +265,7 @@ func (i *Inbound) handleHTTPService(ctx context.Context, stream io.ReadWriteClos
respWriter.WriteResponse(err, nil)
return
}
i.handleStreamService(ctx, stream, respWriter, request, metadata, service.Destination)
i.handleStreamService(ctx, stream, respWriter, request, metadata, service)
case ResolvedServiceUnix, ResolvedServiceUnixTLS, ResolvedServiceHelloWorld:
if request.Type == ConnectionTypeHTTP {
i.handleDirectHTTPStream(ctx, stream, respWriter, request, metadata, service)
@@ -279,7 +279,7 @@ func (i *Inbound) handleHTTPService(ctx context.Context, stream io.ReadWriteClos
respWriter.WriteResponse(err, nil)
return
}
i.handleBastionStream(ctx, stream, respWriter, request, metadata)
i.handleBastionStream(ctx, stream, respWriter, request, metadata, service)
case ResolvedServiceSocksProxy:
if request.Type != ConnectionTypeWebsocket {
err := E.New("socks-proxy service requires websocket request type")

View File

@@ -97,3 +97,30 @@ func TestGetRegionalServiceName(t *testing.T) {
t.Fatalf("expected regional service us-%s, got %s", edgeSRVService, got)
}
}
func TestInitialEdgeAddrIndex(t *testing.T) {
if got := initialEdgeAddrIndex(0, 4); got != 0 {
t.Fatalf("expected conn 0 to get index 0, got %d", got)
}
if got := initialEdgeAddrIndex(3, 4); got != 3 {
t.Fatalf("expected conn 3 to get index 3, got %d", got)
}
if got := initialEdgeAddrIndex(5, 4); got != 1 {
t.Fatalf("expected conn 5 to wrap to index 1, got %d", got)
}
if got := initialEdgeAddrIndex(2, 1); got != 0 {
t.Fatalf("expected single-address pool to always return 0, got %d", got)
}
}
func TestRotateEdgeAddrIndex(t *testing.T) {
if got := rotateEdgeAddrIndex(0, 4); got != 1 {
t.Fatalf("expected index 0 to rotate to 1, got %d", got)
}
if got := rotateEdgeAddrIndex(3, 4); got != 0 {
t.Fatalf("expected last index to wrap to 0, got %d", got)
}
if got := rotateEdgeAddrIndex(0, 1); got != 0 {
t.Fatalf("expected single-address pool to stay at 0, got %d", got)
}
}

View File

@@ -86,9 +86,9 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
haConnections = 4
}
protocol := options.Protocol
if protocol != "" && protocol != "quic" && protocol != "http2" {
return nil, E.New("unsupported protocol: ", protocol, ", expected quic or http2")
protocol, err := normalizeProtocol(options.Protocol)
if err != nil {
return nil, err
}
edgeIPVersion := options.EdgeIPVersion
@@ -283,6 +283,7 @@ const (
func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr, features []string) {
defer i.done.Done()
edgeIndex := initialEdgeAddrIndex(connIndex, len(edgeAddrs))
retries := 0
for {
select {
@@ -291,7 +292,7 @@ func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr, fe
default:
}
edgeAddr := edgeAddrs[rand.Intn(len(edgeAddrs))]
edgeAddr := edgeAddrs[edgeIndex]
err := i.serveConnection(connIndex, edgeAddr, features, uint8(retries))
if err == nil || i.ctx.Err() != nil {
return
@@ -303,6 +304,7 @@ func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr, fe
}
retries++
edgeIndex = rotateEdgeAddrIndex(edgeIndex, len(edgeAddrs))
backoff := backoffDuration(retries)
var retryableErr *RetryableError
if errors.As(err, &retryableErr) && retryableErr.Delay > 0 {
@@ -410,6 +412,20 @@ func backoffDuration(retries int) time.Duration {
return backoff/2 + jitter
}
func initialEdgeAddrIndex(connIndex uint8, size int) int {
if size <= 1 {
return 0
}
return int(connIndex) % size
}
func rotateEdgeAddrIndex(current int, size int) int {
if size <= 1 {
return 0
}
return (current + 1) % size
}
func flattenRegions(regions [][]*EdgeAddr) []*EdgeAddr {
var result []*EdgeAddr
for _, region := range regions {
@@ -430,3 +446,13 @@ func parseToken(token string) (Credentials, error) {
}
return tunnelToken.ToCredentials(), nil
}
func normalizeProtocol(protocol string) (string, error) {
if protocol == "auto" {
return "", nil
}
if protocol != "" && protocol != "quic" && protocol != "http2" {
return "", E.New("unsupported protocol: ", protocol, ", expected auto, quic or http2")
}
return protocol, nil
}

View File

@@ -32,20 +32,20 @@ const (
socksReplyCommandNotSupported = 7
)
func (i *Inbound) handleBastionStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext) {
func (i *Inbound) handleBastionStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
destination, err := resolveBastionDestination(request)
if err != nil {
respWriter.WriteResponse(err, nil)
return
}
i.handleRouterBackedStream(ctx, stream, respWriter, request, M.ParseSocksaddr(destination))
i.handleRouterBackedStream(ctx, stream, respWriter, request, M.ParseSocksaddr(destination), service.OriginRequest.ProxyType)
}
func (i *Inbound) handleStreamService(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, destination M.Socksaddr) {
i.handleRouterBackedStream(ctx, stream, respWriter, request, destination)
func (i *Inbound) handleStreamService(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
i.handleRouterBackedStream(ctx, stream, respWriter, request, service.Destination, service.OriginRequest.ProxyType)
}
func (i *Inbound) handleRouterBackedStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, destination M.Socksaddr) {
func (i *Inbound) handleRouterBackedStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, destination M.Socksaddr, proxyType string) {
targetConn, cleanup, err := i.dialRouterTCP(ctx, destination)
if err != nil {
respWriter.WriteResponse(err, nil)
@@ -61,6 +61,12 @@ func (i *Inbound) handleRouterBackedStream(ctx context.Context, stream io.ReadWr
wsConn := v2raywebsocket.NewConn(newStreamConn(stream), nil, ws.StateServerSide)
defer wsConn.Close()
if isSocksProxyType(proxyType) {
if err := serveFixedSocksStream(ctx, wsConn, targetConn); err != nil && !E.IsClosedOrCanceled(err) {
i.logger.DebugContext(ctx, "socks-over-websocket stream closed: ", err)
}
return
}
_ = bufio.CopyConn(ctx, wsConn, targetConn)
}
@@ -101,6 +107,67 @@ func websocketResponseHeaders(request *ConnectRequest) http.Header {
return header
}
func isSocksProxyType(proxyType string) bool {
lower := strings.ToLower(strings.TrimSpace(proxyType))
return lower == "socks" || lower == "socks5"
}
func serveFixedSocksStream(ctx context.Context, conn net.Conn, targetConn net.Conn) error {
version := make([]byte, 1)
if _, err := io.ReadFull(conn, version); err != nil {
return err
}
if version[0] != 5 {
return E.New("unsupported SOCKS version: ", version[0])
}
methodCount := make([]byte, 1)
if _, err := io.ReadFull(conn, methodCount); err != nil {
return err
}
methods := make([]byte, int(methodCount[0]))
if _, err := io.ReadFull(conn, methods); err != nil {
return err
}
var supportsNoAuth bool
for _, method := range methods {
if method == 0 {
supportsNoAuth = true
break
}
}
if !supportsNoAuth {
_, err := conn.Write([]byte{5, 255})
if err != nil {
return err
}
return E.New("unknown authentication type")
}
if _, err := conn.Write([]byte{5, 0}); err != nil {
return err
}
requestHeader := make([]byte, 4)
if _, err := io.ReadFull(conn, requestHeader); err != nil {
return err
}
if requestHeader[0] != 5 {
return E.New("unsupported SOCKS request version: ", requestHeader[0])
}
if requestHeader[1] != 1 {
_ = writeSocksReply(conn, socksReplyCommandNotSupported)
return E.New("unsupported SOCKS command: ", requestHeader[1])
}
if _, err := readSocksDestination(conn, requestHeader[3]); err != nil {
return err
}
if err := writeSocksReply(conn, socksReplySuccess); err != nil {
return err
}
return bufio.CopyConn(ctx, conn, targetConn)
}
func requestHeaderValue(request *ConnectRequest, headerName string) string {
for _, entry := range request.Metadata {
if !strings.HasPrefix(entry.Key, metadataHTTPHeader+":") {

View File

@@ -201,7 +201,7 @@ func TestHandleBastionStream(t *testing.T) {
done := make(chan struct{})
go func() {
defer close(done)
inboundInstance.handleBastionStream(context.Background(), serverSide, respWriter, request, adapter.InboundContext{})
inboundInstance.handleBastionStream(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, ResolvedService{})
}()
select {
@@ -438,7 +438,10 @@ func TestHandleStreamService(t *testing.T) {
done := make(chan struct{})
go func() {
defer close(done)
inboundInstance.handleStreamService(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, M.ParseSocksaddr(listener.Addr().String()))
inboundInstance.handleStreamService(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, ResolvedService{
Kind: ResolvedServiceStream,
Destination: M.ParseSocksaddr(listener.Addr().String()),
})
}()
select {
@@ -473,3 +476,67 @@ func TestHandleStreamService(t *testing.T) {
t.Fatal("stream service did not exit")
}
}
func TestHandleStreamServiceProxyTypeSocks(t *testing.T) {
listener := startEchoListener(t)
defer listener.Close()
serverSide, clientSide := net.Pipe()
defer clientSide.Close()
inboundInstance := newSpecialServiceInbound(t)
request := &ConnectRequest{
Type: ConnectionTypeWebsocket,
Metadata: []Metadata{
{Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="},
},
}
respWriter := &fakeConnectResponseWriter{done: make(chan struct{})}
done := make(chan struct{})
go func() {
defer close(done)
inboundInstance.handleStreamService(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, ResolvedService{
Kind: ResolvedServiceStream,
Destination: M.ParseSocksaddr(listener.Addr().String()),
OriginRequest: OriginRequestConfig{
ProxyType: "socks",
},
})
}()
select {
case <-respWriter.done:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for stream service connect response")
}
if respWriter.err != nil {
t.Fatal(respWriter.err)
}
if respWriter.status != http.StatusSwitchingProtocols {
t.Fatalf("expected 101 response, got %d", respWriter.status)
}
writeSocksAuth(t, clientSide)
data := writeSocksConnectIPv4(t, clientSide, listener.Addr().String())
if len(data) != 10 || data[1] != socksReplySuccess {
t.Fatalf("unexpected socks connect response: %v", data)
}
if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, []byte("hello")); err != nil {
t.Fatal(err)
}
data, _, err := wsutil.ReadServerData(clientSide)
if err != nil {
t.Fatal(err)
}
if string(data) != "hello" {
t.Fatalf("expected echoed payload, got %q", string(data))
}
_ = clientSide.Close()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("socks stream service did not exit")
}
}