Fix cloudflared test and protocol parity

This commit is contained in:
世界
2026-03-25 15:52:43 +08:00
parent 7ca692d8c2
commit 4497f61323
8 changed files with 203 additions and 8 deletions

View File

@@ -24,7 +24,9 @@ import (
)
const (
h2EdgeSNI = "h2.cftunnel.com"
h2EdgeSNI = "h2.cftunnel.com"
h2ResponseMetaCloudflared = `{"src":"cloudflared"}`
h2ResponseMetaCloudflaredLimited = `{"src":"cloudflared","flow_rate_limited":true}`
)
// HTTP2Connection manages a single HTTP/2 connection to the Cloudflare edge.
@@ -357,7 +359,11 @@ func (w *http2ResponseWriter) WriteResponse(responseError error, metadata []Meta
w.headersSent = true
if responseError != nil {
w.writer.Header().Set(h2HeaderResponseMeta, `{"src":"cloudflared"}`)
if hasFlowConnectRateLimited(metadata) {
w.writer.Header().Set(h2HeaderResponseMeta, h2ResponseMetaCloudflaredLimited)
} else {
w.writer.Header().Set(h2HeaderResponseMeta, h2ResponseMetaCloudflared)
}
w.writer.WriteHeader(http.StatusBadGateway)
w.flusher.Flush()
return nil

View File

@@ -182,7 +182,7 @@ func (i *Inbound) handleTCPStream(ctx context.Context, stream io.ReadWriteCloser
if !i.flowLimiter.Acquire(limit) {
err := E.New("too many active flows")
i.logger.ErrorContext(ctx, err)
respWriter.WriteResponse(err, nil)
respWriter.WriteResponse(err, flowConnectRateLimitedMetadata())
return
}
defer i.flowLimiter.Release(limit)
@@ -341,7 +341,7 @@ func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser,
return
}
httpRequest = applyOriginRequest(httpRequest, service.OriginRequest)
httpRequest = normalizeOriginRequest(request.Type, httpRequest, service.OriginRequest)
requestCtx := httpRequest.Context()
if service.OriginRequest.ConnectTimeout > 0 {
var cancel context.CancelFunc
@@ -489,12 +489,33 @@ func applyOriginRequest(request *http.Request, originRequest OriginRequestConfig
request.Header.Set("X-Forwarded-Host", request.Host)
request.Host = originRequest.HTTPHostHeader
}
if originRequest.DisableChunkedEncoding && request.Header.Get("Content-Length") != "" {
if contentLength, err := strconv.ParseInt(request.Header.Get("Content-Length"), 10, 64); err == nil {
request.ContentLength = contentLength
request.TransferEncoding = nil
return request
}
func normalizeOriginRequest(connectType ConnectionType, request *http.Request, originRequest OriginRequestConfig) *http.Request {
request = applyOriginRequest(request, originRequest)
switch connectType {
case ConnectionTypeWebsocket:
request.Header.Set("Connection", "Upgrade")
request.Header.Set("Upgrade", "websocket")
request.Header.Set("Sec-Websocket-Version", "13")
request.ContentLength = 0
request.Body = nil
default:
if originRequest.DisableChunkedEncoding {
request.TransferEncoding = []string{"gzip", "deflate"}
if contentLength, err := strconv.ParseInt(request.Header.Get("Content-Length"), 10, 64); err == nil {
request.ContentLength = contentLength
}
}
request.Header.Set("Connection", "keep-alive")
}
if _, exists := request.Header["User-Agent"]; !exists {
request.Header.Set("User-Agent", "")
}
return request
}

View File

@@ -124,3 +124,25 @@ func TestRotateEdgeAddrIndex(t *testing.T) {
t.Fatalf("expected single-address pool to stay at 0, got %d", got)
}
}
func TestEffectiveHAConnections(t *testing.T) {
tests := []struct {
name string
requested int
available int
expected int
}{
{name: "requested below available", requested: 2, available: 4, expected: 2},
{name: "requested equals available", requested: 4, available: 4, expected: 4},
{name: "requested above available", requested: 5, available: 3, expected: 3},
{name: "no available edges", requested: 4, available: 0, expected: 0},
}
for _, testCase := range tests {
t.Run(testCase.name, func(t *testing.T) {
if actual := effectiveHAConnections(testCase.requested, testCase.available); actual != testCase.expected {
t.Fatalf("effectiveHAConnections(%d, %d) = %d, want %d", testCase.requested, testCase.available, actual, testCase.expected)
}
})
}
}

View File

@@ -6,6 +6,8 @@ import (
"context"
"encoding/binary"
"net"
"net/http"
"net/http/httptest"
"testing"
"github.com/sagernet/sing-box/adapter"
@@ -17,6 +19,17 @@ import (
"github.com/google/uuid"
)
type captureConnectMetadataWriter struct {
err error
metadata []Metadata
}
func (w *captureConnectMetadataWriter) WriteResponse(responseError error, metadata []Metadata) error {
w.err = responseError
w.metadata = append([]Metadata(nil), metadata...)
return nil
}
func newLimitedInbound(t *testing.T, limit uint64) *Inbound {
t.Helper()
logFactory, err := log.New(log.Options{Options: option.LogOptions{Level: "debug"}})
@@ -56,6 +69,45 @@ func TestHandleTCPStreamRespectsMaxActiveFlows(t *testing.T) {
}
}
func TestHandleTCPStreamRateLimitMetadata(t *testing.T) {
inboundInstance := newLimitedInbound(t, 1)
if !inboundInstance.flowLimiter.Acquire(1) {
t.Fatal("failed to pre-acquire limiter")
}
stream, peer := net.Pipe()
defer stream.Close()
defer peer.Close()
respWriter := &captureConnectMetadataWriter{}
inboundInstance.handleTCPStream(context.Background(), stream, respWriter, adapter.InboundContext{})
if respWriter.err == nil {
t.Fatal("expected too many active flows error")
}
if !hasFlowConnectRateLimited(respWriter.metadata) {
t.Fatal("expected flow rate limit metadata")
}
}
func TestHTTP2ResponseWriterFlowRateLimitedMeta(t *testing.T) {
recorder := httptest.NewRecorder()
writer := &http2ResponseWriter{
writer: recorder,
flusher: recorder,
}
err := writer.WriteResponse(context.DeadlineExceeded, flowConnectRateLimitedMetadata())
if err != nil {
t.Fatal(err)
}
if recorder.Code != http.StatusBadGateway {
t.Fatalf("expected %d, got %d", http.StatusBadGateway, recorder.Code)
}
if meta := recorder.Header().Get(h2HeaderResponseMeta); meta != h2ResponseMetaCloudflaredLimited {
t.Fatalf("unexpected response meta: %q", meta)
}
}
func TestDatagramV2RegisterSessionRespectsMaxActiveFlows(t *testing.T) {
inboundInstance := newLimitedInbound(t, 1)
if !inboundInstance.flowLimiter.Acquire(1) {

View File

@@ -202,6 +202,8 @@ func newTestInbound(t *testing.T, token string, protocol string, haConnections i
datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer),
datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer),
datagramV3Manager: NewDatagramV3SessionManager(),
connectedIndices: make(map[uint8]struct{}),
connectedNotify: make(chan uint8, haConnections),
controlDialer: N.SystemDialer,
accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: N.SystemDialer},
}

View File

@@ -162,6 +162,10 @@ func (i *Inbound) Start(stage adapter.StartStage) error {
if len(edgeAddrs) == 0 {
return E.New("no edge addresses available")
}
if cappedHAConnections := effectiveHAConnections(i.haConnections, len(edgeAddrs)); cappedHAConnections != i.haConnections {
i.logger.Info("requested ", i.haConnections, " HA connections but only ", cappedHAConnections, " edge addresses are available")
i.haConnections = cappedHAConnections
}
i.datagramVersion = resolveDatagramVersion(i.ctx, i.credentials.AccountTag, i.datagramVersion)
features := DefaultFeatures(i.datagramVersion)
@@ -385,6 +389,16 @@ func flattenRegions(regions [][]*EdgeAddr) []*EdgeAddr {
return result
}
func effectiveHAConnections(requested, available int) int {
if available <= 0 {
return 0
}
if requested > available {
return available
}
return requested
}
func parseToken(token string) (Credentials, error) {
data, err := base64.StdEncoding.DecodeString(token)
if err != nil {

View File

@@ -3,8 +3,10 @@
package cloudflare
import (
"io"
"net/http"
"net/url"
"strings"
"testing"
)
@@ -75,3 +77,61 @@ func TestNewDirectOriginTransportNoHappyEyeballs(t *testing.T) {
t.Fatal("expected custom direct dial context")
}
}
func TestNormalizeOriginRequestSetsKeepAliveAndEmptyUserAgent(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "https://example.com/path", http.NoBody)
if err != nil {
t.Fatal(err)
}
request = normalizeOriginRequest(ConnectionTypeHTTP, request, OriginRequestConfig{})
if connection := request.Header.Get("Connection"); connection != "keep-alive" {
t.Fatalf("expected keep-alive connection header, got %q", connection)
}
if values, exists := request.Header["User-Agent"]; !exists || len(values) != 1 || values[0] != "" {
t.Fatalf("expected empty User-Agent header, got %#v", request.Header["User-Agent"])
}
}
func TestNormalizeOriginRequestDisableChunkedEncoding(t *testing.T) {
request, err := http.NewRequest(http.MethodPost, "https://example.com/path", strings.NewReader("payload"))
if err != nil {
t.Fatal(err)
}
request.TransferEncoding = []string{"chunked"}
request.Header.Set("Content-Length", "7")
request = normalizeOriginRequest(ConnectionTypeHTTP, request, OriginRequestConfig{
DisableChunkedEncoding: true,
})
if len(request.TransferEncoding) != 2 || request.TransferEncoding[0] != "gzip" || request.TransferEncoding[1] != "deflate" {
t.Fatalf("unexpected transfer encoding: %#v", request.TransferEncoding)
}
if request.ContentLength != 7 {
t.Fatalf("expected content length 7, got %d", request.ContentLength)
}
}
func TestNormalizeOriginRequestWebsocket(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "https://example.com/path", io.NopCloser(strings.NewReader("payload")))
if err != nil {
t.Fatal(err)
}
request = normalizeOriginRequest(ConnectionTypeWebsocket, request, OriginRequestConfig{})
if connection := request.Header.Get("Connection"); connection != "Upgrade" {
t.Fatalf("expected websocket connection header, got %q", connection)
}
if upgrade := request.Header.Get("Upgrade"); upgrade != "websocket" {
t.Fatalf("expected websocket upgrade header, got %q", upgrade)
}
if version := request.Header.Get("Sec-Websocket-Version"); version != "13" {
t.Fatalf("expected websocket version 13, got %q", version)
}
if request.ContentLength != 0 {
t.Fatalf("expected websocket content length 0, got %d", request.ContentLength)
}
if request.Body != nil {
t.Fatal("expected websocket body to be nil")
}
}

View File

@@ -31,6 +31,8 @@ const (
StreamTypeRPC
)
const metadataFlowConnectRateLimited = "FlowConnectRateLimited"
// ConnectionType indicates the proxied connection type within a data stream.
type ConnectionType uint16
@@ -59,6 +61,22 @@ type Metadata struct {
Val string `capnp:"val"`
}
func flowConnectRateLimitedMetadata() []Metadata {
return []Metadata{{
Key: metadataFlowConnectRateLimited,
Val: "true",
}}
}
func hasFlowConnectRateLimited(metadata []Metadata) bool {
for _, entry := range metadata {
if entry.Key == metadataFlowConnectRateLimited && entry.Val == "true" {
return true
}
}
return false
}
// ConnectRequest is sent by the edge at the start of a data stream.
type ConnectRequest struct {
Dest string `capnp:"dest"`