Refactor cloudflared to depend on sing-cloudflared

Replace the inline cloudflared implementation with a thin adapter
wrapping github.com/sagernet/sing-cloudflared. The protocol/cloudflare
package is reduced to a single inbound.go that bridges the external
Service to the sing-box router.
This commit is contained in:
世界
2026-04-08 14:15:22 +08:00
parent d5a2fd7e95
commit 2e3cb87263
61 changed files with 123 additions and 17423 deletions

13
go.mod
View File

@@ -14,7 +14,6 @@ require (
github.com/go-chi/render v1.0.3
github.com/godbus/dbus/v5 v5.2.2
github.com/gofrs/uuid/v5 v5.4.0
github.com/google/uuid v1.6.0
github.com/insomniacslk/dhcp v0.0.0-20260220084031-5adc3eb26f91
github.com/jsimonetti/rtnetlink v1.4.0
github.com/keybase/go-keychain v0.0.1
@@ -38,13 +37,14 @@ require (
github.com/sagernet/gomobile v0.1.12
github.com/sagernet/gvisor v0.0.0-20250811.0-sing-box-mod.1
github.com/sagernet/quic-go v0.59.0-sing-box-mod.4
github.com/sagernet/sing v0.8.3
github.com/sagernet/sing v0.8.4
github.com/sagernet/sing-cloudflared v0.0.0-20260407120610-7715dc2523fa
github.com/sagernet/sing-mux v0.3.4
github.com/sagernet/sing-quic v0.6.2-0.20260330152607-bf674c163212
github.com/sagernet/sing-shadowsocks v0.2.8
github.com/sagernet/sing-shadowsocks2 v0.2.1
github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11
github.com/sagernet/sing-tun v0.8.7-0.20260323120017-8eb4e8acfc2d
github.com/sagernet/sing-tun v0.8.7-0.20260402180740-11f6e77ec6c6
github.com/sagernet/sing-vmess v0.2.8-0.20250909125414-3aed155119a1
github.com/sagernet/smux v1.5.50-sing-box-mod.1
github.com/sagernet/tailscale v1.92.4-sing-box-1.13-mod.7
@@ -64,7 +64,6 @@ require (
google.golang.org/grpc v1.79.1
google.golang.org/protobuf v1.36.11
howett.net/plist v1.0.1
zombiezen.com/go/capnproto2 v2.18.2+incompatible
)
require (
@@ -75,7 +74,7 @@ require (
github.com/andybalholm/brotli v1.1.0 // indirect
github.com/cenkalti/backoff/v4 v4.3.0 // indirect
github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 // indirect
github.com/coreos/go-oidc/v3 v3.12.0 // indirect
github.com/coreos/go-oidc/v3 v3.17.0 // indirect
github.com/database64128/netx-go v0.1.1 // indirect
github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect
github.com/dblohm7/wingoes v0.0.0-20240119213807-a09d6be7affa // indirect
@@ -95,6 +94,7 @@ require (
github.com/google/go-cmp v0.7.0 // indirect
github.com/google/go-querystring v1.1.0 // indirect
github.com/google/nftables v0.2.1-0.20240414091927-5e242ec57806 // indirect
github.com/google/uuid v1.6.0 // indirect
github.com/hashicorp/yamux v0.1.2 // indirect
github.com/hdevalence/ed25519consensus v0.2.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
@@ -102,6 +102,7 @@ require (
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/mdlayher/socket v0.5.1 // indirect
github.com/mitchellh/go-ps v1.0.0 // indirect
github.com/philhofer/fwd v1.2.0 // indirect
github.com/pierrec/lz4/v4 v4.1.21 // indirect
github.com/pires/go-proxyproto v0.8.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
@@ -151,7 +152,6 @@ require (
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/tinylib/msgp v1.6.3 // indirect
github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 // indirect
github.com/x448/float16 v0.8.4 // indirect
github.com/zeebo/blake3 v0.2.4 // indirect
@@ -169,4 +169,5 @@ require (
google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
lukechampine.com/blake3 v1.3.0 // indirect
zombiezen.com/go/capnproto2 v2.18.2+incompatible // indirect
)

16
go.sum
View File

@@ -28,8 +28,8 @@ github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9
github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg=
github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 h1:8h5+bWd7R6AYUslN6c6iuZWTKsKxUFDlpnmilO6R2n0=
github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q=
github.com/coreos/go-oidc/v3 v3.12.0 h1:sJk+8G2qq94rDI6ehZ71Bol3oUHy63qNYmkiSjrc/Jo=
github.com/coreos/go-oidc/v3 v3.12.0/go.mod h1:gE3LgjOgFoHi9a4ce4/tJczr0Ai2/BoDhf0r5lltWI0=
github.com/coreos/go-oidc/v3 v3.17.0 h1:hWBGaQfbi0iVviX4ibC7bk8OKT5qNr4klBaCHVNvehc=
github.com/coreos/go-oidc/v3 v3.17.0/go.mod h1:wqPbKFrVnE90vty060SB40FCJ8fTHTxSwyXJqZH+sI8=
github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g=
github.com/cretz/bine v0.2.0 h1:8GiDRGlTgz+o8H9DSnsl+5MeBK4HsExxgl6WgzOCuZo=
github.com/cretz/bine v0.2.0/go.mod h1:WU4o9QR9wWp8AVKtTM1XD5vUHkEqnf2vVSo6dBqbetI=
@@ -112,6 +112,8 @@ github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zt
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/letsencrypt/challtestsrv v1.4.2 h1:0ON3ldMhZyWlfVNYYpFuWRTmZNnyfiL9Hh5YzC3JVwU=
github.com/letsencrypt/challtestsrv v1.4.2/go.mod h1:GhqMqcSoeGpYd5zX5TgwA6er/1MbWzx/o7yuuVya+Wk=
github.com/letsencrypt/pebble/v2 v2.10.0 h1:Wq6gYXlsY6ubqI3hhxsTzdyotvfdjFBxuwYqCLCnj/U=
@@ -240,8 +242,10 @@ github.com/sagernet/nftables v0.3.0-beta.4 h1:kbULlAwAC3jvdGAC1P5Fa3GSxVwQJibNen
github.com/sagernet/nftables v0.3.0-beta.4/go.mod h1:OQXAjvjNGGFxaTgVCSTRIhYB5/llyVDeapVoENYBDS8=
github.com/sagernet/quic-go v0.59.0-sing-box-mod.4 h1:6qvrUW79S+CrPwWz6cMePXohgjHoKxLo3c+MDhNwc3o=
github.com/sagernet/quic-go v0.59.0-sing-box-mod.4/go.mod h1:OqILvS182CyOol5zNNo6bguvOGgXzV459+chpRaUC+4=
github.com/sagernet/sing v0.8.3 h1:zGMy9M1deBPEew9pCYIUHKeE+/lDQ5A2CBqjBjjzqkA=
github.com/sagernet/sing v0.8.3/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/sagernet/sing v0.8.4 h1:Fj+jlY3F8vhcRfz/G/P3Dwcs5wqnmyNPT7u1RVVmjFI=
github.com/sagernet/sing v0.8.4/go.mod h1:ARkL0gM13/Iv5VCZmci/NuoOlePoIsW0m7BWfln/Hak=
github.com/sagernet/sing-cloudflared v0.0.0-20260407120610-7715dc2523fa h1:165HiOfgfofJIirEp1NGSmsoJAi+++WhR29IhtAu4A4=
github.com/sagernet/sing-cloudflared v0.0.0-20260407120610-7715dc2523fa/go.mod h1:bH2NKX+NpDTY1Zkxfboxw6MXB/ZywaNLmrDJYgKMJ2Y=
github.com/sagernet/sing-mux v0.3.4 h1:ZQplKl8MNXutjzbMVtWvWG31fohhgOfCuUZR4dVQ8+s=
github.com/sagernet/sing-mux v0.3.4/go.mod h1:QvlKMyNBNrQoyX4x+gq028uPbLM2XeRpWtDsWBJbFSk=
github.com/sagernet/sing-quic v0.6.2-0.20260330152607-bf674c163212 h1:7mFOUqy+DyOj7qKGd1X54UMXbnbJiiMileK/tn17xYc=
@@ -252,8 +256,8 @@ github.com/sagernet/sing-shadowsocks2 v0.2.1 h1:dWV9OXCeFPuYGHb6IRqlSptVnSzOelnq
github.com/sagernet/sing-shadowsocks2 v0.2.1/go.mod h1:RnXS0lExcDAovvDeniJ4IKa2IuChrdipolPYWBv9hWQ=
github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11 h1:tK+75l64tm9WvEFrYRE1t0YxoFdWQqw/h7Uhzj0vJ+w=
github.com/sagernet/sing-shadowtls v0.2.1-0.20250503051639-fcd445d33c11/go.mod h1:sWqKnGlMipCHaGsw1sTTlimyUpgzP4WP3pjhCsYt9oA=
github.com/sagernet/sing-tun v0.8.7-0.20260323120017-8eb4e8acfc2d h1:vi0j6301f6H8t2GYgAC2PA2AdnGdMwkP34B4+N03Qt4=
github.com/sagernet/sing-tun v0.8.7-0.20260323120017-8eb4e8acfc2d/go.mod h1:pLCo4o+LacXEzz0bhwhJkKBjLlKOGPBNOAZ97ZVZWzs=
github.com/sagernet/sing-tun v0.8.7-0.20260402180740-11f6e77ec6c6 h1:HV2I7DicF5Ar8v6F55f03W5FviBB7jgvLhJSDwbFvbk=
github.com/sagernet/sing-tun v0.8.7-0.20260402180740-11f6e77ec6c6/go.mod h1:pLCo4o+LacXEzz0bhwhJkKBjLlKOGPBNOAZ97ZVZWzs=
github.com/sagernet/sing-vmess v0.2.8-0.20250909125414-3aed155119a1 h1:aSwUNYUkVyVvdmBSufR8/nRFonwJeKSIROxHcm5br9o=
github.com/sagernet/sing-vmess v0.2.8-0.20250909125414-3aed155119a1/go.mod h1:P11scgTxMxVVQ8dlM27yNm3Cro40mD0+gHbnqrNGDuY=
github.com/sagernet/smux v1.5.50-sing-box-mod.1 h1:XkJcivBC9V4wBjiGXIXZ229aZCU1hzcbp6kSkkyQ478=

View File

@@ -6,6 +6,7 @@ type CloudflaredInboundOptions struct {
Token string `json:"token,omitempty"`
HAConnections int `json:"ha_connections,omitempty"`
Protocol string `json:"protocol,omitempty"`
PostQuantum bool `json:"post_quantum,omitempty"`
ControlDialer DialerOptions `json:"control_dialer,omitempty"`
TunnelDialer DialerOptions `json:"tunnel_dialer,omitempty"`
EdgeIPVersion int `json:"edge_ip_version,omitempty"`

View File

@@ -1,120 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"fmt"
"net"
"net/http"
"strings"
"sync"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/coreos/go-oidc/v3/oidc"
)
const accessJWTAssertionHeader = "Cf-Access-Jwt-Assertion"
var newAccessValidator = func(access AccessConfig, dialer N.Dialer) (accessValidator, error) {
issuerURL := accessIssuerURL(access.TeamName, access.Environment)
client := &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
return dialer.DialContext(ctx, network, M.ParseSocksaddr(address))
},
},
}
keySet := oidc.NewRemoteKeySet(oidc.ClientContext(context.Background(), client), issuerURL+"/cdn-cgi/access/certs")
verifier := oidc.NewVerifier(issuerURL, keySet, &oidc.Config{
SkipClientIDCheck: true,
})
return &oidcAccessValidator{
verifier: verifier,
audTags: append([]string(nil), access.AudTag...),
}, nil
}
type accessValidator interface {
Validate(ctx context.Context, request *http.Request) error
}
type oidcAccessValidator struct {
verifier *oidc.IDTokenVerifier
audTags []string
}
func (v *oidcAccessValidator) Validate(ctx context.Context, request *http.Request) error {
accessJWT := request.Header.Get(accessJWTAssertionHeader)
if accessJWT == "" {
return E.New("missing access jwt assertion")
}
token, err := v.verifier.Verify(ctx, accessJWT)
if err != nil {
return err
}
if accessTokenAudienceAllowed(token.Audience, v.audTags) {
return nil
}
return E.New("access token audience does not match configured aud_tag")
}
func accessTokenAudienceAllowed(tokenAudience []string, configuredAudTags []string) bool {
for _, tokenAudTag := range tokenAudience {
for _, configuredAudTag := range configuredAudTags {
if configuredAudTag == tokenAudTag {
return true
}
}
}
return false
}
func accessIssuerURL(teamName string, environment string) string {
if strings.EqualFold(environment, "fed") || strings.EqualFold(environment, "fips") {
return fmt.Sprintf("https://%s.fed.cloudflareaccess.com", teamName)
}
return fmt.Sprintf("https://%s.cloudflareaccess.com", teamName)
}
func validateAccessConfiguration(access AccessConfig) error {
if !access.Required {
return nil
}
if access.TeamName == "" && len(access.AudTag) > 0 {
return E.New("access.team_name cannot be blank when access.aud_tag is present")
}
return nil
}
func accessValidatorKey(access AccessConfig) string {
return access.TeamName + "|" + access.Environment + "|" + strings.Join(access.AudTag, ",")
}
type accessValidatorCache struct {
access sync.RWMutex
values map[string]accessValidator
dialer N.Dialer
}
func (c *accessValidatorCache) Get(accessConfig AccessConfig) (accessValidator, error) {
key := accessValidatorKey(accessConfig)
c.access.RLock()
validator, loaded := c.values[key]
c.access.RUnlock()
if loaded {
return validator, nil
}
validator, err := newAccessValidator(accessConfig, c.dialer)
if err != nil {
return nil, err
}
c.access.Lock()
c.values[key] = validator
c.access.Unlock()
return validator, nil
}

View File

@@ -1,191 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"net/http"
"testing"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
type fakeAccessValidator struct {
err error
}
func (v *fakeAccessValidator) Validate(ctx context.Context, request *http.Request) error {
return v.err
}
func newAccessTestInbound(t *testing.T) *Inbound {
t.Helper()
logFactory, err := log.New(log.Options{Options: option.LogOptions{Level: "debug"}})
if err != nil {
t.Fatal(err)
}
return &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
logger: logFactory.NewLogger("test"),
accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: N.SystemDialer},
router: &testRouter{},
controlDialer: N.SystemDialer,
}
}
func TestValidateAccessConfiguration(t *testing.T) {
err := validateAccessConfiguration(AccessConfig{
Required: true,
AudTag: []string{"aud"},
})
if err == nil {
t.Fatal("expected access config validation error")
}
}
func TestAccessTokenAudienceAllowed(t *testing.T) {
testCases := []struct {
name string
tokenAudience []string
configuredTags []string
expected bool
}{
{
name: "matching audience",
tokenAudience: []string{"aud-1", "aud-2"},
configuredTags: []string{"aud-2"},
expected: true,
},
{
name: "empty configured tags rejected",
tokenAudience: []string{"aud-1"},
configuredTags: nil,
expected: false,
},
{
name: "non matching audience rejected",
tokenAudience: []string{"aud-1"},
configuredTags: []string{"aud-2"},
expected: false,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
allowed := accessTokenAudienceAllowed(testCase.tokenAudience, testCase.configuredTags)
if allowed != testCase.expected {
t.Fatalf("accessTokenAudienceAllowed(%v, %v) = %v, want %v", testCase.tokenAudience, testCase.configuredTags, allowed, testCase.expected)
}
})
}
}
func TestRoundTripHTTPAccessDenied(t *testing.T) {
originalFactory := newAccessValidator
defer func() {
newAccessValidator = originalFactory
}()
newAccessValidator = func(access AccessConfig, dialer N.Dialer) (accessValidator, error) {
return &fakeAccessValidator{err: E.New("forbidden")}, nil
}
inboundInstance := newAccessTestInbound(t)
respWriter := &fakeConnectResponseWriter{}
request := &ConnectRequest{
Type: ConnectionTypeHTTP,
Dest: "http://127.0.0.1:8083/test",
Metadata: []Metadata{
{Key: metadataHTTPMethod, Val: http.MethodGet},
{Key: metadataHTTPHost, Val: "example.com"},
},
}
inboundInstance.handleHTTPService(context.Background(), nil, respWriter, request, adapter.InboundContext{}, ResolvedService{
Kind: ResolvedServiceHTTP,
Destination: M.ParseSocksaddr("127.0.0.1:8083"),
OriginRequest: OriginRequestConfig{
Access: AccessConfig{
Required: true,
TeamName: "team",
},
},
})
if respWriter.status != http.StatusForbidden {
t.Fatalf("expected 403, got %d", respWriter.status)
}
}
func TestHandleHTTPServiceStatusAccessDenied(t *testing.T) {
originalFactory := newAccessValidator
defer func() {
newAccessValidator = originalFactory
}()
newAccessValidator = func(access AccessConfig, dialer N.Dialer) (accessValidator, error) {
return &fakeAccessValidator{err: E.New("forbidden")}, nil
}
inboundInstance := newAccessTestInbound(t)
respWriter := &fakeConnectResponseWriter{}
request := &ConnectRequest{
Type: ConnectionTypeHTTP,
Dest: "https://example.com/status",
Metadata: []Metadata{
{Key: metadataHTTPMethod, Val: http.MethodGet},
{Key: metadataHTTPHost, Val: "example.com"},
},
}
inboundInstance.handleHTTPService(context.Background(), nil, respWriter, request, adapter.InboundContext{}, ResolvedService{
Kind: ResolvedServiceStatus,
OriginRequest: OriginRequestConfig{
Access: AccessConfig{
Required: true,
TeamName: "team",
},
},
StatusCode: 404,
})
if respWriter.status != http.StatusForbidden {
t.Fatalf("expected 403, got %d", respWriter.status)
}
}
func TestHandleHTTPServiceStreamAccessDenied(t *testing.T) {
originalFactory := newAccessValidator
defer func() {
newAccessValidator = originalFactory
}()
newAccessValidator = func(access AccessConfig, dialer N.Dialer) (accessValidator, error) {
return &fakeAccessValidator{err: E.New("forbidden")}, nil
}
inboundInstance := newAccessTestInbound(t)
respWriter := &fakeConnectResponseWriter{}
request := &ConnectRequest{
Type: ConnectionTypeWebsocket,
Dest: "https://example.com/ws",
Metadata: []Metadata{
{Key: metadataHTTPMethod, Val: http.MethodGet},
{Key: metadataHTTPHost, Val: "example.com"},
{Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="},
},
}
inboundInstance.handleHTTPService(context.Background(), nil, respWriter, request, adapter.InboundContext{}, ResolvedService{
Kind: ResolvedServiceStream,
Destination: M.ParseSocksaddr("127.0.0.1:8080"),
OriginRequest: OriginRequestConfig{
Access: AccessConfig{
Required: true,
TeamName: "team",
},
},
})
if respWriter.status != http.StatusForbidden {
t.Fatalf("expected 403, got %d", respWriter.status)
}
}

View File

@@ -1,75 +0,0 @@
-----BEGIN CERTIFICATE-----
MIICiTCCAi6gAwIBAgIUXZP3MWb8MKwBE1Qbawsp1sfA/Y4wCgYIKoZIzj0EAwIw
gY8xCzAJBgNVBAYTAlVTMRMwEQYDVQQIEwpDYWxpZm9ybmlhMRYwFAYDVQQHEw1T
YW4gRnJhbmNpc2NvMRkwFwYDVQQKExBDbG91ZEZsYXJlLCBJbmMuMTgwNgYDVQQL
Ey9DbG91ZEZsYXJlIE9yaWdpbiBTU0wgRUNDIENlcnRpZmljYXRlIEF1dGhvcml0
eTAeFw0xOTA4MjMyMTA4MDBaFw0yOTA4MTUxNzAwMDBaMIGPMQswCQYDVQQGEwJV
UzETMBEGA1UECBMKQ2FsaWZvcm5pYTEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzEZ
MBcGA1UEChMQQ2xvdWRGbGFyZSwgSW5jLjE4MDYGA1UECxMvQ2xvdWRGbGFyZSBP
cmlnaW4gU1NMIEVDQyBDZXJ0aWZpY2F0ZSBBdXRob3JpdHkwWTATBgcqhkjOPQIB
BggqhkjOPQMBBwNCAASR+sGALuaGshnUbcxKry+0LEXZ4NY6JUAtSeA6g87K3jaA
xpIg9G50PokpfWkhbarLfpcZu0UAoYy2su0EhN7wo2YwZDAOBgNVHQ8BAf8EBAMC
AQYwEgYDVR0TAQH/BAgwBgEB/wIBAjAdBgNVHQ4EFgQUhTBdOypw1O3VkmcH/es5
tBoOOKcwHwYDVR0jBBgwFoAUhTBdOypw1O3VkmcH/es5tBoOOKcwCgYIKoZIzj0E
AwIDSQAwRgIhAKilfntP2ILGZjwajktkBtXE1pB4Y/fjAfLkIRUzrI15AiEA5UCL
XYZZ9m2c3fKwIenMMojL1eqydsgqj/wK4p5kagQ=
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIEADCCAuigAwIBAgIID+rOSdTGfGcwDQYJKoZIhvcNAQELBQAwgYsxCzAJBgNV
BAYTAlVTMRkwFwYDVQQKExBDbG91ZEZsYXJlLCBJbmMuMTQwMgYDVQQLEytDbG91
ZEZsYXJlIE9yaWdpbiBTU0wgQ2VydGlmaWNhdGUgQXV0aG9yaXR5MRYwFAYDVQQH
Ew1TYW4gRnJhbmNpc2NvMRMwEQYDVQQIEwpDYWxpZm9ybmlhMB4XDTE5MDgyMzIx
MDgwMFoXDTI5MDgxNTE3MDAwMFowgYsxCzAJBgNVBAYTAlVTMRkwFwYDVQQKExBD
bG91ZEZsYXJlLCBJbmMuMTQwMgYDVQQLEytDbG91ZEZsYXJlIE9yaWdpbiBTU0wg
Q2VydGlmaWNhdGUgQXV0aG9yaXR5MRYwFAYDVQQHEw1TYW4gRnJhbmNpc2NvMRMw
EQYDVQQIEwpDYWxpZm9ybmlhMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKC
AQEAwEiVZ/UoQpHmFsHvk5isBxRehukP8DG9JhFev3WZtG76WoTthvLJFRKFCHXm
V6Z5/66Z4S09mgsUuFwvJzMnE6Ej6yIsYNCb9r9QORa8BdhrkNn6kdTly3mdnykb
OomnwbUfLlExVgNdlP0XoRoeMwbQ4598foiHblO2B/LKuNfJzAMfS7oZe34b+vLB
yrP/1bgCSLdc1AxQc1AC0EsQQhgcyTJNgnG4va1c7ogPlwKyhbDyZ4e59N5lbYPJ
SmXI/cAe3jXj1FBLJZkwnoDKe0v13xeF+nF32smSH0qB7aJX2tBMW4TWtFPmzs5I
lwrFSySWAdwYdgxw180yKU0dvwIDAQABo2YwZDAOBgNVHQ8BAf8EBAMCAQYwEgYD
VR0TAQH/BAgwBgEB/wIBAjAdBgNVHQ4EFgQUJOhTV118NECHqeuU27rhFnj8KaQw
HwYDVR0jBBgwFoAUJOhTV118NECHqeuU27rhFnj8KaQwDQYJKoZIhvcNAQELBQAD
ggEBAHwOf9Ur1l0Ar5vFE6PNrZWrDfQIMyEfdgSKofCdTckbqXNTiXdgbHs+TWoQ
wAB0pfJDAHJDXOTCWRyTeXOseeOi5Btj5CnEuw3P0oXqdqevM1/+uWp0CM35zgZ8
VD4aITxity0djzE6Qnx3Syzz+ZkoBgTnNum7d9A66/V636x4vTeqbZFBr9erJzgz
hhurjcoacvRNhnjtDRM0dPeiCJ50CP3wEYuvUzDHUaowOsnLCjQIkWbR7Ni6KEIk
MOz2U0OBSif3FTkhCgZWQKOOLo1P42jHC3ssUZAtVNXrCk3fw9/E15k8NPkBazZ6
0iykLhH1trywrKRMVw67F44IE8Y=
-----END CERTIFICATE-----
-----BEGIN CERTIFICATE-----
MIIGCjCCA/KgAwIBAgIIV5G6lVbCLmEwDQYJKoZIhvcNAQENBQAwgZAxCzAJBgNV
BAYTAlVTMRkwFwYDVQQKExBDbG91ZEZsYXJlLCBJbmMuMRQwEgYDVQQLEwtPcmln
aW4gUHVsbDEWMBQGA1UEBxMNU2FuIEZyYW5jaXNjbzETMBEGA1UECBMKQ2FsaWZv
cm5pYTEjMCEGA1UEAxMab3JpZ2luLXB1bGwuY2xvdWRmbGFyZS5uZXQwHhcNMTkx
MDEwMTg0NTAwWhcNMjkxMTAxMTcwMDAwWjCBkDELMAkGA1UEBhMCVVMxGTAXBgNV
BAoTEENsb3VkRmxhcmUsIEluYy4xFDASBgNVBAsTC09yaWdpbiBQdWxsMRYwFAYD
VQQHEw1TYW4gRnJhbmNpc2NvMRMwEQYDVQQIEwpDYWxpZm9ybmlhMSMwIQYDVQQD
ExpvcmlnaW4tcHVsbC5jbG91ZGZsYXJlLm5ldDCCAiIwDQYJKoZIhvcNAQEBBQAD
ggIPADCCAgoCggIBAN2y2zojYfl0bKfhp0AJBFeV+jQqbCw3sHmvEPwLmqDLqynI
42tZXR5y914ZB9ZrwbL/K5O46exd/LujJnV2b3dzcx5rtiQzso0xzljqbnbQT20e
ihx/WrF4OkZKydZzsdaJsWAPuplDH5P7J82q3re88jQdgE5hqjqFZ3clCG7lxoBw
hLaazm3NJJlUfzdk97ouRvnFGAuXd5cQVx8jYOOeU60sWqmMe4QHdOvpqB91bJoY
QSKVFjUgHeTpN8tNpKJfb9LIn3pun3bC9NKNHtRKMNX3Kl/sAPq7q/AlndvA2Kw3
Dkum2mHQUGdzVHqcOgea9BGjLK2h7SuX93zTWL02u799dr6Xkrad/WShHchfjjRn
aL35niJUDr02YJtPgxWObsrfOU63B8juLUphW/4BOjjJyAG5l9j1//aUGEi/sEe5
lqVv0P78QrxoxR+MMXiJwQab5FB8TG/ac6mRHgF9CmkX90uaRh+OC07XjTdfSKGR
PpM9hB2ZhLol/nf8qmoLdoD5HvODZuKu2+muKeVHXgw2/A6wM7OwrinxZiyBk5Hh
CvaADH7PZpU6z/zv5NU5HSvXiKtCzFuDu4/Zfi34RfHXeCUfHAb4KfNRXJwMsxUa
+4ZpSAX2G6RnGU5meuXpU5/V+DQJp/e69XyyY6RXDoMywaEFlIlXBqjRRA2pAgMB
AAGjZjBkMA4GA1UdDwEB/wQEAwIBBjASBgNVHRMBAf8ECDAGAQH/AgECMB0GA1Ud
DgQWBBRDWUsraYuA4REzalfNVzjann3F6zAfBgNVHSMEGDAWgBRDWUsraYuA4REz
alfNVzjann3F6zANBgkqhkiG9w0BAQ0FAAOCAgEAkQ+T9nqcSlAuW/90DeYmQOW1
QhqOor5psBEGvxbNGV2hdLJY8h6QUq48BCevcMChg/L1CkznBNI40i3/6heDn3IS
zVEwXKf34pPFCACWVMZxbQjkNRTiH8iRur9EsaNQ5oXCPJkhwg2+IFyoPAAYURoX
VcI9SCDUa45clmYHJ/XYwV1icGVI8/9b2JUqklnOTa5tugwIUi5sTfipNcJXHhgz
6BKYDl0/UP0lLKbsUETXeTGDiDpxZYIgbcFrRDDkHC6BSvdWVEiH5b9mH2BON60z
0O0j8EEKTwi9jnafVtZQXP/D8yoVowdFDjXcKkOPF/1gIh9qrFR6GdoPVgB3SkLc
5ulBqZaCHm563jsvWb/kXJnlFxW+1bsO9BDD6DweBcGdNurgmH625wBXksSdD7y/
fakk8DagjbjKShYlPEFOAqEcliwjF45eabL0t27MJV61O/jHzHL3dknXeE4BDa2j
bA+JbyJeUMtU7KMsxvx82RmhqBEJJDBCJ3scVptvhDMRrtqDBW5JShxoAOcpFQGm
iYWicn46nPDjgTU0bX1ZPpTpryXbvciVL5RkVBuyX2ntcOLDPlZWgxZCBp96x07F
AnOzKgZk4RzZPNAxCXERVxajn/FLcOhglVAKo5H0ac+AitlQ0ip55D2/mf8o72tM
fVQ6VpyjEXdiIXWUq/o=
-----END CERTIFICATE-----

View File

@@ -1,72 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"testing"
"time"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common/json"
)
func TestNewInboundRequiresToken(t *testing.T) {
_, err := NewInbound(context.Background(), nil, log.NewNOPFactory().NewLogger("test"), "test", option.CloudflaredInboundOptions{})
if err == nil {
t.Fatal("expected missing token error")
}
}
func TestValidateRegistrationResultRejectsNonRemoteManaged(t *testing.T) {
err := validateRegistrationResult(&RegistrationResult{TunnelIsRemotelyManaged: false})
if err == nil {
t.Fatal("expected unsupported tunnel error")
}
if err != ErrNonRemoteManagedTunnelUnsupported {
t.Fatalf("unexpected error: %v", err)
}
}
func TestNormalizeProtocolAutoUsesTokenStyleSentinel(t *testing.T) {
protocol, err := normalizeProtocol("auto")
if err != nil {
t.Fatal(err)
}
if protocol != "" {
t.Fatalf("expected auto protocol to normalize to token-style empty sentinel, got %q", protocol)
}
}
func TestResolveGracePeriodDefaultsToThirtySeconds(t *testing.T) {
if got := resolveGracePeriod(nil); got != 30*time.Second {
t.Fatalf("expected default grace period 30s, got %s", got)
}
}
func TestResolveGracePeriodPreservesExplicitZero(t *testing.T) {
var options option.CloudflaredInboundOptions
if err := json.Unmarshal([]byte(`{"grace_period":"0s"}`), &options); err != nil {
t.Fatal(err)
}
if options.GracePeriod == nil {
t.Fatal("expected explicit grace period to be set")
}
if got := resolveGracePeriod(options.GracePeriod); got != 0 {
t.Fatalf("expected explicit zero grace period, got %s", got)
}
}
func TestResolveGracePeriodPreservesNonZeroValue(t *testing.T) {
var options option.CloudflaredInboundOptions
if err := json.Unmarshal([]byte(`{"grace_period":"45s"}`), &options); err != nil {
t.Fatal(err)
}
if options.GracePeriod == nil {
t.Fatal("expected explicit grace period to be set")
}
if got := resolveGracePeriod(options.GracePeriod); got != 45*time.Second {
t.Fatalf("expected grace period 45s, got %s", got)
}
}

View File

@@ -1,268 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"errors"
"io"
"net"
"testing"
"time"
"github.com/sagernet/quic-go"
"github.com/google/uuid"
)
type stubNetConn struct {
closed chan struct{}
}
func newStubNetConn() *stubNetConn {
return &stubNetConn{closed: make(chan struct{})}
}
func (c *stubNetConn) Read(_ []byte) (int, error) { <-c.closed; return 0, io.EOF }
func (c *stubNetConn) Write(b []byte) (int, error) { return len(b), nil }
func (c *stubNetConn) Close() error { closeOnce(c.closed); return nil }
func (c *stubNetConn) LocalAddr() net.Addr { return &net.TCPAddr{} }
func (c *stubNetConn) RemoteAddr() net.Addr { return &net.TCPAddr{} }
func (c *stubNetConn) SetDeadline(time.Time) error { return nil }
func (c *stubNetConn) SetReadDeadline(time.Time) error { return nil }
func (c *stubNetConn) SetWriteDeadline(time.Time) error { return nil }
type stubQUICConn struct {
closed chan string
}
func newStubQUICConn() *stubQUICConn {
return &stubQUICConn{closed: make(chan string, 1)}
}
func (c *stubQUICConn) OpenStream() (*quic.Stream, error) { return nil, errors.New("unused") }
func (c *stubQUICConn) AcceptStream(context.Context) (*quic.Stream, error) {
return nil, errors.New("unused")
}
func (c *stubQUICConn) ReceiveDatagram(context.Context) ([]byte, error) {
return nil, errors.New("unused")
}
func (c *stubQUICConn) SendDatagram([]byte) error { return nil }
func (c *stubQUICConn) LocalAddr() net.Addr { return &net.UDPAddr{} }
func (c *stubQUICConn) CloseWithError(_ quic.ApplicationErrorCode, reason string) error {
select {
case c.closed <- reason:
default:
}
return nil
}
type mockRegistrationClient struct {
unregisterCalled chan struct{}
closed chan struct{}
}
func newMockRegistrationClient() *mockRegistrationClient {
return &mockRegistrationClient{
unregisterCalled: make(chan struct{}, 1),
closed: make(chan struct{}, 1),
}
}
func (c *mockRegistrationClient) RegisterConnection(context.Context, TunnelAuth, uuid.UUID, uint8, *RegistrationConnectionOptions) (*RegistrationResult, error) {
return &RegistrationResult{}, nil
}
func (c *mockRegistrationClient) Unregister(context.Context) error {
select {
case c.unregisterCalled <- struct{}{}:
default:
}
return nil
}
func (c *mockRegistrationClient) Close() error {
select {
case c.closed <- struct{}{}:
default:
}
return nil
}
func closeOnce(ch chan struct{}) {
select {
case <-ch:
default:
close(ch)
}
}
func TestHTTP2GracefulShutdownWaitsForActiveRequests(t *testing.T) {
conn := newStubNetConn()
registrationClient := newMockRegistrationClient()
connection := &HTTP2Connection{
conn: conn,
gracePeriod: 200 * time.Millisecond,
registrationClient: registrationClient,
registrationResult: &RegistrationResult{},
serveCancel: func() {},
}
connection.activeRequests.Add(1)
done := make(chan struct{})
go func() {
connection.gracefulShutdown()
close(done)
}()
select {
case <-registrationClient.unregisterCalled:
case <-time.After(time.Second):
t.Fatal("expected unregister call")
}
select {
case <-conn.closed:
t.Fatal("connection closed before active requests completed")
case <-time.After(50 * time.Millisecond):
}
connection.activeRequests.Done()
select {
case <-conn.closed:
case <-time.After(time.Second):
t.Fatal("expected connection close after active requests finished")
}
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("expected graceful shutdown to finish")
}
}
func TestHTTP2GracefulShutdownTimesOut(t *testing.T) {
conn := newStubNetConn()
registrationClient := newMockRegistrationClient()
connection := &HTTP2Connection{
conn: conn,
gracePeriod: 50 * time.Millisecond,
registrationClient: registrationClient,
registrationResult: &RegistrationResult{},
serveCancel: func() {},
}
connection.activeRequests.Add(1)
done := make(chan struct{})
go func() {
connection.gracefulShutdown()
close(done)
}()
select {
case <-conn.closed:
case <-time.After(500 * time.Millisecond):
t.Fatal("expected connection close after grace timeout")
}
connection.activeRequests.Done()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("expected graceful shutdown to finish after request completion")
}
}
func TestQUICGracefulShutdownWaitsForDrainWindow(t *testing.T) {
conn := newStubQUICConn()
registrationClient := newMockRegistrationClient()
serveCancelCalled := make(chan struct{}, 1)
connection := &QUICConnection{
conn: conn,
gracePeriod: 80 * time.Millisecond,
registrationClient: registrationClient,
registrationResult: &RegistrationResult{},
serveCancel: func() {
select {
case serveCancelCalled <- struct{}{}:
default:
}
},
}
done := make(chan struct{})
go func() {
connection.gracefulShutdown()
close(done)
}()
select {
case <-registrationClient.unregisterCalled:
case <-time.After(time.Second):
t.Fatal("expected unregister call")
}
select {
case <-conn.closed:
t.Fatal("connection closed before grace window elapsed")
case <-time.After(20 * time.Millisecond):
}
select {
case reason := <-conn.closed:
if reason != "graceful shutdown" {
t.Fatalf("unexpected close reason: %q", reason)
}
case <-time.After(time.Second):
t.Fatal("expected graceful close")
}
select {
case <-serveCancelCalled:
case <-time.After(time.Second):
t.Fatal("expected serve cancel to be called")
}
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("expected graceful shutdown to finish")
}
}
func TestQUICGracefulShutdownStopsWaitingWhenServeContextEnds(t *testing.T) {
conn := newStubQUICConn()
registrationClient := newMockRegistrationClient()
serveCtx, cancelServe := context.WithCancel(context.Background())
connection := &QUICConnection{
conn: conn,
gracePeriod: time.Second,
registrationClient: registrationClient,
registrationResult: &RegistrationResult{},
serveCtx: serveCtx,
serveCancel: func() {},
}
done := make(chan struct{})
go func() {
connection.gracefulShutdown()
close(done)
}()
select {
case <-registrationClient.unregisterCalled:
case <-time.After(time.Second):
t.Fatal("expected unregister call")
}
cancelServe()
select {
case <-done:
case <-time.After(200 * time.Millisecond):
t.Fatal("expected graceful shutdown to stop waiting once serve context ends")
}
}

View File

@@ -1,522 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"crypto/tls"
"io"
"math"
"net"
"net/http"
"runtime/debug"
"strconv"
"strings"
"sync"
"time"
"github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
M "github.com/sagernet/sing/common/metadata"
"github.com/google/uuid"
"golang.org/x/net/http2"
)
const (
h2EdgeSNI = "h2.cftunnel.com"
h2ResponseMetaCloudflared = `{"src":"cloudflared"}`
h2ResponseMetaCloudflaredLimited = `{"src":"cloudflared","flow_rate_limited":true}`
contentTypeHeader = "content-type"
contentLengthHeader = "content-length"
transferEncodingHeader = "transfer-encoding"
chunkTransferEncoding = "chunked"
sseContentType = "text/event-stream"
grpcContentType = "application/grpc"
ndjsonContentType = "application/x-ndjson"
)
var flushableContentTypes = []string{sseContentType, grpcContentType, ndjsonContentType}
// HTTP2Connection manages a single HTTP/2 connection to the Cloudflare edge.
// Uses role reversal: we dial the edge as a TLS client but serve HTTP/2 as server.
type HTTP2Connection struct {
conn net.Conn
server *http2.Server
logger log.ContextLogger
edgeAddr *EdgeAddr
connIndex uint8
credentials Credentials
connectorID uuid.UUID
features []string
gracePeriod time.Duration
inbound *Inbound
numPreviousAttempts uint8
registrationClient registrationRPCClient
registrationResult *RegistrationResult
controlStreamErr error
activeRequests sync.WaitGroup
serveCancel context.CancelFunc
registrationClose sync.Once
shutdownOnce sync.Once
closeOnce sync.Once
}
// NewHTTP2Connection dials the edge and establishes an HTTP/2 connection with role reversal.
func NewHTTP2Connection(
ctx context.Context,
edgeAddr *EdgeAddr,
connIndex uint8,
credentials Credentials,
connectorID uuid.UUID,
features []string,
numPreviousAttempts uint8,
gracePeriod time.Duration,
inbound *Inbound,
logger log.ContextLogger,
) (*HTTP2Connection, error) {
rootCAs, err := cloudflareRootCertPool()
if err != nil {
return nil, E.Cause(err, "load Cloudflare root CAs")
}
tlsConfig := newEdgeTLSConfig(rootCAs, h2EdgeSNI, nil)
tcpConn, err := inbound.tunnelDialer.DialContext(ctx, "tcp", M.SocksaddrFrom(edgeAddr.TCP.AddrPort().Addr(), edgeAddr.TCP.AddrPort().Port()))
if err != nil {
return nil, E.Cause(err, "dial edge TCP")
}
tlsConn := tls.Client(tcpConn, tlsConfig)
err = tlsConn.HandshakeContext(ctx)
if err != nil {
tcpConn.Close()
return nil, E.Cause(err, "TLS handshake")
}
return &HTTP2Connection{
conn: tlsConn,
server: &http2.Server{
MaxConcurrentStreams: math.MaxUint32,
},
logger: logger,
edgeAddr: edgeAddr,
connIndex: connIndex,
credentials: credentials,
connectorID: connectorID,
features: features,
numPreviousAttempts: numPreviousAttempts,
gracePeriod: gracePeriod,
inbound: inbound,
}, nil
}
// Serve runs the HTTP/2 server. Blocks until the context is cancelled or the connection ends.
func (c *HTTP2Connection) Serve(ctx context.Context) error {
serveCtx, serveCancel := context.WithCancel(context.WithoutCancel(ctx))
c.serveCancel = serveCancel
shutdownDone := make(chan struct{})
go func() {
<-ctx.Done()
c.gracefulShutdown()
close(shutdownDone)
}()
c.server.ServeConn(c.conn, &http2.ServeConnOpts{
Context: serveCtx,
Handler: c,
})
if ctx.Err() != nil {
<-shutdownDone
return ctx.Err()
}
if c.controlStreamErr != nil {
return c.controlStreamErr
}
if c.registrationResult == nil {
return E.New("edge connection closed before registration")
}
return E.New("edge connection closed")
}
func (c *HTTP2Connection) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Header.Get(h2HeaderUpgrade) == h2UpgradeControlStream {
c.handleControlStream(r.Context(), r, w)
return
}
c.activeRequests.Add(1)
defer c.activeRequests.Done()
switch {
case r.Header.Get(h2HeaderUpgrade) == h2UpgradeWebsocket:
c.handleH2DataStream(r.Context(), r, w, ConnectionTypeWebsocket)
case r.Header.Get(h2HeaderTCPSrc) != "":
c.handleH2DataStream(r.Context(), r, w, ConnectionTypeTCP)
case r.Header.Get(h2HeaderUpgrade) == h2UpgradeConfiguration:
c.handleConfigurationUpdate(r, w)
default:
c.handleH2DataStream(r.Context(), r, w, ConnectionTypeHTTP)
}
}
func (c *HTTP2Connection) handleControlStream(ctx context.Context, r *http.Request, w http.ResponseWriter) {
flusher, ok := w.(http.Flusher)
if !ok {
c.logger.Error("response writer does not support flushing")
return
}
w.WriteHeader(http.StatusOK)
flusher.Flush()
stream := newHTTP2Stream(r.Body, &http2FlushWriter{w: w, flusher: flusher})
c.registrationClient = NewRegistrationClient(ctx, stream)
host, _, _ := net.SplitHostPort(c.conn.LocalAddr().String())
originLocalIP := net.ParseIP(host)
options := BuildConnectionOptions(c.connectorID, c.features, c.numPreviousAttempts, originLocalIP)
result, err := c.registrationClient.RegisterConnection(
ctx, c.credentials.Auth(), c.credentials.TunnelID, c.connIndex, options,
)
if err != nil {
c.controlStreamErr = err
c.logger.Error("register connection: ", err)
go c.forceClose()
return
}
if err := validateRegistrationResult(result); err != nil {
c.controlStreamErr = err
c.logger.Error("register connection: ", err)
go c.forceClose()
return
}
c.registrationResult = result
c.inbound.notifyConnected(c.connIndex, "http2")
c.logger.Info("connected to ", result.Location,
" (connection ", result.ConnectionID, ")")
<-ctx.Done()
}
func (c *HTTP2Connection) handleH2DataStream(ctx context.Context, r *http.Request, w http.ResponseWriter, connectionType ConnectionType) {
r.Header.Del(h2HeaderUpgrade)
r.Header.Del(h2HeaderTCPSrc)
flusher, ok := w.(http.Flusher)
if !ok {
c.logger.Error("response writer does not support flushing")
return
}
var destination string
if connectionType == ConnectionTypeTCP {
destination = r.Host
if destination == "" && r.URL != nil {
destination = r.URL.Host
}
} else {
if r.URL.Scheme == "" {
r.URL.Scheme = "http"
}
if r.URL.Host == "" {
r.URL.Host = r.Host
}
destination = r.URL.String()
}
request := &ConnectRequest{
Dest: destination,
Type: connectionType,
}
request.Metadata = append(request.Metadata, Metadata{
Key: metadataHTTPMethod,
Val: r.Method,
})
request.Metadata = append(request.Metadata, Metadata{
Key: metadataHTTPHost,
Val: r.Host,
})
for name, values := range r.Header {
for _, value := range values {
request.Metadata = append(request.Metadata, Metadata{
Key: metadataHTTPHeader + ":" + name,
Val: value,
})
}
}
flushState := &http2FlushState{shouldFlush: connectionType != ConnectionTypeHTTP}
stream := &http2DataStream{
reader: r.Body,
writer: w,
flusher: flusher,
state: flushState,
logger: c.logger,
}
respWriter := &http2ResponseWriter{
writer: w,
flusher: flusher,
flushState: flushState,
}
c.inbound.dispatchRequest(ctx, stream, respWriter, request)
}
type h2ConfigurationUpdateBody struct {
Version int32 `json:"version"`
Config json.RawMessage `json:"config"`
}
func (c *HTTP2Connection) handleConfigurationUpdate(r *http.Request, w http.ResponseWriter) {
var body h2ConfigurationUpdateBody
err := json.NewDecoder(r.Body).Decode(&body)
if err != nil {
c.logger.Error("decode configuration update: ", err)
w.Header().Set(h2HeaderResponseMeta, h2ResponseMetaCloudflared)
w.WriteHeader(http.StatusBadGateway)
return
}
result := c.inbound.ApplyConfig(body.Version, body.Config)
w.WriteHeader(http.StatusOK)
if result.Err != nil {
w.Write([]byte(`{"lastAppliedVersion":` + strconv.FormatInt(int64(result.LastAppliedVersion), 10) + `,"err":` + strconv.Quote(result.Err.Error()) + `}`))
return
}
w.Write([]byte(`{"lastAppliedVersion":` + strconv.FormatInt(int64(result.LastAppliedVersion), 10) + `,"err":null}`))
}
func (c *HTTP2Connection) gracefulShutdown() {
c.shutdownOnce.Do(func() {
if c.registrationClient == nil || c.registrationResult == nil {
c.closeNow()
return
}
unregisterCtx, cancel := context.WithTimeout(context.Background(), c.gracePeriod)
err := c.registrationClient.Unregister(unregisterCtx)
cancel()
if err != nil {
c.logger.Debug("failed to unregister: ", err)
}
c.closeRegistrationClient()
c.waitForActiveRequests(c.gracePeriod)
c.closeNow()
})
}
func (c *HTTP2Connection) forceClose() {
c.shutdownOnce.Do(func() {
c.closeNow()
})
}
func (c *HTTP2Connection) waitForActiveRequests(timeout time.Duration) {
if timeout <= 0 {
c.activeRequests.Wait()
return
}
done := make(chan struct{})
go func() {
c.activeRequests.Wait()
close(done)
}()
timer := time.NewTimer(timeout)
defer timer.Stop()
select {
case <-done:
case <-timer.C:
}
}
func (c *HTTP2Connection) closeRegistrationClient() {
c.registrationClose.Do(func() {
if c.registrationClient != nil {
_ = c.registrationClient.Close()
}
})
}
func (c *HTTP2Connection) closeNow() {
c.closeOnce.Do(func() {
_ = c.conn.Close()
if c.serveCancel != nil {
c.serveCancel()
}
c.closeRegistrationClient()
c.activeRequests.Wait()
})
}
// Close closes the HTTP/2 connection.
func (c *HTTP2Connection) Close() error {
c.forceClose()
return nil
}
// http2Stream wraps an HTTP/2 request body (reader) and a flush-writer (writer) as an io.ReadWriteCloser.
// Used for the control stream.
type http2Stream struct {
reader io.ReadCloser
writer io.Writer
}
func newHTTP2Stream(reader io.ReadCloser, writer io.Writer) *http2Stream {
return &http2Stream{reader: reader, writer: writer}
}
func (s *http2Stream) Read(p []byte) (int, error) { return s.reader.Read(p) }
func (s *http2Stream) Write(p []byte) (int, error) { return s.writer.Write(p) }
func (s *http2Stream) Close() error { return s.reader.Close() }
// http2FlushWriter wraps an http.ResponseWriter and flushes after every write.
type http2FlushWriter struct {
w http.ResponseWriter
flusher http.Flusher
}
func (w *http2FlushWriter) Write(p []byte) (int, error) {
n, err := w.w.Write(p)
if err == nil {
w.flusher.Flush()
}
return n, err
}
// http2DataStream wraps an HTTP/2 request/response pair as io.ReadWriteCloser for data streams.
type http2DataStream struct {
reader io.ReadCloser
writer http.ResponseWriter
flusher http.Flusher
state *http2FlushState
logger log.ContextLogger
}
func (s *http2DataStream) Read(p []byte) (int, error) {
return s.reader.Read(p)
}
func (s *http2DataStream) Write(p []byte) (n int, err error) {
defer func() {
if recovered := recover(); recovered != nil {
if s.logger != nil {
s.logger.Debug("recovered from HTTP/2 data stream panic: ", recovered, "\n", string(debug.Stack()))
}
n = 0
err = io.ErrClosedPipe
}
}()
n, err = s.writer.Write(p)
if err == nil && s.state != nil && s.state.shouldFlush {
s.flusher.Flush()
}
return n, err
}
func (s *http2DataStream) Close() error {
return s.reader.Close()
}
// http2ResponseWriter translates ConnectResponse metadata to HTTP/2 response headers.
type http2ResponseWriter struct {
writer http.ResponseWriter
flusher http.Flusher
headersSent bool
flushState *http2FlushState
}
func (w *http2ResponseWriter) AddTrailer(name, value string) {
if !w.headersSent {
return
}
w.writer.Header().Add(http2.TrailerPrefix+name, value)
}
func (w *http2ResponseWriter) WriteResponse(responseError error, metadata []Metadata) error {
if w.headersSent {
return nil
}
w.headersSent = true
if responseError != nil {
if hasFlowConnectRateLimited(metadata) {
w.writer.Header().Set(h2HeaderResponseMeta, h2ResponseMetaCloudflaredLimited)
} else {
w.writer.Header().Set(h2HeaderResponseMeta, h2ResponseMetaCloudflared)
}
w.writer.WriteHeader(http.StatusBadGateway)
w.flusher.Flush()
return nil
}
statusCode := http.StatusOK
userHeaders := make(http.Header)
for _, entry := range metadata {
if entry.Key == metadataHTTPStatus {
code, err := strconv.Atoi(entry.Val)
if err == nil {
statusCode = code
}
continue
}
if strings.HasPrefix(entry.Key, metadataHTTPHeader+":") {
headerName := strings.TrimPrefix(entry.Key, metadataHTTPHeader+":")
lower := strings.ToLower(headerName)
if lower == "content-length" {
w.writer.Header().Set(headerName, entry.Val)
}
if !isControlResponseHeader(lower) || isWebsocketClientHeader(lower) {
userHeaders.Add(headerName, entry.Val)
}
}
}
w.writer.Header().Set(h2HeaderResponseUser, SerializeHeaders(userHeaders))
w.writer.Header().Set(h2HeaderResponseMeta, h2ResponseMetaOrigin)
if w.flushState != nil && shouldFlushHTTPHeaders(userHeaders) {
w.flushState.shouldFlush = true
}
if statusCode == http.StatusSwitchingProtocols {
statusCode = http.StatusOK
}
w.writer.WriteHeader(statusCode)
if w.flushState != nil && w.flushState.shouldFlush {
w.flusher.Flush()
}
return nil
}
type http2FlushState struct {
shouldFlush bool
}
func shouldFlushHTTPHeaders(headers http.Header) bool {
if headers.Get(contentLengthHeader) == "" {
return true
}
if transferEncoding := strings.ToLower(headers.Get(transferEncodingHeader)); transferEncoding != "" && strings.Contains(transferEncoding, chunkTransferEncoding) {
return true
}
contentType := strings.ToLower(headers.Get(contentTypeHeader))
for _, flushable := range flushableContentTypes {
if strings.HasPrefix(contentType, flushable) {
return true
}
}
return false
}

View File

@@ -1,191 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"bytes"
"io"
"net/http"
"testing"
"github.com/sagernet/sing-box/log"
)
type captureHTTP2Writer struct {
header http.Header
flushCount int
statusCode int
body []byte
panicWrite bool
}
func (w *captureHTTP2Writer) Header() http.Header {
if w.header == nil {
w.header = make(http.Header)
}
return w.header
}
func (w *captureHTTP2Writer) WriteHeader(statusCode int) {
w.statusCode = statusCode
}
func (w *captureHTTP2Writer) Write(p []byte) (int, error) {
if w.panicWrite {
panic("write after close")
}
w.body = append(w.body, p...)
return len(p), nil
}
func (w *captureHTTP2Writer) Flush() {
w.flushCount++
}
func TestHTTP2NonStreamingResponseDoesNotFlush(t *testing.T) {
writer := &captureHTTP2Writer{}
flushState := &http2FlushState{}
respWriter := &http2ResponseWriter{
writer: writer,
flusher: writer,
flushState: flushState,
}
err := respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusOK, http.Header{
"Content-Type": []string{"application/json"},
"Content-Length": []string{"2"},
}))
if err != nil {
t.Fatal(err)
}
if writer.flushCount != 0 {
t.Fatalf("expected no header flush for non-streaming response, got %d", writer.flushCount)
}
stream := &http2DataStream{
writer: writer,
flusher: writer,
state: flushState,
logger: log.NewNOPFactory().NewLogger("test"),
}
if _, err := stream.Write([]byte("ok")); err != nil {
t.Fatal(err)
}
if writer.flushCount != 0 {
t.Fatalf("expected no body flush for non-streaming response, got %d", writer.flushCount)
}
}
func TestHTTP2StreamingResponsesFlush(t *testing.T) {
testCases := []struct {
name string
header http.Header
}{
{
name: "sse",
header: http.Header{
"Content-Type": []string{"text/event-stream"},
"Content-Length": []string{"1"},
},
},
{
name: "grpc",
header: http.Header{
"Content-Type": []string{"application/grpc"},
"Content-Length": []string{"1"},
},
},
{
name: "ndjson",
header: http.Header{
"Content-Type": []string{"application/x-ndjson"},
"Content-Length": []string{"1"},
},
},
{
name: "chunked",
header: http.Header{
"Content-Type": []string{"application/json"},
"Content-Length": []string{"-1"},
"Transfer-Encoding": []string{"chunked"},
},
},
{
name: "no-content-length",
header: http.Header{
"Content-Type": []string{"application/json"},
},
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
writer := &captureHTTP2Writer{}
flushState := &http2FlushState{}
respWriter := &http2ResponseWriter{
writer: writer,
flusher: writer,
flushState: flushState,
}
err := respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusOK, testCase.header))
if err != nil {
t.Fatal(err)
}
if writer.flushCount == 0 {
t.Fatal("expected header flush for streaming response")
}
stream := &http2DataStream{
writer: writer,
flusher: writer,
state: flushState,
logger: log.NewNOPFactory().NewLogger("test"),
}
if _, err := stream.Write([]byte("chunk")); err != nil {
t.Fatal(err)
}
if writer.flushCount < 2 {
t.Fatalf("expected body flush for streaming response, got %d flushes", writer.flushCount)
}
})
}
}
func TestHTTP2DataStreamWriteRecoversPanic(t *testing.T) {
writer := &captureHTTP2Writer{panicWrite: true}
stream := &http2DataStream{
writer: writer,
flusher: writer,
state: &http2FlushState{shouldFlush: true},
logger: log.NewNOPFactory().NewLogger("test"),
}
_, err := stream.Write([]byte("panic"))
if err != io.ErrClosedPipe {
t.Fatalf("expected io.ErrClosedPipe, got %v", err)
}
}
func TestHandleConfigurationUpdateDecodeFailureReturnsBadGateway(t *testing.T) {
writer := &captureHTTP2Writer{}
connection := &HTTP2Connection{
logger: log.NewNOPFactory().NewLogger("test"),
}
request, err := http.NewRequest(http.MethodPost, "https://example.com", bytes.NewBufferString("{"))
if err != nil {
t.Fatal(err)
}
connection.handleConfigurationUpdate(request, writer)
if writer.statusCode != http.StatusBadGateway {
t.Fatalf("expected status %d, got %d", http.StatusBadGateway, writer.statusCode)
}
if meta := writer.Header().Get(h2HeaderResponseMeta); meta != h2ResponseMetaCloudflared {
t.Fatalf("unexpected response meta: %q", meta)
}
if len(writer.body) != 0 {
t.Fatalf("expected empty response body, got %q", string(writer.body))
}
}

View File

@@ -1,436 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"fmt"
"io"
"net"
"sync"
"sync/atomic"
"time"
"github.com/sagernet/quic-go"
"github.com/sagernet/sing-box/log"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/google/uuid"
)
const (
quicEdgeSNI = "quic.cftunnel.com"
quicEdgeALPN = "argotunnel"
quicHandshakeIdleTimeout = 5 * time.Second
quicMaxIdleTimeout = 5 * time.Second
quicKeepAlivePeriod = 1 * time.Second
)
func quicInitialPacketSize(ipVersion int) uint16 {
initialPacketSize := uint16(1252)
if ipVersion == 4 {
initialPacketSize = 1232
}
return initialPacketSize
}
// QUICConnection manages a single QUIC connection to the Cloudflare edge.
type QUICConnection struct {
conn quicConnection
logger log.ContextLogger
edgeAddr *EdgeAddr
connIndex uint8
credentials Credentials
connectorID uuid.UUID
datagramVersion string
features []string
numPreviousAttempts uint8
gracePeriod time.Duration
registrationClient registrationRPCClient
registrationResult *RegistrationResult
onConnected func()
serveCtx context.Context
serveCancel context.CancelFunc
registrationClose sync.Once
shutdownOnce sync.Once
closeOnce sync.Once
}
type quicStreamHandle interface {
io.Reader
io.Writer
io.Closer
CancelRead(code quic.StreamErrorCode)
CancelWrite(code quic.StreamErrorCode)
SetWriteDeadline(t time.Time) error
}
type quicConnection interface {
OpenStream() (*quic.Stream, error)
AcceptStream(ctx context.Context) (*quic.Stream, error)
ReceiveDatagram(ctx context.Context) ([]byte, error)
SendDatagram(data []byte) error
LocalAddr() net.Addr
CloseWithError(code quic.ApplicationErrorCode, reason string) error
}
type closeableQUICConn struct {
*quic.Conn
udpConn *net.UDPConn
}
func (c *closeableQUICConn) CloseWithError(code quic.ApplicationErrorCode, reason string) error {
err := c.Conn.CloseWithError(code, reason)
_ = c.udpConn.Close()
return err
}
// NewQUICConnection dials the edge and establishes a QUIC connection.
func NewQUICConnection(
ctx context.Context,
edgeAddr *EdgeAddr,
connIndex uint8,
credentials Credentials,
connectorID uuid.UUID,
datagramVersion string,
features []string,
numPreviousAttempts uint8,
gracePeriod time.Duration,
tunnelDialer N.Dialer,
onConnected func(),
logger log.ContextLogger,
) (*QUICConnection, error) {
rootCAs, err := cloudflareRootCertPool()
if err != nil {
return nil, E.Cause(err, "load Cloudflare root CAs")
}
tlsConfig := newEdgeTLSConfig(rootCAs, quicEdgeSNI, []string{quicEdgeALPN})
quicConfig := &quic.Config{
HandshakeIdleTimeout: quicHandshakeIdleTimeout,
MaxIdleTimeout: quicMaxIdleTimeout,
KeepAlivePeriod: quicKeepAlivePeriod,
MaxIncomingStreams: 1 << 60,
MaxIncomingUniStreams: 1 << 60,
EnableDatagrams: true,
InitialPacketSize: quicInitialPacketSize(edgeAddr.IPVersion),
}
udpConn, err := createUDPConnForConnIndex(ctx, edgeAddr, tunnelDialer)
if err != nil {
return nil, E.Cause(err, "listen UDP for QUIC edge")
}
conn, err := quic.Dial(ctx, udpConn, edgeAddr.UDP, tlsConfig, quicConfig)
if err != nil {
udpConn.Close()
return nil, E.Cause(err, "dial QUIC edge")
}
return &QUICConnection{
conn: &closeableQUICConn{Conn: conn, udpConn: udpConn},
logger: logger,
edgeAddr: edgeAddr,
connIndex: connIndex,
credentials: credentials,
connectorID: connectorID,
datagramVersion: datagramVersion,
features: features,
numPreviousAttempts: numPreviousAttempts,
gracePeriod: gracePeriod,
onConnected: onConnected,
}, nil
}
// createUDPConnForConnIndex creates a UDP socket for QUIC via the tunnel dialer.
// Unlike cloudflared, we do not attempt to reuse previously-bound ports across
// reconnects — the dialer interface does not support specifying local ports,
// and fixed port binding is not important for our use case.
// We also do not apply Darwin-specific udp4/udp6 network selection to work around
// quic-go#3793 (DF bit on macOS dual-stack); the dialer controls network selection
// and this is a non-critical platform-specific limitation.
func createUDPConnForConnIndex(ctx context.Context, edgeAddr *EdgeAddr, tunnelDialer N.Dialer) (*net.UDPConn, error) {
packetConn, err := tunnelDialer.ListenPacket(ctx, M.SocksaddrFrom(edgeAddr.UDP.AddrPort().Addr(), edgeAddr.UDP.AddrPort().Port()))
if err != nil {
return nil, err
}
udpConn, ok := packetConn.(*net.UDPConn)
if !ok {
packetConn.Close()
return nil, fmt.Errorf("unexpected packet conn type %T", packetConn)
}
return udpConn, nil
}
// Serve runs the QUIC connection: registers, accepts streams, handles datagrams.
// Blocks until the context is cancelled or a fatal error occurs.
func (q *QUICConnection) Serve(ctx context.Context, handler StreamHandler) error {
controlStream, err := q.conn.OpenStream()
if err != nil {
return E.Cause(err, "open control stream")
}
err = q.register(ctx, controlStream)
if err != nil {
controlStream.Close()
q.Close()
return err
}
q.logger.Info("connected to ", q.registrationResult.Location,
" (connection ", q.registrationResult.ConnectionID, ")")
serveCtx, serveCancel := context.WithCancel(context.WithoutCancel(ctx))
q.serveCtx = serveCtx
q.serveCancel = serveCancel
errChan := make(chan error, 2)
go func() {
errChan <- q.acceptStreams(serveCtx, handler)
}()
go func() {
errChan <- q.handleDatagrams(serveCtx, handler)
}()
select {
case <-ctx.Done():
q.gracefulShutdown()
<-errChan
return ctx.Err()
case err = <-errChan:
q.forceClose()
if ctx.Err() != nil {
return ctx.Err()
}
return err
}
}
func (q *QUICConnection) register(ctx context.Context, stream *quic.Stream) error {
q.registrationClient = NewRegistrationClient(ctx, newStreamReadWriteCloser(stream))
host, _, _ := net.SplitHostPort(q.conn.LocalAddr().String())
originLocalIP := net.ParseIP(host)
options := BuildConnectionOptions(q.connectorID, q.features, q.numPreviousAttempts, originLocalIP)
result, err := q.registrationClient.RegisterConnection(
ctx, q.credentials.Auth(), q.credentials.TunnelID, q.connIndex, options,
)
if err != nil {
return E.Cause(err, "register connection")
}
if err := validateRegistrationResult(result); err != nil {
return err
}
q.registrationResult = result
if q.onConnected != nil {
q.onConnected()
}
return nil
}
func (q *QUICConnection) acceptStreams(ctx context.Context, handler StreamHandler) error {
for {
stream, err := q.conn.AcceptStream(ctx)
if err != nil {
return E.Cause(err, "accept stream")
}
go q.handleStream(ctx, stream, handler)
}
}
func (q *QUICConnection) handleStream(ctx context.Context, stream quicStreamHandle, handler StreamHandler) {
rwc := newStreamReadWriteCloser(stream)
defer rwc.Close()
streamType, err := ReadStreamSignature(rwc)
if err != nil {
q.logger.Debug("failed to read stream signature: ", err)
stream.CancelWrite(0)
return
}
switch streamType {
case StreamTypeData:
request, err := ReadConnectRequest(rwc)
if err != nil {
q.logger.Debug("failed to read connect request: ", err)
stream.CancelWrite(0)
return
}
handler.HandleDataStream(ctx, &nopCloserReadWriter{ReadWriteCloser: rwc}, request, q.connIndex)
case StreamTypeRPC:
handler.HandleRPCStreamWithSender(ctx, rwc, q.connIndex, q)
}
}
func (q *QUICConnection) handleDatagrams(ctx context.Context, handler StreamHandler) error {
for {
datagram, err := q.conn.ReceiveDatagram(ctx)
if err != nil {
return E.Cause(err, "receive datagram")
}
handler.HandleDatagram(ctx, datagram, q)
}
}
// SendDatagram sends a QUIC datagram to the edge.
func (q *QUICConnection) SendDatagram(data []byte) error {
return q.conn.SendDatagram(data)
}
func (q *QUICConnection) DatagramVersion() string {
return q.datagramVersion
}
func (q *QUICConnection) OpenRPCStream(ctx context.Context) (io.ReadWriteCloser, error) {
stream, err := q.conn.OpenStream()
if err != nil {
return nil, E.Cause(err, "open rpc stream")
}
rwc := newStreamReadWriteCloser(stream)
if err := WriteRPCStreamSignature(rwc); err != nil {
rwc.Close()
return nil, E.Cause(err, "write rpc stream signature")
}
return rwc, nil
}
func (q *QUICConnection) gracefulShutdown() {
q.shutdownOnce.Do(func() {
if q.registrationClient == nil || q.registrationResult == nil {
q.closeNow("connection closed")
return
}
ctx, cancel := context.WithTimeout(context.Background(), q.gracePeriod)
err := q.registrationClient.Unregister(ctx)
cancel()
if err != nil {
q.logger.Debug("failed to unregister: ", err)
}
q.closeRegistrationClient()
if q.gracePeriod > 0 {
waitCtx := q.serveCtx
if waitCtx == nil {
waitCtx = context.Background()
}
timer := time.NewTimer(q.gracePeriod)
defer timer.Stop()
select {
case <-timer.C:
case <-waitCtx.Done():
}
}
q.closeNow("graceful shutdown")
})
}
func (q *QUICConnection) forceClose() {
q.shutdownOnce.Do(func() {
q.closeNow("connection closed")
})
}
func (q *QUICConnection) closeRegistrationClient() {
q.registrationClose.Do(func() {
if q.registrationClient != nil {
_ = q.registrationClient.Close()
}
})
}
func (q *QUICConnection) closeNow(reason string) {
q.closeOnce.Do(func() {
if q.serveCancel != nil {
q.serveCancel()
}
q.closeRegistrationClient()
_ = q.conn.CloseWithError(0, reason)
})
}
// Close closes the QUIC connection immediately.
func (q *QUICConnection) Close() error {
q.forceClose()
return nil
}
// StreamHandler handles incoming edge streams and datagrams.
type StreamHandler interface {
HandleDataStream(ctx context.Context, stream io.ReadWriteCloser, request *ConnectRequest, connIndex uint8)
HandleRPCStream(ctx context.Context, stream io.ReadWriteCloser, connIndex uint8)
HandleRPCStreamWithSender(ctx context.Context, stream io.ReadWriteCloser, connIndex uint8, sender DatagramSender)
HandleDatagram(ctx context.Context, datagram []byte, sender DatagramSender)
}
// DatagramSender can send QUIC datagrams back to the edge.
type DatagramSender interface {
SendDatagram(data []byte) error
}
// streamReadWriteCloser adapts a *quic.Stream to io.ReadWriteCloser
// with mutex-protected writes and safe close semantics.
type streamReadWriteCloser struct {
stream quicStreamHandle
writeAccess sync.Mutex
}
func newStreamReadWriteCloser(stream quicStreamHandle) *streamReadWriteCloser {
return &streamReadWriteCloser{stream: stream}
}
func (s *streamReadWriteCloser) Read(p []byte) (int, error) {
return s.stream.Read(p)
}
func (s *streamReadWriteCloser) Write(p []byte) (int, error) {
s.writeAccess.Lock()
defer s.writeAccess.Unlock()
return s.stream.Write(p)
}
func (s *streamReadWriteCloser) Close() error {
_ = s.stream.SetWriteDeadline(time.Now())
s.writeAccess.Lock()
defer s.writeAccess.Unlock()
s.stream.CancelRead(0)
return s.stream.Close()
}
// nopCloserReadWriter lets handlers stop consuming the read side without closing
// the underlying stream write side. This matches cloudflared's QUIC HTTP behavior,
// where the request body can be closed before the response is fully written.
type nopCloserReadWriter struct {
io.ReadWriteCloser
sawEOF bool
closed uint32
}
func (n *nopCloserReadWriter) Read(p []byte) (int, error) {
if n.sawEOF {
return 0, io.EOF
}
if atomic.LoadUint32(&n.closed) > 0 {
return 0, fmt.Errorf("closed by handler")
}
readLen, err := n.ReadWriteCloser.Read(p)
if err == io.EOF {
n.sawEOF = true
}
return readLen, err
}
func (n *nopCloserReadWriter) Close() error {
atomic.StoreUint32(&n.closed, 1)
return nil
}

View File

@@ -1,129 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"io"
"strings"
"testing"
"time"
"github.com/sagernet/quic-go"
"github.com/sagernet/sing-box/log"
)
func TestQUICInitialPacketSize(t *testing.T) {
testCases := []struct {
name string
ipVersion int
expected uint16
}{
{name: "ipv4", ipVersion: 4, expected: 1232},
{name: "ipv6", ipVersion: 6, expected: 1252},
{name: "default", ipVersion: 0, expected: 1252},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
if actual := quicInitialPacketSize(testCase.ipVersion); actual != testCase.expected {
t.Fatalf("quicInitialPacketSize(%d) = %d, want %d", testCase.ipVersion, actual, testCase.expected)
}
})
}
}
type mockReadWriteCloser struct {
reader strings.Reader
writes []byte
}
func (m *mockReadWriteCloser) Read(p []byte) (int, error) {
return m.reader.Read(p)
}
func (m *mockReadWriteCloser) Write(p []byte) (int, error) {
m.writes = append(m.writes, p...)
return len(p), nil
}
func (m *mockReadWriteCloser) Close() error {
return nil
}
func TestNOPCloserReadWriterCloseOnlyStopsReads(t *testing.T) {
inner := &mockReadWriteCloser{reader: *strings.NewReader("payload")}
wrapper := &nopCloserReadWriter{ReadWriteCloser: inner}
if err := wrapper.Close(); err != nil {
t.Fatal(err)
}
if _, err := wrapper.Read(make([]byte, 1)); err == nil {
t.Fatal("expected read to fail after close")
}
if _, err := wrapper.Write([]byte("response")); err != nil {
t.Fatal(err)
}
if string(inner.writes) != "response" {
t.Fatalf("unexpected writes %q", inner.writes)
}
}
func TestNOPCloserReadWriterTracksEOF(t *testing.T) {
inner := &mockReadWriteCloser{reader: *strings.NewReader("")}
wrapper := &nopCloserReadWriter{ReadWriteCloser: inner}
if _, err := wrapper.Read(make([]byte, 1)); err != io.EOF {
t.Fatalf("expected EOF, got %v", err)
}
if _, err := wrapper.Read(make([]byte, 1)); err != io.EOF {
t.Fatalf("expected cached EOF, got %v", err)
}
}
type fakeQUICStream struct {
reader strings.Reader
cancelWriteCount int
}
func (s *fakeQUICStream) Read(p []byte) (int, error) { return s.reader.Read(p) }
func (s *fakeQUICStream) Write(p []byte) (int, error) { return len(p), nil }
func (s *fakeQUICStream) Close() error { return nil }
func (s *fakeQUICStream) CancelRead(quic.StreamErrorCode) {}
func (s *fakeQUICStream) CancelWrite(quic.StreamErrorCode) {
s.cancelWriteCount++
}
func (s *fakeQUICStream) SetWriteDeadline(time.Time) error { return nil }
func TestHandleStreamCancelsWriteOnSignatureError(t *testing.T) {
stream := &fakeQUICStream{reader: *strings.NewReader("broken")}
connection := &QUICConnection{logger: log.NewNOPFactory().NewLogger("test")}
connection.handleStream(context.Background(), stream, nil)
if stream.cancelWriteCount != 1 {
t.Fatalf("expected CancelWrite on signature error, got %d", stream.cancelWriteCount)
}
}
type nopStreamHandler struct{}
func (nopStreamHandler) HandleDataStream(context.Context, io.ReadWriteCloser, *ConnectRequest, uint8) {
}
func (nopStreamHandler) HandleRPCStream(context.Context, io.ReadWriteCloser, uint8) {}
func (nopStreamHandler) HandleRPCStreamWithSender(context.Context, io.ReadWriteCloser, uint8, DatagramSender) {
}
func (nopStreamHandler) HandleDatagram(context.Context, []byte, DatagramSender) {}
func TestHandleStreamCancelsWriteOnConnectRequestError(t *testing.T) {
stream := &fakeQUICStream{
reader: *strings.NewReader(string(dataStreamSignature[:])),
}
connection := &QUICConnection{logger: log.NewNOPFactory().NewLogger("test")}
connection.handleStream(context.Background(), stream, nopStreamHandler{})
if stream.cancelWriteCount != 1 {
t.Fatalf("expected CancelWrite on connect request error, got %d", stream.cancelWriteCount)
}
}

View File

@@ -1,222 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"errors"
"io"
"net"
"runtime"
"time"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc"
E "github.com/sagernet/sing/common/exceptions"
"github.com/google/uuid"
"zombiezen.com/go/capnproto2/pogs"
"zombiezen.com/go/capnproto2/rpc"
)
const (
rpcTimeout = 5 * time.Second
)
var clientVersion = "sing-box " + C.Version
// RegistrationClient handles the Cap'n Proto RPC for tunnel registration.
type RegistrationClient struct {
client tunnelrpc.TunnelServer
rpcConn *rpc.Conn
transport rpc.Transport
}
type registrationRPCClient interface {
RegisterConnection(
ctx context.Context,
auth TunnelAuth,
tunnelID uuid.UUID,
connIndex uint8,
options *RegistrationConnectionOptions,
) (*RegistrationResult, error)
Unregister(ctx context.Context) error
Close() error
}
type permanentRegistrationError struct {
Err error
}
func (e *permanentRegistrationError) Error() string {
if e == nil || e.Err == nil {
return "permanent registration error"
}
return e.Err.Error()
}
func (e *permanentRegistrationError) Unwrap() error {
if e == nil {
return nil
}
return e.Err
}
func isPermanentRegistrationError(err error) bool {
var permanentErr *permanentRegistrationError
return errors.As(err, &permanentErr)
}
// NewRegistrationClient creates a Cap'n Proto RPC client over the given stream.
// The stream should be the first QUIC stream (control stream).
func NewRegistrationClient(ctx context.Context, stream io.ReadWriteCloser) *RegistrationClient {
transport := safeTransport(stream)
conn := newRPCClientConn(transport, ctx)
return &RegistrationClient{
client: tunnelrpc.TunnelServer{Client: conn.Bootstrap(ctx)},
rpcConn: conn,
transport: transport,
}
}
// RegisterConnection registers this tunnel connection with the edge.
func (c *RegistrationClient) RegisterConnection(
ctx context.Context,
auth TunnelAuth,
tunnelID uuid.UUID,
connIndex uint8,
options *RegistrationConnectionOptions,
) (*RegistrationResult, error) {
ctx, cancel := context.WithTimeout(ctx, rpcTimeout)
defer cancel()
promise := c.client.RegisterConnection(ctx, func(p tunnelrpc.RegistrationServer_registerConnection_Params) error {
// Marshal TunnelAuth
tunnelAuth, err := p.NewAuth()
if err != nil {
return err
}
authPogs := &RegistrationTunnelAuth{
AccountTag: auth.AccountTag,
TunnelSecret: auth.TunnelSecret,
}
err = pogs.Insert(tunnelrpc.TunnelAuth_TypeID, tunnelAuth.Struct, authPogs)
if err != nil {
return err
}
// Set tunnel ID
err = p.SetTunnelId(tunnelID[:])
if err != nil {
return err
}
// Set connection index
p.SetConnIndex(connIndex)
// Marshal ConnectionOptions
connOptions, err := p.NewOptions()
if err != nil {
return err
}
return pogs.Insert(tunnelrpc.ConnectionOptions_TypeID, connOptions.Struct, options)
})
response, err := promise.Result().Struct()
if err != nil {
return nil, E.Cause(err, "registration RPC")
}
result := response.Result()
switch result.Which() {
case tunnelrpc.ConnectionResponse_result_Which_error:
resultError, err := result.Error()
if err != nil {
return nil, E.Cause(err, "read registration error")
}
cause, _ := resultError.Cause()
registrationError := E.New(cause)
if resultError.ShouldRetry() {
return nil, &RetryableError{
Err: registrationError,
Delay: time.Duration(resultError.RetryAfter()),
}
}
return nil, &permanentRegistrationError{Err: registrationError}
case tunnelrpc.ConnectionResponse_result_Which_connectionDetails:
connDetails, err := result.ConnectionDetails()
if err != nil {
return nil, E.Cause(err, "read connection details")
}
uuidBytes, err := connDetails.Uuid()
if err != nil {
return nil, E.Cause(err, "read connection UUID")
}
connectionID, err := uuid.FromBytes(uuidBytes)
if err != nil {
return nil, E.Cause(err, "parse connection UUID")
}
location, _ := connDetails.LocationName()
return &RegistrationResult{
ConnectionID: connectionID,
Location: location,
TunnelIsRemotelyManaged: connDetails.TunnelIsRemotelyManaged(),
}, nil
default:
return nil, E.New("unexpected registration response type")
}
}
// Unregister sends the UnregisterConnection RPC.
func (c *RegistrationClient) Unregister(ctx context.Context) error {
promise := c.client.UnregisterConnection(ctx, nil)
_, err := promise.Struct()
return err
}
// Close closes the RPC connection and transport.
func (c *RegistrationClient) Close() error {
return E.Errors(
c.rpcConn.Close(),
c.transport.Close(),
)
}
func validateRegistrationResult(result *RegistrationResult) error {
if result == nil || result.TunnelIsRemotelyManaged {
return nil
}
return ErrNonRemoteManagedTunnelUnsupported
}
// BuildConnectionOptions creates the ConnectionOptions to send during registration.
func BuildConnectionOptions(connectorID uuid.UUID, features []string, numPreviousAttempts uint8, originLocalIP net.IP) *RegistrationConnectionOptions {
return &RegistrationConnectionOptions{
Client: RegistrationClientInfo{
ClientID: connectorID[:],
Features: features,
Version: clientVersion,
Arch: runtime.GOOS + "_" + runtime.GOARCH,
},
ReplaceExisting: false,
CompressionQuality: 0,
OriginLocalIP: originLocalIP,
NumPreviousAttempts: numPreviousAttempts,
}
}
// DefaultFeatures returns the feature strings to advertise.
func DefaultFeatures(datagramVersion string) []string {
features := []string{
"serialized_headers",
"support_datagram_v2",
"support_quic_eof",
"allow_remote_config",
}
if datagramVersion == "v3" {
features = append(features, "support_datagram_v3_2")
}
return features
}

View File

@@ -1,44 +0,0 @@
//go:build with_cloudflared
package cloudflare
import "github.com/google/uuid"
// Credentials contains all info needed to run a tunnel.
type Credentials struct {
AccountTag string `json:"AccountTag"`
TunnelSecret []byte `json:"TunnelSecret"`
TunnelID uuid.UUID `json:"TunnelID"`
Endpoint string `json:"Endpoint,omitempty"`
}
// TunnelToken is the compact token format used in the --token flag.
// Field names match cloudflared's JSON encoding.
type TunnelToken struct {
AccountTag string `json:"a"`
TunnelSecret []byte `json:"s"`
TunnelID uuid.UUID `json:"t"`
Endpoint string `json:"e,omitempty"`
}
func (t TunnelToken) ToCredentials() Credentials {
return Credentials{
AccountTag: t.AccountTag,
TunnelSecret: t.TunnelSecret,
TunnelID: t.TunnelID,
Endpoint: t.Endpoint,
}
}
// TunnelAuth is the authentication data sent during tunnel registration.
type TunnelAuth struct {
AccountTag string
TunnelSecret []byte
}
func (c *Credentials) Auth() TunnelAuth {
return TunnelAuth{
AccountTag: c.AccountTag,
TunnelSecret: c.TunnelSecret,
}
}

View File

@@ -1,43 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"encoding/base64"
"testing"
"github.com/google/uuid"
)
func TestParseToken(t *testing.T) {
tunnelID := uuid.New()
secret := []byte("test-secret-32-bytes-long-xxxxx")
tokenJSON := `{"a":"account123","t":"` + tunnelID.String() + `","s":"` + base64.StdEncoding.EncodeToString(secret) + `"}`
token := base64.StdEncoding.EncodeToString([]byte(tokenJSON))
credentials, err := parseToken(token)
if err != nil {
t.Fatal("parseToken: ", err)
}
if credentials.AccountTag != "account123" {
t.Error("expected AccountTag account123, got ", credentials.AccountTag)
}
if credentials.TunnelID != tunnelID {
t.Error("expected TunnelID ", tunnelID, ", got ", credentials.TunnelID)
}
}
func TestParseTokenInvalidBase64(t *testing.T) {
_, err := parseToken("not-valid-base64!!!")
if err == nil {
t.Fatal("expected error for invalid base64")
}
}
func TestParseTokenInvalidJSON(t *testing.T) {
token := base64.StdEncoding.EncodeToString([]byte("{bad json"))
_, err := parseToken(token)
if err == nil {
t.Fatal("expected error for invalid JSON")
}
}

View File

@@ -1,204 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"encoding/binary"
"io"
"net"
"testing"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/google/uuid"
)
type v2UnregisterCall struct {
sessionID uuid.UUID
message string
}
type captureRPCDatagramSender struct {
captureDatagramSender
}
type captureV2SessionRPCClient struct {
unregisterCh chan v2UnregisterCall
}
type blockingPacketConn struct {
closed chan struct{}
}
func newBlockingPacketConn() *blockingPacketConn {
return &blockingPacketConn{closed: make(chan struct{})}
}
func (c *blockingPacketConn) ReadPacket(_ *buf.Buffer) (M.Socksaddr, error) {
<-c.closed
return M.Socksaddr{}, io.EOF
}
func (c *blockingPacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error {
buffer.Release()
return nil
}
func (c *blockingPacketConn) Close() error {
closeOnce(c.closed)
return nil
}
func (c *blockingPacketConn) LocalAddr() net.Addr { return &net.UDPAddr{} }
func (c *blockingPacketConn) SetDeadline(time.Time) error { return nil }
func (c *blockingPacketConn) SetReadDeadline(time.Time) error { return nil }
func (c *blockingPacketConn) SetWriteDeadline(time.Time) error { return nil }
type packetDialingRouter struct {
testRouter
packetConn N.PacketConn
}
func (r *packetDialingRouter) DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error) {
return r.packetConn, nil
}
func (c *captureV2SessionRPCClient) UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string) error {
c.unregisterCh <- v2UnregisterCall{sessionID: sessionID, message: message}
return nil
}
func (c *captureV2SessionRPCClient) Close() error { return nil }
func TestDatagramV2LocalCloseUnregistersRemote(t *testing.T) {
inboundInstance := newLimitedInbound(t, 0)
sender := &captureRPCDatagramSender{}
muxer := NewDatagramV2Muxer(inboundInstance, sender, inboundInstance.logger)
unregisterCh := make(chan v2UnregisterCall, 1)
originalClientFactory := newV2SessionRPCClient
newV2SessionRPCClient = func(ctx context.Context, sender DatagramSender) (v2SessionRPCClient, error) {
return &captureV2SessionRPCClient{unregisterCh: unregisterCh}, nil
}
defer func() {
newV2SessionRPCClient = originalClientFactory
}()
sessionID := uuidTest(7)
if err := muxer.RegisterSession(context.Background(), sessionID, net.IPv4(127, 0, 0, 1), 53, time.Second); err != nil {
t.Fatal(err)
}
muxer.sessionAccess.RLock()
session := muxer.sessions[sessionID]
muxer.sessionAccess.RUnlock()
if session == nil {
t.Fatal("expected registered session")
}
session.closeWithReason("local close")
select {
case call := <-unregisterCh:
if call.sessionID != sessionID {
t.Fatalf("unexpected session id: %s", call.sessionID)
}
if call.message != "local close" {
t.Fatalf("unexpected message: %q", call.message)
}
case <-time.After(2 * time.Second):
t.Fatal("expected unregister rpc")
}
}
func TestDatagramV3RegistrationMigratesSender(t *testing.T) {
inboundInstance := newLimitedInbound(t, 0)
sender1 := &captureDatagramSender{}
sender2 := &captureDatagramSender{}
muxer1 := NewDatagramV3Muxer(inboundInstance, sender1, inboundInstance.logger)
muxer2 := NewDatagramV3Muxer(inboundInstance, sender2, inboundInstance.logger)
requestID := RequestID{}
requestID[15] = 9
payload := make([]byte, 1+2+2+16+4)
payload[0] = 0
binary.BigEndian.PutUint16(payload[1:3], 53)
binary.BigEndian.PutUint16(payload[3:5], 30)
copy(payload[5:21], requestID[:])
copy(payload[21:25], []byte{127, 0, 0, 1})
muxer1.handleRegistration(context.Background(), payload)
session, exists := inboundInstance.datagramV3Manager.Get(requestID)
if !exists {
t.Fatal("expected v3 session after first registration")
}
muxer2.handleRegistration(context.Background(), payload)
session.senderAccess.RLock()
currentSender := session.sender
session.senderAccess.RUnlock()
if currentSender != sender2 {
t.Fatal("expected v3 session sender migration to second sender")
}
session.close()
}
func TestDatagramV3MigrationUpdatesSessionContext(t *testing.T) {
packetConn := newBlockingPacketConn()
inboundInstance := newLimitedInbound(t, 0)
inboundInstance.router = &packetDialingRouter{packetConn: packetConn}
sender1 := &captureDatagramSender{}
sender2 := &captureDatagramSender{}
muxer1 := NewDatagramV3Muxer(inboundInstance, sender1, inboundInstance.logger)
muxer2 := NewDatagramV3Muxer(inboundInstance, sender2, inboundInstance.logger)
requestID := RequestID{}
requestID[15] = 10
payload := make([]byte, 1+2+2+16+4)
payload[0] = 0
binary.BigEndian.PutUint16(payload[1:3], 53)
binary.BigEndian.PutUint16(payload[3:5], 30)
copy(payload[5:21], requestID[:])
copy(payload[21:25], []byte{127, 0, 0, 1})
ctx1, cancel1 := context.WithCancel(context.Background())
muxer1.handleRegistration(ctx1, payload)
ctx2, cancel2 := context.WithCancel(context.Background())
muxer2.handleRegistration(ctx2, payload)
cancel1()
time.Sleep(50 * time.Millisecond)
session, exists := inboundInstance.datagramV3Manager.Get(requestID)
if !exists {
t.Fatal("expected session to survive old connection context cancellation")
}
session.senderAccess.RLock()
currentSender := session.sender
session.senderAccess.RUnlock()
if currentSender != sender2 {
t.Fatal("expected migrated sender to stay active")
}
cancel2()
deadline := time.After(time.Second)
for {
if _, exists := inboundInstance.datagramV3Manager.Get(requestID); !exists {
return
}
select {
case <-deadline:
t.Fatal("expected session to be removed after new context cancellation")
case <-time.After(10 * time.Millisecond):
}
}
}

View File

@@ -1,231 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"net"
"testing"
"time"
"github.com/sagernet/sing-box/adapter/inbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc"
"github.com/google/uuid"
capnp "zombiezen.com/go/capnproto2"
)
func newRegisterUDPSessionCall(t *testing.T, traceContext string) (tunnelrpc.SessionManager_registerUdpSession, func() (tunnelrpc.RegisterUdpSessionResponse, error)) {
return newRegisterUDPSessionCallWithDstIP(t, []byte{127, 0, 0, 1}, traceContext)
}
func newRegisterUDPSessionCallWithDstIP(t *testing.T, dstIP []byte, traceContext string) (tunnelrpc.SessionManager_registerUdpSession, func() (tunnelrpc.RegisterUdpSessionResponse, error)) {
t.Helper()
_, paramsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil))
if err != nil {
t.Fatal(err)
}
params, err := tunnelrpc.NewSessionManager_registerUdpSession_Params(paramsSeg)
if err != nil {
t.Fatal(err)
}
sessionID := uuid.New()
if err := params.SetSessionId(sessionID[:]); err != nil {
t.Fatal(err)
}
if err := params.SetDstIp(dstIP); err != nil {
t.Fatal(err)
}
params.SetDstPort(53)
params.SetCloseAfterIdleHint(int64(30))
if err := params.SetTraceContext(traceContext); err != nil {
t.Fatal(err)
}
_, resultsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil))
if err != nil {
t.Fatal(err)
}
results, err := tunnelrpc.NewSessionManager_registerUdpSession_Results(resultsSeg)
if err != nil {
t.Fatal(err)
}
call := tunnelrpc.SessionManager_registerUdpSession{
Ctx: context.Background(),
Params: params,
Results: results,
}
return call, results.Result
}
func newUnregisterUDPSessionCall(t *testing.T) tunnelrpc.SessionManager_unregisterUdpSession {
t.Helper()
_, paramsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil))
if err != nil {
t.Fatal(err)
}
params, err := tunnelrpc.NewSessionManager_unregisterUdpSession_Params(paramsSeg)
if err != nil {
t.Fatal(err)
}
sessionID := uuid.New()
if err := params.SetSessionId(sessionID[:]); err != nil {
t.Fatal(err)
}
if err := params.SetMessage("close"); err != nil {
t.Fatal(err)
}
_, resultsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil))
if err != nil {
t.Fatal(err)
}
results, err := tunnelrpc.NewSessionManager_unregisterUdpSession_Results(resultsSeg)
if err != nil {
t.Fatal(err)
}
return tunnelrpc.SessionManager_unregisterUdpSession{
Ctx: context.Background(),
Params: params,
Results: results,
}
}
func newUnregisterUDPSessionCallForSession(t *testing.T, sessionID uuid.UUID, message string) tunnelrpc.SessionManager_unregisterUdpSession {
t.Helper()
_, paramsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil))
if err != nil {
t.Fatal(err)
}
params, err := tunnelrpc.NewSessionManager_unregisterUdpSession_Params(paramsSeg)
if err != nil {
t.Fatal(err)
}
if err := params.SetSessionId(sessionID[:]); err != nil {
t.Fatal(err)
}
if err := params.SetMessage(message); err != nil {
t.Fatal(err)
}
_, resultsSeg, err := capnp.NewMessage(capnp.SingleSegment(nil))
if err != nil {
t.Fatal(err)
}
results, err := tunnelrpc.NewSessionManager_unregisterUdpSession_Results(resultsSeg)
if err != nil {
t.Fatal(err)
}
return tunnelrpc.SessionManager_unregisterUdpSession{
Ctx: context.Background(),
Params: params,
Results: results,
}
}
func TestV3RPCRegisterUDPSessionReturnsUnsupportedResult(t *testing.T) {
server := &cloudflaredV3Server{
inbound: &Inbound{Adapter: inbound.NewAdapter(C.TypeCloudflared, "test")},
}
call, readResult := newRegisterUDPSessionCall(t, "trace-context")
if err := server.RegisterUdpSession(call); err != nil {
t.Fatal(err)
}
result, err := readResult()
if err != nil {
t.Fatal(err)
}
resultErr, err := result.Err()
if err != nil {
t.Fatal(err)
}
if resultErr != errUnsupportedDatagramV3UDPRegistration.Error() {
t.Fatalf("unexpected registration error %q", resultErr)
}
spans, err := result.Spans()
if err != nil {
t.Fatal(err)
}
if len(spans) != 0 {
t.Fatalf("expected empty spans, got %x", spans)
}
}
func TestV3RPCUnregisterUDPSessionReturnsUnsupportedError(t *testing.T) {
server := &cloudflaredV3Server{
inbound: &Inbound{Adapter: inbound.NewAdapter(C.TypeCloudflared, "test")},
}
err := server.UnregisterUdpSession(newUnregisterUDPSessionCall(t))
if err == nil {
t.Fatal("expected unsupported unregister error")
}
if err.Error() != errUnsupportedDatagramV3UDPUnregistration.Error() {
t.Fatalf("unexpected unregister error %v", err)
}
}
func TestV2RPCUnregisterUDPSessionPropagatesMessage(t *testing.T) {
inboundInstance := newLimitedInbound(t, 0)
inboundInstance.router = &packetDialingRouter{packetConn: newBlockingPacketConn()}
muxer := NewDatagramV2Muxer(inboundInstance, &captureDatagramSender{}, inboundInstance.logger)
sessionID := uuid.New()
if err := muxer.RegisterSession(context.Background(), sessionID, net.IPv4(127, 0, 0, 1), 53, time.Second); err != nil {
t.Fatal(err)
}
muxer.sessionAccess.RLock()
session := muxer.sessions[sessionID]
muxer.sessionAccess.RUnlock()
if session == nil {
t.Fatal("expected registered session")
}
server := &cloudflaredServer{
inbound: inboundInstance,
muxer: muxer,
ctx: context.Background(),
logger: inboundInstance.logger,
}
if err := server.UnregisterUdpSession(newUnregisterUDPSessionCallForSession(t, sessionID, "edge close")); err != nil {
t.Fatal(err)
}
if reason := session.closeReason(); reason != "edge close" {
t.Fatalf("expected close reason propagated from edge, got %q", reason)
}
}
func TestV2RPCRegisterUDPSessionRejectsMissingDestinationIP(t *testing.T) {
inboundInstance := newLimitedInbound(t, 0)
inboundInstance.router = &packetDialingRouter{packetConn: newBlockingPacketConn()}
server := &cloudflaredServer{
inbound: inboundInstance,
muxer: NewDatagramV2Muxer(inboundInstance, &captureDatagramSender{}, inboundInstance.logger),
ctx: context.Background(),
logger: inboundInstance.logger,
}
call, readResult := newRegisterUDPSessionCallWithDstIP(t, nil, "")
if err := server.RegisterUdpSession(call); err != nil {
t.Fatal(err)
}
result, err := readResult()
if err != nil {
t.Fatal(err)
}
resultErr, err := result.Err()
if err != nil {
t.Fatal(err)
}
if resultErr != "missing destination IP" {
t.Fatalf("unexpected result error %q", resultErr)
}
}

View File

@@ -1,78 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"errors"
"io"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc"
E "github.com/sagernet/sing/common/exceptions"
"zombiezen.com/go/capnproto2/server"
)
var (
errUnsupportedDatagramV3UDPRegistration = errors.New("datagram v3 does not support RegisterUdpSession RPC")
errUnsupportedDatagramV3UDPUnregistration = errors.New("datagram v3 does not support UnregisterUdpSession RPC")
)
type cloudflaredV3Server struct {
inbound *Inbound
logger log.ContextLogger
}
func (s *cloudflaredV3Server) RegisterUdpSession(call tunnelrpc.SessionManager_registerUdpSession) error {
result, err := call.Results.NewResult()
if err != nil {
return err
}
if err := result.SetErr(errUnsupportedDatagramV3UDPRegistration.Error()); err != nil {
return err
}
return result.SetSpans([]byte{})
}
func (s *cloudflaredV3Server) UnregisterUdpSession(call tunnelrpc.SessionManager_unregisterUdpSession) error {
return errUnsupportedDatagramV3UDPUnregistration
}
func (s *cloudflaredV3Server) UpdateConfiguration(call tunnelrpc.ConfigurationManager_updateConfiguration) error {
server.Ack(call.Options)
version := call.Params.Version()
configData, _ := call.Params.Config()
updateResult := s.inbound.ApplyConfig(version, configData)
result, err := call.Results.NewResult()
if err != nil {
return err
}
result.SetLatestAppliedVersion(updateResult.LastAppliedVersion)
if updateResult.Err != nil {
result.SetErr(updateResult.Err.Error())
} else {
result.SetErr("")
}
return nil
}
// ServeV3RPCStream serves configuration updates on v3 and rejects legacy UDP RPCs.
func ServeV3RPCStream(ctx context.Context, stream io.ReadWriteCloser, inbound *Inbound, logger log.ContextLogger) {
srv := &cloudflaredV3Server{
inbound: inbound,
logger: logger,
}
client := tunnelrpc.CloudflaredServer_ServerToClient(srv)
transport := safeTransport(stream)
rpcConn := newRPCServerConn(transport, client.Client)
rpcCtx, cancel := context.WithTimeout(ctx, rpcTimeout)
defer cancel()
select {
case <-rpcConn.Done():
case <-rpcCtx.Done():
}
E.Errors(
rpcConn.Close(),
transport.Close(),
)
}

View File

@@ -1,568 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"io"
"net"
"net/netip"
"sync"
"time"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/google/uuid"
"zombiezen.com/go/capnproto2/rpc"
"zombiezen.com/go/capnproto2/server"
)
// V2 wire format: [payload | 16B sessionID | 1B type] (suffix-based)
// DatagramV2Type identifies the type of a V2 datagram.
type DatagramV2Type byte
const (
DatagramV2TypeUDP DatagramV2Type = 0
DatagramV2TypeIP DatagramV2Type = 1
DatagramV2TypeIPWithTrace DatagramV2Type = 2
DatagramV2TypeTracingSpan DatagramV2Type = 3
sessionIDLength = 16
typeIDLength = 1
)
// DatagramV2Muxer handles V2 datagram demuxing and session management.
type DatagramV2Muxer struct {
inbound *Inbound
logger log.ContextLogger
sender DatagramSender
icmp *ICMPBridge
sessionAccess sync.RWMutex
sessions map[uuid.UUID]*udpSession
}
// NewDatagramV2Muxer creates a new V2 datagram muxer.
func NewDatagramV2Muxer(inbound *Inbound, sender DatagramSender, logger log.ContextLogger) *DatagramV2Muxer {
return &DatagramV2Muxer{
inbound: inbound,
logger: logger,
sender: sender,
icmp: NewICMPBridge(inbound, sender, icmpWireV2),
sessions: make(map[uuid.UUID]*udpSession),
}
}
type rpcStreamOpener interface {
OpenRPCStream(ctx context.Context) (io.ReadWriteCloser, error)
}
type v2SessionRPCClient interface {
UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string) error
Close() error
}
var newV2SessionRPCClient = func(ctx context.Context, sender DatagramSender) (v2SessionRPCClient, error) {
opener, ok := sender.(rpcStreamOpener)
if !ok {
return nil, E.New("sender does not support rpc streams")
}
stream, err := opener.OpenRPCStream(ctx)
if err != nil {
return nil, err
}
transport := safeTransport(stream)
conn := newRPCClientConn(transport, ctx)
return &capnpV2SessionRPCClient{
client: tunnelrpc.SessionManager{Client: conn.Bootstrap(ctx)},
rpcConn: conn,
transport: transport,
}, nil
}
type capnpV2SessionRPCClient struct {
client tunnelrpc.SessionManager
rpcConn *rpc.Conn
transport rpc.Transport
}
func (c *capnpV2SessionRPCClient) UnregisterSession(ctx context.Context, sessionID uuid.UUID, message string) error {
promise := c.client.UnregisterUdpSession(ctx, func(p tunnelrpc.SessionManager_unregisterUdpSession_Params) error {
if err := p.SetSessionId(sessionID[:]); err != nil {
return err
}
return p.SetMessage(message)
})
_, err := promise.Struct()
return err
}
func (c *capnpV2SessionRPCClient) Close() error {
return E.Errors(c.rpcConn.Close(), c.transport.Close())
}
// HandleDatagram demuxes an incoming V2 datagram.
func (m *DatagramV2Muxer) HandleDatagram(ctx context.Context, data []byte) {
if len(data) < typeIDLength {
return
}
datagramType := DatagramV2Type(data[len(data)-typeIDLength])
payload := data[:len(data)-typeIDLength]
switch datagramType {
case DatagramV2TypeUDP:
m.handleUDPDatagram(ctx, payload)
case DatagramV2TypeIP:
if err := m.icmp.HandleV2(ctx, datagramType, payload); err != nil {
m.logger.Debug("drop V2 ICMP datagram: ", err)
}
case DatagramV2TypeIPWithTrace:
if err := m.icmp.HandleV2(ctx, datagramType, payload); err != nil {
m.logger.Debug("drop V2 traced ICMP datagram: ", err)
}
case DatagramV2TypeTracingSpan:
// Tracing spans, ignore
}
}
func (m *DatagramV2Muxer) handleUDPDatagram(ctx context.Context, data []byte) {
if len(data) < sessionIDLength {
return
}
payload := data[:len(data)-sessionIDLength]
sessionID, err := uuid.FromBytes(data[len(data)-sessionIDLength:])
if err != nil {
m.logger.Debug("invalid session ID in V2 datagram: ", err)
return
}
m.sessionAccess.RLock()
session, exists := m.sessions[sessionID]
m.sessionAccess.RUnlock()
if !exists {
m.logger.Debug("unknown V2 UDP session: ", sessionID)
return
}
session.writeToOrigin(payload)
}
// RegisterSession registers a new UDP session from an RPC call.
func (m *DatagramV2Muxer) RegisterSession(
ctx context.Context,
sessionID uuid.UUID,
destinationIP net.IP,
destinationPort uint16,
closeAfterIdle time.Duration,
) error {
if destinationIP == nil {
return E.New("missing destination IP")
}
var destinationAddr netip.Addr
if ip4 := destinationIP.To4(); ip4 != nil {
destinationAddr = netip.AddrFrom4([4]byte(ip4))
} else if ip16 := destinationIP.To16(); ip16 != nil {
destinationAddr = netip.AddrFrom16([16]byte(ip16))
} else {
return E.New("invalid destination IP")
}
destination := netip.AddrPortFrom(destinationAddr, destinationPort)
if closeAfterIdle == 0 {
closeAfterIdle = 210 * time.Second
}
m.sessionAccess.Lock()
if _, exists := m.sessions[sessionID]; exists {
m.sessionAccess.Unlock()
return nil
}
limit := m.inbound.maxActiveFlows()
if !m.inbound.flowLimiter.Acquire(limit) {
m.sessionAccess.Unlock()
return E.New("too many active flows")
}
origin, err := m.inbound.dialWarpPacketConnection(ctx, destination)
if err != nil {
m.inbound.flowLimiter.Release(limit)
m.sessionAccess.Unlock()
return err
}
session := newUDPSession(sessionID, destination, closeAfterIdle, origin, m)
m.sessions[sessionID] = session
m.sessionAccess.Unlock()
m.logger.Info("registered V2 UDP session ", sessionID, " to ", destination)
go m.serveSession(ctx, session, limit)
return nil
}
// UnregisterSession removes a UDP session.
func (m *DatagramV2Muxer) UnregisterSession(sessionID uuid.UUID, message string) {
m.sessionAccess.Lock()
session, exists := m.sessions[sessionID]
if exists {
delete(m.sessions, sessionID)
}
m.sessionAccess.Unlock()
if exists {
session.markRemoteClosed(message)
session.close()
m.logger.Info("unregistered V2 UDP session ", sessionID)
}
}
func (m *DatagramV2Muxer) serveSession(ctx context.Context, session *udpSession, limit uint64) {
defer m.inbound.flowLimiter.Release(limit)
session.serve(ctx)
m.sessionAccess.Lock()
if current, exists := m.sessions[session.id]; exists && current == session {
delete(m.sessions, session.id)
}
m.sessionAccess.Unlock()
if !session.remoteClosed() {
unregisterCtx, cancel := context.WithTimeout(context.Background(), rpcTimeout)
defer cancel()
if err := m.unregisterRemoteSession(unregisterCtx, session.id, session.closeReason()); err != nil {
m.logger.Debug("failed to unregister V2 UDP session ", session.id, ": ", err)
}
}
}
// sendToEdge sends a V2 UDP datagram back to the edge.
func (m *DatagramV2Muxer) sendToEdge(sessionID uuid.UUID, payload []byte) {
data := make([]byte, len(payload)+sessionIDLength+typeIDLength)
copy(data, payload)
copy(data[len(payload):], sessionID[:])
data[len(data)-1] = byte(DatagramV2TypeUDP)
m.sender.SendDatagram(data)
}
// Close closes all sessions.
func (m *DatagramV2Muxer) Close() {
m.sessionAccess.Lock()
sessions := m.sessions
m.sessions = make(map[uuid.UUID]*udpSession)
m.sessionAccess.Unlock()
for _, session := range sessions {
session.close()
}
}
// udpSession represents a V2 UDP session.
type udpSession struct {
id uuid.UUID
destination netip.AddrPort
closeAfterIdle time.Duration
origin N.PacketConn
muxer *DatagramV2Muxer
writeChan chan []byte
closeOnce sync.Once
closeChan chan struct{}
activeAccess sync.RWMutex
activeAt time.Time
stateAccess sync.RWMutex
closedByRemote bool
closeReasonString string
}
func newUDPSession(id uuid.UUID, destination netip.AddrPort, closeAfterIdle time.Duration, origin N.PacketConn, muxer *DatagramV2Muxer) *udpSession {
return &udpSession{
id: id,
destination: destination,
closeAfterIdle: closeAfterIdle,
origin: origin,
muxer: muxer,
writeChan: make(chan []byte, 256),
closeChan: make(chan struct{}),
activeAt: time.Now(),
}
}
func (s *udpSession) writeToOrigin(payload []byte) {
data := make([]byte, len(payload))
copy(data, payload)
select {
case s.writeChan <- data:
default:
}
}
func (s *udpSession) close() {
s.closeOnce.Do(func() {
if s.origin != nil {
_ = s.origin.Close()
}
close(s.closeChan)
})
}
func (s *udpSession) serve(ctx context.Context) {
go s.readLoop()
go s.writeLoop()
tickInterval := s.closeAfterIdle / 2
if tickInterval <= 0 || tickInterval > 10*time.Second {
tickInterval = time.Second
}
ticker := time.NewTicker(tickInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
s.closeWithReason("connection closed")
case <-ticker.C:
if time.Since(s.lastActive()) >= s.closeAfterIdle {
s.closeWithReason("idle timeout")
}
case <-s.closeChan:
return
}
}
}
func (s *udpSession) readLoop() {
for {
buffer := buf.NewPacket()
_, err := s.origin.ReadPacket(buffer)
if err != nil {
buffer.Release()
s.closeWithReason(err.Error())
return
}
s.markActive()
s.muxer.sendToEdge(s.id, append([]byte(nil), buffer.Bytes()...))
buffer.Release()
}
}
func (s *udpSession) writeLoop() {
for {
select {
case payload := <-s.writeChan:
err := s.origin.WritePacket(buf.As(payload), M.SocksaddrFromNetIP(s.destination))
if err != nil {
s.closeWithReason(err.Error())
return
}
s.markActive()
case <-s.closeChan:
return
}
}
}
func (s *udpSession) markActive() {
s.activeAccess.Lock()
s.activeAt = time.Now()
s.activeAccess.Unlock()
}
func (s *udpSession) lastActive() time.Time {
s.activeAccess.RLock()
defer s.activeAccess.RUnlock()
return s.activeAt
}
func (s *udpSession) closeWithReason(reason string) {
s.stateAccess.Lock()
if s.closeReasonString == "" {
s.closeReasonString = reason
}
s.stateAccess.Unlock()
s.close()
}
func (s *udpSession) markRemoteClosed(message string) {
s.stateAccess.Lock()
s.closedByRemote = true
if message != "" {
s.closeReasonString = message
} else if s.closeReasonString == "" {
s.closeReasonString = "unregistered by edge"
}
s.stateAccess.Unlock()
}
func (s *udpSession) remoteClosed() bool {
s.stateAccess.RLock()
defer s.stateAccess.RUnlock()
return s.closedByRemote
}
func (s *udpSession) closeReason() string {
s.stateAccess.RLock()
defer s.stateAccess.RUnlock()
if s.closeReasonString == "" {
return "session closed"
}
return s.closeReasonString
}
// ReadPacket implements N.PacketConn - reads packets from the edge to forward to origin.
func (s *udpSession) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
select {
case data := <-s.writeChan:
_, err := buffer.Write(data)
return M.SocksaddrFromNetIP(s.destination), err
case <-s.closeChan:
return M.Socksaddr{}, io.EOF
}
}
// WritePacket implements N.PacketConn - receives packets from origin to forward to edge.
func (s *udpSession) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
s.muxer.sendToEdge(s.id, buffer.Bytes())
return nil
}
func (s *udpSession) Close() error {
s.close()
return nil
}
func (s *udpSession) LocalAddr() net.Addr { return nil }
func (s *udpSession) SetDeadline(_ time.Time) error { return nil }
func (s *udpSession) SetReadDeadline(_ time.Time) error { return nil }
func (s *udpSession) SetWriteDeadline(_ time.Time) error { return nil }
func (m *DatagramV2Muxer) unregisterRemoteSession(ctx context.Context, sessionID uuid.UUID, message string) error {
client, err := newV2SessionRPCClient(ctx, m.sender)
if err != nil {
return err
}
defer client.Close()
return client.UnregisterSession(ctx, sessionID, message)
}
// V2 RPC server implementation for HandleRPCStream.
type cloudflaredServer struct {
inbound *Inbound
muxer *DatagramV2Muxer
ctx context.Context
logger log.ContextLogger
}
func (s *cloudflaredServer) RegisterUdpSession(call tunnelrpc.SessionManager_registerUdpSession) error {
server.Ack(call.Options)
sessionIDBytes, err := call.Params.SessionId()
if err != nil {
return err
}
sessionID, err := uuid.FromBytes(sessionIDBytes)
if err != nil {
return err
}
destinationIP, err := call.Params.DstIp()
if err != nil {
return err
}
destinationPort := call.Params.DstPort()
closeAfterIdle := time.Duration(call.Params.CloseAfterIdleHint())
if _, traceErr := call.Params.TraceContext(); traceErr != nil {
return traceErr
}
if len(destinationIP) == 0 {
err = E.New("missing destination IP")
} else {
err = s.muxer.RegisterSession(s.ctx, sessionID, net.IP(destinationIP), destinationPort, closeAfterIdle)
}
result, allocErr := call.Results.NewResult()
if allocErr != nil {
return allocErr
}
if spansErr := result.SetSpans([]byte{}); spansErr != nil {
return spansErr
}
if err != nil {
result.SetErr(err.Error())
}
return nil
}
func (s *cloudflaredServer) UnregisterUdpSession(call tunnelrpc.SessionManager_unregisterUdpSession) error {
server.Ack(call.Options)
sessionIDBytes, err := call.Params.SessionId()
if err != nil {
return err
}
sessionID, err := uuid.FromBytes(sessionIDBytes)
if err != nil {
return err
}
message, err := call.Params.Message()
if err != nil {
return err
}
s.muxer.UnregisterSession(sessionID, message)
return nil
}
func (s *cloudflaredServer) UpdateConfiguration(call tunnelrpc.ConfigurationManager_updateConfiguration) error {
server.Ack(call.Options)
version := call.Params.Version()
configData, _ := call.Params.Config()
updateResult := s.inbound.ApplyConfig(version, configData)
result, err := call.Results.NewResult()
if err != nil {
return err
}
result.SetLatestAppliedVersion(updateResult.LastAppliedVersion)
if updateResult.Err != nil {
result.SetErr(updateResult.Err.Error())
} else {
result.SetErr("")
}
return nil
}
// ServeRPCStream handles an incoming V2 RPC stream (session management + configuration).
func ServeRPCStream(ctx context.Context, stream io.ReadWriteCloser, inbound *Inbound, muxer *DatagramV2Muxer, logger log.ContextLogger) {
srv := &cloudflaredServer{
inbound: inbound,
muxer: muxer,
ctx: ctx,
logger: logger,
}
client := tunnelrpc.CloudflaredServer_ServerToClient(srv)
transport := safeTransport(stream)
rpcConn := newRPCServerConn(transport, client.Client)
rpcCtx, cancel := context.WithTimeout(ctx, rpcTimeout)
defer cancel()
select {
case <-rpcConn.Done():
case <-rpcCtx.Done():
}
E.Errors(
rpcConn.Close(),
transport.Close(),
)
}

View File

@@ -1,483 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"encoding/binary"
"errors"
"net/netip"
"os"
"sync"
"time"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
// V3 wire format: [1B type | payload] (prefix-based)
// DatagramV3Type identifies the type of a V3 datagram.
type DatagramV3Type byte
const (
DatagramV3TypeRegistration DatagramV3Type = 0
DatagramV3TypePayload DatagramV3Type = 1
DatagramV3TypeICMP DatagramV3Type = 2
DatagramV3TypeRegistrationResponse DatagramV3Type = 3
// V3 registration header sizes
v3RegistrationFlagLen = 1
v3RegistrationPortLen = 2
v3RegistrationIdleLen = 2
v3RequestIDLength = 16
v3IPv4AddrLen = 4
v3IPv6AddrLen = 16
v3RegistrationBaseLen = 1 + v3RegistrationFlagLen + v3RegistrationPortLen + v3RegistrationIdleLen + v3RequestIDLength // 22
v3PayloadHeaderLen = 1 + v3RequestIDLength // 17
v3RegistrationRespLen = 1 + 1 + v3RequestIDLength + 2 // 20
maxV3UDPPayloadLen = 1280
// V3 registration flags
v3FlagIPv6 byte = 0x01
v3FlagTraced byte = 0x02
v3FlagBundle byte = 0x04
// V3 registration response types
v3ResponseOK byte = 0x00
v3ResponseDestinationUnreachable byte = 0x01
v3ResponseUnableToBindSocket byte = 0x02
v3ResponseTooManyActiveFlows byte = 0x03
v3ResponseErrorWithMsg byte = 0xFF
)
// RequestID is a 128-bit session identifier for V3.
type RequestID [v3RequestIDLength]byte
type v3RegistrationState uint8
const (
v3RegistrationNew v3RegistrationState = iota
v3RegistrationExisting
v3RegistrationMigrated
)
type DatagramV3SessionManager struct {
sessionAccess sync.RWMutex
sessions map[RequestID]*v3Session
}
func NewDatagramV3SessionManager() *DatagramV3SessionManager {
return &DatagramV3SessionManager{
sessions: make(map[RequestID]*v3Session),
}
}
// DatagramV3Muxer handles V3 datagram demuxing and session management.
type DatagramV3Muxer struct {
inbound *Inbound
logger log.ContextLogger
sender DatagramSender
icmp *ICMPBridge
}
// NewDatagramV3Muxer creates a new V3 datagram muxer.
func NewDatagramV3Muxer(inbound *Inbound, sender DatagramSender, logger log.ContextLogger) *DatagramV3Muxer {
return &DatagramV3Muxer{
inbound: inbound,
logger: logger,
sender: sender,
icmp: NewICMPBridge(inbound, sender, icmpWireV3),
}
}
// HandleDatagram demuxes an incoming V3 datagram.
func (m *DatagramV3Muxer) HandleDatagram(ctx context.Context, data []byte) {
if len(data) < 1 {
return
}
datagramType := DatagramV3Type(data[0])
payload := data[1:]
switch datagramType {
case DatagramV3TypeRegistration:
m.handleRegistration(ctx, payload)
case DatagramV3TypePayload:
m.handlePayload(payload)
case DatagramV3TypeICMP:
if err := m.icmp.HandleV3(ctx, payload); err != nil {
m.logger.Debug("drop V3 ICMP datagram: ", err)
}
case DatagramV3TypeRegistrationResponse:
// Unexpected - we never send registrations
m.logger.Debug("received unexpected V3 registration response")
}
}
func (m *DatagramV3Muxer) handleRegistration(ctx context.Context, data []byte) {
if len(data) < v3RegistrationFlagLen+v3RegistrationPortLen+v3RegistrationIdleLen+v3RequestIDLength {
m.logger.Debug("V3 registration too short")
return
}
flags := data[0]
destinationPort := binary.BigEndian.Uint16(data[1:3])
idleDurationSeconds := binary.BigEndian.Uint16(data[3:5])
var requestID RequestID
copy(requestID[:], data[5:5+v3RequestIDLength])
offset := 5 + v3RequestIDLength
var destination netip.AddrPort
if flags&v3FlagIPv6 != 0 {
if len(data) < offset+v3IPv6AddrLen {
m.sendRegistrationResponse(requestID, v3ResponseErrorWithMsg, "registration too short for IPv6")
return
}
var addr [16]byte
copy(addr[:], data[offset:offset+v3IPv6AddrLen])
destination = netip.AddrPortFrom(netip.AddrFrom16(addr), destinationPort)
offset += v3IPv6AddrLen
} else {
if len(data) < offset+v3IPv4AddrLen {
m.sendRegistrationResponse(requestID, v3ResponseErrorWithMsg, "registration too short for IPv4")
return
}
var addr [4]byte
copy(addr[:], data[offset:offset+v3IPv4AddrLen])
destination = netip.AddrPortFrom(netip.AddrFrom4(addr), destinationPort)
offset += v3IPv4AddrLen
}
closeAfterIdle := time.Duration(idleDurationSeconds) * time.Second
if closeAfterIdle == 0 {
closeAfterIdle = 210 * time.Second
}
if !destination.Addr().IsValid() || destination.Addr().IsUnspecified() || destination.Port() == 0 {
m.sendRegistrationResponse(requestID, v3ResponseDestinationUnreachable, "")
return
}
session, state, err := m.inbound.datagramV3Manager.Register(m.inbound, ctx, requestID, destination, closeAfterIdle, m.sender)
if err == errTooManyActiveFlows {
m.sendRegistrationResponse(requestID, v3ResponseTooManyActiveFlows, "")
return
}
if err != nil {
m.sendRegistrationResponse(requestID, v3ResponseUnableToBindSocket, "")
return
}
if state == v3RegistrationNew {
m.logger.Info("registered V3 UDP session to ", destination)
}
m.sendRegistrationResponse(requestID, v3ResponseOK, "")
// Handle bundled first payload
if flags&v3FlagBundle != 0 && len(data) > offset {
session.writeToOrigin(data[offset:])
}
}
func (m *DatagramV3Muxer) handlePayload(data []byte) {
if len(data) < v3RequestIDLength || len(data) > v3RequestIDLength+maxV3UDPPayloadLen {
return
}
var requestID RequestID
copy(requestID[:], data[:v3RequestIDLength])
payload := data[v3RequestIDLength:]
session, exists := m.inbound.datagramV3Manager.Get(requestID)
if !exists {
return
}
session.writeToOrigin(payload)
}
func (m *DatagramV3Muxer) sendRegistrationResponse(requestID RequestID, responseType byte, errorMessage string) {
errorBytes := []byte(errorMessage)
data := make([]byte, v3RegistrationRespLen+len(errorBytes))
data[0] = byte(DatagramV3TypeRegistrationResponse)
data[1] = responseType
copy(data[2:2+v3RequestIDLength], requestID[:])
binary.BigEndian.PutUint16(data[2+v3RequestIDLength:], uint16(len(errorBytes)))
copy(data[v3RegistrationRespLen:], errorBytes)
m.sender.SendDatagram(data)
}
func (m *DatagramV3Muxer) sendPayload(requestID RequestID, payload []byte) {
data := make([]byte, v3PayloadHeaderLen+len(payload))
data[0] = byte(DatagramV3TypePayload)
copy(data[1:1+v3RequestIDLength], requestID[:])
copy(data[v3PayloadHeaderLen:], payload)
m.sender.SendDatagram(data)
}
// Close closes all V3 sessions.
func (m *DatagramV3Muxer) Close() {}
// v3Session represents a V3 UDP session.
type v3Session struct {
id RequestID
destination netip.AddrPort
closeAfterIdle time.Duration
origin N.PacketConn
manager *DatagramV3SessionManager
inbound *Inbound
writeChan chan []byte
closeOnce sync.Once
closeChan chan struct{}
activeAccess sync.RWMutex
activeAt time.Time
senderAccess sync.RWMutex
sender DatagramSender
contextAccess sync.RWMutex
connCtx context.Context
contextChan chan context.Context
}
var errTooManyActiveFlows = errors.New("too many active flows")
func (m *DatagramV3SessionManager) Register(
inbound *Inbound,
ctx context.Context,
requestID RequestID,
destination netip.AddrPort,
closeAfterIdle time.Duration,
sender DatagramSender,
) (*v3Session, v3RegistrationState, error) {
m.sessionAccess.Lock()
if existing, exists := m.sessions[requestID]; exists {
if existing.sender == sender {
existing.updateContext(ctx)
existing.markActive()
m.sessionAccess.Unlock()
return existing, v3RegistrationExisting, nil
}
existing.migrate(sender, ctx)
existing.markActive()
m.sessionAccess.Unlock()
return existing, v3RegistrationMigrated, nil
}
limit := inbound.maxActiveFlows()
if !inbound.flowLimiter.Acquire(limit) {
m.sessionAccess.Unlock()
return nil, 0, errTooManyActiveFlows
}
origin, err := inbound.dialWarpPacketConnection(ctx, destination)
if err != nil {
inbound.flowLimiter.Release(limit)
m.sessionAccess.Unlock()
return nil, 0, err
}
session := &v3Session{
id: requestID,
destination: destination,
closeAfterIdle: closeAfterIdle,
origin: origin,
manager: m,
inbound: inbound,
writeChan: make(chan []byte, 512),
closeChan: make(chan struct{}),
activeAt: time.Now(),
sender: sender,
connCtx: ctx,
contextChan: make(chan context.Context, 1),
}
m.sessions[requestID] = session
m.sessionAccess.Unlock()
sessionCtx := ctx
if sessionCtx == nil {
sessionCtx = context.Background()
}
session.connCtx = sessionCtx
go session.serve(sessionCtx, limit)
return session, v3RegistrationNew, nil
}
func (m *DatagramV3SessionManager) Get(requestID RequestID) (*v3Session, bool) {
m.sessionAccess.RLock()
defer m.sessionAccess.RUnlock()
session, exists := m.sessions[requestID]
return session, exists
}
func (m *DatagramV3SessionManager) remove(session *v3Session) {
m.sessionAccess.Lock()
if current, exists := m.sessions[session.id]; exists && current == session {
delete(m.sessions, session.id)
}
m.sessionAccess.Unlock()
}
func (s *v3Session) serve(ctx context.Context, limit uint64) {
defer s.inbound.flowLimiter.Release(limit)
defer s.manager.remove(s)
go s.readLoop()
go s.writeLoop()
connCtx := ctx
tickInterval := s.closeAfterIdle / 2
if tickInterval <= 0 || tickInterval > 10*time.Second {
tickInterval = time.Second
}
ticker := time.NewTicker(tickInterval)
defer ticker.Stop()
for {
select {
case <-connCtx.Done():
if latestCtx := s.currentContext(); latestCtx != nil && latestCtx != connCtx {
connCtx = latestCtx
continue
}
s.close()
case newCtx := <-s.contextChan:
if newCtx != nil {
connCtx = newCtx
}
case <-ticker.C:
if time.Since(s.lastActive()) >= s.closeAfterIdle {
s.close()
}
case <-s.closeChan:
return
}
}
}
func (s *v3Session) readLoop() {
for {
buffer := buf.NewPacket()
_, err := s.origin.ReadPacket(buffer)
if err != nil {
buffer.Release()
s.close()
return
}
if buffer.Len() > maxV3UDPPayloadLen {
s.inbound.logger.Debug("drop oversized V3 UDP payload: ", buffer.Len())
buffer.Release()
continue
}
s.markActive()
if err := s.senderDatagram(append([]byte(nil), buffer.Bytes()...)); err != nil {
buffer.Release()
s.close()
return
}
buffer.Release()
}
}
func (s *v3Session) writeLoop() {
for {
select {
case payload := <-s.writeChan:
err := s.origin.WritePacket(buf.As(payload), M.SocksaddrFromNetIP(s.destination))
if err != nil {
if errors.Is(err, os.ErrDeadlineExceeded) {
s.inbound.logger.Debug("drop V3 UDP payload due to write deadline exceeded")
continue
}
s.close()
return
}
s.markActive()
case <-s.closeChan:
return
}
}
}
func (s *v3Session) writeToOrigin(payload []byte) {
data := make([]byte, len(payload))
copy(data, payload)
select {
case s.writeChan <- data:
default:
}
}
func (s *v3Session) senderDatagram(payload []byte) error {
data := make([]byte, v3PayloadHeaderLen+len(payload))
data[0] = byte(DatagramV3TypePayload)
copy(data[1:1+v3RequestIDLength], s.id[:])
copy(data[v3PayloadHeaderLen:], payload)
s.senderAccess.RLock()
sender := s.sender
s.senderAccess.RUnlock()
return sender.SendDatagram(data)
}
func (s *v3Session) setSender(sender DatagramSender) {
s.senderAccess.Lock()
s.sender = sender
s.senderAccess.Unlock()
}
func (s *v3Session) updateContext(ctx context.Context) {
if ctx == nil {
return
}
s.contextAccess.Lock()
s.connCtx = ctx
s.contextAccess.Unlock()
select {
case s.contextChan <- ctx:
default:
select {
case <-s.contextChan:
default:
}
s.contextChan <- ctx
}
}
func (s *v3Session) migrate(sender DatagramSender, ctx context.Context) {
s.setSender(sender)
s.updateContext(ctx)
}
func (s *v3Session) currentContext() context.Context {
s.contextAccess.RLock()
defer s.contextAccess.RUnlock()
return s.connCtx
}
func (s *v3Session) markActive() {
s.activeAccess.Lock()
s.activeAt = time.Now()
s.activeAccess.Unlock()
}
func (s *v3Session) lastActive() time.Time {
s.activeAccess.RLock()
defer s.activeAccess.RUnlock()
return s.activeAt
}
func (s *v3Session) close() {
s.closeOnce.Do(func() {
if s.origin != nil {
_ = s.origin.Close()
}
close(s.closeChan)
})
}

View File

@@ -1,232 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"encoding/binary"
"errors"
"io"
"net"
"net/netip"
"os"
"testing"
"time"
"github.com/sagernet/sing-box/adapter/inbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
)
func TestDatagramV3RegistrationDestinationUnreachable(t *testing.T) {
sender := &captureDatagramSender{}
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
flowLimiter: &FlowLimiter{},
datagramV3Manager: NewDatagramV3SessionManager(),
}
muxer := NewDatagramV3Muxer(inboundInstance, sender, nil)
requestID := RequestID{}
requestID[15] = 1
payload := make([]byte, 1+2+2+16+4)
payload[0] = 0
binary.BigEndian.PutUint16(payload[1:3], 0)
binary.BigEndian.PutUint16(payload[3:5], 30)
copy(payload[5:21], requestID[:])
copy(payload[21:25], []byte{0, 0, 0, 0})
muxer.handleRegistration(context.Background(), payload)
if len(sender.sent) != 1 {
t.Fatalf("expected one registration response, got %d", len(sender.sent))
}
if sender.sent[0][0] != byte(DatagramV3TypeRegistrationResponse) || sender.sent[0][1] != v3ResponseDestinationUnreachable {
t.Fatalf("unexpected datagram response: %v", sender.sent[0])
}
}
func TestDatagramV3RegistrationErrorWithMessage(t *testing.T) {
sender := &captureDatagramSender{}
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
flowLimiter: &FlowLimiter{},
datagramV3Manager: NewDatagramV3SessionManager(),
}
muxer := NewDatagramV3Muxer(inboundInstance, sender, nil)
requestID := RequestID{}
requestID[15] = 2
payload := make([]byte, 1+2+2+16+1)
payload[0] = 1
binary.BigEndian.PutUint16(payload[1:3], 53)
binary.BigEndian.PutUint16(payload[3:5], 30)
copy(payload[5:21], requestID[:])
payload[21] = 0xaa
muxer.handleRegistration(context.Background(), payload)
if len(sender.sent) != 1 {
t.Fatalf("expected one registration response, got %d", len(sender.sent))
}
if sender.sent[0][0] != byte(DatagramV3TypeRegistrationResponse) || sender.sent[0][1] != v3ResponseErrorWithMsg {
t.Fatalf("unexpected datagram response: %v", sender.sent[0])
}
}
type scriptedPacketConn struct {
reads [][]byte
index int
}
func (c *scriptedPacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
if c.index >= len(c.reads) {
return M.Socksaddr{}, io.EOF
}
_, err := buffer.Write(c.reads[c.index])
c.index++
return M.Socksaddr{}, err
}
func (c *scriptedPacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error {
buffer.Release()
return nil
}
func (c *scriptedPacketConn) Close() error { return nil }
func (c *scriptedPacketConn) LocalAddr() net.Addr { return &net.UDPAddr{} }
func (c *scriptedPacketConn) SetDeadline(time.Time) error { return nil }
func (c *scriptedPacketConn) SetReadDeadline(time.Time) error { return nil }
func (c *scriptedPacketConn) SetWriteDeadline(time.Time) error { return nil }
type sizeLimitedSender struct {
sent [][]byte
max int
}
func (s *sizeLimitedSender) SendDatagram(data []byte) error {
if len(data) > s.max {
return errors.New("datagram too large")
}
s.sent = append(s.sent, append([]byte(nil), data...))
return nil
}
func TestDatagramV3ReadLoopDropsOversizedOriginPackets(t *testing.T) {
logger := log.NewNOPFactory().NewLogger("test")
sender := &sizeLimitedSender{max: v3PayloadHeaderLen + maxV3UDPPayloadLen}
session := &v3Session{
id: RequestID{},
destination: netip.MustParseAddrPort("127.0.0.1:53"),
origin: &scriptedPacketConn{reads: [][]byte{
make([]byte, maxV3UDPPayloadLen+1),
[]byte("ok"),
}},
inbound: &Inbound{
logger: logger,
},
writeChan: make(chan []byte, 1),
closeChan: make(chan struct{}),
contextChan: make(chan context.Context, 1),
sender: sender,
}
done := make(chan struct{})
go func() {
session.readLoop()
close(done)
}()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("expected read loop to finish")
}
if len(sender.sent) != 1 {
t.Fatalf("expected one datagram after dropping oversized payload, got %d", len(sender.sent))
}
if len(sender.sent[0]) != v3PayloadHeaderLen+2 {
t.Fatalf("unexpected forwarded datagram length: %d", len(sender.sent[0]))
}
}
func TestDatagramV3HandlePayloadDropsOversizedPayload(t *testing.T) {
requestID := RequestID{}
requestID[15] = 9
session := &v3Session{
id: requestID,
writeChan: make(chan []byte, 1),
}
manager := NewDatagramV3SessionManager()
manager.sessions[requestID] = session
muxer := &DatagramV3Muxer{
inbound: &Inbound{
datagramV3Manager: manager,
},
}
payload := make([]byte, v3RequestIDLength+maxV3UDPPayloadLen+1)
copy(payload[:v3RequestIDLength], requestID[:])
muxer.handlePayload(payload)
select {
case <-session.writeChan:
t.Fatal("expected oversized payload to be dropped")
default:
}
}
type deadlinePacketConn struct {
err error
}
func (c *deadlinePacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
buffer.Release()
return M.Socksaddr{}, io.EOF
}
func (c *deadlinePacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error {
buffer.Release()
return c.err
}
func (c *deadlinePacketConn) Close() error { return nil }
func (c *deadlinePacketConn) LocalAddr() net.Addr { return &net.UDPAddr{} }
func (c *deadlinePacketConn) SetDeadline(time.Time) error { return nil }
func (c *deadlinePacketConn) SetReadDeadline(time.Time) error { return nil }
func (c *deadlinePacketConn) SetWriteDeadline(time.Time) error { return nil }
func TestDatagramV3WriteLoopDropsDeadlineExceeded(t *testing.T) {
session := &v3Session{
destination: netip.MustParseAddrPort("127.0.0.1:53"),
origin: &deadlinePacketConn{err: os.ErrDeadlineExceeded},
inbound: &Inbound{
logger: log.NewNOPFactory().NewLogger("test"),
},
writeChan: make(chan []byte, 1),
closeChan: make(chan struct{}),
}
done := make(chan struct{})
go func() {
session.writeLoop()
close(done)
}()
session.writeToOrigin([]byte("payload"))
time.Sleep(50 * time.Millisecond)
select {
case <-session.closeChan:
t.Fatal("expected session to remain open after deadline exceeded")
default:
}
session.close()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("expected write loop to exit after manual close")
}
}

View File

@@ -1,206 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
stdTLS "crypto/tls"
"fmt"
"io"
"net"
"net/http"
"net/url"
"os"
"testing"
"time"
boxTLS "github.com/sagernet/sing-box/common/tls"
"github.com/sagernet/sing-box/log"
)
func TestNewDirectOriginTransportUnix(t *testing.T) {
socketPath := fmt.Sprintf("/tmp/cf-origin-%d.sock", time.Now().UnixNano())
_ = os.Remove(socketPath)
t.Cleanup(func() { _ = os.Remove(socketPath) })
listener, err := net.Listen("unix", socketPath)
if err != nil {
t.Fatal(err)
}
defer listener.Close()
go serveTestHTTPOverListener(listener, func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(http.StatusOK)
_, _ = writer.Write([]byte("unix-ok"))
})
inboundInstance := &Inbound{}
transport, cleanup, err := inboundInstance.newDirectOriginTransport(ResolvedService{
Kind: ResolvedServiceUnix,
UnixPath: socketPath,
BaseURL: &url.URL{
Scheme: "http",
Host: "localhost",
},
}, "")
if err != nil {
t.Fatal(err)
}
defer cleanup()
client := &http.Client{Transport: transport}
resp, err := client.Get("http://localhost/")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if string(body) != "unix-ok" {
t.Fatalf("unexpected response body: %q", string(body))
}
}
func TestNewDirectOriginTransportUnixTLS(t *testing.T) {
socketPath := fmt.Sprintf("/tmp/cf-origin-tls-%d.sock", time.Now().UnixNano())
_ = os.Remove(socketPath)
t.Cleanup(func() { _ = os.Remove(socketPath) })
listener, err := net.Listen("unix", socketPath)
if err != nil {
t.Fatal(err)
}
certificate, err := boxTLS.GenerateKeyPair(nil, nil, time.Now, "localhost")
if err != nil {
t.Fatal(err)
}
tlsListener := stdTLS.NewListener(listener, &stdTLS.Config{
Certificates: []stdTLS.Certificate{*certificate},
})
defer tlsListener.Close()
go serveTestHTTPOverListener(tlsListener, func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(http.StatusOK)
_, _ = writer.Write([]byte("unix-tls-ok"))
})
inboundInstance := &Inbound{}
transport, cleanup, err := inboundInstance.newDirectOriginTransport(ResolvedService{
Kind: ResolvedServiceUnixTLS,
OriginRequest: OriginRequestConfig{
NoTLSVerify: true,
},
UnixPath: socketPath,
BaseURL: &url.URL{
Scheme: "https",
Host: "localhost",
},
}, "")
if err != nil {
t.Fatal(err)
}
defer cleanup()
client := &http.Client{Transport: transport}
resp, err := client.Get("https://localhost/")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if string(body) != "unix-tls-ok" {
t.Fatalf("unexpected response body: %q", string(body))
}
}
func serveTestHTTPOverListener(listener net.Listener, handler func(http.ResponseWriter, *http.Request)) {
server := &http.Server{Handler: http.HandlerFunc(handler)}
_ = server.Serve(listener)
}
func TestDirectOriginTransportCacheReusesMatchingTransports(t *testing.T) {
inboundInstance := &Inbound{
directTransports: make(map[string]*http.Transport),
}
service := ResolvedService{
Kind: ResolvedServiceUnix,
UnixPath: "/tmp/test.sock",
BaseURL: &url.URL{Scheme: "http", Host: "localhost"},
}
transport1, _, err := inboundInstance.newDirectOriginTransport(service, "example.com")
if err != nil {
t.Fatal(err)
}
transport2, _, err := inboundInstance.newDirectOriginTransport(service, "example.com")
if err != nil {
t.Fatal(err)
}
if transport1 != transport2 {
t.Fatal("expected matching direct-origin transports to be reused")
}
transport3, _, err := inboundInstance.newDirectOriginTransport(service, "other.example.com")
if err != nil {
t.Fatal(err)
}
if transport3 == transport1 {
t.Fatal("expected different cache keys to produce different transports")
}
}
func TestApplyConfigClearsDirectOriginTransportCache(t *testing.T) {
configManager, err := NewConfigManager()
if err != nil {
t.Fatal(err)
}
inboundInstance := &Inbound{
logger: log.NewNOPFactory().NewLogger("test"),
configManager: configManager,
directTransports: make(map[string]*http.Transport),
}
service := ResolvedService{
Kind: ResolvedServiceUnix,
UnixPath: "/tmp/test.sock",
BaseURL: &url.URL{Scheme: "http", Host: "localhost"},
}
transport1, _, err := inboundInstance.newDirectOriginTransport(service, "example.com")
if err != nil {
t.Fatal(err)
}
result := inboundInstance.ApplyConfig(1, []byte(`{"ingress":[{"service":"http_status:503"}]}`))
if result.Err != nil {
t.Fatal(result.Err)
}
transport2, _, err := inboundInstance.newDirectOriginTransport(service, "example.com")
if err != nil {
t.Fatal(err)
}
if transport1 == transport2 {
t.Fatal("expected ApplyConfig to clear direct-origin transport cache")
}
}
func TestNewDirectOriginTransportUsesCloudflaredDefaults(t *testing.T) {
inboundInstance := &Inbound{}
transport, cleanup, err := inboundInstance.newDirectOriginTransport(ResolvedService{
Kind: ResolvedServiceUnix,
UnixPath: "/tmp/test.sock",
BaseURL: &url.URL{Scheme: "http", Host: "localhost"},
}, "")
if err != nil {
t.Fatal(err)
}
defer cleanup()
if transport.ExpectContinueTimeout != time.Second {
t.Fatalf("expected ExpectContinueTimeout=1s, got %s", transport.ExpectContinueTimeout)
}
if transport.DisableCompression {
t.Fatal("expected compression to remain enabled by default")
}
}

View File

@@ -1,731 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"crypto/tls"
"encoding/json"
"io"
"net"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
const (
metadataHTTPMethod = "HttpMethod"
metadataHTTPHost = "HttpHost"
metadataHTTPHeader = "HttpHeader"
metadataHTTPStatus = "HttpStatus"
)
var (
loadOriginCABasePool = cloudflareRootCertPool
readOriginCAFile = os.ReadFile
proxyFromEnvironment = http.ProxyFromEnvironment
)
// ConnectResponseWriter abstracts the response writing for both QUIC and HTTP/2.
type ConnectResponseWriter interface {
// WriteResponse sends the connect response (ack or error) with optional metadata.
WriteResponse(responseError error, metadata []Metadata) error
}
type connectResponseTrailerWriter interface {
AddTrailer(name, value string)
}
// quicResponseWriter writes ConnectResponse in QUIC data stream format (signature + capnp).
type quicResponseWriter struct {
stream io.Writer
}
func (w *quicResponseWriter) WriteResponse(responseError error, metadata []Metadata) error {
return WriteConnectResponse(w.stream, responseError, metadata...)
}
// HandleDataStream dispatches an incoming edge data stream (QUIC path).
func (i *Inbound) HandleDataStream(ctx context.Context, stream io.ReadWriteCloser, request *ConnectRequest, connIndex uint8) {
ctx = log.ContextWithNewID(ctx)
respWriter := &quicResponseWriter{stream: stream}
i.dispatchRequest(ctx, stream, respWriter, request)
}
// HandleRPCStream handles an incoming edge RPC stream (session management, configuration).
func (i *Inbound) HandleRPCStream(ctx context.Context, stream io.ReadWriteCloser, connIndex uint8) {
i.logger.DebugContext(ctx, "received RPC stream on connection ", connIndex)
// V2 RPC streams are handled here - the edge calls RegisterUdpSession/UnregisterUdpSession
// We need the sender (DatagramSender) to find the muxer - but HandleRPCStream doesn't have it.
// The V2 muxer is looked up via GetOrCreateV2Muxer in HandleDatagram when first datagram arrives.
// For RPC, we need a different approach - see handleRPCStreamWithSender below.
}
// HandleRPCStreamWithSender handles an RPC stream with access to the DatagramSender for V2 muxer lookup.
func (i *Inbound) HandleRPCStreamWithSender(ctx context.Context, stream io.ReadWriteCloser, connIndex uint8, sender DatagramSender) {
switch datagramVersionForSender(sender) {
case "v3":
ServeV3RPCStream(ctx, stream, i, i.logger)
default:
muxer := i.getOrCreateV2Muxer(sender)
ServeRPCStream(ctx, stream, i, muxer, i.logger)
}
}
// HandleDatagram handles an incoming QUIC datagram.
func (i *Inbound) HandleDatagram(ctx context.Context, datagram []byte, sender DatagramSender) {
switch datagramVersionForSender(sender) {
case "v3":
muxer := i.getOrCreateV3Muxer(sender)
muxer.HandleDatagram(ctx, datagram)
default:
muxer := i.getOrCreateV2Muxer(sender)
muxer.HandleDatagram(ctx, datagram)
}
}
func (i *Inbound) getOrCreateV2Muxer(sender DatagramSender) *DatagramV2Muxer {
i.datagramMuxerAccess.Lock()
defer i.datagramMuxerAccess.Unlock()
muxer, exists := i.datagramV2Muxers[sender]
if !exists {
muxer = NewDatagramV2Muxer(i, sender, i.logger)
i.datagramV2Muxers[sender] = muxer
}
return muxer
}
func (i *Inbound) getOrCreateV3Muxer(sender DatagramSender) *DatagramV3Muxer {
i.datagramMuxerAccess.Lock()
defer i.datagramMuxerAccess.Unlock()
muxer, exists := i.datagramV3Muxers[sender]
if !exists {
muxer = NewDatagramV3Muxer(i, sender, i.logger)
i.datagramV3Muxers[sender] = muxer
}
return muxer
}
// RemoveDatagramMuxer cleans up muxers when a connection closes.
func (i *Inbound) RemoveDatagramMuxer(sender DatagramSender) {
i.datagramMuxerAccess.Lock()
if muxer, exists := i.datagramV2Muxers[sender]; exists {
muxer.Close()
delete(i.datagramV2Muxers, sender)
}
if muxer, exists := i.datagramV3Muxers[sender]; exists {
muxer.Close()
delete(i.datagramV3Muxers, sender)
}
i.datagramMuxerAccess.Unlock()
}
func (i *Inbound) dispatchRequest(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest) {
metadata := adapter.InboundContext{
Inbound: i.Tag(),
InboundType: i.Type(),
}
switch request.Type {
case ConnectionTypeTCP:
metadata.Destination = M.ParseSocksaddr(request.Dest)
i.handleTCPStream(ctx, stream, respWriter, metadata)
case ConnectionTypeHTTP, ConnectionTypeWebsocket:
service, originURL, err := i.resolveHTTPService(request.Dest)
if err != nil {
i.logger.ErrorContext(ctx, "resolve origin service: ", err)
respWriter.WriteResponse(err, nil)
return
}
request.Dest = originURL
i.handleHTTPService(ctx, stream, respWriter, request, metadata, service)
default:
i.logger.ErrorContext(ctx, "unknown connection type: ", request.Type)
}
}
func (i *Inbound) resolveHTTPService(requestURL string) (ResolvedService, string, error) {
parsedURL, err := url.Parse(requestURL)
if err != nil {
return ResolvedService{}, "", E.Cause(err, "parse request URL")
}
service, loaded := i.configManager.Resolve(parsedURL.Hostname(), parsedURL.Path)
if !loaded {
return ResolvedService{}, "", E.New("no ingress rule matched request host/path")
}
originURL, err := service.BuildRequestURL(requestURL)
if err != nil {
return ResolvedService{}, "", E.Cause(err, "build origin request URL")
}
return service, originURL, nil
}
func parseHTTPDestination(dest string) M.Socksaddr {
parsed, err := url.Parse(dest)
if err != nil {
return M.ParseSocksaddr(dest)
}
host := parsed.Hostname()
port := parsed.Port()
if port == "" {
switch parsed.Scheme {
case "https", "wss":
port = "443"
default:
port = "80"
}
}
return M.ParseSocksaddr(net.JoinHostPort(host, port))
}
func (i *Inbound) handleTCPStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, metadata adapter.InboundContext) {
metadata.Network = N.NetworkTCP
i.logger.InfoContext(ctx, "inbound TCP connection to ", metadata.Destination)
limit := i.maxActiveFlows()
if !i.flowLimiter.Acquire(limit) {
err := E.New("too many active flows")
i.logger.ErrorContext(ctx, err)
respWriter.WriteResponse(err, flowConnectRateLimitedMetadata())
return
}
defer i.flowLimiter.Release(limit)
warpRouting := i.configManager.Snapshot().WarpRouting
targetConn, cleanup, err := i.dialRouterTCPWithMetadata(ctx, metadata, routedPipeTCPOptions{
timeout: warpRouting.ConnectTimeout,
onHandshake: func(conn net.Conn) {
_ = applyTCPKeepAlive(conn, warpRouting.TCPKeepAlive)
},
})
if err != nil {
i.logger.ErrorContext(ctx, "dial tcp origin: ", err)
respWriter.WriteResponse(err, nil)
return
}
defer cleanup()
// Cloudflare expects an optimistic ACK here so the routed TCP path can sniff
// the real input stream before the outbound connection is fully established.
err = respWriter.WriteResponse(nil, nil)
if err != nil {
i.logger.ErrorContext(ctx, "write connect response: ", err)
return
}
err = bufio.CopyConn(ctx, newStreamConn(stream), targetConn)
if err != nil && !E.IsClosedOrCanceled(err) {
i.logger.DebugContext(ctx, "copy TCP stream: ", err)
}
}
func (i *Inbound) handleHTTPService(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
validationRequest, err := buildMetadataOnlyHTTPRequest(ctx, request)
if err != nil {
i.logger.ErrorContext(ctx, "build request for access validation: ", err)
respWriter.WriteResponse(err, nil)
return
}
validationRequest = applyOriginRequest(validationRequest, service.OriginRequest)
if service.OriginRequest.Access.Required {
validator, err := i.accessCache.Get(service.OriginRequest.Access)
if err != nil {
i.logger.ErrorContext(ctx, "create access validator: ", err)
respWriter.WriteResponse(err, nil)
return
}
if err := validator.Validate(validationRequest.Context(), validationRequest); err != nil {
respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusForbidden, http.Header{}))
return
}
}
switch service.Kind {
case ResolvedServiceStatus:
err = respWriter.WriteResponse(nil, encodeResponseHeaders(service.StatusCode, http.Header{}))
if err != nil {
i.logger.ErrorContext(ctx, "write status service response: ", err)
}
return
case ResolvedServiceHTTP:
metadata.Destination = service.Destination
if request.Type == ConnectionTypeHTTP {
i.handleHTTPStream(ctx, stream, respWriter, request, metadata, service)
} else {
i.handleWebSocketStream(ctx, stream, respWriter, request, metadata, service)
}
case ResolvedServiceStream:
if request.Type != ConnectionTypeWebsocket {
err := E.New("stream service requires websocket request type")
i.logger.ErrorContext(ctx, err)
respWriter.WriteResponse(err, nil)
return
}
i.handleStreamService(ctx, stream, respWriter, request, metadata, service)
case ResolvedServiceUnix, ResolvedServiceUnixTLS:
if request.Type == ConnectionTypeHTTP {
i.handleDirectHTTPStream(ctx, stream, respWriter, request, metadata, service)
} else {
i.handleDirectWebSocketStream(ctx, stream, respWriter, request, metadata, service)
}
case ResolvedServiceBastion:
if request.Type != ConnectionTypeWebsocket {
err := E.New("bastion service requires websocket request type")
i.logger.ErrorContext(ctx, err)
respWriter.WriteResponse(err, nil)
return
}
i.handleBastionStream(ctx, stream, respWriter, request, metadata, service)
case ResolvedServiceSocksProxy:
if request.Type != ConnectionTypeWebsocket {
err := E.New("socks-proxy service requires websocket request type")
i.logger.ErrorContext(ctx, err)
respWriter.WriteResponse(err, nil)
return
}
i.handleSocksProxyStream(ctx, stream, respWriter, request, metadata, service)
default:
err := E.New("unsupported service kind for HTTP/WebSocket request")
i.logger.ErrorContext(ctx, err)
respWriter.WriteResponse(err, nil)
}
}
func (i *Inbound) handleHTTPStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
metadata.Network = N.NetworkTCP
i.logger.InfoContext(ctx, "inbound HTTP connection to ", metadata.Destination)
transport, cleanup, err := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest, request.MetadataMap()[metadataHTTPHost])
if err != nil {
i.logger.ErrorContext(ctx, "build origin transport: ", err)
respWriter.WriteResponse(err, nil)
return
}
defer cleanup()
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
}
func (i *Inbound) handleWebSocketStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
metadata.Network = N.NetworkTCP
i.logger.InfoContext(ctx, "inbound WebSocket connection to ", metadata.Destination)
transport, cleanup, err := i.newRouterOriginTransport(ctx, metadata, service.OriginRequest, request.MetadataMap()[metadataHTTPHost])
if err != nil {
i.logger.ErrorContext(ctx, "build origin transport: ", err)
respWriter.WriteResponse(err, nil)
return
}
defer cleanup()
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
}
func (i *Inbound) handleDirectHTTPStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
metadata.Network = N.NetworkTCP
i.logger.InfoContext(ctx, "inbound HTTP connection to ", request.Dest)
transport, cleanup, err := i.newDirectOriginTransport(service, request.MetadataMap()[metadataHTTPHost])
if err != nil {
i.logger.ErrorContext(ctx, "build direct origin transport: ", err)
respWriter.WriteResponse(err, nil)
return
}
defer cleanup()
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
}
func (i *Inbound) handleDirectWebSocketStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
metadata.Network = N.NetworkTCP
i.logger.InfoContext(ctx, "inbound WebSocket connection to ", request.Dest)
transport, cleanup, err := i.newDirectOriginTransport(service, request.MetadataMap()[metadataHTTPHost])
if err != nil {
i.logger.ErrorContext(ctx, "build direct origin transport: ", err)
respWriter.WriteResponse(err, nil)
return
}
defer cleanup()
i.roundTripHTTP(ctx, stream, respWriter, request, service, transport)
}
func (i *Inbound) roundTripHTTP(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, service ResolvedService, transport *http.Transport) {
httpRequest, err := buildHTTPRequestFromMetadata(ctx, request, stream)
if err != nil {
i.logger.ErrorContext(ctx, "build HTTP request: ", err)
respWriter.WriteResponse(err, nil)
return
}
httpRequest = normalizeOriginRequest(request.Type, httpRequest, service.OriginRequest)
requestCtx := httpRequest.Context()
if service.OriginRequest.ConnectTimeout > 0 {
var cancel context.CancelFunc
requestCtx, cancel = context.WithTimeout(requestCtx, service.OriginRequest.ConnectTimeout)
defer cancel()
httpRequest = httpRequest.WithContext(requestCtx)
}
httpClient := &http.Client{
Transport: transport,
CheckRedirect: func(request *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
},
}
response, err := httpClient.Do(httpRequest)
if err != nil {
i.logger.ErrorContext(ctx, "origin request: ", err)
respWriter.WriteResponse(err, nil)
return
}
defer response.Body.Close()
responseMetadata := encodeResponseHeaders(response.StatusCode, response.Header)
err = respWriter.WriteResponse(nil, responseMetadata)
if err != nil {
i.logger.ErrorContext(ctx, "write origin response headers: ", err)
return
}
if request.Type == ConnectionTypeWebsocket && response.StatusCode == http.StatusSwitchingProtocols {
rwc, ok := response.Body.(io.ReadWriteCloser)
if !ok {
i.logger.ErrorContext(ctx, "websocket origin response body is not duplex")
return
}
bidirectionalCopy(stream, rwc)
return
}
_, err = io.Copy(stream, response.Body)
if err != nil && !E.IsClosedOrCanceled(err) {
i.logger.DebugContext(ctx, "copy HTTP response body: ", err)
}
if trailerWriter, ok := respWriter.(connectResponseTrailerWriter); ok {
for name, values := range response.Trailer {
for _, value := range values {
trailerWriter.AddTrailer(name, value)
}
}
}
}
func (i *Inbound) newRouterOriginTransport(ctx context.Context, metadata adapter.InboundContext, originRequest OriginRequestConfig, requestHost string) (*http.Transport, func(), error) {
tlsConfig, err := newOriginTLSConfig(originRequest, effectiveOriginHost(originRequest, requestHost))
if err != nil {
return nil, nil, err
}
input, cleanup, _ := i.dialRouterTCPWithMetadata(ctx, metadata, routedPipeTCPOptions{})
transport := &http.Transport{
ExpectContinueTimeout: time.Second,
ForceAttemptHTTP2: originRequest.HTTP2Origin,
TLSHandshakeTimeout: originRequest.TLSTimeout,
IdleConnTimeout: originRequest.KeepAliveTimeout,
MaxIdleConns: originRequest.KeepAliveConnections,
MaxIdleConnsPerHost: originRequest.KeepAliveConnections,
Proxy: proxyFromEnvironment,
TLSClientConfig: tlsConfig,
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return input, nil
},
}
return transport, cleanup, nil
}
func (i *Inbound) newDirectOriginTransport(service ResolvedService, requestHost string) (*http.Transport, func(), error) {
cacheKey, err := directOriginTransportKey(service, requestHost)
if err != nil {
return nil, nil, E.Cause(err, "marshal direct origin transport key")
}
i.directTransportAccess.Lock()
if i.directTransports == nil {
i.directTransports = make(map[string]*http.Transport)
}
if transport, exists := i.directTransports[cacheKey]; exists {
i.directTransportAccess.Unlock()
return transport, func() {}, nil
}
i.directTransportAccess.Unlock()
dialer := &net.Dialer{
Timeout: service.OriginRequest.ConnectTimeout,
KeepAlive: service.OriginRequest.TCPKeepAlive,
}
if service.OriginRequest.NoHappyEyeballs {
dialer.FallbackDelay = -1
}
tlsConfig, err := newOriginTLSConfig(service.OriginRequest, effectiveOriginHost(service.OriginRequest, requestHost))
if err != nil {
return nil, nil, err
}
transport := &http.Transport{
ExpectContinueTimeout: time.Second,
ForceAttemptHTTP2: service.OriginRequest.HTTP2Origin,
TLSHandshakeTimeout: service.OriginRequest.TLSTimeout,
IdleConnTimeout: service.OriginRequest.KeepAliveTimeout,
MaxIdleConns: service.OriginRequest.KeepAliveConnections,
MaxIdleConnsPerHost: service.OriginRequest.KeepAliveConnections,
Proxy: proxyFromEnvironment,
TLSClientConfig: tlsConfig,
}
switch service.Kind {
case ResolvedServiceUnix, ResolvedServiceUnixTLS:
transport.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
return dialer.DialContext(ctx, "unix", service.UnixPath)
}
default:
return nil, nil, E.New("unsupported direct origin service")
}
i.directTransportAccess.Lock()
if i.directTransports == nil {
i.directTransports = make(map[string]*http.Transport)
}
if cached, exists := i.directTransports[cacheKey]; exists {
i.directTransportAccess.Unlock()
transport.CloseIdleConnections()
return cached, func() {}, nil
}
i.directTransports[cacheKey] = transport
i.directTransportAccess.Unlock()
return transport, func() {}, nil
}
type directOriginTransportCacheKey struct {
Kind ResolvedServiceKind `json:"kind"`
UnixPath string `json:"unix_path,omitempty"`
RequestHost string `json:"request_host,omitempty"`
Origin OriginRequestConfig `json:"origin"`
}
func directOriginTransportKey(service ResolvedService, requestHost string) (string, error) {
key := directOriginTransportCacheKey{
Kind: service.Kind,
UnixPath: service.UnixPath,
RequestHost: effectiveOriginHost(service.OriginRequest, requestHost),
Origin: service.OriginRequest,
}
data, err := json.Marshal(key)
if err != nil {
return "", err
}
return string(data), nil
}
func effectiveOriginHost(originRequest OriginRequestConfig, requestHost string) string {
if originRequest.HTTPHostHeader != "" {
return originRequest.HTTPHostHeader
}
return requestHost
}
func newOriginTLSConfig(originRequest OriginRequestConfig, requestHost string) (*tls.Config, error) {
rootCAs, err := loadOriginCABasePool()
if err != nil {
return nil, E.Cause(err, "load origin root CAs")
}
tlsConfig := &tls.Config{
InsecureSkipVerify: originRequest.NoTLSVerify, //nolint:gosec
ServerName: originTLSServerName(originRequest, requestHost),
RootCAs: rootCAs,
}
if originRequest.CAPool == "" {
return tlsConfig, nil
}
pemData, err := readOriginCAFile(originRequest.CAPool)
if err != nil {
return nil, E.Cause(err, "read origin ca pool")
}
if !tlsConfig.RootCAs.AppendCertsFromPEM(pemData) {
return nil, E.New("parse origin ca pool")
}
return tlsConfig, nil
}
func originTLSServerName(originRequest OriginRequestConfig, requestHost string) string {
if originRequest.OriginServerName != "" {
return originRequest.OriginServerName
}
if !originRequest.MatchSNIToHost {
return ""
}
if host, _, err := net.SplitHostPort(requestHost); err == nil {
return host
}
return requestHost
}
func applyOriginRequest(request *http.Request, originRequest OriginRequestConfig) *http.Request {
request = request.Clone(request.Context())
if originRequest.HTTPHostHeader != "" {
request.Header.Set("X-Forwarded-Host", request.Host)
request.Host = originRequest.HTTPHostHeader
}
return request
}
func normalizeOriginRequest(connectType ConnectionType, request *http.Request, originRequest OriginRequestConfig) *http.Request {
request = applyOriginRequest(request, originRequest)
switch connectType {
case ConnectionTypeWebsocket:
request.Header.Set("Connection", "Upgrade")
request.Header.Set("Upgrade", "websocket")
request.Header.Set("Sec-Websocket-Version", "13")
request.ContentLength = 0
request.Body = nil
default:
if originRequest.DisableChunkedEncoding {
request.TransferEncoding = []string{"gzip", "deflate"}
if contentLength, err := strconv.ParseInt(request.Header.Get("Content-Length"), 10, 64); err == nil {
request.ContentLength = contentLength
}
}
request.Header.Set("Connection", "keep-alive")
}
if _, exists := request.Header["User-Agent"]; !exists {
request.Header.Set("User-Agent", "")
}
return request
}
func buildMetadataOnlyHTTPRequest(ctx context.Context, connectRequest *ConnectRequest) (*http.Request, error) {
return buildHTTPRequestFromMetadata(ctx, &ConnectRequest{
Dest: connectRequest.Dest,
Type: connectRequest.Type,
Metadata: append([]Metadata(nil), connectRequest.Metadata...),
}, http.NoBody)
}
func bidirectionalCopy(left, right io.ReadWriteCloser) {
var closeOnce sync.Once
closeBoth := func() {
closeOnce.Do(func() {
common.Close(left, right)
})
}
done := make(chan struct{}, 2)
go func() {
io.Copy(left, right)
closeBoth()
done <- struct{}{}
}()
go func() {
io.Copy(right, left)
closeBoth()
done <- struct{}{}
}()
<-done
<-done
}
func buildHTTPRequestFromMetadata(ctx context.Context, connectRequest *ConnectRequest, body io.Reader) (*http.Request, error) {
metadataMap := connectRequest.MetadataMap()
method := metadataMap[metadataHTTPMethod]
host := metadataMap[metadataHTTPHost]
request, err := http.NewRequestWithContext(ctx, method, connectRequest.Dest, body)
if err != nil {
return nil, E.Cause(err, "create HTTP request")
}
request.Host = host
for _, entry := range connectRequest.Metadata {
if !strings.Contains(entry.Key, metadataHTTPHeader) {
continue
}
parts := strings.SplitN(entry.Key, ":", 2)
if len(parts) != 2 {
continue
}
request.Header.Add(parts[1], entry.Val)
}
contentLengthStr := request.Header.Get("Content-Length")
if contentLengthStr != "" {
request.ContentLength, err = strconv.ParseInt(contentLengthStr, 10, 64)
if err != nil {
return nil, E.Cause(err, "parse content-length")
}
}
if connectRequest.Type != ConnectionTypeWebsocket && !isTransferEncodingChunked(request) && request.ContentLength == 0 {
request.Body = http.NoBody
}
request.Header.Del("Cf-Cloudflared-Proxy-Connection-Upgrade")
return request, nil
}
func isTransferEncodingChunked(request *http.Request) bool {
for _, encoding := range request.TransferEncoding {
if strings.Contains(strings.ToLower(encoding), "chunked") {
return true
}
}
return strings.Contains(strings.ToLower(request.Header.Get("Transfer-Encoding")), "chunked")
}
func encodeResponseHeaders(statusCode int, header http.Header) []Metadata {
metadata := make([]Metadata, 0, len(header)+1)
metadata = append(metadata, Metadata{
Key: metadataHTTPStatus,
Val: strconv.Itoa(statusCode),
})
for name, values := range header {
for _, value := range values {
metadata = append(metadata, Metadata{
Key: metadataHTTPHeader + ":" + name,
Val: value,
})
}
}
return metadata
}
// streamConn wraps an io.ReadWriteCloser as a net.Conn.
type streamConn struct {
io.ReadWriteCloser
}
func newStreamConn(stream io.ReadWriteCloser) *streamConn {
return &streamConn{ReadWriteCloser: stream}
}
func (c *streamConn) LocalAddr() net.Addr { return nil }
func (c *streamConn) RemoteAddr() net.Addr { return nil }
func (c *streamConn) SetDeadline(_ time.Time) error { return nil }
func (c *streamConn) SetReadDeadline(_ time.Time) error { return nil }
func (c *streamConn) SetWriteDeadline(_ time.Time) error { return nil }
type datagramVersionedSender interface {
DatagramVersion() string
}
func datagramVersionForSender(sender DatagramSender) string {
versioned, ok := sender.(datagramVersionedSender)
if !ok {
return defaultDatagramVersion
}
version := versioned.DatagramVersion()
if version == "" {
return defaultDatagramVersion
}
return version
}

View File

@@ -1,137 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"net/http"
"testing"
)
func TestParseHTTPDestination(t *testing.T) {
tests := []struct {
name string
dest string
expected string
}{
{"http with port", "http://127.0.0.1:8083/path", "127.0.0.1:8083"},
{"https default port", "https://example.com", "example.com:443"},
{"http default port", "http://example.com", "example.com:80"},
{"wss default port", "wss://example.com/ws", "example.com:443"},
{"explicit port", "https://example.com:9443/api", "example.com:9443"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := parseHTTPDestination(tt.dest)
if result.String() != tt.expected {
t.Errorf("parseHTTPDestination(%q) = %q, want %q", tt.dest, result.String(), tt.expected)
}
})
}
}
func TestSerializeHeaders(t *testing.T) {
header := http.Header{}
header.Set("Content-Type", "text/html")
header.Set("X-Foo", "bar")
serialized := SerializeHeaders(header)
if serialized == "" {
t.Fatal("expected non-empty serialized headers")
}
decoded := make(map[string]string)
for _, pair := range splitNonEmpty(serialized, ";") {
parts := splitNonEmpty(pair, ":")
if len(parts) != 2 {
t.Fatalf("malformed pair: %q", pair)
}
name, err := headerEncoding.DecodeString(parts[0])
if err != nil {
t.Fatal("decode name: ", err)
}
value, err := headerEncoding.DecodeString(parts[1])
if err != nil {
t.Fatal("decode value: ", err)
}
decoded[string(name)] = string(value)
}
if decoded["Content-Type"] != "text/html" {
t.Error("expected Content-Type=text/html, got ", decoded["Content-Type"])
}
if decoded["X-Foo"] != "bar" {
t.Error("expected X-Foo=bar, got ", decoded["X-Foo"])
}
}
func splitNonEmpty(s string, sep string) []string {
var result []string
for _, part := range splitString(s, sep) {
if part != "" {
result = append(result, part)
}
}
return result
}
func splitString(s string, sep string) []string {
if len(sep) == 0 {
return []string{s}
}
var result []string
start := 0
for i := 0; i <= len(s)-len(sep); i++ {
if s[i:i+len(sep)] == sep {
result = append(result, s[start:i])
start = i + len(sep)
i += len(sep) - 1
}
}
result = append(result, s[start:])
return result
}
func TestIsControlResponseHeader(t *testing.T) {
tests := []struct {
name string
expected bool
}{
{":status", true},
{"cf-int-foo", true},
{"cf-cloudflared-response-meta", true},
{"cf-proxy-src", true},
{"content-type", false},
{"x-custom", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isControlResponseHeader(tt.name)
if result != tt.expected {
t.Errorf("isControlResponseHeader(%q) = %v, want %v", tt.name, result, tt.expected)
}
})
}
}
func TestIsWebsocketClientHeader(t *testing.T) {
tests := []struct {
name string
expected bool
}{
{"sec-websocket-accept", true},
{"connection", true},
{"upgrade", true},
{"content-type", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := isWebsocketClientHeader(tt.name)
if result != tt.expected {
t.Errorf("isWebsocketClientHeader(%q) = %v, want %v", tt.name, result, tt.expected)
}
})
}
}

View File

@@ -1,130 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"crypto/tls"
"net"
"time"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
const (
edgeSRVService = "v2-origintunneld"
edgeSRVProto = "tcp"
edgeSRVName = "argotunnel.com"
dotServerName = "cloudflare-dns.com"
dotServerAddr = "1.1.1.1:853"
dotTimeout = 15 * time.Second
)
func getRegionalServiceName(region string) string {
if region == "" {
return edgeSRVService
}
return region + "-" + edgeSRVService
}
// EdgeAddr represents a Cloudflare edge server address.
type EdgeAddr struct {
TCP *net.TCPAddr
UDP *net.UDPAddr
IPVersion int // 4 or 6
}
// DiscoverEdge performs SRV-based edge discovery and returns addresses
// partitioned into regions (typically 2).
func DiscoverEdge(ctx context.Context, region string, controlDialer N.Dialer) ([][]*EdgeAddr, error) {
regions, err := lookupEdgeSRV(region)
if err != nil {
regions, err = lookupEdgeSRVWithDoT(ctx, region, controlDialer)
if err != nil {
return nil, E.Cause(err, "edge discovery")
}
}
if len(regions) == 0 {
return nil, E.New("edge discovery: no edge addresses found")
}
return regions, nil
}
func lookupEdgeSRV(region string) ([][]*EdgeAddr, error) {
_, addrs, err := net.LookupSRV(getRegionalServiceName(region), edgeSRVProto, edgeSRVName)
if err != nil {
return nil, err
}
return resolveSRVRecords(addrs)
}
func lookupEdgeSRVWithDoT(ctx context.Context, region string, controlDialer N.Dialer) ([][]*EdgeAddr, error) {
resolver := &net.Resolver{
PreferGo: true,
Dial: func(ctx context.Context, _, _ string) (net.Conn, error) {
conn, err := controlDialer.DialContext(ctx, "tcp", M.ParseSocksaddr(dotServerAddr))
if err != nil {
return nil, err
}
return tls.Client(conn, &tls.Config{ServerName: dotServerName}), nil
},
}
lookupCtx, cancel := context.WithTimeout(ctx, dotTimeout)
defer cancel()
_, addrs, err := resolver.LookupSRV(lookupCtx, getRegionalServiceName(region), edgeSRVProto, edgeSRVName)
if err != nil {
return nil, err
}
return resolveSRVRecords(addrs)
}
func resolveSRVRecords(records []*net.SRV) ([][]*EdgeAddr, error) {
var regions [][]*EdgeAddr
for _, record := range records {
ips, err := net.LookupIP(record.Target)
if err != nil {
return nil, E.Cause(err, "resolve SRV target: ", record.Target)
}
if len(ips) == 0 {
continue
}
edgeAddrs := make([]*EdgeAddr, 0, len(ips))
for _, ip := range ips {
ipVersion := 6
if ip.To4() != nil {
ipVersion = 4
}
edgeAddrs = append(edgeAddrs, &EdgeAddr{
TCP: &net.TCPAddr{IP: ip, Port: int(record.Port)},
UDP: &net.UDPAddr{IP: ip, Port: int(record.Port)},
IPVersion: ipVersion,
})
}
regions = append(regions, edgeAddrs)
}
return regions, nil
}
// FilterByIPVersion filters edge addresses to only include the specified IP version.
// version 0 means no filtering (auto).
func FilterByIPVersion(regions [][]*EdgeAddr, version int) [][]*EdgeAddr {
if version == 0 {
return regions
}
var filtered [][]*EdgeAddr
for _, region := range regions {
var addrs []*EdgeAddr
for _, addr := range region {
if addr.IPVersion == version {
addrs = append(addrs, addr)
}
}
if len(addrs) > 0 {
filtered = append(filtered, addrs)
}
}
return filtered
}

View File

@@ -1,148 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"net"
"testing"
N "github.com/sagernet/sing/common/network"
)
func TestDiscoverEdge(t *testing.T) {
regions, err := DiscoverEdge(context.Background(), "", N.SystemDialer)
if err != nil {
t.Fatal("DiscoverEdge: ", err)
}
if len(regions) == 0 {
t.Fatal("expected at least 1 region")
}
for i, region := range regions {
if len(region) == 0 {
t.Errorf("region %d is empty", i)
continue
}
for j, addr := range region {
if addr.TCP == nil {
t.Errorf("region %d addr %d: TCP is nil", i, j)
}
if addr.UDP == nil {
t.Errorf("region %d addr %d: UDP is nil", i, j)
}
if addr.IPVersion != 4 && addr.IPVersion != 6 {
t.Errorf("region %d addr %d: invalid IPVersion %d", i, j, addr.IPVersion)
}
}
}
}
func TestFilterByIPVersion(t *testing.T) {
v4Addr := &EdgeAddr{
TCP: &net.TCPAddr{IP: net.IPv4(1, 1, 1, 1), Port: 7844},
UDP: &net.UDPAddr{IP: net.IPv4(1, 1, 1, 1), Port: 7844},
IPVersion: 4,
}
v6Addr := &EdgeAddr{
TCP: &net.TCPAddr{IP: net.ParseIP("2606:4700::1"), Port: 7844},
UDP: &net.UDPAddr{IP: net.ParseIP("2606:4700::1"), Port: 7844},
IPVersion: 6,
}
mixed := [][]*EdgeAddr{{v4Addr, v6Addr}}
tests := []struct {
name string
version int
expected int
}{
{"auto", 0, 2},
{"v4 only", 4, 1},
{"v6 only", 6, 1},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := FilterByIPVersion(mixed, tt.version)
total := 0
for _, region := range result {
total += len(region)
}
if total != tt.expected {
t.Errorf("expected %d addrs, got %d", tt.expected, total)
}
})
}
t.Run("no match", func(t *testing.T) {
v4Only := [][]*EdgeAddr{{v4Addr}}
result := FilterByIPVersion(v4Only, 6)
if len(result) != 0 {
t.Error("expected empty result for no match")
}
})
t.Run("empty input", func(t *testing.T) {
result := FilterByIPVersion(nil, 4)
if len(result) != 0 {
t.Error("expected empty result for nil input")
}
})
}
func TestGetRegionalServiceName(t *testing.T) {
if got := getRegionalServiceName(""); got != edgeSRVService {
t.Fatalf("expected global service %s, got %s", edgeSRVService, got)
}
if got := getRegionalServiceName("us"); got != "us-"+edgeSRVService {
t.Fatalf("expected regional service us-%s, got %s", edgeSRVService, got)
}
}
func TestInitialEdgeAddrIndex(t *testing.T) {
if got := initialEdgeAddrIndex(0, 4); got != 0 {
t.Fatalf("expected conn 0 to get index 0, got %d", got)
}
if got := initialEdgeAddrIndex(3, 4); got != 3 {
t.Fatalf("expected conn 3 to get index 3, got %d", got)
}
if got := initialEdgeAddrIndex(5, 4); got != 1 {
t.Fatalf("expected conn 5 to wrap to index 1, got %d", got)
}
if got := initialEdgeAddrIndex(2, 1); got != 0 {
t.Fatalf("expected single-address pool to always return 0, got %d", got)
}
}
func TestRotateEdgeAddrIndex(t *testing.T) {
if got := rotateEdgeAddrIndex(0, 4); got != 1 {
t.Fatalf("expected index 0 to rotate to 1, got %d", got)
}
if got := rotateEdgeAddrIndex(3, 4); got != 0 {
t.Fatalf("expected last index to wrap to 0, got %d", got)
}
if got := rotateEdgeAddrIndex(0, 1); got != 0 {
t.Fatalf("expected single-address pool to stay at 0, got %d", got)
}
}
func TestEffectiveHAConnections(t *testing.T) {
tests := []struct {
name string
requested int
available int
expected int
}{
{name: "requested below available", requested: 2, available: 4, expected: 2},
{name: "requested equals available", requested: 4, available: 4, expected: 4},
{name: "requested above available", requested: 5, available: 3, expected: 3},
{name: "no available edges", requested: 4, available: 0, expected: 0},
}
for _, testCase := range tests {
t.Run(testCase.name, func(t *testing.T) {
if actual := effectiveHAConnections(testCase.requested, testCase.available); actual != testCase.expected {
t.Fatalf("effectiveHAConnections(%d, %d) = %d, want %d", testCase.requested, testCase.available, actual, testCase.expected)
}
})
}
}

View File

@@ -1,17 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"crypto/tls"
"crypto/x509"
)
func newEdgeTLSConfig(rootCAs *x509.CertPool, serverName string, nextProtos []string) *tls.Config {
return &tls.Config{
RootCAs: rootCAs,
ServerName: serverName,
NextProtos: nextProtos,
CurvePreferences: []tls.CurveID{tls.CurveP256},
}
}

View File

@@ -1,31 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"crypto/tls"
"crypto/x509"
"testing"
)
func TestNewEdgeTLSConfigUsesP256(t *testing.T) {
rootCAs := x509.NewCertPool()
config := newEdgeTLSConfig(rootCAs, h2EdgeSNI, nil)
if config.RootCAs != rootCAs {
t.Fatal("expected root CA pool to be preserved")
}
if config.ServerName != h2EdgeSNI {
t.Fatalf("expected server name %q, got %q", h2EdgeSNI, config.ServerName)
}
if len(config.CurvePreferences) != 1 || config.CurvePreferences[0] != tls.CurveP256 {
t.Fatalf("unexpected curve preferences: %#v", config.CurvePreferences)
}
}
func TestNewEdgeTLSConfigPreservesNextProtos(t *testing.T) {
config := newEdgeTLSConfig(x509.NewCertPool(), quicEdgeSNI, []string{quicEdgeALPN})
if len(config.NextProtos) != 1 || config.NextProtos[0] != quicEdgeALPN {
t.Fatalf("unexpected next protos: %#v", config.NextProtos)
}
}

View File

@@ -1,123 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"encoding/json"
"hash/fnv"
"net"
"sync"
"time"
)
const (
featureSelectorHostname = "cfd-features.argotunnel.com"
featureLookupTimeout = 10 * time.Second
defaultDatagramVersion = "v2"
defaultFeatureRefreshInterval = time.Hour
)
type cloudflaredFeaturesRecord struct {
DatagramV3Percentage uint32 `json:"dv3_2"`
}
var lookupCloudflaredFeatures = func(ctx context.Context) ([]byte, error) {
lookupCtx, cancel := context.WithTimeout(ctx, featureLookupTimeout)
defer cancel()
records, err := net.DefaultResolver.LookupTXT(lookupCtx, featureSelectorHostname)
if err != nil || len(records) == 0 {
return nil, err
}
return []byte(records[0]), nil
}
type featureSelector struct {
configured string
accountTag string
lookup func(context.Context) ([]byte, error)
refreshInterval time.Duration
currentDatagramVersion string
access sync.RWMutex
}
func newFeatureSelector(ctx context.Context, accountTag string, configured string) *featureSelector {
selector := &featureSelector{
configured: configured,
accountTag: accountTag,
lookup: lookupCloudflaredFeatures,
refreshInterval: defaultFeatureRefreshInterval,
currentDatagramVersion: defaultDatagramVersion,
}
if configured != "" {
selector.currentDatagramVersion = configured
return selector
}
_ = selector.refresh(ctx)
if selector.refreshInterval > 0 {
go selector.refreshLoop(ctx)
}
return selector
}
func (s *featureSelector) Snapshot() (string, []string) {
if s == nil {
return defaultDatagramVersion, DefaultFeatures(defaultDatagramVersion)
}
s.access.RLock()
defer s.access.RUnlock()
return s.currentDatagramVersion, DefaultFeatures(s.currentDatagramVersion)
}
func (s *featureSelector) refresh(ctx context.Context) error {
if s == nil || s.configured != "" {
return nil
}
record, err := s.lookup(ctx)
if err != nil {
return err
}
version, err := resolveRemoteDatagramVersion(s.accountTag, record)
if err != nil {
return err
}
s.access.Lock()
s.currentDatagramVersion = version
s.access.Unlock()
return nil
}
func (s *featureSelector) refreshLoop(ctx context.Context) {
ticker := time.NewTicker(s.refreshInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
_ = s.refresh(ctx)
}
}
}
func resolveRemoteDatagramVersion(accountTag string, record []byte) (string, error) {
var features cloudflaredFeaturesRecord
if err := json.Unmarshal(record, &features); err != nil {
return "", err
}
if accountEnabled(accountTag, features.DatagramV3Percentage) {
return "v3", nil
}
return defaultDatagramVersion, nil
}
func accountEnabled(accountTag string, percentage uint32) bool {
if percentage == 0 {
return false
}
hasher := fnv.New32a()
_, _ = hasher.Write([]byte(accountTag))
return percentage > hasher.Sum32()%100
}

View File

@@ -1,119 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"errors"
"slices"
"testing"
)
func TestFeatureSelectorConfiguredWins(t *testing.T) {
t.Helper()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
selector := newFeatureSelector(ctx, "account", "v3")
version, features := selector.Snapshot()
if version != "v3" {
t.Fatalf("expected configured version to win, got %s", version)
}
if !slices.Contains(features, "support_datagram_v3_2") {
t.Fatalf("expected v3 feature list, got %#v", features)
}
}
func TestFeatureSelectorInitialRemoteSelection(t *testing.T) {
selector := &featureSelector{
accountTag: "account",
lookup: func(context.Context) ([]byte, error) { return []byte(`{"dv3_2":100}`), nil },
currentDatagramVersion: defaultDatagramVersion,
}
if err := selector.refresh(context.Background()); err != nil {
t.Fatal(err)
}
version, _ := selector.Snapshot()
if version != "v3" {
t.Fatalf("expected auto-selected v3, got %s", version)
}
}
func TestFeatureSelectorRefreshUpdatesSnapshot(t *testing.T) {
record := []byte(`{"dv3_2":0}`)
selector := &featureSelector{
accountTag: "account",
currentDatagramVersion: defaultDatagramVersion,
lookup: func(context.Context) ([]byte, error) {
return record, nil
},
}
if err := selector.refresh(context.Background()); err != nil {
t.Fatal(err)
}
version, _ := selector.Snapshot()
if version != defaultDatagramVersion {
t.Fatalf("expected initial v2, got %s", version)
}
record = []byte(`{"dv3_2":100}`)
if err := selector.refresh(context.Background()); err != nil {
t.Fatal(err)
}
version, _ = selector.Snapshot()
if version != "v3" {
t.Fatalf("expected refreshed v3, got %s", version)
}
}
func TestFeatureSelectorRefreshFailureKeepsPreviousValue(t *testing.T) {
selector := &featureSelector{
accountTag: "account",
currentDatagramVersion: "v3",
lookup: func(context.Context) ([]byte, error) {
return nil, errors.New("lookup failed")
},
}
if err := selector.refresh(context.Background()); err == nil {
t.Fatal("expected refresh failure")
}
version, _ := selector.Snapshot()
if version != "v3" {
t.Fatalf("expected previous version to be retained, got %s", version)
}
}
func TestInboundUsesFreshFeatureSnapshotOnRetry(t *testing.T) {
inbound := &Inbound{
featureSelector: &featureSelector{
accountTag: "account",
currentDatagramVersion: defaultDatagramVersion,
},
}
version, features := inbound.currentConnectionFeatures()
if version != defaultDatagramVersion {
t.Fatalf("expected initial v2, got %s", version)
}
if slices.Contains(features, "support_datagram_v3_2") {
t.Fatalf("unexpected v3 feature list: %#v", features)
}
inbound.featureSelector.access.Lock()
inbound.featureSelector.currentDatagramVersion = "v3"
inbound.featureSelector.access.Unlock()
version, features = inbound.currentConnectionFeatures()
if version != "v3" {
t.Fatalf("expected refreshed v3, got %s", version)
}
if !slices.Contains(features, "support_datagram_v3_2") {
t.Fatalf("expected v3 feature list, got %#v", features)
}
}

View File

@@ -1,34 +0,0 @@
//go:build with_cloudflared
package cloudflare
import "sync"
type FlowLimiter struct {
access sync.Mutex
active uint64
}
func (l *FlowLimiter) Acquire(limit uint64) bool {
if limit == 0 {
return true
}
l.access.Lock()
defer l.access.Unlock()
if l.active >= limit {
return false
}
l.active++
return true
}
func (l *FlowLimiter) Release(limit uint64) {
if limit == 0 {
return
}
l.access.Lock()
defer l.access.Unlock()
if l.active > 0 {
l.active--
}
}

View File

@@ -1,160 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"encoding/binary"
"net"
"net/http"
"net/http/httptest"
"testing"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/google/uuid"
)
type captureConnectMetadataWriter struct {
err error
metadata []Metadata
}
func (w *captureConnectMetadataWriter) WriteResponse(responseError error, metadata []Metadata) error {
w.err = responseError
w.metadata = append([]Metadata(nil), metadata...)
return nil
}
func newLimitedInbound(t *testing.T, limit uint64) *Inbound {
t.Helper()
logFactory, err := log.New(log.Options{Options: option.LogOptions{Level: "debug"}})
if err != nil {
t.Fatal(err)
}
configManager, err := NewConfigManager()
if err != nil {
t.Fatal(err)
}
ctx, cancel := context.WithCancel(context.Background())
t.Cleanup(cancel)
config := configManager.Snapshot()
config.WarpRouting.MaxActiveFlows = limit
configManager.activeConfig = config
return &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
ctx: ctx,
cancel: cancel,
router: &testRouter{},
logger: logFactory.NewLogger("test"),
configManager: configManager,
flowLimiter: &FlowLimiter{},
datagramV3Manager: NewDatagramV3SessionManager(),
connectionStates: make([]connectionState, 1),
successfulProtocols: make(map[string]struct{}),
directTransports: make(map[string]*http.Transport),
}
}
func TestHandleTCPStreamRespectsMaxActiveFlows(t *testing.T) {
inboundInstance := newLimitedInbound(t, 1)
if !inboundInstance.flowLimiter.Acquire(1) {
t.Fatal("failed to pre-acquire limiter")
}
stream, peer := net.Pipe()
defer stream.Close()
defer peer.Close()
respWriter := &fakeConnectResponseWriter{}
inboundInstance.handleTCPStream(context.Background(), stream, respWriter, adapter.InboundContext{})
if respWriter.err == nil {
t.Fatal("expected too many active flows error")
}
}
func TestHandleTCPStreamRateLimitMetadata(t *testing.T) {
inboundInstance := newLimitedInbound(t, 1)
if !inboundInstance.flowLimiter.Acquire(1) {
t.Fatal("failed to pre-acquire limiter")
}
stream, peer := net.Pipe()
defer stream.Close()
defer peer.Close()
respWriter := &captureConnectMetadataWriter{}
inboundInstance.handleTCPStream(context.Background(), stream, respWriter, adapter.InboundContext{})
if respWriter.err == nil {
t.Fatal("expected too many active flows error")
}
if !hasFlowConnectRateLimited(respWriter.metadata) {
t.Fatal("expected flow rate limit metadata")
}
}
func TestHTTP2ResponseWriterFlowRateLimitedMeta(t *testing.T) {
recorder := httptest.NewRecorder()
writer := &http2ResponseWriter{
writer: recorder,
flusher: recorder,
}
err := writer.WriteResponse(context.DeadlineExceeded, flowConnectRateLimitedMetadata())
if err != nil {
t.Fatal(err)
}
if recorder.Code != http.StatusBadGateway {
t.Fatalf("expected %d, got %d", http.StatusBadGateway, recorder.Code)
}
if meta := recorder.Header().Get(h2HeaderResponseMeta); meta != h2ResponseMetaCloudflaredLimited {
t.Fatalf("unexpected response meta: %q", meta)
}
}
func TestDatagramV2RegisterSessionRespectsMaxActiveFlows(t *testing.T) {
inboundInstance := newLimitedInbound(t, 1)
if !inboundInstance.flowLimiter.Acquire(1) {
t.Fatal("failed to pre-acquire limiter")
}
muxer := NewDatagramV2Muxer(inboundInstance, &captureDatagramSender{}, inboundInstance.logger)
err := muxer.RegisterSession(context.Background(), uuidTest(1), net.IPv4(1, 1, 1, 1), 53, 0)
if err == nil {
t.Fatal("expected too many active flows error")
}
}
func TestDatagramV3RegistrationTooManyActiveFlows(t *testing.T) {
inboundInstance := newLimitedInbound(t, 1)
if !inboundInstance.flowLimiter.Acquire(1) {
t.Fatal("failed to pre-acquire limiter")
}
sender := &captureDatagramSender{}
muxer := NewDatagramV3Muxer(inboundInstance, sender, inboundInstance.logger)
requestID := RequestID{}
requestID[15] = 1
payload := make([]byte, 1+1+2+2+16+4)
payload[0] = 0
binary.BigEndian.PutUint16(payload[1:3], 53)
binary.BigEndian.PutUint16(payload[3:5], 30)
copy(payload[5:21], requestID[:])
copy(payload[21:25], []byte{1, 1, 1, 1})
muxer.handleRegistration(context.Background(), payload)
if len(sender.sent) != 1 {
t.Fatalf("expected one registration response, got %d", len(sender.sent))
}
if sender.sent[0][0] != byte(DatagramV3TypeRegistrationResponse) || sender.sent[0][1] != v3ResponseTooManyActiveFlows {
t.Fatalf("unexpected v3 response: %v", sender.sent[0])
}
}
func uuidTest(last byte) uuid.UUID {
var value uuid.UUID
value[15] = last
return value
}

View File

@@ -1,55 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"encoding/base64"
"net/http"
"strings"
)
const (
h2HeaderUpgrade = "Cf-Cloudflared-Proxy-Connection-Upgrade"
h2HeaderTCPSrc = "Cf-Cloudflared-Proxy-Src"
h2HeaderResponseMeta = "Cf-Cloudflared-Response-Meta"
h2HeaderResponseUser = "Cf-Cloudflared-Response-Headers"
h2UpgradeControlStream = "control-stream"
h2UpgradeWebsocket = "websocket"
h2UpgradeConfiguration = "update-configuration"
h2ResponseMetaOrigin = `{"src":"origin"}`
)
var headerEncoding = base64.RawStdEncoding
// SerializeHeaders encodes HTTP/1 headers into base64 pairs: base64(name):base64(value);...
func SerializeHeaders(header http.Header) string {
var builder strings.Builder
for name, values := range header {
for _, value := range values {
if builder.Len() > 0 {
builder.WriteByte(';')
}
builder.WriteString(headerEncoding.EncodeToString([]byte(name)))
builder.WriteByte(':')
builder.WriteString(headerEncoding.EncodeToString([]byte(value)))
}
}
return builder.String()
}
// isControlResponseHeader returns true for headers that are internal control headers.
func isControlResponseHeader(name string) bool {
lower := strings.ToLower(name)
return strings.HasPrefix(lower, ":") ||
strings.HasPrefix(lower, "cf-int-") ||
strings.HasPrefix(lower, "cf-cloudflared-") ||
strings.HasPrefix(lower, "cf-proxy-")
}
// isWebsocketClientHeader returns true for headers needed by the client for WebSocket upgrade.
func isWebsocketClientHeader(name string) bool {
lower := strings.ToLower(name)
return lower == "sec-websocket-accept" ||
lower == "connection" ||
lower == "upgrade"
}

View File

@@ -1,247 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"os"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/bufio"
N "github.com/sagernet/sing/common/network"
"github.com/google/uuid"
)
func requireEnvVars(t *testing.T) (token string, testURL string) {
t.Helper()
token = os.Getenv("CF_TUNNEL_TOKEN")
testURL = os.Getenv("CF_TEST_URL")
if token == "" || testURL == "" {
t.Skip("CF_TUNNEL_TOKEN and CF_TEST_URL must be set")
}
return
}
func startOriginServer(t *testing.T) {
t.Helper()
mux := http.NewServeMux()
mux.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{"ok":true}`))
})
mux.HandleFunc("/echo", func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
io.Copy(w, r.Body)
})
mux.HandleFunc("/status/", func(w http.ResponseWriter, r *http.Request) {
codeStr := strings.TrimPrefix(r.URL.Path, "/status/")
code, err := strconv.Atoi(codeStr)
if err != nil {
code = 200
}
w.Header().Set("X-Custom", "test-value")
w.WriteHeader(code)
fmt.Fprintf(w, "status: %d", code)
})
mux.HandleFunc("/headers", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
json.NewEncoder(w).Encode(r.Header)
})
server := &http.Server{
Addr: "127.0.0.1:8083",
Handler: mux,
}
listener, err := net.Listen("tcp", server.Addr)
if err != nil {
t.Fatal("start origin server: ", err)
}
go server.Serve(listener)
t.Cleanup(func() {
server.Close()
})
}
type testRouter struct {
preMatch func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error)
}
func (r *testRouter) Start(stage adapter.StartStage) error { return nil }
func (r *testRouter) Close() error { return nil }
func (r *testRouter) RouteConnection(ctx context.Context, conn net.Conn, metadata adapter.InboundContext) error {
destination := metadata.Destination.String()
upstream, err := net.Dial("tcp", destination)
if err != nil {
conn.Close()
return err
}
go func() {
io.Copy(upstream, conn)
upstream.Close()
}()
io.Copy(conn, upstream)
conn.Close()
return nil
}
func (r *testRouter) RoutePacketConnection(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext) error {
return nil
}
func (r *testRouter) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
destination := metadata.Destination.String()
upstream, err := net.Dial("tcp", destination)
if err != nil {
conn.Close()
onClose(err)
return
}
var once sync.Once
closeFn := func() {
once.Do(func() {
conn.Close()
upstream.Close()
})
}
go func() {
io.Copy(upstream, conn)
closeFn()
}()
io.Copy(conn, upstream)
closeFn()
onClose(nil)
}
func (r *testRouter) RoutePacketConnectionEx(ctx context.Context, conn N.PacketConn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
onClose(nil)
}
func (r *testRouter) DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error) {
conn, err := net.Dial("udp", metadata.Destination.String())
if err != nil {
return nil, err
}
return bufio.NewUnbindPacketConn(conn), nil
}
func (r *testRouter) PreMatch(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
if r.preMatch != nil {
return r.preMatch(metadata, routeContext, timeout, supportBypass)
}
return nil, nil
}
func (r *testRouter) RuleSet(tag string) (adapter.RuleSet, bool) { return nil, false }
func (r *testRouter) Rules() []adapter.Rule { return nil }
func (r *testRouter) NeedFindProcess() bool { return false }
func (r *testRouter) NeedFindNeighbor() bool { return false }
func (r *testRouter) NeighborResolver() adapter.NeighborResolver { return nil }
func (r *testRouter) AppendTracker(tracker adapter.ConnectionTracker) {}
func (r *testRouter) ResetNetwork() {}
func newTestInbound(t *testing.T, token string, protocol string, haConnections int) *Inbound {
t.Helper()
credentials, err := parseToken(token)
if err != nil {
t.Fatal("parse token: ", err)
}
logFactory, err := log.New(log.Options{Options: option.LogOptions{Level: "debug"}})
if err != nil {
t.Fatal("create logger: ", err)
}
configManager, err := NewConfigManager()
if err != nil {
t.Fatal("create config manager: ", err)
}
ctx, cancel := context.WithCancel(context.Background())
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
ctx: ctx,
cancel: cancel,
router: &testRouter{},
logger: logFactory.NewLogger("test"),
credentials: credentials,
connectorID: uuid.New(),
haConnections: haConnections,
protocol: protocol,
edgeIPVersion: 0,
datagramVersion: "",
featureSelector: newFeatureSelector(ctx, credentials.AccountTag, ""),
gracePeriod: 5 * time.Second,
configManager: configManager,
datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer),
datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer),
datagramV3Manager: NewDatagramV3SessionManager(),
connectedIndices: make(map[uint8]struct{}),
connectedNotify: make(chan uint8, haConnections),
controlDialer: N.SystemDialer,
tunnelDialer: N.SystemDialer,
accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: N.SystemDialer},
connectionStates: make([]connectionState, haConnections),
successfulProtocols: make(map[string]struct{}),
directTransports: make(map[string]*http.Transport),
}
t.Cleanup(func() {
cancel()
inboundInstance.Close()
})
return inboundInstance
}
func waitForTunnel(t *testing.T, testURL string, timeout time.Duration) {
t.Helper()
deadline := time.Now().Add(timeout)
client := &http.Client{Timeout: 5 * time.Second}
var lastErr error
var lastStatus int
var lastBody string
for time.Now().Before(deadline) {
resp, err := client.Get(testURL + "/ping")
if err != nil {
lastErr = err
time.Sleep(500 * time.Millisecond)
continue
}
body, _ := io.ReadAll(resp.Body)
resp.Body.Close()
lastStatus = resp.StatusCode
lastBody = string(body)
if resp.StatusCode == http.StatusOK && lastBody == `{"ok":true}` {
return
}
time.Sleep(500 * time.Millisecond)
}
t.Fatalf("tunnel not ready after %s (lastErr=%v, lastStatus=%d, lastBody=%q)", timeout, lastErr, lastStatus, lastBody)
}

View File

@@ -1,577 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"encoding/binary"
"net/netip"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
const (
icmpFlowTimeout = 30 * time.Second
icmpTraceIdentityLength = 16 + 8 + 1
defaultICMPPacketTTL = 255
icmpErrorHeaderLen = 8
ipv4TTLExceededQuoteLen = 548
ipv6TTLExceededQuoteLen = 1232
maxICMPPayloadLen = 1280
icmpv4TypeEchoRequest = 8
icmpv4TypeEchoReply = 0
icmpv4TypeTimeExceeded = 11
icmpv6TypeEchoRequest = 128
icmpv6TypeEchoReply = 129
icmpv6TypeTimeExceeded = 3
)
type ICMPTraceContext struct {
Traced bool
Identity []byte
}
type ICMPFlowKey struct {
IPVersion uint8
SourceIP netip.Addr
Destination netip.Addr
}
type ICMPRequestKey struct {
Flow ICMPFlowKey
Identifier uint16
Sequence uint16
}
type ICMPPacketInfo struct {
IPVersion uint8
Protocol uint8
SourceIP netip.Addr
Destination netip.Addr
ICMPType uint8
ICMPCode uint8
Identifier uint16
Sequence uint16
IPv4HeaderLen int
IPv4TTL uint8
IPv6HopLimit uint8
RawPacket []byte
}
func (i ICMPPacketInfo) FlowKey() ICMPFlowKey {
return ICMPFlowKey{
IPVersion: i.IPVersion,
SourceIP: i.SourceIP,
Destination: i.Destination,
}
}
func (i ICMPPacketInfo) RequestKey() ICMPRequestKey {
return ICMPRequestKey{
Flow: i.FlowKey(),
Identifier: i.Identifier,
Sequence: i.Sequence,
}
}
func (i ICMPPacketInfo) ReplyRequestKey() ICMPRequestKey {
return ICMPRequestKey{
Flow: ICMPFlowKey{
IPVersion: i.IPVersion,
SourceIP: i.Destination,
Destination: i.SourceIP,
},
Identifier: i.Identifier,
Sequence: i.Sequence,
}
}
func (i ICMPPacketInfo) IsEchoRequest() bool {
switch i.IPVersion {
case 4:
return i.ICMPType == icmpv4TypeEchoRequest && i.ICMPCode == 0
case 6:
return i.ICMPType == icmpv6TypeEchoRequest && i.ICMPCode == 0
default:
return false
}
}
func (i ICMPPacketInfo) IsEchoReply() bool {
switch i.IPVersion {
case 4:
return i.ICMPType == icmpv4TypeEchoReply && i.ICMPCode == 0
case 6:
return i.ICMPType == icmpv6TypeEchoReply && i.ICMPCode == 0
default:
return false
}
}
func (i ICMPPacketInfo) TTL() uint8 {
if i.IPVersion == 4 {
return i.IPv4TTL
}
return i.IPv6HopLimit
}
func (i ICMPPacketInfo) TTLExpired() bool {
return i.TTL() <= 1
}
func (i *ICMPPacketInfo) DecrementTTL() error {
switch i.IPVersion {
case 4:
if i.IPv4TTL == 0 || i.IPv4HeaderLen < 20 || len(i.RawPacket) < i.IPv4HeaderLen {
return E.New("invalid IPv4 packet TTL state")
}
i.IPv4TTL--
i.RawPacket[8] = i.IPv4TTL
binary.BigEndian.PutUint16(i.RawPacket[10:12], 0)
binary.BigEndian.PutUint16(i.RawPacket[10:12], checksum(i.RawPacket[:i.IPv4HeaderLen], 0))
case 6:
if i.IPv6HopLimit == 0 || len(i.RawPacket) < 40 {
return E.New("invalid IPv6 packet hop limit state")
}
i.IPv6HopLimit--
i.RawPacket[7] = i.IPv6HopLimit
default:
return E.New("unsupported IP version: ", i.IPVersion)
}
return nil
}
type icmpWireVersion uint8
const (
icmpWireV2 icmpWireVersion = iota + 1
icmpWireV3
)
type icmpFlowState struct {
writer *ICMPReplyWriter
lastActive time.Time
}
type traceEntry struct {
context ICMPTraceContext
createdAt time.Time
}
type ICMPReplyWriter struct {
sender DatagramSender
wireVersion icmpWireVersion
access sync.Mutex
traces map[ICMPRequestKey]traceEntry
}
func NewICMPReplyWriter(sender DatagramSender, wireVersion icmpWireVersion) *ICMPReplyWriter {
return &ICMPReplyWriter{
sender: sender,
wireVersion: wireVersion,
traces: make(map[ICMPRequestKey]traceEntry),
}
}
func (w *ICMPReplyWriter) RegisterRequestTrace(packetInfo ICMPPacketInfo, traceContext ICMPTraceContext) {
if !traceContext.Traced {
return
}
w.access.Lock()
w.traces[packetInfo.RequestKey()] = traceEntry{
context: traceContext,
createdAt: time.Now(),
}
w.access.Unlock()
}
func (w *ICMPReplyWriter) WritePacket(packet []byte) error {
packetInfo, err := ParseICMPPacket(packet)
if err != nil {
return err
}
if !packetInfo.IsEchoReply() {
return nil
}
requestKey := packetInfo.ReplyRequestKey()
w.access.Lock()
entry, loaded := w.traces[requestKey]
if loaded {
delete(w.traces, requestKey)
}
w.access.Unlock()
traceContext := entry.context
datagram, err := encodeICMPDatagram(packetInfo.RawPacket, w.wireVersion, traceContext)
if err != nil {
return err
}
return w.sender.SendDatagram(datagram)
}
func (w *ICMPReplyWriter) cleanupExpired(now time.Time) {
w.access.Lock()
defer w.access.Unlock()
for key, entry := range w.traces {
if now.After(entry.createdAt.Add(icmpFlowTimeout)) {
delete(w.traces, key)
}
}
}
type ICMPBridge struct {
inbound *Inbound
sender DatagramSender
wireVersion icmpWireVersion
routeMapping *tun.DirectRouteMapping
flowAccess sync.Mutex
flows map[ICMPFlowKey]*icmpFlowState
}
func NewICMPBridge(inbound *Inbound, sender DatagramSender, wireVersion icmpWireVersion) *ICMPBridge {
bridge := &ICMPBridge{
inbound: inbound,
sender: sender,
wireVersion: wireVersion,
routeMapping: tun.NewDirectRouteMapping(icmpFlowTimeout),
flows: make(map[ICMPFlowKey]*icmpFlowState),
}
if inbound != nil && inbound.ctx != nil {
go bridge.cleanupLoop(inbound.ctx)
}
return bridge
}
func (b *ICMPBridge) HandleV2(ctx context.Context, datagramType DatagramV2Type, payload []byte) error {
traceContext := ICMPTraceContext{}
switch datagramType {
case DatagramV2TypeIP:
case DatagramV2TypeIPWithTrace:
if len(payload) < icmpTraceIdentityLength {
return E.New("icmp trace payload is too short")
}
traceContext.Traced = true
traceContext.Identity = append([]byte(nil), payload[len(payload)-icmpTraceIdentityLength:]...)
payload = payload[:len(payload)-icmpTraceIdentityLength]
default:
return E.New("unsupported v2 icmp datagram type: ", datagramType)
}
return b.handlePacket(ctx, payload, traceContext)
}
func (b *ICMPBridge) HandleV3(ctx context.Context, payload []byte) error {
return b.handlePacket(ctx, payload, ICMPTraceContext{})
}
func (b *ICMPBridge) handlePacket(ctx context.Context, payload []byte, traceContext ICMPTraceContext) error {
packetInfo, err := ParseICMPPacket(payload)
if err != nil {
return err
}
if !packetInfo.IsEchoRequest() {
return nil
}
if packetInfo.TTLExpired() {
ttlExceededPacket, err := buildICMPTTLExceededPacket(packetInfo)
if err != nil {
return err
}
datagram, err := encodeICMPDatagram(ttlExceededPacket, b.wireVersion, traceContext)
if err != nil {
return err
}
return b.sender.SendDatagram(datagram)
}
if err := packetInfo.DecrementTTL(); err != nil {
return err
}
state := b.getFlowState(packetInfo.FlowKey())
state.lastActive = time.Now()
if traceContext.Traced {
state.writer.RegisterRequestTrace(packetInfo, traceContext)
}
action, err := b.routeMapping.Lookup(tun.DirectRouteSession{
Source: packetInfo.SourceIP,
Destination: packetInfo.Destination,
}, func(timeout time.Duration) (tun.DirectRouteDestination, error) {
metadata := adapter.InboundContext{
Inbound: b.inbound.Tag(),
InboundType: b.inbound.Type(),
IPVersion: packetInfo.IPVersion,
Network: N.NetworkICMP,
Source: M.SocksaddrFrom(packetInfo.SourceIP, 0),
Destination: M.SocksaddrFrom(packetInfo.Destination, 0),
OriginDestination: M.SocksaddrFrom(packetInfo.Destination, 0),
}
return b.inbound.router.PreMatch(metadata, state.writer, timeout, false)
})
if err != nil {
return nil
}
return action.WritePacket(buf.As(packetInfo.RawPacket).ToOwned())
}
func (b *ICMPBridge) getFlowState(key ICMPFlowKey) *icmpFlowState {
b.flowAccess.Lock()
defer b.flowAccess.Unlock()
state, loaded := b.flows[key]
if loaded {
return state
}
state = &icmpFlowState{
writer: NewICMPReplyWriter(b.sender, b.wireVersion),
}
b.flows[key] = state
return state
}
func (b *ICMPBridge) cleanupLoop(ctx context.Context) {
ticker := time.NewTicker(icmpFlowTimeout)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case now := <-ticker.C:
b.cleanupExpired(now)
}
}
}
func (b *ICMPBridge) cleanupExpired(now time.Time) {
b.flowAccess.Lock()
defer b.flowAccess.Unlock()
for key, state := range b.flows {
state.writer.cleanupExpired(now)
if now.After(state.lastActive.Add(icmpFlowTimeout)) {
delete(b.flows, key)
}
}
}
func ParseICMPPacket(packet []byte) (ICMPPacketInfo, error) {
if len(packet) < 1 {
return ICMPPacketInfo{}, E.New("empty IP packet")
}
version := packet[0] >> 4
switch version {
case 4:
return parseIPv4ICMPPacket(packet)
case 6:
return parseIPv6ICMPPacket(packet)
default:
return ICMPPacketInfo{}, E.New("unsupported IP version: ", version)
}
}
func parseIPv4ICMPPacket(packet []byte) (ICMPPacketInfo, error) {
if len(packet) < 20 {
return ICMPPacketInfo{}, E.New("IPv4 packet too short")
}
headerLen := int(packet[0]&0x0F) * 4
if headerLen < 20 || len(packet) < headerLen+8 {
return ICMPPacketInfo{}, E.New("invalid IPv4 header length")
}
if packet[9] != 1 {
return ICMPPacketInfo{}, E.New("IPv4 packet is not ICMP")
}
sourceIP, ok := netip.AddrFromSlice(packet[12:16])
if !ok {
return ICMPPacketInfo{}, E.New("invalid IPv4 source address")
}
destinationIP, ok := netip.AddrFromSlice(packet[16:20])
if !ok {
return ICMPPacketInfo{}, E.New("invalid IPv4 destination address")
}
return ICMPPacketInfo{
IPVersion: 4,
Protocol: 1,
SourceIP: sourceIP,
Destination: destinationIP,
ICMPType: packet[headerLen],
ICMPCode: packet[headerLen+1],
Identifier: binary.BigEndian.Uint16(packet[headerLen+4 : headerLen+6]),
Sequence: binary.BigEndian.Uint16(packet[headerLen+6 : headerLen+8]),
IPv4HeaderLen: headerLen,
IPv4TTL: packet[8],
RawPacket: append([]byte(nil), packet...),
}, nil
}
func parseIPv6ICMPPacket(packet []byte) (ICMPPacketInfo, error) {
if len(packet) < 48 {
return ICMPPacketInfo{}, E.New("IPv6 packet too short")
}
if packet[6] != 58 {
return ICMPPacketInfo{}, E.New("IPv6 packet is not ICMP")
}
sourceIP, ok := netip.AddrFromSlice(packet[8:24])
if !ok {
return ICMPPacketInfo{}, E.New("invalid IPv6 source address")
}
destinationIP, ok := netip.AddrFromSlice(packet[24:40])
if !ok {
return ICMPPacketInfo{}, E.New("invalid IPv6 destination address")
}
return ICMPPacketInfo{
IPVersion: 6,
Protocol: 58,
SourceIP: sourceIP,
Destination: destinationIP,
ICMPType: packet[40],
ICMPCode: packet[41],
Identifier: binary.BigEndian.Uint16(packet[44:46]),
Sequence: binary.BigEndian.Uint16(packet[46:48]),
IPv6HopLimit: packet[7],
RawPacket: append([]byte(nil), packet...),
}, nil
}
func maxEncodedICMPPacketLen(wireVersion icmpWireVersion, traceContext ICMPTraceContext) int {
limit := maxV3UDPPayloadLen
switch wireVersion {
case icmpWireV2:
limit -= typeIDLength
if traceContext.Traced {
limit -= len(traceContext.Identity)
}
case icmpWireV3:
limit -= 1
default:
return 0
}
if limit < 0 {
return 0
}
return limit
}
func buildICMPTTLExceededPacket(packetInfo ICMPPacketInfo) ([]byte, error) {
switch packetInfo.IPVersion {
case 4:
return buildIPv4ICMPTTLExceededPacket(packetInfo)
case 6:
return buildIPv6ICMPTTLExceededPacket(packetInfo)
default:
return nil, E.New("unsupported IP version: ", packetInfo.IPVersion)
}
}
func buildIPv4ICMPTTLExceededPacket(packetInfo ICMPPacketInfo) ([]byte, error) {
const headerLen = 20
if !packetInfo.SourceIP.Is4() || !packetInfo.Destination.Is4() {
return nil, E.New("TTL exceeded packet requires IPv4 addresses")
}
quotedLength := min(len(packetInfo.RawPacket), ipv4TTLExceededQuoteLen)
packet := make([]byte, headerLen+icmpErrorHeaderLen+quotedLength)
packet[0] = 0x45
binary.BigEndian.PutUint16(packet[2:4], uint16(len(packet)))
packet[8] = defaultICMPPacketTTL
packet[9] = 1
copy(packet[12:16], packetInfo.Destination.AsSlice())
copy(packet[16:20], packetInfo.SourceIP.AsSlice())
packet[20] = icmpv4TypeTimeExceeded
packet[21] = 0
copy(packet[headerLen+icmpErrorHeaderLen:], packetInfo.RawPacket[:quotedLength])
binary.BigEndian.PutUint16(packet[22:24], checksum(packet[20:], 0))
binary.BigEndian.PutUint16(packet[10:12], checksum(packet[:headerLen], 0))
return packet, nil
}
func buildIPv6ICMPTTLExceededPacket(packetInfo ICMPPacketInfo) ([]byte, error) {
const headerLen = 40
if !packetInfo.SourceIP.Is6() || !packetInfo.Destination.Is6() {
return nil, E.New("TTL exceeded packet requires IPv6 addresses")
}
quotedLength := min(len(packetInfo.RawPacket), ipv6TTLExceededQuoteLen)
packet := make([]byte, headerLen+icmpErrorHeaderLen+quotedLength)
packet[0] = 0x60
binary.BigEndian.PutUint16(packet[4:6], uint16(icmpErrorHeaderLen+quotedLength))
packet[6] = 58
packet[7] = defaultICMPPacketTTL
copy(packet[8:24], packetInfo.Destination.AsSlice())
copy(packet[24:40], packetInfo.SourceIP.AsSlice())
packet[40] = icmpv6TypeTimeExceeded
packet[41] = 0
copy(packet[headerLen+icmpErrorHeaderLen:], packetInfo.RawPacket[:quotedLength])
binary.BigEndian.PutUint16(packet[42:44], checksum(packet[40:], ipv6PseudoHeaderChecksum(packetInfo.Destination, packetInfo.SourceIP, uint32(icmpErrorHeaderLen+quotedLength), 58)))
return packet, nil
}
func encodeICMPDatagram(packet []byte, wireVersion icmpWireVersion, traceContext ICMPTraceContext) ([]byte, error) {
switch wireVersion {
case icmpWireV2:
return encodeV2ICMPDatagram(packet, traceContext)
case icmpWireV3:
return encodeV3ICMPDatagram(packet)
default:
return nil, E.New("unsupported icmp wire version: ", wireVersion)
}
}
func ipv6PseudoHeaderChecksum(source, destination netip.Addr, payloadLength uint32, nextHeader uint8) uint32 {
var sum uint32
sum = checksumSum(source.AsSlice(), sum)
sum = checksumSum(destination.AsSlice(), sum)
var lengthBytes [4]byte
binary.BigEndian.PutUint32(lengthBytes[:], payloadLength)
sum = checksumSum(lengthBytes[:], sum)
sum = checksumSum([]byte{0, 0, 0, nextHeader}, sum)
return sum
}
func checksumSum(data []byte, sum uint32) uint32 {
for len(data) >= 2 {
sum += uint32(binary.BigEndian.Uint16(data[:2]))
data = data[2:]
}
if len(data) == 1 {
sum += uint32(data[0]) << 8
}
return sum
}
func checksum(data []byte, initial uint32) uint16 {
sum := checksumSum(data, initial)
for sum > 0xffff {
sum = (sum >> 16) + (sum & 0xffff)
}
return ^uint16(sum)
}
func encodeV2ICMPDatagram(packet []byte, _ ICMPTraceContext) ([]byte, error) {
data := make([]byte, 0, len(packet)+1)
data = append(data, packet...)
data = append(data, byte(DatagramV2TypeIP))
return data, nil
}
func encodeV3ICMPDatagram(packet []byte) ([]byte, error) {
if len(packet) == 0 {
return nil, E.New("icmp payload is missing")
}
if len(packet) > maxICMPPayloadLen {
return nil, E.New("icmp payload is too large")
}
data := make([]byte, 0, len(packet)+1)
data = append(data, byte(DatagramV3TypeICMP))
data = append(data, packet...)
return data, nil
}

View File

@@ -1,478 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"bytes"
"context"
"encoding/binary"
"net/netip"
"testing"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-tun"
"github.com/sagernet/sing/common/buf"
N "github.com/sagernet/sing/common/network"
)
type captureDatagramSender struct {
sent [][]byte
}
func (s *captureDatagramSender) SendDatagram(data []byte) error {
s.sent = append(s.sent, append([]byte(nil), data...))
return nil
}
type fakeDirectRouteDestination struct {
routeContext tun.DirectRouteContext
packets [][]byte
reply func(packet []byte) []byte
closed bool
}
func (d *fakeDirectRouteDestination) WritePacket(packet *buf.Buffer) error {
data := append([]byte(nil), packet.Bytes()...)
packet.Release()
d.packets = append(d.packets, data)
if d.reply != nil {
reply := d.reply(data)
if reply != nil {
return d.routeContext.WritePacket(reply)
}
}
return nil
}
func (d *fakeDirectRouteDestination) Close() error {
d.closed = true
return nil
}
func (d *fakeDirectRouteDestination) IsClosed() bool {
return d.closed
}
func TestICMPBridgeHandleV2RoutesEchoRequest(t *testing.T) {
var (
preMatchCalls int
captured adapter.InboundContext
destination *fakeDirectRouteDestination
)
router := &testRouter{
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
preMatchCalls++
captured = metadata
destination = &fakeDirectRouteDestination{routeContext: routeContext}
return destination, nil
},
}
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
router: router,
}
sender := &captureDatagramSender{}
bridge := NewICMPBridge(inboundInstance, sender, icmpWireV2)
source := netip.MustParseAddr("198.18.0.2")
target := netip.MustParseAddr("1.1.1.1")
packet1 := buildIPv4ICMPPacket(source, target, 8, 0, 1, 1)
packet2 := buildIPv4ICMPPacket(source, target, 8, 0, 1, 2)
if err := bridge.HandleV2(context.Background(), DatagramV2TypeIP, packet1); err != nil {
t.Fatal(err)
}
if err := bridge.HandleV2(context.Background(), DatagramV2TypeIP, packet2); err != nil {
t.Fatal(err)
}
if preMatchCalls != 1 {
t.Fatalf("expected one direct-route lookup, got %d", preMatchCalls)
}
if captured.Network != N.NetworkICMP {
t.Fatalf("expected NetworkICMP, got %s", captured.Network)
}
if captured.Source.Addr != source || captured.Destination.Addr != target {
t.Fatalf("unexpected metadata source/destination: %#v", captured)
}
if len(destination.packets) != 2 {
t.Fatalf("expected two packets written, got %d", len(destination.packets))
}
if len(sender.sent) != 0 {
t.Fatalf("expected no reply datagrams, got %d", len(sender.sent))
}
}
func TestICMPBridgeHandleV2TracedReply(t *testing.T) {
traceIdentity := bytes.Repeat([]byte{0x7a}, icmpTraceIdentityLength)
sender := &captureDatagramSender{}
router := &testRouter{
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
return &fakeDirectRouteDestination{
routeContext: routeContext,
reply: buildEchoReply,
}, nil
},
}
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
router: router,
}
bridge := NewICMPBridge(inboundInstance, sender, icmpWireV2)
request := buildIPv4ICMPPacket(netip.MustParseAddr("198.18.0.2"), netip.MustParseAddr("1.1.1.1"), 8, 0, 9, 7)
request = append(request, traceIdentity...)
if err := bridge.HandleV2(context.Background(), DatagramV2TypeIPWithTrace, request); err != nil {
t.Fatal(err)
}
if len(sender.sent) != 1 {
t.Fatalf("expected one reply datagram, got %d", len(sender.sent))
}
reply := sender.sent[0]
if reply[len(reply)-1] != byte(DatagramV2TypeIP) {
t.Fatalf("expected plain v2 IP reply, got type %d", reply[len(reply)-1])
}
if len(reply) != len(buildEchoReply(buildIPv4ICMPPacket(netip.MustParseAddr("198.18.0.2"), netip.MustParseAddr("1.1.1.1"), 8, 0, 9, 7)))+1 {
t.Fatalf("unexpected traced reply size: %d", len(reply))
}
}
func TestICMPBridgeHandleV3Reply(t *testing.T) {
sender := &captureDatagramSender{}
router := &testRouter{
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
return &fakeDirectRouteDestination{
routeContext: routeContext,
reply: buildEchoReply,
}, nil
},
}
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
router: router,
}
bridge := NewICMPBridge(inboundInstance, sender, icmpWireV3)
request := buildIPv6ICMPPacket(netip.MustParseAddr("2001:db8::2"), netip.MustParseAddr("2606:4700:4700::1111"), 128, 0, 3, 5)
if err := bridge.HandleV3(context.Background(), request); err != nil {
t.Fatal(err)
}
if len(sender.sent) != 1 {
t.Fatalf("expected one reply datagram, got %d", len(sender.sent))
}
reply := sender.sent[0]
if reply[0] != byte(DatagramV3TypeICMP) {
t.Fatalf("expected v3 ICMP datagram, got %d", reply[0])
}
}
func TestICMPBridgeDecrementsIPv4TTLBeforeRouting(t *testing.T) {
var destination *fakeDirectRouteDestination
router := &testRouter{
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
destination = &fakeDirectRouteDestination{routeContext: routeContext}
return destination, nil
},
}
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
router: router,
}
bridge := NewICMPBridge(inboundInstance, &captureDatagramSender{}, icmpWireV2)
packet := buildIPv4ICMPPacket(netip.MustParseAddr("198.18.0.2"), netip.MustParseAddr("1.1.1.1"), icmpv4TypeEchoRequest, 0, 1, 1)
packet[8] = 5
if err := bridge.HandleV2(context.Background(), DatagramV2TypeIP, packet); err != nil {
t.Fatal(err)
}
if len(destination.packets) != 1 {
t.Fatalf("expected one routed packet, got %d", len(destination.packets))
}
if got := destination.packets[0][8]; got != 4 {
t.Fatalf("expected decremented IPv4 TTL, got %d", got)
}
}
func TestICMPBridgeDecrementsIPv6HopLimitBeforeRouting(t *testing.T) {
var destination *fakeDirectRouteDestination
router := &testRouter{
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
destination = &fakeDirectRouteDestination{routeContext: routeContext}
return destination, nil
},
}
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
router: router,
}
bridge := NewICMPBridge(inboundInstance, &captureDatagramSender{}, icmpWireV3)
packet := buildIPv6ICMPPacket(netip.MustParseAddr("2001:db8::2"), netip.MustParseAddr("2606:4700:4700::1111"), icmpv6TypeEchoRequest, 0, 1, 1)
packet[7] = 3
if err := bridge.HandleV3(context.Background(), packet); err != nil {
t.Fatal(err)
}
if len(destination.packets) != 1 {
t.Fatalf("expected one routed packet, got %d", len(destination.packets))
}
if got := destination.packets[0][7]; got != 2 {
t.Fatalf("expected decremented IPv6 hop limit, got %d", got)
}
}
func TestICMPBridgeHandleV2TTLExceededTracedReply(t *testing.T) {
var preMatchCalls int
traceIdentity := bytes.Repeat([]byte{0x6b}, icmpTraceIdentityLength)
sender := &captureDatagramSender{}
router := &testRouter{
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
preMatchCalls++
return nil, nil
},
}
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
router: router,
}
bridge := NewICMPBridge(inboundInstance, sender, icmpWireV2)
source := netip.MustParseAddr("198.18.0.2")
target := netip.MustParseAddr("1.1.1.1")
packet := buildIPv4ICMPPacket(source, target, icmpv4TypeEchoRequest, 0, 1, 1)
packet[8] = 1
packet = append(packet, traceIdentity...)
if err := bridge.HandleV2(context.Background(), DatagramV2TypeIPWithTrace, packet); err != nil {
t.Fatal(err)
}
if preMatchCalls != 0 {
t.Fatalf("expected TTL exceeded to bypass routing, got %d route lookups", preMatchCalls)
}
if len(sender.sent) != 1 {
t.Fatalf("expected one TTL exceeded reply, got %d", len(sender.sent))
}
reply := sender.sent[0]
if reply[len(reply)-1] != byte(DatagramV2TypeIP) {
t.Fatalf("expected plain v2 reply, got type %d", reply[len(reply)-1])
}
rawReply := reply[:len(reply)-1]
packetInfo, err := ParseICMPPacket(rawReply)
if err != nil {
t.Fatal(err)
}
if packetInfo.ICMPType != icmpv4TypeTimeExceeded || packetInfo.ICMPCode != 0 {
t.Fatalf("expected IPv4 time exceeded reply, got type=%d code=%d", packetInfo.ICMPType, packetInfo.ICMPCode)
}
if packetInfo.SourceIP != target || packetInfo.Destination != source {
t.Fatalf("unexpected TTL exceeded routing: src=%s dst=%s", packetInfo.SourceIP, packetInfo.Destination)
}
if packetInfo.TTL() != 255 {
t.Fatalf("expected TTL exceeded packet TTL 255, got %d", packetInfo.TTL())
}
}
func TestICMPBridgeHandleV3TTLExceededReply(t *testing.T) {
var preMatchCalls int
sender := &captureDatagramSender{}
router := &testRouter{
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
preMatchCalls++
return nil, nil
},
}
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
router: router,
}
bridge := NewICMPBridge(inboundInstance, sender, icmpWireV3)
source := netip.MustParseAddr("2001:db8::2")
target := netip.MustParseAddr("2606:4700:4700::1111")
packet := buildIPv6ICMPPacket(source, target, icmpv6TypeEchoRequest, 0, 1, 1)
packet[7] = 1
if err := bridge.HandleV3(context.Background(), packet); err != nil {
t.Fatal(err)
}
if preMatchCalls != 0 {
t.Fatalf("expected TTL exceeded to bypass routing, got %d route lookups", preMatchCalls)
}
if len(sender.sent) != 1 {
t.Fatalf("expected one TTL exceeded reply, got %d", len(sender.sent))
}
if sender.sent[0][0] != byte(DatagramV3TypeICMP) {
t.Fatalf("expected v3 ICMP reply, got %d", sender.sent[0][0])
}
packetInfo, err := ParseICMPPacket(sender.sent[0][1:])
if err != nil {
t.Fatal(err)
}
if packetInfo.ICMPType != icmpv6TypeTimeExceeded || packetInfo.ICMPCode != 0 {
t.Fatalf("expected IPv6 time exceeded reply, got type=%d code=%d", packetInfo.ICMPType, packetInfo.ICMPCode)
}
if packetInfo.SourceIP != target || packetInfo.Destination != source {
t.Fatalf("unexpected TTL exceeded routing: src=%s dst=%s", packetInfo.SourceIP, packetInfo.Destination)
}
if packetInfo.TTL() != 255 {
t.Fatalf("expected TTL exceeded packet TTL 255, got %d", packetInfo.TTL())
}
}
func TestICMPBridgeDropsNonEcho(t *testing.T) {
var preMatchCalls int
router := &testRouter{
preMatch: func(metadata adapter.InboundContext, routeContext tun.DirectRouteContext, timeout time.Duration, supportBypass bool) (tun.DirectRouteDestination, error) {
preMatchCalls++
return nil, nil
},
}
inboundInstance := &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
router: router,
}
sender := &captureDatagramSender{}
bridge := NewICMPBridge(inboundInstance, sender, icmpWireV2)
packet := buildIPv4ICMPPacket(netip.MustParseAddr("198.18.0.2"), netip.MustParseAddr("1.1.1.1"), 3, 0, 1, 1)
if err := bridge.HandleV2(context.Background(), DatagramV2TypeIP, packet); err != nil {
t.Fatal(err)
}
if preMatchCalls != 0 {
t.Fatalf("expected no route lookup, got %d", preMatchCalls)
}
if len(sender.sent) != 0 {
t.Fatalf("expected no sender datagrams, got %d", len(sender.sent))
}
}
func TestBuildICMPTTLExceededPacketUsesRFCQuoteLengths(t *testing.T) {
ipv4Packet := buildIPv4ICMPPacket(netip.MustParseAddr("198.18.0.2"), netip.MustParseAddr("1.1.1.1"), icmpv4TypeEchoRequest, 0, 1, 1)
ipv4Packet = append(ipv4Packet, bytes.Repeat([]byte{0xaa}, 4096)...)
ipv4Info, err := ParseICMPPacket(ipv4Packet)
if err != nil {
t.Fatal(err)
}
ipv4Reply, err := buildICMPTTLExceededPacket(ipv4Info)
if err != nil {
t.Fatal(err)
}
if len(ipv4Reply) != 20+icmpErrorHeaderLen+ipv4TTLExceededQuoteLen {
t.Fatalf("unexpected IPv4 TTL exceeded size: %d", len(ipv4Reply))
}
ipv6Packet := buildIPv6ICMPPacket(netip.MustParseAddr("2001:db8::2"), netip.MustParseAddr("2606:4700:4700::1111"), icmpv6TypeEchoRequest, 0, 1, 1)
ipv6Packet = append(ipv6Packet, bytes.Repeat([]byte{0xbb}, 4096)...)
ipv6Info, err := ParseICMPPacket(ipv6Packet)
if err != nil {
t.Fatal(err)
}
ipv6Reply, err := buildICMPTTLExceededPacket(ipv6Info)
if err != nil {
t.Fatal(err)
}
if len(ipv6Reply) != 40+icmpErrorHeaderLen+ipv6TTLExceededQuoteLen {
t.Fatalf("unexpected IPv6 TTL exceeded size: %d", len(ipv6Reply))
}
}
func TestEncodeV3ICMPDatagramRejectsEmptyPayload(t *testing.T) {
if _, err := encodeV3ICMPDatagram(nil); err == nil {
t.Fatal("expected empty payload to be rejected")
}
}
func TestEncodeV3ICMPDatagramRejectsOversizedPayload(t *testing.T) {
if _, err := encodeV3ICMPDatagram(make([]byte, maxICMPPayloadLen+1)); err == nil {
t.Fatal("expected oversized payload to be rejected")
}
}
func TestICMPBridgeCleanupExpired(t *testing.T) {
bridge := NewICMPBridge(&Inbound{}, &captureDatagramSender{}, icmpWireV2)
now := time.Now()
expiredKey := ICMPFlowKey{
IPVersion: 4,
SourceIP: netip.MustParseAddr("198.18.0.2"),
Destination: netip.MustParseAddr("1.1.1.1"),
}
expiredState := bridge.getFlowState(expiredKey)
expiredState.lastActive = now.Add(-icmpFlowTimeout - time.Second)
expiredState.writer.traces[ICMPRequestKey{Flow: expiredKey, Identifier: 1, Sequence: 1}] = traceEntry{
context: ICMPTraceContext{Traced: true, Identity: []byte{1}},
createdAt: now.Add(-icmpFlowTimeout - time.Second),
}
activeKey := ICMPFlowKey{
IPVersion: 6,
SourceIP: netip.MustParseAddr("2001:db8::2"),
Destination: netip.MustParseAddr("2606:4700:4700::1111"),
}
activeState := bridge.getFlowState(activeKey)
activeState.lastActive = now
activeState.writer.traces[ICMPRequestKey{Flow: activeKey, Identifier: 2, Sequence: 2}] = traceEntry{
context: ICMPTraceContext{Traced: true, Identity: []byte{2}},
createdAt: now,
}
bridge.cleanupExpired(now)
if _, exists := bridge.flows[expiredKey]; exists {
t.Fatal("expected expired flow to be removed")
}
if _, exists := bridge.flows[activeKey]; !exists {
t.Fatal("expected active flow to remain")
}
if len(activeState.writer.traces) != 1 {
t.Fatalf("expected active trace to remain, got %d", len(activeState.writer.traces))
}
}
func buildEchoReply(packet []byte) []byte {
info, err := ParseICMPPacket(packet)
if err != nil {
panic(err)
}
switch info.IPVersion {
case 4:
return buildIPv4ICMPPacket(info.Destination, info.SourceIP, 0, 0, info.Identifier, info.Sequence)
case 6:
return buildIPv6ICMPPacket(info.Destination, info.SourceIP, 129, 0, info.Identifier, info.Sequence)
default:
panic("unsupported version")
}
}
func buildIPv4ICMPPacket(source, destination netip.Addr, icmpType, icmpCode uint8, identifier, sequence uint16) []byte {
packet := make([]byte, 28)
packet[0] = 0x45
binary.BigEndian.PutUint16(packet[2:4], uint16(len(packet)))
packet[8] = 64
packet[9] = 1
copy(packet[12:16], source.AsSlice())
copy(packet[16:20], destination.AsSlice())
packet[20] = icmpType
packet[21] = icmpCode
binary.BigEndian.PutUint16(packet[24:26], identifier)
binary.BigEndian.PutUint16(packet[26:28], sequence)
return packet
}
func buildIPv6ICMPPacket(source, destination netip.Addr, icmpType, icmpCode uint8, identifier, sequence uint16) []byte {
packet := make([]byte, 48)
packet[0] = 0x60
binary.BigEndian.PutUint16(packet[4:6], 8)
packet[6] = 58
packet[7] = 64
copy(packet[8:24], source.AsSlice())
copy(packet[24:40], destination.AsSlice())
packet[40] = icmpType
packet[41] = icmpCode
binary.BigEndian.PutUint16(packet[44:46], identifier)
binary.BigEndian.PutUint16(packet[46:48], sequence)
return packet
}

View File

@@ -4,149 +4,32 @@ package cloudflare
import (
"context"
"encoding/base64"
"errors"
"io"
"math/rand"
"net/http"
"runtime/debug"
"net"
"sync"
"time"
cloudflared "github.com/sagernet/sing-cloudflared"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
boxDialer "github.com/sagernet/sing-box/common/dialer"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
"github.com/sagernet/sing/common"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
"github.com/sagernet/sing/common/json"
"github.com/sagernet/sing/common/json/badoption"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/google/uuid"
"github.com/sagernet/sing/common/pipe"
tun "github.com/sagernet/sing-tun"
)
func RegisterInbound(registry *inbound.Registry) {
inbound.Register[option.CloudflaredInboundOptions](registry, C.TypeCloudflared, NewInbound)
}
var ErrNonRemoteManagedTunnelUnsupported = errors.New("cloudflared only supports remote-managed tunnels")
var (
newQUICConnection = NewQUICConnection
newHTTP2Connection = NewHTTP2Connection
serveQUICConnection = func(connection *QUICConnection, ctx context.Context, handler StreamHandler) error {
return connection.Serve(ctx, handler)
}
serveHTTP2Connection = func(connection *HTTP2Connection, ctx context.Context) error {
return connection.Serve(ctx)
}
)
type Inbound struct {
inbound.Adapter
ctx context.Context
cancel context.CancelFunc
router adapter.Router
logger log.ContextLogger
credentials Credentials
connectorID uuid.UUID
haConnections int
protocol string
region string
edgeIPVersion int
datagramVersion string
featureSelector *featureSelector
gracePeriod time.Duration
configManager *ConfigManager
flowLimiter *FlowLimiter
accessCache *accessValidatorCache
controlDialer N.Dialer
tunnelDialer N.Dialer
connectionAccess sync.Mutex
connections []io.Closer
done sync.WaitGroup
datagramMuxerAccess sync.Mutex
datagramV2Muxers map[DatagramSender]*DatagramV2Muxer
datagramV3Muxers map[DatagramSender]*DatagramV3Muxer
datagramV3Manager *DatagramV3SessionManager
connectedAccess sync.Mutex
connectedIndices map[uint8]struct{}
connectedNotify chan uint8
stateAccess sync.Mutex
connectionStates []connectionState
successfulProtocols map[string]struct{}
firstSuccessfulProtocol string
directTransportAccess sync.Mutex
directTransports map[string]*http.Transport
}
type connectionState struct {
protocol string
retries uint8
}
func resolveGracePeriod(value *badoption.Duration) time.Duration {
if value == nil {
return 30 * time.Second
}
return time.Duration(*value)
}
func connectionRetryDecision(err error) (retry bool, cancelAll bool) {
switch {
case err == nil:
return false, false
case errors.Is(err, ErrNonRemoteManagedTunnelUnsupported):
return false, true
case isPermanentRegistrationError(err):
return false, false
default:
return true, false
}
}
func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLogger, tag string, options option.CloudflaredInboundOptions) (adapter.Inbound, error) {
if options.Token == "" {
return nil, E.New("missing token")
}
credentials, err := parseToken(options.Token)
if err != nil {
return nil, E.Cause(err, "parse token")
}
haConnections := options.HAConnections
if haConnections <= 0 {
haConnections = 4
}
protocol, err := normalizeProtocol(options.Protocol)
if err != nil {
return nil, err
}
edgeIPVersion := options.EdgeIPVersion
if edgeIPVersion != 0 && edgeIPVersion != 4 && edgeIPVersion != 6 {
return nil, E.New("unsupported edge_ip_version: ", edgeIPVersion, ", expected 0, 4 or 6")
}
datagramVersion := options.DatagramVersion
if datagramVersion != "" && datagramVersion != "v2" && datagramVersion != "v3" {
return nil, E.New("unsupported datagram_version: ", datagramVersion, ", expected v2 or v3")
}
gracePeriod := resolveGracePeriod(options.GracePeriod)
configManager, err := NewConfigManager()
if err != nil {
return nil, E.Cause(err, "build cloudflared runtime config")
}
controlDialer, err := boxDialer.NewWithOptions(boxDialer.Options{
Context: ctx,
Options: options.ControlDialer,
@@ -163,457 +46,131 @@ func NewInbound(ctx context.Context, router adapter.Router, logger log.ContextLo
return nil, E.Cause(err, "build cloudflared tunnel dialer")
}
region := options.Region
if region != "" && credentials.Endpoint != "" {
return nil, E.New("region cannot be specified when credentials already include an endpoint")
service, err := cloudflared.NewService(cloudflared.ServiceOptions{
Logger: logger,
ConnectionDialer: &routerDialer{router: router, tag: tag},
ControlDialer: controlDialer,
TunnelDialer: tunnelDialer,
ICMPHandler: &icmpRouterHandler{router: router, tag: tag},
ConnContext: func(ctx context.Context) context.Context {
return adapter.WithContext(ctx, &adapter.InboundContext{
Inbound: tag,
InboundType: C.TypeCloudflared,
})
},
Token: options.Token,
HAConnections: options.HAConnections,
Protocol: options.Protocol,
PostQuantum: options.PostQuantum,
EdgeIPVersion: options.EdgeIPVersion,
DatagramVersion: options.DatagramVersion,
GracePeriod: resolveGracePeriod(options.GracePeriod),
Region: options.Region,
})
if err != nil {
return nil, err
}
if region == "" {
region = credentials.Endpoint
}
inboundCtx, cancel := context.WithCancel(ctx)
return &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, tag),
ctx: inboundCtx,
cancel: cancel,
router: router,
logger: logger,
credentials: credentials,
connectorID: uuid.New(),
haConnections: haConnections,
protocol: protocol,
region: region,
edgeIPVersion: edgeIPVersion,
datagramVersion: datagramVersion,
featureSelector: newFeatureSelector(inboundCtx, credentials.AccountTag, datagramVersion),
gracePeriod: gracePeriod,
configManager: configManager,
flowLimiter: &FlowLimiter{},
accessCache: &accessValidatorCache{values: make(map[string]accessValidator), dialer: controlDialer},
controlDialer: controlDialer,
tunnelDialer: tunnelDialer,
datagramV2Muxers: make(map[DatagramSender]*DatagramV2Muxer),
datagramV3Muxers: make(map[DatagramSender]*DatagramV3Muxer),
datagramV3Manager: NewDatagramV3SessionManager(),
connectedIndices: make(map[uint8]struct{}),
connectedNotify: make(chan uint8, haConnections),
connectionStates: make([]connectionState, haConnections),
successfulProtocols: make(map[string]struct{}),
directTransports: make(map[string]*http.Transport),
Adapter: inbound.NewAdapter(C.TypeCloudflared, tag),
service: service,
}, nil
}
type Inbound struct {
inbound.Adapter
service *cloudflared.Service
}
func (i *Inbound) Start(stage adapter.StartStage) error {
if stage != adapter.StartStateStart {
return nil
}
i.logger.Info("starting Cloudflare Tunnel with ", i.haConnections, " HA connections")
regions, err := DiscoverEdge(i.ctx, i.region, i.controlDialer)
if err != nil {
return E.Cause(err, "discover edge")
}
regions = FilterByIPVersion(regions, i.edgeIPVersion)
edgeAddrs := flattenRegions(regions)
if len(edgeAddrs) == 0 {
return E.New("no edge addresses available")
}
if cappedHAConnections := effectiveHAConnections(i.haConnections, len(edgeAddrs)); cappedHAConnections != i.haConnections {
i.logger.Info("requested ", i.haConnections, " HA connections but only ", cappedHAConnections, " edge addresses are available")
i.haConnections = cappedHAConnections
}
for connIndex := 0; connIndex < i.haConnections; connIndex++ {
i.initializeConnectionState(uint8(connIndex))
i.done.Add(1)
go i.superviseConnection(uint8(connIndex), edgeAddrs)
select {
case readyConnIndex := <-i.connectedNotify:
if readyConnIndex != uint8(connIndex) {
i.logger.Debug("received unexpected ready notification for connection ", readyConnIndex)
}
case <-time.After(firstConnectionReadyTimeout):
case <-i.ctx.Done():
if connIndex == 0 {
return i.ctx.Err()
}
return nil
}
}
return nil
}
func (i *Inbound) notifyConnected(connIndex uint8, protocol string) {
i.stateAccess.Lock()
if i.successfulProtocols == nil {
i.successfulProtocols = make(map[string]struct{})
}
i.ensureConnectionStateLocked(connIndex)
state := i.connectionStates[connIndex]
state.retries = 0
state.protocol = protocol
i.connectionStates[connIndex] = state
if protocol != "" {
i.successfulProtocols[protocol] = struct{}{}
if i.firstSuccessfulProtocol == "" {
i.firstSuccessfulProtocol = protocol
}
}
i.stateAccess.Unlock()
if i.connectedNotify == nil {
return
}
i.connectedAccess.Lock()
if _, loaded := i.connectedIndices[connIndex]; loaded {
i.connectedAccess.Unlock()
return
}
i.connectedIndices[connIndex] = struct{}{}
i.connectedAccess.Unlock()
i.connectedNotify <- connIndex
}
func (i *Inbound) ApplyConfig(version int32, config []byte) ConfigUpdateResult {
result := i.configManager.Apply(version, config)
if result.Err != nil {
i.logger.Error("update ingress configuration: ", result.Err)
return result
}
i.resetDirectOriginTransports()
i.logger.Info("updated ingress configuration (version ", result.LastAppliedVersion, ")")
return result
}
func (i *Inbound) maxActiveFlows() uint64 {
return i.configManager.Snapshot().WarpRouting.MaxActiveFlows
return i.service.Start()
}
func (i *Inbound) Close() error {
i.cancel()
i.done.Wait()
i.connectionAccess.Lock()
for _, connection := range i.connections {
connection.Close()
}
i.connections = nil
i.connectionAccess.Unlock()
i.resetDirectOriginTransports()
return nil
return i.service.Close()
}
const (
backoffBaseTime = time.Second
backoffMaxTime = 2 * time.Minute
firstConnectionReadyTimeout = 15 * time.Second
)
func (i *Inbound) superviseConnection(connIndex uint8, edgeAddrs []*EdgeAddr) {
defer i.done.Done()
edgeIndex := initialEdgeAddrIndex(connIndex, len(edgeAddrs))
for {
select {
case <-i.ctx.Done():
return
default:
}
edgeAddr := edgeAddrs[edgeIndex]
err := i.safeServeConnection(connIndex, edgeAddr)
if err == nil || i.ctx.Err() != nil {
return
}
retry, cancelAll := connectionRetryDecision(err)
if cancelAll {
i.logger.Error("connection ", connIndex, " failed permanently: ", err)
i.cancel()
return
}
if !retry {
i.logger.Error("connection ", connIndex, " failed permanently: ", err)
return
}
retries := i.incrementConnectionRetries(connIndex)
edgeIndex = rotateEdgeAddrIndex(edgeIndex, len(edgeAddrs))
backoff := backoffDuration(int(retries))
var retryableErr *RetryableError
if errors.As(err, &retryableErr) && retryableErr.Delay > 0 {
backoff = retryableErr.Delay
}
i.logger.Error("connection ", connIndex, " failed: ", err, ", retrying in ", backoff)
select {
case <-time.After(backoff):
case <-i.ctx.Done():
return
}
}
}
func (i *Inbound) serveConnection(connIndex uint8, edgeAddr *EdgeAddr) error {
state := i.connectionState(connIndex)
protocol := state.protocol
numPreviousAttempts := state.retries
datagramVersion, features := i.currentConnectionFeatures()
switch protocol {
case "quic":
err := i.serveQUIC(connIndex, edgeAddr, datagramVersion, features, numPreviousAttempts)
if err == nil || i.ctx.Err() != nil {
return err
}
if errors.Is(err, ErrNonRemoteManagedTunnelUnsupported) {
return err
}
if !i.protocolIsAuto() {
return err
}
if i.hasSuccessfulProtocol("quic") {
return err
}
i.setConnectionProtocol(connIndex, "http2")
i.logger.Warn("QUIC connection failed, falling back to HTTP/2: ", err)
return i.serveHTTP2(connIndex, edgeAddr, features, numPreviousAttempts)
case "http2":
return i.serveHTTP2(connIndex, edgeAddr, features, numPreviousAttempts)
default:
return E.New("unsupported protocol: ", protocol)
}
}
func (i *Inbound) safeServeConnection(connIndex uint8, edgeAddr *EdgeAddr) (err error) {
defer func() {
if recovered := recover(); recovered != nil {
err = E.New("panic in serve connection: ", recovered, "\n", string(debug.Stack()))
}
}()
return i.serveConnection(connIndex, edgeAddr)
}
func (i *Inbound) serveQUIC(connIndex uint8, edgeAddr *EdgeAddr, datagramVersion string, features []string, numPreviousAttempts uint8) error {
i.logger.Info("connecting to edge via QUIC (connection ", connIndex, ")")
connection, err := newQUICConnection(
i.ctx, edgeAddr, connIndex,
i.credentials, i.connectorID, datagramVersion,
features, numPreviousAttempts, i.gracePeriod, i.tunnelDialer, func() {
i.notifyConnected(connIndex, "quic")
}, i.logger,
)
if err != nil {
return E.Cause(err, "create QUIC connection")
}
i.trackConnection(connection)
defer func() {
i.untrackConnection(connection)
i.RemoveDatagramMuxer(connection)
}()
return serveQUICConnection(connection, i.ctx, i)
}
func (i *Inbound) currentConnectionFeatures() (string, []string) {
if i.featureSelector != nil {
return i.featureSelector.Snapshot()
}
version := i.datagramVersion
if version == "" {
version = defaultDatagramVersion
}
return version, DefaultFeatures(version)
}
func (i *Inbound) serveHTTP2(connIndex uint8, edgeAddr *EdgeAddr, features []string, numPreviousAttempts uint8) error {
i.logger.Info("connecting to edge via HTTP/2 (connection ", connIndex, ")")
connection, err := newHTTP2Connection(
i.ctx, edgeAddr, connIndex,
i.credentials, i.connectorID,
features, numPreviousAttempts, i.gracePeriod, i, i.logger,
)
if err != nil {
return E.Cause(err, "create HTTP/2 connection")
}
i.trackConnection(connection)
defer i.untrackConnection(connection)
return serveHTTP2Connection(connection, i.ctx)
}
func (i *Inbound) initializeConnectionState(connIndex uint8) {
i.stateAccess.Lock()
defer i.stateAccess.Unlock()
i.ensureConnectionStateLocked(connIndex)
if i.connectionStates[connIndex].protocol == "" {
i.connectionStates[connIndex].protocol = i.initialProtocolLocked()
}
}
func (i *Inbound) connectionState(connIndex uint8) connectionState {
i.stateAccess.Lock()
defer i.stateAccess.Unlock()
i.ensureConnectionStateLocked(connIndex)
state := i.connectionStates[connIndex]
if state.protocol == "" {
state.protocol = i.initialProtocolLocked()
i.connectionStates[connIndex] = state
}
return state
}
func (i *Inbound) incrementConnectionRetries(connIndex uint8) uint8 {
i.stateAccess.Lock()
defer i.stateAccess.Unlock()
i.ensureConnectionStateLocked(connIndex)
state := i.connectionStates[connIndex]
state.retries++
i.connectionStates[connIndex] = state
return state.retries
}
func (i *Inbound) setConnectionProtocol(connIndex uint8, protocol string) {
i.stateAccess.Lock()
defer i.stateAccess.Unlock()
i.ensureConnectionStateLocked(connIndex)
state := i.connectionStates[connIndex]
state.protocol = protocol
i.connectionStates[connIndex] = state
}
func (i *Inbound) hasSuccessfulProtocol(protocol string) bool {
i.stateAccess.Lock()
defer i.stateAccess.Unlock()
if i.successfulProtocols == nil {
return false
}
_, ok := i.successfulProtocols[protocol]
return ok
}
func (i *Inbound) protocolIsAuto() bool {
return i.protocol == ""
}
func (i *Inbound) ensureConnectionStateLocked(connIndex uint8) {
requiredLen := int(connIndex) + 1
if len(i.connectionStates) >= requiredLen {
return
}
grown := make([]connectionState, requiredLen)
copy(grown, i.connectionStates)
i.connectionStates = grown
}
func (i *Inbound) initialProtocolLocked() string {
if i.protocol != "" {
return i.protocol
}
if i.firstSuccessfulProtocol != "" {
return i.firstSuccessfulProtocol
}
return "quic"
}
func (i *Inbound) resetDirectOriginTransports() {
i.directTransportAccess.Lock()
transports := i.directTransports
i.directTransports = make(map[string]*http.Transport)
i.directTransportAccess.Unlock()
for _, transport := range transports {
transport.CloseIdleConnections()
}
}
func (i *Inbound) trackConnection(connection io.Closer) {
i.connectionAccess.Lock()
defer i.connectionAccess.Unlock()
i.connections = append(i.connections, connection)
}
func (i *Inbound) untrackConnection(connection io.Closer) {
i.connectionAccess.Lock()
defer i.connectionAccess.Unlock()
for index, tracked := range i.connections {
if tracked == connection {
i.connections = append(i.connections[:index], i.connections[index+1:]...)
break
}
}
}
func backoffDuration(retries int) time.Duration {
backoff := backoffBaseTime * (1 << min(retries, 7))
if backoff > backoffMaxTime {
backoff = backoffMaxTime
}
// Add jitter: random duration in [backoff/2, backoff)
jitter := time.Duration(rand.Int63n(int64(backoff / 2)))
return backoff/2 + jitter
}
func initialEdgeAddrIndex(connIndex uint8, size int) int {
if size <= 1 {
func resolveGracePeriod(value *badoption.Duration) time.Duration {
if value == nil {
return 0
}
return int(connIndex) % size
return time.Duration(*value)
}
func rotateEdgeAddrIndex(current int, size int) int {
if size <= 1 {
return 0
}
return (current + 1) % size
// routerDialer bridges N.Dialer to the sing-box router for origin connections.
type routerDialer struct {
router adapter.Router
tag string
}
func flattenRegions(regions [][]*EdgeAddr) []*EdgeAddr {
var result []*EdgeAddr
for _, region := range regions {
result = append(result, region...)
func (d *routerDialer) DialContext(ctx context.Context, network string, destination M.Socksaddr) (net.Conn, error) {
input, output := pipe.Pipe()
done := make(chan struct{})
metadata := adapter.InboundContext{
Inbound: d.tag,
InboundType: C.TypeCloudflared,
Network: N.NetworkTCP,
Destination: destination,
}
return result
var closeOnce sync.Once
closePipe := func() {
closeOnce.Do(func() {
common.Close(input, output)
})
}
go d.router.RouteConnectionEx(ctx, output, metadata, N.OnceClose(func(it error) {
closePipe()
close(done)
}))
return input, nil
}
func effectiveHAConnections(requested, available int) int {
if available <= 0 {
return 0
func (d *routerDialer) ListenPacket(ctx context.Context, destination M.Socksaddr) (net.PacketConn, error) {
originDialer, ok := d.router.(routedOriginPacketDialer)
if !ok {
return nil, E.New("router does not support cloudflare routed packet dialing")
}
if requested > available {
return available
}
return requested
}
func parseToken(token string) (Credentials, error) {
data, err := base64.StdEncoding.DecodeString(token)
packetConn, err := originDialer.DialRoutePacketConnection(ctx, adapter.InboundContext{
Inbound: d.tag,
InboundType: C.TypeCloudflared,
Network: N.NetworkUDP,
Destination: destination,
UDPConnect: true,
})
if err != nil {
return Credentials{}, E.Cause(err, "decode token")
return nil, err
}
var tunnelToken TunnelToken
err = json.Unmarshal(data, &tunnelToken)
if err != nil {
return Credentials{}, E.Cause(err, "unmarshal token")
}
return tunnelToken.ToCredentials(), nil
return bufio.NewNetPacketConn(packetConn), nil
}
// "auto" does not choose a transport here. We normalize it to an empty
// sentinel so serveConnection can apply the token-style behavior later.
// In the token-provided, remotely-managed tunnel path supported here, that
// matches cloudflared's NewProtocolSelector(..., tunnelTokenProvided=true)
// branch rather than the non-token remote-percentage selector.
func normalizeProtocol(protocol string) (string, error) {
if protocol == "auto" {
return "", nil
}
if protocol != "" && protocol != "quic" && protocol != "http2" {
return "", E.New("unsupported protocol: ", protocol, ", expected auto, quic or http2")
}
return protocol, nil
type routedOriginPacketDialer interface {
DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error)
}
// icmpRouterHandler bridges cloudflared.ICMPHandler to router.PreMatch.
type icmpRouterHandler struct {
router adapter.Router
tag string
}
func (h *icmpRouterHandler) RouteICMPConnection(ctx context.Context, session tun.DirectRouteSession, routeContext tun.DirectRouteContext, timeout time.Duration) (tun.DirectRouteDestination, error) {
var ipVersion uint8
if session.Source.Is4() {
ipVersion = 4
} else {
ipVersion = 6
}
metadata := adapter.InboundContext{
Inbound: h.tag,
InboundType: C.TypeCloudflared,
IPVersion: ipVersion,
Network: N.NetworkICMP,
Source: M.SocksaddrFrom(session.Source, 0),
Destination: M.SocksaddrFrom(session.Destination, 0),
OriginDestination: M.SocksaddrFrom(session.Destination, 0),
}
return h.router.PreMatch(metadata, routeContext, timeout, false)
}

View File

@@ -1,228 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"errors"
"strings"
"testing"
"time"
"github.com/sagernet/sing-box/log"
N "github.com/sagernet/sing/common/network"
"github.com/google/uuid"
)
func restoreConnectionHooks(t *testing.T) {
t.Helper()
originalNewQUICConnection := newQUICConnection
originalNewHTTP2Connection := newHTTP2Connection
originalServeQUICConnection := serveQUICConnection
originalServeHTTP2Connection := serveHTTP2Connection
t.Cleanup(func() {
newQUICConnection = originalNewQUICConnection
newHTTP2Connection = originalNewHTTP2Connection
serveQUICConnection = originalServeQUICConnection
serveHTTP2Connection = originalServeHTTP2Connection
})
}
func TestServeConnectionAutoFallbackSticky(t *testing.T) {
restoreConnectionHooks(t)
inboundInstance := newLimitedInbound(t, 0)
inboundInstance.protocol = ""
inboundInstance.initializeConnectionState(0)
var quicCalls, http2Calls int
newQUICConnection = func(context.Context, *EdgeAddr, uint8, Credentials, uuid.UUID, string, []string, uint8, time.Duration, N.Dialer, func(), log.ContextLogger) (*QUICConnection, error) {
quicCalls++
return &QUICConnection{}, nil
}
serveQUICConnection = func(*QUICConnection, context.Context, StreamHandler) error {
return errors.New("quic failed")
}
newHTTP2Connection = func(context.Context, *EdgeAddr, uint8, Credentials, uuid.UUID, []string, uint8, time.Duration, *Inbound, log.ContextLogger) (*HTTP2Connection, error) {
http2Calls++
return &HTTP2Connection{}, nil
}
serveHTTP2Connection = func(*HTTP2Connection, context.Context) error {
return errors.New("http2 failed")
}
if err := inboundInstance.serveConnection(0, &EdgeAddr{}); err == nil || err.Error() != "http2 failed" {
t.Fatalf("expected HTTP/2 fallback error, got %v", err)
}
if state := inboundInstance.connectionState(0); state.protocol != "http2" {
t.Fatalf("expected sticky HTTP/2 fallback, got %#v", state)
}
if err := inboundInstance.serveConnection(0, &EdgeAddr{}); err == nil || err.Error() != "http2 failed" {
t.Fatalf("expected second HTTP/2 error, got %v", err)
}
if quicCalls != 1 {
t.Fatalf("expected QUIC to be attempted once, got %d", quicCalls)
}
if http2Calls != 2 {
t.Fatalf("expected HTTP/2 to be attempted twice, got %d", http2Calls)
}
}
func TestSecondConnectionInitialProtocolUsesFirstSuccess(t *testing.T) {
inboundInstance := newLimitedInbound(t, 0)
inboundInstance.protocol = ""
inboundInstance.notifyConnected(0, "http2")
inboundInstance.initializeConnectionState(1)
if state := inboundInstance.connectionState(1); state.protocol != "http2" {
t.Fatalf("expected second connection to inherit HTTP/2, got %#v", state)
}
}
func TestServeConnectionSkipsFallbackWhenQUICAlreadySucceeded(t *testing.T) {
restoreConnectionHooks(t)
inboundInstance := newLimitedInbound(t, 0)
inboundInstance.protocol = ""
inboundInstance.notifyConnected(0, "quic")
inboundInstance.initializeConnectionState(1)
var http2Calls int
quicErr := errors.New("quic failed")
newQUICConnection = func(context.Context, *EdgeAddr, uint8, Credentials, uuid.UUID, string, []string, uint8, time.Duration, N.Dialer, func(), log.ContextLogger) (*QUICConnection, error) {
return &QUICConnection{}, nil
}
serveQUICConnection = func(*QUICConnection, context.Context, StreamHandler) error {
return quicErr
}
newHTTP2Connection = func(context.Context, *EdgeAddr, uint8, Credentials, uuid.UUID, []string, uint8, time.Duration, *Inbound, log.ContextLogger) (*HTTP2Connection, error) {
http2Calls++
return &HTTP2Connection{}, nil
}
err := inboundInstance.serveConnection(1, &EdgeAddr{})
if !errors.Is(err, quicErr) {
t.Fatalf("expected QUIC error without fallback, got %v", err)
}
if http2Calls != 0 {
t.Fatalf("expected no HTTP/2 fallback, got %d calls", http2Calls)
}
if state := inboundInstance.connectionState(1); state.protocol != "quic" {
t.Fatalf("expected connection to remain on QUIC, got %#v", state)
}
}
func TestNotifyConnectedResetsRetries(t *testing.T) {
inboundInstance := newLimitedInbound(t, 0)
inboundInstance.protocol = ""
inboundInstance.initializeConnectionState(0)
inboundInstance.incrementConnectionRetries(0)
inboundInstance.incrementConnectionRetries(0)
inboundInstance.notifyConnected(0, "http2")
state := inboundInstance.connectionState(0)
if state.retries != 0 {
t.Fatalf("expected retries reset after success, got %d", state.retries)
}
if state.protocol != "http2" {
t.Fatalf("expected protocol to be pinned to success, got %q", state.protocol)
}
}
func TestSafeServeConnectionRecoversPanic(t *testing.T) {
restoreConnectionHooks(t)
inboundInstance := newLimitedInbound(t, 0)
inboundInstance.protocol = "quic"
inboundInstance.initializeConnectionState(0)
newQUICConnection = func(context.Context, *EdgeAddr, uint8, Credentials, uuid.UUID, string, []string, uint8, time.Duration, N.Dialer, func(), log.ContextLogger) (*QUICConnection, error) {
return &QUICConnection{}, nil
}
serveQUICConnection = func(*QUICConnection, context.Context, StreamHandler) error {
panic("boom")
}
err := inboundInstance.safeServeConnection(0, &EdgeAddr{})
if err == nil || !strings.Contains(err.Error(), "panic in serve connection") {
t.Fatalf("expected recovered panic error, got %v", err)
}
}
func TestSuperviseConnectionStopsOnPermanentRegistrationError(t *testing.T) {
restoreConnectionHooks(t)
inboundInstance := newLimitedInbound(t, 0)
inboundInstance.protocol = "quic"
inboundInstance.initializeConnectionState(0)
permanentErr := &permanentRegistrationError{Err: errors.New("permanent register error")}
newQUICConnection = func(context.Context, *EdgeAddr, uint8, Credentials, uuid.UUID, string, []string, uint8, time.Duration, N.Dialer, func(), log.ContextLogger) (*QUICConnection, error) {
return &QUICConnection{}, nil
}
serveQUICConnection = func(*QUICConnection, context.Context, StreamHandler) error {
return permanentErr
}
inboundInstance.done.Add(1)
done := make(chan struct{})
go func() {
inboundInstance.superviseConnection(0, []*EdgeAddr{{}})
close(done)
}()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("expected connection supervision to stop")
}
if retries := inboundInstance.connectionState(0).retries; retries != 0 {
t.Fatalf("expected no retries for permanent registration error, got %d", retries)
}
select {
case <-inboundInstance.ctx.Done():
t.Fatal("expected permanent registration error to stop only this connection")
default:
}
}
func TestSuperviseConnectionCancelsInboundOnNonRemoteManagedError(t *testing.T) {
restoreConnectionHooks(t)
inboundInstance := newLimitedInbound(t, 0)
inboundInstance.protocol = "quic"
inboundInstance.initializeConnectionState(0)
newQUICConnection = func(context.Context, *EdgeAddr, uint8, Credentials, uuid.UUID, string, []string, uint8, time.Duration, N.Dialer, func(), log.ContextLogger) (*QUICConnection, error) {
return &QUICConnection{}, nil
}
serveQUICConnection = func(*QUICConnection, context.Context, StreamHandler) error {
return ErrNonRemoteManagedTunnelUnsupported
}
inboundInstance.done.Add(1)
done := make(chan struct{})
go func() {
inboundInstance.superviseConnection(0, []*EdgeAddr{{}})
close(done)
}()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("expected connection supervision to stop")
}
select {
case <-inboundInstance.ctx.Done():
case <-time.After(time.Second):
t.Fatal("expected inbound cancellation on non-remote-managed tunnel error")
}
}

View File

@@ -1,271 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"testing"
"github.com/sagernet/sing-box/log"
)
func newTestIngressInbound(t *testing.T) *Inbound {
t.Helper()
configManager, err := NewConfigManager()
if err != nil {
t.Fatal(err)
}
return &Inbound{
logger: log.NewNOPFactory().NewLogger("test"),
configManager: configManager,
}
}
func mustResolvedService(t *testing.T, rawService string) ResolvedService {
t.Helper()
service, err := parseResolvedService(rawService, defaultOriginRequestConfig())
if err != nil {
t.Fatal(err)
}
return service
}
func TestApplyConfig(t *testing.T) {
inboundInstance := newTestIngressInbound(t)
config1 := []byte(`{"ingress":[{"hostname":"a.com","service":"http://localhost:80"},{"hostname":"b.com","service":"http://localhost:81"},{"service":"http_status:404"}]}`)
result := inboundInstance.ApplyConfig(1, config1)
if result.Err != nil {
t.Fatal(result.Err)
}
if result.LastAppliedVersion != 1 {
t.Fatalf("expected version 1, got %d", result.LastAppliedVersion)
}
service, loaded := inboundInstance.configManager.Resolve("a.com", "/")
if !loaded || service.Service != "http://localhost:80" {
t.Fatalf("expected a.com to resolve to localhost:80, got %#v, loaded=%v", service, loaded)
}
result = inboundInstance.ApplyConfig(1, []byte(`{"ingress":[{"service":"http_status:503"}]}`))
if result.Err != nil {
t.Fatal(result.Err)
}
if result.LastAppliedVersion != 1 {
t.Fatalf("same version should keep current version, got %d", result.LastAppliedVersion)
}
service, loaded = inboundInstance.configManager.Resolve("b.com", "/")
if !loaded || service.Service != "http://localhost:81" {
t.Fatalf("expected old rules to remain, got %#v, loaded=%v", service, loaded)
}
result = inboundInstance.ApplyConfig(2, []byte(`{"ingress":[{"service":"http_status:503"}]}`))
if result.Err != nil {
t.Fatal(result.Err)
}
if result.LastAppliedVersion != 2 {
t.Fatalf("expected version 2, got %d", result.LastAppliedVersion)
}
service, loaded = inboundInstance.configManager.Resolve("anything.com", "/")
if !loaded || service.StatusCode != 503 {
t.Fatalf("expected catch-all status 503, got %#v, loaded=%v", service, loaded)
}
}
func TestApplyConfigInvalidJSON(t *testing.T) {
inboundInstance := newTestIngressInbound(t)
result := inboundInstance.ApplyConfig(1, []byte("not json"))
if result.Err == nil {
t.Fatal("expected parse error")
}
if result.LastAppliedVersion != -1 {
t.Fatalf("expected version to stay -1, got %d", result.LastAppliedVersion)
}
}
func TestDefaultConfigIsCatchAll503(t *testing.T) {
inboundInstance := newTestIngressInbound(t)
service, loaded := inboundInstance.configManager.Resolve("any.example.com", "/")
if !loaded {
t.Fatal("expected default config to resolve catch-all rule")
}
if service.StatusCode != 503 {
t.Fatalf("expected catch-all 503, got %#v", service)
}
}
func TestResolveExactAndWildcard(t *testing.T) {
inboundInstance := newTestIngressInbound(t)
inboundInstance.configManager.activeConfig = RuntimeConfig{
Ingress: []compiledIngressRule{
{Hostname: "test.example.com", Service: mustResolvedService(t, "http://localhost:8080")},
{Hostname: "*.example.com", Service: mustResolvedService(t, "http://localhost:9090")},
{Service: mustResolvedService(t, "http_status:404")},
},
}
service, loaded := inboundInstance.configManager.Resolve("test.example.com", "/")
if !loaded || service.Service != "http://localhost:8080" {
t.Fatalf("expected exact match, got %#v, loaded=%v", service, loaded)
}
service, loaded = inboundInstance.configManager.Resolve("sub.example.com", "/")
if !loaded || service.Service != "http://localhost:9090" {
t.Fatalf("expected wildcard match, got %#v, loaded=%v", service, loaded)
}
service, loaded = inboundInstance.configManager.Resolve("unknown.test", "/")
if !loaded || service.StatusCode != 404 {
t.Fatalf("expected catch-all 404, got %#v, loaded=%v", service, loaded)
}
}
func TestResolveHTTPService(t *testing.T) {
inboundInstance := newTestIngressInbound(t)
inboundInstance.configManager.activeConfig = RuntimeConfig{
Ingress: []compiledIngressRule{
{Hostname: "foo.com", Service: mustResolvedService(t, "http://127.0.0.1:8083")},
{Service: mustResolvedService(t, "http_status:404")},
},
}
service, requestURL, err := inboundInstance.resolveHTTPService("https://foo.com/path?q=1")
if err != nil {
t.Fatal(err)
}
if service.Destination.String() != "127.0.0.1:8083" {
t.Fatalf("expected destination 127.0.0.1:8083, got %s", service.Destination)
}
if requestURL != "http://127.0.0.1:8083/path?q=1" {
t.Fatalf("expected rewritten URL, got %s", requestURL)
}
}
func TestResolveHTTPServiceStatus(t *testing.T) {
inboundInstance := newTestIngressInbound(t)
inboundInstance.configManager.activeConfig = RuntimeConfig{
Ingress: []compiledIngressRule{
{Service: mustResolvedService(t, "http_status:404")},
},
}
service, requestURL, err := inboundInstance.resolveHTTPService("https://any.com/path")
if err != nil {
t.Fatal(err)
}
if service.StatusCode != 404 {
t.Fatalf("expected status 404, got %#v", service)
}
if requestURL != "https://any.com/path" {
t.Fatalf("status service should keep request URL, got %s", requestURL)
}
}
func TestParseResolvedServiceCanonicalizesWebSocketOrigin(t *testing.T) {
testCases := []struct {
rawService string
wantScheme string
}{
{rawService: "ws://127.0.0.1:8080", wantScheme: "http"},
{rawService: "wss://127.0.0.1:8443", wantScheme: "https"},
}
for _, testCase := range testCases {
t.Run(testCase.rawService, func(t *testing.T) {
service, err := parseResolvedService(testCase.rawService, defaultOriginRequestConfig())
if err != nil {
t.Fatal(err)
}
if service.BaseURL == nil {
t.Fatal("expected base URL")
}
if service.BaseURL.Scheme != testCase.wantScheme {
t.Fatalf("expected scheme %q, got %q", testCase.wantScheme, service.BaseURL.Scheme)
}
if service.Service != testCase.rawService {
t.Fatalf("expected raw service to stay %q, got %q", testCase.rawService, service.Service)
}
})
}
}
func TestParseResolvedServiceGenericStreamSchemeWithoutPort(t *testing.T) {
service, err := parseResolvedService("ftp://127.0.0.1", defaultOriginRequestConfig())
if err != nil {
t.Fatal(err)
}
if service.Kind != ResolvedServiceStream {
t.Fatalf("expected stream service, got %v", service.Kind)
}
if service.Destination.AddrString() != "127.0.0.1" {
t.Fatalf("expected destination host 127.0.0.1, got %s", service.Destination.AddrString())
}
if service.Destination.Port != 0 {
t.Fatalf("expected destination port 0, got %d", service.Destination.Port)
}
if service.StreamHasPort {
t.Fatal("expected generic stream service without port to report missing port")
}
}
func TestParseResolvedServiceGenericStreamSchemeWithPort(t *testing.T) {
service, err := parseResolvedService("ftp://127.0.0.1:21", defaultOriginRequestConfig())
if err != nil {
t.Fatal(err)
}
if service.Kind != ResolvedServiceStream {
t.Fatalf("expected stream service, got %v", service.Kind)
}
if service.Destination.String() != "127.0.0.1:21" {
t.Fatalf("expected destination 127.0.0.1:21, got %s", service.Destination)
}
if !service.StreamHasPort {
t.Fatal("expected generic stream service with explicit port to be dialable")
}
}
func TestParseResolvedServiceSSHDefaultPort(t *testing.T) {
service, err := parseResolvedService("ssh://127.0.0.1", defaultOriginRequestConfig())
if err != nil {
t.Fatal(err)
}
if service.Destination.String() != "127.0.0.1:22" {
t.Fatalf("expected destination 127.0.0.1:22, got %s", service.Destination)
}
if !service.StreamHasPort {
t.Fatal("expected ssh stream service to apply default port")
}
}
func TestParseResolvedServiceTCPDefaultPort(t *testing.T) {
service, err := parseResolvedService("tcp://127.0.0.1", defaultOriginRequestConfig())
if err != nil {
t.Fatal(err)
}
if service.Destination.String() != "127.0.0.1:7864" {
t.Fatalf("expected destination 127.0.0.1:7864, got %s", service.Destination)
}
if !service.StreamHasPort {
t.Fatal("expected tcp stream service to apply default port")
}
}
func TestResolveHTTPServiceWebSocketOrigin(t *testing.T) {
inboundInstance := newTestIngressInbound(t)
inboundInstance.configManager.activeConfig = RuntimeConfig{
Ingress: []compiledIngressRule{
{Hostname: "foo.com", Service: mustResolvedService(t, "ws://127.0.0.1:8083")},
{Service: mustResolvedService(t, "http_status:404")},
},
}
_, requestURL, err := inboundInstance.resolveHTTPService("https://foo.com/path?q=1")
if err != nil {
t.Fatal(err)
}
if requestURL != "http://127.0.0.1:8083/path?q=1" {
t.Fatalf("expected websocket origin to be canonicalized, got %s", requestURL)
}
}

View File

@@ -1,181 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/sagernet/sing-box/adapter"
)
func TestQUICIntegration(t *testing.T) {
token, testURL := requireEnvVars(t)
startOriginServer(t)
inboundInstance := newTestInbound(t, token, "quic", 1)
err := inboundInstance.Start(adapter.StartStateStart)
if err != nil {
t.Fatal("Start: ", err)
}
waitForTunnel(t, testURL, 30*time.Second)
resp, err := http.Get(testURL + "/ping")
if err != nil {
t.Fatal("GET /ping: ", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatal("expected 200, got ", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal("read body: ", err)
}
if string(body) != `{"ok":true}` {
t.Error("unexpected body: ", string(body))
}
}
func TestHTTP2Integration(t *testing.T) {
token, testURL := requireEnvVars(t)
startOriginServer(t)
inboundInstance := newTestInbound(t, token, "http2", 1)
err := inboundInstance.Start(adapter.StartStateStart)
if err != nil {
t.Fatal("Start: ", err)
}
waitForTunnel(t, testURL, 30*time.Second)
resp, err := http.Get(testURL + "/ping")
if err != nil {
t.Fatal("GET /ping: ", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatal("expected 200, got ", resp.StatusCode)
}
}
func TestMultipleHAConnections(t *testing.T) {
token, testURL := requireEnvVars(t)
startOriginServer(t)
inboundInstance := newTestInbound(t, token, "quic", 2)
err := inboundInstance.Start(adapter.StartStateStart)
if err != nil {
t.Fatal("Start: ", err)
}
waitForTunnel(t, testURL, 30*time.Second)
// Allow time for second connection to register
time.Sleep(3 * time.Second)
inboundInstance.connectionAccess.Lock()
connCount := len(inboundInstance.connections)
inboundInstance.connectionAccess.Unlock()
if connCount < 2 {
t.Errorf("expected at least 2 connections, got %d", connCount)
}
resp, err := http.Get(testURL + "/ping")
if err != nil {
t.Fatal("GET /ping: ", err)
}
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatal("expected 200, got ", resp.StatusCode)
}
}
func TestHTTPResponseCorrectness(t *testing.T) {
token, testURL := requireEnvVars(t)
startOriginServer(t)
inboundInstance := newTestInbound(t, token, "quic", 1)
err := inboundInstance.Start(adapter.StartStateStart)
if err != nil {
t.Fatal("Start: ", err)
}
waitForTunnel(t, testURL, 30*time.Second)
t.Run("StatusCode", func(t *testing.T) {
resp, err := http.Get(testURL + "/status/201")
if err != nil {
t.Fatal("GET /status/201: ", err)
}
resp.Body.Close()
if resp.StatusCode != 201 {
t.Error("expected 201, got ", resp.StatusCode)
}
})
t.Run("CustomHeader", func(t *testing.T) {
resp, err := http.Get(testURL + "/status/200")
if err != nil {
t.Fatal("GET /status/200: ", err)
}
resp.Body.Close()
customHeader := resp.Header.Get("X-Custom")
if customHeader != "test-value" {
t.Error("expected X-Custom=test-value, got ", customHeader)
}
})
t.Run("PostEcho", func(t *testing.T) {
resp, err := http.Post(testURL+"/echo", "text/plain", strings.NewReader("payload"))
if err != nil {
t.Fatal("POST /echo: ", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatal("expected 200, got ", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal("read body: ", err)
}
if string(body) != "payload" {
t.Error("unexpected body: ", string(body))
}
})
}
func TestGracefulClose(t *testing.T) {
token, testURL := requireEnvVars(t)
startOriginServer(t)
inboundInstance := newTestInbound(t, token, "quic", 1)
err := inboundInstance.Start(adapter.StartStateStart)
if err != nil {
t.Fatal("Start: ", err)
}
waitForTunnel(t, testURL, 30*time.Second)
err = inboundInstance.Close()
if err != nil {
t.Fatal("Close: ", err)
}
if inboundInstance.ctx.Err() == nil {
t.Error("expected context to be cancelled after Close")
}
inboundInstance.connectionAccess.Lock()
remaining := inboundInstance.connections
inboundInstance.connectionAccess.Unlock()
if remaining != nil {
t.Error("expected connections to be nil after Close, got ", len(remaining))
}
}

View File

@@ -1,96 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"net"
"net/netip"
"sort"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
)
type compiledIPRule struct {
prefix netip.Prefix
ports []int
allow bool
}
type ipRulePolicy struct {
rules []compiledIPRule
}
func newIPRulePolicy(rawRules []IPRule) (*ipRulePolicy, error) {
policy := &ipRulePolicy{
rules: make([]compiledIPRule, 0, len(rawRules)),
}
for _, rawRule := range rawRules {
if rawRule.Prefix == "" {
return nil, E.New("ip_rule prefix cannot be blank")
}
prefix, err := netip.ParsePrefix(rawRule.Prefix)
if err != nil {
return nil, E.Cause(err, "parse ip_rule prefix")
}
ports := append([]int(nil), rawRule.Ports...)
sort.Ints(ports)
for _, port := range ports {
if port < 1 || port > 65535 {
return nil, E.New("invalid ip_rule port: ", port)
}
}
policy.rules = append(policy.rules, compiledIPRule{
prefix: prefix,
ports: ports,
allow: rawRule.Allow,
})
}
return policy, nil
}
func (p *ipRulePolicy) Allow(ctx context.Context, destination M.Socksaddr) (bool, error) {
if p == nil {
return false, nil
}
ipAddr, err := resolvePolicyDestination(ctx, destination)
if err != nil {
return false, err
}
port := int(destination.Port)
for _, rule := range p.rules {
if !rule.prefix.Contains(ipAddr) {
continue
}
if len(rule.ports) == 0 {
return rule.allow, nil
}
portIndex := sort.SearchInts(rule.ports, port)
if portIndex < len(rule.ports) && rule.ports[portIndex] == port {
return rule.allow, nil
}
}
return false, nil
}
func resolvePolicyDestination(ctx context.Context, destination M.Socksaddr) (netip.Addr, error) {
if destination.IsIP() {
return destination.Unwrap().Addr, nil
}
if !destination.IsFqdn() {
return netip.Addr{}, E.New("destination is neither IP nor FQDN")
}
ipAddrs, err := net.DefaultResolver.LookupIPAddr(ctx, destination.Fqdn)
if err != nil {
return netip.Addr{}, E.Cause(err, "resolve destination")
}
if len(ipAddrs) == 0 {
return netip.Addr{}, E.New("resolved destination is empty")
}
resolvedAddr, ok := netip.AddrFromSlice(ipAddrs[0].IP)
if !ok {
return netip.Addr{}, E.New("resolved destination is invalid")
}
return resolvedAddr.Unmap(), nil
}

View File

@@ -1,59 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"net/netip"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing/common/buf"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
const originUDPWriteTimeout = 200 * time.Millisecond
type udpWriteDeadlinePacketConn struct {
N.PacketConn
}
func (c *udpWriteDeadlinePacketConn) WritePacket(buffer *buf.Buffer, destination M.Socksaddr) error {
_ = c.PacketConn.SetWriteDeadline(time.Now().Add(originUDPWriteTimeout))
defer func() {
_ = c.PacketConn.SetWriteDeadline(time.Time{})
}()
return c.PacketConn.WritePacket(buffer, destination)
}
type routedOriginPacketDialer interface {
DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error)
}
func (i *Inbound) dialWarpPacketConnection(ctx context.Context, destination netip.AddrPort) (N.PacketConn, error) {
originDialer, ok := i.router.(routedOriginPacketDialer)
if !ok {
return nil, E.New("router does not support cloudflare routed packet dialing")
}
warpRouting := i.configManager.Snapshot().WarpRouting
if warpRouting.ConnectTimeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, warpRouting.ConnectTimeout)
defer cancel()
}
packetConn, err := originDialer.DialRoutePacketConnection(ctx, adapter.InboundContext{
Inbound: i.Tag(),
InboundType: i.Type(),
Network: N.NetworkUDP,
Destination: M.SocksaddrFromNetIP(destination),
UDPConnect: true,
})
if err != nil {
return nil, err
}
return &udpWriteDeadlinePacketConn{PacketConn: packetConn}, nil
}

View File

@@ -1,56 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"errors"
"net"
"testing"
"time"
"github.com/sagernet/sing/common/buf"
M "github.com/sagernet/sing/common/metadata"
)
type captureDeadlinePacketConn struct {
err error
deadlines []time.Time
}
func (c *captureDeadlinePacketConn) ReadPacket(buffer *buf.Buffer) (M.Socksaddr, error) {
buffer.Release()
return M.Socksaddr{}, errors.New("unused")
}
func (c *captureDeadlinePacketConn) WritePacket(buffer *buf.Buffer, _ M.Socksaddr) error {
buffer.Release()
return c.err
}
func (c *captureDeadlinePacketConn) Close() error { return nil }
func (c *captureDeadlinePacketConn) LocalAddr() net.Addr { return &net.UDPAddr{} }
func (c *captureDeadlinePacketConn) SetDeadline(time.Time) error { return nil }
func (c *captureDeadlinePacketConn) SetReadDeadline(time.Time) error { return nil }
func (c *captureDeadlinePacketConn) SetWriteDeadline(t time.Time) error {
c.deadlines = append(c.deadlines, t)
return nil
}
func TestDeadlinePacketConnWrapsWriteDeadline(t *testing.T) {
packetConn := &captureDeadlinePacketConn{}
wrapped := &udpWriteDeadlinePacketConn{PacketConn: packetConn}
if err := wrapped.WritePacket(buf.As([]byte("payload")), M.Socksaddr{}); err != nil {
t.Fatal(err)
}
if len(packetConn.deadlines) != 2 {
t.Fatalf("expected two deadline updates, got %d", len(packetConn.deadlines))
}
if packetConn.deadlines[0].IsZero() {
t.Fatal("expected first deadline to set a timeout")
}
if !packetConn.deadlines[1].IsZero() {
t.Fatal("expected second deadline to clear the timeout")
}
}

View File

@@ -1,337 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"bytes"
"context"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"io"
"math/big"
"net"
"net/http"
"net/url"
"os"
"strings"
"testing"
"time"
"github.com/sagernet/sing-box/adapter"
N "github.com/sagernet/sing/common/network"
)
type noopRouteConnectionRouter struct {
testRouter
}
func (r *noopRouteConnectionRouter) RouteConnectionEx(_ context.Context, conn net.Conn, _ adapter.InboundContext, onClose N.CloseHandlerFunc) {
_ = conn.Close()
onClose(nil)
}
func TestOriginTLSServerName(t *testing.T) {
t.Run("origin server name overrides host", func(t *testing.T) {
serverName := originTLSServerName(OriginRequestConfig{
OriginServerName: "origin.example.com",
MatchSNIToHost: true,
}, "request.example.com")
if serverName != "origin.example.com" {
t.Fatalf("expected origin.example.com, got %s", serverName)
}
})
t.Run("match sni to host strips port", func(t *testing.T) {
serverName := originTLSServerName(OriginRequestConfig{
MatchSNIToHost: true,
}, "request.example.com:443")
if serverName != "request.example.com" {
t.Fatalf("expected request.example.com, got %s", serverName)
}
})
t.Run("match sni to host uses http host header", func(t *testing.T) {
serverName := originTLSServerName(OriginRequestConfig{
MatchSNIToHost: true,
}, effectiveOriginHost(OriginRequestConfig{
HTTPHostHeader: "origin.example.com",
MatchSNIToHost: true,
}, "request.example.com"))
if serverName != "origin.example.com" {
t.Fatalf("expected origin.example.com, got %s", serverName)
}
})
t.Run("match sni to host strips port from http host header", func(t *testing.T) {
serverName := originTLSServerName(OriginRequestConfig{
MatchSNIToHost: true,
}, effectiveOriginHost(OriginRequestConfig{
HTTPHostHeader: "origin.example.com:8443",
MatchSNIToHost: true,
}, "request.example.com"))
if serverName != "origin.example.com" {
t.Fatalf("expected origin.example.com, got %s", serverName)
}
})
t.Run("disabled match keeps empty server name", func(t *testing.T) {
serverName := originTLSServerName(OriginRequestConfig{}, "request.example.com")
if serverName != "" {
t.Fatalf("expected empty server name, got %s", serverName)
}
})
}
func TestNewOriginTLSConfigErrorsOnMissingCAPool(t *testing.T) {
originalBaseLoader := loadOriginCABasePool
loadOriginCABasePool = func() (*x509.CertPool, error) {
return x509.NewCertPool(), nil
}
defer func() {
loadOriginCABasePool = originalBaseLoader
}()
_, err := newOriginTLSConfig(OriginRequestConfig{
CAPool: "/path/does/not/exist.pem",
}, "request.example.com")
if err == nil {
t.Fatal("expected error for missing ca pool")
}
}
func TestNewOriginTLSConfigAppendsCustomCAInsteadOfReplacingBasePool(t *testing.T) {
basePEM, baseCert := createTestCertificatePEM(t, "base")
customPEM, customCert := createTestCertificatePEM(t, "custom")
basePool := x509.NewCertPool()
if !basePool.AppendCertsFromPEM(basePEM) {
t.Fatal("expected base cert to append")
}
originalBaseLoader := loadOriginCABasePool
loadOriginCABasePool = func() (*x509.CertPool, error) {
return basePool, nil
}
defer func() {
loadOriginCABasePool = originalBaseLoader
}()
caFile := writeTempPEM(t, customPEM)
tlsConfig, err := newOriginTLSConfig(OriginRequestConfig{
CAPool: caFile,
}, "request.example.com")
if err != nil {
t.Fatal(err)
}
if tlsConfig.RootCAs == nil {
t.Fatal("expected root CA pool")
}
subjects := tlsConfig.RootCAs.Subjects()
if len(subjects) != 2 {
t.Fatalf("expected 2 subjects, got %d", len(subjects))
}
if !containsSubject(subjects, baseCert.RawSubject) {
t.Fatal("expected base subject to remain in pool")
}
if !containsSubject(subjects, customCert.RawSubject) {
t.Fatal("expected custom subject to be appended to pool")
}
}
func TestOriginTransportUsesProxyFromEnvironmentOnly(t *testing.T) {
originalProxyFromEnvironment := proxyFromEnvironment
proxyFromEnvironment = func(request *http.Request) (*url.URL, error) {
return url.Parse("http://proxy.example.com:8080")
}
defer func() {
proxyFromEnvironment = originalProxyFromEnvironment
}()
inbound := &Inbound{}
transport, cleanup, err := inbound.newDirectOriginTransport(ResolvedService{
Kind: ResolvedServiceUnix,
UnixPath: "/tmp/test.sock",
OriginRequest: OriginRequestConfig{
ProxyAddress: "127.0.0.1",
ProxyPort: 8081,
ProxyType: "http",
},
}, "")
if err != nil {
t.Fatal(err)
}
defer cleanup()
proxyURL, err := transport.Proxy(&http.Request{URL: &url.URL{Scheme: "http", Host: "example.com"}})
if err != nil {
t.Fatal(err)
}
if proxyURL == nil || proxyURL.String() != "http://proxy.example.com:8080" {
t.Fatalf("expected environment proxy URL, got %#v", proxyURL)
}
}
func TestNewDirectOriginTransportNoHappyEyeballs(t *testing.T) {
inbound := &Inbound{}
transport, cleanup, err := inbound.newDirectOriginTransport(ResolvedService{
Kind: ResolvedServiceUnix,
UnixPath: "/tmp/test.sock",
OriginRequest: OriginRequestConfig{
NoHappyEyeballs: true,
},
}, "")
if err != nil {
t.Fatal(err)
}
defer cleanup()
if transport.Proxy == nil {
t.Fatal("expected proxy function to be configured from environment")
}
if transport.DialContext == nil {
t.Fatal("expected custom direct dial context")
}
}
func TestNewRouterOriginTransportPropagatesTLSConfigError(t *testing.T) {
originalBaseLoader := loadOriginCABasePool
loadOriginCABasePool = func() (*x509.CertPool, error) {
return x509.NewCertPool(), nil
}
defer func() {
loadOriginCABasePool = originalBaseLoader
}()
inbound := &Inbound{}
_, _, err := inbound.newRouterOriginTransport(context.Background(), adapter.InboundContext{}, OriginRequestConfig{
CAPool: "/path/does/not/exist.pem",
}, "")
if err == nil {
t.Fatal("expected transport build error")
}
}
func TestNewRouterOriginTransportUsesCloudflaredDefaults(t *testing.T) {
inbound := &Inbound{
router: &noopRouteConnectionRouter{},
}
transport, cleanup, err := inbound.newRouterOriginTransport(context.Background(), adapter.InboundContext{}, OriginRequestConfig{}, "")
if err != nil {
t.Fatal(err)
}
defer cleanup()
if transport.ExpectContinueTimeout != time.Second {
t.Fatalf("expected ExpectContinueTimeout=1s, got %s", transport.ExpectContinueTimeout)
}
if transport.DisableCompression {
t.Fatal("expected compression to remain enabled by default")
}
}
func TestNormalizeOriginRequestSetsKeepAliveAndEmptyUserAgent(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "https://example.com/path", http.NoBody)
if err != nil {
t.Fatal(err)
}
request = normalizeOriginRequest(ConnectionTypeHTTP, request, OriginRequestConfig{})
if connection := request.Header.Get("Connection"); connection != "keep-alive" {
t.Fatalf("expected keep-alive connection header, got %q", connection)
}
if values, exists := request.Header["User-Agent"]; !exists || len(values) != 1 || values[0] != "" {
t.Fatalf("expected empty User-Agent header, got %#v", request.Header["User-Agent"])
}
}
func TestNormalizeOriginRequestDisableChunkedEncoding(t *testing.T) {
request, err := http.NewRequest(http.MethodPost, "https://example.com/path", strings.NewReader("payload"))
if err != nil {
t.Fatal(err)
}
request.TransferEncoding = []string{"chunked"}
request.Header.Set("Content-Length", "7")
request = normalizeOriginRequest(ConnectionTypeHTTP, request, OriginRequestConfig{
DisableChunkedEncoding: true,
})
if len(request.TransferEncoding) != 2 || request.TransferEncoding[0] != "gzip" || request.TransferEncoding[1] != "deflate" {
t.Fatalf("unexpected transfer encoding: %#v", request.TransferEncoding)
}
if request.ContentLength != 7 {
t.Fatalf("expected content length 7, got %d", request.ContentLength)
}
}
func TestNormalizeOriginRequestWebsocket(t *testing.T) {
request, err := http.NewRequest(http.MethodGet, "https://example.com/path", io.NopCloser(strings.NewReader("payload")))
if err != nil {
t.Fatal(err)
}
request = normalizeOriginRequest(ConnectionTypeWebsocket, request, OriginRequestConfig{})
if connection := request.Header.Get("Connection"); connection != "Upgrade" {
t.Fatalf("expected websocket connection header, got %q", connection)
}
if upgrade := request.Header.Get("Upgrade"); upgrade != "websocket" {
t.Fatalf("expected websocket upgrade header, got %q", upgrade)
}
if version := request.Header.Get("Sec-Websocket-Version"); version != "13" {
t.Fatalf("expected websocket version 13, got %q", version)
}
if request.ContentLength != 0 {
t.Fatalf("expected websocket content length 0, got %d", request.ContentLength)
}
if request.Body != nil {
t.Fatal("expected websocket body to be nil")
}
}
func createTestCertificatePEM(t *testing.T, commonName string) ([]byte, *x509.Certificate) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatal(err)
}
template := &x509.Certificate{
SerialNumber: big.NewInt(time.Now().UnixNano()),
Subject: pkix.Name{
CommonName: commonName,
},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature,
BasicConstraintsValid: true,
IsCA: true,
}
der, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey)
if err != nil {
t.Fatal(err)
}
certificate, err := x509.ParseCertificate(der)
if err != nil {
t.Fatal(err)
}
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}), certificate
}
func writeTempPEM(t *testing.T, pemData []byte) string {
t.Helper()
path := t.TempDir() + "/ca.pem"
if err := os.WriteFile(path, pemData, 0o600); err != nil {
t.Fatal(err)
}
return path
}
func containsSubject(subjects [][]byte, want []byte) bool {
for _, subject := range subjects {
if bytes.Equal(subject, want) {
return true
}
}
return false
}

View File

@@ -1,78 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"bytes"
"context"
"io"
"net/http"
"testing"
)
func TestBuildHTTPRequestFromMetadataUsesNoBodyWhenLengthZeroWithoutChunked(t *testing.T) {
request, err := buildHTTPRequestFromMetadata(context.Background(), &ConnectRequest{
Dest: "http://example.com",
Type: ConnectionTypeHTTP,
Metadata: []Metadata{
{Key: metadataHTTPMethod, Val: http.MethodGet},
{Key: metadataHTTPHost, Val: "cf.host"},
},
}, io.NopCloser(bytes.NewBuffer(nil)))
if err != nil {
t.Fatal(err)
}
if request.Body != http.NoBody {
t.Fatalf("expected http.NoBody, got %#v", request.Body)
}
}
func TestBuildHTTPRequestFromMetadataPreservesBodyWhenTransferEncodingChunked(t *testing.T) {
request, err := buildHTTPRequestFromMetadata(context.Background(), &ConnectRequest{
Dest: "http://example.com",
Type: ConnectionTypeHTTP,
Metadata: []Metadata{
{Key: metadataHTTPMethod, Val: http.MethodPost},
{Key: metadataHTTPHost, Val: "cf.host"},
{Key: metadataHTTPHeader + ":Transfer-Encoding", Val: "chunked"},
},
}, io.NopCloser(bytes.NewBufferString("payload")))
if err != nil {
t.Fatal(err)
}
if request.Body == http.NoBody {
t.Fatal("expected request body to be preserved")
}
body, err := io.ReadAll(request.Body)
if err != nil {
t.Fatal(err)
}
if string(body) != "payload" {
t.Fatalf("unexpected body %q", body)
}
}
func TestBuildHTTPRequestFromMetadataPreservesBodyWhenTransferEncodingContainsChunked(t *testing.T) {
request, err := buildHTTPRequestFromMetadata(context.Background(), &ConnectRequest{
Dest: "http://example.com",
Type: ConnectionTypeHTTP,
Metadata: []Metadata{
{Key: metadataHTTPMethod, Val: http.MethodPost},
{Key: metadataHTTPHost, Val: "cf.host"},
{Key: metadataHTTPHeader + ":Transfer-Encoding", Val: "gzip,chunked"},
},
}, io.NopCloser(bytes.NewBufferString("payload")))
if err != nil {
t.Fatal(err)
}
if request.Body == http.NoBody {
t.Fatal("expected request body to be preserved")
}
body, err := io.ReadAll(request.Body)
if err != nil {
t.Fatal(err)
}
if string(body) != "payload" {
t.Fatalf("unexpected body %q", body)
}
}

View File

@@ -1,91 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/sagernet/sing-box/log"
)
type trailerCaptureResponseWriter struct {
status int
trailers http.Header
}
func (w *trailerCaptureResponseWriter) WriteResponse(responseError error, metadata []Metadata) error {
for _, entry := range metadata {
if entry.Key == metadataHTTPStatus {
w.status = http.StatusOK
}
}
return nil
}
func (w *trailerCaptureResponseWriter) AddTrailer(name, value string) {
if w.trailers == nil {
w.trailers = make(http.Header)
}
w.trailers.Add(name, value)
}
type captureReadWriteCloser struct {
body []byte
}
func (c *captureReadWriteCloser) Read(_ []byte) (int, error) {
return 0, io.EOF
}
func (c *captureReadWriteCloser) Write(p []byte) (int, error) {
c.body = append(c.body, p...)
return len(p), nil
}
func (c *captureReadWriteCloser) Close() error {
return nil
}
func TestRoundTripHTTPCopiesTrailers(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Trailer", "X-Test-Trailer")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
w.Header().Set("X-Test-Trailer", "trailer-value")
}))
defer server.Close()
transport, ok := server.Client().Transport.(*http.Transport)
if !ok {
t.Fatalf("unexpected transport type %T", server.Client().Transport)
}
inboundInstance := &Inbound{
logger: log.NewNOPFactory().NewLogger("test"),
}
stream := &captureReadWriteCloser{}
respWriter := &trailerCaptureResponseWriter{}
request := &ConnectRequest{
Dest: server.URL,
Type: ConnectionTypeHTTP,
Metadata: []Metadata{
{Key: metadataHTTPMethod, Val: http.MethodGet},
{Key: metadataHTTPHost, Val: "example.com"},
},
}
inboundInstance.roundTripHTTP(context.Background(), stream, respWriter, request, ResolvedService{
OriginRequest: defaultOriginRequestConfig(),
}, transport)
if got := respWriter.trailers.Get("X-Test-Trailer"); got != "trailer-value" {
t.Fatalf("expected copied trailer, got %q", got)
}
if string(stream.body) != "ok" {
t.Fatalf("unexpected response body %q", stream.body)
}
}

View File

@@ -1,24 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"crypto/x509"
_ "embed"
E "github.com/sagernet/sing/common/exceptions"
)
//go:embed cloudflare_ca.pem
var cloudflareRootCAPEM []byte
func cloudflareRootCertPool() (*x509.CertPool, error) {
pool, err := x509.SystemCertPool()
if err != nil {
pool = x509.NewCertPool()
}
if !pool.AppendCertsFromPEM(cloudflareRootCAPEM) {
return nil, E.New("failed to parse embedded Cloudflare root CAs")
}
return pool, nil
}

View File

@@ -1,90 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"net"
"sync"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing/common"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/sing/common/pipe"
)
type routedPipeTCPOptions struct {
timeout time.Duration
onHandshake func(net.Conn)
}
type routedPipeTCPConn struct {
net.Conn
handshakeOnce sync.Once
onHandshake func(net.Conn)
}
func (c *routedPipeTCPConn) ConnHandshakeSuccess(conn net.Conn) error {
if c.onHandshake != nil {
c.handshakeOnce.Do(func() {
c.onHandshake(conn)
})
}
return nil
}
func (i *Inbound) dialRouterTCPWithMetadata(ctx context.Context, metadata adapter.InboundContext, options routedPipeTCPOptions) (net.Conn, func(), error) {
input, output := pipe.Pipe()
routerConn := &routedPipeTCPConn{
Conn: output,
onHandshake: options.onHandshake,
}
done := make(chan struct{})
routeCtx := ctx
var cancel context.CancelFunc
if options.timeout > 0 {
routeCtx, cancel = context.WithTimeout(ctx, options.timeout)
}
var closeOnce sync.Once
closePipe := func() {
closeOnce.Do(func() {
if cancel != nil {
cancel()
}
common.Close(input, routerConn)
})
}
go i.router.RouteConnectionEx(routeCtx, routerConn, metadata, N.OnceClose(func(it error) {
closePipe()
close(done)
}))
return input, func() {
closePipe()
select {
case <-done:
case <-time.After(time.Second):
}
}, nil
}
func applyTCPKeepAlive(conn net.Conn, keepAlive time.Duration) error {
if keepAlive <= 0 {
return nil
}
type keepAliveConn interface {
SetKeepAlive(bool) error
SetKeepAlivePeriod(time.Duration) error
}
tcpConn, ok := conn.(keepAliveConn)
if !ok {
return nil
}
if err := tcpConn.SetKeepAlive(true); err != nil {
return err
}
return tcpConn.SetKeepAlivePeriod(keepAlive)
}

View File

@@ -1,165 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"io"
"net"
"testing"
"time"
"github.com/sagernet/sing-box/adapter"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
)
func TestHandleTCPStreamUsesRouteConnectionEx(t *testing.T) {
listener := startEchoListener(t)
defer listener.Close()
router := &countingRouter{}
inboundInstance := newSpecialServiceInboundWithRouter(t, router)
serverSide, clientSide := net.Pipe()
defer clientSide.Close()
respWriter := &fakeConnectResponseWriter{done: make(chan struct{})}
responseDone := respWriter.done
finished := make(chan struct{})
go func() {
inboundInstance.handleTCPStream(context.Background(), serverSide, respWriter, adapter.InboundContext{
Destination: M.ParseSocksaddr(listener.Addr().String()),
})
close(finished)
}()
select {
case <-responseDone:
case <-time.After(time.Second):
t.Fatal("timed out waiting for connect response")
}
if respWriter.err != nil {
t.Fatal("unexpected response error: ", respWriter.err)
}
if err := clientSide.SetDeadline(time.Now().Add(time.Second)); err != nil {
t.Fatal(err)
}
payload := []byte("ping")
if _, err := clientSide.Write(payload); err != nil {
t.Fatal(err)
}
response := make([]byte, len(payload))
if _, err := io.ReadFull(clientSide, response); err != nil {
t.Fatal(err)
}
if string(response) != string(payload) {
t.Fatalf("unexpected echo payload: %q", string(response))
}
if router.count.Load() != 1 {
t.Fatalf("expected RouteConnectionEx to be used once, got %d", router.count.Load())
}
_ = clientSide.Close()
select {
case <-finished:
case <-time.After(time.Second):
t.Fatal("timed out waiting for TCP stream handler to exit")
}
}
func TestHandleTCPStreamWritesOptimisticAck(t *testing.T) {
router := &blockingRouteRouter{
started: make(chan struct{}),
release: make(chan struct{}),
}
inboundInstance := newSpecialServiceInboundWithRouter(t, router)
serverSide, clientSide := net.Pipe()
defer clientSide.Close()
respWriter := &fakeConnectResponseWriter{done: make(chan struct{})}
responseDone := respWriter.done
finished := make(chan struct{})
go func() {
inboundInstance.handleTCPStream(context.Background(), serverSide, respWriter, adapter.InboundContext{
Destination: M.ParseSocksaddr("127.0.0.1:443"),
})
close(finished)
}()
select {
case <-router.started:
case <-time.After(time.Second):
t.Fatal("timed out waiting for router goroutine to start")
}
select {
case <-responseDone:
case <-time.After(time.Second):
t.Fatal("timed out waiting for optimistic connect response")
}
if respWriter.err != nil {
t.Fatal("unexpected response error: ", respWriter.err)
}
close(router.release)
_ = clientSide.Close()
select {
case <-finished:
case <-time.After(time.Second):
t.Fatal("timed out waiting for TCP stream handler to exit")
}
}
func TestRoutedPipeTCPConnHandshakeAppliesKeepAlive(t *testing.T) {
left, right := net.Pipe()
defer left.Close()
defer right.Close()
remoteConn := &keepAliveTestConn{Conn: right}
routerConn := &routedPipeTCPConn{
Conn: left,
onHandshake: func(conn net.Conn) {
_ = applyTCPKeepAlive(conn, 15*time.Second)
},
}
if err := routerConn.ConnHandshakeSuccess(remoteConn); err != nil {
t.Fatal(err)
}
if !remoteConn.enabled {
t.Fatal("expected keepalive to be enabled")
}
if remoteConn.period != 15*time.Second {
t.Fatalf("unexpected keepalive period: %s", remoteConn.period)
}
}
type blockingRouteRouter struct {
testRouter
started chan struct{}
release chan struct{}
}
func (r *blockingRouteRouter) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
close(r.started)
<-r.release
_ = conn.Close()
onClose(nil)
}
type keepAliveTestConn struct {
net.Conn
enabled bool
period time.Duration
}
func (c *keepAliveTestConn) SetKeepAlive(enabled bool) error {
c.enabled = enabled
return nil
}
func (c *keepAliveTestConn) SetKeepAlivePeriod(period time.Duration) error {
c.period = period
return nil
}

View File

@@ -1,251 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"io"
"net"
"testing"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc"
N "github.com/sagernet/sing/common/network"
"github.com/google/uuid"
)
type blockingRPCStream struct {
closed chan struct{}
}
func newBlockingRPCStream() *blockingRPCStream {
return &blockingRPCStream{closed: make(chan struct{})}
}
func (s *blockingRPCStream) Read(_ []byte) (int, error) {
<-s.closed
return 0, io.EOF
}
func (s *blockingRPCStream) Write(p []byte) (int, error) {
return len(p), nil
}
func (s *blockingRPCStream) Close() error {
select {
case <-s.closed:
default:
close(s.closed)
}
return nil
}
type blockingPacketDialRouter struct {
testRouter
entered chan struct{}
release chan struct{}
}
func (r *blockingPacketDialRouter) DialRoutePacketConnection(ctx context.Context, metadata adapter.InboundContext) (N.PacketConn, error) {
select {
case <-r.entered:
default:
close(r.entered)
}
select {
case <-r.release:
return newBlockingPacketConn(), nil
case <-ctx.Done():
return nil, ctx.Err()
}
}
func newRPCInbound(t *testing.T, router adapter.Router) *Inbound {
t.Helper()
inboundInstance := newLimitedInbound(t, 0)
inboundInstance.router = router
return inboundInstance
}
func newRPCClientPair(t *testing.T, ctx context.Context) (tunnelrpc.CloudflaredServer, io.Closer, io.Closer, net.Conn, net.Conn) {
t.Helper()
serverSide, clientSide := net.Pipe()
transport := safeTransport(clientSide)
clientConn := newRPCClientConn(transport, ctx)
client := tunnelrpc.CloudflaredServer{Client: clientConn.Bootstrap(ctx)}
return client, clientConn, transport, serverSide, clientSide
}
func TestServeRPCStreamRespectsContextDeadline(t *testing.T) {
inboundInstance := newLimitedInbound(t, 0)
stream := newBlockingRPCStream()
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
done := make(chan struct{})
go func() {
ServeRPCStream(ctx, stream, inboundInstance, NewDatagramV2Muxer(inboundInstance, &captureDatagramSender{}, inboundInstance.logger), inboundInstance.logger)
close(done)
}()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("expected ServeRPCStream to exit after context deadline")
}
}
func TestServeV3RPCStreamRespectsContextDeadline(t *testing.T) {
inboundInstance := newLimitedInbound(t, 0)
stream := newBlockingRPCStream()
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
done := make(chan struct{})
go func() {
ServeV3RPCStream(ctx, stream, inboundInstance, inboundInstance.logger)
close(done)
}()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("expected ServeV3RPCStream to exit after context deadline")
}
}
func TestV2RPCAckAllowsConcurrentDispatch(t *testing.T) {
router := &blockingPacketDialRouter{
entered: make(chan struct{}),
release: make(chan struct{}),
}
inboundInstance := newRPCInbound(t, router)
muxer := NewDatagramV2Muxer(inboundInstance, &captureDatagramSender{}, inboundInstance.logger)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
client, clientConn, transport, serverSide, clientSide := newRPCClientPair(t, ctx)
defer clientConn.Close()
defer transport.Close()
defer clientSide.Close()
done := make(chan struct{})
go func() {
ServeRPCStream(ctx, serverSide, inboundInstance, muxer, inboundInstance.logger)
close(done)
}()
registerPromise := client.RegisterUdpSession(ctx, func(p tunnelrpc.SessionManager_registerUdpSession_Params) error {
sessionID := uuid.New()
if err := p.SetSessionId(sessionID[:]); err != nil {
return err
}
if err := p.SetDstIp([]byte{127, 0, 0, 1}); err != nil {
return err
}
p.SetDstPort(53)
p.SetCloseAfterIdleHint(int64(time.Second))
return p.SetTraceContext("")
})
select {
case <-router.entered:
case <-time.After(time.Second):
t.Fatal("expected register RPC to enter the blocking dial")
}
updateCtx, updateCancel := context.WithTimeout(ctx, 500*time.Millisecond)
defer updateCancel()
updatePromise := client.UpdateConfiguration(updateCtx, func(p tunnelrpc.ConfigurationManager_updateConfiguration_Params) error {
p.SetVersion(1)
return p.SetConfig([]byte(`{"ingress":[{"service":"http_status:503"}]}`))
})
if _, err := updatePromise.Result().Struct(); err != nil {
t.Fatalf("expected concurrent update RPC to succeed, got %v", err)
}
close(router.release)
if _, err := registerPromise.Result().Struct(); err != nil {
t.Fatalf("expected register RPC to complete, got %v", err)
}
cancel()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("expected ServeRPCStream to exit")
}
}
func TestV3RPCAckAllowsConcurrentDispatch(t *testing.T) {
inboundInstance := newLimitedInbound(t, 0)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
client, clientConn, transport, serverSide, clientSide := newRPCClientPair(t, ctx)
defer clientConn.Close()
defer transport.Close()
defer clientSide.Close()
done := make(chan struct{})
go func() {
ServeV3RPCStream(ctx, serverSide, inboundInstance, inboundInstance.logger)
close(done)
}()
inboundInstance.configManager.access.Lock()
updatePromise := client.UpdateConfiguration(ctx, func(p tunnelrpc.ConfigurationManager_updateConfiguration_Params) error {
p.SetVersion(1)
return p.SetConfig([]byte(`{"ingress":[{"service":"http_status:503"}]}`))
})
time.Sleep(20 * time.Millisecond)
registerCtx, registerCancel := context.WithTimeout(ctx, 500*time.Millisecond)
defer registerCancel()
registerPromise := client.RegisterUdpSession(registerCtx, func(p tunnelrpc.SessionManager_registerUdpSession_Params) error {
sessionID := uuid.New()
if err := p.SetSessionId(sessionID[:]); err != nil {
return err
}
if err := p.SetDstIp([]byte{127, 0, 0, 1}); err != nil {
return err
}
p.SetDstPort(53)
p.SetCloseAfterIdleHint(int64(time.Second))
return p.SetTraceContext("")
})
registerResult, err := registerPromise.Result().Struct()
if err != nil {
t.Fatalf("expected concurrent v3 register RPC to succeed, got %v", err)
}
resultErr, err := registerResult.Err()
if err != nil {
t.Fatal(err)
}
if resultErr != errUnsupportedDatagramV3UDPRegistration.Error() {
t.Fatalf("unexpected registration error %q", resultErr)
}
inboundInstance.configManager.access.Unlock()
if _, err := updatePromise.Result().Struct(); err != nil {
t.Fatalf("expected update RPC to complete, got %v", err)
}
cancel()
select {
case <-done:
case <-time.After(time.Second):
t.Fatal("expected ServeV3RPCStream to exit")
}
}

View File

@@ -1,715 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"encoding/json"
"net"
"net/url"
"regexp"
"strconv"
"strings"
"sync"
"time"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
"golang.org/x/net/idna"
)
const (
defaultHTTPConnectTimeout = 30 * time.Second
defaultTLSTimeout = 10 * time.Second
defaultTCPKeepAlive = 30 * time.Second
defaultKeepAliveTimeout = 90 * time.Second
defaultKeepAliveConnections = 100
defaultProxyAddress = "127.0.0.1"
defaultWarpRoutingConnectTime = 5 * time.Second
defaultWarpRoutingTCPKeepAlive = 30 * time.Second
)
type ResolvedServiceKind int
const (
ResolvedServiceHTTP ResolvedServiceKind = iota
ResolvedServiceStream
ResolvedServiceStatus
ResolvedServiceUnix
ResolvedServiceUnixTLS
ResolvedServiceBastion
ResolvedServiceSocksProxy
)
type ResolvedService struct {
Kind ResolvedServiceKind
Service string
Destination M.Socksaddr
StreamHasPort bool
BaseURL *url.URL
UnixPath string
StatusCode int
SocksPolicy *ipRulePolicy
OriginRequest OriginRequestConfig
}
func (s ResolvedService) RouterControlled() bool {
return s.Kind == ResolvedServiceHTTP || s.Kind == ResolvedServiceStream
}
func (s ResolvedService) BuildRequestURL(requestURL string) (string, error) {
switch s.Kind {
case ResolvedServiceHTTP, ResolvedServiceUnix, ResolvedServiceUnixTLS:
requestParsed, err := url.Parse(requestURL)
if err != nil {
return "", err
}
originURL := *s.BaseURL
originURL.Path = requestParsed.Path
originURL.RawPath = requestParsed.RawPath
originURL.RawQuery = requestParsed.RawQuery
originURL.Fragment = requestParsed.Fragment
return originURL.String(), nil
default:
return requestURL, nil
}
}
func canonicalizeHTTPOriginURL(parsedURL *url.URL) *url.URL {
if parsedURL == nil {
return nil
}
canonicalURL := *parsedURL
switch canonicalURL.Scheme {
case "ws":
canonicalURL.Scheme = "http"
case "wss":
canonicalURL.Scheme = "https"
}
return &canonicalURL
}
func isHTTPServiceScheme(scheme string) bool {
switch scheme {
case "http", "https", "ws", "wss":
return true
default:
return false
}
}
type compiledIngressRule struct {
Hostname string
PunycodeHostname string
Path *regexp.Regexp
Service ResolvedService
}
type RuntimeConfig struct {
Ingress []compiledIngressRule
OriginRequest OriginRequestConfig
WarpRouting WarpRoutingConfig
}
type OriginRequestConfig struct {
ConnectTimeout time.Duration
TLSTimeout time.Duration
TCPKeepAlive time.Duration
NoHappyEyeballs bool
KeepAliveTimeout time.Duration
KeepAliveConnections int
HTTPHostHeader string
OriginServerName string
MatchSNIToHost bool
CAPool string
NoTLSVerify bool
DisableChunkedEncoding bool
BastionMode bool
ProxyAddress string
ProxyPort uint
ProxyType string
IPRules []IPRule
HTTP2Origin bool
Access AccessConfig
}
type AccessConfig struct {
Required bool
TeamName string
AudTag []string
Environment string
}
type IPRule struct {
Prefix string
Ports []int
Allow bool
}
type WarpRoutingConfig struct {
ConnectTimeout time.Duration
MaxActiveFlows uint64
TCPKeepAlive time.Duration
}
type ConfigUpdateResult struct {
LastAppliedVersion int32
Err error
}
type ConfigManager struct {
access sync.RWMutex
currentVersion int32
activeConfig RuntimeConfig
}
func NewConfigManager() (*ConfigManager, error) {
config, err := defaultRuntimeConfig()
if err != nil {
return nil, err
}
return &ConfigManager{
currentVersion: -1,
activeConfig: config,
}, nil
}
func (m *ConfigManager) Snapshot() RuntimeConfig {
m.access.RLock()
defer m.access.RUnlock()
return m.activeConfig
}
func (m *ConfigManager) CurrentVersion() int32 {
m.access.RLock()
defer m.access.RUnlock()
return m.currentVersion
}
func (m *ConfigManager) Apply(version int32, raw []byte) ConfigUpdateResult {
m.access.Lock()
defer m.access.Unlock()
if version <= m.currentVersion {
return ConfigUpdateResult{LastAppliedVersion: m.currentVersion}
}
config, err := buildRemoteRuntimeConfig(raw)
if err != nil {
return ConfigUpdateResult{
LastAppliedVersion: m.currentVersion,
Err: err,
}
}
m.activeConfig = config
m.currentVersion = version
return ConfigUpdateResult{LastAppliedVersion: m.currentVersion}
}
func (m *ConfigManager) Resolve(hostname, path string) (ResolvedService, bool) {
m.access.RLock()
defer m.access.RUnlock()
return m.activeConfig.Resolve(hostname, path)
}
func (c RuntimeConfig) Resolve(hostname, path string) (ResolvedService, bool) {
host := stripPort(hostname)
for _, rule := range c.Ingress {
if !matchIngressRule(rule, host, path) {
continue
}
return rule.Service, true
}
return ResolvedService{}, false
}
func matchIngressRule(rule compiledIngressRule, hostname, path string) bool {
hostMatch := rule.Hostname == "" || rule.Hostname == "*" || matchIngressHost(rule.Hostname, hostname)
if !hostMatch && rule.PunycodeHostname != "" {
hostMatch = matchIngressHost(rule.PunycodeHostname, hostname)
}
if !hostMatch {
return false
}
return rule.Path == nil || rule.Path.MatchString(path)
}
func matchIngressHost(pattern, hostname string) bool {
if pattern == hostname {
return true
}
if strings.HasPrefix(pattern, "*.") {
return strings.HasSuffix(hostname, strings.TrimPrefix(pattern, "*"))
}
return false
}
func defaultRuntimeConfig() (RuntimeConfig, error) {
defaultOriginRequest := defaultOriginRequestConfig()
compiledRules, err := compileIngressRules(defaultOriginRequest, nil)
if err != nil {
return RuntimeConfig{}, err
}
return RuntimeConfig{
Ingress: compiledRules,
OriginRequest: defaultOriginRequest,
WarpRouting: WarpRoutingConfig{
ConnectTimeout: defaultWarpRoutingConnectTime,
TCPKeepAlive: defaultWarpRoutingTCPKeepAlive,
},
}, nil
}
func buildRemoteRuntimeConfig(raw []byte) (RuntimeConfig, error) {
var remote remoteConfigJSON
if err := json.Unmarshal(raw, &remote); err != nil {
return RuntimeConfig{}, E.Cause(err, "decode remote config")
}
defaultOriginRequest := originRequestFromRemote(remote.OriginRequest)
warpRouting := warpRoutingFromRemote(remote.WarpRouting)
var ingressRules []localIngressRule
for _, rule := range remote.Ingress {
ingressRules = append(ingressRules, localIngressRule{
Hostname: rule.Hostname,
Path: rule.Path,
Service: rule.Service,
OriginRequest: mergeRemoteOriginRequest(defaultOriginRequest, rule.OriginRequest),
})
}
compiledRules, err := compileIngressRules(defaultOriginRequest, ingressRules)
if err != nil {
return RuntimeConfig{}, err
}
return RuntimeConfig{
Ingress: compiledRules,
OriginRequest: defaultOriginRequest,
WarpRouting: warpRouting,
}, nil
}
type localIngressRule struct {
Hostname string
Path string
Service string
OriginRequest OriginRequestConfig
}
type remoteConfigJSON struct {
OriginRequest remoteOriginRequestJSON `json:"originRequest"`
Ingress []remoteIngressRuleJSON `json:"ingress"`
WarpRouting remoteWarpRoutingJSON `json:"warp-routing"`
}
type remoteIngressRuleJSON struct {
Hostname string `json:"hostname,omitempty"`
Path string `json:"path,omitempty"`
Service string `json:"service"`
OriginRequest remoteOriginRequestJSON `json:"originRequest,omitempty"`
}
type remoteOriginRequestJSON struct {
ConnectTimeout int64 `json:"connectTimeout,omitempty"`
TLSTimeout int64 `json:"tlsTimeout,omitempty"`
TCPKeepAlive int64 `json:"tcpKeepAlive,omitempty"`
NoHappyEyeballs *bool `json:"noHappyEyeballs,omitempty"`
KeepAliveTimeout int64 `json:"keepAliveTimeout,omitempty"`
KeepAliveConnections *int `json:"keepAliveConnections,omitempty"`
HTTPHostHeader string `json:"httpHostHeader,omitempty"`
OriginServerName string `json:"originServerName,omitempty"`
MatchSNIToHost *bool `json:"matchSNIToHost,omitempty"`
CAPool string `json:"caPool,omitempty"`
NoTLSVerify *bool `json:"noTLSVerify,omitempty"`
DisableChunkedEncoding *bool `json:"disableChunkedEncoding,omitempty"`
BastionMode *bool `json:"bastionMode,omitempty"`
ProxyAddress string `json:"proxyAddress,omitempty"`
ProxyPort *uint `json:"proxyPort,omitempty"`
ProxyType string `json:"proxyType,omitempty"`
IPRules []remoteIPRuleJSON `json:"ipRules,omitempty"`
HTTP2Origin *bool `json:"http2Origin,omitempty"`
Access *remoteAccessJSON `json:"access,omitempty"`
}
type remoteAccessJSON struct {
Required bool `json:"required,omitempty"`
TeamName string `json:"teamName,omitempty"`
AudTag []string `json:"audTag,omitempty"`
Environment string `json:"environment,omitempty"`
}
type remoteIPRuleJSON struct {
Prefix string `json:"prefix,omitempty"`
Ports []int `json:"ports,omitempty"`
Allow bool `json:"allow,omitempty"`
}
type remoteWarpRoutingJSON struct {
ConnectTimeout int64 `json:"connectTimeout,omitempty"`
MaxActiveFlows uint64 `json:"maxActiveFlows,omitempty"`
TCPKeepAlive int64 `json:"tcpKeepAlive,omitempty"`
}
func compileIngressRules(defaultOriginRequest OriginRequestConfig, rawRules []localIngressRule) ([]compiledIngressRule, error) {
if len(rawRules) == 0 {
rawRules = []localIngressRule{{
Service: "http_status:503",
OriginRequest: defaultOriginRequest,
}}
}
if !isCatchAllRule(rawRules[len(rawRules)-1].Hostname, rawRules[len(rawRules)-1].Path) {
return nil, E.New("the last ingress rule must be a catch-all rule")
}
compiled := make([]compiledIngressRule, 0, len(rawRules))
for index, rule := range rawRules {
if err := validateHostname(rule.Hostname, index == len(rawRules)-1); err != nil {
return nil, err
}
if err := validateAccessConfiguration(rule.OriginRequest.Access); err != nil {
return nil, err
}
service, err := parseResolvedService(rule.Service, rule.OriginRequest)
if err != nil {
return nil, err
}
var pathPattern *regexp.Regexp
if rule.Path != "" {
pathPattern, err = regexp.Compile(rule.Path)
if err != nil {
return nil, E.Cause(err, "compile ingress path regex")
}
}
punycode := ""
if rule.Hostname != "" && rule.Hostname != "*" {
punycodeValue, err := idna.Lookup.ToASCII(rule.Hostname)
if err == nil && punycodeValue != rule.Hostname {
punycode = punycodeValue
}
}
compiled = append(compiled, compiledIngressRule{
Hostname: rule.Hostname,
PunycodeHostname: punycode,
Path: pathPattern,
Service: service,
})
}
return compiled, nil
}
func parseResolvedService(rawService string, originRequest OriginRequestConfig) (ResolvedService, error) {
switch {
case rawService == "":
if originRequest.BastionMode {
return ResolvedService{
Kind: ResolvedServiceBastion,
Service: "bastion",
OriginRequest: originRequest,
}, nil
}
return ResolvedService{}, E.New("missing ingress service")
case strings.HasPrefix(rawService, "http_status:"):
statusCode, err := strconv.Atoi(strings.TrimPrefix(rawService, "http_status:"))
if err != nil {
return ResolvedService{}, E.Cause(err, "parse http_status service")
}
if statusCode < 100 || statusCode > 999 {
return ResolvedService{}, E.New("invalid http_status code: ", statusCode)
}
return ResolvedService{
Kind: ResolvedServiceStatus,
Service: rawService,
StatusCode: statusCode,
OriginRequest: originRequest,
}, nil
case rawService == "hello_world" || rawService == "hello-world":
return ResolvedService{}, E.New("unsupported ingress service: hello_world")
case rawService == "bastion":
return ResolvedService{
Kind: ResolvedServiceBastion,
Service: rawService,
OriginRequest: originRequest,
}, nil
case rawService == "socks-proxy":
policy, err := newIPRulePolicy(originRequest.IPRules)
if err != nil {
return ResolvedService{}, E.Cause(err, "compile socks-proxy ip rules")
}
return ResolvedService{
Kind: ResolvedServiceSocksProxy,
Service: rawService,
SocksPolicy: policy,
OriginRequest: originRequest,
}, nil
case strings.HasPrefix(rawService, "unix:"):
return ResolvedService{
Kind: ResolvedServiceUnix,
Service: rawService,
UnixPath: strings.TrimPrefix(rawService, "unix:"),
BaseURL: &url.URL{Scheme: "http", Host: "localhost"},
OriginRequest: originRequest,
}, nil
case strings.HasPrefix(rawService, "unix+tls:"):
return ResolvedService{
Kind: ResolvedServiceUnixTLS,
Service: rawService,
UnixPath: strings.TrimPrefix(rawService, "unix+tls:"),
BaseURL: &url.URL{Scheme: "https", Host: "localhost"},
OriginRequest: originRequest,
}, nil
}
parsedURL, err := url.Parse(rawService)
if err != nil {
return ResolvedService{}, E.Cause(err, "parse ingress service URL")
}
if parsedURL.Scheme == "" || parsedURL.Hostname() == "" {
return ResolvedService{}, E.New("ingress service must include scheme and hostname: ", rawService)
}
if parsedURL.Path != "" {
return ResolvedService{}, E.New("ingress service cannot include a path: ", rawService)
}
if isHTTPServiceScheme(parsedURL.Scheme) {
return ResolvedService{
Kind: ResolvedServiceHTTP,
Service: rawService,
Destination: parseHTTPServiceDestination(parsedURL),
BaseURL: canonicalizeHTTPOriginURL(parsedURL),
OriginRequest: originRequest,
}, nil
}
destination, hasPort := parseStreamServiceDestination(parsedURL)
return ResolvedService{
Kind: ResolvedServiceStream,
Service: rawService,
Destination: destination,
StreamHasPort: hasPort,
BaseURL: parsedURL,
OriginRequest: originRequest,
}, nil
}
func parseHTTPServiceDestination(parsedURL *url.URL) M.Socksaddr {
host := parsedURL.Hostname()
port := parsedURL.Port()
if port == "" {
switch parsedURL.Scheme {
case "https", "wss":
port = "443"
default:
port = "80"
}
}
return M.ParseSocksaddr(net.JoinHostPort(host, port))
}
func parseStreamServiceDestination(parsedURL *url.URL) (M.Socksaddr, bool) {
host := parsedURL.Hostname()
port := parsedURL.Port()
if port == "" {
switch parsedURL.Scheme {
case "ssh":
port = "22"
case "rdp":
port = "3389"
case "smb":
port = "445"
case "tcp":
port = "7864"
default:
return M.ParseSocksaddrHostPort(host, 0), false
}
}
return M.ParseSocksaddr(net.JoinHostPort(host, port)), true
}
func validateHostname(hostname string, isLast bool) error {
if hostname == "" || hostname == "*" {
if !isLast {
return E.New("only the last ingress rule may be a catch-all rule")
}
return nil
}
if strings.Count(hostname, "*") > 1 || (strings.Contains(hostname, "*") && !strings.HasPrefix(hostname, "*.")) {
return E.New("hostname wildcard must be in the form *.example.com")
}
if stripPort(hostname) != hostname {
return E.New("ingress hostname cannot contain a port")
}
return nil
}
func isCatchAllRule(hostname, path string) bool {
return (hostname == "" || hostname == "*") && path == ""
}
func stripPort(hostname string) string {
if host, _, err := net.SplitHostPort(hostname); err == nil {
return host
}
return hostname
}
func defaultOriginRequestConfig() OriginRequestConfig {
return OriginRequestConfig{
ConnectTimeout: defaultHTTPConnectTimeout,
TLSTimeout: defaultTLSTimeout,
TCPKeepAlive: defaultTCPKeepAlive,
KeepAliveTimeout: defaultKeepAliveTimeout,
KeepAliveConnections: defaultKeepAliveConnections,
ProxyAddress: defaultProxyAddress,
}
}
func originRequestFromRemote(input remoteOriginRequestJSON) OriginRequestConfig {
config := defaultOriginRequestConfig()
if input.ConnectTimeout != 0 {
config.ConnectTimeout = time.Duration(input.ConnectTimeout) * time.Second
}
if input.TLSTimeout != 0 {
config.TLSTimeout = time.Duration(input.TLSTimeout) * time.Second
}
if input.TCPKeepAlive != 0 {
config.TCPKeepAlive = time.Duration(input.TCPKeepAlive) * time.Second
}
if input.KeepAliveTimeout != 0 {
config.KeepAliveTimeout = time.Duration(input.KeepAliveTimeout) * time.Second
}
if input.KeepAliveConnections != nil {
config.KeepAliveConnections = *input.KeepAliveConnections
}
if input.NoHappyEyeballs != nil {
config.NoHappyEyeballs = *input.NoHappyEyeballs
}
config.HTTPHostHeader = input.HTTPHostHeader
config.OriginServerName = input.OriginServerName
if input.MatchSNIToHost != nil {
config.MatchSNIToHost = *input.MatchSNIToHost
}
config.CAPool = input.CAPool
if input.NoTLSVerify != nil {
config.NoTLSVerify = *input.NoTLSVerify
}
if input.DisableChunkedEncoding != nil {
config.DisableChunkedEncoding = *input.DisableChunkedEncoding
}
if input.BastionMode != nil {
config.BastionMode = *input.BastionMode
}
if input.ProxyAddress != "" {
config.ProxyAddress = input.ProxyAddress
}
if input.ProxyPort != nil {
config.ProxyPort = *input.ProxyPort
}
config.ProxyType = input.ProxyType
if input.HTTP2Origin != nil {
config.HTTP2Origin = *input.HTTP2Origin
}
if input.Access != nil {
config.Access = AccessConfig{
Required: input.Access.Required,
TeamName: input.Access.TeamName,
AudTag: append([]string(nil), input.Access.AudTag...),
Environment: input.Access.Environment,
}
}
for _, rule := range input.IPRules {
config.IPRules = append(config.IPRules, IPRule{
Prefix: rule.Prefix,
Ports: append([]int(nil), rule.Ports...),
Allow: rule.Allow,
})
}
return config
}
func mergeRemoteOriginRequest(base OriginRequestConfig, override remoteOriginRequestJSON) OriginRequestConfig {
result := base
if override.ConnectTimeout != 0 {
result.ConnectTimeout = time.Duration(override.ConnectTimeout) * time.Second
}
if override.TLSTimeout != 0 {
result.TLSTimeout = time.Duration(override.TLSTimeout) * time.Second
}
if override.TCPKeepAlive != 0 {
result.TCPKeepAlive = time.Duration(override.TCPKeepAlive) * time.Second
}
if override.NoHappyEyeballs != nil {
result.NoHappyEyeballs = *override.NoHappyEyeballs
}
if override.KeepAliveTimeout != 0 {
result.KeepAliveTimeout = time.Duration(override.KeepAliveTimeout) * time.Second
}
if override.KeepAliveConnections != nil {
result.KeepAliveConnections = *override.KeepAliveConnections
}
if override.HTTPHostHeader != "" {
result.HTTPHostHeader = override.HTTPHostHeader
}
if override.OriginServerName != "" {
result.OriginServerName = override.OriginServerName
}
if override.MatchSNIToHost != nil {
result.MatchSNIToHost = *override.MatchSNIToHost
}
if override.CAPool != "" {
result.CAPool = override.CAPool
}
if override.NoTLSVerify != nil {
result.NoTLSVerify = *override.NoTLSVerify
}
if override.DisableChunkedEncoding != nil {
result.DisableChunkedEncoding = *override.DisableChunkedEncoding
}
if override.BastionMode != nil {
result.BastionMode = *override.BastionMode
}
if override.ProxyAddress != "" {
result.ProxyAddress = override.ProxyAddress
}
if override.ProxyPort != nil {
result.ProxyPort = *override.ProxyPort
}
if override.ProxyType != "" {
result.ProxyType = override.ProxyType
}
if len(override.IPRules) > 0 {
result.IPRules = nil
for _, rule := range override.IPRules {
result.IPRules = append(result.IPRules, IPRule{
Prefix: rule.Prefix,
Ports: append([]int(nil), rule.Ports...),
Allow: rule.Allow,
})
}
}
if override.HTTP2Origin != nil {
result.HTTP2Origin = *override.HTTP2Origin
}
if override.Access != nil {
result.Access = AccessConfig{
Required: override.Access.Required,
TeamName: override.Access.TeamName,
AudTag: append([]string(nil), override.Access.AudTag...),
Environment: override.Access.Environment,
}
}
return result
}
func warpRoutingFromRemote(input remoteWarpRoutingJSON) WarpRoutingConfig {
config := WarpRoutingConfig{
ConnectTimeout: defaultWarpRoutingConnectTime,
TCPKeepAlive: defaultWarpRoutingTCPKeepAlive,
MaxActiveFlows: input.MaxActiveFlows,
}
if input.ConnectTimeout != 0 {
config.ConnectTimeout = time.Duration(input.ConnectTimeout) * time.Second
}
if input.TCPKeepAlive != 0 {
config.TCPKeepAlive = time.Duration(input.TCPKeepAlive) * time.Second
}
return config
}

View File

@@ -1,63 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"io"
"time"
E "github.com/sagernet/sing/common/exceptions"
capnp "zombiezen.com/go/capnproto2"
"zombiezen.com/go/capnproto2/rpc"
)
const (
safeTransportMaxRetries = 3
safeTransportRetryInterval = 500 * time.Millisecond
)
type safeReadWriteCloser struct {
io.ReadWriteCloser
retries int
}
func (s *safeReadWriteCloser) Read(p []byte) (int, error) {
n, err := s.ReadWriteCloser.Read(p)
if n == 0 && err != nil && isTemporaryError(err) {
if s.retries >= safeTransportMaxRetries {
return 0, E.Cause(err, "read capnproto transport after multiple temporary errors")
}
s.retries++
time.Sleep(safeTransportRetryInterval)
return n, err
}
if err == nil {
s.retries = 0
}
return n, err
}
func isTemporaryError(err error) bool {
type temporary interface{ Temporary() bool }
t, ok := err.(temporary)
return ok && t.Temporary()
}
func safeTransport(stream io.ReadWriteCloser) rpc.Transport {
return rpc.StreamTransport(&safeReadWriteCloser{ReadWriteCloser: stream})
}
type noopCapnpLogger struct{}
func (noopCapnpLogger) Infof(ctx context.Context, format string, args ...interface{}) {}
func (noopCapnpLogger) Errorf(ctx context.Context, format string, args ...interface{}) {}
func newRPCClientConn(transport rpc.Transport, ctx context.Context) *rpc.Conn {
return rpc.NewConn(transport, rpc.ConnLog(noopCapnpLogger{}))
}
func newRPCServerConn(transport rpc.Transport, client capnp.Client) *rpc.Conn {
return rpc.NewConn(transport, rpc.MainInterface(client), rpc.ConnLog(noopCapnpLogger{}))
}

View File

@@ -1,355 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"crypto/sha1"
"encoding/base64"
"io"
"net"
"net/http"
"net/netip"
"net/url"
"strconv"
"strings"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/transport/v2raywebsocket"
"github.com/sagernet/sing/common/bufio"
E "github.com/sagernet/sing/common/exceptions"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/ws"
)
var wsAcceptGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
const (
socksReplySuccess = 0
socksReplyRuleFailure = 2
socksReplyNetworkUnreachable = 3
socksReplyHostUnreachable = 4
socksReplyConnectionRefused = 5
socksReplyCommandNotSupported = 7
)
func (i *Inbound) handleBastionStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
destination, err := resolveBastionDestination(request)
if err != nil {
respWriter.WriteResponse(err, nil)
return
}
i.handleRouterBackedStream(ctx, stream, respWriter, request, M.ParseSocksaddr(destination), service.OriginRequest.ProxyType)
}
func (i *Inbound) handleStreamService(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
if !service.StreamHasPort {
respWriter.WriteResponse(E.New("address ", streamServiceHostname(service), ": missing port in address"), nil)
return
}
i.handleRouterBackedStream(ctx, stream, respWriter, request, service.Destination, service.OriginRequest.ProxyType)
}
func (i *Inbound) handleRouterBackedStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, destination M.Socksaddr, proxyType string) {
targetConn, cleanup, err := i.dialRouterTCP(ctx, destination)
if err != nil {
respWriter.WriteResponse(err, nil)
return
}
defer cleanup()
err = respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusSwitchingProtocols, websocketResponseHeaders(request)))
if err != nil {
i.logger.ErrorContext(ctx, "write bastion websocket response: ", err)
return
}
wsConn := v2raywebsocket.NewConn(newStreamConn(stream), nil, ws.StateServerSide)
defer wsConn.Close()
if isSocksProxyType(proxyType) {
if err := serveFixedSocksStream(ctx, wsConn, targetConn); err != nil && !E.IsClosedOrCanceled(err) {
i.logger.DebugContext(ctx, "socks-over-websocket stream closed: ", err)
}
return
}
_ = bufio.CopyConn(ctx, wsConn, targetConn)
}
func (i *Inbound) handleSocksProxyStream(ctx context.Context, stream io.ReadWriteCloser, respWriter ConnectResponseWriter, request *ConnectRequest, metadata adapter.InboundContext, service ResolvedService) {
err := respWriter.WriteResponse(nil, encodeResponseHeaders(http.StatusSwitchingProtocols, websocketResponseHeaders(request)))
if err != nil {
i.logger.ErrorContext(ctx, "write socks-proxy websocket response: ", err)
return
}
wsConn := v2raywebsocket.NewConn(newStreamConn(stream), nil, ws.StateServerSide)
defer wsConn.Close()
if err := i.serveSocksProxy(ctx, wsConn, service.SocksPolicy); err != nil && !E.IsClosedOrCanceled(err) {
i.logger.DebugContext(ctx, "socks-proxy stream closed: ", err)
}
}
func resolveBastionDestination(request *ConnectRequest) (string, error) {
headerValue := requestHeaderValue(request, "Cf-Access-Jump-Destination")
if headerValue == "" {
return "", E.New("missing Cf-Access-Jump-Destination header")
}
if parsed, err := url.Parse(headerValue); err == nil && parsed.Host != "" {
headerValue = parsed.Host
}
return strings.SplitN(headerValue, "/", 2)[0], nil
}
func websocketResponseHeaders(request *ConnectRequest) http.Header {
header := http.Header{}
header.Set("Connection", "Upgrade")
header.Set("Upgrade", "websocket")
secKey := requestHeaderValue(request, "Sec-WebSocket-Key")
if secKey != "" {
sum := sha1.Sum(append([]byte(secKey), wsAcceptGUID...))
header.Set("Sec-WebSocket-Accept", base64.StdEncoding.EncodeToString(sum[:]))
}
return header
}
func isSocksProxyType(proxyType string) bool {
lower := strings.ToLower(strings.TrimSpace(proxyType))
return lower == "socks" || lower == "socks5"
}
func serveFixedSocksStream(ctx context.Context, conn net.Conn, targetConn net.Conn) error {
version := make([]byte, 1)
if _, err := io.ReadFull(conn, version); err != nil {
return err
}
if version[0] != 5 {
return E.New("unsupported SOCKS version: ", version[0])
}
methodCount := make([]byte, 1)
if _, err := io.ReadFull(conn, methodCount); err != nil {
return err
}
methods := make([]byte, int(methodCount[0]))
if _, err := io.ReadFull(conn, methods); err != nil {
return err
}
var supportsNoAuth bool
for _, method := range methods {
if method == 0 {
supportsNoAuth = true
break
}
}
if !supportsNoAuth {
_, err := conn.Write([]byte{5, 255})
if err != nil {
return err
}
return E.New("unknown authentication type")
}
if _, err := conn.Write([]byte{5, 0}); err != nil {
return err
}
requestHeader := make([]byte, 4)
if _, err := io.ReadFull(conn, requestHeader); err != nil {
return err
}
if requestHeader[0] != 5 {
return E.New("unsupported SOCKS request version: ", requestHeader[0])
}
if requestHeader[1] != 1 {
_ = writeSocksReply(conn, socksReplyCommandNotSupported)
return E.New("unsupported SOCKS command: ", requestHeader[1])
}
if _, err := readSocksDestination(conn, requestHeader[3]); err != nil {
return err
}
if err := writeSocksReply(conn, socksReplySuccess); err != nil {
return err
}
return bufio.CopyConn(ctx, conn, targetConn)
}
func requestHeaderValue(request *ConnectRequest, headerName string) string {
for _, entry := range request.Metadata {
if !strings.HasPrefix(entry.Key, metadataHTTPHeader+":") {
continue
}
name := strings.TrimPrefix(entry.Key, metadataHTTPHeader+":")
if strings.EqualFold(name, headerName) {
return entry.Val
}
}
return ""
}
func streamServiceHostname(service ResolvedService) string {
if service.BaseURL != nil && service.BaseURL.Hostname() != "" {
return service.BaseURL.Hostname()
}
parsedURL, err := url.Parse(service.Service)
if err == nil && parsedURL.Hostname() != "" {
return parsedURL.Hostname()
}
return service.Destination.AddrString()
}
func (i *Inbound) dialRouterTCP(ctx context.Context, destination M.Socksaddr) (net.Conn, func(), error) {
metadata := adapter.InboundContext{
Inbound: i.Tag(),
InboundType: i.Type(),
Network: N.NetworkTCP,
Destination: destination,
}
return i.dialRouterTCPWithMetadata(ctx, metadata, routedPipeTCPOptions{})
}
func (i *Inbound) serveSocksProxy(ctx context.Context, conn net.Conn, policy *ipRulePolicy) error {
version := make([]byte, 1)
if _, err := io.ReadFull(conn, version); err != nil {
return err
}
if version[0] != 5 {
return E.New("unsupported SOCKS version: ", version[0])
}
methodCount := make([]byte, 1)
if _, err := io.ReadFull(conn, methodCount); err != nil {
return err
}
methods := make([]byte, int(methodCount[0]))
if _, err := io.ReadFull(conn, methods); err != nil {
return err
}
var supportsNoAuth bool
for _, method := range methods {
if method == 0 {
supportsNoAuth = true
break
}
}
if !supportsNoAuth {
if _, err := conn.Write([]byte{5, 255}); err != nil {
return err
}
return E.New("unknown authentication type")
}
if _, err := conn.Write([]byte{5, 0}); err != nil {
return err
}
requestHeader := make([]byte, 4)
if _, err := io.ReadFull(conn, requestHeader); err != nil {
return err
}
if requestHeader[0] != 5 {
return E.New("unsupported SOCKS request version: ", requestHeader[0])
}
if requestHeader[1] != 1 {
_ = writeSocksReply(conn, socksReplyCommandNotSupported)
return E.New("unsupported SOCKS command: ", requestHeader[1])
}
destination, err := readSocksDestination(conn, requestHeader[3])
if err != nil {
return err
}
allowed, err := policy.Allow(ctx, destination)
if err != nil {
_ = writeSocksReply(conn, socksReplyRuleFailure)
return err
}
if !allowed {
_ = writeSocksReply(conn, socksReplyRuleFailure)
return E.New("connect to ", destination, " denied by ip_rules")
}
targetConn, cleanup, err := i.dialRouterTCP(ctx, destination)
if err != nil {
_ = writeSocksReply(conn, socksReplyForDialError(err))
return err
}
defer cleanup()
if err := writeSocksReply(conn, socksReplySuccess); err != nil {
return err
}
return bufio.CopyConn(ctx, conn, targetConn)
}
func writeSocksReply(conn net.Conn, reply byte) error {
_, err := conn.Write([]byte{5, reply, 0, 1, 0, 0, 0, 0, 0, 0})
return err
}
func socksReplyForDialError(err error) byte {
lower := strings.ToLower(err.Error())
switch {
case strings.Contains(lower, "refused"):
return socksReplyConnectionRefused
case strings.Contains(lower, "network is unreachable"):
return socksReplyNetworkUnreachable
default:
return socksReplyHostUnreachable
}
}
func readSocksDestination(conn net.Conn, addressType byte) (M.Socksaddr, error) {
switch addressType {
case 1:
addr := make([]byte, 4)
if _, err := io.ReadFull(conn, addr); err != nil {
return M.Socksaddr{}, err
}
port, err := readSocksPort(conn)
if err != nil {
return M.Socksaddr{}, err
}
ipAddr, ok := netip.AddrFromSlice(addr)
if !ok {
return M.Socksaddr{}, E.New("invalid IPv4 SOCKS destination")
}
return M.SocksaddrFrom(ipAddr, port), nil
case 3:
length := make([]byte, 1)
if _, err := io.ReadFull(conn, length); err != nil {
return M.Socksaddr{}, err
}
host := make([]byte, int(length[0]))
if _, err := io.ReadFull(conn, host); err != nil {
return M.Socksaddr{}, err
}
port, err := readSocksPort(conn)
if err != nil {
return M.Socksaddr{}, err
}
return M.ParseSocksaddr(net.JoinHostPort(string(host), strconv.Itoa(int(port)))), nil
case 4:
addr := make([]byte, 16)
if _, err := io.ReadFull(conn, addr); err != nil {
return M.Socksaddr{}, err
}
port, err := readSocksPort(conn)
if err != nil {
return M.Socksaddr{}, err
}
ipAddr, ok := netip.AddrFromSlice(addr)
if !ok {
return M.Socksaddr{}, E.New("invalid IPv6 SOCKS destination")
}
return M.SocksaddrFrom(ipAddr, port), nil
default:
return M.Socksaddr{}, E.New("unsupported SOCKS address type: ", addressType)
}
}
func readSocksPort(conn net.Conn) (uint16, error) {
port := make([]byte, 2)
if _, err := io.ReadFull(conn, port); err != nil {
return 0, err
}
return uint16(port[0])<<8 | uint16(port[1]), nil
}

View File

@@ -1,680 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"context"
"errors"
"io"
"net"
"net/http"
"net/url"
"strconv"
"sync/atomic"
"testing"
"time"
"github.com/sagernet/sing-box/adapter"
"github.com/sagernet/sing-box/adapter/inbound"
C "github.com/sagernet/sing-box/constant"
"github.com/sagernet/sing-box/log"
"github.com/sagernet/sing-box/option"
M "github.com/sagernet/sing/common/metadata"
N "github.com/sagernet/sing/common/network"
"github.com/sagernet/ws"
"github.com/sagernet/ws/wsutil"
)
type fakeConnectResponseWriter struct {
status int
headers http.Header
err error
done chan struct{}
}
func (w *fakeConnectResponseWriter) WriteResponse(responseError error, metadata []Metadata) error {
w.err = responseError
w.headers = make(http.Header)
for _, entry := range metadata {
switch {
case entry.Key == metadataHTTPStatus:
status, _ := strconv.Atoi(entry.Val)
w.status = status
case len(entry.Key) > len(metadataHTTPHeader)+1 && entry.Key[:len(metadataHTTPHeader)+1] == metadataHTTPHeader+":":
w.headers.Add(entry.Key[len(metadataHTTPHeader)+1:], entry.Val)
}
}
if w.done != nil {
close(w.done)
w.done = nil
}
return nil
}
func newSpecialServiceInbound(t *testing.T) *Inbound {
return newSpecialServiceInboundWithRouter(t, &testRouter{})
}
func newSpecialServiceInboundWithRouter(t *testing.T, router adapter.Router) *Inbound {
t.Helper()
logFactory, err := log.New(log.Options{Options: option.LogOptions{Level: "debug"}})
if err != nil {
t.Fatal(err)
}
configManager, err := NewConfigManager()
if err != nil {
t.Fatal(err)
}
return &Inbound{
Adapter: inbound.NewAdapter(C.TypeCloudflared, "test"),
router: router,
logger: logFactory.NewLogger("test"),
configManager: configManager,
flowLimiter: &FlowLimiter{},
}
}
type countingRouter struct {
testRouter
count atomic.Int32
}
func (r *countingRouter) RouteConnectionEx(ctx context.Context, conn net.Conn, metadata adapter.InboundContext, onClose N.CloseHandlerFunc) {
r.count.Add(1)
r.testRouter.RouteConnectionEx(ctx, conn, metadata, onClose)
}
func startEchoListener(t *testing.T) net.Listener {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatal(err)
}
go func() {
for {
conn, err := listener.Accept()
if err != nil {
return
}
go func(conn net.Conn) {
defer conn.Close()
_, _ = io.Copy(conn, conn)
}(conn)
}
}()
return listener
}
func newSocksProxyService(t *testing.T, rules []IPRule) ResolvedService {
t.Helper()
service, err := parseResolvedService("socks-proxy", OriginRequestConfig{IPRules: rules})
if err != nil {
t.Fatal(err)
}
return service
}
func newSocksProxyConnectRequest() *ConnectRequest {
return &ConnectRequest{
Type: ConnectionTypeWebsocket,
Metadata: []Metadata{
{Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="},
},
}
}
func startSocksProxyStream(t *testing.T, inboundInstance *Inbound, service ResolvedService) (net.Conn, <-chan struct{}) {
t.Helper()
serverSide, clientSide := net.Pipe()
respWriter := &fakeConnectResponseWriter{done: make(chan struct{})}
done := make(chan struct{})
go func() {
defer close(done)
inboundInstance.handleSocksProxyStream(context.Background(), serverSide, respWriter, newSocksProxyConnectRequest(), adapter.InboundContext{}, service)
}()
select {
case <-respWriter.done:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for socks-proxy connect response")
}
if respWriter.err != nil {
t.Fatal(respWriter.err)
}
if respWriter.status != http.StatusSwitchingProtocols {
t.Fatalf("expected 101 response, got %d", respWriter.status)
}
return clientSide, done
}
func writeSocksAuth(t *testing.T, conn net.Conn) {
t.Helper()
if err := wsutil.WriteClientMessage(conn, ws.OpBinary, []byte{5, 1, 0}); err != nil {
t.Fatal(err)
}
data, _, err := wsutil.ReadServerData(conn)
if err != nil {
t.Fatal(err)
}
if string(data) != string([]byte{5, 0}) {
t.Fatalf("unexpected auth response: %v", data)
}
}
func writeSocksConnectIPv4(t *testing.T, conn net.Conn, address string) []byte {
t.Helper()
host, portText, err := net.SplitHostPort(address)
if err != nil {
t.Fatal(err)
}
port, err := strconv.Atoi(portText)
if err != nil {
t.Fatal(err)
}
requestBytes := []byte{5, 1, 0, 1}
requestBytes = append(requestBytes, net.ParseIP(host).To4()...)
requestBytes = append(requestBytes, byte(port>>8), byte(port))
if err := wsutil.WriteClientMessage(conn, ws.OpBinary, requestBytes); err != nil {
t.Fatal(err)
}
data, _, err := wsutil.ReadServerData(conn)
if err != nil {
t.Fatal(err)
}
return data
}
func TestServeSocksProxyRejectsMissingNoAuth(t *testing.T) {
inboundInstance := newSpecialServiceInbound(t)
serverSide, clientSide := net.Pipe()
defer clientSide.Close()
errCh := make(chan error, 1)
go func() {
errCh <- inboundInstance.serveSocksProxy(context.Background(), serverSide, nil)
}()
if _, err := clientSide.Write([]byte{5, 1, 2}); err != nil {
t.Fatal(err)
}
response := make([]byte, 2)
if _, err := io.ReadFull(clientSide, response); err != nil {
t.Fatal(err)
}
if string(response) != string([]byte{5, 255}) {
t.Fatalf("unexpected auth rejection response: %v", response)
}
if err := <-errCh; err == nil {
t.Fatal("expected socks auth rejection error")
}
}
func TestSocksReplyForDialError(t *testing.T) {
if reply := socksReplyForDialError(io.EOF); reply != socksReplyHostUnreachable {
t.Fatalf("expected host unreachable for generic error, got %d", reply)
}
if reply := socksReplyForDialError(errors.New("connection refused")); reply != 5 {
t.Fatalf("expected connection refused reply, got %d", reply)
}
if reply := socksReplyForDialError(errors.New("network is unreachable")); reply != 3 {
t.Fatalf("expected network unreachable reply, got %d", reply)
}
}
func TestHandleBastionStream(t *testing.T) {
listener := startEchoListener(t)
defer listener.Close()
serverSide, clientSide := net.Pipe()
defer clientSide.Close()
inboundInstance := newSpecialServiceInbound(t)
request := &ConnectRequest{
Type: ConnectionTypeWebsocket,
Metadata: []Metadata{
{Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="},
{Key: metadataHTTPHeader + ":Cf-Access-Jump-Destination", Val: listener.Addr().String()},
},
}
respWriter := &fakeConnectResponseWriter{done: make(chan struct{})}
done := make(chan struct{})
go func() {
defer close(done)
inboundInstance.handleBastionStream(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, ResolvedService{})
}()
select {
case <-respWriter.done:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for bastion connect response")
}
if respWriter.err != nil {
t.Fatal(respWriter.err)
}
if respWriter.status != http.StatusSwitchingProtocols {
t.Fatalf("expected 101 response, got %d", respWriter.status)
}
if respWriter.headers.Get("Sec-WebSocket-Accept") == "" {
t.Fatal("expected websocket accept header")
}
if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, []byte("hello")); err != nil {
t.Fatal(err)
}
data, opCode, err := wsutil.ReadServerData(clientSide)
if err != nil {
t.Fatal(err)
}
if opCode != ws.OpBinary {
t.Fatalf("expected binary frame, got %v", opCode)
}
if string(data) != "hello" {
t.Fatalf("expected echoed payload, got %q", string(data))
}
_ = clientSide.Close()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("bastion stream did not exit")
}
}
func TestHandleSocksProxyStream(t *testing.T) {
listener := startEchoListener(t)
defer listener.Close()
_, portText, _ := net.SplitHostPort(listener.Addr().String())
port, _ := strconv.Atoi(portText)
service := newSocksProxyService(t, []IPRule{{
Prefix: "127.0.0.0/8",
Ports: []int{port},
Allow: true,
}})
clientSide, done := startSocksProxyStream(t, newSpecialServiceInbound(t), service)
defer clientSide.Close()
writeSocksAuth(t, clientSide)
data := writeSocksConnectIPv4(t, clientSide, listener.Addr().String())
if len(data) != 10 || data[1] != 0 {
t.Fatalf("unexpected connect response: %v", data)
}
if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, []byte("hello")); err != nil {
t.Fatal(err)
}
data, _, err := wsutil.ReadServerData(clientSide)
if err != nil {
t.Fatal(err)
}
if string(data) != "hello" {
t.Fatalf("expected echoed payload, got %q", string(data))
}
_ = clientSide.Close()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("socks-proxy stream did not exit")
}
}
func TestHandleSocksProxyStreamDenyRule(t *testing.T) {
listener := startEchoListener(t)
defer listener.Close()
_, portText, _ := net.SplitHostPort(listener.Addr().String())
port, _ := strconv.Atoi(portText)
service := newSocksProxyService(t, []IPRule{{
Prefix: "127.0.0.0/8",
Ports: []int{port},
Allow: false,
}})
router := &countingRouter{}
clientSide, done := startSocksProxyStream(t, newSpecialServiceInboundWithRouter(t, router), service)
defer clientSide.Close()
writeSocksAuth(t, clientSide)
data := writeSocksConnectIPv4(t, clientSide, listener.Addr().String())
if len(data) != 10 || data[1] != socksReplyRuleFailure {
t.Fatalf("unexpected deny response: %v", data)
}
if router.count.Load() != 0 {
t.Fatalf("expected no router dial, got %d", router.count.Load())
}
_ = clientSide.Close()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("socks-proxy stream did not exit")
}
}
func TestHandleSocksProxyStreamPortMismatchDefaultDeny(t *testing.T) {
listener := startEchoListener(t)
defer listener.Close()
_, portText, _ := net.SplitHostPort(listener.Addr().String())
port, _ := strconv.Atoi(portText)
service := newSocksProxyService(t, []IPRule{{
Prefix: "127.0.0.0/8",
Ports: []int{port + 1},
Allow: true,
}})
router := &countingRouter{}
clientSide, done := startSocksProxyStream(t, newSpecialServiceInboundWithRouter(t, router), service)
defer clientSide.Close()
writeSocksAuth(t, clientSide)
data := writeSocksConnectIPv4(t, clientSide, listener.Addr().String())
if len(data) != 10 || data[1] != socksReplyRuleFailure {
t.Fatalf("unexpected port mismatch response: %v", data)
}
if router.count.Load() != 0 {
t.Fatalf("expected no router dial, got %d", router.count.Load())
}
_ = clientSide.Close()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("socks-proxy stream did not exit")
}
}
func TestHandleSocksProxyStreamEmptyRulesDefaultDeny(t *testing.T) {
listener := startEchoListener(t)
defer listener.Close()
router := &countingRouter{}
clientSide, done := startSocksProxyStream(t, newSpecialServiceInboundWithRouter(t, router), newSocksProxyService(t, nil))
defer clientSide.Close()
writeSocksAuth(t, clientSide)
data := writeSocksConnectIPv4(t, clientSide, listener.Addr().String())
if len(data) != 10 || data[1] != socksReplyRuleFailure {
t.Fatalf("unexpected empty-rule response: %v", data)
}
if router.count.Load() != 0 {
t.Fatalf("expected no router dial, got %d", router.count.Load())
}
_ = clientSide.Close()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("socks-proxy stream did not exit")
}
}
func TestHandleSocksProxyStreamRuleOrderFirstMatchWins(t *testing.T) {
listener := startEchoListener(t)
defer listener.Close()
_, portText, _ := net.SplitHostPort(listener.Addr().String())
port, _ := strconv.Atoi(portText)
allowFirst := newSocksProxyService(t, []IPRule{
{Prefix: "127.0.0.0/8", Ports: []int{port}, Allow: true},
{Prefix: "127.0.0.1/32", Ports: []int{port}, Allow: false},
})
denyFirst := newSocksProxyService(t, []IPRule{
{Prefix: "127.0.0.1/32", Ports: []int{port}, Allow: false},
{Prefix: "127.0.0.0/8", Ports: []int{port}, Allow: true},
})
t.Run("allow-first", func(t *testing.T) {
clientSide, done := startSocksProxyStream(t, newSpecialServiceInbound(t), allowFirst)
defer clientSide.Close()
writeSocksAuth(t, clientSide)
data := writeSocksConnectIPv4(t, clientSide, listener.Addr().String())
if len(data) != 10 || data[1] != socksReplySuccess {
t.Fatalf("unexpected allow-first response: %v", data)
}
_ = clientSide.Close()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("socks-proxy stream did not exit")
}
})
t.Run("deny-first", func(t *testing.T) {
router := &countingRouter{}
clientSide, done := startSocksProxyStream(t, newSpecialServiceInboundWithRouter(t, router), denyFirst)
defer clientSide.Close()
writeSocksAuth(t, clientSide)
data := writeSocksConnectIPv4(t, clientSide, listener.Addr().String())
if len(data) != 10 || data[1] != socksReplyRuleFailure {
t.Fatalf("unexpected deny-first response: %v", data)
}
if router.count.Load() != 0 {
t.Fatalf("expected no router dial, got %d", router.count.Load())
}
_ = clientSide.Close()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("socks-proxy stream did not exit")
}
})
}
func TestHandleStreamService(t *testing.T) {
listener := startEchoListener(t)
defer listener.Close()
serverSide, clientSide := net.Pipe()
defer clientSide.Close()
inboundInstance := newSpecialServiceInbound(t)
request := &ConnectRequest{
Type: ConnectionTypeWebsocket,
Metadata: []Metadata{
{Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="},
},
}
respWriter := &fakeConnectResponseWriter{done: make(chan struct{})}
done := make(chan struct{})
go func() {
defer close(done)
inboundInstance.handleStreamService(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, ResolvedService{
Kind: ResolvedServiceStream,
Destination: M.ParseSocksaddr(listener.Addr().String()),
StreamHasPort: true,
})
}()
select {
case <-respWriter.done:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for stream service connect response")
}
if respWriter.err != nil {
t.Fatal(respWriter.err)
}
if respWriter.status != http.StatusSwitchingProtocols {
t.Fatalf("expected 101 response, got %d", respWriter.status)
}
if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, []byte("hello")); err != nil {
t.Fatal(err)
}
data, opCode, err := wsutil.ReadServerData(clientSide)
if err != nil {
t.Fatal(err)
}
if opCode != ws.OpBinary {
t.Fatalf("expected binary frame, got %v", opCode)
}
if string(data) != "hello" {
t.Fatalf("expected echoed payload, got %q", string(data))
}
_ = clientSide.Close()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("stream service did not exit")
}
}
func TestHandleStreamServiceProxyTypeSocks(t *testing.T) {
listener := startEchoListener(t)
defer listener.Close()
serverSide, clientSide := net.Pipe()
defer clientSide.Close()
inboundInstance := newSpecialServiceInbound(t)
request := &ConnectRequest{
Type: ConnectionTypeWebsocket,
Metadata: []Metadata{
{Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="},
},
}
respWriter := &fakeConnectResponseWriter{done: make(chan struct{})}
done := make(chan struct{})
go func() {
defer close(done)
inboundInstance.handleStreamService(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, ResolvedService{
Kind: ResolvedServiceStream,
Destination: M.ParseSocksaddr(listener.Addr().String()),
StreamHasPort: true,
OriginRequest: OriginRequestConfig{
ProxyType: "socks",
},
})
}()
select {
case <-respWriter.done:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for stream service connect response")
}
if respWriter.err != nil {
t.Fatal(respWriter.err)
}
if respWriter.status != http.StatusSwitchingProtocols {
t.Fatalf("expected 101 response, got %d", respWriter.status)
}
writeSocksAuth(t, clientSide)
data := writeSocksConnectIPv4(t, clientSide, listener.Addr().String())
if len(data) != 10 || data[1] != socksReplySuccess {
t.Fatalf("unexpected socks connect response: %v", data)
}
if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, []byte("hello")); err != nil {
t.Fatal(err)
}
data, _, err := wsutil.ReadServerData(clientSide)
if err != nil {
t.Fatal(err)
}
if string(data) != "hello" {
t.Fatalf("expected echoed payload, got %q", string(data))
}
_ = clientSide.Close()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("socks stream service did not exit")
}
}
func TestHandleStreamServiceGenericSchemeWithPort(t *testing.T) {
listener := startEchoListener(t)
defer listener.Close()
serverSide, clientSide := net.Pipe()
defer clientSide.Close()
inboundInstance := newSpecialServiceInbound(t)
request := &ConnectRequest{
Type: ConnectionTypeWebsocket,
Metadata: []Metadata{
{Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="},
},
}
respWriter := &fakeConnectResponseWriter{done: make(chan struct{})}
done := make(chan struct{})
go func() {
defer close(done)
inboundInstance.handleStreamService(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, ResolvedService{
Kind: ResolvedServiceStream,
Service: "ftp://" + listener.Addr().String(),
Destination: M.ParseSocksaddr(listener.Addr().String()),
StreamHasPort: true,
})
}()
select {
case <-respWriter.done:
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for stream service connect response")
}
if respWriter.err != nil {
t.Fatal(respWriter.err)
}
if respWriter.status != http.StatusSwitchingProtocols {
t.Fatalf("expected 101 response, got %d", respWriter.status)
}
if err := wsutil.WriteClientMessage(clientSide, ws.OpBinary, []byte("hello")); err != nil {
t.Fatal(err)
}
data, _, err := wsutil.ReadServerData(clientSide)
if err != nil {
t.Fatal(err)
}
if string(data) != "hello" {
t.Fatalf("expected echoed payload, got %q", string(data))
}
_ = clientSide.Close()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("generic stream service did not exit")
}
}
func TestHandleStreamServiceGenericSchemeWithoutPort(t *testing.T) {
serverSide, clientSide := net.Pipe()
defer clientSide.Close()
defer serverSide.Close()
router := &countingRouter{}
inboundInstance := newSpecialServiceInboundWithRouter(t, router)
request := &ConnectRequest{
Type: ConnectionTypeWebsocket,
Metadata: []Metadata{
{Key: metadataHTTPHeader + ":Sec-WebSocket-Key", Val: "dGhlIHNhbXBsZSBub25jZQ=="},
},
}
respWriter := &fakeConnectResponseWriter{done: make(chan struct{})}
inboundInstance.handleStreamService(context.Background(), serverSide, respWriter, request, adapter.InboundContext{}, ResolvedService{
Kind: ResolvedServiceStream,
Service: "ftp://127.0.0.1",
Destination: M.ParseSocksaddrHostPort("127.0.0.1", 0),
StreamHasPort: false,
BaseURL: &url.URL{
Scheme: "ftp",
Host: "127.0.0.1",
},
})
if respWriter.err == nil {
t.Fatal("expected missing port error")
}
if respWriter.err.Error() != "address 127.0.0.1: missing port in address" {
t.Fatalf("unexpected error: %v", respWriter.err)
}
if respWriter.status == http.StatusSwitchingProtocols {
t.Fatalf("expected non-upgrade response on error, got %d", respWriter.status)
}
if router.count.Load() != 0 {
t.Fatalf("expected router not to be used, got %d", router.count.Load())
}
}

View File

@@ -1,235 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"io"
"net"
"time"
"github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc"
E "github.com/sagernet/sing/common/exceptions"
"github.com/google/uuid"
capnp "zombiezen.com/go/capnproto2"
"zombiezen.com/go/capnproto2/pogs"
)
// Protocol signatures distinguish stream types.
var (
dataStreamSignature = [6]byte{0x0A, 0x36, 0xCD, 0x12, 0xA1, 0x3E}
rpcStreamSignature = [6]byte{0x52, 0xBB, 0x82, 0x5C, 0xDB, 0x65}
)
const protocolVersion = "01"
// StreamType identifies the kind of QUIC stream.
type StreamType int
const (
StreamTypeData StreamType = iota
StreamTypeRPC
)
const metadataFlowConnectRateLimited = "FlowConnectRateLimited"
// ConnectionType indicates the proxied connection type within a data stream.
type ConnectionType uint16
const (
ConnectionTypeHTTP ConnectionType = iota
ConnectionTypeWebsocket
ConnectionTypeTCP
)
func (c ConnectionType) String() string {
switch c {
case ConnectionTypeHTTP:
return "http"
case ConnectionTypeWebsocket:
return "websocket"
case ConnectionTypeTCP:
return "tcp"
default:
return "unknown"
}
}
// Metadata is a key-value pair in stream metadata.
type Metadata struct {
Key string `capnp:"key"`
Val string `capnp:"val"`
}
func flowConnectRateLimitedMetadata() []Metadata {
return []Metadata{{
Key: metadataFlowConnectRateLimited,
Val: "true",
}}
}
func hasFlowConnectRateLimited(metadata []Metadata) bool {
for _, entry := range metadata {
if entry.Key == metadataFlowConnectRateLimited && entry.Val == "true" {
return true
}
}
return false
}
// ConnectRequest is sent by the edge at the start of a data stream.
type ConnectRequest struct {
Dest string `capnp:"dest"`
Type ConnectionType `capnp:"type"`
Metadata []Metadata `capnp:"metadata"`
}
func (r *ConnectRequest) MetadataMap() map[string]string {
result := make(map[string]string, len(r.Metadata))
for _, m := range r.Metadata {
result[m.Key] = m.Val
}
return result
}
func (r *ConnectRequest) fromCapnp(msg *capnp.Message) error {
root, err := tunnelrpc.ReadRootConnectRequest(msg)
if err != nil {
return err
}
return pogs.Extract(r, tunnelrpc.ConnectRequest_TypeID, root.Struct)
}
// ConnectResponse is sent back to the edge after processing a ConnectRequest.
type ConnectResponse struct {
Error string `capnp:"error"`
Metadata []Metadata `capnp:"metadata"`
}
func (r *ConnectResponse) toCapnp() (*capnp.Message, error) {
msg, seg, err := capnp.NewMessage(capnp.SingleSegment(nil))
if err != nil {
return nil, err
}
root, err := tunnelrpc.NewRootConnectResponse(seg)
if err != nil {
return nil, err
}
err = pogs.Insert(tunnelrpc.ConnectResponse_TypeID, root.Struct, r)
if err != nil {
return nil, err
}
return msg, nil
}
// ReadStreamSignature reads the 6-byte stream type signature.
func ReadStreamSignature(r io.Reader) (StreamType, error) {
var signature [6]byte
_, err := io.ReadFull(r, signature[:])
if err != nil {
return 0, err
}
switch signature {
case dataStreamSignature:
return StreamTypeData, nil
case rpcStreamSignature:
return StreamTypeRPC, nil
default:
return 0, E.New("unknown stream signature")
}
}
// ReadConnectRequest reads the version and ConnectRequest from a data stream.
func ReadConnectRequest(r io.Reader) (*ConnectRequest, error) {
version := make([]byte, 2)
_, err := io.ReadFull(r, version)
if err != nil {
return nil, E.Cause(err, "read version")
}
msg, err := capnp.NewDecoder(r).Decode()
if err != nil {
return nil, E.Cause(err, "decode connect request")
}
request := &ConnectRequest{}
err = request.fromCapnp(msg)
if err != nil {
return nil, E.Cause(err, "extract connect request")
}
return request, nil
}
// WriteConnectResponse writes a ConnectResponse with the data stream preamble.
func WriteConnectResponse(w io.Writer, responseError error, metadata ...Metadata) error {
response := &ConnectResponse{
Metadata: metadata,
}
if responseError != nil {
response.Error = responseError.Error()
}
msg, err := response.toCapnp()
if err != nil {
return E.Cause(err, "encode connect response")
}
// Write data stream preamble
_, err = w.Write(dataStreamSignature[:])
if err != nil {
return err
}
_, err = w.Write([]byte(protocolVersion))
if err != nil {
return err
}
return capnp.NewEncoder(w).Encode(msg)
}
func WriteRPCStreamSignature(w io.Writer) error {
_, err := w.Write(rpcStreamSignature[:])
return err
}
// Registration data structures for the control stream.
type RegistrationTunnelAuth struct {
AccountTag string `capnp:"accountTag"`
TunnelSecret []byte `capnp:"tunnelSecret"`
}
type RegistrationClientInfo struct {
ClientID []byte `capnp:"clientId"`
Features []string `capnp:"features"`
Version string `capnp:"version"`
Arch string `capnp:"arch"`
}
type RegistrationConnectionOptions struct {
Client RegistrationClientInfo `capnp:"client"`
OriginLocalIP net.IP `capnp:"originLocalIp"`
ReplaceExisting bool `capnp:"replaceExisting"`
CompressionQuality uint8 `capnp:"compressionQuality"`
NumPreviousAttempts uint8 `capnp:"numPreviousAttempts"`
}
// RegistrationResult is the parsed result of a RegisterConnection RPC.
type RegistrationResult struct {
ConnectionID uuid.UUID
Location string
TunnelIsRemotelyManaged bool
}
// RetryableError signals the edge wants us to retry after a delay.
type RetryableError struct {
Err error
Delay time.Duration
}
func (e *RetryableError) Error() string {
return e.Err.Error()
}
func (e *RetryableError) Unwrap() error {
return e.Err
}

View File

@@ -1,95 +0,0 @@
//go:build with_cloudflared
package cloudflare
import (
"bytes"
"errors"
"io"
"testing"
)
func TestReadStreamSignatureData(t *testing.T) {
buf := bytes.NewBuffer(dataStreamSignature[:])
streamType, err := ReadStreamSignature(buf)
if err != nil {
t.Fatal("ReadStreamSignature: ", err)
}
if streamType != StreamTypeData {
t.Error("expected StreamTypeData, got ", streamType)
}
}
func TestReadStreamSignatureRPC(t *testing.T) {
buf := bytes.NewBuffer(rpcStreamSignature[:])
streamType, err := ReadStreamSignature(buf)
if err != nil {
t.Fatal("ReadStreamSignature: ", err)
}
if streamType != StreamTypeRPC {
t.Error("expected StreamTypeRPC, got ", streamType)
}
}
func TestReadStreamSignatureUnknown(t *testing.T) {
buf := bytes.NewBuffer([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
_, err := ReadStreamSignature(buf)
if err == nil {
t.Fatal("expected error for unknown signature")
}
}
func TestReadStreamSignatureTooShort(t *testing.T) {
buf := bytes.NewBuffer([]byte{0x0A, 0x36, 0xCD})
_, err := ReadStreamSignature(buf)
if err == nil {
t.Fatal("expected error for short input")
}
if !errors.Is(err, io.ErrUnexpectedEOF) {
t.Error("expected ErrUnexpectedEOF, got ", err)
}
}
func TestWriteConnectResponseSuccess(t *testing.T) {
var buf bytes.Buffer
metadata := Metadata{Key: "testKey", Val: "testVal"}
err := WriteConnectResponse(&buf, nil, metadata)
if err != nil {
t.Fatal("WriteConnectResponse: ", err)
}
data := buf.Bytes()
if len(data) < 8 {
t.Fatal("response too short: ", len(data))
}
var signature [6]byte
copy(signature[:], data[:6])
if signature != dataStreamSignature {
t.Error("expected data stream signature")
}
version := string(data[6:8])
if version != "01" {
t.Error("expected version 01, got ", version)
}
}
func TestWriteConnectResponseError(t *testing.T) {
var buf bytes.Buffer
err := WriteConnectResponse(&buf, errors.New("test failure"))
if err != nil {
t.Fatal("WriteConnectResponse: ", err)
}
data := buf.Bytes()
if len(data) < 8 {
t.Fatal("response too short")
}
var signature [6]byte
copy(signature[:], data[:6])
if signature != dataStreamSignature {
t.Error("expected data stream signature")
}
}

View File

@@ -1,31 +0,0 @@
# Generate go.capnp.out with:
# capnp compile -o- go.capnp > go.capnp.out
# Must run inside this directory to preserve paths.
@0xd12a1c51fedd6c88;
annotation package(file) :Text;
# The Go package name for the generated file.
annotation import(file) :Text;
# The Go import path that the generated file is accessible from.
# Used to generate import statements and check if two types are in the
# same package.
annotation doc(struct, field, enum) :Text;
# Adds a doc comment to the generated code.
annotation tag(enumerant) :Text;
# Changes the string representation of the enum in the generated code.
annotation notag(enumerant) :Void;
# Removes the string representation of the enum in the generated code.
annotation customtype(field) :Text;
# OBSOLETE, not used by code generator.
annotation name(struct, field, union, enum, enumerant, interface, method, param, annotation, const, group) :Text;
# Used to rename the element in the generated code.
$package("capnp");
$import("zombiezen.com/go/capnproto2");

View File

@@ -1,28 +0,0 @@
using Go = import "go.capnp";
@0xb29021ef7421cc32;
$Go.package("tunnelrpc");
$Go.import("github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc");
struct ConnectRequest @0xc47116a1045e4061 {
dest @0 :Text;
type @1 :ConnectionType;
metadata @2 :List(Metadata);
}
enum ConnectionType @0xc52e1bac26d379c8 {
http @0;
websocket @1;
tcp @2;
}
struct Metadata @0xe1446b97bfd1cd37 {
key @0 :Text;
val @1 :Text;
}
struct ConnectResponse @0xb1032ec91cef8727 {
error @0 :Text;
metadata @1 :List(Metadata);
}

View File

@@ -1,394 +0,0 @@
// Code generated by capnpc-go. DO NOT EDIT.
package tunnelrpc
import (
capnp "zombiezen.com/go/capnproto2"
text "zombiezen.com/go/capnproto2/encoding/text"
schemas "zombiezen.com/go/capnproto2/schemas"
)
type ConnectRequest struct{ capnp.Struct }
// ConnectRequest_TypeID is the unique identifier for the type ConnectRequest.
const ConnectRequest_TypeID = 0xc47116a1045e4061
func NewConnectRequest(s *capnp.Segment) (ConnectRequest, error) {
st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 2})
return ConnectRequest{st}, err
}
func NewRootConnectRequest(s *capnp.Segment) (ConnectRequest, error) {
st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 8, PointerCount: 2})
return ConnectRequest{st}, err
}
func ReadRootConnectRequest(msg *capnp.Message) (ConnectRequest, error) {
root, err := msg.RootPtr()
return ConnectRequest{root.Struct()}, err
}
func (s ConnectRequest) String() string {
str, _ := text.Marshal(0xc47116a1045e4061, s.Struct)
return str
}
func (s ConnectRequest) Dest() (string, error) {
p, err := s.Struct.Ptr(0)
return p.Text(), err
}
func (s ConnectRequest) HasDest() bool {
p, err := s.Struct.Ptr(0)
return p.IsValid() || err != nil
}
func (s ConnectRequest) DestBytes() ([]byte, error) {
p, err := s.Struct.Ptr(0)
return p.TextBytes(), err
}
func (s ConnectRequest) SetDest(v string) error {
return s.Struct.SetText(0, v)
}
func (s ConnectRequest) Type() ConnectionType {
return ConnectionType(s.Struct.Uint16(0))
}
func (s ConnectRequest) SetType(v ConnectionType) {
s.Struct.SetUint16(0, uint16(v))
}
func (s ConnectRequest) Metadata() (Metadata_List, error) {
p, err := s.Struct.Ptr(1)
return Metadata_List{List: p.List()}, err
}
func (s ConnectRequest) HasMetadata() bool {
p, err := s.Struct.Ptr(1)
return p.IsValid() || err != nil
}
func (s ConnectRequest) SetMetadata(v Metadata_List) error {
return s.Struct.SetPtr(1, v.List.ToPtr())
}
// NewMetadata sets the metadata field to a newly
// allocated Metadata_List, preferring placement in s's segment.
func (s ConnectRequest) NewMetadata(n int32) (Metadata_List, error) {
l, err := NewMetadata_List(s.Struct.Segment(), n)
if err != nil {
return Metadata_List{}, err
}
err = s.Struct.SetPtr(1, l.List.ToPtr())
return l, err
}
// ConnectRequest_List is a list of ConnectRequest.
type ConnectRequest_List struct{ capnp.List }
// NewConnectRequest creates a new list of ConnectRequest.
func NewConnectRequest_List(s *capnp.Segment, sz int32) (ConnectRequest_List, error) {
l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 8, PointerCount: 2}, sz)
return ConnectRequest_List{l}, err
}
func (s ConnectRequest_List) At(i int) ConnectRequest { return ConnectRequest{s.List.Struct(i)} }
func (s ConnectRequest_List) Set(i int, v ConnectRequest) error { return s.List.SetStruct(i, v.Struct) }
func (s ConnectRequest_List) String() string {
str, _ := text.MarshalList(0xc47116a1045e4061, s.List)
return str
}
// ConnectRequest_Promise is a wrapper for a ConnectRequest promised by a client call.
type ConnectRequest_Promise struct{ *capnp.Pipeline }
func (p ConnectRequest_Promise) Struct() (ConnectRequest, error) {
s, err := p.Pipeline.Struct()
return ConnectRequest{s}, err
}
type ConnectionType uint16
// ConnectionType_TypeID is the unique identifier for the type ConnectionType.
const ConnectionType_TypeID = 0xc52e1bac26d379c8
// Values of ConnectionType.
const (
ConnectionType_http ConnectionType = 0
ConnectionType_websocket ConnectionType = 1
ConnectionType_tcp ConnectionType = 2
)
// String returns the enum's constant name.
func (c ConnectionType) String() string {
switch c {
case ConnectionType_http:
return "http"
case ConnectionType_websocket:
return "websocket"
case ConnectionType_tcp:
return "tcp"
default:
return ""
}
}
// ConnectionTypeFromString returns the enum value with a name,
// or the zero value if there's no such value.
func ConnectionTypeFromString(c string) ConnectionType {
switch c {
case "http":
return ConnectionType_http
case "websocket":
return ConnectionType_websocket
case "tcp":
return ConnectionType_tcp
default:
return 0
}
}
type ConnectionType_List struct{ capnp.List }
func NewConnectionType_List(s *capnp.Segment, sz int32) (ConnectionType_List, error) {
l, err := capnp.NewUInt16List(s, sz)
return ConnectionType_List{l.List}, err
}
func (l ConnectionType_List) At(i int) ConnectionType {
ul := capnp.UInt16List{List: l.List}
return ConnectionType(ul.At(i))
}
func (l ConnectionType_List) Set(i int, v ConnectionType) {
ul := capnp.UInt16List{List: l.List}
ul.Set(i, uint16(v))
}
type Metadata struct{ capnp.Struct }
// Metadata_TypeID is the unique identifier for the type Metadata.
const Metadata_TypeID = 0xe1446b97bfd1cd37
func NewMetadata(s *capnp.Segment) (Metadata, error) {
st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2})
return Metadata{st}, err
}
func NewRootMetadata(s *capnp.Segment) (Metadata, error) {
st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2})
return Metadata{st}, err
}
func ReadRootMetadata(msg *capnp.Message) (Metadata, error) {
root, err := msg.RootPtr()
return Metadata{root.Struct()}, err
}
func (s Metadata) String() string {
str, _ := text.Marshal(0xe1446b97bfd1cd37, s.Struct)
return str
}
func (s Metadata) Key() (string, error) {
p, err := s.Struct.Ptr(0)
return p.Text(), err
}
func (s Metadata) HasKey() bool {
p, err := s.Struct.Ptr(0)
return p.IsValid() || err != nil
}
func (s Metadata) KeyBytes() ([]byte, error) {
p, err := s.Struct.Ptr(0)
return p.TextBytes(), err
}
func (s Metadata) SetKey(v string) error {
return s.Struct.SetText(0, v)
}
func (s Metadata) Val() (string, error) {
p, err := s.Struct.Ptr(1)
return p.Text(), err
}
func (s Metadata) HasVal() bool {
p, err := s.Struct.Ptr(1)
return p.IsValid() || err != nil
}
func (s Metadata) ValBytes() ([]byte, error) {
p, err := s.Struct.Ptr(1)
return p.TextBytes(), err
}
func (s Metadata) SetVal(v string) error {
return s.Struct.SetText(1, v)
}
// Metadata_List is a list of Metadata.
type Metadata_List struct{ capnp.List }
// NewMetadata creates a new list of Metadata.
func NewMetadata_List(s *capnp.Segment, sz int32) (Metadata_List, error) {
l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2}, sz)
return Metadata_List{l}, err
}
func (s Metadata_List) At(i int) Metadata { return Metadata{s.List.Struct(i)} }
func (s Metadata_List) Set(i int, v Metadata) error { return s.List.SetStruct(i, v.Struct) }
func (s Metadata_List) String() string {
str, _ := text.MarshalList(0xe1446b97bfd1cd37, s.List)
return str
}
// Metadata_Promise is a wrapper for a Metadata promised by a client call.
type Metadata_Promise struct{ *capnp.Pipeline }
func (p Metadata_Promise) Struct() (Metadata, error) {
s, err := p.Pipeline.Struct()
return Metadata{s}, err
}
type ConnectResponse struct{ capnp.Struct }
// ConnectResponse_TypeID is the unique identifier for the type ConnectResponse.
const ConnectResponse_TypeID = 0xb1032ec91cef8727
func NewConnectResponse(s *capnp.Segment) (ConnectResponse, error) {
st, err := capnp.NewStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2})
return ConnectResponse{st}, err
}
func NewRootConnectResponse(s *capnp.Segment) (ConnectResponse, error) {
st, err := capnp.NewRootStruct(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2})
return ConnectResponse{st}, err
}
func ReadRootConnectResponse(msg *capnp.Message) (ConnectResponse, error) {
root, err := msg.RootPtr()
return ConnectResponse{root.Struct()}, err
}
func (s ConnectResponse) String() string {
str, _ := text.Marshal(0xb1032ec91cef8727, s.Struct)
return str
}
func (s ConnectResponse) Error() (string, error) {
p, err := s.Struct.Ptr(0)
return p.Text(), err
}
func (s ConnectResponse) HasError() bool {
p, err := s.Struct.Ptr(0)
return p.IsValid() || err != nil
}
func (s ConnectResponse) ErrorBytes() ([]byte, error) {
p, err := s.Struct.Ptr(0)
return p.TextBytes(), err
}
func (s ConnectResponse) SetError(v string) error {
return s.Struct.SetText(0, v)
}
func (s ConnectResponse) Metadata() (Metadata_List, error) {
p, err := s.Struct.Ptr(1)
return Metadata_List{List: p.List()}, err
}
func (s ConnectResponse) HasMetadata() bool {
p, err := s.Struct.Ptr(1)
return p.IsValid() || err != nil
}
func (s ConnectResponse) SetMetadata(v Metadata_List) error {
return s.Struct.SetPtr(1, v.List.ToPtr())
}
// NewMetadata sets the metadata field to a newly
// allocated Metadata_List, preferring placement in s's segment.
func (s ConnectResponse) NewMetadata(n int32) (Metadata_List, error) {
l, err := NewMetadata_List(s.Struct.Segment(), n)
if err != nil {
return Metadata_List{}, err
}
err = s.Struct.SetPtr(1, l.List.ToPtr())
return l, err
}
// ConnectResponse_List is a list of ConnectResponse.
type ConnectResponse_List struct{ capnp.List }
// NewConnectResponse creates a new list of ConnectResponse.
func NewConnectResponse_List(s *capnp.Segment, sz int32) (ConnectResponse_List, error) {
l, err := capnp.NewCompositeList(s, capnp.ObjectSize{DataSize: 0, PointerCount: 2}, sz)
return ConnectResponse_List{l}, err
}
func (s ConnectResponse_List) At(i int) ConnectResponse { return ConnectResponse{s.List.Struct(i)} }
func (s ConnectResponse_List) Set(i int, v ConnectResponse) error {
return s.List.SetStruct(i, v.Struct)
}
func (s ConnectResponse_List) String() string {
str, _ := text.MarshalList(0xb1032ec91cef8727, s.List)
return str
}
// ConnectResponse_Promise is a wrapper for a ConnectResponse promised by a client call.
type ConnectResponse_Promise struct{ *capnp.Pipeline }
func (p ConnectResponse_Promise) Struct() (ConnectResponse, error) {
s, err := p.Pipeline.Struct()
return ConnectResponse{s}, err
}
const schema_b29021ef7421cc32 = "x\xda\xac\x91Ak\x13A\x1c\xc5\xdf\x9bI\\\x85\xe8" +
"fH\x15DCi\x11\xb5\xc5\x06\x9b\x08\x82\xa7\x80\x15" +
"TZ\xcc\x14\xcf\x96\xedv\xb05\xed\xee$;\xb5\xe4" +
"\x13x\xf5&\x1e=\x0a\x8a\xe8\x17\xf0\xa2\xa0\xa0\x88\x88" +
"\x1f\xc0\x83\x07O\xfd\x04\xb22\x0b\xdb@\xc9\xc1Co" +
"\x7f\xde<\xde\xfc\xfe\xffW\xff\xd6\x15\x8b\xd5\x87\x04t" +
"\xbdz,\xbf\xf4d\xff\xfc\xe7\x96|\x0b\xd5d\xde\xfe" +
"2\xe3\xf6g\x9e\xbeCU\x04\xc0\xe2\xf3GT\xaf\x03" +
"@\xbd\xdc\x03\xf3\xa8\xfb\xa0\xf2\xe2\xcc\xe0\x03t\x93\x87" +
"\xad\x9d\xb3\\gc\x81\x01\xd0\x98\xe3\x1b0\xff4\xfa" +
"q\xf1\xd5\xb9\xd6G\xa8\xa6\x18\x9b\xc1\xceO\xef\xfcS" +
"8\x7f\xf3\x1e\x98_\xff\xfa\xfd\xfd\xb3\xfe\xd2\xaf\x09\x04" +
"\x9d\xbfl\xb3q\xd2\x8f\x8d\x13\xc2C\x0cv\xb7\xe2\xb5" +
"\x1d\xe3*\xd1F\xe4\xa25;L]\x1a\xa7\xdb\xad8" +
"\xb2\x89\xbdq3M\x12\x13\xbbU\x93\xd90M2\xd3" +
"#\xf5qY\x01*\x04\xd4\\\x1b\xd0\x17$\xf5UA" +
"EN\xd1\x8b\x0bw\x01}ER\xdf\x16\x9c6\xc3a" +
":d\x0d\x8250\xdf1\xae\xf8\x05\x00O\x81=I" +
"\xd6\xc7\xb4\xa0\x17\xff\x17h\xb0\x1b\x98\xccy\x9e\xda\x01" +
"\xcf\xady@w%\xf5\xb2`\x89s\xc7kK\x92\xba" +
"'\xa8\x04\xa7(\x00\xb5\xe2\x19\x97%\xf5\xa6`\xb8a" +
"2W\"\x86nd\x0d\xc3\xf1\xb1A\x86GJ\xbe\x95" +
"&\xf7\x83\x91-.Y+`\x9a\xf3>@\x9d^\x05" +
"(\x94\x9a\x05\xc2M\xe7l\xbeg\xd6\xb34\xee\x1b\xd0" +
"\x05.\xb6\x07\xf1rb\xfc\x8aq\xd3\xc5\xc3\xa1\x8af" +
"'U\xe4\xc5\xcb\x92\xfa\x9a`\xd07\xa3r\xfb\xe0q" +
"\xb4]\xce\xff\x02\x00\x00\xff\xff\x14\xd5\xb6\xda"
func init() {
schemas.Register(schema_b29021ef7421cc32,
0xb1032ec91cef8727,
0xc47116a1045e4061,
0xc52e1bac26d379c8,
0xe1446b97bfd1cd37)
}

View File

@@ -1,195 +0,0 @@
using Go = import "go.capnp";
@0xdb8274f9144abc7e;
$Go.package("tunnelrpc");
$Go.import("github.com/sagernet/sing-box/protocol/cloudflare/tunnelrpc");
# === DEPRECATED Legacy Tunnel Authentication and Registration methods/servers ===
#
# These structs and interfaces are no longer used but it is important to keep
# them around to make sure backwards compatibility within the rpc protocol is
# maintained.
struct Authentication @0xc082ef6e0d42ed1d {
# DEPRECATED: Legacy tunnel authentication mechanism
key @0 :Text;
email @1 :Text;
originCAKey @2 :Text;
}
struct TunnelRegistration @0xf41a0f001ad49e46 {
# DEPRECATED: Legacy tunnel authentication mechanism
err @0 :Text;
# the url to access the tunnel
url @1 :Text;
# Used to inform the client of actions taken.
logLines @2 :List(Text);
# In case of error, whether the client should attempt to reconnect.
permanentFailure @3 :Bool;
# Displayed to user
tunnelID @4 :Text;
# How long should this connection wait to retry in seconds, if the error wasn't permanent
retryAfterSeconds @5 :UInt16;
# A unique ID used to reconnect this tunnel.
eventDigest @6 :Data;
# A unique ID used to prove this tunnel was previously connected to a given metal.
connDigest @7 :Data;
}
struct RegistrationOptions @0xc793e50592935b4a {
# DEPRECATED: Legacy tunnel authentication mechanism
# The tunnel client's unique identifier, used to verify a reconnection.
clientId @0 :Text;
# Information about the running binary.
version @1 :Text;
os @2 :Text;
# What to do with existing tunnels for the given hostname.
existingTunnelPolicy @3 :ExistingTunnelPolicy;
# If using the balancing policy, identifies the LB pool to use.
poolName @4 :Text;
# Client-defined tags to associate with the tunnel
tags @5 :List(Tag);
# A unique identifier for a high-availability connection made by a single client.
connectionId @6 :UInt8;
# origin LAN IP
originLocalIp @7 :Text;
# whether Argo Tunnel client has been autoupdated
isAutoupdated @8 :Bool;
# whether Argo Tunnel client is run from a terminal
runFromTerminal @9 :Bool;
# cross stream compression setting, 0 - off, 3 - high
compressionQuality @10 :UInt64;
uuid @11 :Text;
# number of previous attempts to send RegisterTunnel/ReconnectTunnel
numPreviousAttempts @12 :UInt8;
# Set of features this cloudflared knows it supports
features @13 :List(Text);
}
enum ExistingTunnelPolicy @0x84cb9536a2cf6d3c {
# DEPRECATED: Legacy tunnel registration mechanism
ignore @0;
disconnect @1;
balance @2;
}
struct ServerInfo @0xf2c68e2547ec3866 {
# DEPRECATED: Legacy tunnel registration mechanism
locationName @0 :Text;
}
struct AuthenticateResponse @0x82c325a07ad22a65 {
# DEPRECATED: Legacy tunnel registration mechanism
permanentErr @0 :Text;
retryableErr @1 :Text;
jwt @2 :Data;
hoursUntilRefresh @3 :UInt8;
}
interface TunnelServer @0xea58385c65416035 extends (RegistrationServer) {
# DEPRECATED: Legacy tunnel authentication server
registerTunnel @0 (originCert :Data, hostname :Text, options :RegistrationOptions) -> (result :TunnelRegistration);
getServerInfo @1 () -> (result :ServerInfo);
unregisterTunnel @2 (gracePeriodNanoSec :Int64) -> ();
# obsoleteDeclarativeTunnelConnect RPC deprecated in TUN-3019
obsoleteDeclarativeTunnelConnect @3 () -> ();
authenticate @4 (originCert :Data, hostname :Text, options :RegistrationOptions) -> (result :AuthenticateResponse);
reconnectTunnel @5 (jwt :Data, eventDigest :Data, connDigest :Data, hostname :Text, options :RegistrationOptions) -> (result :TunnelRegistration);
}
struct Tag @0xcbd96442ae3bb01a {
# DEPRECATED: Legacy tunnel additional HTTP header mechanism
name @0 :Text;
value @1 :Text;
}
# === End DEPRECATED Objects ===
struct ClientInfo @0x83ced0145b2f114b {
# The tunnel client's unique identifier, used to verify a reconnection.
clientId @0 :Data;
# Set of features this cloudflared knows it supports
features @1 :List(Text);
# Information about the running binary.
version @2 :Text;
# Client OS and CPU info
arch @3 :Text;
}
struct ConnectionOptions @0xb4bf9861fe035d04 {
# client details
client @0 :ClientInfo;
# origin LAN IP
originLocalIp @1 :Data;
# What to do if connection already exists
replaceExisting @2 :Bool;
# cross stream compression setting, 0 - off, 3 - high
compressionQuality @3 :UInt8;
# number of previous attempts to send RegisterConnection
numPreviousAttempts @4 :UInt8;
}
struct ConnectionResponse @0xdbaa9d03d52b62dc {
result :union {
error @0 :ConnectionError;
connectionDetails @1 :ConnectionDetails;
}
}
struct ConnectionError @0xf5f383d2785edb86 {
cause @0 :Text;
# How long should this connection wait to retry in ns
retryAfter @1 :Int64;
shouldRetry @2 :Bool;
}
struct ConnectionDetails @0xb5f39f082b9ac18a {
# identifier of this connection
uuid @0 :Data;
# airport code of the colo where this connection landed
locationName @1 :Text;
# tells if the tunnel is remotely managed
tunnelIsRemotelyManaged @2: Bool;
}
struct TunnelAuth @0x9496331ab9cd463f {
accountTag @0 :Text;
tunnelSecret @1 :Data;
}
interface RegistrationServer @0xf71695ec7fe85497 {
registerConnection @0 (auth :TunnelAuth, tunnelId :Data, connIndex :UInt8, options :ConnectionOptions) -> (result :ConnectionResponse);
unregisterConnection @1 () -> ();
updateLocalConfiguration @2 (config :Data) -> ();
}
struct RegisterUdpSessionResponse @0xab6d5210c1f26687 {
err @0 :Text;
spans @1 :Data;
}
interface SessionManager @0x839445a59fb01686 {
# Let the edge decide closeAfterIdle to make sure cloudflared doesn't close session before the edge closes its side
registerUdpSession @0 (sessionId :Data, dstIp :Data, dstPort :UInt16, closeAfterIdleHint :Int64, traceContext :Text = "") -> (result :RegisterUdpSessionResponse);
unregisterUdpSession @1 (sessionId :Data, message :Text) -> ();
}
struct UpdateConfigurationResponse @0xdb58ff694ba05cf9 {
# Latest configuration that was applied successfully. The err field might be populated at the same time to indicate
# that cloudflared is using an older configuration because the latest cannot be applied
latestAppliedVersion @0 :Int32;
# Any error encountered when trying to apply the last configuration
err @1 :Text;
}
# ConfigurationManager defines RPC to manage cloudflared configuration remotely
interface ConfigurationManager @0xb48edfbdaa25db04 {
updateConfiguration @0 (version :Int32, config :Data) -> (result: UpdateConfigurationResponse);
}
interface CloudflaredServer @0xf548cef9dea2a4a1 extends(SessionManager, ConfigurationManager) {}

File diff suppressed because it is too large Load Diff