Fix cloudflared test and protocol parity
This commit is contained in:
@@ -24,7 +24,9 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
h2EdgeSNI = "h2.cftunnel.com"
|
||||
h2EdgeSNI = "h2.cftunnel.com"
|
||||
h2ResponseMetaCloudflared = `{"src":"cloudflared"}`
|
||||
h2ResponseMetaCloudflaredLimited = `{"src":"cloudflared","flow_rate_limited":true}`
|
||||
)
|
||||
|
||||
// HTTP2Connection manages a single HTTP/2 connection to the Cloudflare edge.
|
||||
@@ -357,7 +359,11 @@ func (w *http2ResponseWriter) WriteResponse(responseError error, metadata []Meta
|
||||
w.headersSent = true
|
||||
|
||||
if responseError != nil {
|
||||
w.writer.Header().Set(h2HeaderResponseMeta, `{"src":"cloudflared"}`)
|
||||
if hasFlowConnectRateLimited(metadata) {
|
||||
w.writer.Header().Set(h2HeaderResponseMeta, h2ResponseMetaCloudflaredLimited)
|
||||
} else {
|
||||
w.writer.Header().Set(h2HeaderResponseMeta, h2ResponseMetaCloudflared)
|
||||
}
|
||||
w.writer.WriteHeader(http.StatusBadGateway)
|
||||
w.flusher.Flush()
|
||||
return nil
|
||||
|
||||
@@ -182,7 +182,7 @@ func (i *Inbound) handleTCPStream(ctx context.Context, stream io.ReadWriteCloser
|
||||
if !i.flowLimiter.Acquire(limit) {
|
||||
err := E.New("too many active flows")
|
||||
i.logger.ErrorContext(ctx, err)
|
||||
respWriter.WriteResponse(err, nil)
|
||||
respWriter.WriteResponse(err, flowConnectRateLimitedMetadata())
|
||||
return
|
||||
}
|
||||
defer i.flowLimiter.Release(limit)
|
||||
@@ -341,7 +341,7 @@ func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser,
|
||||
return
|
||||
}
|
||||
|
||||
httpRequest = applyOriginRequest(httpRequest, service.OriginRequest)
|
||||
httpRequest = normalizeOriginRequest(request.Type, httpRequest, service.OriginRequest)
|
||||
requestCtx := httpRequest.Context()
|
||||
if service.OriginRequest.ConnectTimeout > 0 {
|
||||
var cancel context.CancelFunc
|
||||
@@ -489,12 +489,33 @@ func applyOriginRequest(request *http.Request, originRequest OriginRequestConfig
|
||||
request.Header.Set("X-Forwarded-Host", request.Host)
|
||||
request.Host = originRequest.HTTPHostHeader
|
||||
}
|
||||
if originRequest.DisableChunkedEncoding && request.Header.Get("Content-Length") != "" {
|
||||
if contentLength, err := strconv.ParseInt(request.Header.Get("Content-Length"), 10, 64); err == nil {
|
||||
request.ContentLength = contentLength
|
||||
request.TransferEncoding = nil
|
||||
return request
|
||||
}
|
||||
|
||||
func normalizeOriginRequest(connectType ConnectionType, request *http.Request, originRequest OriginRequestConfig) *http.Request {
|
||||
request = applyOriginRequest(request, originRequest)
|
||||
|
||||
switch connectType {
|
||||
case ConnectionTypeWebsocket:
|
||||
request.Header.Set("Connection", "Upgrade")
|
||||
request.Header.Set("Upgrade", "websocket")
|
||||
request.Header.Set("Sec-Websocket-Version", "13")
|
||||
request.ContentLength = 0
|
||||
request.Body = nil
|
||||
default:
|
||||
if originRequest.DisableChunkedEncoding {
|
||||
request.TransferEncoding = []string{"gzip", "deflate"}
|
||||
if contentLength, err := strconv.ParseInt(request.Header.Get("Content-Length"), 10, 64); err == nil {
|
||||
request.ContentLength = contentLength
|
||||
}
|
||||
}
|
||||
request.Header.Set("Connection", "keep-alive")
|
||||
}
|
||||
|
||||
if _, exists := request.Header["User-Agent"]; !exists {
|
||||
request.Header.Set("User-Agent", "")
|
||||
}
|
||||
|
||||
return request
|
||||
}
|
||||
|
||||
|
||||
@@ -124,3 +124,25 @@ func TestRotateEdgeAddrIndex(t *testing.T) {
|
||||
t.Fatalf("expected single-address pool to stay at 0, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEffectiveHAConnections(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
requested int
|
||||
available int
|
||||
expected int
|
||||
}{
|
||||
{name: "requested below available", requested: 2, available: 4, expected: 2},
|
||||
{name: "requested equals available", requested: 4, available: 4, expected: 4},
|
||||
{name: "requested above available", requested: 5, available: 3, expected: 3},
|
||||
{name: "no available edges", requested: 4, available: 0, expected: 0},
|
||||
}
|
||||
|
||||
for _, testCase := range tests {
|
||||
t.Run(testCase.name, func(t *testing.T) {
|
||||
if actual := effectiveHAConnections(testCase.requested, testCase.available); actual != testCase.expected {
|
||||
t.Fatalf("effectiveHAConnections(%d, %d) = %d, want %d", testCase.requested, testCase.available, actual, testCase.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/sagernet/sing-box/adapter"
|
||||
@@ -17,6 +19,17 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type captureConnectMetadataWriter struct {
|
||||
err error
|
||||
metadata []Metadata
|
||||
}
|
||||
|
||||
func (w *captureConnectMetadataWriter) WriteResponse(responseError error, metadata []Metadata) error {
|
||||
w.err = responseError
|
||||
w.metadata = append([]Metadata(nil), metadata...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func newLimitedInbound(t *testing.T, limit uint64) *Inbound {
|
||||
t.Helper()
|
||||
logFactory, err := log.New(log.Options{Options: option.LogOptions{Level: "debug"}})
|
||||
@@ -56,6 +69,45 @@ func TestHandleTCPStreamRespectsMaxActiveFlows(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleTCPStreamRateLimitMetadata(t *testing.T) {
|
||||
inboundInstance := newLimitedInbound(t, 1)
|
||||
if !inboundInstance.flowLimiter.Acquire(1) {
|
||||
t.Fatal("failed to pre-acquire limiter")
|
||||
}
|
||||
|
||||
stream, peer := net.Pipe()
|
||||
defer stream.Close()
|
||||
defer peer.Close()
|
||||
|
||||
respWriter := &captureConnectMetadataWriter{}
|
||||
inboundInstance.handleTCPStream(context.Background(), stream, respWriter, adapter.InboundContext{})
|
||||
if respWriter.err == nil {
|
||||
t.Fatal("expected too many active flows error")
|
||||
}
|
||||
if !hasFlowConnectRateLimited(respWriter.metadata) {
|
||||
t.Fatal("expected flow rate limit metadata")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHTTP2ResponseWriterFlowRateLimitedMeta(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
writer := &http2ResponseWriter{
|
||||
writer: recorder,
|
||||
flusher: recorder,
|
||||
}
|
||||
|
||||
err := writer.WriteResponse(context.DeadlineExceeded, flowConnectRateLimitedMetadata())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if recorder.Code != http.StatusBadGateway {
|
||||
t.Fatalf("expected %d, got %d", http.StatusBadGateway, recorder.Code)
|
||||
}
|
||||
if meta := recorder.Header().Get(h2HeaderResponseMeta); meta != h2ResponseMetaCloudflaredLimited {
|
||||
t.Fatalf("unexpected response meta: %q", meta)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDatagramV2RegisterSessionRespectsMaxActiveFlows(t *testing.T) {
|
||||
inboundInstance := newLimitedInbound(t, 1)
|
||||
if !inboundInstance.flowLimiter.Acquire(1) {
|
||||
|
||||
@@ -202,6 +202,8 @@ func newTestInbound(t *testing.T, token string, protocol string, haConnections i
|
||||
datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer),
|
||||
datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer),
|
||||
datagramV3Manager: NewDatagramV3SessionManager(),
|
||||
connectedIndices: make(map[uint8]struct{}),
|
||||
connectedNotify: make(chan uint8, haConnections),
|
||||
controlDialer: N.SystemDialer,
|
||||
accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: N.SystemDialer},
|
||||
}
|
||||
|
||||
@@ -162,6 +162,10 @@ func (i *Inbound) Start(stage adapter.StartStage) error {
|
||||
if len(edgeAddrs) == 0 {
|
||||
return E.New("no edge addresses available")
|
||||
}
|
||||
if cappedHAConnections := effectiveHAConnections(i.haConnections, len(edgeAddrs)); cappedHAConnections != i.haConnections {
|
||||
i.logger.Info("requested ", i.haConnections, " HA connections but only ", cappedHAConnections, " edge addresses are available")
|
||||
i.haConnections = cappedHAConnections
|
||||
}
|
||||
|
||||
i.datagramVersion = resolveDatagramVersion(i.ctx, i.credentials.AccountTag, i.datagramVersion)
|
||||
features := DefaultFeatures(i.datagramVersion)
|
||||
@@ -385,6 +389,16 @@ func flattenRegions(regions [][]*EdgeAddr) []*EdgeAddr {
|
||||
return result
|
||||
}
|
||||
|
||||
func effectiveHAConnections(requested, available int) int {
|
||||
if available <= 0 {
|
||||
return 0
|
||||
}
|
||||
if requested > available {
|
||||
return available
|
||||
}
|
||||
return requested
|
||||
}
|
||||
|
||||
func parseToken(token string) (Credentials, error) {
|
||||
data, err := base64.StdEncoding.DecodeString(token)
|
||||
if err != nil {
|
||||
|
||||
@@ -3,8 +3,10 @@
|
||||
package cloudflare
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -75,3 +77,61 @@ func TestNewDirectOriginTransportNoHappyEyeballs(t *testing.T) {
|
||||
t.Fatal("expected custom direct dial context")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOriginRequestSetsKeepAliveAndEmptyUserAgent(t *testing.T) {
|
||||
request, err := http.NewRequest(http.MethodGet, "https://example.com/path", http.NoBody)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
request = normalizeOriginRequest(ConnectionTypeHTTP, request, OriginRequestConfig{})
|
||||
if connection := request.Header.Get("Connection"); connection != "keep-alive" {
|
||||
t.Fatalf("expected keep-alive connection header, got %q", connection)
|
||||
}
|
||||
if values, exists := request.Header["User-Agent"]; !exists || len(values) != 1 || values[0] != "" {
|
||||
t.Fatalf("expected empty User-Agent header, got %#v", request.Header["User-Agent"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOriginRequestDisableChunkedEncoding(t *testing.T) {
|
||||
request, err := http.NewRequest(http.MethodPost, "https://example.com/path", strings.NewReader("payload"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
request.TransferEncoding = []string{"chunked"}
|
||||
request.Header.Set("Content-Length", "7")
|
||||
|
||||
request = normalizeOriginRequest(ConnectionTypeHTTP, request, OriginRequestConfig{
|
||||
DisableChunkedEncoding: true,
|
||||
})
|
||||
if len(request.TransferEncoding) != 2 || request.TransferEncoding[0] != "gzip" || request.TransferEncoding[1] != "deflate" {
|
||||
t.Fatalf("unexpected transfer encoding: %#v", request.TransferEncoding)
|
||||
}
|
||||
if request.ContentLength != 7 {
|
||||
t.Fatalf("expected content length 7, got %d", request.ContentLength)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeOriginRequestWebsocket(t *testing.T) {
|
||||
request, err := http.NewRequest(http.MethodGet, "https://example.com/path", io.NopCloser(strings.NewReader("payload")))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
request = normalizeOriginRequest(ConnectionTypeWebsocket, request, OriginRequestConfig{})
|
||||
if connection := request.Header.Get("Connection"); connection != "Upgrade" {
|
||||
t.Fatalf("expected websocket connection header, got %q", connection)
|
||||
}
|
||||
if upgrade := request.Header.Get("Upgrade"); upgrade != "websocket" {
|
||||
t.Fatalf("expected websocket upgrade header, got %q", upgrade)
|
||||
}
|
||||
if version := request.Header.Get("Sec-Websocket-Version"); version != "13" {
|
||||
t.Fatalf("expected websocket version 13, got %q", version)
|
||||
}
|
||||
if request.ContentLength != 0 {
|
||||
t.Fatalf("expected websocket content length 0, got %d", request.ContentLength)
|
||||
}
|
||||
if request.Body != nil {
|
||||
t.Fatal("expected websocket body to be nil")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,6 +31,8 @@ const (
|
||||
StreamTypeRPC
|
||||
)
|
||||
|
||||
const metadataFlowConnectRateLimited = "FlowConnectRateLimited"
|
||||
|
||||
// ConnectionType indicates the proxied connection type within a data stream.
|
||||
type ConnectionType uint16
|
||||
|
||||
@@ -59,6 +61,22 @@ type Metadata struct {
|
||||
Val string `capnp:"val"`
|
||||
}
|
||||
|
||||
func flowConnectRateLimitedMetadata() []Metadata {
|
||||
return []Metadata{{
|
||||
Key: metadataFlowConnectRateLimited,
|
||||
Val: "true",
|
||||
}}
|
||||
}
|
||||
|
||||
func hasFlowConnectRateLimited(metadata []Metadata) bool {
|
||||
for _, entry := range metadata {
|
||||
if entry.Key == metadataFlowConnectRateLimited && entry.Val == "true" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// ConnectRequest is sent by the edge at the start of a data stream.
|
||||
type ConnectRequest struct {
|
||||
Dest string `capnp:"dest"`
|
||||
|
||||
Reference in New Issue
Block a user