Files
sing-box/protocol/cloudflare/origin_request_test.go
2026-03-31 15:32:56 +08:00

309 lines
9.0 KiB
Go

//go:build with_cloudflared
package cloudflare
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"io"
"math/big"
"net/http"
"net/url"
"os"
"strings"
"testing"
"time"
"github.com/sagernet/sing-box/adapter"
)
func TestOriginTLSServerName(t *testing.T) {
t.Run("origin server name overrides host", func(t *testing.T) {
serverName := originTLSServerName(OriginRequestConfig{
OriginServerName: "origin.example.com",
MatchSNIToHost: true,
}, "request.example.com")
if serverName != "origin.example.com" {
t.Fatalf("expected origin.example.com, got %s", serverName)
}
})
t.Run("match sni to host strips port", func(t *testing.T) {
serverName := originTLSServerName(OriginRequestConfig{
MatchSNIToHost: true,
}, "request.example.com:443")
if serverName != "request.example.com" {
t.Fatalf("expected request.example.com, got %s", serverName)
}
})
t.Run("match sni to host uses http host header", func(t *testing.T) {
serverName := originTLSServerName(OriginRequestConfig{
MatchSNIToHost: true,
}, effectiveOriginHost(OriginRequestConfig{
HTTPHostHeader: "origin.example.com",
MatchSNIToHost: true,
}, "request.example.com"))
if serverName != "origin.example.com" {
t.Fatalf("expected origin.example.com, got %s", serverName)
}
})
t.Run("match sni to host strips port from http host header", func(t *testing.T) {
serverName := originTLSServerName(OriginRequestConfig{
MatchSNIToHost: true,
}, effectiveOriginHost(OriginRequestConfig{
HTTPHostHeader: "origin.example.com:8443",
MatchSNIToHost: true,
}, "request.example.com"))
if serverName != "origin.example.com" {
t.Fatalf("expected origin.example.com, got %s", serverName)
}
})
t.Run("disabled match keeps empty server name", func(t *testing.T) {
serverName := originTLSServerName(OriginRequestConfig{}, "request.example.com")
if serverName != "" {
t.Fatalf("expected empty server name, got %s", serverName)
}
})
}
func TestNewOriginTLSConfigErrorsOnMissingCAPool(t *testing.T) {
originalBaseLoader := loadOriginCABasePool
loadOriginCABasePool = func() (*x509.CertPool, error) {
return x509.NewCertPool(), nil
}
defer func() {
loadOriginCABasePool = originalBaseLoader
}()
_, err := newOriginTLSConfig(OriginRequestConfig{
CAPool: "/path/does/not/exist.pem",
}, "request.example.com")
if err == nil {
t.Fatal("expected error for missing ca pool")
}
}
func TestNewOriginTLSConfigAppendsCustomCAInsteadOfReplacingBasePool(t *testing.T) {
basePEM, baseCert := createTestCertificatePEM(t, "base")
customPEM, customCert := createTestCertificatePEM(t, "custom")
basePool := x509.NewCertPool()
if !basePool.AppendCertsFromPEM(basePEM) {
t.Fatal("expected base cert to append")
}
originalBaseLoader := loadOriginCABasePool
loadOriginCABasePool = func() (*x509.CertPool, error) {
return basePool, nil
}
defer func() {
loadOriginCABasePool = originalBaseLoader
}()
caFile := writeTempPEM(t, customPEM)
tlsConfig, err := newOriginTLSConfig(OriginRequestConfig{
CAPool: caFile,
}, "request.example.com")
if err != nil {
t.Fatal(err)
}
if tlsConfig.RootCAs == nil {
t.Fatal("expected root CA pool")
}
subjects := tlsConfig.RootCAs.Subjects()
if len(subjects) != 2 {
t.Fatalf("expected 2 subjects, got %d", len(subjects))
}
if !containsSubject(subjects, baseCert.RawSubject) {
t.Fatal("expected base subject to remain in pool")
}
if !containsSubject(subjects, customCert.RawSubject) {
t.Fatal("expected custom subject to be appended to pool")
}
}
func TestOriginTransportUsesProxyFromEnvironmentOnly(t *testing.T) {
originalProxyFromEnvironment := proxyFromEnvironment
proxyFromEnvironment = func(request *http.Request) (*url.URL, error) {
return url.Parse("http://proxy.example.com:8080")
}
defer func() {
proxyFromEnvironment = originalProxyFromEnvironment
}()
inbound := &Inbound{}
transport, cleanup, err := inbound.newDirectOriginTransport(ResolvedService{
Kind: ResolvedServiceUnix,
UnixPath: "/tmp/test.sock",
OriginRequest: OriginRequestConfig{
ProxyAddress: "127.0.0.1",
ProxyPort: 8081,
ProxyType: "http",
},
}, "")
if err != nil {
t.Fatal(err)
}
defer cleanup()
proxyURL, err := transport.Proxy(&http.Request{URL: &url.URL{Scheme: "http", Host: "example.com"}})
if err != nil {
t.Fatal(err)
}
if proxyURL == nil || proxyURL.String() != "http://proxy.example.com:8080" {
t.Fatalf("expected environment proxy URL, got %#v", proxyURL)
}
}
func TestNewDirectOriginTransportNoHappyEyeballs(t *testing.T) {
inbound := &Inbound{}
transport, cleanup, err := inbound.newDirectOriginTransport(ResolvedService{
Kind: ResolvedServiceUnix,
UnixPath: "/tmp/test.sock",
OriginRequest: OriginRequestConfig{
NoHappyEyeballs: true,
},
}, "")
if err != nil {
t.Fatal(err)
}
defer cleanup()
if transport.Proxy == nil {
t.Fatal("expected proxy function to be configured from environment")
}
if transport.DialContext == nil {
t.Fatal("expected custom direct dial context")
}
}
func TestNewRouterOriginTransportPropagatesTLSConfigError(t *testing.T) {
originalBaseLoader := loadOriginCABasePool
loadOriginCABasePool = func() (*x509.CertPool, error) {
return x509.NewCertPool(), nil
}
defer func() {
loadOriginCABasePool = originalBaseLoader
}()
inbound := &Inbound{}
_, _, err := inbound.newRouterOriginTransport(context.Background(), adapter.InboundContext{}, OriginRequestConfig{
CAPool: "/path/does/not/exist.pem",
}, "")
if err == nil {
t.Fatal("expected transport build error")
}
}
func TestNormalizeOriginRequestSetsKeepAliveAndEmptyUserAgent(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "https://example.com/path", http.NoBody)
if err != nil {
t.Fatal(err)
}
request = normalizeOriginRequest(ConnectionTypeHTTP, request, OriginRequestConfig{})
if connection := request.Header.Get("Connection"); connection != "keep-alive" {
t.Fatalf("expected keep-alive connection header, got %q", connection)
}
if values, exists := request.Header["User-Agent"]; !exists || len(values) != 1 || values[0] != "" {
t.Fatalf("expected empty User-Agent header, got %#v", request.Header["User-Agent"])
}
}
func TestNormalizeOriginRequestDisableChunkedEncoding(t *testing.T) {
request, err := http.NewRequest(http.MethodPost, "https://example.com/path", strings.NewReader("payload"))
if err != nil {
t.Fatal(err)
}
request.TransferEncoding = []string{"chunked"}
request.Header.Set("Content-Length", "7")
request = normalizeOriginRequest(ConnectionTypeHTTP, request, OriginRequestConfig{
DisableChunkedEncoding: true,
})
if len(request.TransferEncoding) != 2 || request.TransferEncoding[0] != "gzip" || request.TransferEncoding[1] != "deflate" {
t.Fatalf("unexpected transfer encoding: %#v", request.TransferEncoding)
}
if request.ContentLength != 7 {
t.Fatalf("expected content length 7, got %d", request.ContentLength)
}
}
func TestNormalizeOriginRequestWebsocket(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "https://example.com/path", io.NopCloser(strings.NewReader("payload")))
if err != nil {
t.Fatal(err)
}
request = normalizeOriginRequest(ConnectionTypeWebsocket, request, OriginRequestConfig{})
if connection := request.Header.Get("Connection"); connection != "Upgrade" {
t.Fatalf("expected websocket connection header, got %q", connection)
}
if upgrade := request.Header.Get("Upgrade"); upgrade != "websocket" {
t.Fatalf("expected websocket upgrade header, got %q", upgrade)
}
if version := request.Header.Get("Sec-Websocket-Version"); version != "13" {
t.Fatalf("expected websocket version 13, got %q", version)
}
if request.ContentLength != 0 {
t.Fatalf("expected websocket content length 0, got %d", request.ContentLength)
}
if request.Body != nil {
t.Fatal("expected websocket body to be nil")
}
}
func createTestCertificatePEM(t *testing.T, commonName string) ([]byte, *x509.Certificate) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
template := &x509.Certificate{
SerialNumber: big.NewInt(time.Now().UnixNano()),
Subject: pkix.Name{
CommonName: commonName,
},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature,
BasicConstraintsValid: true,
IsCA: true,
}
der, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey)
if err != nil {
t.Fatal(err)
}
certificate, err := x509.ParseCertificate(der)
if err != nil {
t.Fatal(err)
}
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}), certificate
}
func writeTempPEM(t *testing.T, pemData []byte) string {
t.Helper()
path := t.TempDir() + "/ca.pem"
if err := os.WriteFile(path, pemData, 0o600); err != nil {
t.Fatal(err)
}
return path
}
func containsSubject(subjects [][]byte, want []byte) bool {
for _, subject := range subjects {
if bytes.Equal(subject, want) {
return true
}
}
return false
}