mirror of
https://github.com/SagerNet/sing-box.git
synced 2026-04-13 20:28:32 +10:00
309 lines
9.0 KiB
Go
309 lines
9.0 KiB
Go
//go:build with_cloudflared
|
|
|
|
package cloudflare
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/rand"
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/pem"
|
|
"io"
|
|
"math/big"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/sagernet/sing-box/adapter"
|
|
)
|
|
|
|
func TestOriginTLSServerName(t *testing.T) {
|
|
t.Run("origin server name overrides host", func(t *testing.T) {
|
|
serverName := originTLSServerName(OriginRequestConfig{
|
|
OriginServerName: "origin.example.com",
|
|
MatchSNIToHost: true,
|
|
}, "request.example.com")
|
|
if serverName != "origin.example.com" {
|
|
t.Fatalf("expected origin.example.com, got %s", serverName)
|
|
}
|
|
})
|
|
|
|
t.Run("match sni to host strips port", func(t *testing.T) {
|
|
serverName := originTLSServerName(OriginRequestConfig{
|
|
MatchSNIToHost: true,
|
|
}, "request.example.com:443")
|
|
if serverName != "request.example.com" {
|
|
t.Fatalf("expected request.example.com, got %s", serverName)
|
|
}
|
|
})
|
|
|
|
t.Run("match sni to host uses http host header", func(t *testing.T) {
|
|
serverName := originTLSServerName(OriginRequestConfig{
|
|
MatchSNIToHost: true,
|
|
}, effectiveOriginHost(OriginRequestConfig{
|
|
HTTPHostHeader: "origin.example.com",
|
|
MatchSNIToHost: true,
|
|
}, "request.example.com"))
|
|
if serverName != "origin.example.com" {
|
|
t.Fatalf("expected origin.example.com, got %s", serverName)
|
|
}
|
|
})
|
|
|
|
t.Run("match sni to host strips port from http host header", func(t *testing.T) {
|
|
serverName := originTLSServerName(OriginRequestConfig{
|
|
MatchSNIToHost: true,
|
|
}, effectiveOriginHost(OriginRequestConfig{
|
|
HTTPHostHeader: "origin.example.com:8443",
|
|
MatchSNIToHost: true,
|
|
}, "request.example.com"))
|
|
if serverName != "origin.example.com" {
|
|
t.Fatalf("expected origin.example.com, got %s", serverName)
|
|
}
|
|
})
|
|
|
|
t.Run("disabled match keeps empty server name", func(t *testing.T) {
|
|
serverName := originTLSServerName(OriginRequestConfig{}, "request.example.com")
|
|
if serverName != "" {
|
|
t.Fatalf("expected empty server name, got %s", serverName)
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestNewOriginTLSConfigErrorsOnMissingCAPool(t *testing.T) {
|
|
originalBaseLoader := loadOriginCABasePool
|
|
loadOriginCABasePool = func() (*x509.CertPool, error) {
|
|
return x509.NewCertPool(), nil
|
|
}
|
|
defer func() {
|
|
loadOriginCABasePool = originalBaseLoader
|
|
}()
|
|
|
|
_, err := newOriginTLSConfig(OriginRequestConfig{
|
|
CAPool: "/path/does/not/exist.pem",
|
|
}, "request.example.com")
|
|
if err == nil {
|
|
t.Fatal("expected error for missing ca pool")
|
|
}
|
|
}
|
|
|
|
func TestNewOriginTLSConfigAppendsCustomCAInsteadOfReplacingBasePool(t *testing.T) {
|
|
basePEM, baseCert := createTestCertificatePEM(t, "base")
|
|
customPEM, customCert := createTestCertificatePEM(t, "custom")
|
|
|
|
basePool := x509.NewCertPool()
|
|
if !basePool.AppendCertsFromPEM(basePEM) {
|
|
t.Fatal("expected base cert to append")
|
|
}
|
|
|
|
originalBaseLoader := loadOriginCABasePool
|
|
loadOriginCABasePool = func() (*x509.CertPool, error) {
|
|
return basePool, nil
|
|
}
|
|
defer func() {
|
|
loadOriginCABasePool = originalBaseLoader
|
|
}()
|
|
|
|
caFile := writeTempPEM(t, customPEM)
|
|
tlsConfig, err := newOriginTLSConfig(OriginRequestConfig{
|
|
CAPool: caFile,
|
|
}, "request.example.com")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if tlsConfig.RootCAs == nil {
|
|
t.Fatal("expected root CA pool")
|
|
}
|
|
subjects := tlsConfig.RootCAs.Subjects()
|
|
if len(subjects) != 2 {
|
|
t.Fatalf("expected 2 subjects, got %d", len(subjects))
|
|
}
|
|
if !containsSubject(subjects, baseCert.RawSubject) {
|
|
t.Fatal("expected base subject to remain in pool")
|
|
}
|
|
if !containsSubject(subjects, customCert.RawSubject) {
|
|
t.Fatal("expected custom subject to be appended to pool")
|
|
}
|
|
}
|
|
|
|
func TestOriginTransportUsesProxyFromEnvironmentOnly(t *testing.T) {
|
|
originalProxyFromEnvironment := proxyFromEnvironment
|
|
proxyFromEnvironment = func(request *http.Request) (*url.URL, error) {
|
|
return url.Parse("http://proxy.example.com:8080")
|
|
}
|
|
defer func() {
|
|
proxyFromEnvironment = originalProxyFromEnvironment
|
|
}()
|
|
|
|
inbound := &Inbound{}
|
|
transport, cleanup, err := inbound.newDirectOriginTransport(ResolvedService{
|
|
Kind: ResolvedServiceUnix,
|
|
UnixPath: "/tmp/test.sock",
|
|
OriginRequest: OriginRequestConfig{
|
|
ProxyAddress: "127.0.0.1",
|
|
ProxyPort: 8081,
|
|
ProxyType: "http",
|
|
},
|
|
}, "")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer cleanup()
|
|
|
|
proxyURL, err := transport.Proxy(&http.Request{URL: &url.URL{Scheme: "http", Host: "example.com"}})
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if proxyURL == nil || proxyURL.String() != "http://proxy.example.com:8080" {
|
|
t.Fatalf("expected environment proxy URL, got %#v", proxyURL)
|
|
}
|
|
}
|
|
|
|
func TestNewDirectOriginTransportNoHappyEyeballs(t *testing.T) {
|
|
inbound := &Inbound{}
|
|
transport, cleanup, err := inbound.newDirectOriginTransport(ResolvedService{
|
|
Kind: ResolvedServiceUnix,
|
|
UnixPath: "/tmp/test.sock",
|
|
OriginRequest: OriginRequestConfig{
|
|
NoHappyEyeballs: true,
|
|
},
|
|
}, "")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
defer cleanup()
|
|
if transport.Proxy == nil {
|
|
t.Fatal("expected proxy function to be configured from environment")
|
|
}
|
|
if transport.DialContext == nil {
|
|
t.Fatal("expected custom direct dial context")
|
|
}
|
|
}
|
|
|
|
func TestNewRouterOriginTransportPropagatesTLSConfigError(t *testing.T) {
|
|
originalBaseLoader := loadOriginCABasePool
|
|
loadOriginCABasePool = func() (*x509.CertPool, error) {
|
|
return x509.NewCertPool(), nil
|
|
}
|
|
defer func() {
|
|
loadOriginCABasePool = originalBaseLoader
|
|
}()
|
|
|
|
inbound := &Inbound{}
|
|
_, _, err := inbound.newRouterOriginTransport(context.Background(), adapter.InboundContext{}, OriginRequestConfig{
|
|
CAPool: "/path/does/not/exist.pem",
|
|
}, "")
|
|
if err == nil {
|
|
t.Fatal("expected transport build error")
|
|
}
|
|
}
|
|
|
|
func TestNormalizeOriginRequestSetsKeepAliveAndEmptyUserAgent(t *testing.T) {
|
|
request, err := http.NewRequest(http.MethodGet, "https://example.com/path", http.NoBody)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
request = normalizeOriginRequest(ConnectionTypeHTTP, request, OriginRequestConfig{})
|
|
if connection := request.Header.Get("Connection"); connection != "keep-alive" {
|
|
t.Fatalf("expected keep-alive connection header, got %q", connection)
|
|
}
|
|
if values, exists := request.Header["User-Agent"]; !exists || len(values) != 1 || values[0] != "" {
|
|
t.Fatalf("expected empty User-Agent header, got %#v", request.Header["User-Agent"])
|
|
}
|
|
}
|
|
|
|
func TestNormalizeOriginRequestDisableChunkedEncoding(t *testing.T) {
|
|
request, err := http.NewRequest(http.MethodPost, "https://example.com/path", strings.NewReader("payload"))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
request.TransferEncoding = []string{"chunked"}
|
|
request.Header.Set("Content-Length", "7")
|
|
|
|
request = normalizeOriginRequest(ConnectionTypeHTTP, request, OriginRequestConfig{
|
|
DisableChunkedEncoding: true,
|
|
})
|
|
if len(request.TransferEncoding) != 2 || request.TransferEncoding[0] != "gzip" || request.TransferEncoding[1] != "deflate" {
|
|
t.Fatalf("unexpected transfer encoding: %#v", request.TransferEncoding)
|
|
}
|
|
if request.ContentLength != 7 {
|
|
t.Fatalf("expected content length 7, got %d", request.ContentLength)
|
|
}
|
|
}
|
|
|
|
func TestNormalizeOriginRequestWebsocket(t *testing.T) {
|
|
request, err := http.NewRequest(http.MethodGet, "https://example.com/path", io.NopCloser(strings.NewReader("payload")))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
request = normalizeOriginRequest(ConnectionTypeWebsocket, request, OriginRequestConfig{})
|
|
if connection := request.Header.Get("Connection"); connection != "Upgrade" {
|
|
t.Fatalf("expected websocket connection header, got %q", connection)
|
|
}
|
|
if upgrade := request.Header.Get("Upgrade"); upgrade != "websocket" {
|
|
t.Fatalf("expected websocket upgrade header, got %q", upgrade)
|
|
}
|
|
if version := request.Header.Get("Sec-Websocket-Version"); version != "13" {
|
|
t.Fatalf("expected websocket version 13, got %q", version)
|
|
}
|
|
if request.ContentLength != 0 {
|
|
t.Fatalf("expected websocket content length 0, got %d", request.ContentLength)
|
|
}
|
|
if request.Body != nil {
|
|
t.Fatal("expected websocket body to be nil")
|
|
}
|
|
}
|
|
|
|
func createTestCertificatePEM(t *testing.T, commonName string) ([]byte, *x509.Certificate) {
|
|
t.Helper()
|
|
|
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
template := &x509.Certificate{
|
|
SerialNumber: big.NewInt(time.Now().UnixNano()),
|
|
Subject: pkix.Name{
|
|
CommonName: commonName,
|
|
},
|
|
NotBefore: time.Now().Add(-time.Hour),
|
|
NotAfter: time.Now().Add(time.Hour),
|
|
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature,
|
|
BasicConstraintsValid: true,
|
|
IsCA: true,
|
|
}
|
|
der, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
certificate, err := x509.ParseCertificate(der)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}), certificate
|
|
}
|
|
|
|
func writeTempPEM(t *testing.T, pemData []byte) string {
|
|
t.Helper()
|
|
path := t.TempDir() + "/ca.pem"
|
|
if err := os.WriteFile(path, pemData, 0o600); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return path
|
|
}
|
|
|
|
func containsSubject(subjects [][]byte, want []byte) bool {
|
|
for _, subject := range subjects {
|
|
if bytes.Equal(subject, want) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|